# Training a model.

In this notebook, we will train a convolutional neural network for facial recognition using a technique known as transfer learning. Specifically the model will classify whether an image of a detected face is the designated celebrity or is not the designated celebrity.

Transfer learning is a technique where we can use a model pre-trained for one task, and repurpose it on a new and related task. This will allow us to achieve higher performance than training a model from scratch.

Below is a nice introduction into the definition of transfer learning, taken from the following resource: https://cs231n.github.io/transfer-learning/

 "In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a dataset of sufficient size. Instead, it is common to pretrain a ConvNet on a very large dataset (e.g. ImageNet, which contains 1.2 million images with 1000 categories), and then use the ConvNet either as an initialization or a fixed feature extractor for the task of interest."

To implement the training of our model, we will use the PyTorch framework. Here is a reference for an example using PyTorch: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

The pre-trained model we will leverage on is Inception ResNet v1, which has been trained on the VGGFace 2 dataset. The VGGFace2 dataset consists of approximately 3.3M faces and 9000 classes. This model has been trained to extract features from face images. Here is the VGGFace2 dataset website: https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/

## 1. Install requirements

In [1]:
# Mount drive.
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

ModuleNotFoundError: No module named 'google'

In [None]:
!ls

In [None]:
# Change to your folder directory.
%cd gdrive/MyDrive/Project-1

In [None]:
!pip install facenet_pytorch

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory

cudnn.benchmark = True
plt.ion()   # Interactive mode.

import cv2
import copy
from facenet_pytorch import InceptionResnetV1

# 2. Load data and visualize sample images.

The dataset should be split into train and val folders, with each folder containing a folder for each class.

In [None]:
# Data augmentation and normalization for training data.
# Normalization only for validation data.
data_transforms = {
  'train': transforms.Compose([
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  ]),
  'val': transforms.Compose([
      transforms.Resize(299),
      transforms.CenterCrop(299),
      transforms.ToTensor(),
      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  ]),
}

data_dir = 'face_images'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4, drop_last=True)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

class_names = image_datasets['train'].classes

In [None]:
print(f"Number of images in train dataloader: {len(dataloaders['train'])}")
print(f"Number of images in val dataloader: {len(dataloaders['val'])}")
print(f'Class names: {class_names}')

In [None]:
# Set device to GPUs if avaialable.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using " + str(device))

In [None]:
# Function to display images.
def imshow(inp, title=None):
  """Imshow for Tensor."""
  inp  = inp.numpy().transpose((1, 2, 0))
  mean = np.array([0.485, 0.456, 0.406])
  std  = np.array([0.229, 0.224, 0.225])
  inp  = std * inp + mean
  inp  = np.clip(inp, 0, 1)
  plt.imshow(inp)
  if title is not None:
      plt.title(title)
  plt.pause(5)  # Pause a bit so that visualizations are updated.

In [None]:
# Get a batch of training data.
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch.
out = torchvision.utils.make_grid(inputs)

# Visualize sampel images.
imshow(out, title=[class_names[x] for x in classes])

## 3. Train and evaluate the model.

In [None]:
# Function to train model.
def train_model(model, criterion, optimizer, scheduler, num_epochs=30):
  since = time.time()

  #best_model_wts = copy.deepcopy(model.state_dict())
  best_acc = 0.0

  for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
      if phase == 'train':
        model.train()  # Set model to training mode.
      else:
        model.eval()   # Set model to evaluate mode.

      running_loss = 0.0
      running_corrects = 0

      # Iterate over data.
      for inputs, labels in dataloaders[phase]:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero the parameter gradients.
        optimizer.zero_grad()

        # Forward pass.
        # Track history if only in train.
        with torch.set_grad_enabled(phase == 'train'):
          outputs = model(inputs)
          _, preds = torch.max(outputs, 1)
          loss = criterion(outputs, labels)

          # Backpropagation only in train.
          if phase == 'train':
            loss.backward()
            optimizer.step()

          # statistics
          running_loss += loss.item() * inputs.size(0)
          running_corrects += torch.sum(preds == labels.data)

        if phase == 'train':
          scheduler.step()

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(
          phase, epoch_loss, epoch_acc))

        # Deep copy the model.
        if phase == 'val' and epoch_acc > best_acc:
          best_acc = epoch_acc
          best_model_wts = copy.deepcopy(model.state_dict())

      print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

  return model

In [None]:
# Function to visualize model predictions on validation set.
def visualize_model(model, num_images=6):
  was_training = model.training
  model.eval()
  images_so_far = 0
  fig = plt.figure()

  with torch.no_grad():
    for i, (inputs, labels) in enumerate(dataloaders['val']):
      inputs = inputs.to(device)
      labels = labels.to(device)

      outputs = model(inputs)
      _, preds = torch.max(outputs, 1)
      print(f'preds: {preds}')

      for j in range(inputs.size()[0]):
        images_so_far += 1
        ax = plt.subplot(num_images//2, 2, images_so_far)
        ax.axis('off')
        ax.set_title('predicted: {}'.format(class_names[preds[j]]))
        imshow(inputs.cpu().data[j])

        if images_so_far == num_images:
          model.train(mode=was_training)
          return
    model.train(mode=was_training)

In [None]:
# Setup model
# Specify the number of classes.
model_ft = InceptionResnetV1(pretrained='vggface2', device=device, classify= True, num_classes=len(class_names))

print(model_ft.logits)

model_ft = model_ft.to(device)

# Setup loss function.
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized.
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs.
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [None]:
# Train model.
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=10)

# Save the trained model to be used later for inference.
model_path = "trained_model.pt"
print("Saving model "+model_path);
torch.save(model_ft.state_dict(), model_path)

In [None]:
# Visualize model on validation data.
visualize_model(model_ft)