In [2]:
import torch
import numpy as np, pandas as pd, glob,  time
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, models, datasets
from torch.utils.data import Dataset, DataLoader
import cv2

In [None]:
device='cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class GenderAge(Dataset):
  def __init__(self, df):
    self.df=df
    self.normalize=transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )

  def __len__(self):
    return len(self.df)

  def __getitem__(self, ix):
    f=self.df.iloc[ix].squeeze()
    file=f.file
    gen=f.gender=='Female'
    age=f.age
    im=cv2.imread(file)
    im=cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    return im, age, gen

  def preprocess_image(self, im):
    im=cv2.resize(im, (224, 224))
    im=torch.tensor(im).permute(2, 0, 1)
    im=self.normalize(im/255.)
    return im[1]

  def collate_fn(self, batch):
    'used during data loading'
    ims, ages, genders=[], [], []

    for im, age, gender in batch:
      im=self.preprocess_image(im)
      ims.append(im)

      ages.append(float(int(age)/80))
      genders.append(float(gender))

    ages, genders=[torch.tensor(x).to(device).float() for x in [ages, genders]]
    ims=torch.cat(ims).to(device)
    return ims, ages, genders