In [1]:
import torch
import torchvision
import glob
import numpy as np
from PIL import Image

In [8]:
class CustomImageDataset(torch.utils.data.Dataset):
    
    def __init__(self, root, transforms=None, target_transforms= None):
        self.root = root
        self.transforms = transforms
        self.target_transforms = None
        female = list(glob.glob(root+'/female/*.jpg'))
        male = list(glob.glob(root+'/male/*.jpg'))
        
        self.imgs = female + male
        self.targets = list(np.zeros(len(female))) + list(np.ones(len(male)))
    
    def __getitem__(self, idx):
        image = Image.open(self.imgs[idx]).convert("RGB")
        image = np.array(image)
        image.resize((300, 300, 3))
        label = self.targets[idx]
        if self.transforms:
            image = self.transforms(image)
        if self.target_transforms:
            label = self.target_transforms(label)
        return image/255., label
    
    def __len__(self):
        return len(self.imgs)

In [9]:
from torch.utils.data import DataLoader

training_data = CustomImageDataset(root='Datasets/Gender Classification/Training')
test_data = CustomImageDataset(root='Datasets/Gender Classification/Validation')

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

In [146]:
# next(iter(train_dataloader))
next(iter(train_dataloader))

[tensor([[[[0.0078, 0.0118, 0.0314],
           [0.0157, 0.0118, 0.0353],
           [0.0353, 0.0314, 0.0549],
           ...,
           [0.5176, 0.3020, 0.2235],
           [0.5098, 0.2941, 0.2157],
           [0.4980, 0.2824, 0.2039]],
 
          [[0.4706, 0.2627, 0.1922],
           [0.4431, 0.2431, 0.1765],
           [0.4157, 0.2314, 0.1686],
           ...,
           [0.5608, 0.4039, 0.3569],
           [0.5255, 0.3686, 0.3216],
           [0.5216, 0.3804, 0.3255]],
 
          [[0.5490, 0.4078, 0.3529],
           [0.5686, 0.4196, 0.3686],
           [0.5765, 0.4196, 0.3804],
           ...,
           [0.0000, 0.0196, 0.0314],
           [0.0039, 0.0235, 0.0353],
           [0.0157, 0.0392, 0.0392]],
 
          ...,
 
          [[0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000],
           ...,
           [0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000]],
 
          [[0.0000,