In [None]:
import torch
import numpy as np, pandas as pd, glob,  time
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, models, datasets
from torch.utils.data import Dataset, DataLoader
import cv2

In [None]:
device='cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class GenderAge(Dataset):
  def __init__(self, df):
    self.df=df
    self.normalize=transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )

  def __len__(self):
    return len(self.df)

  def __getitem__(self, ix):
    f=self.df.iloc[ix].squeeze()
    file=f.file
    gen=f.gender=='Female'
    age=f.age
    im=cv2.imread(file)
    im=cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    return im, age, gen

  def preprocess_image(self, im):
    im=cv2.resize(im, (224, 224))
    im=torch.tensor(im).permute(2, 0, 1)
    im=self.normalize(im/255.)
    return im[1]

  def collate_fn(self, batch):
    'used during data loading'
    ims, ages, genders=[], [], []

    for im, age, gender in batch:
      im=self.preprocess_image(im)
      ims.append(im)

      ages.append(float(int(age)/80))
      genders.append(float(gender))

    ages, genders=[torch.tensor(x).to(device).float() for x in [ages, genders]]
    ims=torch.cat(ims).to(device)
    return ims, ages, genders

### Unet Architecture

In [None]:
class UNet(nn.Module):
 def __init__(self, pretrained=True, out_channels=12):
  super().__init__()
  self.encoder= vgg16_bn(pretrained=pretrained).features
  self.block1 = nn.Sequential(*self.encoder[:6])
  self.block2 = nn.Sequential(*self.encoder[6:13])
  self.block3 = nn.Sequential(*self.encoder[13:20])
  self.block4 = nn.Sequential(*self.encoder[20:27])
  self.block5 = nn.Sequential(*self.encoder[27:34])
  self.bottleneck = nn.Sequential(*self.encoder[34:])
  self.conv_bottleneck = conv(512, 1024)
  self.up_conv6 = up_conv(1024, 512)
  self.conv6 = conv(512 + 512, 512)
  self.up_conv7 = up_conv(512, 256)
  self.conv7 = conv(256 + 512, 256)
  self.up_conv8 = up_conv(256, 128)
  self.conv8 = conv(128 + 256, 128)
  self.up_conv9 = up_conv(128, 64)
  self.conv9 = conv(64 + 128, 64)
  self.up_conv10 = up_conv(64, 32)
  self.conv10 = conv(32 + 64, 32)
  self.conv11 = nn.Conv2d(32, out_channels, kernel_size=1)

In [None]:
def forward(self, x):
 block1 = self.block1(x)
 block2 = self.block2(block1)
 block3 = self.block3(block2)
 block4 = self.block4(block3)
 block5 = self.block5(block4)
 bottleneck = self.bottleneck(block5)
 x = self.conv_bottleneck(bottleneck)
 x = self.up_conv6(x)
 x = torch.cat([x, block5], dim=1)
 x = self.conv6(x)
 x = self.up_conv7(x)
 x = torch.cat([x, block4], dim=1)
 x = self.conv7(x)
 x = self.up_conv8(x)
 x = torch.cat([x, block3], dim=1)
 x = self.conv8(x)
 x = self.up_conv9(x)
 x = torch.cat([x, block2], dim=1)
 x = self.conv9(x)
 x = self.up_conv10(x)
 x = torch.cat([x, block1], dim=1)
 x = self.conv10(x)
 x = self.conv11(x)
 return x