In the 'markdown cell' below  replace the `???` with the names of those in your group.

???

# Assignment (part3): Classification of cell morphological changes with transfer learning
_by Phil Harrison (February 2021)_

Note: only those students who wish to potentially achieve a `VG` grade need do this part of the assignment.

## Load packages

In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

# TensorFlow and Keras
import tensorflow as tf
from tensorflow.keras import models, layers, optimizers
from tensorflow.keras.regularizers import l2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.resnet50 import preprocess_input

# Helper libraries
import random
import itertools
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.utils import class_weight
import pandas as pd
from PIL import Image

from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False
print(tf.__version__)

## Functions
Don't worry too much about the code in the functions below, but you might want to go through when they are called later on so that you roughly understand what they're doing.

In [None]:
def load_dataset():
    dirname = 'bbbc021v1_images'
    x_orig = np.zeros((660, 256, 256, 3), dtype=np.float32)

    for f in range(x_orig.shape[0]):
        img    = Image.open(dirname + '/bbbc021v1_%s.png' % str(f))
        img    = np.array(img)
        x_orig[f] = img

    labels = pd.read_csv('bbbc021v1_labels.csv',
                          usecols=["compound", "concentration", "moa"],
                          sep=";")
    y_orig = np.array(labels['moa'])

    return x_orig, y_orig

def convert_to_one_hot(y, C):
    moa_dict = {'Aurora kinase inhibitors': 0, 'Cholesterol-lowering': 1,
                'Eg5 inhibitors': 2, 'Protein synthesis': 3, 'DNA replication': 4, 'DNA damage': 5}

    y = np.asarray([moa_dict[item] for item in y])
    y = np.eye(C)[y]
    y = y.astype('float32')

    return y

def plot_history(model_history, model_name):
    fig = plt.figure(figsize=(15, 5), facecolor='w')
    ax = fig.add_subplot(131)
    ax.plot(model_history.history['loss'])
    ax.plot(model_history.history['val_loss'])
    ax.set(title=model_name + ': Model loss', ylabel='Loss', xlabel='Epoch')
    ax.legend(['train', 'valid'], loc='upper right')
    
    ax = fig.add_subplot(132)
    ax.plot(np.log(model_history.history['loss']))
    ax.plot(np.log(model_history.history['val_loss']))
    ax.set(title=model_name + ': Log model loss', ylabel='Log loss', xlabel='Epoch')
    ax.legend(['Train', 'Test'], loc='upper right')    

    ax = fig.add_subplot(133)
    ax.plot(model_history.history['accuracy'])
    ax.plot(model_history.history['val_accuracy'])
    ax.set(title=model_name + ': Model accuracy', ylabel='Accuracy', xlabel='Epoch')
    ax.legend(['train', 'valid'], loc='upper right')
    plt.show()
    plt.close()

def plot_confusion_matrix(cm, classes, model_name,
                          cmap=plt.cm.Blues):
    title = model_name + ': Confusion Matrix'
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    
def valid_evaluate(model, model_name):
    y_pred = model.predict(X_valid)
    y_pred = y_pred.argmax(axis=-1)
    y_true = Y_valid.argmax(axis=-1)
    
    class_names = ['Aur', 'Ch', 'Eg5', 'PS', 'DR', 'DS']
    cnf_matrix = confusion_matrix(y_true, y_pred)
    np.set_printoptions(precision=2)
    plt.figure(figsize=(15,5), facecolor='w')
    plot_confusion_matrix(cnf_matrix, classes=class_names, model_name=model_name)
    plt.show()
    plt.close()
    
    print('')
    print('classification report for validation data:')
    print(classification_report(y_true, y_pred, digits=3))

## Read in and preprocess the data

In [None]:
X_orig, y_orig = load_dataset()
Y = convert_to_one_hot(y_orig, 6)
X = preprocess_input(X_orig)

n_train = 500

random.seed(5026)
indices = np.arange(len(Y))
random.shuffle(indices)

X_train, X_valid = X[indices[:n_train]], X[indices[n_train:]]
Y_train, Y_valid = Y[indices[:n_train]], Y[indices[n_train:]]

## Transfer learning
TRANSFER LEARNING: "_For transfer learning, large annotated datasets, like ImageNet, can be used to pre‐train state‐of‐the‐art CNNs such as Resnet and Inception. The transferred parameter values—providing good initial values for gradient descent—can be fine‐tuned to fit the target data. Alternatively, the pre‐trained parameters in the initial layers can be frozen—capturing generic image representations—while the parameters in the final layers can be fine‐tuned to the current task. Relative to training from scratch, transfer learning allows the fitting of deeper networks, using fewer task‐specific annotated images, for improved classification performance and generalizability_".

Read this TensorFlow tutorial (https://www.tensorflow.org/tutorials/images/transfer_learning). Some things they do there a little differently than we have, don't worry about those. But modify their code cell where they define the `base_model` appropriate for our case (ie. replace `???` below accordingly) where we will use the ResNet50 pretrained (base) model and chage the image shape appropriately for our data.

In [None]:
???

In the code cell below make it such that you don't train the base model (i.e. to start with we freeze the base model and only train the part we've added to the top. Also add a line of code to summarise your base model.

In [None]:
???

Your model above should have a whopping 23,587,712 parameters!

In [None]:
num_classes = 6
x = base_model.output
??? # add a global average pooling layer
preds = ??? # add your final dense layer with a softmax activation

ResNet50 = models.Model(inputs=base_model.input, outputs=preds)
ResNet50.summary()

Your model should now have 23,600,006 parameters.

## Training with the base frozen
In the code cell below compile your `ResNet50` model to use the Adam optimizer (with a learning rate of 0.001), categorical_crossentropy loss and have `accuracy` as a metric to keep track of.

In [None]:
???

In the code cell below fit, plot and evaluate your model. Use a batch size of 32 and train your model for 10 epochs. We don't need to train for that long when we are essentially only training the end part of the network that we tacked onto the base model.

In [None]:
???

## Training with the base unfrozen
In the code cell below set your base model to trainable, fit for 30 epochs with a much lower learning rate of 0.00001. Re-compile the model, re-fit it and plot and evaluate it. 

Provided you don't re-define the model here the training will continue where it left off above.

We set such a low learning rate so as not to move the base model's weights too far away from whey they started, we just want to tweak them a little to adapt the entire model better to our data.

Some practicioners advocate only training the later layers in the base model, given that the early layers tend to capture simple features, such as edges and blobs that are common to all image data. With this data, however, we have found it best to train the entire base model, so that is what we'll do here. 

In [None]:
???

## So...
for my run of the above model I a got a weighted avg f1-score of 0.988! I'm guessing you got something equally impressive. So of all the tricks we've learnt this week it seems that transfer learning shines out as a very good thing to do, expecially when our datasets are not particularly large, as is often the case in image cytometry.

I hope you enjoyed this week and are inspired to dive deeper into deep learning. Either way, that's it. Now you're free to run away and enjoy your weekend.

Cheers, - Phil

# THE END