In [7]:
import torch
from torchvision.models import efficientnet_b0, mobilenet_v2
import torchvision.transforms.v2 as v2
from torch import nn, float32
import sys
from pathlib import Path
sys.path.append(str(Path('../').resolve()))
from utils.transform_utils import DetectFace

In [8]:
class ShuffleNet_V2_X0_5_FaceTransforms(nn.Module):
    """
    A series of transformations to apply to an image before feeding it to a ShuffleNetV2_X0_5 model.
    """
    def __init__(self, detector = None, pad: int = 0):
        super().__init__()

        self.detector = detector

        self.transforms = v2.Compose([
            v2.RandomHorizontalFlip(p=0.5),
            v2.RandomRotation(degrees=15),
            DetectFace(detector, pad),
            v2.Resize((224, 224), interpolation=v2.InterpolationMode.BILINEAR),
            v2.ToImage(),
            v2.ToDtype(dtype=float32, scale=True),
            v2.Normalize(mean=[0.485, 0.456, 0.406],  std=[0.229, 0.224, 0.225])
        ])
    
    def forward(self, x):
        x = self.transforms(x)
        return x

In [None]:
class EfficientNet_B0_FaceTransforms(nn.Module):
    """
    A series of transformations to apply to an image before feeding it to a EfficientNet_B0 model.
    """
    def __init__(self, detector = None, pad: int = 0):
        super().__init__()

        self.detector = detector

        self.transforms = v2.Compose([
            v2.RandomHorizontalFlip(p=0.5),
            v2.RandomRotation(degrees=15),
            DetectFace(detector, pad),
            v2.Resize((224, 224), interpolation=v2.InterpolationMode.BICUBIC),
            v2.ToImage(),
            v2.ToDtype(dtype=float32, scale=True),
            v2.Normalize(mean=[0.485, 0.456, 0.406],  std=[0.229, 0.224, 0.225])
        ])
    
    def forward(self, x):
        x = self.transforms(x)
        return x

In [None]:
from torch.utils.data import Dataset
import torch
from typing import Union, Optional, Callable, Tuple, Any
import pandas as pd
import os
from PIL import Image
from pathlib import Path

class CelebA(Dataset):
    def __init__(
            self,
            root: Union[str, Path],
            transform: Optional[Callable] = None,
            
    ):
        self.root = root
        self.transform = transform
        
        # Read filenames and attributes
        files, attr = self._read_csv('list_attr_celeba.csv')
        
        self.attr = torch.div(attr + 1, 2, rounding_mode='floor').float()
        self.files = [os.path.join(self.root, 'img_align_celeba/', file) for file in files]
        

    def get_pos_weights(self) -> torch.Tensor:
        num_of_labels = len(self.attr)
        num_of_pos_labels = torch.sum(self.attr, dim = 0)
        num_of_neg_labels = num_of_labels - num_of_pos_labels
        pos_weights = num_of_neg_labels / num_of_pos_labels
        return pos_weights


    def _read_csv(
            self,
            filename: str,
    ) -> Union[torch.Tensor, Tuple]:
        
        df = pd.read_csv(os.path.join(self.root, filename), index_col=0, header=0)
        attr = torch.from_numpy(df.values)
        files = df.index.values

        return files, attr
    

    def __getitem__(self, index) -> Tuple[Any, Any]:
        image = Image.open(self.files[index])
        
        if self.transform is not None:
            image = self.transform(image)

        return image, self.attr[index]


    def __len__(self) -> int:
        return len(self.attr)

    


202599