In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os, re

from kaggle_datasets import KaggleDatasets

In [None]:
# NEW on TPU in TensorFlow 24: shorter cross-compatible TPU/GPU/multi-GPU/cluster-GPU detection code

try: # detect TPUs
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: # detect GPUs
    strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    #strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
    #strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() # for clusters of multi-GPU machines

print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
basedir = KaggleDatasets().get_gcs_path('tpu-getting-started')
tfrecordsdir = os.path.join(basedir, "tfrecords-jpeg-512x512")
traindir = os.path.join(tfrecordsdir, "train")
testdir = os.path.join(tfrecordsdir, "test")
valdir = os.path.join(tfrecordsdir, "val")
submission_file = os.path.join(basedir, "sample_submission.csv")

In [None]:
IMAGE_SIZE = (512, 512)
IMAGE_SHAPE = IMAGE_SIZE + (3, )
BATCH_SIZE = 16 * strategy.num_replicas_in_sync

In [None]:
EPOCHS = 12

start_lr = 0.00001
min_lr = 0.00001
max_lr = 0.00005 * strategy.num_replicas_in_sync
rampup_epochs = 5
sustain_epochs = 0
exp_decay = .8

def lrfn(epoch):
    if epoch < rampup_epochs:
        return (max_lr - start_lr)/rampup_epochs * epoch + start_lr
    elif epoch < rampup_epochs + sustain_epochs:
        return max_lr
    else:
        return (max_lr - min_lr) * exp_decay**(epoch-rampup_epochs-sustain_epochs) + min_lr
    
lr_callback = tf.keras.callbacks.LearningRateScheduler(lambda epoch: lrfn(epoch), verbose=True)

rang = np.arange(EPOCHS)
y = [lrfn(x) for x in rang]
plt.plot(rang, y)
print('Learning rate per epoch:')

In [None]:
def get_tfrecord_ds(path):
    filenames = tf.io.gfile.glob(os.path.join(path, "*"))
    return tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.AUTOTUNE)

In [None]:
image_feature_description_train = {
    'class': tf.io.FixedLenFeature([], tf.int64),
    'image': tf.io.FixedLenFeature([], tf.string),
}

image_feature_description_test = {
    'id': tf.io.FixedLenFeature([], tf.string),
    'image': tf.io.FixedLenFeature([], tf.string),
}

def parse_image_train(proto):
    example = tf.io.parse_single_example(proto, image_feature_description_train)
    image = tf.image.decode_jpeg(example["image"], channels=3)
    image = tf.reshape(image, IMAGE_SHAPE) 
    label = example["class"]
    return image, label

def parse_image_test(proto):
    example = tf.io.parse_single_example(proto, image_feature_description_test)
    image = tf.image.decode_jpeg(example["image"], channels=3)
    image = tf.reshape(image, IMAGE_SHAPE)
    return image, example["id"]

In [None]:
rng = tf.random.Generator.from_seed(123)

def augment(image, label):
    seed = rng.make_seeds(2)[0]
    image = tf.image.stateless_random_crop(
        image, size=IMAGE_SHAPE, seed=seed
    )
    seed = rng.make_seeds(2)[0]
    image = tf.image.stateless_random_brightness(
        image, max_delta=0.5, seed=seed
    )
    seed = rng.make_seeds(2)[0]
    image = tf.image.stateless_random_flip_left_right(
        image, seed=seed
    )
    return image, label

In [None]:
ignore_order = tf.data.Options()
ignore_order.experimental_deterministic = False
ds_train = get_tfrecord_ds(traindir).with_options(ignore_order).map(parse_image_train, num_parallel_calls=tf.data.AUTOTUNE).map(augment, num_parallel_calls=tf.data.AUTOTUNE)
ds_val = get_tfrecord_ds(valdir).with_options(ignore_order).map(parse_image_train, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = get_tfrecord_ds(testdir).map(parse_image_test, num_parallel_calls=tf.data.AUTOTUNE)

ds_train = ds_train.repeat().cache().shuffle(15000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
ds_val = ds_val.repeat().batch(BATCH_SIZE).cache().prefetch(tf.data.AUTOTUNE)
ds_test = ds_test.batch(BATCH_SIZE).cache().prefetch(tf.data.AUTOTUNE)

In [None]:
def count_data_items(path):
    filenames = tf.io.gfile.glob(os.path.join(path, "*")) 
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

TRAIN_STEPS = -(-count_data_items(traindir) // BATCH_SIZE)
VAL_STEPS = -(-count_data_items(valdir) // BATCH_SIZE)
TEST_STEPS = -(-count_data_items(testdir) // BATCH_SIZE)

In [None]:
plt.figure(figsize=(10, 10))
for ds in ds_train.take(1):
    for i in range(9):
        plt.subplot(3, 3, i + 1)
        plt.axis("off")
        plt.imshow(ds[0][i])
        plt.title(ds[1][i].numpy())
plt.show()

In [None]:
with strategy.scope():
    preprocess_input = tf.keras.applications.xception.preprocess_input
    base_model = tf.keras.applications.Xception(
        input_shape=IMAGE_SHAPE,
        include_top=False,
        weights='imagenet'
    )
    base_model.trainable = True
    model = tf.keras.Sequential([
        tf.keras.layers.Lambda(lambda x: preprocess_input(x), input_shape=IMAGE_SHAPE),
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(104)
    ])
    model.compile(
        optimizer="adam",
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy'],
        steps_per_execution=32,
    )
    model.summary()

In [None]:
history = model.fit(
    ds_train, 
    epochs=EPOCHS, 
    validation_data=ds_val,
    steps_per_epoch=TRAIN_STEPS,
    validation_steps=VAL_STEPS,
    callbacks=[lr_callback],
)

In [None]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(EPOCHS)

plt.figure(figsize=(10, 10))
plt.subplot(2, 1, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

In [None]:
ds_test_image = ds_test.map(lambda image, idnum: image)
pred = model.predict(ds_test_image, steps=TEST_STEPS)
pred_label = tf.math.argmax(pred, 1)

In [None]:
ds_test_id = ds_test.map(lambda image, idnum: idnum).unbatch()
ids = [str(x, "utf-8") for x in ds_test_id.as_numpy_iterator()]

In [None]:
df = pd.read_csv(submission_file)
df["label"] = pred_label
df["id"] = ids
df.to_csv("submission.csv", index=False)