## Train

In [None]:
import tensorflow as tf

from tensorflow.keras.preprocessing.image import ImageDataGenerator, DirectoryIterator
from tensorflow.keras.applications import EfficientNetB4
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras.applications.efficientnet import preprocess_input as preprocess_input
from tensorflow.keras.layers import GlobalAveragePooling2D, BatchNormalization, Dropout, Dense, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import CSVLogger, EarlyStopping
from tensorflow.keras import Model
from tensorflow.errors import ResourceExhaustedError

import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path


def calc_class_weights(train_iterator):
    """
    Calculate class weighs dictionary to use as input for the cnn training. This is useful if the training set is
    imbalanced.

    The weight of class "i" is calculated as the number of samples in the most populated class divided by the number of
    samples in class i (max_class_frequency / class_frequency).
    Note that the class weights are capped at 10. This is done in order to avoid placing too much weight on
    small fraction of the dataset. For the same reason, the weight is set to 1 for any class in the training set that
    contains fewer than 5 samples.

    :param class_counts: A list with the number of files for each class.
    :return:
    """

    # Fixed parameters
    class_counts = np.unique(train_iterator.classes, return_counts=True)
    class_weights = []
    max_freq = max(class_counts[1])
    class_weights = [max_freq / count for count in class_counts[1]]
    
    print("Classes: " + str(class_counts[0]))
    print("Samples per class: " + str(class_counts[1]))
    print("Class weights: " + str(class_weights))

    return class_weights


def unfreeze_layers(model, last_fixed_layer):
    # Retrieve the index of the last fixed layer and add 1 so that it is also set to not trainable
    first_trainable = model.layers.index(model.get_layer(last_fixed_layer)) + 1

    # Set which layers are trainable.
    for layer_idx, layer in enumerate(model.layers):
        if not isinstance(layer, BatchNormalization):
            layer.trainable = layer_idx >= first_trainable
    return model


def build_model(optimiser, last_fixed_layer):
    model = EfficientNetB4(include_top=False, input_shape=(400, 300, 3), weights="imagenet")

    # Freeze the pretrained weights
    model.trainable = False

    # Rebuild top
    x = GlobalAveragePooling2D(name="avg_pool")(model.output)
    x = BatchNormalization()(x)

    top_dropout_rate = 0.2
    x = Dropout(top_dropout_rate, name="top_dropout")(x)
    outputs = Dense(6, activation="softmax", name="pred")(x)

    model = unfreeze_layers(model, last_fixed_layer)
    
    # Compile
    model = Model(model.input, outputs, name="EfficientNet")
    model.compile(
        optimizer=optimiser, loss="categorical_crossentropy", metrics=["accuracy"]
    )
    return model


def train_model(rotation, shear, zoom, brightness, lr, last_fixed_layer, batch_size):
    model_name = f'efficientnetb4_r{rotation}_s{shear}_z{zoom}_b{brightness}_lr{lr}_l{last_fixed_layer}'
    if os.path.exists(Path('.') / (model_name + '.h5')):
        print(f'{model_name} already trained')
        return
    print(f'Now training {model_name}')
    
    train_generator = ImageDataGenerator(
        horizontal_flip=True,
        vertical_flip=True,
        rotation_range=rotation,
        shear_range=shear,
        zoom_range=zoom,
        brightness_range=brightness,
        fill_mode='nearest',
        preprocessing_function=preprocess_input,
    )
    train_iterator = train_generator.flow_from_directory(
        '/home/ubuntu/store/internal-neurips/hpo/train',
        target_size=(400, 300),
        class_mode='categorical',
        batch_size=batch_size,
        follow_links=True,
        interpolation='bilinear',
    )

    valid_generator = ImageDataGenerator(
        fill_mode='nearest',
        preprocessing_function=preprocess_input
    )
    valid_iterator = valid_generator.flow_from_directory(
        '/home/ubuntu/store/internal-neurips/hpo/valid',
        batch_size=batch_size, 
        target_size=(400, 300),
        class_mode='categorical',
        follow_links=True,
        interpolation='bilinear',
    )

    loss_weights = calc_class_weights(train_iterator)

    optimiser = Adam(lr=lr)
    model = build_model(optimiser, last_fixed_layer)
    
    logger = CSVLogger(model_name + '.csv')
    early_stopping = EarlyStopping(monitor='val_loss', min_delta=0.02, patience=5, verbose=1, mode='auto', restore_best_weights=True)

    model.fit(
        x=train_iterator,
        batch_size=batch_size,
        epochs=100,
        verbose=True,
        validation_data=valid_iterator,
        class_weight=dict(zip(range(6), loss_weights)),
        workers=8,
        callbacks=[logger, early_stopping]
    )
    model.save(model_name + '.h5')


rotation_ranges = [10, 20]
shear_ranges = [0, 0.25, 0.5]
zoom_ranges = [0.25, 0.5]
brightness_ranges = [[0.25, 0.5], [0.5, 1], [0.25, 1]]
learning_rates = [0.001, 0.0001]
last_fixed_layers = ['top_conv', 'block6d_add']

for rotation in rotation_ranges:
    for shear in shear_ranges:
        for zoom in zoom_ranges:
            for brightness in brightness_ranges:
                for lr in learning_rates:
                    for last_fixed_layer in last_fixed_layers:
                        try:
                            train_model(rotation, shear, zoom, brightness, lr, last_fixed_layer, 64)
                        except ResourceExhaustedError:
                            print('Using batch size 32')
                            train_model(rotation, shear, zoom, brightness, lr, last_fixed_layer, 32)

## Validate

In [None]:
import tensorflow as tf

from tensorflow.keras.preprocessing.image import ImageDataGenerator, DirectoryIterator
from tensorflow.keras.applications import EfficientNetB4
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras.applications.efficientnet import preprocess_input as preprocess_input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import CSVLogger, EarlyStopping
from tensorflow.keras import Model
from tensorflow.errors import ResourceExhaustedError

import numpy as np
import os
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
import glob

base_path = "/home/ubuntu/store/efficientnet-hpo"
model_names = glob.glob("/home/ubuntu/store/efficientnet-hpo/*.h5")

for model_path in model_names:
    model_name = Path(model_path).stem
    if os.path.exists(Path(base_path) / (model_name + '_preds.csv')):
        print(f'{model_name} already validated')
        continue
    print('Now validating', model_name)
    valid_generator = ImageDataGenerator(
        fill_mode='nearest',
        preprocessing_function=preprocess_input
    )
    valid_iterator = valid_generator.flow_from_directory(
        '/home/ubuntu/store/internal-neurips/hpo/valid',
        batch_size=8, 
        target_size=(400, 300),
        class_mode='categorical',
        follow_links=True,
        interpolation='bilinear',
        shuffle=False
    )
    
    model = load_model(Path(model_path))
    preds = [np.argmax(pred) for pred in model.predict(valid_iterator)]
    actual = valid_iterator.labels
    preds_df = pd.DataFrame.from_dict({'actual': actual, 'pred': preds}).to_pickle(Path(base_path) / (model_name + '_preds.csv'))
    

## Compare models

In [4]:
import glob
import pandas as pd
from pathlib import Path
from sklearn.metrics import classification_report

base_path = "/home/ubuntu/store/efficientnet-hpo"
model_preds = glob.glob("/home/ubuntu/store/efficientnet-hpo/*_preds.csv")
model_comparison_dict = {}

for model_pred in model_preds:
    model_preds_df = pd.read_pickle(Path(model_pred))
    model_comparison_dict[Path(model_pred).stem] = classification_report(
        model_preds_df['actual'], 
        model_preds_df['pred'],
        labels=[0, 1, 2, 3, 4, 5],
        target_names=['acne', 'actinic_keratosis', 'psoriasis_no_pustular', 'seborrheic_dermatitis', 'vitiligo', 'wart'],
        output_dict=True
    )['macro avg']
    
model_comparison_df = pd.DataFrame.from_dict(model_comparison_dict, orient='index')
model_comparison_df

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,precision,recall,f1-score,support
"efficientnetb4_r10_s0.5_z0.5_b[0.25, 1]_lr0.001_lblock6d_add_preds",0.793494,0.777861,0.781666,641
"efficientnetb4_r10_s0_z0.5_b[0.5, 1]_lr0.001_lblock6d_add_preds",0.737619,0.764948,0.743953,641
"efficientnetb4_r10_s0.25_z0.5_b[0.25, 1]_lr0.0001_lblock6d_add_preds",0.757570,0.734771,0.739731,641
"efficientnetb4_r10_s0_z0.5_b[0.25, 1]_lr0.0001_ltop_conv_preds",0.563621,0.639468,0.570758,641
"efficientnetb4_r10_s0.25_z0.5_b[0.25, 0.5]_lr0.001_lblock6d_add_preds",0.643750,0.639668,0.624256,641
...,...,...,...,...
"efficientnetb4_r10_s0_z0.25_b[0.25, 1]_lr0.0001_ltop_conv_preds",0.563897,0.677711,0.573536,641
"efficientnetb4_r20_s0.25_z0.5_b[0.25, 0.5]_lr0.0001_ltop_conv_preds",0.508510,0.564646,0.485648,641
"efficientnetb4_r10_s0.5_z0.25_b[0.5, 1]_lr0.001_lblock6d_add_preds",0.814395,0.749505,0.754685,641
"efficientnetb4_r20_s0_z0.5_b[0.5, 1]_lr0.0001_lblock6d_add_preds",0.749392,0.750045,0.745791,641


In [6]:
best_models_df = model_comparison_df[(model_comparison_df['precision'] > 0.75) & (model_comparison_df['recall'] > 0.75) & (model_comparison_df['f1-score'] > 0.75)].sort_values('f1-score', ascending=False)
best_models_df

Unnamed: 0,precision,recall,f1-score,support
"efficientnetb4_r20_s0_z0.5_b[0.5, 1]_lr0.001_lblock6d_add_preds",0.798606,0.777725,0.786215,641
"efficientnetb4_r10_s0.5_z0.5_b[0.25, 1]_lr0.001_lblock6d_add_preds",0.793494,0.777861,0.781666,641
"efficientnetb4_r20_s0.5_z0.5_b[0.5, 1]_lr0.0001_lblock6d_add_preds",0.777884,0.784633,0.779742,641
"efficientnetb4_r20_s0.5_z0.5_b[0.5, 1]_lr0.001_lblock6d_add_preds",0.803701,0.757207,0.777556,641
"efficientnetb4_r10_s0.25_z0.25_b[0.5, 1]_lr0.001_lblock6d_add_preds",0.786794,0.769685,0.773469,641
"efficientnetb4_r10_s0.25_z0.25_b[0.25, 0.5]_lr0.001_lblock6d_add_preds",0.785545,0.764286,0.772651,641
"efficientnetb4_r20_s0.5_z0.25_b[0.5, 1]_lr0.001_lblock6d_add_preds",0.791573,0.756057,0.768411,641
"efficientnetb4_r20_s0.25_z0.5_b[0.25, 1]_lr0.0001_lblock6d_add_preds",0.750105,0.798738,0.768339,641
"efficientnetb4_r10_s0_z0.25_b[0.5, 1]_lr0.0001_lblock6d_add_preds",0.763526,0.776867,0.767004,641
"efficientnetb4_r20_s0.5_z0.5_b[0.25, 1]_lr0.001_lblock6d_add_preds",0.765928,0.776757,0.765638,641


rotation 20, shear 0.5, zoom 0.5, brightness [0.5, 1],  lr 0.001, last fixed layer 'block6d_add_preds'

In [7]:
from IPython.display import display, HTML

for model_name in best_models_df.index.values:    
    logs_df = pd.read_csv(model_name[:-6] + '.csv')
    display(HTML(logs_df.to_html()))
    print(model_name, len(logs_df))

Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.401477,4.184147,0.525741,2.500236
1,1,0.484648,3.055843,0.577223,1.348631
2,2,0.595025,2.480273,0.691108,1.5702
3,3,0.640886,2.124576,0.513261,1.391444
4,4,0.663039,1.935008,0.634945,1.08854
5,5,0.717062,1.710684,0.74571,0.799861
6,6,0.739604,1.56462,0.74259,0.86859
7,7,0.795569,1.106989,0.828393,0.67069
8,8,0.791294,1.296559,0.815913,0.701485
9,9,0.806063,1.300761,0.815913,0.704059


efficientnetb4_r20_s0_z0.5_b[0.5, 1]_lr0.001_lblock6d_add_preds 19


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.450836,3.964501,0.553822,1.686082
1,1,0.55344,2.815843,0.595944,1.101737
2,2,0.589195,2.475139,0.527301,1.302883
3,3,0.647105,2.209675,0.683307,0.90933
4,4,0.689468,1.966991,0.792512,0.674695
5,5,0.683638,1.858347,0.75663,0.76996
6,6,0.75204,1.41086,0.75195,0.718228
7,7,0.775748,1.21967,0.804992,0.558002
8,8,0.813836,1.098167,0.839314,0.529723
9,9,0.803731,1.068731,0.74415,0.688391


efficientnetb4_r10_s0.5_z0.5_b[0.25, 1]_lr0.001_lblock6d_add_preds 14


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.408084,3.778816,0.581903,1.271987
1,1,0.586086,2.341193,0.624025,1.115347
2,2,0.654878,1.915306,0.694228,0.957469
3,3,0.70618,1.602301,0.74103,0.837674
4,4,0.743879,1.370994,0.74883,0.768395
5,5,0.769141,1.272129,0.804992,0.652311
6,6,0.772639,1.124426,0.812793,0.563436
7,7,0.809561,0.956396,0.804992,0.587823
8,8,0.834435,0.852871,0.822153,0.528858
9,9,0.848815,0.806624,0.819033,0.571236


efficientnetb4_r20_s0.5_z0.5_b[0.5, 1]_lr0.0001_lblock6d_add_preds 16


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.453167,3.690572,0.680187,1.094021
1,1,0.56471,2.90284,0.468019,2.082849
2,2,0.640886,2.050738,0.697348,0.860923
3,3,0.691799,1.791114,0.636505,1.130367
4,4,0.719394,1.804484,0.723869,1.10572
5,5,0.719394,1.858013,0.722309,1.081925
6,6,0.750097,1.279463,0.74103,0.83976
7,7,0.810727,1.065125,0.769111,0.715667
8,8,0.832491,0.947534,0.815913,0.674481
9,9,0.846094,0.813652,0.834633,0.634047


efficientnetb4_r20_s0.5_z0.5_b[0.5, 1]_lr0.001_lblock6d_add_preds 20


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.439953,4.030571,0.399376,1.416088
1,1,0.606685,2.530653,0.641186,0.867079
2,2,0.670424,2.169568,0.647426,0.980157
3,3,0.743879,1.480806,0.848674,0.522767
4,4,0.814225,0.972894,0.781591,0.744824
5,5,0.825107,1.032908,0.826833,0.615946
6,6,0.846871,0.837719,0.798752,0.725874
7,7,0.79246,1.25642,0.765991,0.879748
8,8,0.7637,1.758729,0.658346,0.8978


efficientnetb4_r10_s0.25_z0.25_b[0.5, 1]_lr0.001_lblock6d_add_preds 9


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.355616,4.257462,0.666147,1.168263
1,1,0.495142,3.085498,0.50546,1.3293
2,2,0.552274,3.006044,0.641186,1.369301
3,3,0.565876,2.647481,0.616225,1.118353
4,4,0.664982,1.974423,0.765991,0.853278
5,5,0.715896,1.646955,0.695788,0.926716
6,6,0.722114,1.490834,0.786271,0.834683
7,7,0.7637,1.366637,0.76287,0.717985
8,8,0.801788,1.09781,0.804992,0.68143
9,9,0.821609,0.921063,0.708268,0.767749


efficientnetb4_r10_s0.25_z0.25_b[0.25, 0.5]_lr0.001_lblock6d_add_preds 19


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.473766,3.827008,0.722309,0.976201
1,1,0.611349,2.313862,0.74103,0.707895
2,2,0.72911,1.45712,0.709828,0.737564
3,3,0.777691,1.186815,0.814353,0.543914
4,4,0.813059,0.927369,0.74415,0.720181
5,5,0.856976,0.77635,0.808112,0.59361
6,6,0.874854,0.623733,0.848674,0.512532
7,7,0.865527,0.735697,0.815913,0.553305
8,8,0.898951,0.544376,0.815913,0.599666
9,9,0.916051,0.429283,0.815913,0.673551


efficientnetb4_r20_s0.5_z0.25_b[0.5, 1]_lr0.001_lblock6d_add_preds 12


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.41508,3.835193,0.472699,1.409796
1,1,0.551496,2.567791,0.686427,1.086545
2,2,0.627283,2.093856,0.736349,0.944417
3,3,0.662651,1.852224,0.74727,0.826172
4,4,0.714341,1.549939,0.717629,0.815508
5,5,0.744267,1.389949,0.797192,0.599323
6,6,0.750097,1.31191,0.75351,0.67455
7,7,0.776914,1.182927,0.770671,0.616052
8,8,0.805286,1.019422,0.792512,0.577861
9,9,0.811115,0.987967,0.798752,0.534012


efficientnetb4_r20_s0.25_z0.5_b[0.25, 1]_lr0.0001_lblock6d_add_preds 18


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.464827,3.392236,0.625585,1.193047
1,1,0.626506,2.115282,0.74883,0.966693
2,2,0.706957,1.620419,0.708268,0.946528
3,3,0.755927,1.302782,0.776911,0.751356
4,4,0.792849,1.05452,0.795632,0.667017
5,5,0.831325,0.875331,0.837754,0.550739
6,6,0.860085,0.713795,0.839314,0.500648
7,7,0.865527,0.672334,0.843994,0.509463
8,8,0.872911,0.628071,0.845554,0.472963
9,9,0.892732,0.56198,0.850234,0.510639


efficientnetb4_r10_s0_z0.25_b[0.5, 1]_lr0.0001_lblock6d_add_preds 14


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.439953,3.897685,0.49298,3.38796
1,1,0.565099,2.690736,0.692668,1.04422
2,2,0.621842,2.240409,0.606864,1.135364
3,3,0.661485,2.005903,0.833073,0.620494
4,4,0.700738,1.708317,0.778471,0.759255
5,5,0.731442,1.651646,0.803432,0.667986
6,6,0.732219,1.435765,0.783151,0.640602
7,7,0.784298,1.252723,0.726989,0.773545
8,8,0.783521,1.294715,0.731669,0.805558


efficientnetb4_r20_s0.5_z0.5_b[0.25, 1]_lr0.001_lblock6d_add_preds 9


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.433346,3.595486,0.633385,1.290572
1,1,0.62534,2.159081,0.731669,0.966052
2,2,0.697629,1.637824,0.726989,0.933326
3,3,0.748931,1.363872,0.75195,0.80452
4,4,0.783521,1.09943,0.770671,0.719102
5,5,0.815002,0.881846,0.789392,0.641511
6,6,0.840264,0.817984,0.820593,0.572713
7,7,0.859308,0.728471,0.819033,0.527151
8,8,0.887291,0.565942,0.829953,0.510678
9,9,0.905169,0.490061,0.834633,0.528199


efficientnetb4_r10_s0.5_z0.25_b[0.5, 1]_lr0.0001_lblock6d_add_preds 16


Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.398756,3.756074,0.620905,1.251268
1,1,0.566265,2.538315,0.627145,1.131342
2,2,0.662262,1.912611,0.684867,0.931595
3,3,0.687136,1.663923,0.75819,0.768506
4,4,0.736494,1.450996,0.783151,0.63992
5,5,0.766032,1.216486,0.804992,0.586036
6,6,0.795181,1.099574,0.784711,0.611881
7,7,0.810727,0.987264,0.815913,0.531227
8,8,0.83288,0.831178,0.820593,0.531951
9,9,0.827439,0.847613,0.836193,0.51449


efficientnetb4_r10_s0.25_z0.5_b[0.5, 1]_lr0.0001_lblock6d_add_preds 17
