In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
#for dirname, _, filenames in os.walk('/kaggle/input'):
    #for filename in filenames:
        #print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from IPython.core.magic import register_cell_magic
@register_cell_magic
def skip(line, cell=None):
    '''Skips execution of the current line/cell if line evaluates to True.'''
    if eval(line):
        return
        
    get_ipython().run_cell(cell)

In [None]:
import pandas as pd
import tensorflow as tf
from tensorflow import keras 
from tensorflow.keras import backend as K 
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import matplotlib.pyplot as plt
import seaborn as sns
import os, cv2
import random
import dill
import gc
import time
from kaggle_datasets import KaggleDatasets
import tensorflow_datasets.public_api as tfds
import math
from tqdm.notebook import tqdm
import tensorflow_addons as tfa
from tensorflow.python.distribute import values as value_lib
from mt_utils import CstTokenizer

In [None]:
seed=123456789
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'

In [None]:
# NEW on TPU in TensorFlow 24: shorter cross-compatible TPU/GPU/multi-GPU/cluster-GPU detection code
tpu = None
try: # detect TPUs
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: # detect GPUs
    #strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
    #strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() # for clusters of multi-GPU machines

    
#strategy,tpu = tf.distribute.MirroredStrategy(devices=["TPU:0", "TPU:1","TPU:2"]),True

print("Number of accelerators: ", strategy.num_replicas_in_sync)


AUTO = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

# References:
https://www.tensorflow.org/tutorials/distribute/custom_training

https://www.tensorflow.org/guide/tpu

https://www.tensorflow.org/tutorials/text/image_captioning

https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt

#### Dataset links : https://www.kaggle.com/tchaye59/mt-tfrecord-custom-vocab & https://www.kaggle.com/tchaye59/mtcustomvocabimg
#### Pretraining : https://www.kaggle.com/tchaye59/mt-pretraining

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path('mtcustomvocabimg')
SUBMIT = True

In [None]:
start = '<start>'
end = '<end>'
max_seq = 393

In [None]:
# tokenizer
tokenizer = CstTokenizer()
start_index = tokenizer.word_index[start]
end_index = tokenizer.word_index[end]
tokenizer.word_index

# Dataset

In [None]:
class TrainDataset(tfds.core.GeneratorBasedBuilder):
    VERSION = tfds.core.Version('0.1.0')
    
    def _split_generators(self, dl_manager):
        return [
            tfds.core.SplitGenerator(
                    name=f'train',
                    gen_kwargs={
                    },
            )
        ]
    
    def _info(self):
        return tfds.core.DatasetInfo(
            builder=self,
            description=(""),
            features=tfds.features.FeaturesDict({
                "image": tfds.features.Image(shape=(None,None,1)),
                "target": tfds.features.Tensor(shape=(max_seq,),dtype=tf.int8),
            }),
        )
    
    def _generate_examples(self,**args):
        pass
    

class TestDataset(tfds.core.GeneratorBasedBuilder):
    VERSION = tfds.core.Version('0.1.0')
    
    def _split_generators(self, dl_manager):
        return [
            tfds.core.SplitGenerator(
                    name=f'test',
                    gen_kwargs={
                    },
            )
        ]
    
    def _info(self):
        return tfds.core.DatasetInfo(
            builder=self,
            description=(""),
            features=tfds.features.FeaturesDict({
                "image": tfds.features.Image(shape=(None,None,1),),
                "image_id": tfds.features.Text(),
            }),
        )
    
    def _generate_examples(self,**args):
        pass

In [None]:
BATCH_SIZE_PER_REPLICA = 512
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
train_steps = 2424186//GLOBAL_BATCH_SIZE
BUFFER_SIZE = 10000
TRAIN_IMAGE_MODEL = True

prefetch = 50
HEIGHT = 320
WIDTH = 320

In [None]:
# Feel free to change these parameters according to your system's configuration
embedding_dim = 256
vocab_size = len(tokenizer.index_word)+1
attention_features_shape = 256
rnn_units = 512

In [None]:
def data_augment(image):
    p_rotation = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_noise = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_flip1 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_flip2 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
            
    # Rotation
    if p_rotation > .2:
        image = rotation(image)
        
    # Flip
    if p_flip1 > .4:
        image = tf.image.random_flip_left_right(image, seed)
        
    # Flip
    if p_flip2 > .4:
        image = tf.image.random_flip_up_down(image, seed)
        
    # Resize 
    image = tf.image.resize(image,(WIDTH, HEIGHT))
            
    # Noise
    if p_noise >= .4:
        image = random_noise(image)
        
    return image

def rotation(img, rotation=0.2):
    rotation = tf.random.uniform([], -1.0, 1.0, dtype=tf.float32)*rotation
    shape = tf.shape(img)
    h,w = shape[0],shape[1]
    # Pad the image with zeros to avoid losing some pixels after rotation. 
    # This will double the image width and height
    img = tf.image.pad_to_bounding_box(img,h//2, w//2,h*2, w*2)
    img = tfa.image.rotate(img,rotation,fill_value=0)
    # Now remove the zero pads
    return remove_pad(img)


def remove_pad(arr,pad_value = 0.0):
    arr_masked = tf.reduce_all(arr != pad_value , axis=-1)
    #x
    y = tf.argmax(arr_masked, axis=1)
    y = tf.where(y)
    y_min,y_max = y[0,0],y[-1,0]+1
    #y
    x = tf.argmax(arr_masked, axis=0)
    x = tf.where(x)
    x_min,x_max = x[0,0],x[-1,0]+1
    arr = arr[y_min:y_max,x_min:x_max]
    return arr

def random_noise(img,p=0.01):
    shape = tf.shape(img)
    choice = tf.random.categorical(tf.math.log([[p, 1-p]]), tf.size(img),dtype=tf.int32)
    noise = tf.random.categorical(tf.math.log([[1., 1.]]), tf.size(img),dtype=tf.int32)
    choice = tf.reshape(choice,shape)
    noise = tf.reshape(noise,shape)
    noise = tf.abs(choice-1)*noise
    choice = tf.cast(choice,img.dtype)
    noise = tf.cast(noise,img.dtype)
    return (choice*img)+noise

## Load the pretrained image model

In [None]:
with strategy.scope():
    load_locally = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
    #image_model = tf.keras.models.load_model('../input/mt-pretraining/EfficientNetB0.h5',options=load_locally,compile=False)
    image_input = tf.keras.layers.Input(shape=(WIDTH,HEIGHT,1))
    image_model = tf.keras.applications.EfficientNetB1(include_top=False,weights=None,input_shape=(WIDTH,HEIGHT,1),)
    image_model = image_model(image_input)
    
    image_model = tf.keras.layers.Reshape((attention_features_shape,-1))(image_model)
    image_model = tf.keras.Model(image_input,image_model)
    image_model.compile()
    image_model.summary()

In [None]:
def get_dataset(_):
    builder = TrainDataset(data_dir=GCS_PATH)
    # The following line download the dataset
    builder.download_and_prepare()
    dataset = builder.as_dataset()['train']

    # normalize, shuffle and bacth
    def preprecoss(x):
        img,target = x['image'],x['target']
        # Normalize : There are two pixels 0 and 255
        img = tf.cast(img == 0,tf.float32)
        
        # label
        target = tf.cast(target, tf.int32 )
        return data_augment(img),target
        
        return data_augment(img),target
    dataset = dataset.repeat().shuffle(BUFFER_SIZE).map(preprecoss,num_parallel_calls=AUTO)
    dataset = dataset.batch(BATCH_SIZE_PER_REPLICA).prefetch(prefetch)
    return dataset

with strategy.scope():
    if tpu is None:
        dataset = get_dataset(0)
    else:
        dataset = strategy.experimental_distribute_datasets_from_function(get_dataset)
    train_iterator = iter(dataset)

In [None]:
%%time
inputs = next(train_iterator)

# Model

In [None]:
class CNN_Encoder(tf.keras.Model):
    # Since you have already extracted the features and dumped it using pickle
    # This encoder passes those features through a Fully connected layer
    def __init__(self, embedding_dim):
        super(CNN_Encoder, self).__init__()
        self.lstm_layer = tf.keras.layers.LSTM(rnn_units,
                                   return_sequences=True,
                                   return_state=True,
                                   recurrent_initializer='glorot_uniform')


    def call(self, x):
        output, h, c = self.lstm_layer(x)
        return output, h, c
    
    def get_config(self):
        return {
            'lstm_layer': self.lstm_layer,
        }

In [None]:
class RNN_Decoder(tf.keras.Model):

    def __init__(self, embedding_dim, vocab_size):
        super(RNN_Decoder, self).__init__()

        self.dec_units = rnn_units
        self.attention_type = 'luong'

        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        
        # Final Dense layer on which softmax will be applied
        self.fc = tf.keras.Sequential([
            tf.keras.layers.Dropout(0.5),
            tf.keras.layers.Dense(vocab_size)
        ])
        
        # Create attention mechanism with memory = None
        memory_sequence_length = None# GLOBAL_BATCH_SIZE * [max_seq]
        self.attention_mechanism = self.build_attention_mechanism(self.dec_units,None, memory_sequence_length,self.attention_type)

        # Define the fundamental cell for decoder recurrent structure
        self.decoder_rnn_cell = tf.keras.layers.LSTMCell(rnn_units)
        # Wrap attention mechanism with the fundamental rnn cell of decoder
        self.rnn_cell = self.build_rnn_cell()

        # Sampler
        self.sampler = tfa.seq2seq.sampler.TrainingSampler()

        # Define the decoder with respect to fundamental rnn cell
        self.decoder = tfa.seq2seq.BasicDecoder(self.rnn_cell, sampler=self.sampler, output_layer=self.fc,maximum_iterations=max_seq-1)


    def call(self, inputs, initial_state):
        x = self.embedding(inputs)
        sequence_length = tf.repeat(max_seq - 1,tf.shape(x)[0])
        outputs, _, _ = self.decoder(x, initial_state=initial_state,
                                     sequence_length=sequence_length)
        return outputs

    def build_rnn_cell(self):
        rnn_cell = tfa.seq2seq.AttentionWrapper(self.decoder_rnn_cell,
                                                self.attention_mechanism, attention_layer_size=self.dec_units)
        return rnn_cell

    def build_attention_mechanism(self, dec_units, memory, memory_sequence_length=None, attention_type='luong'):
        # ------------- #
        # typ: Which sort of attention (Bahdanau, Luong)
        # dec_units: final dimension of attention outputs 
        # memory: encoder hidden states of shape (batch_size, max_length_input, enc_units)
        # memory_sequence_length: 1d array of shape (batch_size) with every element set to max_length_input (for masking purpose)

        if (attention_type == 'bahdanau'):
            return tfa.seq2seq.BahdanauAttention(units=dec_units, memory=memory,
                                                 memory_sequence_length=memory_sequence_length)
        else:
            return tfa.seq2seq.LuongAttention(units=dec_units, memory=memory,
                                              memory_sequence_length=memory_sequence_length)

    def build_initial_state(self, batch_sz,encoder_state):
        decoder_initial_state = self.rnn_cell.get_initial_state(batch_size=batch_sz, dtype=tf.float32)
        decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
        return decoder_initial_state

    def reset_state(self, batch_size):
        return tf.zeros((batch_size, rnn_units))

    def get_config(self):
        return {
            'units': self.units,
            'embedding': self.embedding,
            'rnn': self.rnn,
            'fc': self.fc,
            'attention': self.attention, }


In [None]:
with strategy.scope():
    optimizer = tf.keras.optimizers.Adam(0.0025)
    crossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

    #@tf.function
    def loss_function(real, pred):
        mask = tf.math.logical_not(tf.math.equal(real, 0))
        loss_ = crossentropy(real, pred)
        
        mask = tf.cast(mask, dtype=loss_.dtype)
        loss_ *= mask
      
        return tf.nn.compute_average_loss(loss_, global_batch_size=GLOBAL_BATCH_SIZE)
        #return tf.reduce_mean(loss_)

### Create Encoder

In [None]:
with strategy.scope():
    bs = 2
    encoder = CNN_Encoder(embedding_dim,)
    sample_output = image_model(tf.zeros((bs,WIDTH,HEIGHT,1),tf.float32))
    sample_output, sample_h, sample_c = encoder(sample_output)
    print("Encoder Outputs Shape: ", sample_output.shape)

### Create Decoder

In [None]:
with strategy.scope():
    decoder = RNN_Decoder(embedding_dim, vocab_size)
    sample_x = tf.random.uniform((bs, max_seq))
    decoder.attention_mechanism.setup_memory(sample_output)
    initial_state = decoder.build_initial_state(bs,[sample_h, sample_c],)

    sample_decoder_outputs = decoder(sample_x, initial_state)

    print("Decoder Outputs Shape: ", sample_decoder_outputs.rnn_output.shape)


# Checkpoint

In [None]:
checkpoint_name = "./model"
!cp -r ../input/mt-fast-distributed-training-tpu/*.dill .
!cp -r ../input/mt-fast-distributed-training-tpu/*.h5 .

In [None]:
def save_model(name=checkpoint_name):
    encoder.save_weights(f'{name}_encoder.h5', options=save_locally)
    decoder.save_weights(f'{name}_decoder.h5', options=save_locally)
    image_model.save_weights('image_model.h5', options=save_locally)
    
def load_model(name=checkpoint_name):
    encoder.load_weights(f'{name}_encoder.h5', options=load_locally)
    decoder.load_weights(f'{name}_decoder.h5', options=load_locally)
    image_model.load_weights('image_model.h5', options=load_locally)
    return encoder, decoder,image_model

In [None]:
with strategy.scope():
    save_locally = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
    load_locally = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
    encoder, decoder, image_model = load_model()

In [None]:
@tf.function(experimental_relax_shapes=True)
def train_step(inputs):
    img_tensor, target = inputs
    loss = 0.
    
    batch_sz = tf.shape(img_tensor)[0]
    
    #image features
    if not TRAIN_IMAGE_MODEL:
        img_tensor = image_model(img_tensor,training=False)

    with tf.GradientTape() as tape:
        
        if TRAIN_IMAGE_MODEL:
                img_tensor = image_model(img_tensor,training=True)
                
        enc_output, enc_h, enc_c = encoder(img_tensor)

        dec_input = target[:, :-1]  # Ignore <end> token
        real = target[:, 1:]  # ignore <start> token

        # Set the AttentionMechanism object with encoder_outputs
        decoder.attention_mechanism.setup_memory(enc_output)

        # Create AttentionWrapperState as initial_state for decoder
        decoder_initial_state = decoder.build_initial_state(batch_sz, [enc_h, enc_c])
        pred = decoder(dec_input, decoder_initial_state)
        logits = pred.rnn_output
        loss = loss_function(real, logits)

    variables = encoder.trainable_variables + decoder.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))

    return loss

@tf.function(experimental_relax_shapes=True)
def distributed_train_step(inputs,):
    per_replica_losses = strategy.run(train_step, args=(inputs,))
    return  strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses,axis=None)

# Training

In [None]:
loss_plot = []
if os.path.exists('losses.dill'):
    loss_plot = dill.load(open('losses.dill','rb'))

In [None]:
EPOCHS = 10
best_loss = 0
for epoch in range(EPOCHS):
    start = time.time()
    total_loss = 0.0
    print(f'Epoch {epoch + 1}/{EPOCHS}')
    for batch in range(train_steps):
        inputs = next(train_iterator)
        loss = distributed_train_step(inputs).numpy()
        total_loss += loss
            
        if batch % 1==0:
            print(f'{batch+1}/{train_steps} Total Loss: {total_loss/(batch+1):.4f}  Batch Loss: {loss:.4f}',end='\r')
            
    # storing the epoch end loss value to plot later
    loss_plot.append(total_loss / train_steps)
    print()
    if best_loss >= total_loss:
        print("Saving...")
        save_model()
        best_loss = total_loss

    print(f'Epoch {epoch + 1} Loss {total_loss/train_steps:.6f}')
    print(f'Time taken for 1 epoch {time.time() - start} sec\n')

In [None]:
save_model()

In [None]:
plt.plot(loss_plot)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Plot')
plt.show()

In [None]:
dill.dump(loss_plot,open('losses.dill','wb'))

# Testing

In [None]:
# Get test dataset
def get_test_dataset(_):
    builder = TestDataset(data_dir=GCS_PATH)
    # The following line download the dataset
    builder.download_and_prepare()
    dataset = builder.as_dataset()['test']

    # normalize, shuffle and bacth
    def preprecoss(x):
        img = x['image']
        
        #Normalize
        img = tf.cast(img == 0,tf.float32)
        im_size = tf.shape(img)
        w,h = im_size[0],im_size[1]
        if h > w:
            img = tf.image.transpose(img)
            img = tf.image.flip_up_down(img)
        
        img = tf.image.resize(img,(WIDTH, HEIGHT))
        return img,x['image_id']
    
    dataset = dataset.map(preprecoss,num_parallel_calls=AUTO)
    dataset = dataset.batch(BATCH_SIZE_PER_REPLICA)
    dataset = dataset.prefetch(prefetch)
    return dataset

with strategy.scope():
    test_dataset = strategy.experimental_distribute_datasets_from_function(get_test_dataset)
    #test_iterator = iter(test_dataset)

In [None]:
@tf.function(experimental_relax_shapes=True)
def greedy_predict(img_tensor):
    inference_batch_size = img_tensor.shape[0]
    img_tensor = image_model(img_tensor,training=False)
    
    enc_out, enc_h, enc_c = encoder(img_tensor)
    dec_h = enc_h
    dec_c = enc_c
    
    start_tokens = tf.fill([inference_batch_size], start_index)
    end_token = end_index
    
    greedy_sampler = tfa.seq2seq.GreedyEmbeddingSampler()
    
    # Instantiate BasicDecoder object
    decoder_instance = tfa.seq2seq.BasicDecoder(cell=decoder.rnn_cell, sampler=greedy_sampler, output_layer=decoder.fc,maximum_iterations=max_seq)
    # Setup Memory in decoder stack
    decoder.attention_mechanism.setup_memory(enc_out)

    # set decoder_initial_state
    decoder_initial_state = decoder.build_initial_state(inference_batch_size, [enc_h, enc_c])

    ### Since the BasicDecoder wraps around Decoder's rnn cell only, you have to ensure that the inputs to BasicDecoder 
    ### decoding step is output of embedding layer. tfa.seq2seq.GreedyEmbeddingSampler() takes care of this. 
    ### You only need to get the weights of embedding layer, which can be done by decoder.embedding.variables[0] and pass this callabble to BasicDecoder's call() function

    decoder_embedding_matrix = decoder.embedding.variables[0]

    outputs, _, _ = decoder_instance(decoder_embedding_matrix, start_tokens = start_tokens, end_token= end_token, initial_state=decoder_initial_state)
    return outputs.sample_id#.numpy()

In [None]:
@tf.function(experimental_relax_shapes=True)
def beam_predict(img_tensor,beam_width=3):
    inference_batch_size = img_tensor.shape[0]
    img_tensor = image_model(img_tensor,training=False)
    
    enc_out, enc_h, enc_c = encoder(img_tensor)

    dec_h = enc_h
    dec_c = enc_c

    start_tokens = tf.fill([inference_batch_size], start_index)
    end_token = end_index

    # From official documentation
    # NOTE If you are using the BeamSearchDecoder with a cell wrapped in AttentionWrapper, then you must ensure that:
    # The encoder output has been tiled to beam_width via tfa.seq2seq.tile_batch (NOT tf.tile).
    # The batch_size argument passed to the get_initial_state method of this wrapper is equal to true_batch_size * beam_width.
    # The initial state created with get_initial_state above contains a cell_state value containing properly tiled final state from the encoder.

    enc_out = tfa.seq2seq.tile_batch(enc_out, multiplier=beam_width)
    decoder.attention_mechanism.setup_memory(enc_out)
    print("beam_with * [batch_size, max_length_input, rnn_units] :  3 * [1, 16, 1024]] :", enc_out.shape)

    # set decoder_inital_state which is an AttentionWrapperState considering beam_width
    hidden_state = tfa.seq2seq.tile_batch([enc_h, enc_c], multiplier=beam_width)
    print(beam_width*inference_batch_size)
    decoder_initial_state = decoder.rnn_cell.get_initial_state(batch_size=beam_width*inference_batch_size,dtype=tf.float32)
    decoder_initial_state = decoder_initial_state.clone(cell_state=hidden_state)

    # Instantiate BeamSearchDecoder
    decoder_instance = tfa.seq2seq.BeamSearchDecoder(decoder.rnn_cell,beam_width=beam_width, output_layer=decoder.fc,maximum_iterations=max_seq)
    decoder_embedding_matrix = decoder.embedding.variables[0]

    # The BeamSearchDecoder object's call() function takes care of everything.
    outputs, final_state, sequence_lengths = decoder_instance(decoder_embedding_matrix, start_tokens=start_tokens, end_token=end_token, initial_state=decoder_initial_state)
    # outputs is tfa.seq2seq.FinalBeamSearchDecoderOutput object. 
    # The final beam predictions are stored in outputs.predicted_id
    # outputs.beam_search_decoder_output is a tfa.seq2seq.BeamSearchDecoderOutput object which keep tracks of beam_scores and parent_ids while performing a beam decoding step
    # final_state = tfa.seq2seq.BeamSearchDecoderState object.
    # Sequence Length = [inference_batch_size, beam_width] details the maximum length of the beams that are generated


    # outputs.predicted_id.shape = (inference_batch_size, time_step_outputs, beam_width)
    # outputs.beam_search_decoder_output.scores.shape = (inference_batch_size, time_step_outputs, beam_width)
    # Convert the shape of outputs and beam_scores to (inference_batch_size, beam_width, time_step_outputs)
    final_outputs = tf.transpose(outputs.predicted_ids, perm=(0,2,1))
    beam_scores = tf.transpose(outputs.beam_search_decoder_output.scores, perm=(0,2,1))

    return final_outputs, beam_scores

In [None]:
@tf.function(experimental_relax_shapes=True)
def distributed_greedy_predict(inputs):
    return strategy.run(greedy_predict, args=(inputs,))


@tf.function(experimental_relax_shapes=True)
def distributed_beam_predict(inputs):
    return strategy.run(beam_predict, args=(inputs,))

In [None]:
def process_predict(x):
    x = x.numpy()
    
    res = []
    for i in range(len(x)):
        tmp = x[i]
        idx = np.argwhere(tmp == end_index)
        # Remove sos, eos and pad 
        if idx.size != 0:
            idx = idx[0,0]
            tmp = tmp[:idx]
            
        tmp = tmp[tmp!=0]
        tmp = tmp[tmp!=start_index]
        tmp = tmp[tmp!=end_index]
        res.append(tmp)
    captions = tokenizer.detokenize(res)
    captions = [f'InChI=1S/{c}'  for c in captions]
    return captions

# Submit

In [None]:
%%skip not SUBMIT
results = []
images = []

for (image,iids) in tqdm(test_dataset,total=1616107//GLOBAL_BATCH_SIZE):
    if type(image) == value_lib.PerReplica:
        stop = False
        for v in image.values:
            if v.shape[0] == 0:
                stop = True
        if stop:
            print('Stopping...')
            break
        res = distributed_greedy_predict(image)
        res = tf.concat(strategy.unwrap(res),axis=0)
        iids = tf.concat(strategy.unwrap(iids),axis=0)
    else:
        res = greedy_predict(image)
    
    res = process_predict(res)
    results.extend(res)
    images.extend([s.decode('utf-8') for s in iids.numpy()])

    
# Process the last batch
if type(image) == value_lib.PerReplica:
    image_values,iids_values =  image.values,iids.values

    for i in range(strategy.num_replicas_in_sync):
        if image_values[i].shape[0] == 0:
            continue
        res = greedy_predict(image_values[i])
        res = process_predict(res)
        results.extend(res)
        images.extend([s.decode('utf-8') for s in iids_values[i].numpy()])

In [None]:
%%skip not SUBMIT
submission_df = pd.read_csv('../input/bms-molecular-translation/sample_submission.csv',index_col=0)
submission_df.loc[images,'InChI']=results
submission_df.to_csv('submission.csv')