## Train models

In [None]:
import tensorflow as tf

from tensorflow.keras.preprocessing.image import ImageDataGenerator, DirectoryIterator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras.applications.resnet import preprocess_input as resnet_preprocessing
from tensorflow.keras.layers import Dense, Flatten, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import CSVLogger, EarlyStopping
from tensorflow.keras import Model
from tensorflow.errors import ResourceExhaustedError

import numpy as np
import os
import matplotlib.pyplot as plt
from pathlib import Path

def calc_class_weights(train_iterator):
    """
    Calculate class weighs dictionary to use as input for the cnn training. This is useful if the training set is
    imbalanced.

    The weight of class "i" is calculated as the number of samples in the most populated class divided by the number of
    samples in class i (max_class_frequency / class_frequency).
    Note that the class weights are capped at 10. This is done in order to avoid placing too much weight on
    small fraction of the dataset. For the same reason, the weight is set to 1 for any class in the training set that
    contains fewer than 5 samples.

    :param class_counts: A list with the number of files for each class.
    :return:
    """

    # Fixed parameters
    class_counts = np.unique(train_iterator.classes, return_counts=True)
    class_weights = []
    max_freq = max(class_counts[1])
    class_weights = [max_freq / count for count in class_counts[1]]
    
    print("Classes: " + str(class_counts[0]))
    print("Samples per class: " + str(class_counts[1]))
    print("Class weights: " + str(class_weights))

    return class_weights


def unfreeze_layers(model, last_fixed_layer):
    # Retrieve the index of the last fixed layer and add 1 so that it is also set to not trainable
    first_trainable = model.layers.index(model.get_layer(last_fixed_layer)) + 1

    # Set which layers are trainable.
    for layer_idx, layer in enumerate(model.layers):
        if not isinstance(layer, BatchNormalization):
            layer.trainable = layer_idx >= first_trainable
    return model


def train_model(rotation, shear, zoom, brightness, lr, last_fixed_layer, batch_size):
    model_name = f'resnet50_r{rotation}_s{shear}_z{zoom}_b{brightness}_lr{lr}_l{last_fixed_layer}'
    if os.path.exists(Path('.') / (model_name + '.h5')):
        print(f'{model_name} already trained')
        return
    print(f'Now training {model_name}')
    
    train_generator = ImageDataGenerator(
        horizontal_flip=True,
        vertical_flip=True,
        rotation_range=rotation,
        shear_range=shear,
        zoom_range=zoom,
        brightness_range=brightness,
        fill_mode='nearest',
        preprocessing_function=resnet_preprocessing,
    )
    train_iterator = train_generator.flow_from_directory(
        '/home/ubuntu/store/internal-neurips/hpo/train', 
        target_size=(400, 300),
        class_mode='categorical',
        batch_size=batch_size,
        follow_links=True,
        interpolation='bilinear',
    )

    valid_generator = ImageDataGenerator(
        fill_mode='nearest',
        preprocessing_function=resnet_preprocessing
    )
    valid_iterator = valid_generator.flow_from_directory(
        '/home/ubuntu/store/internal-neurips/hpo/valid', 
        batch_size=batch_size, 
        target_size=(400, 300),
        class_mode='categorical',
        follow_links=True,
        interpolation='bilinear',
    )

    loss_weights = calc_class_weights(train_iterator)

    base_model = ResNet50(include_top=False, weights='imagenet', input_shape=(400, 300, 3))
    top_model = Flatten()(base_model.output)
    top_model = Dense(6, activation='softmax', name='diagnosis')(top_model)
    model = Model(inputs=base_model.input, outputs=top_model)
    model = unfreeze_layers(model, last_fixed_layer)
    
    optimiser = Adam(lr=lr)
    model.compile(
        optimizer=optimiser,
        loss='categorical_crossentropy',
        metrics=['accuracy'],
        loss_weights=loss_weights,
    )
    
    logger = CSVLogger(model_name + '.csv')
    early_stopping = EarlyStopping(monitor='val_loss', min_delta=0.5, patience=5, verbose=1, mode='auto', restore_best_weights=True)

    model.fit(
        x=train_iterator,
        batch_size=batch_size,
        epochs=100,
        verbose=True,
        validation_data=valid_iterator,
        class_weight=dict(zip(range(6), loss_weights)),
        workers=8,
        callbacks=[logger, early_stopping]
    )
    model.save(model_name + '.h5')


rotation_ranges = [10, 20]
shear_ranges = [0, 0.25, 0.5]
zoom_ranges = [0.25, 0.5]
brightness_ranges = [[0.25, 0.5], [0.5, 1], [0.25, 1]]
learning_rates = [0.01, 0.001, 0.0001]
last_fixed_layers = ['conv5_block3_out', 'conv5_block2_add']

for rotation in rotation_ranges:
    for shear in shear_ranges:
        for zoom in zoom_ranges:
            for brightness in brightness_ranges:
                for lr in learning_rates:
                    for last_fixed_layer in last_fixed_layers:
                        try:
                            model_name = train_model(rotation, shear, zoom, brightness, lr, last_fixed_layer, 64)
                        except ResourceExhaustedError:
                            print('Using batch size 32')
                            train_model(rotation, shear, zoom, brightness, lr, last_fixed_layer, 32)

## Validate models

In [None]:
import tensorflow as tf

from tensorflow.keras.preprocessing.image import ImageDataGenerator, DirectoryIterator
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras.applications.resnet import preprocess_input as resnet_preprocessing
from tensorflow.keras.layers import Dense, Flatten, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import CSVLogger, EarlyStopping
from tensorflow.keras import Model
from tensorflow.errors import ResourceExhaustedError

import numpy as np
import os
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
import glob

base_path = "/home/ubuntu/store/resnet50-hpo"
model_names = glob.glob("/home/ubuntu/store/resnet50-hpo/*.h5")

for model_path in model_names:
    model_name = Path(model_path).stem
    if os.path.exists(Path(base_path) / (model_name + '_preds.csv')):
        print(f'{model_name} already validated')
        continue
    print('Now validating', model_name)
    valid_generator = ImageDataGenerator(
        fill_mode='nearest',
        preprocessing_function=resnet_preprocessing
    )
    valid_iterator = valid_generator.flow_from_directory(
        '/home/ubuntu/store/internal-neurips/hpo/valid', 
        batch_size=8, 
        target_size=(400, 300),
        class_mode='categorical',
        follow_links=True,
        interpolation='bilinear',
        shuffle=False
    )
    
    model = load_model(Path(model_path))
    preds = [np.argmax(pred) for pred in model.predict(valid_iterator)]
    actual = valid_iterator.labels
    preds_df = pd.DataFrame.from_dict({'actual': actual, 'pred': preds}).to_pickle(Path(base_path) / (model_name + '_preds.csv'))
    

## Compare models

In [None]:
import glob
import pandas as pd
from pathlib import Path
from sklearn.metrics import classification_report

base_path = "/home/ubuntu/store/resnet50-hpo"
model_preds = glob.glob("/home/ubuntu/store/resnet50-hpo/*_preds.csv")
model_comparison_dict = {}

for model_pred in model_preds:
    model_preds_df = pd.read_pickle(Path(model_pred))
    model_comparison_dict[Path(model_pred).stem] = classification_report(
        model_preds_df['actual'], 
        model_preds_df['pred'],
        labels=[0, 1, 2, 3, 4, 5],
        target_names=['acne', 'actinic_keratosis', 'psoriasis_no_pustular', 'seborrheic_dermatitis', 'vitiligo', 'wart'],
        output_dict=True
    )['macro avg']
    
model_comparison_df = pd.DataFrame.from_dict(model_comparison_dict, orient='index')
model_comparison_df

In [9]:
best_models_df = model_comparison_df[(model_comparison_df['precision'] > 0.7) & (model_comparison_df['recall'] > 0.7) & (model_comparison_df['f1-score'] > 0.7)].sort_values('f1-score', ascending=False)
best_models_df

Unnamed: 0,precision,recall,f1-score,support
"resnet50_r20_s0_z0.25_b[0.5, 1]_lr0.01_lconv5_block3_out_preds",0.729912,0.735866,0.729631,641
"resnet50_r20_s0.25_z0.25_b[0.25, 1]_lr0.01_lconv5_block3_out_preds",0.745634,0.724153,0.726307,641
"resnet50_r20_s0.5_z0.5_b[0.25, 1]_lr0.01_lconv5_block2_add_preds",0.728156,0.725648,0.72424,641
"resnet50_r10_s0_z0.25_b[0.25, 1]_lr0.001_lconv5_block2_add_preds",0.722007,0.733677,0.723806,641
"resnet50_r20_s0.25_z0.5_b[0.5, 1]_lr0.001_lconv5_block3_out_preds",0.728203,0.721818,0.72199,641
"resnet50_r10_s0.5_z0.5_b[0.5, 1]_lr0.01_lconv5_block3_out_preds",0.71308,0.739585,0.720816,641
"resnet50_r20_s0.5_z0.25_b[0.5, 1]_lr0.01_lconv5_block3_out_preds",0.707829,0.721732,0.712132,641
"resnet50_r20_s0.5_z0.5_b[0.25, 1]_lr0.01_lconv5_block3_out_preds",0.707968,0.731797,0.710242,641
"resnet50_r10_s0.5_z0.25_b[0.5, 1]_lr0.01_lconv5_block3_out_preds",0.705382,0.715649,0.709712,641
"resnet50_r20_s0_z0.25_b[0.25, 1]_lr0.01_lconv5_block3_out_preds",0.713733,0.721145,0.706567,641


In [11]:
from IPython.display import display, HTML

for model_name in best_models_df.index.values:    
    logs_df = pd.read_csv(model_name[:-6] + '.csv')
    display(HTML(logs_df.to_html()))
    print(model_name, len(logs_df))

Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.487369,414.064178,0.455538,194.195251
1,1,0.675476,121.848412,0.595944,146.393692
2,2,0.704236,70.723602,0.733229,59.304382
3,3,0.75515,53.026775,0.720749,43.929764
4,4,0.772639,43.271111,0.75507,36.612503
5,5,0.819277,24.796917,0.781591,24.423382
6,6,0.80412,19.86515,0.784711,24.246513
7,7,0.822386,16.299892,0.776911,24.489233
8,8,0.841042,12.726364,0.783151,19.484083
9,9,0.856976,10.645636,0.809672,17.028006


resnet50_r20_s0_z0.25_b[0.5, 1]_lr0.01_lconv5_block3_out_preds 25


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.487369,416.181458,0.574103,109.584122
1,1,0.637777,154.860748,0.650546,70.147728
2,2,0.72328,76.987877,0.720749,45.54903
3,3,0.751652,53.539047,0.75195,34.473633
4,4,0.767586,38.848713,0.731669,37.715549
5,5,0.774582,38.943649,0.784711,31.306559
6,6,0.814225,25.856995,0.723869,50.586929
7,7,0.802176,26.890816,0.76287,29.279236
8,8,0.845317,20.086269,0.776911,24.900476
9,9,0.833269,18.187519,0.733229,27.268709


resnet50_r20_s0.25_z0.25_b[0.25, 1]_lr0.01_lconv5_block3_out_preds 24


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.461329,317.858643,0.425897,423.073822
1,1,0.62262,90.337578,0.647426,56.506836
2,2,0.67159,63.533806,0.709828,36.610832
3,3,0.700738,36.576229,0.74103,27.785564
4,4,0.718228,25.25395,0.74883,29.275126
5,5,0.729499,25.05715,0.74883,21.445164
6,6,0.748931,17.558575,0.787831,13.311394
7,7,0.777691,13.440453,0.75351,14.316554
8,8,0.75787,12.982221,0.75195,10.788032
9,9,0.770696,10.383837,0.780031,9.563357


resnet50_r20_s0.5_z0.5_b[0.25, 1]_lr0.01_lconv5_block2_add_preds 20


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.476098,51.000229,0.382215,275.169556
1,1,0.683638,13.420738,0.544462,18.308062
2,2,0.727167,9.843394,0.639626,10.271335
3,3,0.753595,7.730604,0.728549,10.337798
4,4,0.785853,7.657276,0.605304,24.410791
5,5,0.797124,7.078298,0.723869,9.656201
6,6,0.839876,5.337606,0.789392,7.028327
7,7,0.860863,4.508052,0.784711,7.246513
8,8,0.862029,5.037836,0.731669,8.258326
9,9,0.889623,2.944255,0.76131,6.801443


resnet50_r10_s0_z0.25_b[0.25, 1]_lr0.001_lconv5_block2_add_preds 20


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.441508,63.95343,0.630265,15.557788
1,1,0.619899,34.969814,0.681747,14.765365
2,2,0.659541,33.460552,0.737909,13.417816
3,3,0.701904,28.431211,0.772231,12.54357
4,4,0.70618,30.964615,0.74259,13.329675
5,5,0.716285,28.692139,0.789392,11.495328
6,6,0.761757,23.273226,0.75039,14.471905
7,7,0.773416,21.344955,0.765991,12.881423
8,8,0.780412,22.264132,0.812793,10.541589
9,9,0.782355,18.599251,0.801872,12.335538


resnet50_r20_s0.25_z0.5_b[0.5, 1]_lr0.001_lconv5_block3_out_preds 14


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.440731,528.930847,0.597504,113.306519
1,1,0.652157,135.989075,0.652106,74.826157
2,2,0.683249,86.368828,0.622465,94.583458
3,3,0.682472,77.680756,0.678627,46.039829
4,4,0.750097,46.98988,0.720749,41.066307
5,5,0.755927,40.071758,0.798752,26.689423
6,6,0.765643,39.087505,0.773791,34.096508
7,7,0.794792,29.238258,0.769111,31.217899
8,8,0.822386,22.296017,0.74103,58.614883
9,9,0.815779,19.709564,0.797192,24.933765


resnet50_r10_s0.5_z0.5_b[0.5, 1]_lr0.01_lconv5_block3_out_preds 34


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.473766,592.13678,0.597504,96.499954
1,1,0.683638,101.136803,0.634945,66.383965
2,2,0.759425,63.234089,0.653666,60.129517
3,3,0.75515,45.099049,0.684867,60.736897
4,4,0.785464,31.588005,0.731669,30.959131
5,5,0.813447,22.620234,0.76131,22.859409
6,6,0.822386,19.387192,0.723869,33.43581
7,7,0.851535,12.059542,0.76287,24.975794
8,8,0.859697,11.784209,0.767551,21.356123
9,9,0.847649,14.752234,0.778471,22.407398


resnet50_r20_s0.5_z0.25_b[0.5, 1]_lr0.01_lconv5_block3_out_preds 21


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.475709,471.859344,0.517941,250.262695
1,1,0.621454,156.072006,0.609984,90.902893
2,2,0.628449,121.758064,0.677067,47.728428
3,3,0.705791,75.744896,0.586583,100.818001
4,4,0.698795,60.591217,0.687988,56.571484
5,5,0.736106,46.882404,0.709828,42.874454
6,6,0.748931,34.804932,0.731669,36.19772
7,7,0.769141,28.886871,0.664587,48.154182
8,8,0.792849,22.894073,0.776911,25.966173
9,9,0.758648,23.287191,0.770671,18.656345


resnet50_r20_s0.5_z0.5_b[0.25, 1]_lr0.01_lconv5_block3_out_preds 18


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.535562,443.392578,0.723869,115.900131
1,1,0.68869,109.374115,0.636505,65.684372
2,2,0.721726,73.426147,0.622465,48.457268
3,3,0.783133,39.26506,0.695788,34.617832
4,4,0.795569,24.554098,0.687988,30.928776
5,5,0.82433,19.464151,0.776911,19.195822
6,6,0.847649,12.787503,0.764431,15.726027
7,7,0.851924,15.039034,0.789392,15.983541
8,8,0.853867,11.328144,0.798752,14.733813
9,9,0.879907,8.120616,0.74883,19.551027


resnet50_r10_s0.5_z0.25_b[0.5, 1]_lr0.01_lconv5_block3_out_preds 25


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.513797,383.442902,0.436817,121.834412
1,1,0.649436,146.841415,0.692668,61.798668
2,2,0.711621,86.498055,0.75351,59.805553
3,3,0.761757,60.807564,0.683307,74.995628
4,4,0.773805,46.667713,0.75351,48.519901
5,5,0.787408,39.333702,0.736349,43.298141
6,6,0.781967,39.537136,0.776911,41.042229
7,7,0.810727,29.126343,0.798752,32.860729
8,8,0.84143,18.012793,0.765991,29.873812
9,9,0.82705,19.897484,0.726989,32.346447


resnet50_r20_s0_z0.25_b[0.25, 1]_lr0.01_lconv5_block3_out_preds 29


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.486203,39.440475,0.581903,21.95878
1,1,0.620288,20.258921,0.653666,11.162831
2,2,0.661096,17.059778,0.606864,15.187747
3,3,0.715119,14.662372,0.75663,9.976078
4,4,0.739992,14.31179,0.769111,7.696912
5,5,0.743879,12.92852,0.73947,12.210646
6,6,0.796735,10.195427,0.728549,9.559237
7,7,0.784298,11.23291,0.667707,14.297707
8,8,0.779635,10.361997,0.74259,9.129226
9,9,0.797513,9.690266,0.772231,8.592833


resnet50_r20_s0.25_z0.5_b[0.5, 1]_lr0.001_lconv5_block2_add_preds 10


rotation 20, shear 0, zoom 0.25, brightness [0.25, 1], learning rate 0.01, last fixed layer conv5_block3_out