## Train models

In [None]:
import tensorflow as tf

from tensorflow.keras.preprocessing.image import ImageDataGenerator, DirectoryIterator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras.applications.resnet import preprocess_input as resnet_preprocessing
from tensorflow.keras.layers import Dense, Flatten, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import CSVLogger, EarlyStopping
from tensorflow.keras import Model
from tensorflow.errors import ResourceExhaustedError

import numpy as np
import os
import matplotlib.pyplot as plt
from pathlib import Path

def calc_class_weights(train_iterator):
    """
    Calculate class weighs dictionary to use as input for the cnn training. This is useful if the training set is
    imbalanced.

    The weight of class "i" is calculated as the number of samples in the most populated class divided by the number of
    samples in class i (max_class_frequency / class_frequency).
    Note that the class weights are capped at 10. This is done in order to avoid placing too much weight on
    small fraction of the dataset. For the same reason, the weight is set to 1 for any class in the training set that
    contains fewer than 5 samples.

    :param class_counts: A list with the number of files for each class.
    :return:
    """

    # Fixed parameters
    class_counts = np.unique(train_iterator.classes, return_counts=True)
    class_weights = []
    max_freq = max(class_counts[1])
    class_weights = [max_freq / count for count in class_counts[1]]
    
    print("Classes: " + str(class_counts[0]))
    print("Samples per class: " + str(class_counts[1]))
    print("Class weights: " + str(class_weights))

    return class_weights


def unfreeze_layers(model, last_fixed_layer):
    # Retrieve the index of the last fixed layer and add 1 so that it is also set to not trainable
    first_trainable = model.layers.index(model.get_layer(last_fixed_layer)) + 1

    # Set which layers are trainable.
    for layer_idx, layer in enumerate(model.layers):
        if not isinstance(layer, BatchNormalization):
            layer.trainable = layer_idx >= first_trainable
    return model


def train_model(rotation, shear, zoom, brightness, lr, last_fixed_layer, batch_size, idx):
    model_name = f'resnet50_{idx}'
    if os.path.exists(Path('.') / (model_name + '.h5')):
        print(f'{model_name} already trained')
        return
    print(f'Now training {model_name}')
    
    train_generator = ImageDataGenerator(
        horizontal_flip=True,
        vertical_flip=True,
        rotation_range=rotation,
        shear_range=shear,
        zoom_range=zoom,
        brightness_range=brightness,
        fill_mode='nearest',
        preprocessing_function=resnet_preprocessing,
    )
    train_iterator = train_generator.flow_from_directory(
        '/home/ubuntu/store/internal-neurips/full', 
        target_size=(300, 400),
        class_mode='categorical',
        batch_size=batch_size,
        follow_links=True,
        interpolation='bilinear',
    )
    
    loss_weights = calc_class_weights(train_iterator)

    base_model = ResNet50(include_top=False, weights='imagenet', input_shape=(300, 400, 3))
    top_model = Flatten()(base_model.output)
    top_model = Dense(6, activation='softmax', name='diagnosis')(top_model)
    model = Model(inputs=base_model.input, outputs=top_model)
    model = unfreeze_layers(model, last_fixed_layer)
    
    optimiser = Adam(lr=lr)
    model.compile(
        optimizer=optimiser,
        loss='categorical_crossentropy',
        metrics=['accuracy'],
        loss_weights=loss_weights,
    )
    
    logger = CSVLogger(model_name + '.csv')

    model.fit(
        x=train_iterator,
        batch_size=batch_size,
        epochs=30,
        verbose=True,
        class_weight=dict(zip(range(6), loss_weights)),
        workers=8,
        callbacks=[logger]
    )
    model.save(model_name + '.h5')
    
for idx in range(5):
    model_name = train_model(20, 0, 0.25, [0.25, 1], 0.01, 'conv5_block2_add', 64, idx)

## Validate models

In [2]:
import tensorflow as tf

from tensorflow.keras.preprocessing.image import ImageDataGenerator, DirectoryIterator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras.applications.resnet import preprocess_input as resnet_preprocessing
from tensorflow.keras.layers import Dense, Flatten, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import CSVLogger, EarlyStopping
from tensorflow.keras import Model
from tensorflow.errors import ResourceExhaustedError

import numpy as np
import os
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
import glob
from sklearn.metrics import classification_report

base_path = "/home/ubuntu/store/resnet-final-size"
model_names = glob.glob("/home/ubuntu/store/resnet-final-size/*.h5")

for model_path in model_names:
    model_name = Path(model_path).stem
    if os.path.exists(Path(base_path) / (model_name + '_preds.csv')):
        print(f'{model_name} already validated')
        continue
    print('Now validating', model_name)
    valid_generator = ImageDataGenerator(
        fill_mode='nearest',
        preprocessing_function=resnet_preprocessing
    )
    valid_iterator = valid_generator.flow_from_directory(
        '/home/ubuntu/store/DermX-test-set/test', 
        batch_size=8, 
        target_size=(300, 400),
        class_mode='categorical',
        follow_links=True,
        interpolation='bilinear',
        shuffle=False
    )
    
    model = load_model(Path(model_path))
    preds = [np.argmax(pred) for pred in model.predict(valid_iterator)]
    actual = valid_iterator.labels
    preds_df = pd.DataFrame.from_dict({'actual': actual, 'pred': preds, 'filenames': valid_iterator.filenames}).to_pickle(Path(base_path) / (model_name + '_preds.csv'))
    

resnet50_0 already validated
resnet50_2 already validated
resnet50_1 already validated
resnet50_4 already validated
resnet50_3 already validated


## Compare models

In [3]:
base_path = "/home/ubuntu/store/resnet-final"
model_preds = glob.glob("/home/ubuntu/store/resnet-final-size/*_preds.csv")
model_comparison_dict = {}

for model_pred in model_preds:
    model_preds_df = pd.read_pickle(Path(model_pred))
    model_comparison_dict[Path(model_pred).stem] = classification_report(
        model_preds_df['actual'], 
        model_preds_df['pred'],
        labels=[0, 1, 2, 3, 4, 5],
        target_names=['acne', 'actinic_keratosis', 'psoriasis_no_pustular', 'seborrheic_dermatitis', 'vitiligo', 'wart'],
        output_dict=True
    )['macro avg']
    model_comparison_dict[Path(model_pred).stem]['accuracy'] = len(model_preds_df[model_preds_df['actual'] == model_preds_df['pred']]) / len(model_preds_df)

model_comparison_df = pd.DataFrame.from_dict(model_comparison_dict, orient='index')
model_comparison_df

Unnamed: 0,precision,recall,f1-score,support,accuracy
resnet50_3_preds,0.380841,0.303244,0.263395,566,0.318021
resnet50_2_preds,0.399679,0.301546,0.275586,566,0.316254
resnet50_0_preds,0.420597,0.373788,0.34599,566,0.390459
resnet50_1_preds,0.373701,0.361956,0.356411,566,0.378092
resnet50_4_preds,0.430262,0.298456,0.267258,566,0.312721
