# Compare models

In [None]:
import sys

from utils import *
from visual_utils import *

from padding_generator import generators

from tensorflow.keras.models import load_model
from tensorflow.keras.utils import plot_model

import matplotlib.pyplot as plt
import pandas as pd
import math

import os
import glob

pd.set_option('display.max_colwidth', -1)

Display available trainned models

In [None]:
available_models()

Select models to compare

In [None]:
model_names = [
    'baseline',
    'resnet_34',
    'densenet_2',
    'resnet_101_lds',
]

PATHS = [data_paths(p, create_if_missing=False) for p in model_names]

Load models and evaluate them. Be aware of selecting too many models and/or too big batch size can lead to problem with insuficcient memory

In [None]:
batch_size = 64 #reduce if the memory is insufficient

results = []

set_type = SetType.TEST # TRAIN, VALID, TEST

# Generator codes: 
# b - no augmentation
# 1 - Modest aug
# 2 - Strong aug
# 
# suffix e for equalization, eg: 1e
# suffix g for gauss noise, eg: 2g
gen_code = 'b'
gen = generators[gen_code]

for model_name in model_names:
    print('Evaluating ' + model_name)
    
    model = load_model(data_paths(model_name)['best'], compile=True)
    e = study_eval(model, set_type, batch_size = batch_size, generator = gen)
    
    results.append(e)
    
res_df = pd.DataFrame(results, columns = ['ind_accuracy', 'ind_cohen_kappa', 'study_accuracy', 'study_cohen_kappa'])

display(res_df)

Display learning graphs

In [None]:
metrics = [
    'cohen_kappa',
    #'accuracy'
]
fig, axes = plt.subplots(1, len(metrics))
fig.set_size_inches(14, 6)

colors = "brgcmykk"

for p in range(len(PATHS)):
    data=get_log(PATHS[p])
    model_name=model_names[p]
    
    for metric, ax in zip(metrics, axes if len(metrics) > 1 else [axes]):
        ax.set_title(metric)
        ax.set_ylabel(metric)
        ax.set_xlabel('Epoch')
        
        ax.scatter(data.index[1:], data[metric][1:], label= model_name + ' train', c = colors[p], s = 1)
        ax.plot(data.index[1:], data['val_'+metric][1:], label= model_name + ' validation', c = colors[p])
        ax.legend();
        