In [None]:
%%capture
!pip install timm
!pip install python-Levenshtein

In [None]:
import os
import time

import pandas as pd
import numpy as np

import timm

import cv2

from tqdm import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import Dataset, DataLoader
import torch.backends.cudnn as cudnn

from sklearn.model_selection import train_test_split

from Levenshtein import distance as levenshtein_distance

In [None]:
BASE_DIR = '../input/bms-molecular-translation/train'

In [None]:
label_df = pd.read_csv("../input/bms-molecular-translation/train_labels.csv")
label_df.head()

In [None]:
label_df.shape

In [None]:
# Currently using only 90k+10k images
label_df = label_df.iloc[:100000]

In [None]:
# Train on 45000, validate on 10k
train_df, test_df = train_test_split(label_df, test_size=0.1, shuffle=False)
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

In [None]:
train_df.shape, test_df.shape

## Build Vocabulary

There are only limited set of characters in the InChI notation. Refer this for more details: https://www.kaggle.com/c/bms-molecular-translation/discussion/223471. So I'll build a character based vocabulary class which will be used for target preprocessing.

In [None]:
class Vocabulary:
    def __init__(self, freq_threshold=2, reverse=False):
        self.itos = {0: "<pad>", 1: "<sos>", 2: "<eos>", 3: "<unk>"}
        self.stoi = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}
        self.freq_threshold = freq_threshold
        self.reverse = reverse
        self.tokenizer = self._tokenizer

    def __len__(self):
        return len(self.itos)
    
    def _tokenizer(self, text):
        return (char for char in text)

    def tokenize(self, text):
        if self.reverse:
            return [token for token in self.tokenizer(text)][::-1]
        else:
            return [token for token in self.tokenizer(text)]

    def build_vocabulary(self, sentence_list):
        """Basically builds a frequency map for all possible characters."""
        frequencies = {}
        idx = len(self.itos)

        for sentence in sentence_list:
            # Preprocess the InChI.
            for char in self.tokenize(sentence):
                if char in frequencies:
                    frequencies[char] += 1
                else:
                    frequencies[char] = 1

                if frequencies[char] == self.freq_threshold:
                    self.stoi[char] = idx
                    self.itos[idx] = char
                    idx += 1

    def numericalize(self, text):
        """Convert characters to numbers."""
        tokenized_text = self.tokenize(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<unk>"]
            for token in tokenized_text
        ]

In [None]:
%%time
# Build vocab using training data
freq_threshold = 2
vocab = Vocabulary(freq_threshold=freq_threshold, reverse=False)

# build vocab
vocab.build_vocabulary(train_df['InChI'].to_list())

In [None]:
IMG_SIZE = 128

In [None]:
class BMSDataset(Dataset):
    def __init__(self, df):
        super().__init__()
        self.df = df
        self.vocab = vocab
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # Read the image
        img_id = self.df.iloc[idx]['image_id']
        label = self.df.iloc[idx]['InChI']
        label_len = len(label) + 2  # (2 for <sos> and <eos>)
        img_path = os.path.join(BASE_DIR, img_id[0], img_id[1], img_id[2], f'{img_id}.png')
        
        img = self._load_from_file(img_path)
        
        # Convert label to numbers
        label = self._get_numericalized(label, self.vocab)
        return img, torch.tensor(label), torch.tensor(label_len)
    
    def _get_numericalized(self, sentence, vocab):
        """Numericalize given text using prebuilt vocab."""
        numericalized = [vocab.stoi["<sos>"]]
        numericalized.extend(vocab.numericalize(sentence))
        numericalized.append(vocab.stoi["<eos>"])
        return numericalized

    def _load_from_file(self, img_path):
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image = cv2.resize(image, (IMG_SIZE, IMG_SIZE)) 
        image /= 255.0  # Normalize
        return image

In [None]:
def bms_collate(batch):
    
    imgs, labels, label_lens = [], [], []
    
    for data_point in batch:
        imgs.append(torch.from_numpy(data_point[0]).permute(2, 0, 1))
        labels.append(data_point[1])
        label_lens.append(data_point[2])

    
    labels = pad_sequence(labels, batch_first=True, padding_value=vocab.stoi["<pad>"])

    return torch.stack(imgs), labels, torch.stack(label_lens).reshape(-1, 1)
    

train_dataset = BMSDataset(train_df)

train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    pin_memory=True,
    num_workers=4,
    shuffle=True,
    collate_fn=bms_collate
)

val_dataset = BMSDataset(test_df)
val_loader = DataLoader(
    val_dataset,
    batch_size=128,
    pin_memory=True,
    num_workers=4,
    shuffle=False,
    collate_fn=bms_collate
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Available device: {device}")

## Model Architecture

This code is heavily based on this great tutorial: https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning

In [None]:
class Encoder(nn.Module):
    """
    Encoder.
    """

    def __init__(self, back_bone, encoded_image_size=8):
        super(Encoder, self).__init__()
        self.enc_image_size = encoded_image_size

        # Remove linear and pool layers (since we're not doing classification)
        modules = list(back_bone.children())[:-2]
        self.back_bone = nn.Sequential(*modules)

        # Resize image to fixed size to allow input images of variable size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))

        self.fine_tune(True)

    def forward(self, images):
        """
        Forward propagation.
        :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size)
        :return: encoded images
        """
        # Extract features
        out = self.back_bone(images)  # (batch_size, 2048, image_size/32, image_size/32)
        out = self.adaptive_pool(out)  # (batch_size, 2048, encoded_image_size, encoded_image_size)
        out = out.permute(0, 2, 3, 1)  # (batch_size, encoded_image_size, encoded_image_size, 2048)
        return out

    def fine_tune(self, fine_tune=True):
        """
        Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder.
        :param fine_tune: Allow?
        """
        for p in self.back_bone.parameters():
            p.requires_grad = False
        # If fine-tuning, only fine-tune convolutional blocks 2 through 4
        for c in list(self.back_bone.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune

In [None]:
class Attention(nn.Module):
    """
    Attention Network.
    """

    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        """
        :param encoder_dim: feature size of encoded images
        :param decoder_dim: size of decoder's RNN
        :param attention_dim: size of the attention network
        """
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # linear layer to transform encoded image
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # linear layer to transform decoder's output
        self.full_att = nn.Linear(attention_dim, 1)  # linear layer to calculate values to be softmax-ed
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)  # softmax layer to calculate weights

    def forward(self, encoder_out, decoder_hidden):
        """
        Forward propagation.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
        :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
        :return: attention weighted encoding, weights
        """
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)

        return attention_weighted_encoding, alpha

In [None]:
class DecoderWithAttention(nn.Module):
    """
    Decoder.
    """

    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.4):
        """
        :param attention_dim: size of attention network
        :param embed_dim: embedding size
        :param decoder_dim: size of decoder's RNN
        :param vocab_size: size of vocabulary
        :param encoder_dim: feature size of encoded images
        :param dropout: dropout
        """
        super(DecoderWithAttention, self).__init__()

        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout

        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)  # attention network

        self.embedding = nn.Embedding(vocab_size, embed_dim)  # embedding layer
        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)  # decoding LSTMCell
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial hidden state of LSTMCell
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial cell state of LSTMCell
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # linear layer to create a sigmoid-activated gate
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)  # linear layer to find scores over vocabulary
        self.init_weights()  # initialize some layers with the uniform distribution

    def init_weights(self):
        """
        Initializes some parameters with values from the uniform distribution, for easier convergence.
        """
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def load_pretrained_embeddings(self, embeddings):
        """
        Loads embedding layer with pre-trained embeddings.
        :param embeddings: pre-trained embeddings
        """
        self.embedding.weight = nn.Parameter(embeddings)

    def fine_tune_embeddings(self, fine_tune=True):
        """
        Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings).
        :param fine_tune: Allow?
        """
        for p in self.embedding.parameters():
            p.requires_grad = fine_tune

    def init_hidden_state(self, encoder_out):
        """
        Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
        :return: hidden state, cell state
        """
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        Forward propagation.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
        :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
        :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
        :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
        """
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        # Flatten image
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # Sort input data by decreasing lengths; why? apparent below
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        # Embedding
        embeddings = self.embedding(encoded_captions)  # (batch_size, max_caption_length, embed_dim)

        # Initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)

        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        # So, decoding lengths are actual lengths - 1
        decode_lengths = (caption_lengths - 1).tolist()

        # Create tensors to hold word predicion scores and alphas
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)

        # At each time-step, decode by
        # attention-weighing the encoder's output based on the decoder's previous hidden state output
        # then generate a new word in the decoder with the previous word and the attention weighted encoding
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                                h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        return predictions, encoded_captions, decode_lengths, alphas, sort_ind

## Training

In [None]:
class AverageMeter:
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def accuracy(scores, targets, k):
    """
    Computes top-k accuracy, from predicted and true labels.
    :param scores: scores from the model
    :param targets: true labels
    :param k: k in top-k accuracy
    :return: top-k accuracy
    """

    batch_size = targets.size(0)
    _, ind = scores.topk(k, 1, True, True)
    correct = ind.eq(targets.view(-1, 1).expand_as(ind))
    correct_total = correct.view(-1).float().sum()  # 0D tensor
    return correct_total.item() * (100.0 / batch_size)

In [None]:
def clip_gradient(optimizer, grad_clip):
    """
    Clips gradients computed during backpropagation to avoid explosion of gradients.
    :param optimizer: optimizer with the gradients to be clipped
    :param grad_clip: clip value
    """
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)

In [None]:
def train_epoch(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epoch):
    """
    Performs one epoch's training.
    :param train_loader: DataLoader for training data
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning)
    :param decoder_optimizer: optimizer to update decoder's weights
    :param epoch: epoch number
    """
    
    batch_time = AverageMeter()  # forward prop. + back prop. time
    losses = AverageMeter()  # loss (per word decoded)
    top5accs = AverageMeter()  # top5 accuracy
    
    decoder.train()  # train mode (dropout and batchnorm is used)
    encoder.train()
    start = time.time()

    for i, (imgs, caps, caplens) in tqdm(enumerate(train_loader), total=len(train_loader), position=0, leave=True):
        # Move to GPU, if available
        imgs = imgs.to(device)
        caps = caps.to(device)
        caplens = caplens.to(device)

        # Forward prop.
        imgs = encoder(imgs)
        scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

        # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
        targets = caps_sorted[:, 1:]

        # Remove timesteps that we didn't decode at, or are pads
        # pack_padded_sequence is an easy trick to do this
        scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data

        # Calculate loss
        loss = criterion(scores, targets)

        # Add doubly stochastic attention regularization
        loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

        # Back prop.
        decoder_optimizer.zero_grad()
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        loss.backward()

        # Clip gradients
        if grad_clip is not None:
            clip_gradient(decoder_optimizer, grad_clip)
            if encoder_optimizer is not None:
                clip_gradient(encoder_optimizer, grad_clip)

        # Update weights
        decoder_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()

        # Keep track of metrics
        top5 = accuracy(scores, targets, 5)
        losses.update(loss.item(), sum(decode_lengths))
        top5accs.update(top5, sum(decode_lengths))
        batch_time.update(time.time() - start)

        start = time.time()
        
    return losses, top5accs, batch_time

In [None]:
def validate_epoch(val_loader, encoder, decoder, criterion):
    decoder.eval()  # eval mode (no dropout or batchnorm)
    if encoder is not None:
        encoder.eval()
    
    batch_time = AverageMeter()
    losses = AverageMeter()
    top5accs = AverageMeter()
    
    start = time.time()

    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # hypotheses (predictions)

    with torch.no_grad():
        # Batches
        for i, (imgs, caps, caplens) in tqdm(enumerate(val_loader), total=len(val_loader), position=0, leave=True):

            # Move to device, if available
            imgs = imgs.to(device)
            caps = caps.to(device)
            caplens = caplens.to(device)

            # Forward prop.
            if encoder is not None:
                imgs = encoder(imgs)
            scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)

            # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
            targets = caps_sorted[:, 1:]

            # Remove timesteps that we didn't decode at, or are pads
            # pack_padded_sequence is an easy trick to do this
            scores_copy = scores.clone()
            scores = pack_padded_sequence(scores, decode_lengths, batch_first=True).data
            targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data

            # Calculate loss
            loss = criterion(scores, targets)

            # Add doubly stochastic attention regularization
            loss += alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()

            # Keep track of metrics
            losses.update(loss.item(), sum(decode_lengths))
            top5 = accuracy(scores, targets, 5)
            top5accs.update(top5, sum(decode_lengths))
            batch_time.update(time.time() - start)

            start = time.time()
        
        return losses, top5accs, batch_time

In [None]:
# Model parameters
emb_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.4
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Training parameters
start_epoch = 0
epochs = 5  # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU
workers = 4  # for data-loading; right now, only 1 works with h5py
encoder_lr = 1e-3  # learning rate for encoder if fine-tuning
decoder_lr = 4e-3  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0.  # BLEU-4 score right now
fine_tune_encoder = False  # fine-tune encoder?

In [None]:
decoder = DecoderWithAttention(attention_dim=attention_dim,
                               embed_dim=emb_dim,
                               decoder_dim=decoder_dim,
                               vocab_size=len(vocab),
                               dropout=dropout)
decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),
                                     lr=decoder_lr)
encoder = Encoder(timm.create_model('resnet101', pretrained=True))
encoder.fine_tune(fine_tune_encoder)
encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
                                     lr=encoder_lr) if fine_tune_encoder else None

In [None]:
decoder = decoder.to(device)
encoder = encoder.to(device)

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss().to(device)

In [None]:
def save_checkpoint(epoch, encoder, decoder, encoder_optimizer, decoder_optimizer):
    """
    Saves model checkpoint.
    :param data_name: base name of processed dataset
    :param epoch: epoch number
    :param epochs_since_improvement: number of epochs since last improvement in BLEU-4 score
    :param encoder: encoder model
    :param decoder: decoder model
    :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning
    :param decoder_optimizer: optimizer to update decoder's weights
    :param bleu4: validation BLEU-4 score for this epoch
    :param is_best: is this checkpoint the best so far?
    """
    state = {'epoch': epoch,
             'encoder': encoder,
             'decoder': decoder,
             'encoder_optimizer': encoder_optimizer,
             'decoder_optimizer': decoder_optimizer}
    filename = f'checkpoint_{epoch}.pth.tar'
    torch.save(state, filename)
    # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint

In [None]:
# For inference:
epochs = 0

In [None]:
best_score = 0
for epoch in range(start_epoch, epochs):

    # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
    if epochs_since_improvement == 20:
        break
    if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
        adjust_learning_rate(decoder_optimizer, 0.8)
        if fine_tune_encoder:
            adjust_learning_rate(encoder_optimizer, 0.8)

    # One epoch's training
    train_loss, train_top5acc, batch_time = train_epoch(
        train_loader=train_loader,
        encoder=encoder,
        decoder=decoder,
        criterion=criterion,
        encoder_optimizer=encoder_optimizer,
        decoder_optimizer=decoder_optimizer,
        epoch=epoch
    )
    
    val_loss, val_top5acc, _ = validate_epoch(
        val_loader,
        encoder,
        decoder,
        criterion
    )
    
    if best_score < val_top5acc.avg:
        best_score = val_top5acc.avg
        print(f"Saving checkpoint. Best score: {best_score:.4f}")
        save_checkpoint(epoch+1, encoder, decoder, encoder_optimizer, decoder_optimizer)
    
    
    print(f'Epoch: {epoch+1:02} | Time: {batch_time.avg} sec')
    print(f'\t    Train Loss: {train_loss.avg:.4f} | Val. Loss: {val_loss.avg:.4f}')
    print(f'\t    Top5 Acc.: {train_top5acc.avg:.3f} | Val. Top5 Acc.: {val_top5acc.avg:.3f} \n')

## Inference

In [None]:
checkpoint = torch.load('../input/bmsmt/checkpoint_12.pth.tar')
encoder = checkpoint['encoder']
decoder = checkpoint['decoder']

In [None]:
sample_sub_df = pd.read_csv("../input/bms-molecular-translation/sample_submission.csv")
sample_sub_df.head()

In [None]:
class BMSDatasetTest(Dataset):
    def __init__(self, df):
        super().__init__()
        self.df = df
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # Read the image
        img_id = self.df.iloc[idx]['image_id']
        img_path = os.path.join('../input/bms-molecular-translation/test', img_id[0], img_id[1], img_id[2], f'{img_id}.png')
        
        img = self._load_from_file(img_path)
        return torch.tensor(img).permute(2, 0, 1)

    def _load_from_file(self, img_path):
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image = cv2.resize(image, (IMG_SIZE, IMG_SIZE)) 
        image /= 255.0  # Normalize
        return image

In [None]:
test_dataset = BMSDatasetTest(sample_sub_df)
test_batch_size = 128
test_dataloader = DataLoader(
    test_dataset,
    batch_size=test_batch_size,
    pin_memory=True,
    num_workers=4,
    shuffle=False,
)

In [None]:
max_pred_len = 100

In [None]:
def inference(encoder, decoder, imgs, vocab):
    imgs = imgs.to(device)
    batch_size = len(imgs)
    
    encoder_out = encoder(imgs)  # (1, enc_image_size, enc_image_size, encoder_dim)
    enc_image_size = encoder_out.size(1)
    encoder_dim = encoder_out.size(3)
    
    encoder_out = encoder_out.view(batch_size, -1, encoder_dim)
    
     # Start decoding
    h, c = decoder.init_hidden_state(encoder_out)
    start = torch.full((batch_size,1), vocab.stoi['<sos>']).to(device)
    pred = torch.zeros((batch_size, max_pred_len), dtype=torch.long).to(device)
    pred[:, 0] = start.squeeze()
    
    idx = 1

    while True:
        embeddings = decoder.embedding(start).squeeze(1)
        
        awe, _ = decoder.attention(encoder_out, h)
        
        gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
        awe = gate * awe

        h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))
        scores = decoder.fc(h)  # (s, vocab_size)
        scores = F.log_softmax(scores, dim=1)

        start = scores.argmax(1).reshape(-1, 1).to(device)
        
        pred[:, idx] = start.squeeze(1)
        
        if idx >= max_pred_len-1:
            break
        
        idx += 1
        
    return pred

In [None]:
def batch_stringify(batch):
    preds = []
    for item in batch:
        pred = np.vectorize(vocab.itos.get)(item)
        # Truncate everything after <eos>
        try:
            pred = pred[1:np.nonzero(pred == '<eos>')[0][0]]
        except IndexError:
            pred = pred[1: ]
            pass

        preds.append("".join(pred))
    return preds

## Calculating Competition Metric on validation data

In [None]:
preds, gts = [], []
for imgs, caps, capslen in tqdm(val_loader):
    preds.extend(batch_stringify(inference(encoder, decoder, imgs, vocab).cpu().detach().numpy()))
    gts.extend(batch_stringify(caps))

print(f'Levenshtein distance: {np.mean(np.vectorize(levenshtein_distance)(preds, gts))}')

## Prediction on Test Data

In [None]:
preds = []
for imgs in tqdm(test_dataloader):
    preds.extend(batch_stringify(inference(encoder, decoder, imgs, vocab).cpu().detach().numpy()))

In [None]:
sample_sub_df['InChI'] = preds
sample_sub_df.head()

In [None]:
sample_sub_df.to_csv("submission.csv", index=False)