In the 'markdown cell' below  replace the `???` with the names of those in your group.

???

# Assignment (part 2): classification of cell morphological changes with ResNet
_by Phil Harrison (February 2021), verified by David Holmberg (2023)_

## Load packages

In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

# TensorFlow and Keras
import tensorflow as tf
from tensorflow.keras import models, layers, optimizers
from tensorflow.keras.regularizers import l2
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Helper libraries
import random
import itertools
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.utils import class_weight
import pandas as pd
from PIL import Image

from tensorflow.python.util import deprecation
deprecation._PRINT_DEPRECATION_WARNINGS = False
print(tf.__version__)

## Functions
Don't worry too much about the code in the functions below, but you might want to go through when they are called later on so that you roughly understand what they're doing.

In [None]:
def load_dataset():
    dirname = 'bbbc021v1_images'
    x_orig = np.zeros((660, 256, 256, 3), dtype=np.float32)

    for f in range(x_orig.shape[0]):
        img    = Image.open(dirname + '/bbbc021v1_%s.png' % str(f))
        img    = np.array(img)
        x_orig[f] = img

    labels = pd.read_csv('bbbc021v1_labels.csv',
                          usecols=["compound", "concentration", "moa"],
                          sep=";")
    y_orig = np.array(labels['moa'])

    return x_orig, y_orig

def convert_to_one_hot(y, C):
    moa_dict = {'Aurora kinase inhibitors': 0, 'Cholesterol-lowering': 1,
                'Eg5 inhibitors': 2, 'Protein synthesis': 3, 'DNA replication': 4, 'DNA damage': 5}

    y = np.asarray([moa_dict[item] for item in y])
    y = np.eye(C)[y]
    y = y.astype('float32')

    return y

def plot_history(model_history, model_name):
    fig = plt.figure(figsize=(15, 5), facecolor='w')
    ax = fig.add_subplot(131)
    ax.plot(model_history.history['loss'])
    ax.plot(model_history.history['val_loss'])
    ax.set(title=model_name + ': Model loss', ylabel='Loss', xlabel='Epoch')
    ax.legend(['train', 'valid'], loc='upper right')
    
    ax = fig.add_subplot(132)
    ax.plot(np.log(model_history.history['loss']))
    ax.plot(np.log(model_history.history['val_loss']))
    ax.set(title=model_name + ': Log model loss', ylabel='Log loss', xlabel='Epoch')
    ax.legend(['Train', 'Test'], loc='upper right')    

    ax = fig.add_subplot(133)
    ax.plot(model_history.history['accuracy'])
    ax.plot(model_history.history['val_accuracy'])
    ax.set(title=model_name + ': Model accuracy', ylabel='Accuracy', xlabel='Epoch')
    ax.legend(['train', 'valid'], loc='upper right')
    plt.show()
    plt.close()

def plot_confusion_matrix(cm, classes, model_name,
                          cmap=plt.cm.Blues):
    title = model_name + ': Confusion Matrix'
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    
def valid_evaluate(model, model_name):
    y_pred = model.predict(X_valid)
    y_pred = y_pred.argmax(axis=-1)
    y_true = Y_valid.argmax(axis=-1)
    
    class_names = ['Aur', 'Ch', 'Eg5', 'PS', 'DR', 'DS']
    cnf_matrix = confusion_matrix(y_true, y_pred)
    np.set_printoptions(precision=2)
    plt.figure(figsize=(15,5), facecolor='w')
    plot_confusion_matrix(cnf_matrix, classes=class_names, model_name=model_name)
    plt.show()
    plt.close()
    
    print('')
    print('classification report for validation data:')
    print(classification_report(y_true, y_pred, digits=3))

## Read in and preprocess the data

In [None]:
X_orig, y_orig = load_dataset()
Y = convert_to_one_hot(y_orig, 6)
X = X_orig/255.

n_train = 500

random.seed(5026)
indices = np.arange(len(Y))
random.shuffle(indices)

X_train, X_valid = X[indices[:n_train]], X[indices[n_train:]]
Y_train, Y_valid = Y[indices[:n_train]], Y[indices[n_train:]]

## ResNet - a deep residual CNN
In this section you will implement a deeper state-of-the-art CNN. Specifically, the CNN that you will implement is not too different from the residual CNN presented by He _et al._ (https://arxiv.org/pdf/1512.03385.pdf). 

A deep residual network contains dozens of residual blocks (see Figure 1) with intermediate normalization. 

<p>
    <img src="figs/residual_block.png" alt="drawing" style="width:400px;"/>
    <center>Figure 1. A residual block - the building block of a residual network.</center>
</p>

The identity mapping is often called a _skip-connection_ or _shortcut_ and has shown to assist in avoiding the degrading effect of training very deep networks --- a degrading effect that is most apparent for "plain" networks. Thus, the residual implementation made it possible to successfully train a very deep CNN that outperformed all the other CNNs that came before.

<p>
    <img src="figs/resnet.png" alt="drawing" style="width:1200px;"/>
    <center>Figure 2. Example of a ResNet architecture.</center>
</p>

### Batch-normalization

Normalization of input can help in improving neural networks. The idea of batch-normalization is to take this normalization to the intermediate layers. Specifically, batch-normalization normalizes the layer's output before it goes through the activation function (hence we no longer do activation within the convolutional layer, but have an additional layer for this). This can make the neural network more stable and faster at training.

Note: although people sometimes use batch normalization and dropout in the same network, these two approaches can sometimes interfere with one another. Hence we will not use dropout in our ResNet models.

## Useful ResNet functions
Fill in the `???` in the code below (based on the guidance given after them).

In [None]:
def conv(inputs, num_filters, strides):
    x = layers.Conv2D(num_filters, kernel_size=(3, 3), strides=strides, padding='same',
                      kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs)
    return x

def ResNet_block(inputs, num_filters, strides):
    x = conv(inputs, num_filters=num_filters, strides=strides)
    ??? # add batch normalisation
    ??? # add a relu activation
    x = conv(x, num_filters=num_filters, strides=1)
    ??? # add batch normalisation
    if strides > 1:
        y = conv(inputs, num_filters=num_filters, strides=strides)
        ??? # add batch normalisation (for y this time)
    else:
        y = inputs
    z = layers.add([x, y])
    z = layers.Activation('relu')(z)
    return z

## Define ResNet model
Coding ResNet is rather more involved than LeNet. Note, Figure 2 is simply for illustrative purposes, the number of filters there, blocks per stack etc., is different to what we would like you to implement. Your task is to fill in the `???` again (based on the guidance given after them) and then jot down the architecture of the model based on the code in this cell and the one above (keep this succinct, i.e. don't try to write what's going on in every layer, or you'll be here all day (!), just jot down the basics in terms of how it is structured). Once you've done that, run the cell and compare your output with ours. 

Hint: again look to the code given in the lectures for guidance here.

In [None]:
num_classes = 6
inps = layers.Input((256, 256, 3))
x = layers.Conv2D(32, kernel_size=(5, 5), strides=2, padding='same',
                 kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inps)

# add batch normalisation
# add a relu activation
# add max pooling (3x3, with a stride of 2 and 'same' padding)

num_stacks = 3
num_blocks_per_stack = [3, 4, 3]
num_filters_in_stack = [64, 128, 256]

for i in range(num_stacks):
    num_filters = num_filters_in_stack[i]
    for j in range(num_blocks_per_stack[i]):
        if(j == 0):
            strides = 2
        else:
            strides = 1
        #x = ???  # call the ResNet_block function here with the apprporate inputs
        # hint for the call above, look to where we call the conv function within the ResNet block 

# Add pooling and define output
# define a softmax dense output layer (named "preds")


ResNet = models.Model(inputs=inps, outputs=preds)
ResNet.summary()

## Did you code in ResNet correctly?
At the end of the summary you should see the following:

Total params: 4,960,838

Trainable params: 4,953,990

Non-trainable params: 6,848

## Resnet model description
In the 'markdown' cell below replace `???` with your brief summary of your ResNet model architecture.

???

## Compile the ResNet model
Note for the ResNet model we will use a lower learning rate of 0.0001 as opposed to the default value of 0.001 that we used for the LeNet model.

In [None]:
lr = 0.0001
#ResNet.compile(???)

## Fit and evaluate the ResNet model

In [None]:
batch_size = 32
n_epochs = 50

history = ResNet.fit(X_train, Y_train,
                     batch_size=batch_size,
                     epochs=n_epochs,
                     validation_data=(X_valid, Y_valid),
                     verbose=2)

plot_history(history, 'ResNet')
valid_evaluate(ResNet, 'ResNet')

## Multiple runs of the same model
Again, as we did for the LeNet model, re-run your ResNet model 5 times to compute the mean 'weighted avg f1-score'. Show this in the code cell below. For my five runs I got:

(0.769 + 0.815 + 0.720 + 0.777 + 0.842) / 5 = 0.785

## Notes
Your probably noticed that the training was quite eratic for your ResNet model above (i.e. the training curve was quite jaggedy). We could combat this with changes to the model architecture, the batch size and learning rate. We could also perhaps improve things with data augmentation. We will not explore these modifications here though.

This was a big model that we trained, with almost 5 million parameters. With only 500 images for training this was rather ambitious! A better way to deal with a small dataset size, when we still want to fit a fairly complex and descriptive model (that will hopefully make good predictions), is to use transfer learning. This will be the focus of part 4 of the assignemnt.