In [None]:
import tensorflow as tf
from tensorflow import keras
import keras.layers as layers
import os
import matplotlib.pyplot as plt
import glob
import numpy as np

In [None]:
try: # detect TPUs
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: # detect GPUs
    strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
print("Number of accelerators: ", strategy.num_replicas_in_sync)

In [None]:
strategy

In [None]:
TRAIN_FILES = glob.glob('../input/tpu-getting-started/tfrecords-jpeg-512x512/train/*.tfrec')
VAL_FILES = glob.glob('../input/tpu-getting-started/tfrecords-jpeg-512x512/val/*.tfrec')
TEST_FILES = glob.glob('../input/tpu-getting-started/tfrecords-jpeg-512x512/test/*.tfrec')

In [None]:
def decode_image(raw_image):
    image = tf.image.decode_jpeg(raw_image,channels = 3)
    image = tf.cast(image,dtype = tf.float64)   
    return image

def parse_labeled(item):
    feature_desc = {
        'class':tf.io.FixedLenFeature([],tf.int64),
        'image':tf.io.FixedLenFeature([],tf.string)
    }
    parsed_example = tf.io.parse_single_example(item,feature_desc)
    image = decode_image(parsed_example['image'])
    return image,parsed_example['class']

def parse_unlabeled(item):
    feature_desc = {
        'id':tf.io.FixedLenFeature([],tf.string),
        'image':tf.io.FixedLenFeature([],tf.string)
    }
    parsed_example = tf.io.parse_single_example(item,feature_desc)
    image = decode_image(parsed_example['image'])
    return [parsed_example['id']],image
    

In [None]:
train_ds_unparsed = tf.data.TFRecordDataset(TRAIN_FILES)
val_ds_unparsed = tf.data.TFRecordDataset(VAL_FILES)
test_ds_unparsed = tf.data.TFRecordDataset(TEST_FILES)

In [None]:
AUTO = tf.data.experimental.AUTOTUNE
train_ds = train_ds_unparsed.map(parse_labeled,AUTO).batch(32).prefetch(AUTO)
val_ds = val_ds_unparsed.map(parse_labeled,AUTO).batch(32).prefetch(AUTO)

In [None]:
preprocess = keras.applications.resnet.preprocess_input
resnet_head = keras.applications.ResNet50(
    include_top=False,
    weights="imagenet",
    input_shape=(512,512,3)
)
resnet_head.trainable = False

In [None]:
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
with strategy.scope():
    inp = keras.Input(shape = (512,512,3),dtype = tf.float64)
    x = preprocess(inp)
    x = resnet_head(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(104,activation = 'softmax')(x)
    model = keras.Model(inp,x)
model.compile(optimizer = 'adam',
              loss = 'sparse_categorical_crossentropy',
              metrics = ['sparse_categorical_accuracy']
            )


In [None]:
keras.utils.plot_model(model,show_shapes = True)

In [None]:
model.fit(
    train_ds,
    validation_data = val_ds,
    epochs = 30,
    callbacks = [early_stopping]
)

In [None]:
test_ds = test_ds_unparsed.map(parse_unlabeled,AUTO)

In [None]:
test_images_ds = test_ds.map(lambda idnum,image: image).batch(1)
probabilities = model.predict(test_images_ds)
predictions = np.argmax(probabilities, axis=-1)

In [None]:
test_ids_ds = test_ds.map(lambda idnum,image: idnum).unbatch()
test_ids = next(iter(test_ids_ds.batch(7382))).numpy().astype('U')


np.savetxt(
    'submission.csv',
    np.rec.fromarrays([test_ids, predictions]),
    fmt=['%s', '%d'],
    delimiter=',',
    header='id,label',
    comments='',
)