In [1]:
import torch
import numpy as np, cv2, pandas as pd, glob, time
import matplotlib.pyplot as plt
%matplotlib inline
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, models, datasets

In [2]:
device='cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
#@ Fetching Datasets:
from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

auth.authenticate_user()
gauth=GoogleAuth()
gauth.credentials=GoogleCredentials.get_application_default()
drive=GoogleDrive(gauth)

def getfile(file_id, name):
  downloaded=drive.CreateFile({'id': file_id})
  downloaded.GetContentFile(name)

getfile('1Z1RqRo0_JiavaZw2yzZG6WETdZQ8qX86', 'fairface-img-margin025-trainval.zip')
getfile('1k5vvyREmHDW5TSM9QgB04Bvc8C8_7dl-', 'fairface-label-train.csv')
getfile('1_rtz1M1zhvS0d5vVoXUamnohB6cJ02iJ', 'fairface-label-val.csv')

!unzip -qq fairface-img-margin025-trainval.zip

In [4]:
trn_df = pd.read_csv('fairface-label-train.csv')
val_df = pd.read_csv('fairface-label-val.csv')
trn_df.head()

Unnamed: 0,file,age,gender,race,service_test
0,train/1.jpg,59,Male,East Asian,True
1,train/2.jpg,39,Female,Indian,False
2,train/3.jpg,11,Female,Black,False
3,train/4.jpg,26,Female,Indian,True
4,train/5.jpg,26,Female,Indian,True


In [5]:
IMAGE_SIZE=224
#@ Class for Gender and Age:
class GenderAge(Dataset):
  def __init__(self, df, tfms=None):
    self.df=df
    self.normalize=transforms.Normalize(
                     mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225])

  def __len__(self):
    return len(self.df)

  def __getitem__(self, ix):
    f=self.df.iloc[ix].sqeeze()
    file=f.file
    gen=f.gender=='Female'
    age=f.age
    im=cv2.imread(file)
    im=cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    return im ,age, gen

  def preprocess(self, im):
    im=cv2.resize(im, (IMAGE_SIZE, IMAGE_SIZE))
    im=torch.tensor(im).permute(2, 0, 1)
    im=self.normalize(im/255.)
    return im[None]

  def collate_func(self, batch):
    'preprocess images, ages, gender'
    ims, ages, genders=[], [], []
    for im, age, gender in batch:
      im=self.preprocess(im)
      ims.append(im)

      ages.append(float(int(age)/80))
      genders.append(float(gender))

    ages, genders=[torch.tensor(x).to(device).float() for x in [ages, genders]]
    ims=torch.cat(ims).to(device)
    return ims, ages, gender