Inspiré de ce [kernel](https://www.kaggle.com/atamazian/fc-ensemble-external-data-effnet-densenet/) et de [celui-ci](https://www.kaggle.com/chankhavu/a-beginner-s-tpu-kernel-single-model-0-97)

Importations des librairies

In [None]:
!pip install --quiet tensorflow-addons --upgrade
!pip install -q efficientnet

import tensorflow as tf
from kaggle_datasets import KaggleDatasets
import numpy as np
import re
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense,\
                                    BatchNormalization, Dropout, GlobalAveragePooling2D

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import he_uniform
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
import efficientnet.tfkeras as efficientnet
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix


Pour se connecter au TPU de Kaggle

In [None]:
try:
    # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

Récupération des jeux de données. Ainsi que choix de la taille des images.

Des jeux de données provenant d'autres compétitions sont utilisés pour améliorer nos modèles. Etant donné qu'ils sont structurés exactement de la même manière que le data d'origine, il n'est pas nécessaire d'effectuer des traitements supplémentaires.

In [None]:
IMAGE_SIZE = [224, 224]
img_size_path = '/tfrecords-jpeg-' + str(IMAGE_SIZE[0]) + 'x' + str(IMAGE_SIZE[0])
or_path = KaggleDatasets().get_gcs_path('tpu-getting-started')
or_path += img_size_path
TRAINING_FILENAMES = tf.io.gfile.glob(or_path + '/train/*.tfrec')
VALIDATION_FILENAMES = tf.io.gfile.glob(or_path + '/val/*.tfrec')
TEST_FILENAMES = tf.io.gfile.glob(or_path + '/test/*.tfrec')

# Extending the dataset with additional data
ext_gcs = KaggleDatasets().get_gcs_path('tf-flower-photo-tfrec')
imagenet_files = tf.io.gfile.glob(ext_gcs + '/imagenet' + img_size_path + '/*.tfrec')
inaturelist_files = tf.io.gfile.glob(ext_gcs + '/inaturalist' + img_size_path + '/*.tfrec')
openimage_files = tf.io.gfile.glob(ext_gcs + '/openimage' + img_size_path + '/*.tfrec')
oxford_files = tf.io.gfile.glob(ext_gcs + '/oxford_102' + img_size_path + '/*.tfrec')
tensorflow_files = tf.io.gfile.glob(ext_gcs + '/tf_flowers' + img_size_path + '/*.tfrec')

EXTRA_FILES = imagenet_files + inaturelist_files + openimage_files + oxford_files + tensorflow_files
#TRAINING_FILENAMES = TRAINING_FILENAMES + EXTRA_FILES

Déterminer les types des variables pour les TPU

In [None]:
raw_dataset = tf.data.TFRecordDataset(TRAINING_FILENAMES)
serialized_example = next(iter(raw_dataset))
example = tf.train.Example()
example.ParseFromString(serialized_example.numpy())
print(str(example)[:300] + ' ...')

Les fonctions ci-dessous permettent d'extraire les images et les mettre sous format tensor pour que les TPU puissent les prendre en compte.

In [None]:
def decode_image(image_data):    
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) 
    
    return image

def read_labeled_tfrecord(example):
    
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    parsed_example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(parsed_example['image'])
    label = tf.cast(parsed_example['class'], tf.int32)
    
    return image, label

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum

Les 104 classes présentes dans les datasets

In [None]:
CLASSES = [
    'pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 
    'wild geranium', 'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 
    'globe thistle', 'snapdragon', "colt's foot", 'king protea', 'spear thistle', 
    'yellow iris', 'globe-flower', 'purple coneflower', 'peruvian lily', 
    'balloon flower', 'giant white arum lily', 'fire lily', 'pincushion flower', 
    'fritillary', 'red ginger', 'grape hyacinth', 'corn poppy', 
    'prince of wales feathers', 'stemless gentian', 'artichoke', 'sweet william', 
    'carnation', 'garden phlox', 'love in the mist', 'cosmos',  'alpine sea holly', 
    'ruby-lipped cattleya', 'cape flower', 'great masterwort',  'siam tulip', 
    'lenten rose', 'barberton daisy', 'daffodil',  'sword lily', 'poinsettia', 
    'bolero deep blue',  'wallflower', 'marigold', 'buttercup', 'daisy', 
    'common dandelion', 'petunia', 'wild pansy', 'primula',  'sunflower', 
    'lilac hibiscus', 'bishop of llandaff', 'gaura',  'geranium', 'orange dahlia', 
    'pink-yellow dahlia', 'cautleya spicata',  'japanese anemone', 
    'black-eyed susan', 'silverbush', 'californian poppy',  'osteospermum', 
    'spring crocus', 'iris', 'windflower',  'tree poppy', 'gazania', 'azalea', 
    'water lily',  'rose', 'thorn apple', 'morning glory', 'passion flower',  
    'lotus', 'toad lily', 'anthurium', 'frangipani',  'clematis', 'hibiscus', 
    'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen ', 
    'watercress',  'canna lily', 'hippeastrum ', 'bee balm', 'pink quill',  
    'foxglove', 'bougainvillea', 'camellia', 'mallow',  'mexican petunia',  
    'bromelia', 'blanket flower', 'trumpet creeper',  'blackberry lily', 
    'common tulip', 'wild rose']

Pour éviter l'overfitting et améliorer nos modèles, nous employons la Data Augmentations.
Nous allons donc jouer sur :
- la luminosité
- le constraste
- l'effet mirroir
- la rotation
- le zoom
- les cutout : initancier aléatoirement des blocs noirs

In [None]:
SIZE_CUTOUT = (50,50)

def aug_img(image, label):
    img = tf.image.random_brightness(image, 0.1)
    img = tf.image.random_contrast(img, 0.8, 2.2)
    
    rand_rad = np.random.uniform(-np.pi / 4, np.pi / 4)
    img = tfa.image.rotate(img, rand_rad)
    rand_zoom = np.random.uniform(0, 0.2)
    img = tf.image.central_crop(img, 1 - (rand_rad**2 + rand_zoom))
    img = tf.image.resize(img, IMAGE_SIZE)
    
    img = tf.image.random_flip_left_right(img)
    
    img = tf.expand_dims(img, 0)
    img = tfa.image.random_cutout(img, SIZE_CUTOUT)
    img = tf.squeeze(img)
    
    return img, label

Les fonctions ci-dessous permettent de créer des batchs pour chaque dataset pour l'entraînement et les prédictions des modèles.

In [None]:
def load_dataset(filenames, labeled=True, ordered=False):

    options = tf.data.Options()
    
    options.experimental_deterministic = ordered
    
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.experimental.AUTOTUNE)
    dataset = dataset.with_options(options)  
    dataset = dataset.map(
        read_labeled_tfrecord if labeled else read_unlabeled_tfrecord,
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
    )
    
    return dataset

def get_training_dataset(file, batch_size, augmentation=False):
    
    dataset = load_dataset(file, labeled=True)
    
    if augmentation:
        dataset = dataset.map(aug_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        
    dataset = dataset.repeat()
    dataset = dataset.shuffle(2024)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    
    return dataset

def get_validation_dataset(batch_size, ordered=False):
    
    dataset = load_dataset(VALIDATION_FILENAMES, labeled=True, ordered=ordered)
    dataset = dataset.batch(batch_size)
    dataset = dataset.cache()
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    
    return dataset

def get_test_dataset(batch_size):
    
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=True)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) 
    
    return dataset

def count_data_items(filenames):
    return np.sum([int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames])
    
ORIGINAL_NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
NUM_VALIDATION_IMAGES = count_data_items(VALIDATION_FILENAMES)
EXTRA_NUM_TRAINING_IMAGES = count_data_items(EXTRA_FILES)
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)

BATCH_SIZE = 16 * strategy.num_replicas_in_sync
STEPS_PER_EPOCH = ORIGINAL_NUM_TRAINING_IMAGES // BATCH_SIZE
STEPS_PER_EPOCH_EXTRA = EXTRA_NUM_TRAINING_IMAGES // BATCH_SIZE


print('Original Dataset:\n\n{} training images\n{} validation images\n{} unlabeled test images'.format(ORIGINAL_NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))
print(EXTRA_NUM_TRAINING_IMAGES, "extra images")

Nous voyons que les jeux de données supplémentaires sont en quantité non négligeable et donc potentiellement utile pour les modèles.

Jetons un coup d'oeil à nos données.

In [None]:
plt.figure(figsize=(10,7))

train_dataset = get_training_dataset(TRAINING_FILENAMES, batch_size=12)
train_iter = iter(train_dataset)
batch = next(train_iter)

for i, image in enumerate(batch[0]):
    ax = plt.subplot(3, 4, i + 1)
    plt.imshow(image)
    plt.title(CLASSES[batch[1].numpy()[i]])
    plt.axis("off")

Regardons l'utilisation de la data augmentation sur la deuxième image

In [None]:
image = batch[0][1]

plt.figure(figsize=(7,5))

for i in range(12):
    ax = plt.subplot(3, 4, i + 1)
    img, _ = aug_img(image, "")
    
    plt.imshow(img)
    plt.axis("off")

Les images générées sont bien toutes différentes les unes des autres avec des niveaux de constrate ou de luminosité distincts. 

Nous allons récupérer toutes les images pour analyser les données.

In [None]:
training_dataset = get_training_dataset(TRAINING_FILENAMES, ORIGINAL_NUM_TRAINING_IMAGES)
validation_dataset = get_validation_dataset(NUM_VALIDATION_IMAGES)
extra_dataset = get_training_dataset(EXTRA_FILES, EXTRA_NUM_TRAINING_IMAGES)
train_batch = next(iter(training_dataset))
val_batch = next(iter(validation_dataset))
extra_batch = next(iter(extra_dataset))

Affichons la distribution des labels pour le jeu de données de validation et celui d'entraînement.

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(12,12))
ax = sns.distplot(train_batch[1].numpy(), bins=len(CLASSES), kde=False, label='Train', vertical=True)
ax = sns.distplot(val_batch[1].numpy(), bins=len(CLASSES), kde=False, label='Validation', vertical=True)

ax.set_yticks(np.arange(len(CLASSES)))
ax.set_yticklabels(CLASSES, fontsize=7)

plt.legend(prop={'size': 12})
plt.title('Distribution of labels')
plt.xlabel('Occurrences')
plt.ylabel('Labels')
plt.show()

Nous pouvons noter la présence d'imbalance de classes. Cependant, la distribution des deux datasets sont identiques.

Comparons celle d'entraînement à celle des jeux de données supplémentaires.

In [None]:
plt.figure(figsize=(12,12))
ax = sns.distplot(train_batch[1].numpy(), bins=len(CLASSES), kde=False, label='Train', vertical=True)
ax = sns.distplot(extra_batch[1].numpy(), bins=len(CLASSES), kde=False, label='Extra', vertical=True)

ax.set_yticks(np.arange(len(CLASSES)))
ax.set_yticklabels(CLASSES, fontsize=7)

plt.legend(prop={'size': 12})
plt.title('Distribution of labels')
plt.xlabel('Occurrences')
plt.ylabel('Labels')
plt.show()

Nous observons que la distribution des images supplémentaires n'est pas exactement pareille à celle d'entraînement.

En effectuant des tests au préalable, il semblerait que les images tests (de soumission) aient la même distribution que celle d'entraînement.

Par conséquent, nous allons utiliser les données supplémentaires en effectuant une correction d'imbalance de classes en ajoutant des coefficients.

In [None]:
from sklearn.utils import class_weight

class_weights = class_weight.compute_class_weight('balanced', np.unique(extra_batch[1].numpy()), extra_batch[1].numpy())
dict_weights = dict(enumerate(class_weights))
dict_weights

Une technique utilisé est le learning rate schedule où le taux d'apprentissage progresse lentement en fonction du numéro de batch courant pour ensuite effectuer une descente exponentielle.

In [None]:
# Learning rate schedule for TPU, GPU and CPU.
# Using an LR ramp up because fine-tuning a pre-trained model.
# Starting with a high LR would break the pre-trained weights.

LR_START = 0.00001
LR_MAX = 0.00005 * strategy.num_replicas_in_sync
LR_MIN = 0.00001
LR_RAMPUP_EPOCHS = 5
LR_SUSTAIN_EPOCHS = 0
LR_EXP_DECAY = .8

def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
    return lr
    
lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose = True)

rng = [i for i in range(50)]
y = [lrfn(x) for x in rng]
plt.plot(rng, y)
print("Learning rate schedule: {:.3g} to {:.3g} to {:.3g}".format(y[0], max(y), y[-1]))

Comme nous utilisons la méthode ensembliste en prenant en compte les prédictions de deux modèles, nous choisissons deux ayant généralement de bons résultats : DenseNet201 et EfficientNetB7.

L'entraînement de chaque modèle s'effectue en 2 étapes :
- Apprendre sur les jeux de données supplémentaires ayant une correction d'imbalance de classes
- Apprendre sur le jeu de données d'apprentissage

In [None]:
training_dataset = get_training_dataset(TRAINING_FILENAMES, BATCH_SIZE, augmentation=True)
validation_dataset = get_validation_dataset(BATCH_SIZE)
training_dataset_extra = get_training_dataset(EXTRA_FILES, BATCH_SIZE, augmentation=True)

with strategy.scope():    

    base_model = tf.keras.applications.DenseNet201(weights='imagenet', include_top = False, input_shape=[*IMAGE_SIZE, 3])
    base_model.trainable = True

    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(len(CLASSES), activation='softmax')
    ])

model.compile(
    optimizer='adam',
    loss = 'sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy']
)


history = model.fit(training_dataset_extra, 
          steps_per_epoch=STEPS_PER_EPOCH_EXTRA, 
          epochs=3, 
          validation_data=validation_dataset,
          class_weight=dict_weights,                    
          callbacks=[lr_callback]
)


history = model.fit(training_dataset, 
          steps_per_epoch=STEPS_PER_EPOCH, 
          epochs=50, 
          validation_data=validation_dataset,                    
          callbacks=[tf.keras.callbacks.EarlyStopping('val_loss', patience=5), lr_callback]
)

In [None]:
with strategy.scope():    

    base_model = efficientnet.EfficientNetB7(weights='noisy-student', include_top = False, input_shape=[*IMAGE_SIZE, 3])
    base_model.trainable = True

    model2 = tf.keras.Sequential([
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(len(CLASSES), activation='softmax')
    ])

model2.compile(
    optimizer='adam',
    loss = 'sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy']
)


history = model2.fit(training_dataset_extra, 
          steps_per_epoch=STEPS_PER_EPOCH_EXTRA, 
          epochs=3, 
          validation_data=validation_dataset,
          class_weight=dict_weights,                    
          callbacks=[lr_callback]
)

history = model2.fit(training_dataset, 
          steps_per_epoch=STEPS_PER_EPOCH, 
          epochs=50, 
          validation_data=validation_dataset,                    
          callbacks=[tf.keras.callbacks.EarlyStopping('val_loss', patience=5), lr_callback]
)

Pour utiliser la méthode ensembliste, il faut déterminer un coefficient permettant de déterminer la meilleure balance entre les deux modèles.

Pour cela, nous allons utiliser le dataset de validation pour comparer les résultats.

In [None]:
cmdataset = get_validation_dataset(BATCH_SIZE, ordered=True) # since we are splitting the dataset and iterating separately on images and labels, order matters.
images_ds = cmdataset.map(lambda image, label: image)
labels_ds = cmdataset.map(lambda image, label: label).unbatch()
cm_correct_labels = next(iter(labels_ds.batch(NUM_VALIDATION_IMAGES))).numpy() # get everything as one batch
m1 = model.predict(images_ds)
m2 = model2.predict(images_ds)

cm_predictions = np.argmax(m1, axis=-1)
print("M1:", f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro'))

cm_predictions = np.argmax(m2, axis=-1)
print("M2:", f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro'))

scores = []
for alpha in np.linspace(0,1,100):
    cm_probabilities = alpha*m1+(1-alpha)*m2
    cm_predictions = np.argmax(cm_probabilities, axis=-1)
    scores.append(f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro'))

    best_alpha = np.argmax(scores)/100
print('Best alpha:', str(best_alpha))

Cet alpha sera donc conservé pour les prédictions.

Regardons le matrice de confusion.

In [None]:
def display_confusion_matrix(cmat, score, precision, recall):
    plt.figure(figsize=(15,15))
    ax = plt.gca()
    ax.matshow(cmat, cmap='Reds')
    ax.set_xticks(range(len(CLASSES)))
    ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    ax.set_yticks(range(len(CLASSES)))
    ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    titlestring = ""
    if score is not None:
        titlestring += 'f1 = {:.3f} '.format(score)
    if precision is not None:
        titlestring += '\nprecision = {:.3f} '.format(precision)
    if recall is not None:
        titlestring += '\nrecall = {:.3f} '.format(recall)
    if len(titlestring) > 0:
        ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
    plt.show()

cm_probabilities = best_alpha*m1 + (1-best_alpha)*m2
cm_predictions = np.argmax(cm_probabilities, axis=-1)
cmat = confusion_matrix(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)))
score = f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
precision = precision_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
recall = recall_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
#cmat = (cmat.T / cmat.sum(axis=1)).T # normalized
display_confusion_matrix(cmat, score, precision, recall)
print('f1 score: {:.3f}, precision: {:.3f}, recall: {:.3f}'.format(score, precision, recall))


Les classes ayant le plus d'instances sont celles qui ont eu de meilleures bonnes prédictions.

Test Time Augmentation est également employé. Cette technique consiste à utiliser la data augmentation sur les fichiers tests.

In [None]:
def tta_data_aug(images):
    img = tf.image.random_brightness(images, 0.1)
    img = tf.image.random_contrast(img, 0.8, 2.2)
    
    rand_rad = np.random.uniform(-np.pi / 6, np.pi / 6)
    img = tfa.image.rotate(img, rand_rad)
    rand_zoom = np.random.uniform(0, 0.2)
    img = tf.image.central_crop(img, 1 - (rand_rad**2 + rand_zoom))
    img = tf.image.resize(img, IMAGE_SIZE)
    
    img = tf.image.random_flip_left_right(img)
    
    img = tfa.image.random_cutout(img, SIZE_CUTOUT)
    
    return img



def tta_predictions(model, model2, ds, n):
    probs  = []
    probs2 = []
    for i in range(n):
        ds_tta = ds.map(tta_data_aug, num_parallel_calls=tf.data.experimental.AUTOTUNE)        
        probs.append(model.predict(ds_tta))
        probs2.append(model2.predict(ds_tta))
        
    return probs, probs2



Prédiction des images tests pour la soumission

In [None]:
cmdataset = get_validation_dataset(BATCH_SIZE, ordered=True) # since we are splitting the dataset and iterating separately on images and labels, order matters.
images_ds = cmdataset.map(lambda image, label: image)
labels_ds = cmdataset.map(lambda image, label: label).unbatch()
cm_correct_labels = next(iter(labels_ds.batch(NUM_VALIDATION_IMAGES))).numpy() # get everything as one batch
m1, m2 = tta_predictions(model, model2, images_ds, 5)

m1 = np.mean(m1, axis=0)
m2 = np.mean(m2, axis=0)

cm_predictions = np.argmax(m1, axis=-1)
print("M1:", f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro'))

cm_predictions = np.argmax(m2, axis=-1)
print("M2:", f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro'))

scores = []
for alpha in np.linspace(0,1,100):
    cm_probabilities = alpha*m1+(1-alpha)*m2
    cm_predictions = np.argmax(cm_probabilities, axis=-1)
    scores.append(f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro'))

    best_alpha = np.argmax(scores)/100
print('Best alpha:', str(best_alpha))


test_ds = get_test_dataset(BATCH_SIZE)
test_images_ds = test_ds.map(lambda image, idnum: image)
m1, m2 = tta_predictions(model, model2, test_images_ds, 5)
probs1 = np.mean(m1, axis=0)
probs2 = np.mean(m2, axis=0)
probabilities = best_alpha*probs1 + (1-best_alpha)*probs2
predictions = np.argmax(probabilities, axis=-1)

test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U') # all in one batch
np.savetxt('submission.csv', np.rec.fromarrays([test_ids, predictions]), fmt=['%s', '%d'], delimiter=',', header='id,label', comments='')