This notebook is a baseline for an encoder and decoder model written in Tensorflow and running on a TPU. Several notebooks, examples and documentation were used as a source of inspiration

In [None]:
!pip install -q --upgrade pip
!pip install -q efficientnet

In [None]:
import pandas as pd
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import efficientnet.tfkeras as efn

from tensorflow.keras.mixed_precision import experimental as mixed_precision
from kaggle_datasets import KaggleDatasets
from tqdm.notebook import tqdm
from multiprocessing import cpu_count

import numpy as np
import os
import io
import time
import pickle
import math
import random

In [None]:
try:
    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Running on TPU',TPU.master())
except ValueError:
    print('Running on GPU')
    TPU = None

In [None]:
if TPU:
    tf.config.experimental_connect_to_cluster(TPU)
    tf.tpu.experimental.initialize_tpu_system(TPU)
    strategy = tf.distribute.experimental.TPUStrategy(TPU)
else:
    strategy = tf.distribute.get_strategy()

REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

mixed_precision.set_policy('mixed_bfloat16' if TPU else 'float32')

print(f'Compute dtype: {mixed_precision.global_policy().compute_dtype}')
print(f'Variable dtype: {mixed_precision.global_policy().variable_dtype}')

In [None]:
DEBUG = False

IMG_HEIGHT = 256
IMG_WIDTH = 448
N_CHANNELS = 3

MAX_INCHI_LEN = 200

BATCH_SIZE_BASE = 6 if DEBUG else (64 if TPU else 12)
BATCH_SIZE = BATCH_SIZE_BASE*REPLICAS
BATCH_SIZE_DEBUG = 2

N_TEST_IMGS = 1616107
N_TEST_STEPS = N_TEST_IMGS // BATCH_SIZE + 1

TARGET_DTYPE = tf.bfloat16 if TPU else tf.float32
LABEL_DTYPE = tf.uint8

VAL_SIZE = int(1e3) if DEBUG else int(100e3)
VAL_STEPS = VAL_SIZE // BATCH_SIZE

IMAGENET_MEAN = tf.constant([0.485, 0.456, 0.406], dtype=tf.float32)
IMAGENET_STD = tf.constant([0.229, 0.224, 0.225], dtype=tf.float32)

if TPU:
    GCS_DS_PATH = KaggleDatasets().get_gcs_path('molecular-translation-images-cleaned-tfrecords')
    
AUTO = tf.data.experimental.AUTOTUNE

In [None]:
with open('../input/molecular-translation-images-cleaned-tfrecords/vocabulary_to_int.pkl', 'rb') as handle:
    vocabulary_to_int = pickle.load( handle)
    
with open('../input/molecular-translation-images-cleaned-tfrecords/int_to_vocabulary.pkl', 'rb') as handle:
    int_to_vocabulary = pickle.load( handle)
    

print(f'vocabulary_to_int head: {list(vocabulary_to_int.items())[:5]}')
print(f'int_to_vocabulary head: {list(int_to_vocabulary.items())[:5]}')

In [None]:
VOCAB_SIZE = len(vocabulary_to_int.values())
SEQ_LEN_OUT = MAX_INCHI_LEN
DECODER_DIM = 512
CHAR_EMBEDDING_DIM = 256
ATTENTION_UNITS = 256

print(f'VOCAB_SIZE:{VOCAB_SIZE}')

In [None]:
@tf.function
def decode_tfrecord(record_bytes):
    features = tf.io.parse_single_example(record_bytes, {
        'image': tf.io.FixedLenFeature([], tf.string),
        'InChI': tf.io.FixedLenFeature([MAX_INCHI_LEN], tf.int64),
    })

    # decode the PNG and explicitly reshape to image size (required on TPU)
    image = tf.io.decode_png(features['image'])    
    image = tf.reshape(image, [IMG_HEIGHT, IMG_WIDTH, 1])
    # normalize according to ImageNet mean and std
    image = tf.cast(image, tf.float32)  / 255.0
    image = (image - IMAGENET_MEAN) / IMAGENET_STD
    
    if TPU: # if running on TPU image needs to be cast to bfloat16
        image = tf.cast(image, TARGET_DTYPE)
    
    InChI = tf.reshape(features['InChI'], [MAX_INCHI_LEN])
    InChI = tf.cast(InChI, LABEL_DTYPE)
    
    return image, InChI

In [None]:
def get_dataset(bs=BATCH_SIZE, val=False):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False
    
    if val:
        FNAMES_TRAIN_TFRECORDS = tf.io.gfile.glob(f'{GCS_DS_PATH}/val/*.tfrecords')
    
    FNAMES_TRAIN_TFRECORDS = tf.io.gfile.glob(f'{GCS_DS_PATH}/train/*.tfrecords')
    dataset = tf.data.TFRecordDataset(FNAMES_TRAIN_TFRECORDS, num_parallel_reads=AUTO)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.prefetch(AUTO) 
    dataset = dataset.repeat()
    dataset = dataset.map(decode_tfrecord, num_parallel_calls=AUTO)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    dataset = dataset.prefetch(1) 
    
    return dataset

train_dataset = get_dataset()

In [None]:
val_dataset = get_dataset(val=True)

In [None]:
class Encoder(tf.keras.Model):
    def __init__(self):
        super(Encoder, self).__init__()
        
        self.feature_maps = efn.EfficientNetB0(include_top=False, weights='noisy-student')
        
        global ENCODER_DIM
        ENCODER_DIM = self.feature_maps.layers[-1].output_shape[-1]
        
        self.reshape = tf.keras.layers.Reshape([-1, ENCODER_DIM], name='reshape_featuere_maps')

    def call(self, x, training, debug=False):
        x = self.feature_maps(x, training=training)
        if debug:
            print(f'feature maps shape: {x.shape}')
            
        x = self.reshape(x, training=training)
        if debug:
            print(f'feature maps reshaped shape: {x.shape}')
        
        return x

In [None]:
imgs, lbls = next(iter(train_dataset))
print(f'imgs.shape: {imgs.shape}, lbls.shape: {lbls.shape}')
img0 = imgs[0].numpy().astype(np.float32)
train_batch_info = (img0.mean(), img0.std(), img0.min(), img0.max(), img0.dtype)
print('train img0 mean: %.3f, std: %.3f, min: %.3f, max: %.3f, %s'%train_batch_info)

In [None]:
with tf.device('/CPU:0'):
    encoder = Encoder()
    encoder_res = encoder(imgs[:BATCH_SIZE_DEBUG], debug = True)
    
print('Encode output shape: (batch_size, seq_len, units) {}'.format(encoder_res.shape))

In [None]:
class Decoder(keras.Model):
    def __init__(self, vocab_size, encoder_dim, decoder_dim, char_embedding_dim):
        super(Decoder, self).__init__()
        self.init_h = keras.layers.Dense(units=decoder_dim, input_shape=[encoder_dim], name='encoder_res_to_hiddent_init')
        self.init_c = keras.layers.Dense(units=decoder_dim, input_shape=[encoder_dim], name='encoder_res_to_inp_act_init')
        self.lstm_cell = keras.layers.LSTMCell(decoder_dim, name='lstm_char_predictor')
        self.do = keras.layers.Dropout(0.3, name='prediction_dropout')
        self.fcn = keras.layers.Dense(units=vocab_size, input_shape=[decoder_dim], dtype=tf.float32, name='lstm_output_to_char_probs')
        self.embedding = keras.layers.Embedding(vocab_size, char_embedding_dim, name='char_embedding')
#         self.attention = BahdanauAttention(attention_units)
        
    def call(self, char, h, c, enc_output, training, debug=False):
        if debug:
            print(f'char shape: {char.shape}, h shape: {h.shape}, c shape: {c.shape}, enc_output shape: {enc_output.shape}')
        char = self.embedding(char, training=training)
        char = tf.squeeze(char, axis=1)

        lstm_input = char
        
        if debug:
            print(f'lstm_input shape: {lstm_input.shape}')
        _, (h_new, c_new) = self.lstm_cell(lstm_input, (h, c), training=training)
        output = self.do(h_new, training=training)
        output = self.fcn(output, training=training)
        
        return output, h_new, c_new
    
    def init_hidden_state(self, encoder_out, training):
        mean_encoder_out = tf.math.reduce_mean(encoder_out, axis=1)
        h = self.init_h(mean_encoder_out, training=training)
        c = self.init_c(mean_encoder_out, training=training)
        
        return h, c

In [None]:
with tf.device('/CPU:0'):
    decoder = Decoder(VOCAB_SIZE, ENCODER_DIM, DECODER_DIM, CHAR_EMBEDDING_DIM)
    h, c = decoder.init_hidden_state(encoder_res[:BATCH_SIZE_DEBUG], training=False)
    preds, h, c = decoder(lbls[:BATCH_SIZE_DEBUG, :1], h, c, encoder_res, debug=True)
    print('Decoder output shape: (batch_size, vocab_size {}'.format(preds.shape))

In [None]:
START_TOKEN = tf.constant(vocabulary_to_int.get('<start>'), dtype=tf.int64)
END_TOKEN = tf.constant(vocabulary_to_int.get('<end>'), dtype=tf.int64)
PAD_TOKEN = tf.constant(vocabulary_to_int.get('<pad>'), dtype=tf.int64)

In [None]:
tf.keras.backend.clear_session()

with strategy.scope():
    mixed_precision.set_policy('mixed_bfloat16' if TPU else 'float32')
    
    tf.config.optimizer.set_jit(True)
    
    print(f'Compute dtype: {mixed_precision.global_policy().compute_dtype}')
    print(f'Variable dtype: {mixed_precision.global_policy().variable_dtype}')
    
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
    
    def loss_function(real, pred):
        per_example_loss = loss_object(real, pred)

        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=BATCH_SIZE)
    
    # Metrics
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    train_loss = tf.keras.metrics.Sum()
    val_loss = tf.keras.metrics.Sum()


    # Encoder
    encoder = Encoder()
    encoder.build(input_shape=[BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, N_CHANNELS])
    encoder_res = encoder(imgs[:2], training=False)
    
    # Decoder
    decoder = Decoder(VOCAB_SIZE, ENCODER_DIM, DECODER_DIM, CHAR_EMBEDDING_DIM)
    h, c = decoder.init_hidden_state(encoder_res, training=False)
    preds, h, c = decoder(lbls[:2, :1], h, c, encoder_res, training=False)
    
    # Adam Optimizer
    optimizer = tf.keras.optimizers.Adam()

In [None]:
EPOCHS = 1
WARMUP_STEPS = 500
TRAIN_STEPS = 1000
VERBOSE_FREQ = 100
STEPS_PER_EPOCH = TRAIN_STEPS // VERBOSE_FREQ
TOTAL_STEPS = EPOCHS * TRAIN_STEPS

In [None]:
def lrfn(step, WARMUP_LR_START, LR_START, LR_FINAL, DECAYS):
    # exponential warmup
    if step < WARMUP_STEPS:
        warmup_factor = (step / WARMUP_STEPS) ** 2
        lr = WARMUP_LR_START + (LR_START - WARMUP_LR_START) * warmup_factor
    # staircase decay
    else:
        power = (step - WARMUP_STEPS) // ((TOTAL_STEPS - WARMUP_STEPS) / (DECAYS + 1))
        decay_factor =  ((LR_START / LR_FINAL) ** (1 / DECAYS)) ** power
        lr = LR_START / decay_factor

    return round(lr, 8)

In [None]:
def dense_to_sparse(dense):
    ones = tf.ones(dense.shape)
    indices = tf.where(ones)
    values = tf.gather_nd(dense, indices)
    sparse = tf.SparseTensor(indices, values, dense.shape)
    
    return sparse

# computes the levenshtein distance between the predictions and labels
def get_levenshtein_distance(preds, lbls):
    preds = tf.cast(preds, tf.int64)

    preds = tf.where(tf.not_equal(preds, START_TOKEN) & tf.not_equal(preds, END_TOKEN) & tf.not_equal(preds, PAD_TOKEN), preds, y=0)
    
    lbls = strategy.gather(lbls, axis=0)
    lbls = tf.cast(lbls, tf.int64)
    lbls = tf.where(tf.not_equal(lbls, START_TOKEN) & tf.not_equal(lbls, END_TOKEN) & tf.not_equal(lbls, PAD_TOKEN), lbls, y=0)
    
    preds_sparse = dense_to_sparse(preds)
    lbls_sparse = dense_to_sparse(lbls)

    batch_distance = tf.edit_distance(preds_sparse, lbls_sparse, normalize=False)
    mean_distance = tf.math.reduce_mean(batch_distance)
    
    return mean_distance

In [None]:
@tf.function()
def distributed_train_step(dataset):
    def train_step(inp, targ):
        total_loss = 0.0
        
        with tf.GradientTape() as tape:
            enc_output = encoder(inp, training=True)
            h, c = decoder.init_hidden_state(enc_output, training=True)
            dec_input = tf.expand_dims(targ[:, 0], 1)
            for idx in range(1, SEQ_LEN_OUT):
                t = targ[:, idx]
                t = tf.reshape(t, [BATCH_SIZE_BASE])
                predictions, h, c = decoder(dec_input, h, c, enc_output, training=True)
                total_loss += loss_function(t, predictions)
                train_accuracy.update_state(t, predictions)
                dec_input = tf.expand_dims(t, 1)
                
        variables = encoder.trainable_variables + decoder.trainable_variables
        gradients = tape.gradient(total_loss, variables)
        gradients, _ = tf.clip_by_global_norm(gradients, 10.0)
        optimizer.apply_gradients(zip(gradients, variables))
        
        batch_loss = total_loss/(SEQ_LEN_OUT-1)
        train_loss.update_state(batch_loss)
        
    train_loss.reset_states()
    train_accuracy.reset_states()
    
    for _ in tf.range(tf.convert_to_tensor(VERBOSE_FREQ)):
        strategy.run(train_step, args=next(dataset))

In [None]:
def validation_step(inp, targ):
    total_loss = 0.0
    enc_output = encoder(inp, training=False)
    h, c = decoder.init_hidden_state(enc_output, training=False)
    dec_input = tf.expand_dims(targ[:, 0], 1)

    predictions_seq = tf.expand_dims(targ[:, 0], 1)

    # Teacher forcing - feeding the target as the next input
    for t in range(1, SEQ_LEN_OUT):
        # passing enc_output to the decoder
        predictions, h, c = decoder(dec_input, h, c, enc_output, training=False)

        # add loss 
        # update loss and train metrics
        total_loss += loss_function(targ[:, t], predictions)
        
        # add predictions to pred_seq
        dec_input = tf.math.argmax(predictions, axis=1, output_type=tf.int32)
        dec_input = tf.expand_dims(dec_input, axis=1)
        dec_input = tf.cast(dec_input, LABEL_DTYPE)
        predictions_seq = tf.concat([predictions_seq, dec_input], axis=1)
        
    batch_loss = total_loss / (SEQ_LEN_OUT - 1)
    val_loss.update_state(batch_loss)
    
    return predictions_seq

In [None]:
@tf.function
def distributed_val_step(dataset):
    inp_val, targ_val = next(dataset)
    per_replica_predictions_seq = strategy.run(validation_step, args=(inp_val, targ_val))
    predictions_seq = strategy.gather(per_replica_predictions_seq, axis=0)
    
    return predictions_seq, targ_val

In [None]:
def get_val_metrics(val_dist_dataset):
    # reset metrics
    val_loss.reset_states()
    total_ls_distance = 0.0
    
    for step in range(VAL_STEPS):
        predictions_seq, targ = distributed_val_step(val_dist_dataset)
        levenshtein_distance = get_levenshtein_distance(predictions_seq, targ)
        total_ls_distance += levenshtein_distance
    
    return total_ls_distance / VAL_STEPS

In [None]:
def log(batch, t_start_batch, val_ls_distance=False):
    print(
        f'Step %s|' % f'{batch * VERBOSE_FREQ}/{TRAIN_STEPS}'.ljust(10, ' '),
        f'loss: %.3f,' % (train_loss.result() / VERBOSE_FREQ),
        f'acc: %.3f, ' % train_accuracy.result(),
    end='')
    
    if val_ls_distance:
        print(
            f'val_loss: %.3f, ' % (val_loss.result() / VERBOSE_FREQ),
            f'val lsd: %s,' % ('%.1f' % val_ls_distance).ljust(5, ' '),
        end='')
    # always end with batch duration and line break
    print(
        f'lr: %s,' % ('%.1E' % LRREDUCE.get_lr()).ljust(7),
        f't: %s sec' % int(time.time() - t_start_batch),
    )

In [None]:
class Stats():
    def __init__(self):
        self.stats = {
            'train_loss': [],
            'train_acc': [],
        }
        
    def update_stats(self):
        self.stats['train_loss'].append(train_loss.result() / VERBOSE_FREQ)
        self.stats['train_acc'].append(train_accuracy.result())
        
    def get_stats(self, metric):
        return self.stats[metric]
        
    def plot_stat(self, metric):
        plt.figure(figsize=(15,8))
        plt.xticks(fontsize=16)
        plt.yticks(fontsize=16)
        plt.plot(self.stats[metric])
        plt.grid()
        plt.title(f'{metric} stats', size=24)
        plt.show()
        
STATS = Stats()

In [None]:
LR_SCHEDULE = [lrfn(step, 1e-8, 2e-3, 1e-4 ,EPOCHS) for step in range(TOTAL_STEPS)]

class LRReduce():
    def __init__(self, optimizer, lr_schedule):
        self.opt = optimizer
        self.lr_schedule = lr_schedule
        # assign initial learning rate
        self.lr = lr_schedule[0]
        self.opt.learning_rate.assign(self.lr)
        
    def step(self, step):
        self.lr = self.lr_schedule[step]
        # assign learning rate to optimizer
        self.opt.learning_rate.assign(self.lr)
        
    def get_counter(self):
        return self.c
    
    def get_lr(self):
        return self.lr
        
LRREDUCE = LRReduce(optimizer, LR_SCHEDULE)

In [None]:
step_total = 0
for epoch in range(EPOCHS):
    print(f'*****EPOCH: {epoch+1}*****')
    t_start = time.time()
    t_start_batch = time.time()
    total_loss = 0
    
    train_dist_dataset = iter(strategy.experimental_distribute_dataset(train_dataset))
    val_dist_dataset = iter(strategy.experimental_distribute_dataset(val_dataset))
    
    for step in range(1, STEPS_PER_EPOCH+1):
        distributed_train_step(train_dist_dataset)
        STATS.update_stats()
        encoder.save_weights(f'./encoder_epoch_{epoch+1}.h5')
        decoder.save_weights(f'./decoder_epoch_{epoch+1}.h5')
        
        if step == STEPS_PER_EPOCH:
            val_ls_distance = get_val_metrics(val_dist_dataset)
            log(step, t_start_batch, val_ls_distance)
        else:
            log(step, t_start_batch)
            # reset start time batch
            t_start_batch = time.time()
            
        total_loss += train_loss.result()
        LRREDUCE.step(epoch * TRAIN_STEPS + step * VERBOSE_FREQ - 1)
        
        if np.isnan(total_loss):
            break
            
    if np.isnan(total_loss):
        break

    print(f'Epoch {epoch} Loss {round(total_loss.numpy() / TRAIN_STEPS, 3)}, time: {int(time.time() - t_start)} sec\n')

In [None]:
END_TOKEN = vocabulary_to_int.get('<end>')
START_TOKEN = vocabulary_to_int.get('<start>')
PAD_TOKEN =  vocabulary_to_int.get('<pad>')

def int2char(i_str):
    res = 'InChI=1S/'
    for i in i_str:
        if i == END_TOKEN:
            return res
        elif i != START_TOKEN and i != PAD_TOKEN:
            res += int_to_vocabulary.get(i)
    return res

In [None]:
@tf.function
def decode_tfrecord_test(record_bytes):
    features = tf.io.parse_single_example(record_bytes, {
        'image': tf.io.FixedLenFeature([], tf.string),
        'image_id': tf.io.FixedLenFeature([], tf.string),
    })

    image = tf.io.decode_png(features['image'])    
    image = tf.reshape(image, [IMG_HEIGHT, IMG_WIDTH, 1])
    image = tf.cast(image, tf.float32)  / 255.0
    image = (image - IMAGENET_MEAN) / IMAGENET_STD
    image = tf.cast(image, TARGET_DTYPE)
    
    image_id = features['image_id']
    
    return image, image_id

In [None]:
def get_test_dataset(bs=BATCH_SIZE):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False
    
    if TPU:
        FNAMES_TRAIN_TFRECORDS = tf.io.gfile.glob(f'{GCS_DS_PATH}/test/*.tfrecords')
    else:
        FNAMES_TRAIN_TFRECORDS = tf.io.gfile.glob('/kaggle/input/molecular-translation-images-cleaned-tfrecords/test/*.tfrecords')
        
    test_dataset = tf.data.TFRecordDataset(FNAMES_TRAIN_TFRECORDS, num_parallel_reads=AUTO if TPU else cpu_count())
    test_dataset = test_dataset.with_options(ignore_order)
    test_dataset = test_dataset.prefetch(AUTO)
    test_dataset = test_dataset.map(decode_tfrecord_test, num_parallel_calls=AUTO if TPU else cpu_count())
    test_dataset = test_dataset.batch(BATCH_SIZE)
    test_dataset = test_dataset.prefetch(1)
    
    return test_dataset

test_dataset = get_test_dataset()

In [None]:
imgs, img_ids = next(iter(test_dataset))
print(f'imgs.shape: {imgs.shape}, img_ids.shape: {img_ids.shape}')
print(f'imgs dtype: {imgs.dtype}, img_ids dtype: {img_ids.dtype}')
img0 = imgs[0].numpy().astype(np.float32)
train_batch_info = (img0.mean(), img0.std(), img0.min(), img0.max())
print('train img 0 mean: %.3f, 0 std: %.3f, min: %.3f, max: %.3f' % train_batch_info)

In [None]:
# Models
tf.keras.backend.clear_session()

# enable XLA optmizations
tf.config.optimizer.set_jit(True)

with strategy.scope():
    encoder = Encoder()
    encoder.build(input_shape=[BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, N_CHANNELS])
    encoder_res = encoder(imgs[:BATCH_SIZE])
    encoder.load_weights('./encoder_epoch_1.h5')
    encoder.trainable = False
    encoder.compile()

    decoder = Decoder(VOCAB_SIZE, ENCODER_DIM, DECODER_DIM, CHAR_EMBEDDING_DIM)
    h, c = decoder.init_hidden_state(encoder_res, training=False)
    preds, h, c = decoder(tf.ones([BATCH_SIZE, 1]), h, c, encoder_res)
    decoder.load_weights('./decoder_epoch_1.h5')
    decoder.trainable = False
    decoder.compile()

In [None]:
def prediction_step(imgs):
    # get the feature maps from the encoder
    encoder_res = encoder(imgs)
    # initialize the hidden LSTM states given the feature maps
    h, c = decoder.init_hidden_state(encoder_res, training=False)
    
    # initialize the prediction results with the <start> token
    predictions_seq = tf.fill([len(imgs), 1], value=vocabulary_to_int.get('<start>'))
    predictions_seq = tf.cast(predictions_seq, tf.int32)
    # first encoder input is always the <start> token
    dec_input = tf.expand_dims([vocabulary_to_int.get('<start>')] * len(imgs), 1)

    # Teacher forcing - feeding the target as the next input
    for t in range(1, SEQ_LEN_OUT):
        # make character prediction and receive new LSTM states
        predictions, h, c = decoder(dec_input, h, c, encoder_res)
        
        # softmax prediction to get prediction classes
        dec_input = tf.math.argmax(predictions, axis=1, output_type=tf.int32)
               
        # expand dimension of prediction to make valid encoder input
        dec_input = tf.expand_dims(dec_input, axis=1)
        
        # add character to predictions
        predictions_seq = tf.concat([predictions_seq, dec_input], axis=1)
            
    return predictions_seq

In [None]:
@tf.function
def distributed_test_step(imgs):
    per_replica_predictions = strategy.run(prediction_step, args=[imgs])
    predictions = strategy.gather(per_replica_predictions, axis=0)
    
    return predictions

In [None]:
@tf.function
def test_step_last_batch(imgs):
    return prediction_step(imgs)

In [None]:
predictions_inchi = []
# List with image id's
predictions_img_ids = []
# Distributed test set, needed for TPU
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

# Prediction Loop
for step, (per_replica_imgs, per_repliac_img_ids) in tqdm(enumerate(test_dist_dataset), total=N_TEST_STEPS):
    # special step for last batch which has a different size
    # this step will take about half a minute because the function needs to be compiled
    if TPU and step == N_TEST_STEPS - 1:
        imgs_single_device = strategy.gather(per_replica_imgs, axis=0)
        preds = test_step_last_batch(imgs_single_device)
    else:
        # make test step and get predictions
        preds = distributed_test_step(per_replica_imgs)
    
    # get image ids
    img_ids = strategy.gather(per_repliac_img_ids, axis=0)
    
    # decode integer encoded predictions to characters and add to InChI's prediction list
    predictions_inchi += [int2char(p) for p in preds.numpy()]
    # add image id's to list
    predictions_img_ids += [e.decode() for e in img_ids.numpy()]

In [None]:
submission = pd.DataFrame({ 'image_id': predictions_img_ids, 'InChI': predictions_inchi }, dtype='string')
submission.head()

In [None]:
submission.to_csv('submission.csv', index=False)