In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import tensorflow as tf
from matplotlib import pyplot as plt
from tensorflow.keras import layers
import os,math,re

# Outline
1. TPU Setting
2. Data Import and Visualization
3. Training: VCG Transfer Learning
4. Validation
5. Predicting Test Data
6. References

# TPU Setting

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() 

print("REPLICAS: ", strategy.num_replicas_in_sync)

# Data Import and Visualization

In [None]:
from kaggle_datasets import KaggleDatasets

GCS_DS_PATH = KaggleDatasets().get_gcs_path('tpu-getting-started')
print(GCS_DS_PATH) # what do gcs paths look like?

In [None]:
IMAGE_SIZE = [299, 299] 
# VGG16 : Input size 224 x 224
# Xception: 299 x 299
GCS_PATH = GCS_DS_PATH + '/tfrecords-jpeg-512x512'
AUTO = tf.data.experimental.AUTOTUNE

TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')
VALIDATION_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')
TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test/*.tfrec') 

In [None]:
# Flower classes encoding
CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102

In [None]:
def decode_image(image_data):
    image =tf.image.decode_jpeg(image_data, channels=3)
    image =tf.image.resize(image,[*IMAGE_SIZE])  # resize image to the dimension needed for the pretrained model
    image =tf.cast(image, tf.float32) /255.0
    image = tf.reshape(image,[*IMAGE_SIZE, 3])
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([],tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([],tf.int64),
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example["image"])
    label = tf.cast(example["class"], tf.int32)
    return image,label

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([],tf.string),
        "id": tf.io.FixedLenFeature([],tf.string),
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example["image"])
    idnum = example["id"]
    return image,idnum

def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False
    
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    return dataset

In [None]:
def data_augment(image, label):
    image = tf.image.random_flip_left_right(image)
    return image,label

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.repeat() # repeats for several epochs
    dataset = dataset.shuffle(buffer_size=2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training
    return dataset

def get_validation_dataset(ordered=False):
    dataset = load_dataset(VALIDATION_FILENAMES, labeled=True, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO)
    return dataset

def get_test_dataset(ordered = False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

def count_data_items(filenames):
    # the number of data items in the name of the .tfrec 
    n  = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

fTrainImages = count_data_items(TRAINING_FILENAMES)
fValidationImages = count_data_items(VALIDATION_FILENAMES)
fTestImages = count_data_items(TEST_FILENAMES)
print(f"{fTrainImages} training images, {fValidationImages} validation images, {fTestImages} test images ")

In [None]:
BATCH_SIZE = 16* strategy.num_replicas_in_sync
ds_train = get_training_dataset()
ds_valid = get_validation_dataset()
ds_test = get_test_dataset()

print("Training: ", ds_train)
print("Validation: ", ds_valid)
print("Test: ", ds_test)

In [None]:
for image, label in ds_train.take(3):
    print(image.numpy().shape, label.numpy().shape)
    
print("Training data label examples:", label.numpy())


In [None]:
for image, idnum in ds_test.take(3):
    print(image.numpy().shape, idnum.numpy().shape)
print("Test data IDs: ", idnum.numpy().astype('U')) # U = unicode string

## Image Visualization 

In [None]:
def batch_to_numpy_images_and_labels(data):
    images,labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object:
        numpy_labels = [ None for _ in enumerate(numpy_images)]
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label==correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 
                                'OK' if correct else 'NO',
                                u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct
def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title)>0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2),
                  color= 'red' if red else 'black',
                fontdict={'verticalalignment':'center'}, 
                  pad=int(titlesize/1.5))
    return (subplot[0],subplot[1],subplot[2]+1)
def display_batch_of_images(databatch,predictions=None):
    images,labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels  = [None for _ in enumerate(images)]
        
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
    
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot = (rows,cols,1)
    if(rows < cols):
        plt.figure(figsize=(FIGSIZE, FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols, FIGSIZE))
    
    #display 
    for i, (image,label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else CLASSES[label]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i],label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3
        subplot = display_one_flower(image,title,subplot, not correct,
                                    titlesize = dynamic_titlesize)
    
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0,hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()
               

In [None]:
ds_iter = iter(ds_train.unbatch().batch(20))

In [None]:
one_batch = next(ds_iter)
display_batch_of_images(one_batch)

# Training Model

In [None]:
def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

In [None]:
EPOCHS = 10

In [None]:
## Learning Rate Schedule
        
def exponential_lr(epoch, start_lr = 1e-5, min_lr = 1e-5, max_lr = 5e-5, 
                   rampup_epochs=5, sustain_epochs=0, exp_decay=0.8):
    if epoch < rampup_epochs:
        lr = (max_lr-start_lr)/rampup_epochs*epoch + start_lr
    elif epoch < rampup_epochs + sustain_epochs:
        lr = max_lr
    else:
        lr = (max_lr - min_lr)*exp_decay**(epoch- rampup_epochs - sustain_epochs) + min_lr
    return lr

lr_callback = tf.keras.callbacks.LearningRateScheduler(exponential_lr, verbose=False)

rng = [i for i in range(EPOCHS)]
y = [exponential_lr(x) for x in rng]
plt.plot(rng,y)
print("Learning Rate Schedule: {:.3g} to {:.3g} to {:.3g}".format(y[0],max(y),y[-1]))
    

## Pretrained Xception Model

In [None]:
with strategy.scope():
    img_adjust_layer = tf.keras.layers.Lambda(lambda data: tf.keras.applications.xception.preproces_input(tf.cast(data,tf.float32)), input_shape=[*IMAGE_SIZE,3])
    xce_pretrained_model = tf.keras.applications.Xception(weights='imagenet',include_top=False)    
    xce_pretrained_model.trainable = True
    
    xce_model = tf.keras.Sequential();
    xce_model.add(xce_pretrained_model)
    xce_model.add(layers.GlobalAveragePooling2D())
    xce_model.add(layers.Dense(len(CLASSES), activation='softmax'))
    
xce_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
xce_model.summary()

In [None]:
STEPS_PER_EPOCH = fTrainImages // BATCH_SIZE
xce_history = xce_model.fit(ds_train, validation_data=ds_valid, 
                            epochs=EPOCHS, steps_per_epoch = STEPS_PER_EPOCH, 
                            callbacks=[lr_callback], verbose=1)

In [None]:
display_training_curves(xce_history.history['loss'], xce_history.history['val_loss'], 'loss', 211)
display_training_curves(xce_history.history['sparse_categorical_accuracy'], xce_history.history['val_sparse_categorical_accuracy'],'accuracy',212)

## Pretrained ResNet Model

In [None]:
with strategy.scope():
    img_adjust_layer = tf.keras.layers.Lambda(lambda data: tf.keras.applications.resnet_v2.preproces_input(tf.cast(data,tf.float32)), input_shape=[*IMAGE_SIZE,3])
    resnet_pretrained_model = tf.keras.applications.ResNet50V2(weights='imagenet',include_top=False)    
    resnet_pretrained_model.trainable = True
    
    resnet_model = tf.keras.Sequential();
    resnet_model.add(resnet_pretrained_model)
    resnet_model.add(layers.GlobalAveragePooling2D())
    resnet_model.add(layers.Dense(len(CLASSES), activation='softmax'))
    
resnet_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
resnet_model.summary()

In [None]:
resnet_history = resnet_model.fit(ds_train, validation_data=ds_valid, 
                            epochs=EPOCHS, steps_per_epoch = STEPS_PER_EPOCH, 
                            callbacks=[lr_callback], verbose=1)

In [None]:
display_training_curves(resnet_history.history['loss'], resnet_history.history['val_loss'], 'loss', 211)
display_training_curves(resnet_history.history['sparse_categorical_accuracy'], resnet_history.history['val_sparse_categorical_accuracy'],'accuracy',212)

# Validation

## Macro F1 Scores

In [None]:
from sklearn.metrics import f1_score

validation_ds = get_validation_dataset(ordered=True)
validation_images_ds = validation_ds.map(lambda image, label: image)
val_label_ds = validation_ds.map(lambda image, label: label).unbatch()
val_labels = next(iter(val_label_ds.batch(fValidationImages))).numpy()

In [None]:
validation_proba = xce_model.predict(validation_images_ds)
predictions = np.argmax(validation_proba, axis=-1)
xce_f1 = f1_score(val_labels, predictions,average='macro')
print(f"Macro F1 scores: {xce_f1}")

xce_model.evaluate(validation_ds)

In [None]:
validation_proba = resnet_model.predict(validation_images_ds)
predictions = np.argmax(validation_proba, axis=-1)
resnet_f1 = f1_score(val_labels, predictions,average='macro')
print(f"Macro F1 scores: {resnet_f1}")

resnet_model.evaluate(validation_ds)

In [None]:
dataset = get_validation_dataset()
dataset = dataset.unbatch().batch(20)
batch = iter(dataset)

In [None]:
images, labels = next(batch)
proba  = xce_model.predict(images)
predictions = np.argmax(proba, axis=-1)
display_batch_of_images((images,labels), predictions)

# Predicting Test Data

In [None]:
test_ds = get_test_dataset(ordered=True)
test_images_ds = test_ds.map(lambda image, idnum: image)
test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
test_ids = next(iter(test_ids_ds.batch(fTestImages))).numpy().astype('U')

In [None]:
proba = resnet_model.predict(test_images_ds)
predictions = np.argmax(proba, axis=-1)

np.savetxt('submission_resnet50.csv',
          np.rec.fromarrays([test_ids,predictions]),
           fmt=['%s', '%d'],
           delimiter=',',
           header='id,label',
           comments='',)



In [None]:
proba = xce_model.predict(test_images_ds)
predictions = np.argmax(proba, axis=-1)

np.savetxt('submission_xce.csv',
          np.rec.fromarrays([test_ids,predictions]),
           fmt=['%s', '%d'],
           delimiter=',',
           header='id,label',
           comments='',)

np.savetxt('submission.csv',
          np.rec.fromarrays([test_ids,predictions]),
           fmt=['%s', '%d'],
           delimiter=',',
           header='id,label',
           comments='',)

# References
1. https://www.kaggle.com/ryanholbrook/create-your-first-submission
2. https://www.kaggle.com/mgornergoogle/five-flowers-with-keras-and-xception-on-tpu
3. https://www.kaggle.com/philculliton/a-simple-petals-tf-2-2-notebook