# **Import Libraries**

In [None]:
#Set path to MAIN FOLDER OF EXPERIMENT
#cd /path/to/EXPERIMENT_FOLDER/

/content/drive/My Drive/PSR


In [5]:
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np
import os

#Import custom modules
from utils.ensemble import get_ensemble, add_preprocess_layer
from utils.eval_utils import get_generator, get_prediction
from utils.load_utils import init_models
import utils.ensemble as en

In [6]:
#Define Constants
BATCH_SIZE = 4
TOTAL_CLASSES = 199
IMAGE_SIZE = (256,256)

#Model path
TEACHER_MODEL_PATH = 'models/Teacher_Models/'

#Data paths
VAL_DATA_PATH = 'ds/val/'
TEST_DATA_PATH = 'ds/test/'

**Load Models**

In [7]:
#Initialize models
MODEL_FILES = ['DenseNet201', 'Xception', 'InceptionResNetV2', 'ResNet152V2', 'EfficientNetB7', 'NASNetLarge']
models = init_models(MODEL_FILES)

[INFO] Loading Model: DenseNet201
[INFO] Loading Model: Xception
[INFO] Loading Model: InceptionResNetV2
[INFO] Loading Model: ResNet152V2
[INFO] Loading Model: EfficientNetB7
[INFO] Loading Model: NASNetLarge


# **Teacher Ensemble Selection**

In [None]:
print("ITERATE OVER ALL POSSIBLE ABLATIONS\n")

ablation = {'VAL':{}, 'TEST':{}}
max_i = 2**len(MODEL_FILES)

for i in range(1, max_i):
    to_ensemble = []
    for j in range(len(MODEL_FILES)):
        if i & (1 << j):
            to_ensemble.append(MODEL_FILES[j])

    if len(to_ensemble) <= 1: continue
    en.TO_ENSEMBLE = to_ensemble
    teacher_ensemble = get_ensemble(models)
    print("Ablation :", to_ensemble)

    #Validation
    val_datagen, nb_val_samples = get_generator(VAL_DATA_PATH, 'sparse')
    y_true, y_prob, y_soft_prob, y_pred = get_prediction(teacher_ensemble, val_datagen, nb_val_samples)
    score = accuracy_score(y_true, y_pred)*100
    ablation['VAL']['teacher_ensemble' + str(i)] = score

    #Testing
    test_datagen, nb_test_samples = get_generator(TEST_DATA_PATH, 'sparse')
    y_true, y_prob, y_soft_prob, y_pred = get_prediction(teacher_ensemble, test_datagen, nb_test_samples)
    score = accuracy_score(y_true, y_pred)*100
    ablation['TEST']['teacher_ensemble' + str(i)] = score

    print('TEST ACCURACY: {0:.4f}'.format(ablation['VAL']['teacher_ensemble' + str(i)]))
    print('VAL ACCURACY: {0:.4f}\n'.format(ablation['TEST']['teacher_ensemble' + str(i)]))

Ablation : ['DenseNet201', 'Xception']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 90.5360
VAL ACCURACY: 89.6985
Ablation : ['DenseNet201', 'InceptionResNetV2']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 89.9497
VAL ACCURACY: 90.9548
Ablation : ['Xception', 'InceptionResNetV2']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 90.1173
VAL ACCURACY: 89.9497
Ablation : ['DenseNet201', 'Xception', 'InceptionResNetV2']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.4573
VAL ACCURACY: 91.0804
Ablation : ['DenseNet201', 'ResNet152V2']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 89.1122
VAL ACCURACY: 87.6884
Ablation : ['Xception', 'ResNet152V2']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 89.0285
VAL ACCURACY: 86.8090
Ablation : ['DenseNet201', 'Xception', 'ResNet152V2']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 90.8710
VAL ACCURACY: 88.6935
Ablation : ['InceptionResNetV2', 'ResNet152V2']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 89.9497
VAL ACCURACY: 88.4422
Ablation : ['DenseNet201', 'InceptionResNetV2', 'ResNet152V2']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 90.9548
VAL ACCURACY: 90.0754
Ablation : ['Xception', 'InceptionResNetV2', 'ResNet152V2']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.0385
VAL ACCURACY: 89.0704
Ablation : ['DenseNet201', 'Xception', 'InceptionResNetV2', 'ResNet152V2']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.7085
VAL ACCURACY: 90.9548
Ablation : ['DenseNet201', 'EfficientNetB7']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 93.4673
VAL ACCURACY: 92.4623
Ablation : ['Xception', 'EfficientNetB7']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.0436
VAL ACCURACY: 90.7035
Ablation : ['DenseNet201', 'Xception', 'EfficientNetB7']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.8811
VAL ACCURACY: 91.8342
Ablation : ['InceptionResNetV2', 'EfficientNetB7']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.2111
VAL ACCURACY: 91.3317
Ablation : ['DenseNet201', 'InceptionResNetV2', 'EfficientNetB7']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.5461
VAL ACCURACY: 92.7136
Ablation : ['Xception', 'InceptionResNetV2', 'EfficientNetB7']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.9648
VAL ACCURACY: 91.9598
Ablation : ['DenseNet201', 'Xception', 'InceptionResNetV2', 'EfficientNetB7']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.7973
VAL ACCURACY: 92.4623
Ablation : ['ResNet152V2', 'EfficientNetB7']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.1223
VAL ACCURACY: 90.2010
Ablation : ['DenseNet201', 'ResNet152V2', 'EfficientNetB7']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.3786
VAL ACCURACY: 90.8291
Ablation : ['Xception', 'ResNet152V2', 'EfficientNetB7']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.7085
VAL ACCURACY: 90.2010
Ablation : ['DenseNet201', 'Xception', 'ResNet152V2', 'EfficientNetB7']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.1273
VAL ACCURACY: 91.0804
Ablation : ['InceptionResNetV2', 'ResNet152V2', 'EfficientNetB7']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.8760
VAL ACCURACY: 91.4573
Ablation : ['DenseNet201', 'InceptionResNetV2', 'ResNet152V2', 'EfficientNetB7']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.2948
VAL ACCURACY: 92.4623
Ablation : ['Xception', 'InceptionResNetV2', 'ResNet152V2', 'EfficientNetB7']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.9598
VAL ACCURACY: 91.7085
Ablation : ['DenseNet201', 'Xception', 'InceptionResNetV2', 'ResNet152V2', 'EfficientNetB7']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.2111
VAL ACCURACY: 92.3367
Ablation : ['DenseNet201', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 90.0335
VAL ACCURACY: 91.3317
Ablation : ['Xception', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 89.7822
VAL ACCURACY: 90.0754
Ablation : ['DenseNet201', 'Xception', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.0436
VAL ACCURACY: 92.3367
Ablation : ['InceptionResNetV2', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 89.8660
VAL ACCURACY: 90.8291
Ablation : ['DenseNet201', 'InceptionResNetV2', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 90.6198
VAL ACCURACY: 92.7136
Ablation : ['Xception', 'InceptionResNetV2', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.2060
VAL ACCURACY: 90.9548
Ablation : ['DenseNet201', 'Xception', 'InceptionResNetV2', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.7085
VAL ACCURACY: 92.2111
Ablation : ['ResNet152V2', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 89.1122
VAL ACCURACY: 89.5729
Ablation : ['DenseNet201', 'ResNet152V2', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 90.7035
VAL ACCURACY: 91.2060
Ablation : ['Xception', 'ResNet152V2', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 90.7873
VAL ACCURACY: 89.9497
Ablation : ['DenseNet201', 'Xception', 'ResNet152V2', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.7923
VAL ACCURACY: 91.0804
Ablation : ['InceptionResNetV2', 'ResNet152V2', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 90.7873
VAL ACCURACY: 90.7035
Ablation : ['DenseNet201', 'InceptionResNetV2', 'ResNet152V2', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.2898
VAL ACCURACY: 91.4573
Ablation : ['Xception', 'InceptionResNetV2', 'ResNet152V2', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.3735
VAL ACCURACY: 90.9548
Ablation : ['DenseNet201', 'Xception', 'InceptionResNetV2', 'ResNet152V2', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.7923
VAL ACCURACY: 91.9598
Ablation : ['EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.4573
VAL ACCURACY: 92.5879
Ablation : ['DenseNet201', 'EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.6298
VAL ACCURACY: 93.7186
Ablation : ['Xception', 'EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.4623
VAL ACCURACY: 92.7136
Ablation : ['DenseNet201', 'Xception', 'EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 93.0486
VAL ACCURACY: 92.9648
Ablation : ['InceptionResNetV2', 'EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.7923
VAL ACCURACY: 92.7136
Ablation : ['DenseNet201', 'InceptionResNetV2', 'EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.3786
VAL ACCURACY: 93.3417
Ablation : ['Xception', 'InceptionResNetV2', 'EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.2948
VAL ACCURACY: 92.7136
Ablation : ['DenseNet201', 'Xception', 'InceptionResNetV2', 'EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.7973
VAL ACCURACY: 92.8392
Ablation : ['ResNet152V2', 'EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 91.7923
VAL ACCURACY: 92.3367
Ablation : ['DenseNet201', 'ResNet152V2', 'EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.4623
VAL ACCURACY: 93.0905
Ablation : ['Xception', 'ResNet152V2', 'EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.2111
VAL ACCURACY: 92.5879
Ablation : ['DenseNet201', 'Xception', 'ResNet152V2', 'EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.8811
VAL ACCURACY: 92.4623
Ablation : ['InceptionResNetV2', 'ResNet152V2', 'EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.0436
VAL ACCURACY: 92.7136
Ablation : ['DenseNet201', 'InceptionResNetV2', 'ResNet152V2', 'EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.1273
VAL ACCURACY: 92.8392
Ablation : ['Xception', 'InceptionResNetV2', 'ResNet152V2', 'EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.1273
VAL ACCURACY: 92.3367
Ablation : ['DenseNet201', 'Xception', 'InceptionResNetV2', 'ResNet152V2', 'EfficientNetB7', 'NASNetLarge']
Found 1194 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=299.0), HTML(value='')))


Found 796 images belonging to 199 classes.


HBox(children=(FloatProgress(value=0.0, max=199.0), HTML(value='')))


TEST ACCURACY: 92.5461
VAL ACCURACY: 92.8392


In [None]:
#Show results of 5 Best ablations
print("TOP 6 RESULTS FROM ALL POSSIBLE ABLATIONS")
df = pd.DataFrame(ablation)
df.sort_values(['VAL', 'TEST'], ascending=False).head()

Unnamed: 0,VAL,TEST
teacher_ensemble17,93.467337,92.462312
teacher_ensemble51,93.048576,92.964824
teacher_ensemble22,92.964824,91.959799
teacher_ensemble59,92.881072,92.462312
teacher_ensemble19,92.881072,91.834171


In [9]:
#Show Best Ensemble (Please Select best_ablation value from above table)
best_ablation = 17
best_ensemble = []
for j in range(len(MODEL_FILES)):
    if 17 & (1 << j):
        best_ensemble.append(MODEL_FILES[j])

print("BEST ENSEMBLE", best_ensemble, '\n')
en.TO_ENSEMBLE = best_ensemble
teacher_ensemble = get_ensemble(models)
teacher_ensemble.summary()

BEST ENSEMBLE ['DenseNet201', 'EfficientNetB7'] 

Model: "model_11"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
ensemble_input (InputLayer)     [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
preprocess_DenseNet201 (D201_Pr (None, 256, 256, 3)  0           ensemble_input[0][0]             
__________________________________________________________________________________________________
preprocess_EfficientNetB7 (ENB7 (None, 256, 256, 3)  0           ensemble_input[0][0]             
__________________________________________________________________________________________________
densenet201_DenseNet201 (Functi (None, 8, 8, 1920)   18321984    preprocess_DenseNet201[0][0]     
_________________________________________

In [None]:
print("SAVE MODEL")
teacher_ensemble.save(TEACHER_MODEL_PATH + 'EnsembleModel/' + 'model.h5')

SAVE MODEL


