Train a multi-objective image classifier using data from https://www.kaggle.com/jangedoo/utkface-new  

In [2]:
!pip install pytorch-lightning neptune-client

In [4]:
import torch
from torchvision import datasets, transforms, io, models
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from PIL import ImageFile
from lime import lime_image
from skimage.segmentation import mark_boundaries

from collections import OrderedDict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd 
import os
import seaborn as sns
import time

# EDA

Since the dataset contains only of pictures and their titles, we have to create several helper functions to help us extract information from titles. Provided information is: image, race, gender and age. I will read all the pictures and create a dataframe with available information.

In [5]:
root_dir = '../input/utkface-new/UTKFace'
files = os.listdir(root_dir)

In [7]:
def read_image(root_dir: str, img_name: str) -> (int, int, int, torch.Tensor, str):
    """
    Helper function to read image information
    """
    full_img_name = root_dir + "/" + img_name
    splitted = img_name.split('_')
    try:
        age, gender, race, image, image_name = int(splitted[0]), int(splitted[1]), int(splitted[2]), io.read_image(full_img_name), full_img_name
    except:
        age, gender, race, image, image_name = int(splitted[0]), int(splitted[1]), None, io.read_image(full_img_name), full_img_name
    return age, gender, race, image, image_name


def race(integer: int) -> str:
    """
    Helper function to convert race number to text
    """
    age_dict = {0:"White", 1:"Black", 2:"Asian", 3:"Indian", 4:"Other"}
    try:
        return age_dict[integer]
    except:
        return age_dict[4]


def gender(integer: int) -> str:
    """
    Helper function to convert gender number to text
    """
    gender_dict = {0:"Male", 1:"Female"}
    return gender_dict[integer]


def age_bin(integer: int) -> str:
    """
    Helper function to classify age into bins
    """
    age_bin_dict = {0:"0-9", 10:"10-19", 20:"20-29", 30:"30-39", 40:"40-49",
                   50:"50-59", 60:"60-69", 70:"70-79", 80:"80-89", 90:"90 and above"}
    rounddown = int(np.floor(integer / 10.0)) * 10
    age_bin = rounddown if rounddown <= 90 else 90
    return age_bin_dict[age_bin]


def autolabel(bar_plot):
    """
    Attach a text label above each bar displaying its height
    """
    for p in bar_plot.patches:
        bar_plot.annotate(format(p.get_height(), '.0f'), 
                       (p.get_x() + p.get_width() / 2., p.get_height()), 
                       ha = 'center', va = 'center', 
                       xytext = (0, 9), 
                       textcoords = 'offset points')

In [8]:
ages = []
genders = []
races = []
images = []
image_names = []

for img_name in files:
    photo_info = read_image(root_dir, img_name)
    ages.append(photo_info[0])
    genders.append(photo_info[1])
    races.append(photo_info[2])
    images.append(photo_info[3])
    image_names.append(photo_info[4])
    
df = pd.DataFrame({"Age": ages, "Gender": genders, "Race": races, "image_name": image_names})

df['Age_bin'] = df['Age'].apply(lambda x: age_bin(x))
df['Gender_str'] = df['Gender'].apply(lambda x: gender(x))
df['Race_str'] = df['Race'].apply(lambda x: race(x))

In [10]:
len(images), len(df)

We have 23708 pictures in our dataset containing faces of individuals, but not all the photos are what we would expect. Some of them are showing only one eye, others - some numbers or nothing distinguishable. After manual inspection I will remove such pictures from the dataset and we will be left with 23702 pictures.

In [12]:
plt.figure()
f, axarr = plt.subplots(2,3) 
plt.axis('off')

axarr[0][0].imshow(images[df.index[df['image_name'].str.contains("1_0_0_20170109193052283.jpg.chip")][0]].permute(1, 2, 0))
axarr[0][1].imshow(images[df.index[df['image_name'].str.contains("1_0_0_20170109194120301.jpg.chip")][0]].permute(1, 2, 0))
axarr[0][2].imshow(images[df.index[df['image_name'].str.contains("1_1_4_20170109194502921.jpg.chip")][0]].permute(1, 2, 0))
axarr[1][0].imshow(images[df.index[df['image_name'].str.contains("90_0_0_20170111210338948.jpg.chip")][0]].permute(1, 2, 0))
axarr[1][1].imshow(images[df.index[df['image_name'].str.contains("80_0_2_20170111210646563.jpg.chip")][0]].permute(1, 2, 0))
axarr[1][2].imshow(images[df.index[df['image_name'].str.contains("90_0_0_20170111210338948.jpg.chip")][0]].permute(1, 2, 0))

In [14]:
bad_indexes = [
    df.index[df['image_name'].str.contains("1_0_0_20170109194120301.jpg.chip")][0],
    df.index[df['image_name'].str.contains("1_0_0_20170109193052283.jpg.chip")][0],
    df.index[df['image_name'].str.contains("90_0_0_20170111210338948.jpg.chip")][0],
    df.index[df['image_name'].str.contains("1_1_4_20170109194502921.jpg.chip")][0],
    df.index[df['image_name'].str.contains("80_0_2_20170111210646563.jpg.chip")][0],
    df.index[df['image_name'].str.contains("28_0_1_20170117020012900.jpg.chip")][0]
]

In [15]:
df.drop(df.index[bad_indexes], inplace=True)

for bad_index in bad_indexes:
    del images[bad_index]

In [16]:
len(images), len(df)

Let's see how our data is distributed. I will check age, gender and race distribution.

In [17]:
sns.set(rc={'figure.figsize':(10,6)})
ax = sns.countplot(x="Age_bin",data=df, order=['0-9','10-19','20-29','30-39','40-49','50-59','60-69', '70-79', '80-89', '90 and above'])
ax.set(xlabel='Age',
       ylabel='Age count',
       title='Age distribution')
autolabel(ax)
plt.tight_layout()
plt.show()

In [18]:
sns.set(rc={'figure.figsize':(4,6)})
ax2 = sns.countplot(x="Gender_str",data=df)
ax2.set(xlabel='Gender',
       ylabel='Gender count',
       title='Gender distribution')
autolabel(ax2)
plt.tight_layout()
plt.show()

In [19]:
sns.set(rc={'figure.figsize':(10,6)})
ax3 = sns.countplot(x="Race_str",data=df, order=['White', 'Black', 'Asian', 'Indian', 'Other'])
ax3.set(xlabel='Race',
       ylabel='Race count',
       title='Race distribution')
autolabel(ax3)
plt.tight_layout()
plt.show()

In [20]:
df.describe()

One important thing is age - pictures vary from 1 to 116 years. I will have to keep this in mind when training a model.

# Data preprocessing

Data preprocessing step consists only of FaceDataset class, which performs the following transformations: image resizing and tensor normalization according to ImageNet dataset. Since gender values are 0 (male) and 1 (female), it would be nice to have age values on the same scale. We saw earlier that age ranges from 1 to 116. By applying natural logarithm on biggest age value (116) and dividing it by 4.75 we get approx. 1, so by formula log(age)/4.75 we can squeeze all age values between 0 and 1. 

In [21]:
class FaceDataset(Dataset):
    
    def __init__(self, imgs: list, df: pd.DataFrame):
        super().__init__()
        self.imgs = imgs
        self.df = df
        self.ages = df['Age']
        self.genders = df['Gender']
        self.transforms = transforms.Compose([transforms.Resize(48), 
                                              transforms.ToTensor(),
                                              transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                             ])

        
    def __len__(self):
        return self.df.shape[0]
    
    
    def __getitem__(self,idx):
        age = torch.tensor(self.ages.iloc[idx], dtype=torch.float32)
        gender = torch.tensor(self.genders.iloc[idx], dtype=torch.int64)
        img_tensor = self.transforms(transforms.ToPILImage()(self.imgs[idx]))
        return {"input": img_tensor, "labels": (np.log(age)/4.75, gender)}

# Training

First I split data to train, validation and test sets. For training I will be using resnet34 pretrained neural network. Since the last layer consists of 1000 outputs (and for our task we only need 2 - one for gender and one for age) I will have to remove last layer of resnet34 and replace it with my own linear layers. For loss logging I am using third party logger - Neptune AI. 

In [22]:
x_train, x_test, y_train, y_test = train_test_split(images, df, test_size=0.1, random_state=1, shuffle=True)

x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.1, random_state=1, shuffle=True) 

len(x_train), len(x_val), len(y_train), len(y_val), len(x_test), len(y_test)

In [25]:
class FaceModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        backbone = models.resnet34(pretrained=True)
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = torch.nn.Sequential(*layers)
        self.fc_age = torch.nn.Linear(num_filters, 1)
        self.fc_gender = torch.nn.Linear(num_filters, 1)
        self.criterion_age = torch.nn.MSELoss()
        self.criterion_gender = torch.nn.BCEWithLogitsLoss()
        self.batch_size = 32
        self.base_grad(requires_grad=False)
        
    def base_grad(self, requires_grad: bool):
        """
        Used to able or disable gradient calculation during backpropagation.
        """
        for param in self.feature_extractor.parameters():
            param.requires_grad = requires_grad


    def forward(self, image_tensor):
        
        self.feature_extractor.eval()
        with torch.no_grad():
            x = self.feature_extractor(image_tensor).flatten(1)
        age_pred = self.fc_age(x)
        gender_pred = self.fc_gender(x)
        
        return age_pred, gender_pred

    
    def training_step(self, batch: dict, batch_idx):
        
        input_tensor = batch["input"]
        age = batch["labels"][0]
        gender = batch["labels"][1]
        
        age_pred, gender_pred = self(input_tensor)
        age_loss = self.criterion_age(age_pred, age)
        gender_loss = self.criterion_gender(gender_pred, gender.unsqueeze(1).float())
        loss = age_loss + gender_loss
        
        return {"loss": loss, "age_pred": age_pred, "gender_pred": gender_pred}
    
    
    def validation_step(self, val_batch: dict, val_batch_idx):
        
        input_tensor = val_batch["input"]
        age = val_batch["labels"][0]
        gender = val_batch["labels"][1]
        
        age_pred, gender_pred = self(input_tensor)
        age_loss = self.criterion_age(age_pred, age)
        gender_loss = self.criterion_gender(gender_pred, gender.unsqueeze(1).float())
        loss = age_loss + gender_loss
        
        return {"loss": loss, "age_pred": age_pred, "gender_pred": gender_pred}

    
    def configure_optimizers(self):
        
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
    
    def train_dataloader(self):
        
        dataset = FaceDataset(x_train, y_train)
        loader = DataLoader(dataset, batch_size = self.batch_size, shuffle=True, num_workers=8)
        return loader

    
    def val_dataloader(self):
        
        dataset = FaceDataset(x_val, y_val)
        loader = DataLoader(dataset, batch_size = self.batch_size, shuffle=False, num_workers=8)
        return loader
    
    
    def training_epoch_end(self, outputs: dict) -> None:
        """
        Contains the outputs of the training_step, collected for all batches over the epoch.
        Tracks the progress of the entire epoch by calculating average loss.
        """
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        print(f"Training loss is {avg_loss} for this epoch")
#         neptune_logger["train_loss"].log(avg_loss)


    def validation_epoch_end(self, outputs: dict) -> None:
        """
        Contains the outputs of the validation_step, collected for all batches over the epoch.
        Tracks the progress of the entire epoch by calculating average loss.
        """
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        self.log("val_loss", avg_loss)
        print(f"Validation loss is {avg_loss} for this epoch")
#         neptune_logger["val_loss"].log(avg_loss)

In [None]:
# import neptune.new as neptune
# import datetime

# api_token = 'eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIyYmE3YWY2MS1jN2VlLTRkMjctYTAyOS05ZTUxM2M5M2ZhY2QifQ=='
    
# neptune_logger = neptune.init(project="tiskutis/FaceClassifier", api_token=api_token)

In [None]:
# used to stop logger instance in neptune service
# neptune_logger.stop()

I will train for 5 and finetune for 10 epochs with learning rates 1e-3 and 1e-5 respectively. Model will be trained and saved in runtime environment.

In [28]:
def train(model, epochs, lr):
    model.lr = lr
    trainer = pl.Trainer(max_epochs=epochs, 
                         gpus=torch.cuda.device_count(), 
                         callbacks=[EarlyStopping(monitor="val_loss", patience=2)])
    trainer.fit(model)
    
    return trainer


def train_and_save(train_epochs: int = 5, finetune_epochs: int = 10, train_lr: float = 1e-3, finetune_lr: float = 1e-5):
    
    model = FaceModel() 
    print("=" * 25, f"Training - ", "=" * 25)
    train(model, epochs=train_epochs, lr=train_lr)
    
    print("=" * 25, f"Fine-tuning - ", "=" * 25)
    model.base_grad(True)
    trainer = train(model, epochs=finetune_epochs, lr=finetune_lr)
    
    trainer.save_checkpoint(f'resnet342.ckpt')

In [29]:
resnet34_checkpoint = train_and_save()

# Testing

In [30]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = FaceModel().load_from_checkpoint("./resnet342.ckpt")
model.eval().to(device)

In [31]:
test_loader = DataLoader(FaceDataset(x_test, y_test), batch_size = 32, shuffle=False, num_workers=8)

In [32]:
age_preds = []
age_trues = []
gender_preds = []
gender_trues = []
inference_time = []


with torch.no_grad():
    for i, batch in enumerate(test_loader):
        batch = {k: v for k, v in batch.items()}
        start = time.time()
        
        images = batch["input"].to(device)
        age = batch["labels"][0]
        gender = batch["labels"][1]
        age_pred, gender_pred = model(images)
        
        end = time.time()
        
        age_preds.append(age_pred)
        gender_preds.append(gender_pred)
        
        age_trues.append(age)
        gender_trues.append(gender)
        
        inference_time.append(end - start)

In [33]:
print(f"Inference time for batch of 32 samples: {sum(inference_time)/len(inference_time)} s")

In [34]:
age_preds = torch.cat(age_preds).detach().cpu().numpy().reshape(-1,1)
gender_preds = torch.nn.Sigmoid()(torch.cat(gender_preds)).detach().cpu().numpy().reshape(-1,1)
age_trues = torch.cat(age_trues).detach().cpu().numpy().reshape(-1,1)
gender_trues = torch.cat(gender_trues).detach().cpu().numpy().reshape(-1,1)

After predictions are made I am going to create a new dataframe consisting of test data and test values as well as predicted values. I will also include differences between real and predicted values so that I could inspect best and worst samples.

In [35]:
df_test = y_test.copy()
df_test["Gender_pred"] = gender_preds
df_test["Gender_pred_binary"] = df_test.apply(lambda x: 0 if x["Gender_pred"] < 0.5 else 1, axis=1)

f = lambda x: np.round(np.exp(x*4.75))
df_test["Age_pred_sig"] = age_preds
df_test["Age_pred_real"] = f(age_preds)

In [36]:
df_test.describe()

# Results

In [37]:
gender_binary_preds = np.where(np.array(gender_preds) > 0.5, 1, 0)

### Classification report

In [38]:
print(classification_report(gender_trues, gender_binary_preds))

### Confusion matrix

In [39]:
disp = ConfusionMatrixDisplay(confusion_matrix(gender_trues, gender_binary_preds), display_labels=["Male", "Female"])
disp.plot()
disp.ax_.set_title(f"Confusion matrix for gender classification")

### Age predictions

In [40]:
sns.scatterplot(data=df_test, x="Age", y="Age_pred_real").set_title("Age predictions vs real age")

In [41]:
df_test["Gender_pred_diff"] = df_test.apply(lambda x: abs(x["Gender_pred"] - x["Gender"]), axis=1)
df_test["Age_pred_diff"] = df_test.apply(lambda x: abs(x["Age"] - x["Age_pred_real"]), axis=1)

In [42]:
df_test.head()

Let's sort the dataframe by best/worst predictions and plot them. For this case I will create 4 separate dataframes with sorted values. 

In [43]:
df_gender_worst = df_test.sort_values(by=['Gender_pred_diff'], ascending=False)
df_age_worst = df_test.sort_values(by=['Age_pred_diff'], ascending=False)
df_gender_best = df_test.sort_values(by=['Gender_pred_diff'], ascending=True)
df_age_best = df_test.sort_values(by=['Age_pred_diff'], ascending=True)

In [44]:
df_gender_worst.head(5)

In [45]:
def plot_gender_predictions(df):
    fig = plt.figure(figsize=(20, 14))
    rows = 4
    columns = 5

    for i in range(0, len(df.head(20))):
        fig.add_subplot(rows, columns, i+1)
        plt.imshow(io.read_image(df.iloc[i]["image_name"]).permute(1, 2, 0))
        plt.axis('off')
        gender_real = df.iloc[i]["Gender_str"]
        gender_pred = "Male" if df.iloc[i]["Gender_pred"] < 0.5 else "Female"

        plt.title(f"Predicted: {gender_pred}, real : {gender_real}")
        

def plot_age_predictions(df):
    fig = plt.figure(figsize=(20, 14))
    rows = 4
    columns = 5

    for i in range(0, len(df.head(20))):
        fig.add_subplot(rows, columns, i+1)
        plt.imshow(io.read_image(df.iloc[i]["image_name"]).permute(1, 2, 0))
        plt.axis('off')
        age_real = df.iloc[i]["Age"]
        age_pred = df.iloc[i]["Age_pred_real"] 

        plt.title(f"Predicted: {age_pred}, real : {age_real}")

### Worst / best performing gender classification samples

In [46]:
plot_gender_predictions(df_gender_worst)

- It seems that the majority of misclassified samples had incorrect gender labels (while it is not obvious for babies, most of the cases in worst 20 samples are women, which had male label in the original data).
- Some of predicted females, when real label is male, have quite peculiar feminine face features. 

In [47]:
plot_gender_predictions(df_gender_best)

- Out of 20 best gender classification sampels 19 are men. 
- Most of these men tend to have quite noticeable facial hair, especially beard.

### Worst / best performing age classification samples

In [48]:
plot_age_predictions(df_age_worst)

- Worst age predictions are dominated by photos with old people. Training data consisted mostly of young people, so model probably didn't learn old people's features well enough.

In [50]:
plot_age_predictions(df_age_best)

- Best predictions are between the ages of 20-35

# Lime explanations

I will try to use Lime explanation tool to see which areas contribute to model prediction the most.
Code is written by Lime examples in Pytorch: https://github.com/marcotcr/lime/blob/master/doc/notebooks/Tutorial%20-%20images%20-%20Pytorch.ipynb

In [51]:
def get_pil_transform():
    
    transf = transforms.Compose(
      [transforms.Resize((256, 256)), transforms.CenterCrop(224)]
    )

    return transf

def get_preprocess_transform():
    
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])     
    transf = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])    

    return transf    

pill_transf = get_pil_transform()
preprocess_transform = get_preprocess_transform()

def age_predict(images: np.array):
    """
    Returns person's age given image as numpy array.
    """
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    batch = torch.stack(tuple(preprocess_transform(i) for i in images), dim=0)
    batch = batch.to(device)
    age, _ = model(batch)
    age = torch.cat(tuple(age.unsqueeze(0))).detach().cpu().reshape(-1, 1)
    age = np.round(np.exp(age*4.75))
    return age

def gender_predict(images: np.array):
    """
    Returns person's gender (0 for male and 1 for female) given image as numpy array.
    """
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    batch = torch.stack(tuple(preprocess_transform(i) for i in images), dim=0)
    batch = batch.to(device)
    _, gender = model(batch)
    gender = torch.nn.Sigmoid()(torch.cat(tuple(gender.unsqueeze(0)))).detach().cpu().numpy().reshape(-1, 1)
    return gender

In [53]:
def plot_gender_explainability(df):
    
    fig = plt.figure(figsize=(20, 14))
    rows = 4
    columns = 5
    
    for i in range(0, len(df.head(20))):
    
        explainer = lime_image.LimeImageExplainer()
        worst_gender_explanation = explainer.explain_instance(
        io.read_image(df['image_name'].iloc[i]).permute(1, 2, 0).numpy(),
        gender_predict,
        top_labels=5,
        hide_color=0,
        num_samples=1000
        )

        worst_temp_gender, worst_mask_gender = worst_gender_explanation.get_image_and_mask(
            worst_gender_explanation.top_labels[0],
            positive_only=False,
            num_features=3,
            hide_rest=False
        )

        worst_img_boundary_gender = mark_boundaries(worst_temp_gender/255.0, worst_mask_gender)

        fig.add_subplot(rows, columns, i+1)
        plt.title(f"Gender pred: {df['Gender_pred_binary'].iloc[i]}, real: {df['Gender'].iloc[i]}")
        plt.imshow(worst_img_boundary_gender)
        plt.axis('off')

In [54]:
plot_gender_explainability(df_gender_worst)

In [56]:
plot_gender_explainability(df_gender_best)

# Conclusions

- Gender-wise model performs very well (f-1 score being about 0.8);
- Main pain points in gender misclassification are incorrect labels in train/test data;
- Some men exhibit strong feminine features - men with long hair and smooth skin may be interpreted as females, whereas men with dense facial hair are among the best classification samples.
- Age predictions are poorest among old people - this could have been expected having in mind low number of elder training samples. 
- Age predictions are best for 20-40 year age groups. Since this age category was abundant in training, model learned their features very well so bias towards younger age is apparent. 
- Although indians/asians tend to dominate worst gender predictions, the underlying problem could be low picture resolution and mislabeling. Nevertheless, this could cause ethnical issues.