In [None]:
import os
import re
import math
import random
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from kaggle_datasets import KaggleDatasets

In [None]:
# use 224x224 image data, since input for EfficientNet B0 is 224x224x3

IMAGE_SIZE = [224, 224]
BATCH_SIZE = 64
AUTO = tf.data.experimental.AUTOTUNE

GCS_DS_PATH = KaggleDatasets().get_gcs_path('tpu-getting-started')
GCS_PATH = GCS_DS_PATH + '/tfrecords-jpeg-224x224'
TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')
VALIDATION_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')
TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test/*.tfrec')

In [None]:
# code to load data, note that EfficientNet uses 0-255 input range, while other model might use 0-1

def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    return image, label # returns a dataset of (image, label) pairs

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum # returns a dataset of image(s)

def load_dataset(filenames, labeled=True):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

In [None]:
# one-hot encoding for label
# image augmentation (reduces overfitting)
# shuffle
# batch
# prefetch or cache

NUM_CLASSES = 104
CROP_SIZE = [200, 200]

def data_augment(image, label):
    image = tf.image.random_flip_left_right(image)
    if bool(random.getrandbits(1)):
        image = tf.image.random_crop(image, [*CROP_SIZE, 3])
        image = tf.image.resize(image, IMAGE_SIZE)
    return image, label

def input_preprocess(image, label):
    label = tf.one_hot(label, NUM_CLASSES)
    return image, label

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    dataset = dataset.map(input_preprocess, num_parallel_calls=AUTO)
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

def get_validation_dataset():
    dataset = load_dataset(VALIDATION_FILENAMES, labeled=True)
    dataset = dataset.map(input_preprocess)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    return dataset

def get_test_dataset():
    dataset = load_dataset(TEST_FILENAMES, labeled=False)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset

def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec
    # files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
NUM_VALIDATION_IMAGES = count_data_items(VALIDATION_FILENAMES)
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
print('Dataset: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))

ds_train = get_training_dataset()
ds_valid = get_validation_dataset()
ds_test = get_test_dataset()

In [None]:
# use noisy-student version of EfficientNet B0 weights (better performance than standard "imagenet" weights)

!wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/noisystudent/noisy_student_efficientnet-b0.tar.gz
!tar -xf noisy_student_efficientnet-b0.tar.gz

!wget https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/python/keras/applications/efficientnet_weight_update_util.py
!python efficientnet_weight_update_util.py --model b0 --notop --ckpt noisy_student_efficientnet-b0/model.ckpt --o efficientnetb0_notop.h5

In [None]:
# load EfficientNet B0 without top
# freeze weights on EfficientNet B0
# add a softmax top (dropout reduces overfitting)

pretrained_model = tf.keras.applications.efficientnet.EfficientNetB0(
    weights='efficientnetb0_notop.h5',
    include_top=False,
    input_shape=[*IMAGE_SIZE, 3]
)
pretrained_model.trainable = False

model = tf.keras.Sequential([
    # To a base pretrained on ImageNet to extract features from images...
    pretrained_model,
    # ... attach a new head to act as a classifier.
    # tf.keras.layers.BatchNormalization(name='batch_norm1'),
    # tf.keras.layers.Conv2D(20, 1, activation='relu', name='conv1'),
    # tf.keras.layers.Flatten(name='flatten'),
    tf.keras.layers.GlobalAveragePooling2D(name='avg_pool'),
    tf.keras.layers.BatchNormalization(name='batch_norm2'),
    tf.keras.layers.Dropout(0.2, name="dropout2"),
    tf.keras.layers.Dense(NUM_CLASSES, activation='softmax', name='output')
], name="EfficientNetB0")

In [None]:
# ideally this should be ADAM with learning rate schedules

LEARNING_RATE = 0.01

model.compile(
    optimizer=tf.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()

In [None]:
EPOCHS = 20

history = model.fit(
    ds_train,
    validation_data=ds_valid,
    epochs=EPOCHS
)

In [None]:
def plot_history(hist):
    plt.plot(hist.history["accuracy"])
    plt.plot(hist.history["val_accuracy"])
    plt.title("model accuracy")
    plt.ylabel("accuracy")
    plt.xlabel("epoch")
    plt.legend(["train", "validation"], loc="upper left")
    plt.show()

plot_history(history)

In [None]:
NEXT_EPOCHES = 40

# reduce learning rate
model.optimizer.learning_rate = 0.0001

history_continued = model.fit(
    ds_train,
    validation_data=ds_valid,
    initial_epoch=EPOCHS,
    epochs=NEXT_EPOCHES
)

In [None]:
plot_history(history_continued)

In [None]:
# code snippet to load a saved model, to continue from last version

# !wget https://kaggledatastore.blob.core.windows.net/data/flowers/enet0_epoch30.h5

# model = tf.keras.models.load_model('enet0_epoch30.h5')

# !rm enet0_epoch30.h5

In [None]:
# save model

model.save('enet0_epoch40.h5')

In [None]:
# predict on test data

test_ds = get_test_dataset()

test_images_ds = test_ds.map(lambda image, idnum: image)
probabilities = model.predict(test_images_ds)
predictions = np.argmax(probabilities, axis=-1)

In [None]:
# make submission.csv

test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U')

np.savetxt(
    'submission.csv',
    np.rec.fromarrays([test_ids, predictions]),
    fmt=['%s', '%d'],
    delimiter=',',
    header='id,label',
    comments=''
)