In [None]:
!pip install -q efficientnet

In [None]:
import gc
import math, re, os, time
import tensorflow as tf, tensorflow.keras.backend as K
#tf.config.experimental_run_functions_eagerly(True)
import numpy as np
from collections import namedtuple
from matplotlib import pyplot as plt
from kaggle_datasets import KaggleDatasets
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
print("Tensorflow version " + tf.__version__)
AUTO = tf.data.experimental.AUTOTUNE

# TPU or GPU detection

In [None]:
# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

# Competition data access
TPUs read data directly from Google Cloud Storage (GCS). This Kaggle utility will copy the dataset to a GCS bucket co-located with the TPU. If you have multiple datasets attached to the notebook, you can pass the name of a specific dataset to the get_gcs_path function. The name of the dataset is the name of the directory it is mounted in. Use `!ls /kaggle/input/` to list attached datasets.

In [None]:
GCS_DS_PATH = KaggleDatasets().get_gcs_path('flower-classification') # you can list the bucket with "!gsutil ls $GCS_DS_PATH"

In [None]:
#EXT_DS_PATH = KaggleDatasets().get_gcs_path('oxford-102-for-tpu-competition')

In [None]:
EXT_DS_PATH = KaggleDatasets().get_gcs_path('tf-flower-photo-tfrec')

# Configuration

In [None]:
USE_EXTERNAL = False
SKIP_VALIDATION = True
if tpu:
    SIZE = 512
else:
    SIZE = 192
IMAGE_SIZE = [SIZE, SIZE] # At this size, a GPU will run out of memory. Use the TPU.
                        # For GPU training, please select 224 x 224 px image size.
EPOCHS = 18
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
AUG_BATCH = BATCH_SIZE

In [None]:
GCS_PATH_SELECT = { # available image sizes
    192: GCS_DS_PATH + '/tfrecords-jpeg-192x192',
    224: GCS_DS_PATH + '/tfrecords-jpeg-224x224',
    331: GCS_DS_PATH + '/tfrecords-jpeg-331x331',
    512: GCS_DS_PATH + '/tfrecords-jpeg-512x512'
}
#GCS_PATH = GCS_PATH_SELECT[IMAGE_SIZE[0]]

#TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')
#VALIDATION_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')
#TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test/*.tfrec') # predictions on this dataset should be submitted for the competition


In [None]:
def get_filenames(size):
    path = GCS_PATH_SELECT[size]
    train_filenames = tf.io.gfile.glob(path + '/train/*.tfrec')
    valid_filenames = tf.io.gfile.glob(path + '/val/*.tfrec')
    test_filenames = tf.io.gfile.glob(path + '/test/*.tfrec')
    
    return train_filenames, valid_filenames, test_filenames

In [None]:
if USE_EXTERNAL:
    EXT_PATH_SELECT = { # available image sizes
        192: '/tfrecords-jpeg-192x192',
        224: '/tfrecords-jpeg-224x224',
        331: '/tfrecords-jpeg-331x331',
        512: '/tfrecords-jpeg-512x512'
    }
    #EXT_PATH = EXT_PATH_SELECT[IMAGE_SIZE[0]]
    #EXT_TRAINING_FILENAMES = []
    #if 1:
    #    EXT_TRAINING_FILENAMES += tf.io.gfile.glob(EXT_DS_PATH + '/imagenet' + EXT_PATH + '/*.tfrec')
    #if 2:
    #    EXT_TRAINING_FILENAMES += tf.io.gfile.glob(EXT_DS_PATH +  '/inaturalist_1' + EXT_PATH + '/*.tfrec')
    #if 3:
    #    EXT_TRAINING_FILENAMES += tf.io.gfile.glob(EXT_DS_PATH +  '/openimage' + EXT_PATH + '/*.tfrec')
    #if 4:
    #    EXT_TRAINING_FILENAMES += tf.io.gfile.glob(EXT_DS_PATH +  '/oxford_102' + EXT_PATH + '/*.tfrec')
    #if 5:
    #    EXT_TRAINING_FILENAMES += tf.io.gfile.glob(EXT_DS_PATH +  '/tf_flowers' + EXT_PATH + '/*.tfrec')
    
    #print(EXT_TRAINING_FILENAMES)

In [None]:
def get_ext_filenames(size, imagenet=False, inaturalist=False, openimage=False, oxford=False, tfflowers=False):
    path = EXT_PATH_SELECT[size]
    train_filenames = []
    if imagenet:
        train_filenames += tf.io.gfile.glob(EXT_DS_PATH + '/imagenet' + path + '/*.tfrec')
    if inaturalist:
        train_filenames += tf.io.gfile.glob(EXT_DS_PATH + '/inaturalist_1' + path + '/*.tfrec')
    if openimage:
        train_filenames += tf.io.gfile.glob(EXT_DS_PATH + '/openimage' + path + '/*.tfrec')
    if oxford:
        train_filenames += tf.io.gfile.glob(EXT_DS_PATH + '/oxford_102' + path + '/*.tfrec')
    if tfflowers:
        train_filenames += tf.io.gfile.glob(EXT_DS_PATH + '/tf_flowers' + path + '/*.tfrec')
        
    return train_filenames
        

In [None]:
CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102

## Visualization utilities
data -> pixels, nothing of much interest for the machine learning practitioner in this section.

In [None]:
# numpy and matplotlib defaults
np.set_printoptions(threshold=15, linewidth=80)

def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
        numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is the case for test data)
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)
    
def display_batch_of_images(databatch, predictions=None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        #title = '' if label is None else CLASSES[label]
        if label is None:
            title = ''
        elif isinstance(label, int):
            title = CLASSES[label]
        else:
            idx = np.argmax(label, axis=0)
            title = CLASSES[idx]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

def display_confusion_matrix(cmat, score, precision, recall):
    plt.figure(figsize=(15,15))
    ax = plt.gca()
    ax.matshow(cmat, cmap='Reds')
    ax.set_xticks(range(len(CLASSES)))
    ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    ax.set_yticks(range(len(CLASSES)))
    ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    titlestring = ""
    if score is not None:
        titlestring += 'f1 = {:.3f} '.format(score)
    if precision is not None:
        titlestring += '\nprecision = {:.3f} '.format(precision)
    if recall is not None:
        titlestring += '\nrecall = {:.3f} '.format(recall)
    if len(titlestring) > 0:
        ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
    plt.show()
    
def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

# Datasets

In [None]:
def get_batch_transformatioin_matrix(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    """Returns a tf.Tensor of shape (batch_size, 3, 3) with each element along the 1st axis being
       an image transformation matrix (which transforms indicies).

    Args:
        rotation: 1-D Tensor with shape [batch_size].
        shear: 1-D Tensor with shape [batch_size].
        height_zoom: 1-D Tensor with shape [batch_size].
        width_zoom: 1-D Tensor with shape [batch_size].
        height_shift: 1-D Tensor with shape [batch_size].
        width_shift: 1-D Tensor with shape [batch_size].
        
    Returns:
        A 3-D Tensor with shape [batch_size, 3, 3].
    """    

    # A trick to get batch_size
    batch_size = tf.cast(tf.reduce_sum(tf.ones_like(rotation)), tf.int64)    
    
    # CONVERT DEGREES TO RADIANS
    rotation = tf.constant(math.pi) * rotation / 180.0
    shear = tf.constant(math.pi) * shear / 180.0

    # shape = (batch_size,)
    one = tf.ones_like(rotation, dtype=tf.float32)
    zero = tf.zeros_like(rotation, dtype=tf.float32)
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation) # shape = (batch_size,)
    s1 = tf.math.sin(rotation) # shape = (batch_size,)

    # Intermediate matrix for rotation, shape = (9, batch_size) 
    rotation_matrix_temp = tf.stack([c1, s1, zero, -s1, c1, zero, zero, zero, one], axis=0)
    # shape = (batch_size, 9)
    rotation_matrix_temp = tf.transpose(rotation_matrix_temp)
    # Fianl rotation matrix, shape = (batch_size, 3, 3)
    rotation_matrix = tf.reshape(rotation_matrix_temp, shape=(batch_size, 3, 3))
        
    # SHEAR MATRIX
    c2 = tf.math.cos(shear) # shape = (batch_size,)
    s2 = tf.math.sin(shear) # shape = (batch_size,)
    
    # Intermediate matrix for shear, shape = (9, batch_size) 
    shear_matrix_temp = tf.stack([one, s2, zero, zero, c2, zero, zero, zero, one], axis=0)
    # shape = (batch_size, 9)
    shear_matrix_temp = tf.transpose(shear_matrix_temp)
    # Fianl shear matrix, shape = (batch_size, 3, 3)
    shear_matrix = tf.reshape(shear_matrix_temp, shape=(batch_size, 3, 3))    
    

    # ZOOM MATRIX
    
    # Intermediate matrix for zoom, shape = (9, batch_size) 
    zoom_matrix_temp = tf.stack([one / height_zoom, zero, zero, zero, one / width_zoom, zero, zero, zero, one], axis=0)
    # shape = (batch_size, 9)
    zoom_matrix_temp = tf.transpose(zoom_matrix_temp)
    # Fianl zoom matrix, shape = (batch_size, 3, 3)
    zoom_matrix = tf.reshape(zoom_matrix_temp, shape=(batch_size, 3, 3))
    
    # SHIFT MATRIX
    
    # Intermediate matrix for shift, shape = (9, batch_size) 
    shift_matrix_temp = tf.stack([one, zero, height_shift, zero, one, width_shift, zero, zero, one], axis=0)
    # shape = (batch_size, 9)
    shift_matrix_temp = tf.transpose(shift_matrix_temp)
    # Fianl shift matrix, shape = (batch_size, 3, 3)
    shift_matrix = tf.reshape(shift_matrix_temp, shape=(batch_size, 3, 3))    
        
    return tf.linalg.matmul(tf.linalg.matmul(rotation_matrix, shear_matrix), tf.linalg.matmul(zoom_matrix, shift_matrix))

In [None]:
def batch_transform(images, labels):
    """Returns a tf.Tensor of the same shape as `images`, represented a batch of randomly transformed images.

    Args:
        images: 4-D Tensor with shape (batch_size, width, hight, depth).
            Currently, `depth` can only be 3.
        
    Returns:
        A 4-D Tensor with the same shape as `images`.
    """ 
    
    # input `images`: a batch of images [batch_size, dim, dim, 3]
    # output: images randomly rotated, sheared, zoomed, and shifted
    DIM = images.shape[1]
    XDIM = DIM % 2  # fix for size 331
    
    # A trick to get batch_size
    batch_size = tf.cast(tf.reduce_sum(tf.ones_like(images)) / (images.shape[1] * images.shape[2] * images.shape[3]), tf.int64)
    
    rot = 15.0 * tf.random.normal([batch_size], dtype='float32')
    shr = 5.0 * tf.random.normal([batch_size], dtype='float32') 
    h_zoom = 1.0 + tf.random.normal([batch_size], dtype='float32') / 10.0
    w_zoom = 1.0 + tf.random.normal([batch_size], dtype='float32') / 10.0
    h_shift = 16.0 * tf.random.normal([batch_size], dtype='float32') 
    w_shift = 16.0 * tf.random.normal([batch_size], dtype='float32') 
  
    # GET TRANSFORMATION MATRIX
    # shape = (batch_size, 3, 3)
    m = get_batch_transformatioin_matrix(rot, shr, h_zoom, w_zoom, h_shift, w_shift) 

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat(tf.range(DIM // 2, -DIM // 2, -1), DIM)  # shape = (DIM * DIM,)
    y = tf.tile(tf.range(-DIM // 2, DIM // 2), [DIM])  # shape = (DIM * DIM,)
    z = tf.ones([DIM * DIM], dtype='int32')  # shape = (DIM * DIM,)
    idx = tf.stack([x, y, z])  # shape = (3, DIM * DIM)
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = tf.linalg.matmul(m, tf.cast(idx, dtype='float32'))  # shape = (batch_size, 3, DIM ** 2)
    idx2 = K.cast(idx2, dtype='int32')  # shape = (batch_size, 3, DIM ** 2)
    idx2 = K.clip(idx2, -DIM // 2 + XDIM + 1, DIM // 2)  # shape = (batch_size, 3, DIM ** 2)
    
    # FIND ORIGIN PIXEL VALUES
    # shape = (batch_size, 2, DIM ** 2)
    idx3 = tf.stack([DIM // 2 - idx2[:, 0, ], DIM // 2 - 1 + idx2[:, 1, ]], axis=1)  
    
    # shape = (batch_size, DIM ** 2, 3)
    d = tf.gather_nd(images, tf.transpose(idx3, perm=[0, 2, 1]), batch_dims=1)
        
    # shape = (batch_size, DIM, DIM, 3)
    new_images = tf.reshape(d, (batch_size, DIM, DIM, 3))

    return new_images, labels

In [None]:
def onehot(image,label):
    CLASSES = 104
    return image,tf.one_hot(label,CLASSES)

In [None]:
def batch_cutmix(images, labels, PROBABILITY=1.0, batch_size=0):
    
    DIM = IMAGE_SIZE[0]
    CLASSES = 104
    
    if batch_size == 0:
        batch_size = AUG_BATCH
    
    # DO CUTMIX WITH PROBABILITY DEFINED ABOVE
    # This is a tensor containing 0 or 1 -- 0: no cutmix.
    # shape = [batch_size]
    do_cutmix = tf.cast(tf.random.uniform([batch_size], 0, 1) <= PROBABILITY, tf.int32)
    
    # Choose random images in the batch for cutmix
    # shape = [batch_size]
    new_image_indices = tf.cast(tf.random.uniform([batch_size], 0, batch_size), tf.int32)
    
    # Choose random location in the original image to put the new images
    # shape = [batch_size]
    new_x = tf.cast(tf.random.uniform([batch_size], 0, DIM), tf.int32)
    new_y = tf.cast(tf.random.uniform([batch_size], 0, DIM), tf.int32)
    
    # Random width for new images, shape = [batch_size]
    b = tf.random.uniform([batch_size], 0, 1) # this is beta dist with alpha=1.0
    new_width = tf.cast(DIM * tf.math.sqrt(1-b), tf.int32) * do_cutmix
    
    # shape = [batch_size]
    new_y0 = tf.math.maximum(0, new_y - new_width // 2)
    new_y1 = tf.math.minimum(DIM, new_y + new_width // 2)
    new_x0 = tf.math.maximum(0, new_x - new_width // 2)
    new_x1 = tf.math.minimum(DIM, new_x + new_width // 2)
    
    # shape = [batch_size, DIM]
    target = tf.broadcast_to(tf.range(DIM), shape=(batch_size, DIM))
    
    # shape = [batch_size, DIM]
    mask_y = tf.math.logical_and(new_y0[:, tf.newaxis] <= target, target <= new_y1[:, tf.newaxis])
    
    # shape = [batch_size, DIM]
    mask_x = tf.math.logical_and(new_x0[:, tf.newaxis] <= target, target <= new_x1[:, tf.newaxis])    
    
    # shape = [batch_size, DIM, DIM]
    mask = tf.cast(tf.math.logical_and(mask_y[:, :, tf.newaxis], mask_x[:, tf.newaxis, :]), tf.float32)

    # All components are of shape [batch_size, DIM, DIM, 3]
    new_images =  images * tf.broadcast_to(1 - mask[:, :, :, tf.newaxis], [batch_size, DIM, DIM, 3]) + \
                    tf.gather(images, new_image_indices) * tf.broadcast_to(mask[:, :, :, tf.newaxis], [batch_size, DIM, DIM, 3])

    a = tf.cast(new_width ** 2 / DIM ** 2, tf.float32)    
        
    # Make labels
    if len(labels.shape) == 1:
        labels = tf.one_hot(labels, CLASSES)
        
    new_labels =  (1-a)[:, tf.newaxis] * labels + a[:, tf.newaxis] * tf.gather(labels, new_image_indices)        
        
    return new_images, new_labels

In [None]:
def batch_mixup(images, labels, PROBABILITY=1.0, batch_size=0):

    DIM = IMAGE_SIZE[0]
    CLASSES = 104
    
    if batch_size == 0:
        batch_size = AUG_BATCH
    
    # Do `batch_mixup` with a probability = `PROBABILITY`
    # This is a tensor containing 0 or 1 -- 0: no mixup.
    # shape = [batch_size]
    do_mixup = tf.cast(tf.random.uniform([batch_size], 0, 1) <= PROBABILITY, tf.int32)

    # Choose random images in the batch for cutmix
    # shape = [batch_size]
    new_image_indices = tf.cast(tf.random.uniform([batch_size], 0, batch_size), tf.int32)
    
    # ratio of importance of the 2 images to be mixed up
    # shape = [batch_size]
    a = tf.random.uniform([batch_size], 0, 1) * tf.cast(do_mixup, tf.float32)  # this is beta dist with alpha=1.0
                
    # The second part corresponds to the images to be added to the original images `images`.
    new_images =  (1-a)[:, tf.newaxis, tf.newaxis, tf.newaxis] * images + a[:, tf.newaxis, tf.newaxis, tf.newaxis] * tf.gather(images, new_image_indices)

    # Make labels
    if len(labels.shape) == 1:
        labels = tf.one_hot(labels, CLASSES)
    new_labels =  (1-a)[:, tf.newaxis] * labels + a[:, tf.newaxis] * tf.gather(labels, new_image_indices)

    return new_images, new_labels

In [None]:
def transform_cut_mix(image,label):
    # THIS FUNCTION APPLIES BOTH CUTMIX AND MIXUP
    DIM = IMAGE_SIZE[0]
    CLASSES = 104
    SWITCH = 0.5
    CUTMIX_PROB = 0.666
    MIXUP_PROB = 0.666
    # FOR SWITCH PERCENT OF TIME WE DO CUTMIX AND (1-SWITCH) WE DO MIXUP
    #image2, label2 = cutmix(image, label, CUTMIX_PROB)
    image2, label2 = batch_cutmix(image, label, CUTMIX_PROB)
    #image3, label3 = mixup(image, label, MIXUP_PROB)
    image3, label3 = batch_mixup(image, label, MIXUP_PROB)
    imgs = []; labs = []
    for j in range(AUG_BATCH):
        P = tf.cast( tf.random.uniform([],0,1)<=SWITCH, tf.float32)
        imgs.append(P*image2[j,]+(1-P)*image3[j,])
        labs.append(P*label2[j,]+(1-P)*label3[j,])
    # RESHAPE HACK SO TPU COMPILER KNOWS SHAPE OF OUTPUT TENSOR (maybe use Python typing instead?)
    image4 = tf.reshape(tf.stack(imgs),(AUG_BATCH,DIM,DIM,3))
    label4 = tf.reshape(tf.stack(labs),(AUG_BATCH,CLASSES))
    return image4,label4

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    return image, label # returns a dataset of (image, label) pairs

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum # returns a dataset of image(s)

def load_dataset(filenames, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

def data_augment(image, label):
    # data augmentation. Thanks to the dataset.prefetch(AUTO) statement in the next function (below),
    # this happens essentially for free on TPU. Data pipeline code is executed on the "CPU" part
    # of the TPU while the TPU itself is computing gradients.
    #image, _ = transform_rszs(image, label)
    image = tf.image.random_flip_left_right(image)
    #image = tf.image.random_saturation(image, 0, 2)
    return image, label   

def get_training_dataset(dataset, simple_aug=False, advance_aug=False, cut_mix_aug=0):
    #dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    if simple_aug:
        dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    #if advance_aug:
    #    dataset = dataset.map(transform_rszs, num_parallel_calls=AUTO)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.batch(AUG_BATCH)
    if advance_aug:
        dataset = dataset.map(batch_transform, num_parallel_calls=AUTO)
    if cut_mix_aug==1:
        dataset = dataset.map(batch_cutmix, num_parallel_calls=AUTO)
    elif cut_mix_aug==2:
        dataset = dataset.map(batch_mixup, num_parallel_calls=AUTO)
    elif cut_mix_aug==3:
        dataset = dataset.map(transform_cut_mix, num_parallel_calls=AUTO)
    dataset = dataset.unbatch()
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_validation_dataset(dataset, ordered=False, repeated=False):
    #dataset = load_dataset(VALIDATION_FILENAMES, labeled=True, ordered=ordered)
    if repeated:
        dataset = dataset.repeat()
        dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=repeated)
    #dataset = dataset.cache() #remark to avoid socket closed error
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_test_dataset(ordered=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)


def int_div_round_up(a, b):
    return (a + b - 1) // b


In [None]:
SIZE = 512
IMAGE_SIZE = [SIZE, SIZE]

TRAINING_FILENAMES, VALIDATION_FILENAMES, TEST_FILENAMES = get_filenames(SIZE)
print(len(TRAINING_FILENAMES))

if USE_EXTERNAL:
    EXT_TRAINING_FILENAMES = get_ext_filenames(SIZE, imagenet=True, inaturalist=True, openimage=False, oxford=False, tfflowers=False)
    print(len(EXT_TRAINING_FILENAMES))
    
if USE_EXTERNAL:
    TRAINING_FILENAMES = TRAINING_FILENAMES + EXT_TRAINING_FILENAMES

if SKIP_VALIDATION:
    TRAINING_FILENAMES = TRAINING_FILENAMES + VALIDATION_FILENAMES
    
NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
NUM_VALIDATION_IMAGES = count_data_items(VALIDATION_FILENAMES) if not SKIP_VALIDATION else 0
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE
VALIDATION_STEPS = int_div_round_up(NUM_VALIDATION_IMAGES, BATCH_SIZE)
print('Dataset: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))    

# Dataset visualizations

In [None]:
train_ds = load_dataset(TRAINING_FILENAMES, labeled=True)
valid_ds = load_dataset(VALIDATION_FILENAMES, labeled=True) if not SKIP_VALIDATION else None
test_ds = load_dataset(TEST_FILENAMES, labeled=False)

In [None]:
# data dump
print("Training data shapes:")
for image, label in get_training_dataset(train_ds).take(3):
    print(image.numpy().shape, label.numpy().shape)
print("Training data label examples:", label.numpy())
if not SKIP_VALIDATION:
    print("Validation data shapes:")
    for image, label in get_validation_dataset(valid_ds).take(3):
        print(image.numpy().shape, label.numpy().shape)
    print("Validation data label examples:", label.numpy())
print("Test data shapes:")
for image, idnum in get_test_dataset(test_ds).take(3):
    print(image.numpy().shape, idnum.numpy().shape)
print("Test data IDs:", idnum.numpy().astype('U')) # U=unicode string

In [None]:
if 0:
    # Peek at training data
    training_dataset = get_training_dataset(train_ds, simple_aug=True, advance_aug=False, cut_mix_aug=1)
    training_dataset = training_dataset.unbatch().batch(20)
    train_batch = iter(training_dataset)

In [None]:
if 0:
    # run this cell again for next set of images
    display_batch_of_images(next(train_batch))

In [None]:
if 0:
    # peek at test data
    test_dataset = get_test_dataset(test_ds)
    test_dataset = test_dataset.unbatch().batch(20)
    test_batch = iter(test_dataset)

In [None]:
if 0:
    # run this cell again for next set of images
    display_batch_of_images(next(test_batch))

In [None]:
if 0:
    images, labels = next(iter(training_dataset.take(1)))
    new_images, new_labels = batch_cutmix(images, labels, PROBABILITY=1.0, batch_size=20)
    display_batch_of_images((new_images, new_labels))

In [None]:
if 0:
    images, labels = next(iter(training_dataset.take(1)))
    new_images, new_labels = batch_mixup(images, labels, PROBABILITY=1.0, batch_size=20)
    display_batch_of_images((new_images, new_labels))

# Model

In [None]:
LR_START = 0.00003
LR_MAX = 0.00005 * strategy.num_replicas_in_sync
LR_MIN = 0.00001
LR_RAMPUP_EPOCHS = 5
LR_SUSTAIN_EPOCHS = 3
LR_EXP_DECAY = .8
        
@tf.function
def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
    return lr

rng = [i for i in range(EPOCHS)]
y = [lrfn(x) for x in rng]
plt.plot(rng, y)
print("Learning rate schedule: {:.3g} to {:.3g} to {:.3g}".format(y[0], max(y), y[-1]))

**DenseNet**

In [None]:
with strategy.scope():
    pretrained_model = tf.keras.applications.DenseNet201(weights='imagenet', include_top=False ,input_shape=[*IMAGE_SIZE, 3])
    pretrained_model.trainable = True # False = transfer learning, True = fine-tuning
    
    model1 = tf.keras.Sequential([
        pretrained_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(len(CLASSES), activation='softmax')
    ])
        
    class LRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __call__(self, step):
            return lrfn(epoch=step//STEPS_PER_EPOCH)
        
    optimizer1 = tf.keras.optimizers.Adam(learning_rate=LRSchedule())
    #optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)

    model1.compile(
        #optimizer='adam',
        optimizer = optimizer1,
        loss = 'sparse_categorical_crossentropy',
        metrics=['sparse_categorical_accuracy']
    )
    model1.summary()
        
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    train_loss = tf.keras.metrics.Sum()
    valid_loss = tf.keras.metrics.Sum()
    
    loss_fn = lambda a,b: tf.nn.compute_average_loss(tf.keras.losses.sparse_categorical_crossentropy(a,b), global_batch_size=BATCH_SIZE)

In [None]:
if tpu:
    STEPS_PER_TPU_CALL = 99
    VALIDATION_STEPS_PER_TPU_CALL = 29

    @tf.function
    def train_step(model, optimizer, data_iter):
        def train_step_fn(images, labels):
            with tf.GradientTape() as tape:
                probabilities = model(images, training=True)
                loss = loss_fn(labels, probabilities)
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            train_accuracy.update_state(labels, probabilities)
            train_loss.update_state(loss)

        for _ in tf.range(STEPS_PER_TPU_CALL):
            strategy.experimental_run_v2(train_step_fn, next(data_iter))

    @tf.function
    def valid_step(model, data_iter):
        def valid_step_fn(images, labels):
            probabilities = model(images, training=False)
            loss = loss_fn(labels, probabilities)

            valid_accuracy.update_state(labels, probabilities)
            valid_loss.update_state(loss)

        for _ in tf.range(VALIDATION_STEPS_PER_TPU_CALL):
            strategy.experimental_run_v2(valid_step_fn, next(data_iter))


# Training

**Model1 (DenseNet)**

In [None]:
EPOCHS = 24 #override global setting
start_time = epoch_start_time = time.time()

if tpu:
    train_dist_ds = strategy.experimental_distribute_dataset(get_training_dataset(train_ds, simple_aug=True, advance_aug=True, cut_mix_aug=0))
    valid_dist_ds = strategy.experimental_distribute_dataset(get_validation_dataset(valid_ds, repeated=True)) if not SKIP_VALIDATION else None

    print("Training steps per epoch:", STEPS_PER_EPOCH, "in increment of:", STEPS_PER_TPU_CALL)
    if not SKIP_VALIDATION:
        print("Validation images:", NUM_VALIDATION_IMAGES,
              "Batch size:", BATCH_SIZE,
              "Validation steps:", NUM_VALIDATION_IMAGES//BATCH_SIZE, "in increments of", VALIDATION_STEPS_PER_TPU_CALL)
        print("Repeated validation images:", int_div_round_up(NUM_VALIDATION_IMAGES, BATCH_SIZE*VALIDATION_STEPS_PER_TPU_CALL)*VALIDATION_STEPS_PER_TPU_CALL*BATCH_SIZE-NUM_VALIDATION_IMAGES)

    History = namedtuple('History', 'history')
    history = History(history={'loss': [], 'val_loss': [], 'sparse_categorical_accuracy': [], 'val_sparse_categorical_accuracy': []}) if not SKIP_VALIDATION else History(history={'loss': [], 'val_loss': [], 'sparse_categorical_accuracy': []})

    epoch = 0
    train_data_iter = iter(train_dist_ds)
    valid_data_iter = iter(valid_dist_ds) if not SKIP_VALIDATION else None

    step = 0
    epoch_steps = 0
    while True:
        train_step(model1, optimizer1, train_data_iter)
        epoch_steps += STEPS_PER_TPU_CALL
        step += STEPS_PER_TPU_CALL
        print('=', end='', flush=True)

        if (step//STEPS_PER_EPOCH) > epoch:
            print('|', end='', flush=True)

            if not SKIP_VALIDATION:
                valid_epoch_steps = 0
            #for _ in range(int_div_round_up(NUM_VALIDATION_IMAGES, BATCH_SIZE*VALIDATION_STEPS_PER_TPU_CALL)):
            #    valid_step(valid_data_iter)
            #    valid_epoch_steps += VALIDATION_STEPS_PER_TPU_CALL
            #    print('=', end='', flush=True)

                valid_step(model1, valid_data_iter)
                valid_epoch_steps += VALIDATION_STEPS_PER_TPU_CALL
                print('=', end='', flush=True)
                
                history.history['val_sparse_categorical_accuracy'].append(valid_accuracy.result().numpy())
                history.history['val_loss'].append(valid_loss.result().numpy() / VALIDATION_STEPS)
            
            history.history['sparse_categorical_accuracy'].append(train_accuracy.result().numpy())
            #history.history['val_sparse_categorical_accuracy'].append(valid_accuracy.result().numpy())
            history.history['loss'].append(train_loss.result().numpy() / STEPS_PER_EPOCH)
            #history.history['val_loss'].append(valid_loss.result().numpy() / VALIDATION_STEPS)

            epoch_time = time.time() - epoch_start_time
            print('\nEPOCH {:d}/{:d}'.format(epoch+1, EPOCHS))
            if not SKIP_VALIDATION:
                print('time: {:0.1f}s'.format(epoch_time),
                        'loss: {:0.4f}'.format(history.history['loss'][-1]),
                        'accuracy: {:0.4f}'.format(history.history['sparse_categorical_accuracy'][-1]),
                        'val_loss: {:0.4f}'.format(history.history['val_loss'][-1]),
                        'val_acc: {:0.4f}'.format(history.history['val_sparse_categorical_accuracy'][-1]),
                        'lr: {:0.4g}'.format(lrfn(epoch)), flush=True)
            else:
                print('time: {:0.1f}s'.format(epoch_time),
                        'loss: {:0.4f}'.format(history.history['loss'][-1]),
                        'accuracy: {:0.4f}'.format(history.history['sparse_categorical_accuracy'][-1]),
                        'lr: {:0.4g}'.format(lrfn(epoch)), flush=True)


            epoch = (step+1) // STEPS_PER_EPOCH
            epoch_start_time = time.time()
            train_accuracy.reset_states()
            if not SKIP_VALIDATION:
                valid_accuracy.reset_states()
                valid_loss.reset_states()
            train_loss.reset_states()

            if epoch >= EPOCHS:
                break

else:
    EPOCHS = 15
    history = model1.fit(
    get_training_dataset(train_ds), 
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=EPOCHS,
    #callbacks=[lr_callback],
    validation_data=get_validation_dataset(valid_ds) if not SKIP_VALIDATION else None
)
    
simple_ctl_training_time = time.time() - start_time
print("OPTIMIZED CTL TRAINING TIME: {:0.1f}s".format(simple_ctl_training_time))

In [None]:
model1.save_weights('densenet-flower-tpu.hdf5')

In [None]:
if not SKIP_VALIDATION:
    display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 211)
    display_training_curves(history.history['sparse_categorical_accuracy'], history.history['val_sparse_categorical_accuracy'], 'accuracy', 212)
else:
    display_training_curves(history.history['loss'], history.history['loss'], 'loss', 211)
    display_training_curves(history.history['sparse_categorical_accuracy'], history.history['sparse_categorical_accuracy'], 'accuracy', 212)

In [None]:
test_ds = get_test_dataset(ordered=True)
test_images_ds = test_ds.map(lambda image, idnum: image)

In [None]:
print('Computing predictions model1...')
preds1 = model1.predict(test_images_ds)
del model1
gc.collect()

tf.tpu.experimental.initialize_tpu_system(tpu)

# Get different data set for model2

In [None]:
SIZE = 331
IMAGE_SIZE = [SIZE, SIZE]

TRAINING_FILENAMES, VALIDATION_FILENAMES, TEST_FILENAMES = get_filenames(SIZE)
print(len(TRAINING_FILENAMES))

if USE_EXTERNAL:
    EXT_TRAINING_FILENAMES = get_ext_filenames(SIZE, imagenet=False, inaturalist=False, openimage=True, oxford=True, tfflowers=True)
    print(len(EXT_TRAINING_FILENAMES))
    
if USE_EXTERNAL:
    TRAINING_FILENAMES = TRAINING_FILENAMES + EXT_TRAINING_FILENAMES

if SKIP_VALIDATION:
    TRAINING_FILENAMES = TRAINING_FILENAMES + VALIDATION_FILENAMES
    
NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
NUM_VALIDATION_IMAGES = count_data_items(VALIDATION_FILENAMES) if not SKIP_VALIDATION else 0
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE
VALIDATION_STEPS = int_div_round_up(NUM_VALIDATION_IMAGES, BATCH_SIZE)
print('Dataset: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))    

In [None]:
train_ds = load_dataset(TRAINING_FILENAMES, labeled=True)
valid_ds = load_dataset(VALIDATION_FILENAMES, labeled=True) if not SKIP_VALIDATION else None

**EfficientNet**

In [None]:
import efficientnet.tfkeras as efn

In [None]:
with strategy.scope():
    pretrained_model2 = efn.EfficientNetB7(weights='noisy-student', include_top=False ,input_shape=[*IMAGE_SIZE, 3])
    pretrained_model2.trainable = True # False = transfer learning, True = fine-tuning
    
    model2 = tf.keras.Sequential([
        pretrained_model2,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(len(CLASSES), activation='softmax')
    ])
        
    class LRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __call__(self, step):
            return lrfn(epoch=step//STEPS_PER_EPOCH)
        
    optimizer2 = tf.keras.optimizers.Adam(learning_rate=LRSchedule())
    #optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)

    model2.compile(
        #optimizer='adam',
        optimizer = optimizer2,
        loss = 'sparse_categorical_crossentropy',
        metrics=['sparse_categorical_accuracy']
    )
    model2.summary()
        
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    train_loss = tf.keras.metrics.Sum()
    valid_loss = tf.keras.metrics.Sum()
    
    loss_fn = lambda a,b: tf.nn.compute_average_loss(tf.keras.losses.sparse_categorical_crossentropy(a,b), global_batch_size=BATCH_SIZE)

In [None]:
if tpu:
    STEPS_PER_TPU_CALL = 99
    VALIDATION_STEPS_PER_TPU_CALL = 29

    @tf.function
    def train_step2(model, optimizer, data_iter):
        def train_step_fn(images, labels):
            with tf.GradientTape() as tape:
                probabilities = model(images, training=True)
                loss = loss_fn(labels, probabilities)
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            train_accuracy.update_state(labels, probabilities)
            train_loss.update_state(loss)

        for _ in tf.range(STEPS_PER_TPU_CALL):
            strategy.experimental_run_v2(train_step_fn, next(data_iter))

    @tf.function
    def valid_step2(model, data_iter):
        def valid_step_fn(images, labels):
            probabilities = model(images, training=False)
            loss = loss_fn(labels, probabilities)

            valid_accuracy.update_state(labels, probabilities)
            valid_loss.update_state(loss)

        for _ in tf.range(VALIDATION_STEPS_PER_TPU_CALL):
            strategy.experimental_run_v2(valid_step_fn, next(data_iter))

In [None]:
start_time = epoch_start_time = time.time()

if tpu:
    train_dist_ds = strategy.experimental_distribute_dataset(get_training_dataset(train_ds, simple_aug=True, advance_aug=True, cut_mix_aug=0))
    valid_dist_ds = strategy.experimental_distribute_dataset(get_validation_dataset(valid_ds, repeated=True)) if not SKIP_VALIDATION else None

    print("Training steps per epoch:", STEPS_PER_EPOCH, "in increment of:", STEPS_PER_TPU_CALL)
    if not SKIP_VALIDATION:
        print("Validation images:", NUM_VALIDATION_IMAGES,
              "Batch size:", BATCH_SIZE,
              "Validation steps:", NUM_VALIDATION_IMAGES//BATCH_SIZE, "in increments of", VALIDATION_STEPS_PER_TPU_CALL)
        print("Repeated validation images:", int_div_round_up(NUM_VALIDATION_IMAGES, BATCH_SIZE*VALIDATION_STEPS_PER_TPU_CALL)*VALIDATION_STEPS_PER_TPU_CALL*BATCH_SIZE-NUM_VALIDATION_IMAGES)

    History = namedtuple('History', 'history')
    history = History(history={'loss': [], 'val_loss': [], 'sparse_categorical_accuracy': [], 'val_sparse_categorical_accuracy': []}) if not SKIP_VALIDATION else History(history={'loss': [], 'val_loss': [], 'sparse_categorical_accuracy': []})

    epoch = 0
    train_data_iter = iter(train_dist_ds)
    valid_data_iter = iter(valid_dist_ds) if not SKIP_VALIDATION else None

    step = 0
    epoch_steps = 0
    while True:
        train_step2(model2, optimizer2, train_data_iter)
        epoch_steps += STEPS_PER_TPU_CALL
        step += STEPS_PER_TPU_CALL
        print('=', end='', flush=True)

        if (step//STEPS_PER_EPOCH) > epoch:
            print('|', end='', flush=True)

            if not SKIP_VALIDATION:
                valid_epoch_steps = 0
            #for _ in range(int_div_round_up(NUM_VALIDATION_IMAGES, BATCH_SIZE*VALIDATION_STEPS_PER_TPU_CALL)):
            #    valid_step(valid_data_iter)
            #    valid_epoch_steps += VALIDATION_STEPS_PER_TPU_CALL
            #    print('=', end='', flush=True)

                valid_step2(model2, valid_data_iter)
                valid_epoch_steps += VALIDATION_STEPS_PER_TPU_CALL
                print('=', end='', flush=True)
                
                history.history['val_sparse_categorical_accuracy'].append(valid_accuracy.result().numpy())
                history.history['val_loss'].append(valid_loss.result().numpy() / VALIDATION_STEPS)
            
            history.history['sparse_categorical_accuracy'].append(train_accuracy.result().numpy())
            #history.history['val_sparse_categorical_accuracy'].append(valid_accuracy.result().numpy())
            history.history['loss'].append(train_loss.result().numpy() / STEPS_PER_EPOCH)
            #history.history['val_loss'].append(valid_loss.result().numpy() / VALIDATION_STEPS)

            epoch_time = time.time() - epoch_start_time
            print('\nEPOCH {:d}/{:d}'.format(epoch+1, EPOCHS))
            if not SKIP_VALIDATION:
                print('time: {:0.1f}s'.format(epoch_time),
                        'loss: {:0.4f}'.format(history.history['loss'][-1]),
                        'accuracy: {:0.4f}'.format(history.history['sparse_categorical_accuracy'][-1]),
                        'val_loss: {:0.4f}'.format(history.history['val_loss'][-1]),
                        'val_acc: {:0.4f}'.format(history.history['val_sparse_categorical_accuracy'][-1]),
                        'lr: {:0.4g}'.format(lrfn(epoch)), flush=True)
            else:
                print('time: {:0.1f}s'.format(epoch_time),
                        'loss: {:0.4f}'.format(history.history['loss'][-1]),
                        'accuracy: {:0.4f}'.format(history.history['sparse_categorical_accuracy'][-1]),
                        'lr: {:0.4g}'.format(lrfn(epoch)), flush=True)


            epoch = (step+1) // STEPS_PER_EPOCH
            epoch_start_time = time.time()
            train_accuracy.reset_states()
            if not SKIP_VALIDATION:
                valid_accuracy.reset_states()
                valid_loss.reset_states()
            train_loss.reset_states()

            if epoch >= EPOCHS:
                break

else:
    EPOCHS = 15
    history = model2.fit(
    get_training_dataset(train_ds), 
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=EPOCHS,
    #callbacks=[lr_callback],
    validation_data=get_validation_dataset(valid_ds) if not SKIP_VALIDATION else None
)
    
simple_ctl_training_time = time.time() - start_time
print("OPTIMIZED CTL TRAINING TIME: {:0.1f}s".format(simple_ctl_training_time))

In [None]:
model2.save_weights('effnetb7-flower-tpu.hdf5')

In [None]:
if not SKIP_VALIDATION:
    display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 211)
    display_training_curves(history.history['sparse_categorical_accuracy'], history.history['val_sparse_categorical_accuracy'], 'accuracy', 212)
else:
    display_training_curves(history.history['loss'], history.history['loss'], 'loss', 211)
    display_training_curves(history.history['sparse_categorical_accuracy'], history.history['sparse_categorical_accuracy'], 'accuracy', 212)

In [None]:
test_ds = get_test_dataset(ordered=True)
test_images_ds = test_ds.map(lambda image, idnum: image)

In [None]:
print('Computing predictions model2...')
preds2 = model2.predict(test_images_ds)
del model2
gc.collect()

tf.tpu.experimental.initialize_tpu_system(tpu)

**Use size=512 and some external data for EffNetB6**

In [None]:
SIZE = 512
IMAGE_SIZE = [SIZE, SIZE]
SKIP_VALIDATION = True

TRAINING_FILENAMES, VALIDATION_FILENAMES, TEST_FILENAMES = get_filenames(SIZE)
print(len(TRAINING_FILENAMES))

if USE_EXTERNAL:
    EXT_TRAINING_FILENAMES = get_ext_filenames(SIZE, imagenet=False, inaturalist=True, openimage=False, oxford=False, tfflowers=True)
    print(len(EXT_TRAINING_FILENAMES))
    
if USE_EXTERNAL:
    TRAINING_FILENAMES = TRAINING_FILENAMES + EXT_TRAINING_FILENAMES

if SKIP_VALIDATION:
    TRAINING_FILENAMES = TRAINING_FILENAMES + VALIDATION_FILENAMES
    
NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
NUM_VALIDATION_IMAGES = count_data_items(VALIDATION_FILENAMES) if not SKIP_VALIDATION else 0
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE
VALIDATION_STEPS = int_div_round_up(NUM_VALIDATION_IMAGES, BATCH_SIZE)
print('Dataset: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))

train_ds = load_dataset(TRAINING_FILENAMES, labeled=True)
valid_ds = load_dataset(VALIDATION_FILENAMES, labeled=True) if not SKIP_VALIDATION else None

In [None]:
with strategy.scope():
    pretrained_model3 = efn.EfficientNetB6(weights='noisy-student', include_top=False ,input_shape=[*IMAGE_SIZE, 3])
    pretrained_model3.trainable = True # False = transfer learning, True = fine-tuning
    
    model3 = tf.keras.Sequential([
        pretrained_model3,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(len(CLASSES), activation='softmax')
    ])
        
    class LRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __call__(self, step):
            return lrfn(epoch=step//STEPS_PER_EPOCH)
        
    optimizer3 = tf.keras.optimizers.Adam(learning_rate=LRSchedule())
    #optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)

    model3.compile(
        #optimizer='adam',
        optimizer = optimizer3,
        loss = 'sparse_categorical_crossentropy',
        metrics=['sparse_categorical_accuracy']
    )
    model3.summary()
        
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    train_loss = tf.keras.metrics.Sum()
    valid_loss = tf.keras.metrics.Sum()
    
    loss_fn = lambda a,b: tf.nn.compute_average_loss(tf.keras.losses.sparse_categorical_crossentropy(a,b), global_batch_size=BATCH_SIZE)

In [None]:
if tpu:
    STEPS_PER_TPU_CALL = 99
    VALIDATION_STEPS_PER_TPU_CALL = 29

    @tf.function
    def train_step3(model, optimizer, data_iter):
        def train_step_fn(images, labels):
            with tf.GradientTape() as tape:
                probabilities = model(images, training=True)
                loss = loss_fn(labels, probabilities)
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            train_accuracy.update_state(labels, probabilities)
            train_loss.update_state(loss)

        for _ in tf.range(STEPS_PER_TPU_CALL):
            strategy.experimental_run_v2(train_step_fn, next(data_iter))

    @tf.function
    def valid_step3(model, data_iter):
        def valid_step_fn(images, labels):
            probabilities = model(images, training=False)
            loss = loss_fn(labels, probabilities)

            valid_accuracy.update_state(labels, probabilities)
            valid_loss.update_state(loss)

        for _ in tf.range(VALIDATION_STEPS_PER_TPU_CALL):
            strategy.experimental_run_v2(valid_step_fn, next(data_iter))

In [None]:
start_time = epoch_start_time = time.time()

if tpu:
    train_dist_ds = strategy.experimental_distribute_dataset(get_training_dataset(train_ds, simple_aug=True, advance_aug=True, cut_mix_aug=0))
    valid_dist_ds = strategy.experimental_distribute_dataset(get_validation_dataset(valid_ds, repeated=True)) if not SKIP_VALIDATION else None

    print("Training steps per epoch:", STEPS_PER_EPOCH, "in increment of:", STEPS_PER_TPU_CALL)
    if not SKIP_VALIDATION:
        print("Validation images:", NUM_VALIDATION_IMAGES,
              "Batch size:", BATCH_SIZE,
              "Validation steps:", NUM_VALIDATION_IMAGES//BATCH_SIZE, "in increments of", VALIDATION_STEPS_PER_TPU_CALL)
        print("Repeated validation images:", int_div_round_up(NUM_VALIDATION_IMAGES, BATCH_SIZE*VALIDATION_STEPS_PER_TPU_CALL)*VALIDATION_STEPS_PER_TPU_CALL*BATCH_SIZE-NUM_VALIDATION_IMAGES)

    History = namedtuple('History', 'history')
    history = History(history={'loss': [], 'val_loss': [], 'sparse_categorical_accuracy': [], 'val_sparse_categorical_accuracy': []}) if not SKIP_VALIDATION else History(history={'loss': [], 'val_loss': [], 'sparse_categorical_accuracy': []})

    epoch = 0
    train_data_iter = iter(train_dist_ds)
    valid_data_iter = iter(valid_dist_ds) if not SKIP_VALIDATION else None

    step = 0
    epoch_steps = 0
    while True:
        train_step3(model3, optimizer3, train_data_iter)
        epoch_steps += STEPS_PER_TPU_CALL
        step += STEPS_PER_TPU_CALL
        print('=', end='', flush=True)

        if (step//STEPS_PER_EPOCH) > epoch:
            print('|', end='', flush=True)

            if not SKIP_VALIDATION:
                valid_epoch_steps = 0
                valid_step3(model3, valid_data_iter)
                valid_epoch_steps += VALIDATION_STEPS_PER_TPU_CALL
                print('=', end='', flush=True)
                
                history.history['val_sparse_categorical_accuracy'].append(valid_accuracy.result().numpy())
                history.history['val_loss'].append(valid_loss.result().numpy() / VALIDATION_STEPS)
            
            history.history['sparse_categorical_accuracy'].append(train_accuracy.result().numpy())
            #history.history['val_sparse_categorical_accuracy'].append(valid_accuracy.result().numpy())
            history.history['loss'].append(train_loss.result().numpy() / STEPS_PER_EPOCH)
            #history.history['val_loss'].append(valid_loss.result().numpy() / VALIDATION_STEPS)

            epoch_time = time.time() - epoch_start_time
            print('\nEPOCH {:d}/{:d}'.format(epoch+1, EPOCHS))
            if not SKIP_VALIDATION:
                print('time: {:0.1f}s'.format(epoch_time),
                        'loss: {:0.4f}'.format(history.history['loss'][-1]),
                        'accuracy: {:0.4f}'.format(history.history['sparse_categorical_accuracy'][-1]),
                        'val_loss: {:0.4f}'.format(history.history['val_loss'][-1]),
                        'val_acc: {:0.4f}'.format(history.history['val_sparse_categorical_accuracy'][-1]),
                        'lr: {:0.4g}'.format(lrfn(epoch)), flush=True)
            else:
                print('time: {:0.1f}s'.format(epoch_time),
                        'loss: {:0.4f}'.format(history.history['loss'][-1]),
                        'accuracy: {:0.4f}'.format(history.history['sparse_categorical_accuracy'][-1]),
                        'lr: {:0.4g}'.format(lrfn(epoch)), flush=True)


            epoch = (step+1) // STEPS_PER_EPOCH
            epoch_start_time = time.time()
            train_accuracy.reset_states()
            if not SKIP_VALIDATION:
                valid_accuracy.reset_states()
                valid_loss.reset_states()
            train_loss.reset_states()

            if epoch >= EPOCHS:
                break

else:
    EPOCHS = 15
    history = model3.fit(
    get_training_dataset(train_ds), 
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=EPOCHS,
    #callbacks=[lr_callback],
    validation_data=get_validation_dataset(valid_ds) if not SKIP_VALIDATION else None
)
    
simple_ctl_training_time = time.time() - start_time
print("OPTIMIZED CTL TRAINING TIME: {:0.1f}s".format(simple_ctl_training_time))

In [None]:
model3.save_weights('effnetb6-flower-tpu.hdf5')

In [None]:
if not SKIP_VALIDATION:
    display_training_curves(history.history['loss'], history.history['val_loss'], 'loss', 211)
    display_training_curves(history.history['sparse_categorical_accuracy'], history.history['val_sparse_categorical_accuracy'], 'accuracy', 212)
else:
    display_training_curves(history.history['loss'], history.history['loss'], 'loss', 211)
    display_training_curves(history.history['sparse_categorical_accuracy'], history.history['sparse_categorical_accuracy'], 'accuracy', 212)

In [None]:
test_ds = get_test_dataset(ordered=True)
test_images_ds = test_ds.map(lambda image, idnum: image)

In [None]:
print('Computing predictions model3...')
preds3 = model3.predict(test_images_ds)
del model3
gc.collect()

tf.tpu.experimental.initialize_tpu_system(tpu)

# Confusion matrix

In [None]:
if 0 and not SKIP_VALIDATION:
    cmdataset = get_validation_dataset(load_dataset(VALIDATION_FILENAMES, labeled=True, ordered=True), ordered=True) # since we are splitting the dataset and iterating separately on images and labels, order matters.
    images_ds = cmdataset.map(lambda image, label: image)
    labels_ds = cmdataset.map(lambda image, label: label).unbatch()
    cm_correct_labels = next(iter(labels_ds.batch(NUM_VALIDATION_IMAGES))).numpy() # get everything as one batch
    cm_probabilities = (model1.predict(images_ds)+model2.predict(images_ds))/2
    cm_predictions = np.argmax(cm_probabilities, axis=-1)
    print("Correct   labels: ", cm_correct_labels.shape, cm_correct_labels)
    print("Predicted labels: ", cm_predictions.shape, cm_predictions)

In [None]:
if 0 and not SKIP_VALIDATION:
    cmat = confusion_matrix(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)))
    score = f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
    precision = precision_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
    recall = recall_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro')
    cmat = (cmat.T / cmat.sum(axis=1)).T # normalized
    display_confusion_matrix(cmat, score, precision, recall)
    print('f1 score: {:.3f}, precision: {:.3f}, recall: {:.3f}'.format(score, precision, recall))

# Predictions

In [None]:
alpha = 0.55

In [None]:
#test_ds = get_test_dataset(ordered=True) # since we are splitting the dataset and iterating separately on images and ids, order matters.

print('Computing predictions...')
#test_images_ds = test_ds.map(lambda image, idnum: image)
#probabilities = alpha * model1.predict(test_images_ds) + (1-alpha) * model2.predict(test_images_ds)
probabilities = 0.45 * (alpha * preds1 + (1-alpha) * preds2) + 0.55 * preds3
predictions = np.argmax(probabilities, axis=-1)
print(predictions)

print('Generating submission.csv file...')
test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U') # all in one batch
np.savetxt('submission.csv', np.rec.fromarrays([test_ids, predictions]), fmt=['%s', '%d'], delimiter=',', header='id,label', comments='')
!head submission.csv

# Visual validation

In [None]:
if 0 and not SKIP_VALIDATION:
    dataset = get_validation_dataset(valid_ds)
    dataset = dataset.unbatch().batch(20)
    batch = iter(dataset)
    
    # run this cell again for next set of images
    images, labels = next(batch)
    probabilities = model.predict(images)
    predictions = np.argmax(probabilities, axis=-1)
    display_batch_of_images((images, labels), predictions)