# About this kernel, about this competition

## Intro

This competition will be over in about 2 days and it has been my first Kaggle competition. I am rather a beginner in ML and I want to thank Kaggle for this great opportunity 
to learn about tensorflow, classification and tpu's.

I started with a public ensemble kernel from Wojtek Rosa and tried a lot of basic hyperparameter tuning. As new discussion entries appeared I could learn a lot about augmentation techniques, under/oversampling, optimizers and other stuff. Kudos to the nice people, that fed the community
with their knowlege. I will mention the most important contributions in the later sections.

## 0) TPU stuff

TPUs are impressive and @mgoernergoogle made it easy to understand the basic code, which is necessary to start with tpus.

https://www.kaggle.com/mgornergoogle/getting-started-with-100-flowers-on-tpu

Later he provided a kernel which implemented a custom training loop, that could speed up learning up to 20%. Unfortunately tensorflow 2.1 showed up to be unstable when training 512x512 sized images - this should be fixed in tf 2.2, which has been released 3 days ago. but did not find the way into the Kaggle environment as I write this.

https://www.kaggle.com/mgornergoogle/custom-training-loop-with-100-flowers-on-tpu


## 1) Datasets

What could be found in the the beginning was an unbalanced dataset of flowers with 104 classes, which has been nicely assembled from 5 public flower datasets by Martin Goerner.

As the competition went on, people incorporated one or more of these public datasets in their training, published them later and it showed that using those datasets could greatly improve LB scores. Thanks to Heng CherKeng, Kirill Blinov and all the others for their contributions.

https://www.kaggle.com/c/flower-classification-with-tpus/discussion/140866

https://www.kaggle.com/kirillblinov/tf-flower-photo-tfrec

In the last days there has even been a little discussion whether these datasets are allowed to be used. 

https://www.kaggle.com/c/flower-classification-with-tpus/discussion/148329



## 2) Augmentations

Chris Deotte contributed greatly to this topic, providing notebooks that showed us an implementation Gridmask, CutMix and MixUp augmentations along with his spatial affine transformations. I tried them all and found it very interesting and also introduced me to learn about label smoothing (another technique to handle unbalanced datasets with one-hot encoded class labels).

https://www.kaggle.com/cdeotte/rotation-augmentation-gpu-tpu-0-96

https://www.kaggle.com/cdeotte/cutmix-and-mixup-on-gpu-tpu

https://www.kaggle.com/yihdarshieh/make-chris-deotte-s-data-augmentation-faster

https://www.kaggle.com/yihdarshieh/batch-implementation-of-more-data-augmentations

https://www.kaggle.com/xiejialun/gridmask-data-augmentation-with-tensorflow


Thanks to MichaÅ‚ Szachniewicz who implemented AugMix to run under tensorflow 2.x. A pitty that the experiments did not produce nice results.

https://www.kaggle.com/szacho/augmix-data-augmentation-on-tpu

I found it interesting, how good the rather simple cutout augmentation worked - see the random_blockout() function from a competitor below.


I also stumbled about AutoAug, a technique used by Google researchers and in AutoML for classification and now even for object detection, to find the best fitting augmentation parameters for a given dataset. AutoAug can be found in the tensorflow repository on github, but is implemented in tensorflow 1.x and I did not have the time to invest in that.


## 3) Models and Techniques

### 3.1 Models

State of the art is the usage of the Effcientnet set of models. Theses models are trained on imagenet and noisy-student - both variations of the weights are available in the Keras version on github. Some people combined one or two Efficientnet model with other models in an ensemble, like it is done in this kernel.

Wojtek Rosa provided this starter kernel: https://www.kaggle.com/wrrosa/tpu-enet-b7-densenet


### 3.2 Optimizers

There is a lot of research going on in this field and computer scientists are proposing a lot of new optimizers these days.
I started with Adam and did some experiments, especially with the so called Ranger optimizer (a combination of RectifiedAdam and Lookahead - they can be found in the tensorflow addons library).

All in all I did not find success using these, maybe because they converge slower and there is limited training time in this competition -  so I went back to plain Adam.


### 3.3 Learning Rate and other parameters

When one is finetuning a model (train all weights of a pretrained model) one should implement a rampup phase for some epochs with a lower learning rate, so one doesn't break the pretrained features. The starter notebook provides a LearningRateScheduler with exponential decay.

An alternative would be the usage of a cosine decaying learning rate, which I implemented below in this notebook. I did not try cyclic learning rates, which would have been interesting as well. Btw. - the ranger optimizer likes high flat learning rates in the beginning and cosine annealing.

> On TPU the initial batch size could be doubled with 512x512 sized images, which really was a big improvement (16 * strategy.num_replicas_in_sync * 2)

> I could get a bit better results multiplying the proposed learning rate schedule from the starter kernel by 1.2




### 3.4 Class Weights

Class weights are a method where one can tell the optimizer to underweight the influence of overrepresented classes. A short piece of code in shown below, but I did not use it at last, because it showed, that the losses are getting smaller more slowly. Maybe more epochs would show that this method leads to a good model, but in this competition we are restricted to a runtime of 3 hours and this is not effective. Further there is doubt whether class weights do work at all in tf2.1 on tpu.


### 3.5 Oversampling/Undersampling

For an unbalanced dataset people have found success in training with data, where one filters out examples of the overpresented classes or one extends the dataset with (modified) copies of the underrepresented examples. I did not get lucky with it in this competition.

https://www.kaggle.com/yihdarshieh/tutorial-oversample


### 3.6 Progressive Resizing

During the competition I read about progressive resizing (to train a model with a smaller image size first and then again with a larger image size) but then a notebook which implemented this using the fastai library, brought me back to this idea in the last days, so I implemented it in this kernel.

https://www.kaggle.com/kurianbenoy/classifying-flowers-with-fastaiv2-0-96


### 3.7 Custom Training Loop

As mentioned above the custom training loop from https://www.kaggle.com/mgornergoogle/custom-training-loop-with-100-flowers-on-tpu can save about 20% training time.


### 3.8 KFolds

Using KFolds is the idea of putting together training and validation data in one set and then splitting this set differently K-times, train K models and then aggregate the predictions of these K models.

https://www.kaggle.com/ragnar123/4-kfold-densenet201


### 3.9 TTA

TTA (test time augmentations) is the idea to augment the test data several times and aggregate the predictions on these data. I did not find success with this, but many successful competitors use it. There is a nice notebook from Caleb about this technique:

https://www.kaggle.com/calebeverett/comparison-of-tta-prediction-procedures


### 3.10 Pseudo labeling test data

I did not try, but well doing competitors probably do.

Some time ago Chris Deotte provided a nice summary how it is done:

https://www.kaggle.com/cdeotte/pseudo-labeling-qda-0-969


### Using Mish activations

Efficientnet uses the relatively new swish activation function. In the last days I came across an alternative which seems to do better - mish()

As exchanging activations functions in Keras on the fly seems to be difficult, I am not sure if I can try something in that direction in the next days.


## Final words

My best result so far (before the last weekend) has been training a 4 model ensemble for 13 epochs with the 512x512 size images with over 40000 images and no TTA (LB score 0.975) and I am rather sure that training for more epochs and using more training images would improve the score a lot. Probably one has to use 224x224 images,

For the next weekend I tend to try smaller images sizes with more epochs and images for fun, because the usage of the external datasets is probably not allowed for the final LB run.
Maybe some TTA experiments, if time allows.

**Update**

I am sorry, I cannot remember whos kernel I copied - but the 4 model one is working great - thx.
Training the 4 model ensemble with all external training data pushed me to a LB score of 0.9837 with a training time of 2h20 - this should be a good base for further experiments :)

**Update**

Got an 0.984+ score with training 224x224 images.

**Update**

Could not test mish() - but training time seems to be 10% slower on EN models with my simple patch.

Could not test TTA - I think I simply did it wrong. No way to climb the LB score without these :)

Tried to run 48 epochs with a custom training loop, but it seems to use more then 3 hours.... what a pitty :)

I really like to try these improvements next week...


This notebook will serve as a summary of some of the knowledge I built up and should be able to reach a LB score of 0.967 with the base dataset in one way or another.
One should try different models and run it with validation calculations to find the best alpha for ensembling and then submit it training on both (train and val) datasets.

Thanks for the fish :)   (to everybody, who doesn't know this quote, pls google Douglas Adams)


In [None]:
!pip install -q efficientnet

In [None]:
import math, re, os, time
import tensorflow as tf, tensorflow.keras.backend as K
import numpy as np
from matplotlib import pyplot as plt
from kaggle_datasets import KaggleDatasets
import efficientnet.tfkeras as efn
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
from tensorflow.keras.applications import DenseNet201
from collections import namedtuple
#from sklearn.model_selection import KFold
print("Tensorflow version " + tf.__version__)

# Configurations

In [None]:
AUTO = tf.data.experimental.AUTOTUNE

# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

if tpu:
    BATCH_SIZE_MULT = 2
else:
    BATCH_SIZE_MULT = 1

# Data access
GCS_DS_PATH = KaggleDatasets().get_gcs_path('flower-classification')
print(GCS_DS_PATH)

# Configuration
#IMAGE_SIZE = [512, 512]
IMAGE_SIZE = [331, 331]
#IMAGE_SIZE = [192, 192]

img_size=IMAGE_SIZE[0]

# pls check all EPOCHS value changes, if you want to train for higher scores
EPOCHS = 3 #14

# bigger batch size is really useful
BATCH_SIZE = 16 * strategy.num_replicas_in_sync * BATCH_SIZE_MULT # doubling the batch_size rocks

# flag for TTA predictions
DO_TTA = False


In [None]:
if not tpu:
    MIXED_PRECISION = False
    XLA_ACCELERATE = False
else:
    MIXED_PRECISION = True
    XLA_ACCELERATE = True

if MIXED_PRECISION:
    from tensorflow.keras.mixed_precision import experimental as mixed_precision
    if tpu: policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16')
    else: policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
    mixed_precision.set_policy(policy)
    print('Mixed precision enabled')

if XLA_ACCELERATE: 
    if not tpu: # I cannot remember why, probably got a problem without this
        tf.config.optimizer.set_jit(True)
        print('Accelerated Linear Algebra enabled')

# Custom LR schedule

### Standard learning rate function

In [None]:
LR_START = 0.00001
LR_MAX = 0.00005 * strategy.num_replicas_in_sync
LR_MIN = 0.00001 #0.00001
LR_RAMPUP_EPOCHS = 5
LR_SUSTAIN_EPOCHS = 0
LR_EXP_DECAY = .8

#pushing this value often helps a bit (eg 1.2)
LR_MULTIPLIER = 1.0

@tf.function
def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
    return lr * LR_MULTIPLIER
    
lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=True)

rng = [i for i in range(EPOCHS)]
y = [lrfn(x) for x in rng]
plt.plot(rng, y)
print("Learning rate schedule: {:.3g} to {:.3g} to {:.3g}".format(y[0], max(y), y[-1]))


### Learning rate function for a follow up training

In [None]:
# this is very basic and not optimized
LR_START2 = 0.00001
LR_MAX2 = 0.000025 * strategy.num_replicas_in_sync
LR_MIN2 = 0.00001 #0.00001
LR_RAMPUP_EPOCHS2 = 5
LR_SUSTAIN_EPOCHS2 = 0
LR_EXP_DECAY2 = .8

LR_MULTIPLIER = 1.0

@tf.function
def lrfn2(epoch):
    if epoch < LR_RAMPUP_EPOCHS2:
        lr = (LR_MAX2 - LR_START2) / LR_RAMPUP_EPOCHS2 * epoch + LR_START2
    elif epoch < LR_RAMPUP_EPOCHS2 + LR_SUSTAIN_EPOCHS2:
        lr = LR_MAX2
    else:
        lr = (LR_MAX2 - LR_MIN2) * LR_EXP_DECAY2**(epoch - LR_RAMPUP_EPOCHS2 - LR_SUSTAIN_EPOCHS2) + LR_MIN2
    return lr
    
lr_callback2 = tf.keras.callbacks.LearningRateScheduler(lrfn2, verbose=True)

rng = [i for i in range(EPOCHS)]
y = [lrfn2(x) for x in rng]
plt.plot(rng, y)
print("Learning rate schedule: {:.3g} to {:.3g} to {:.3g}".format(y[0], max(y), y[-1]))

### Learning rate schedule with cosine anneal

In [None]:
@tf.function
def lrfnCosineDecay(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    # cosine anneal
    else:
        lr = LR_MIN + (LR_MAX - LR_MIN) * (1 + math.cos(math.pi * epoch / EPOCHS)) / 2
    return lr * LR_MULTIPLIER

rng = [i for i in range(EPOCHS)]
y = [lrfnCosineDecay(x) for x in rng]
plt.plot(rng, y)
print("Learning rate schedule: {:.3g} to {:.3g} to {:.3g}".format(y[0], max(y), y[-1]))

lr_callbackCosineDecay = tf.keras.callbacks.LearningRateScheduler(lrfnCosineDecay, verbose=True)

In [None]:
GCS_PATH_SELECT = { # available image sizes
    192: GCS_DS_PATH + '/tfrecords-jpeg-192x192',
    224: GCS_DS_PATH + '/tfrecords-jpeg-224x224',
    331: GCS_DS_PATH + '/tfrecords-jpeg-331x331',
    512: GCS_DS_PATH + '/tfrecords-jpeg-512x512'
}
GCS_PATH = GCS_PATH_SELECT[IMAGE_SIZE[0]]


TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')
VALIDATION_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')
TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test/*.tfrec') # predictions on this dataset should be submitted for the competition

# watch out for overfitting!
SKIP_VALIDATION = False
if SKIP_VALIDATION:
    TRAINING_FILENAMES = TRAINING_FILENAMES + VALIDATION_FILENAMES

# in the beginning I tried to remove suspicious samples - a relict
#VALIDATION_MISMATCHES_IDS = ['55a883e16','f4ec48685','2023d3cac','f8eab6777','741999f79','861282b96','28594d9ce','bab3ef1f5','617a30d60','4571b9509','6a3a28a06','9b8f2f5bd','293c37e25','7472eb523','0bf0b39b3','c846d8649','9ee42218f','f4ec48685']


Unhide to see `CLASSES`:

In [None]:
CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102

# Helper Functions
## Visualization

In [None]:
# numpy and matplotlib defaults
np.set_printoptions(threshold=15, linewidth=80)

def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
        numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is the case for test data)
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)
    
def display_batch_of_images(databatch, predictions=None, figsize  = 13.0):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE =  figsize
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else CLASSES[label]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

def display_confusion_matrix(cmat, score, precision, recall):
    plt.figure(figsize=(15,15))
    ax = plt.gca()
    ax.matshow(cmat, cmap='Reds')
    ax.set_xticks(range(len(CLASSES)))
    ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    ax.set_yticks(range(len(CLASSES)))
    ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    titlestring = ""
    if score is not None:
        titlestring += 'f1 = {:.3f} '.format(score)
    if precision is not None:
        titlestring += '\nprecision = {:.3f} '.format(precision)
    if recall is not None:
        titlestring += '\nrecall = {:.3f} '.format(recall)
    if len(titlestring) > 0:
        ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
    plt.show()
    
def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

## Datasets Functions

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    return image, label # returns a dataset of (image, label) pairs

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum # returns a dataset of image(s)

def read_labeled_id_tfrecord(example):
    LABELED_ID_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_ID_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    idnum =  example['id']
    return image, label, idnum # returns a dataset of (image, label, idnum) triples

def load_dataset(filenames, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_id_tfrecord  if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
        
    return dataset

def load_dataset_with_id(filenames, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_id_tfrecord, num_parallel_calls=AUTO)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

def data_augment(image, label):
    # data augmentation. Thanks to the dataset.prefetch(AUTO) statement in the next function (below),
    # this happens essentially for free on TPU. Data pipeline code is executed on the "CPU" part
    # of the TPU while the TPU itself is computing gradients.
    image = tf.image.random_flip_left_right(image)
    #image = tf.image.random_brightness(image, 0.05, seed=None)
    #image = tf.image.random_contrast(image, 0.8, 1.2, seed=None)
    
    #random cut
    image= random_blockout(image)
    
    return image, label 


def get_training_dataset(do_aug=True,do_repeat=True):
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    
    #dataset = dataset.filter(lambda image, label, idnum: tf.reduce_sum(tf.cast(idnum == VALIDATION_MISMATCHES_IDS, tf.int32))==0)
    dataset = dataset.map(lambda image, label, idnum: [image, label])
    
    # add an additional argument to this function with False default and call it with True if you want to try undersampling
    #if do_undersample:
    #    dataset = dataset.filter(undersample_filter)
    
    if do_repeat:
        dataset = dataset.repeat() # the training dataset must repeat for several epochs
    
    if do_aug:
        dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
        #dataset = dataset.map(cropandresize, num_parallel_calls=AUTO)
        
    if do_repeat:
        dataset = dataset.shuffle(2048)
        dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
        dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset


def get_validation_dataset(ordered=False, repeated=False):
    dataset = load_dataset(VALIDATION_FILENAMES, labeled=True, ordered=ordered)
    
    #dataset = dataset.filter(lambda image, label, idnum: tf.reduce_sum(tf.cast(idnum == VALIDATION_MISMATCHES_IDS, tf.int32))==0)
    dataset = dataset.map(lambda image, label, idnum: [image, label])
    
    if repeated:
        dataset = dataset.repeat()
        dataset = dataset.shuffle(2048)
        
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=repeated)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_validation_dataset_with_id(ordered=False):
    dataset = load_dataset_with_id(VALIDATION_FILENAMES, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_test_dataset(ordered=False, do_aug=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    
    if do_aug:
        dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
        dataset = dataset.map(transform, num_parallel_calls=AUTO)
        
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def count_data_items(filenames):
    #dataset = load_dataset(filenames,labeled = False)
    #dataset = dataset.map(lambda image, idnum: idnum)
    #dataset = dataset.filter(lambda idnum: tf.reduce_sum(tf.cast(idnum == VALIDATION_MISMATCHES_IDS, tf.int32))==0)
    #uids = next(iter(dataset.batch(26000))).numpy().astype('U') 
    #return len(np.unique(uids))    
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

def int_div_round_up(a, b):
    return (a + b - 1) // b

NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
NUM_VALIDATION_IMAGES = (1 - SKIP_VALIDATION) * count_data_items(VALIDATION_FILENAMES)
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE
VALIDATION_STEPS = int_div_round_up(NUM_VALIDATION_IMAGES, BATCH_SIZE)
print('Dataset: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))


## Augmentations

### 1) Chris Deottes affine transforms

In [None]:
def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transformmatrix which transforms indicies
        
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    shear = math.pi * shear / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape( tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )
        
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape( tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3] )    
    
    # ZOOM MATRIX
    zoom_matrix = tf.reshape( tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0),[3,3] )
    
    # SHIFT MATRIX
    shift_matrix = tf.reshape( tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3] )
    
    return K.dot(K.dot(rotation_matrix, shear_matrix), K.dot(zoom_matrix, shift_matrix))

In [None]:
def transform(image,label):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    
    DIM = IMAGE_SIZE[0]
    XDIM = DIM%2 #fix for size 331
    
    rot = 15. * tf.random.normal([1],dtype='float32')
    shr = 5. * tf.random.normal([1],dtype='float32') 
    h_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    w_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    #w_zoom = h_zoom
    h_shift = 16. * tf.random.normal([1],dtype='float32') 
    w_shift = 16. * tf.random.normal([1],dtype='float32') 
  
    # GET TRANSFORMATION MATRIX
    m = get_mat(rot,shr,h_zoom,w_zoom,h_shift,w_shift) 

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3))
    #print('transform')
        

    return tf.reshape(d,[DIM,DIM,3]),label

### Unused augmentation (maybe for TTA)

In [None]:
#cropandresize
def cropandresize(image,label):

    #box[0] = [0,0,int(IMAGE_SIZE[0]/2),int(IMAGE_SIZE[1]/2)]
    #box[1] = [int(IMAGE_SIZE[0]/2),0,int(IMAGE_SIZE[0]/2),int(IMAGE_SIZE[1]/2)]
    #box[2] = [0,int(IMAGE_SIZE[0]/2),int(IMAGE_SIZE[0]/2),int(IMAGE_SIZE[1]/2)]
    #box[3] = [int(IMAGE_SIZE[0]/2),int(IMAGE_SIZE[0]/2),int(IMAGE_SIZE[0]/2),int(IMAGE_SIZE[1]/2)]
    #box[4] = [int(IMAGE_SIZE[0]/4),int(IMAGE_SIZE[0]/4),int(IMAGE_SIZE[0]/2),int(IMAGE_SIZE[1]/2)]
    
    rnd = tf.random.uniform(shape=[], minval=0, maxval=7, dtype=tf.int64) 
    
    if rnd == 0:
        image = tf.image.crop_to_bounding_box(image, 0, 0, int(2*IMAGE_SIZE[0]/3),int(2*IMAGE_SIZE[0]/3))
    elif rnd == 1:
        image = tf.image.crop_to_bounding_box(image, int(IMAGE_SIZE[0]/3), 0, int(2*IMAGE_SIZE[0]/3),int(2*IMAGE_SIZE[0]/3))
    elif rnd == 2:
        image = tf.image.crop_to_bounding_box(image, 0, int(IMAGE_SIZE[0]/3), int(2*IMAGE_SIZE[0]/3),int(2*IMAGE_SIZE[0]/3))
    elif rnd == 3:
        image = tf.image.crop_to_bounding_box(image, int(IMAGE_SIZE[0]/3), int(IMAGE_SIZE[0]/3), int(2*IMAGE_SIZE[0]/3),int(2*IMAGE_SIZE[0]/3))
    elif rnd == 4:
        image = tf.image.crop_to_bounding_box(image, int(IMAGE_SIZE[0]/6), int(IMAGE_SIZE[0]/6), int(2*IMAGE_SIZE[0]/3),int(2*IMAGE_SIZE[0]/3))
    else:
        image =  image
    
    #image = tf.image.resize(image, size=[IMAGE_SIZE[0],IMAGE_SIZE[0]])
    return tf.image.resize(image, size=[IMAGE_SIZE[0],IMAGE_SIZE[0]]),label     #tf.reshape(image,[IMAGE_SIZE[0],IMAGE_SIZE[0],3]),label

### Used augmentation

In [None]:
#simple and effective

def random_blockout(img, sl=0.1, sh=0.2, rl=0.4):

    h, w, c = img_size, img_size, 3
    origin_area = tf.cast(h*w, tf.float32)

    e_size_l = tf.cast(tf.round(tf.sqrt(origin_area * sl * rl)), tf.int32)
    e_size_h = tf.cast(tf.round(tf.sqrt(origin_area * sh / rl)), tf.int32)

    e_height_h = tf.minimum(e_size_h, h)
    e_width_h = tf.minimum(e_size_h, w)

    erase_height = tf.random.uniform(shape=[], minval=e_size_l, maxval=e_height_h, dtype=tf.int32)
    erase_width = tf.random.uniform(shape=[], minval=e_size_l, maxval=e_width_h, dtype=tf.int32)

    erase_area = tf.zeros(shape=[erase_height, erase_width, c])
    erase_area = tf.cast(erase_area, tf.uint8)

    pad_h = h - erase_height
    pad_top = tf.random.uniform(shape=[], minval=0, maxval=pad_h, dtype=tf.int32)
    pad_bottom = pad_h - pad_top

    pad_w = w - erase_width
    pad_left = tf.random.uniform(shape=[], minval=0, maxval=pad_w, dtype=tf.int32)
    pad_right = pad_w - pad_left

    erase_mask = tf.pad([erase_area], [[0,0],[pad_top, pad_bottom], [pad_left, pad_right], [0,0]], constant_values=1)
    erase_mask = tf.squeeze(erase_mask, axis=0)
    erased_img = tf.multiply(tf.cast(img,tf.float32), tf.cast(erase_mask, tf.float32))

    return tf.cast(erased_img, img.dtype)


In [None]:
### Peek at training data

training_dataset = get_training_dataset(do_aug=False)
training_dataset = training_dataset.unbatch().batch(20)
train_batch = iter(training_dataset)

In [None]:
### Run this cell again for next set of images
display_batch_of_images(next(train_batch))

### Class weights

How to calc them:

In [None]:
import datetime
import tqdm
import json
from collections import Counter
import gc
gc.enable()

def get_training_dataset_raw():

    dataset = load_dataset(TRAINING_FILENAMES, labeled=True, ordered=False)
    return dataset


raw_training_dataset = get_training_dataset_raw()

label_counter = Counter()
for images, labels, id in raw_training_dataset:
    label_counter.update([labels.numpy()])

del raw_training_dataset    

TARGET_NUM_PER_CLASS = 122

def get_weight_for_class(class_id):
    counting = label_counter[class_id]
    
    weight = TARGET_NUM_PER_CLASS / counting
    
    return weight

# one version for tpu one for gpu - tf2.1 bug
weight_per_class = {class_id: get_weight_for_class(class_id) for class_id in range(104)}
weight_per_class_wo_id = {get_weight_for_class(class_id) for class_id in range(104)}

In [None]:
print(weight_per_class)

### Undersampling

In [None]:
# not used here - see get_taining_data()
UNDERSAMPLE_CLASS_IDS0 = [103]           # 30%
UNDERSAMPLE_CLASS_IDS1 = [67, 4]         # 35%
UNDERSAMPLE_CLASS_IDS2 = [49,13,0]         # 50%
UNDERSAMPLE_CLASS_IDS3 = [53,48,73,47,102] # 50%

def undersample_filter(image, label):
    rnd = tf.random.uniform([1], minval=0, maxval=1, dtype='float32', seed=0)
    
    res = tf.math.equal(tf.reduce_sum(tf.cast(label == UNDERSAMPLE_CLASS_IDS0, tf.int32)), 0)
    if not res and rnd >= 0.7:
        return False
    
    res = tf.math.equal(tf.reduce_sum(tf.cast(label == UNDERSAMPLE_CLASS_IDS1, tf.int32)), 0)
    if not res and rnd >= 0.65:
        return False
    
    res = tf.math.equal(tf.reduce_sum(tf.cast(label == UNDERSAMPLE_CLASS_IDS2, tf.int32)), 0)
    if not res and rnd >= 0.5:
        return False
    
    res = tf.math.equal(tf.reduce_sum(tf.cast(label == UNDERSAMPLE_CLASS_IDS3, tf.int32)), 0)
    if not res and rnd >= 0.4:
        return False

    return True

## Training Model

### Load Model into TPU

In [None]:
# Need this line so Google will recite some incantations
# for Turing to magically load the model onto the TPU
with strategy.scope():
    enet = efn.EfficientNetB7(
        input_shape=(None,None,3), #setting the shape to None allows multiple input sizes
        weights='noisy-student',
        include_top=False
    )

    model = tf.keras.Sequential([
        enet,
        tf.keras.layers.GlobalAveragePooling2D(),
        #tf.keras.layers.Dense(2000, activation='relu'), # extra layer or dropout did not improve the model 
        tf.keras.layers.Dense(len(CLASSES), activation='softmax',dtype='float32')
    ])
        
model.compile(
    optimizer=tf.keras.optimizers.Adam(lr=0.0001),
    loss = 'sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy']
)
model.summary()

## Training

** additional callbacks could be added, plus the way to add class_weights to training **

In [None]:
#scheduler = tf.keras.callbacks.ReduceLROnPlateau(patience=3, verbose=1)
#earlystopping = tf.keras.callbacks.EarlyStopping(monitor='val_sparse_categorical_accuracy', min_delta=0.001, patience=4, verbose=1, mode='auto', baseline=None, restore_best_weights=True)

history = model.fit(
    get_training_dataset(), 
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=EPOCHS,
    callbacks=[lr_callback], #earlystopping],
    validation_data=None if SKIP_VALIDATION else get_validation_dataset()
    #class_weight=weight_per_class
)

In [None]:
# save the weights and restart training with different image size and learning rate for progressive resizing
#model.save_weights('11.h5')

IMAGE_SIZE = [512, 512] #changing input image sizes on the fly
img_size=IMAGE_SIZE[0]
EPOCHS = 3 # 18
GCS_PATH = GCS_PATH_SELECT[IMAGE_SIZE[0]]

TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')
VALIDATION_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')


In [None]:
history = model.fit(
    get_training_dataset(), 
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=EPOCHS,
    callbacks=[lr_callback2], #earlystopping],
    validation_data=None if SKIP_VALIDATION else get_validation_dataset()
    #class_weight=weight_per_class
)

In [None]:
if not SKIP_VALIDATION:
    display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 211)
    display_training_curves(history.history['sparse_categorical_accuracy'], history.history['val_sparse_categorical_accuracy'], 'accuracy', 212)

In [None]:
#when we run a prediction on valid now, we can del the model to save memory and use the prediction later
#del model
#gc.collect()

## Training model 2 - using the custom training loop:

In [None]:
IMAGE_SIZE = [331, 331]
img_size=IMAGE_SIZE[0]
EPOCHS = 3 #40
GCS_PATH = GCS_PATH_SELECT[IMAGE_SIZE[0]]

TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')
VALIDATION_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')


### this would be the standard way
with strategy.scope():
    rnet = efn.EfficientNetB6(
        input_shape=(None,None,3), #(IMAGE_SIZE[0], IMAGE_SIZE[1], 3),
        #weights='imagenet',
        weights='noisy-student',
        include_top=False
    )

    model2 = tf.keras.Sequential([
        rnet,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(len(CLASSES), activation='softmax',dtype='float32')
    ])
        
model2.compile(
    optimizer=tf.keras.optimizers.Adam(lr=0.0001),
    loss = 'sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy']
)
model2.summary()

### this would be the standard way - part2
history2 = model2.fit(
    get_training_dataset(), 
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=EPOCHS, 
    callbacks=[lr_callback],
    validation_data=None if SKIP_VALIDATION else get_validation_dataset()
    #class_weight=weight_per_class_wo_id
)

### Martin Goerners custom training loop code

In [None]:
with strategy.scope():
    pretrained_model = efn.EfficientNetB5(weights='noisy-student', include_top=False ,input_shape=(None,None,3)) #[*IMAGE_SIZE, 3])
    #pretrained_model.trainable = True # False = transfer learning, True = fine-tuning
    
    model2 = tf.keras.Sequential([
        pretrained_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(len(CLASSES), activation='softmax', dtype='float32') # setting dtype='float32' is necessary for mixed precision usage
    ])
    model2.summary()
    
    # Instiate optimizer with learning rate schedule
    class LRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __call__(self, step):
            return lrfn(epoch=step//STEPS_PER_EPOCH)
    optimizer = tf.keras.optimizers.Adam(learning_rate=LRSchedule())
        
    # this also works but is not very readable
    #optimizer = tf.keras.optimizers.Adam(learning_rate=lambda: lrfn(tf.cast(optimizer.iterations, tf.float32)//STEPS_PER_EPOCH))
    
    # Instantiate metrics
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    train_loss = tf.keras.metrics.Sum()
    valid_loss = tf.keras.metrics.Sum()
    
    # Loss
    # The recommendation from the Tensorflow custom training loop  documentation is:
    # loss_fn = lambda a,b: tf.nn.compute_average_loss(tf.keras.losses.sparse_categorical_crossentropy(a,b), global_batch_size=BATCH_SIZE)
    # https://www.tensorflow.org/tutorials/distribute/custom_training#define_the_loss_function
    # This works too and shifts all the averaging to the training loop which is easier:
    loss_fn = tf.keras.losses.sparse_categorical_crossentropy

In [None]:
STEPS_PER_TPU_CALL = int(NUM_TRAINING_IMAGES/BATCH_SIZE) #99 original
VALIDATION_STEPS_PER_TPU_CALL = 29  # random?!

@tf.function
def train_step(data_iter):
    def train_step_fn(images, labels):
        with tf.GradientTape() as tape:
            probabilities = model2(images, training=True)
            loss = loss_fn(labels, probabilities)
        grads = tape.gradient(loss, model2.trainable_variables)
        optimizer.apply_gradients(zip(grads, model2.trainable_variables))
        
        #update metrics
        train_accuracy.update_state(labels, probabilities)
        train_loss.update_state(loss)
        
    # this loop runs on the TPU
    for _ in tf.range(STEPS_PER_TPU_CALL):
        strategy.experimental_run_v2(train_step_fn, next(data_iter))

@tf.function
def valid_step(data_iter):
    def valid_step_fn(images, labels):
        probabilities = model2(images, training=False)
        loss = loss_fn(labels, probabilities)
        
        # update metrics
        valid_accuracy.update_state(labels, probabilities)
        valid_loss.update_state(loss)
        
    # this loop runs on the TPU
    for _ in tf.range(VALIDATION_STEPS_PER_TPU_CALL):
        strategy.experimental_run_v2(valid_step_fn, next(data_iter))

In [None]:
start_time = epoch_start_time = time.time()

# distribute the datset according to the strategy
train_dist_ds = strategy.experimental_distribute_dataset(get_training_dataset())
# Hitting End Of Dataset exceptions is a problem in this setup. Using a repeated validation set instead.
# This will introduce a slight inaccuracy because the validation dataset now has some repeated elements.
valid_dist_ds = strategy.experimental_distribute_dataset(get_validation_dataset(repeated=True))

print("Training steps per epoch:", STEPS_PER_EPOCH, "in increments of", STEPS_PER_TPU_CALL)
print("Validation images:", NUM_VALIDATION_IMAGES,
      "Batch size:", BATCH_SIZE,
      "Validation steps:", NUM_VALIDATION_IMAGES//BATCH_SIZE, "in increments of", VALIDATION_STEPS_PER_TPU_CALL)
print("Repeated validation images:", int_div_round_up(NUM_VALIDATION_IMAGES, BATCH_SIZE*VALIDATION_STEPS_PER_TPU_CALL)*VALIDATION_STEPS_PER_TPU_CALL*BATCH_SIZE-NUM_VALIDATION_IMAGES)
History = namedtuple('History', 'history')
history = History(history={'loss': [], 'val_loss': [], 'sparse_categorical_accuracy': [], 'val_sparse_categorical_accuracy': []})

epoch = 0
train_data_iter = iter(train_dist_ds) # the training data iterator is repeated and it is not reset
                                      # for each validation run (same as model.fit)
valid_data_iter = iter(valid_dist_ds) # the validation data iterator is repeated and it is not reset
                                      # for each validation run (different from model.fit whre the
                                      # recommendation is to use a non-repeating validation dataset)

step = 0
epoch_steps = 0
while True:
    
    # run training step
    train_step(train_data_iter)
    epoch_steps += STEPS_PER_TPU_CALL
    step += STEPS_PER_TPU_CALL
    print('=', end='', flush=True)

    # validation run at the end of each epoch
    if (step // STEPS_PER_EPOCH) > epoch:
        print('|', end='', flush=True)
        
        # validation run
        valid_epoch_steps = 0
        for _ in range(int_div_round_up(NUM_VALIDATION_IMAGES, BATCH_SIZE*VALIDATION_STEPS_PER_TPU_CALL)):
            valid_step(valid_data_iter)
            valid_epoch_steps += VALIDATION_STEPS_PER_TPU_CALL
            print('=', end='', flush=True)

        # compute metrics
        history.history['sparse_categorical_accuracy'].append(train_accuracy.result().numpy())
        history.history['val_sparse_categorical_accuracy'].append(valid_accuracy.result().numpy())
        history.history['loss'].append(train_loss.result().numpy() / (BATCH_SIZE*epoch_steps))
        history.history['val_loss'].append(valid_loss.result().numpy() / (BATCH_SIZE*valid_epoch_steps))
        
        # report metrics
        epoch_time = time.time() - epoch_start_time
        print('\nEPOCH {:d}/{:d}'.format(epoch+1, EPOCHS))
        print('time: {:0.1f}s'.format(epoch_time),
              'loss: {:0.4f}'.format(history.history['loss'][-1]),
              'accuracy: {:0.4f}'.format(history.history['sparse_categorical_accuracy'][-1]),
              'val_loss: {:0.4f}'.format(history.history['val_loss'][-1]),
              'val_acc: {:0.4f}'.format(history.history['val_sparse_categorical_accuracy'][-1]),
              'lr: {:0.4g}'.format(lrfn(epoch)),
              'steps/val_steps: {:d}/{:d}'.format(epoch_steps, valid_epoch_steps), flush=True)
        
        # set up next epoch
        epoch = step // STEPS_PER_EPOCH
        epoch_steps = 0
        epoch_start_time = time.time()
        train_accuracy.reset_states()
        valid_accuracy.reset_states()
        valid_loss.reset_states()
        train_loss.reset_states()
        if epoch >= EPOCHS:
            break

optimized_ctl_training_time = time.time() - start_time
print("OPTIMIZED CTL TRAINING TIME: {:0.1f}s".format(optimized_ctl_training_time))

### history2 not defined here
if not SKIP_VALIDATION:
    display_training_curves(history2.history['loss'], history2.history['val_loss'], 'loss', 211)
    display_training_curves(history2.history['sparse_categorical_accuracy'], history2.history['val_sparse_categorical_accuracy'], 'accuracy', 212)

**B6 with 512x512 gets oom**

# Finding best alpha
Our final model is just mix of two presented above. In the first commit it was arithmetic mean (alpha = 0.5). Note that using validation data as training will fit your model with accuracy equal 1.0.
Thus formula presented below of linear combination of models will work only with validation data:

prob = alpha  prob(model) + (1 - alpha)  prob(model2)

In [None]:
IMAGE_SIZE = [512, 512] 
img_size=IMAGE_SIZE[0] 
GCS_PATH = GCS_PATH_SELECT[IMAGE_SIZE[0]]

TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec') 
VALIDATION_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')
TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test/*.tfrec')

In [None]:
if not SKIP_VALIDATION:
    cmdataset = get_validation_dataset(ordered=True) # since we are splitting the dataset and iterating separately on images and labels, order matters.
    images_ds = cmdataset.map(lambda image, label: image)
    labels_ds = cmdataset.map(lambda image, label: label).unbatch()
    cm_correct_labels = next(iter(labels_ds.batch(NUM_VALIDATION_IMAGES))).numpy() # get everything as one batch
    m = model.predict(images_ds)    
    m2 = model2.predict(images_ds)
    scores = []
    for alpha in np.linspace(0,1,100):
        cm_probabilities = alpha*m+(1-alpha)*m2
        cm_predictions = np.argmax(cm_probabilities, axis=-1)
        scores.append(f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro'))
        
    print("Correct   labels: ", cm_correct_labels.shape, cm_correct_labels)
    print("Predicted labels: ", cm_predictions.shape, cm_predictions)
    plt.plot(scores)
    best_alpha = np.argmax(scores)/100
    cm_probabilities = best_alpha*m+(1-best_alpha)*m2
    cm_predictions = np.argmax(cm_probabilities, axis=-1)
else:
    best_alpha = 0.5 #0.44

In [None]:
print(best_alpha)

# there should be code in the competition discussion or a sample model which shows how to implement this for 3 or more models


# Mismatches on a validation data

In [None]:
#best_alpha=0.60

In [None]:
if not SKIP_VALIDATION:
    
    cmdataset_with_id = get_validation_dataset_with_id(ordered=True)
    ids_ds = cmdataset_with_id.map(lambda image, label, idnum: idnum).unbatch()
    ids = next(iter(ids_ds.batch(NUM_VALIDATION_IMAGES))).numpy().astype('U') # get everything as one batch

    val_batch = iter(cmdataset.unbatch().batch(1))
    noip = sum(cm_predictions!=cm_correct_labels)
    print('Number of incorrect predictions: ' + str(noip) + ' ('+str(round(noip/NUM_VALIDATION_IMAGES*100,1))+'%)')
    for fi in range(NUM_VALIDATION_IMAGES):
        x = next(val_batch)
        if cm_predictions[fi] != cm_correct_labels[fi]:
            print("Image id: '" + ids[fi] + "'")
            display_batch_of_images(x,np.array([cm_predictions[fi]]),figsize = 4)

# Confusion matrix

In [None]:
# better colors for the cmat would be nice 
if not SKIP_VALIDATION:
    cmat = confusion_matrix(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)))
    score = f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
    precision = precision_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
    recall = recall_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
    #cmat = (cmat.T / cmat.sum(axis=1)).T # normalized
    display_confusion_matrix(cmat, score, precision, recall)
    print('f1 score: {:.3f}, precision: {:.3f}, recall: {:.3f}'.format(score, precision, recall))

# Predictions

In [None]:
if not DO_TTA:
    test_ds = get_test_dataset(ordered=True) # since we are splitting the dataset and iterating separately on images and ids, order matters.

    print('Computing predictions...')

    test_images_ds = test_ds.map(lambda image, idnum: image)
    probabilities = best_alpha*model.predict(test_images_ds) + (1-best_alpha)*model2.predict(test_images_ds)
    predictions = np.argmax(probabilities, axis=-1)
    print(predictions)

    print('Generating submission.csv file...')
    test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
    test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U') # all in one batch
    np.savetxt('submission.csv', np.rec.fromarrays([test_ids, predictions]), fmt=['%s', '%d'], delimiter=',', header='id,label', comments='')
    
else:
    
    print('Computing predictions with TTA ...')
    preds_tta = []
    
    for i in range(8):
        # did not have time to work out the TTA code - my implemented augmentations are probably too much
        test_ds = get_test_dataset(ordered=True, do_aug=True) # since we are splitting the dataset and iterating separately on images and ids, order matters.
        test_images_ds = test_ds.map(lambda image, idnum: image)
        probabilities = best_alpha*model.predict(test_images_ds) + (1-best_alpha)*model2.predict(test_images_ds)
        predictions = np.argmax(probabilities, axis=-1)
        preds_tta.append(predictions)
    
    final_pred = np.mean(preds_tta, axis=0)
    
    print(final_pred)

    print('Generating submission.csv file...')
    test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
    test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U') # all in one batch
    np.savetxt('submission.csv', np.rec.fromarrays([test_ids, final_pred]), fmt=['%s', '%d'], delimiter=',', header='id,label', comments='')

print('Done')