# Melanoma Classification

Kaggle Competition Page: www.kaggle.com/c/siim-isic-melanoma-classification/overview


<div class="alert alert-block alert-warning">  
<b> Imports.  </b>
</div>

In [None]:
import os
import shutil
import random
from pathlib import Path
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.applications.densenet import DenseNet121
from tensorflow.keras.applications.nasnet import NASNetMobile, NASNetLarge
from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Conv2D,MaxPool2D,GlobalAveragePooling2D,AveragePooling2D
import numpy as np
import cv2
import matplotlib.pyplot as plt
from datetime import datetime, date
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
import json
from sklearn.metrics import roc_curve, auc, precision_recall_curve, plot_precision_recall_curve, confusion_matrix
import itertools
from tqdm import tqdm
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

print("IMPORTS DONE")

<div class="alert alert-block alert-info">  
<b> Functions.  </b>
</div>

In [None]:
def check_image(path_to_images, fileName):
    absolutePath = path_to_images + fileName + IMAGE_TYPE
    img_file = Path(absolutePath)
    if img_file.is_file():
        return absolutePath
    return False

def load_data(path_to_images, csv_path, mode):
    print("Loading " + mode)
    data = pd.read_csv(csv_path)
    data['image_path'] = data['image_name'].apply(lambda img_name: check_image(path_to_images, img_name))
    data = data[data['image_path'] != False]
    print("valid rows in " + mode, data.shape[0])
    return data

def getTrainData(img_pixels):
    if img_pixels == 224:
        global train_224_backup
        if train_224_backup is not None:
            return train_224_backup.copy()
    elif img_pixels == 331:
        global train_331_backup
        if train_331_backup is not None:
            return train_331_backup.copy()
    else:
        raise 
    path_to_images = BASE_PATH_TO_IMAGES + str(img_pixels) + "/train/"
    csv_path = os.path.join(BASE_INPUT_PATH, "train.csv")
    data = load_data(path_to_images, csv_path, "train")
    
    if img_pixels == 224:
        train_224_backup = data.copy()
    elif img_pixels == 331:
        train_331_backup = data.copy()
    return data
    
def getTestData(img_pixels):
    if img_pixels == 224:
        global test_224_backup
        if test_224_backup is not None:
            return test_224_backup.copy()
    elif img_pixels == 331:
        global test_331_backup
        if test_331_backup is not None:
            return test_331_backup.copy()
    else:
        raise 
    path_to_images = BASE_PATH_TO_IMAGES + str(img_pixels) + "/test/"
    csv_path = os.path.join(BASE_INPUT_PATH, "test.csv")
    data = load_data(path_to_images, csv_path, "test")
    
    if img_pixels == 224:
        test_224_backup = data.copy()
    elif img_pixels == 331:
        test_331_backup = data.copy()
    return data

def doUndersampling(dataset, balance = 1):
#     print(dataset[dataset.target == POSITIVE_CLASS].shape, "positive in dataset")
#     print(dataset[dataset.target == NEGATIVE_CLASS].shape, "negative in dataset")
    p_inds = dataset[dataset.target == POSITIVE_CLASS].index.tolist()
    np_inds = dataset[dataset.target == NEGATIVE_CLASS].index.tolist()
    sample_size = int(balance * len(p_inds))
    np_sample = random.sample(np_inds, sample_size) if sample_size < len(np_inds) else np_inds
    returndataset = dataset.loc[p_inds + np_sample]
#     print(returndataset[returndataset.target == POSITIVE_CLASS].shape, "positive after under")
#     print(returndataset[returndataset.target == NEGATIVE_CLASS].shape, "negative after under")
    return returndataset
    
def trainTestPatientCheck(train,test):
    ids_train = set(train.patient_id.values)
    ids_test = set(test.patient_id.values)
    patient_overlap = list(ids_train.intersection(ids_test))
    n_overlap = len(patient_overlap)


def timestamp_and_experimentId(model_name, do_undersampling, prepro, prepro_rotation, prepro_blur, prepro_brightness, prepro_zoom, batch_size, epochs, base_model_trainable):
    timestamp = str(date.today()) + "_" + str(datetime.now().strftime("%H:%M:%S"))
    experiment_id = model_name + ("_Under" if do_undersampling else "_CW") + ("_PreFlips" if prepro else "") + ("Rot" if prepro_rotation else "") + ("Bright" if prepro_brightness else "")
    experiment_id = experiment_id + ("Blur" if prepro_blur else "") + ("Zoom" if prepro_zoom else "")
    experiment_id = experiment_id +"_"+ str(batch_size) + "B_" + str(epochs) + "E_" + ("FineTuning" if base_model_trainable else "Extractor")
    try:
        os.makedirs(experiment_id)
        open("./" +experiment_id+"/" + timestamp, 'w').close()
    except FileExistsError:
        pass
    return (timestamp, experiment_id)
    

def create_splits(df, test_size, classToPredict):
    train_data, val_data = train_test_split(df,  test_size = test_size, random_state = SEED, stratify = df[classToPredict])
    return train_data, val_data

def preprocess_rotation(image):
    if bool(random.getrandbits(1)):
        image = np.rot90(image, np.random.choice([-1, 1, 2]))
    return image

def preprocess_blur(image):
    if bool(random.getrandbits(1)):
        image = cv2.blur(image,(3,3))
    return image

def preprocess_both(image):
    if bool(random.getrandbits(1)):
        image = np.rot90(image, np.random.choice([-1, 1, 2]))
    if bool(random.getrandbits(1)):
        image = cv2.blur(image,(3,3))
    return image
 
def get_training_gen(df, prepro, prepro_brightness, prepro_zoom, prepro_rotation, prepro_blur, batch_size, img_size, doShuffle = True):
    if not prepro:
        train_idg = ImageDataGenerator(rescale=1. / 255.0)
    else:    
        brightnessRange = [0.75,1] if prepro_brightness else None
        zoomRange = [0.75,1] if prepro_zoom else 0.0
        if(prepro_rotation and prepro_blur):
            prepro_function = preprocess_both
        elif prepro_rotation:
            prepro_function = preprocess_rotation
        elif prepro_blur:
            prepro_function = preprocess_blur
        else:
            prepro_function = None
        train_idg = ImageDataGenerator(
            rescale = 1 / 255.0,
            horizontal_flip = True,
            vertical_flip = True,
            brightness_range = brightnessRange,
            zoom_range = zoomRange,
            preprocessing_function=prepro_function
        )

    train_gen = train_idg.flow_from_dataframe(
        seed=SEED,
        dataframe=df,
        directory=None,
        x_col='image_path',
        y_col='target',
        class_mode=CLASS_MODE,
        shuffle=doShuffle,
        target_size=img_size,
        batch_size=batch_size,
        validate_filenames = False
    )

    return train_gen

def get_validation_gen(df, batch_size, img_size):
    val_idg = ImageDataGenerator(rescale=1. / 255.0)
    val_gen = val_idg.flow_from_dataframe(
        seed=SEED,
        dataframe=df,
        directory=None,
        x_col='image_path',
        y_col='target',
        class_mode=CLASS_MODE,
        shuffle=False,
        target_size=img_size,
        batch_size=batch_size,
        validate_filenames = False
    )

    return val_gen

def create_model(base_model, base_model_trainable, model_name):
    model_length = len(base_model.layers)
    trainable_layers = 100
    end_index = model_length if (not base_model_trainable) else (int(model_length-trainable_layers) if model_length>50 else 0)
    for layer in base_model.layers[0:end_index]:
        layer.trainable = False
    model = Sequential()
    model.add(base_model)
    if model_name == "nasnetlarge":
        model.add(GlobalAveragePooling2D())
    else:
        model.add(AveragePooling2D((2), name='avg_pool'))
        model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu', kernel_initializer=tf.keras.initializers.GlorotUniform()))
    model.add(layers.Dropout(0.3))
    model.add(layers.Dense(16, activation='relu', kernel_initializer=tf.keras.initializers.GlorotUniform()))
    model.add(layers.Dropout(0.3))
    model.add(layers.Dense(OUTPUT_NEURONS, activation='sigmoid', kernel_initializer=tf.keras.initializers.GlorotUniform()))
    return model

def earlyStoppingPatienteCalculator(do_undersampling, base_model_trainable):
    argument = (not do_undersampling) + base_model_trainable
    switcher = {
        0: 15,
        1: 10,
        2: 7
    }
    return switcher.get(argument, "Invalid data for early stopping")

def save_history(history, timestamp, base_output_path):
    f = plt.figure()
    f.set_figwidth(15)
    f.add_subplot(1, 2, 1)
    plt.plot(history['val_loss'], label='val loss')
    plt.plot(history['loss'], label='train loss')
    plt.legend()
    plt.title("Model Loss")
    f.add_subplot(1, 2, 2)
    plt.plot(history['val_accuracy'], label='val accuracy')
    plt.plot(history['accuracy'], label='train accuracy')
    plt.legend()
    plt.title("Model Accuracy")
    if SAVE_OUTPUT:
        length = len(history["loss"])-1
        metrics = ["loss", "accuracy","auc","val_loss", "val_accuracy","val_auc"]
        f = open(base_output_path + "2finalResults.txt", "a")
        for metric in metrics:
            metricValue = round(history[metric][length],4)
            f.write(metric + ":" + str(metricValue) + ("\n\n" if metric == "auc" else "\n"))
        f.close()
        plt.savefig(base_output_path + "2history.png")
        with open(base_output_path + "2history.json", 'w') as f:
            json.dump(history, f)
            
def plot_auc(y_true_classes, y_pred_probs, base_output_path):
    fpr, tpr, thresholds = roc_curve(y_true_classes, y_pred_probs, pos_label=1)
    fig, c_ax = plt.subplots(1,1, figsize = (8, 8))
    c_ax.plot(fpr, tpr, label = '%s (AUC:%0.2f)'  % ('Target', auc(fpr, tpr)))
    c_ax.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--')
    c_ax.legend()
    c_ax.set_xlabel('False Positive Rate')
    c_ax.set_ylabel('True Positive Rate')
    plt.savefig(base_output_path + "5auc.png")


def calc_f1(prec, recall):
    return 2*(prec*recall)/(prec+recall) if recall and prec else 0

def plot_confusion_matrix(cm, labels, base_output_path):
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    tick_marks = np.arange(len(labels))
    plt.xticks(tick_marks, labels, rotation=55)
    plt.yticks(tick_marks, labels)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], 'd'), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    if SAVE_OUTPUT:
        plt.savefig(base_output_path + "4confMatrix.png")
        
        

def createSubmissionFile(model, dataset, img_size, filePath):
    data=[]
    rangeValue = dataset.shape[0] if not fast_run else 50
    for i in tqdm(range(rangeValue)):
        image_path = dataset.iloc[i].image_path
        image_name = dataset.iloc[i].image_name
        img = tf.keras.preprocessing.image.load_img(image_path, target_size=img_size)
        img = tf.keras.preprocessing.image.img_to_array(img)
        img = img / 255
        img_array = tf.expand_dims(img, 0)
        pred = model.predict(img_array)
        y_pred_prob = round(pred[0][0],4)
        data.append([image_name, y_pred_prob])
    sub_df = pd.DataFrame(data, columns = ['image_name', 'target']) 
    sub_df.to_csv(filePath, index=False)
    sub_df.head()

def clearWD():
    import os
    import glob
    import shutil

    files = glob.glob('/kaggle/working/*')
    for f in files:
        try:
            os.remove(f)
        except:
            shutil.rmtree(f)
print("ALL FUNCTIONS LOADED")

<div class="alert alert-block alert-info">  
<b> Experiment Function.  </b>
</div>

In [None]:
def evaluateExperiment(fast_run, model_name, do_undersampling, prepro, prepro_brightness, prepro_zoom, prepro_rotation, prepro_blur, batch_size, epochs, base_model_trainable, base_model, learning_rate):
    (timestamp, experiment_id) = timestamp_and_experimentId(model_name, do_undersampling, prepro, prepro_rotation, prepro_blur, prepro_brightness, prepro_zoom, batch_size, epochs, base_model_trainable)
    base_output = "./" +experiment_id+"/"
    base_output_path = base_output + experiment_id + "-"
    img_pixels = modelpixels[model_name]
    img_size = (img_pixels, img_pixels)

    OPTIMIZER = Adam(lr=learning_rate) #Dejar esto que es el que mejor funcionar
    LOSS = 'binary_crossentropy'
    METRICS = [
        'accuracy', 
        'AUC'
    ] 
    
    train = getTrainData(img_pixels)
    test = getTestData(img_pixels)
    
    if(IS_CLASS_MODE_BINARY):
        train['target'] = train['target'].apply(str)
    if do_undersampling:
#         print("DOING UNDERSAMPLING")
        balance = 1
    else:
#         print("NO UNDERSAMPLING (un poco solo)")
        balance = 10
    train = doUndersampling(train, balance)


    trainTestPatientCheck(train, test)
    

    if SAVE_OUTPUT:
        train_gen = get_training_gen(train, prepro, prepro_brightness, prepro_zoom, prepro_rotation, prepro_blur, batch_size, img_size, doShuffle=False)
        t_x, t_y = next(train_gen)
        fig, m_axs = plt.subplots(4, 4, figsize = (16, 16))
        for (c_x, c_y, c_ax) in zip(t_x, t_y, m_axs.flatten()):
            c_ax.imshow(c_x, cmap = 'bone')
            if c_y == "1": 
                c_ax.set_title(str(c_y) + "-MALIGNANT")
            else:
                c_ax.set_title(str(c_y) + "-BENIGN")
            c_ax.axis('off')
        plt.savefig(base_output_path + "1dataAug.png")


    model = create_model(base_model, base_model_trainable, model_name)

    callback_list = []
    esPatience = earlyStoppingPatienteCalculator(do_undersampling, base_model_trainable)
    stop_early = EarlyStopping(monitor='val_auc', mode='max', patience=esPatience)
    callback_list.append(stop_early)

    if SAVE_OUTPUT:
        weight_path = base_output_path + "3model.hdf5"
        checkpoint = ModelCheckpoint(
            weight_path,
            save_weights_only=True,
            verbose=VERBOSE_LEVEL,
            save_best_only=True,
            monitor='val_auc',
            overwrite=True,
            mode='max',
        )
        callback_list.append(checkpoint)

    train_df, val_df = create_splits(train, 0.2, 'target')
#     print("rows in train_df", train_df.shape[0])
#     print("rows in val_df", val_df.shape[0])

    train_gen = get_training_gen(train_df, prepro, prepro_brightness, prepro_zoom, prepro_rotation, prepro_blur, batch_size, img_size)
    val_gen = get_validation_gen(val_df, batch_size, img_size)
    valX, valY = val_gen.next()

#     print("Calculating class weights")
    testY = np.array(train_df['target'])
    class_weights_computed = class_weight.compute_class_weight('balanced',np.unique(testY), testY)
    class_weights = dict(enumerate(class_weights_computed))
#     print("class_weights: " + str(class_weights))
    print("earlyStopPatience", esPatience)
    
# #     print("CLASSMODE:", CLASS_MODE)
# #     print("do_undersampling", do_undersampling)
# #     print("batch_size: ", batch_size)
# #     print("epochs: ", epochs)
# #     print("prepro_rotation", prepro_rotation)
# #     print("prepro_blur", prepro_blur)

    
    model.compile(loss=LOSS,metrics=METRICS,optimizer=OPTIMIZER,)

    if fast_run:
        epochs = 3
    print(experiment_id)
    history = model.fit(
        train_gen,
        epochs=epochs, 
        class_weight = class_weights,
        verbose=VERBOSE_LEVEL,
        callbacks=callback_list, 
        validation_data=(valX, valY),
    )
    
    
#     valPath = base_output + "testingVal.csv"
#     createSubmissionFile(model, val_df, img_size, valPath)

    save_history(history.history, timestamp, base_output_path)

    y_true_classes = [] # true labels
    y_pred_classes = [] # predictions
    y_pred_probs = [] # predictions probabilities
    rangeValue = val_df.shape[0] if not fast_run else 50
    for i in range(rangeValue):
        y_real = val_df.iloc[i].target
        y_real_int = int(y_real)
        image_path = val_df.iloc[i].image_path
        img = tf.keras.preprocessing.image.load_img(image_path, target_size=img_size)
        img = tf.keras.preprocessing.image.img_to_array(img)
        img = img / 255
        img_array = tf.expand_dims(img, 0)
        y_pred = model.predict(img_array)
        y_pred_prob = round(y_pred[0][0],4)
        y_pred_class = int(round(y_pred_prob, 0))
        #print("Real: ", y_real, "-> pred: ", y_pred_prob, "class", y_pred_class)
        y_true_classes.append(y_real_int)
        y_pred_classes.append(y_pred_class)
        y_pred_probs.append(y_pred_prob)
        

    plot_auc(y_true_classes, y_pred_probs, base_output_path)

    precision, recall, thresholds = precision_recall_curve(y_true_classes, y_pred_probs)
    f1score = [calc_f1(precision[i],recall[i]) for i in range(len(thresholds))]
    idx = np.argmax(f1score)
    precision = round(precision[idx], 4)
    recall = round(recall[idx], 4)
    threshold = round(thresholds[idx], 4)
    f1score = round(f1score[idx], 4)

#     print('Precision:', precision)
#     print('Recall:', recall)
#     print('Threshold:', threshold)
#     print('F1 Score:', f1score)

    cm_plot_label =['benign', 'malignant']
    plot_confusion_matrix(confusion_matrix(y_true_classes, y_pred_classes), cm_plot_label, base_output_path)
    image_path = test.iloc[0].image_path
    # Show a prediction for a random image
    image_path = test.sample().iloc[0].image_path
    img = tf.keras.preprocessing.image.load_img(image_path, target_size=img_size)
    img = tf.keras.preprocessing.image.img_to_array(img)
    img = img / 255
    img_array = tf.expand_dims(img, 0)

    pred = model.predict(img_array)
    prediction = round(pred[0][0],2)
#     print("Chance of being malignant: {:.2f} %".format(prediction))

    finding = "Diagnosis: BENIGN"
    if not prediction < 0.5:
        finding = "Diagnosis: MALIGNANT"

    x = plt.figure(figsize=(5,5))
    x = plt.imshow(img)
    x = plt.title(finding)
    x = plt.axis("off")

    if SAVE_OUTPUT:
        # save the model to a json file
        model_json = model.to_json()
        with open(base_output_path + "3model.json", "w") as json_file:
            json_file.write(model_json)
        subPath = base_output + "submission.csv"
        createSubmissionFile(model, test, img_size, subPath)

    shutil.make_archive("/kaggle/working/" + experiment_id, 'zip', "/kaggle/working/" + experiment_id)
    shutil.rmtree('/kaggle/working/' + experiment_id)
print("EVALUATE EXPERIMENT LOADED")

<div class="alert alert-block alert-success">  
    <b> Variables globales</b>
</div>

In [None]:
#Variables y cosas globales de verdad
SEED = 1
OUTPUT_NEURONS = 1
VERBOSE_LEVEL = 2
CLASS_MODE = "binary"
# "raw" evitarlo y "categorical" debería de ser igual pero al ser onehot pues también evitarlo en principio
IS_CLASS_MODE_BINARY = CLASS_MODE == "binary"
POSITIVE_CLASS = "1" if IS_CLASS_MODE_BINARY else 1
NEGATIVE_CLASS = "0" if IS_CLASS_MODE_BINARY else 0


BASE_INPUT_PATH = '/kaggle/input/tfmmelanomapreprocessed'
BASE_PATH_TO_IMAGES = '/kaggle/input/tfmmelanomapreprocessed/dataset/jpeg'
IMAGE_TYPE = ".jpg"

modelpixels = {"vgg16" : 224, "densenet121":224, "nasnetmobile":224, "nasnetlarge":331, "inceptionresnetv2":224}
INPUT_224_SHAPE = (224, 224, 3)
INPUT_331_SHAPE = (331, 331, 3)
def getVGG16():
    return VGG16(input_shape=INPUT_224_SHAPE,include_top=False,weights='imagenet')
def getDenseNet121():
    return DenseNet121(input_shape=INPUT_224_SHAPE,include_top=False,weights='imagenet')
def getNASNetMobile():
    return NASNetMobile(input_shape=INPUT_224_SHAPE,include_top=False,weights='imagenet')
def getNASNetLarge():
    return NASNetLarge(input_shape=INPUT_331_SHAPE,include_top=False,weights='imagenet')
def getInceptionResNetV2():
    return InceptionResNetV2(input_shape=INPUT_224_SHAPE,include_top=False,weights='imagenet')
allmodelsfunc = {"vgg16" : getVGG16, "densenet121":getDenseNet121, "nasnetmobile":getNASNetMobile, "nasnetlarge":getNASNetLarge, "inceptionresnetv2":getInceptionResNetV2}

#TODO MIRAR
train_224_backup = train_224_backup.copy() if ("train_224_backup" in globals() and train_224_backup is not None) else None
test_224_backup = test_224_backup.copy() if ("test_224_backup" in globals() and test_224_backup is not None) else None
train_331_backup = train_331_backup.copy() if ("train_331_backup" in globals() and train_331_backup is not None) else None
test_331_backup = test_331_backup.copy() if ("test_331_backup" in globals() and test_331_backup is not None) else None
    
# Tensorflow execution optimizations
# Source: https://www.tensorflow.org/guide/mixed_precision & https://www.tensorflow.org/xla
MIXED_PRECISION = True
XLA_ACCELERATE = True
GPUS = 0

GPUS = len(tf.config.experimental.list_physical_devices('GPU'))
if GPUS == 0:
    DEVICE = 'CPU'
#     raise RuntimeError('Running on CPU')
else:
    DEVICE = 'GPU'
    if MIXED_PRECISION:
        from tensorflow.keras.mixed_precision import experimental as mixed_precision
        policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
        mixed_precision.set_policy(policy)
#         print('Mixed precision enabled')
    if XLA_ACCELERATE:
        tf.config.optimizer.set_jit(True)
#         print('Accelerated Linear Algebra enabled')

# print("Tensorflow version " + tf.__version__)


# print("Set seeds")
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
os.environ['TF_KERAS'] = str(SEED)
os.environ['TF_DETERMINISTIC_OPS'] = str(SEED)
os.environ['TF_CUDNN_DETERMINISTIC'] = str(SEED)
tf.random.set_seed(SEED)

<div class="alert alert-block alert-success">  
    <b> Configuración del experimento</b>
</div>

In [None]:
#"vgg16","densenet121", "nasnetmobile", "nasnetlarge", "inceptionresnetv2"
modelsfunc = dict((k, allmodelsfunc[k]) for k in ["vgg16","densenet121", "nasnetmobile", "nasnetlarge", "inceptionresnetv2"])


SAVE_OUTPUT = False
fast_run = False
# fast_run = True

do_undersamplings = [True,False]
base_model_trainables = [False,True]
preprocesses = [True]
prepro_rotation = True
prepro_blur = True
prepro_brightness = True
prepro_zoom = True

for model_name in modelsfunc.keys():
    for preprocess in preprocesses:
        prepro_rotation = prepro_rotation and preprocess
        prepro_blur = prepro_blur and preprocess
        prepro_brightness = prepro_brightness and preprocess
        prepro_zoom = prepro_zoom and preprocess
        for do_undersampling in do_undersamplings:
            for base_model_trainable in base_model_trainables:
                print("\n\n=========================================" + model_name + "=========================================")
                base_model = modelsfunc[model_name]()
                batch_size = 16 if do_undersampling else (32 if not base_model_trainable else 64)
                epochs = 20 if do_undersampling else 40
                learning_rate = 1e-4 if not base_model_trainable else 1e-5
                
                epochs = 1
                
                evaluateExperiment(fast_run, model_name, do_undersampling, preprocess, prepro_brightness, prepro_zoom, prepro_rotation, prepro_blur, 
                                   batch_size, epochs, base_model_trainable, base_model, learning_rate)

In [None]:
# clearWD()