# Transfer Learning for Image Classification

This notebook uses image classification models from [torchvision](https://pytorch.org/vision/stable/index.html) that were originally trained using [ImageNet](https://image-net.org/) and does transfer learning with a torchvision dataset or your own raw images.

The notebook performs the following steps:
1. [Import dependencies and setup parameters](#1.-Import-dependencies-and-setup-parameters)
2. [Prepare the dataset](#2.-Prepare-the-dataset)
3. [Predict using the original model](#3.-Predict-using-the-original-model)
4. [Transfer learning](#4.-Transfer-learning)
5. [Visualize the model output](#5.-Visualize-the-model-output)
6. [Export the saved model](#6.-Export-the-saved-model)

## 1. Import dependencies and setup parameters

In [None]:
import os
import time
import math
import numpy as np
import pandas as pd
import torch
import torchvision
from torchvision import datasets, models, transforms
from PIL import Image
from pydoc import locate
import warnings

import intel_extension_for_pytorch as ipex
import matplotlib.pyplot as plt

from model_utils import torchvision_model_map, get_retrainable_model

warnings.filterwarnings("ignore")

print('Supported models:')
print('\n'.join(torchvision_model_map.keys()))

In [None]:
# Specify a model from the list above
model_name = "efficientnet_b0"

# Specify the the parent directory for the custom or torchvision dataset
dataset_directory = os.environ["DATASET_DIR"] if "DATASET_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "dataset")
    
# Specify a directory for output
output_directory = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "output")

# Batch size
batch_size = 32

In [None]:
if model_name not in torchvision_model_map.keys():
    raise ValueError("The specified model_name ({}) is invalid. Please select from: {}".
                     format(model_name, torchvision_model_map.keys()))
    
print("Pretrained Image Classification Model:", model_name)   

## 2. Prepare the dataset

Define transforms for data resizing and augmentation. The normalization means and standard deviations `[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]` are specific to torchvision image classification models and are explained in the [documentation](https://pytorch.org/vision/stable/models.html#classification).

In [None]:
# Preprocessing transforms
import torchvision.transforms as T

def get_transform(train):
    transforms = []
    transforms.append(T.Resize([256, 256]))
    if train:
        transforms.append(T.RandomHorizontalFlip())
    transforms.append(T.ToTensor())
    transforms.append(T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
    
    return T.Compose(transforms)

### Option A: Use a torchvision dataset

To use a torchvision dataset, load from the torchvision.datasets library, applying transforms for image augmentation, normalization, and resizing. This example uses the Food101 dataset from the [torchvision datasets for image classification](https://pytorch.org/vision/stable/datasets.html#image-classification), but you can choose from a variety of options. If the dataset is not found in the dataset directory it is downloaded. Subsequent runs will reuse the already downloaded dataset.

Note: Some torchvision datasets use a `train=True/False` argument and others have a `split="train"/"test"` convention. See the torchvision documentation to see how to specify the subset you want to use.

In [None]:
dataset = torchvision.datasets.Food101(dataset_directory, split='train',
                                            transform=get_transform(True), download=True)
dataset_test = torchvision.datasets.Food101(dataset_directory, split='test',
                                                 transform=get_transform(False), download=True)   
class_names = dataset.classes

print('Training data size: {}'.format(len(dataset)))
print('Validation data size: {}'.format(len(dataset_test)))

Now skip ahead to the [Predict using the original model](#3.-Predict-using-the-original-model) section.

### Option B: Use a downloaded or custom dataset

To use your own image dataset for transfer learning with the rest of this notebook, format your images as `.jpg` files and save them in folders named after the classes that you want the model to predict. To provide a working example using the correct layout, we will download and extract a flower species dataset. After downloading and extracting, you will have the following  subdirectories in your dataset directory. Each species subfolder will contain numerous `.jpg` files:

```
dataset_directory
└── flower_photos
    └── daisy
    └── dandelion
    └── roses
    └── sunflowers
    └── tulips
```

Use this as an example to organize your own image files accordingly.

In [None]:
# When you have your own properly organized subdirectory of images, adjust this variable
dataset_subdir = os.path.join(dataset_directory, "flower_photos")

In [None]:
# Only run this if you want to use the example flowers dataset
if not os.path.exists(dataset_subdir):
    os.mkdir(dataset_subdir)
    !apt-get update && apt-get -qq install curl
    !wget https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    !tar xvf flower_photos.tgz --directory $dataset_directory

In [None]:
dataset = datasets.ImageFolder(dataset_subdir, get_transform(True))
dataset_test = datasets.ImageFolder(dataset_subdir, get_transform(False))
class_names = dataset.classes

In [None]:
# Use 25% for validation and 75% for training
indices = torch.randperm(len(dataset)).tolist()
num_training_samples = math.floor(len(dataset)*.75)

dataset_test = torch.utils.data.Subset(dataset, indices[-num_training_samples:])
dataset = torch.utils.data.Subset(dataset, indices[:num_training_samples])   

## 3. Predict using the original model

In [None]:
# Create a data loader just for visualization
data_loader = torch.utils.data.DataLoader(dataset, batch_size=30,
                                          shuffle=True, num_workers=4)

In [None]:
# Get the ImageNet labels for displaying with the predictions
imagenet_classes = []
labels_file = 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
if not os.path.exists('ImageNetLabels.txt'):
    !wget $labels_file

with open('ImageNetLabels.txt') as f:
    imagenet_labels = f.readlines()
    imagenet_classes = [l.strip() for l in imagenet_labels]

In [None]:
# Get the pretrained torchvision model
pretrained_model_class = locate('torchvision.models.{}'.format(model_name))
model = pretrained_model_class(pretrained=True)

# Get a batch of training data
inputs, classes = next(iter(data_loader))

# Get predictions from the pretrained model
model.eval()
outputs = model(inputs)

In [None]:
# List of the actual labels for this batch
actual_label_batch = [class_names[int(id)] for id in classes]

# List of the predicted labels for this batch
_, predicted_id = torch.max(outputs, 1)
predicted_label_batch = [imagenet_classes[id] for id in predicted_id]

In [None]:
# Create a results table to list out the ImageNet class prediction vs the actual dataset label
results_table = []
for prediction, actual in zip(predicted_label_batch, actual_label_batch):
    results_table.append([prediction, actual])

pd.DataFrame(results_table, columns=["ImageNet Prediction", "Actual Label"])

In [None]:
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
    plt.subplot(6,5,n+1)
    inp = inputs[n]
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    plt.title(predicted_label_batch[n].title(), fontsize=9)
    plt.axis('off')
_ = plt.suptitle("ImageNet predictions")
plt.show()

## 4. Transfer learning

Replace the pretrained head of the network with a new layer based on the number of classes in our dataset. Train the model using the new dataset for the specified number of epochs.

In [None]:
# Number of training epochs
num_epochs = 1

# To reduce training time, the feature extractor layer can remain frozen (do_fine_tuning=False).
# Fine-tuning can be enabled to potentially get better accuracy. Note that enabling fine-tuning
# will increase training time.
do_fine_tuning = False

In [None]:
def main(model, criterion, optimizer, dataset, dataset_test, num_epochs=10):
    since = time.time()
    
    device = torch.device("cpu")
    model = model.to(device)
    best_acc = 0.0

    # Create data loaders for training and validation
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)
    data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size,
                                          shuffle=False, num_workers=4)
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0

        # Iterate over data.
        for inputs, labels in data_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward and backward pass
            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(dataset)
        epoch_acc = running_corrects.double() / len(dataset)

        print(f'Training Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        # Evaluation phase
        model.eval()
        running_loss = 0.0
        running_corrects = 0
            
        # Iterate over data.
        for inputs, labels in data_loader_test:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            with torch.set_grad_enabled(False):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                    
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            
        epoch_loss = running_loss / len(dataset_test)
        epoch_acc = running_corrects.double() / len(dataset_test)

        if epoch_acc > best_acc:
            best_acc = epoch_acc
                
        print(f'Validation Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        print()
        

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best Validation Accuracy: {best_acc:4f}')

    return model

In [None]:
model = get_retrainable_model(model_name, len(class_names), do_fine_tuning)
criterion = torch.nn.CrossEntropyLoss()

# Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

print('Trainable parameters: {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))

In [None]:
model, optimizer = ipex.optimize(model, optimizer=optimizer)
model = main(model, criterion, optimizer, dataset, dataset_test, num_epochs)

## 5. Visualize the model output

In [None]:
model.eval()
outputs = model(inputs)
_, predicted_id = torch.max(outputs, 1)
predicted_label_batch = [class_names[id] for id in predicted_id]

# Display the results
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)
for n in range(30):
    plt.subplot(6,5,n+1)
    inp = inputs[n]
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    correct_prediction = actual_label_batch[n] == predicted_label_batch[n]
    color = "darkgreen" if correct_prediction else "crimson"
    title = predicted_label_batch[n].title() if correct_prediction else "{}\n({})".format(predicted_label_batch[n], actual_label_batch[n]) 
    plt.title(title, fontsize=9, color=color)
    plt.axis('off')
_ = plt.suptitle("Model predictions")
plt.show()
print("Correct predictions are shown in green")
print("Incorrect predictions are shown in red with the actual label in parenthesis")


## 6. Export the saved model

In [None]:
if not os.path.exists(output_directory):
    !mkdir -p $output_directory
file_path = "{}/image_classification.pt".format(output_directory)
torch.save(model.state_dict(), file_path)
print("Saved to {}".format(file_path))

## Dataset citations
```
@inproceedings{bossard14,
  title = {Food-101 -- Mining Discriminative Components with Random Forests},
  author = {Bossard, Lukas and Guillaumin, Matthieu and Van Gool, Luc},
  booktitle = {European Conference on Computer Vision},
  year = {2014}
}

@ONLINE {tfflowers,
author = "The TensorFlow Team",
title = "Flowers",
month = "jan",
year = "2019",
url = "http://download.tensorflow.org/example_images/flower_photos.tgz" }
```