In [4]:
import os
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

In [5]:
class LFWDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.face_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.face_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.face_frame.iloc[idx, 0])
        image = Image.open(img_name)
        label = self.face_frame.iloc[idx, 1]

        if self.transform:
            image = self.transform(image)

        return image, label


In [6]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3)),
])


In [None]:
dataset_path = './data/lfw/lfw-deepfunneled/'
csv_file = './data/lfw/lfw_allnames.csv'

lfw_dataset = LFWDataset(csv_file=csv_file, root_dir=dataset_path, transform=transform)

data_loader = DataLoader(lfw_dataset, batch_size=4, shuffle=True)
