In [None]:
%%capture
!pip install timm
!pip install python-Levenshtein

In [None]:
import os
import re
import time

import pandas as pd
import numpy as np

import timm

import cv2

from tqdm import tqdm
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import Dataset, DataLoader
import torch.backends.cudnn as cudnn

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from sklearn.model_selection import train_test_split

from Levenshtein import distance as levenshtein_distance

In [None]:
BASE_DIR = '../input/bms-molecular-translation/train'

In [None]:
label_df = pd.read_csv("../input/bms-molecular-translation/train_labels.csv")
label_df.head()

In [None]:
# Remove "InchI=1S/" as it's common for all notations
label_df["InChI"] = label_df["InChI"].str[9: ]

In [None]:
%%time
formula_regex = re.compile(r"(.*?)/c")
atom_conn_regex = re.compile(r"/c(.*?)(/h|$)")
hydrogen_regex = re.compile(r"/h(.*)")

def get_formulae(inChI):
    try:
        return formula_regex.search(inChI).group(1)
    except:
        return None

def get_atom_conn(inChI):
    try:
        return atom_conn_regex.search(inChI).group(1)
    except:
        return None

def get_hydrogen(inChI):
    try:
        return hydrogen_regex.search(inChI).group(1)
    except:
        return None

label_df["formula"] = label_df["InChI"].map(get_formulae)
label_df["atom_conn"] = label_df["InChI"].map(get_atom_conn)
label_df["hydrogen"] = label_df["InChI"].map(get_hydrogen)

label_df.head()

In [None]:
label_df.isna().sum()

There are a few training samples having no hydrogen atom information in InChI notation. Dropping those.

In [None]:
label_df = label_df.dropna(axis=0)
label_df.shape

In [None]:
# Currently using only 90k+10k images
label_df = label_df.iloc[:100000]

In [None]:
# Train on 90k, validate on 10k
train_df, test_df = train_test_split(label_df, test_size=0.1, shuffle=False)
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

In [None]:
train_df.shape, test_df.shape

## Build Vocabulary

There are only limited set of characters in the InChI notation. Refer this for more details: https://www.kaggle.com/c/bms-molecular-translation/discussion/223471. So I'll build a character based vocabulary class which will be used for target preprocessing.

In [None]:
class Vocabulary:
    def __init__(self, freq_threshold=2, reverse=False):
        self.itos = {0: "<pad>", 1: "<sos>", 2: "<eos>", 3: "<unk>"}
        self.stoi = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}
        self.freq_threshold = freq_threshold
        self.reverse = reverse
        self.tokenizer = self._tokenizer

    def __len__(self):
        return len(self.itos)
    
    def _tokenizer(self, text):
        return (char for char in text)

    def tokenize(self, text):
        if self.reverse:
            return [token for token in self.tokenizer(text)][::-1]
        else:
            return [token for token in self.tokenizer(text)]

    def build_vocabulary(self, sentence_list):
        """Basically builds a frequency map for all possible characters."""
        frequencies = {}
        idx = len(self.itos)

        for sentence in sentence_list:
            # Preprocess the InChI.
            for char in self.tokenize(sentence):
                if char in frequencies:
                    frequencies[char] += 1
                else:
                    frequencies[char] = 1

                if frequencies[char] == self.freq_threshold:
                    self.stoi[char] = idx
                    self.itos[idx] = char
                    idx += 1

    def numericalize(self, text):
        """Convert characters to numbers."""
        tokenized_text = self.tokenize(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<unk>"]
            for token in tokenized_text
        ]

In [None]:
%%time
# Build vocab using training data
freq_threshold = 2
vocab_formula = Vocabulary(freq_threshold=freq_threshold, reverse=False)
vocab_atom_conn = Vocabulary(freq_threshold=freq_threshold, reverse=False)
vocab_hydrogen = Vocabulary(freq_threshold=freq_threshold, reverse=False)

# build vocab
vocab_formula.build_vocabulary(train_df['formula'].to_list())
vocab_atom_conn.build_vocabulary(train_df['atom_conn'].to_list())
vocab_hydrogen.build_vocabulary(train_df['hydrogen'].to_list())

In [None]:
IMG_SIZE = 128

## Augmentations

In [None]:
def get_train_augs():
    return A.Compose(
        [
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(p=0.5),
            A.GaussNoise(p=0.5),
            A.Cutout(num_holes=8, max_h_size=64, max_w_size=64, fill_value=255, p=0.5),
            ToTensorV2(p=1.0),
        ], 
        p=1.0
    )

def get_valid_augs():
    return A.Compose(
        [
            ToTensorV2(p=1.0),
        ], 
        p=1.0
    )

In [None]:
class BMSDataset(Dataset):
    def __init__(self, df, augs, vocab_formula, vocab_atom_conn, vocab_hydrogen):
        super().__init__()
        self.df = df
        self.augs = augs
        self.vocab_formula = vocab_formula
        self.vocab_atom_conn = vocab_atom_conn
        self.vocab_hydrogen = vocab_hydrogen

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # Read the image
        img_id = self.df.iloc[idx]['image_id']
        label = self.df.iloc[idx]['InChI']
        
        # Split formula, atom connection and hydrogen sections
        label_formula = self.df.iloc[idx]['formula']
        label_atom_conn = self.df.iloc[idx]['atom_conn']
        label_hydrogen = self.df.iloc[idx]['hydrogen']
        
        label_formula_len = len(label_formula) + 2  # (2 for <sos> and <eos>)
        label_atom_conn_len = len(label_atom_conn) + 2
        label_hydrogen_len = len(label_hydrogen) + 2
        
        img_path = os.path.join(BASE_DIR, img_id[0], img_id[1], img_id[2], f'{img_id}.png')
        
        img = self._load_from_file(img_path)
        
        # Apply image augmentations if available
        if self.augs:
            img = self.augs(image=img)["image"]
        
        # Convert label to numbers
        label_formula = self._get_numericalized(label_formula, self.vocab_formula)
        label_atom_conn = self._get_numericalized(label_atom_conn, self.vocab_atom_conn)
        label_hydrogen = self._get_numericalized(label_hydrogen, self.vocab_hydrogen)

        return img, torch.tensor(label_formula), torch.tensor(label_atom_conn), torch.tensor(label_hydrogen), torch.tensor(label_formula_len), torch.tensor(label_atom_conn_len), torch.tensor(label_hydrogen_len)
    
    def _get_numericalized(self, sentence, vocab):
        """Numericalize given text using prebuilt vocab."""
        numericalized = [vocab.stoi["<sos>"]]
        numericalized.extend(vocab.numericalize(sentence))
        numericalized.append(vocab.stoi["<eos>"])
        return numericalized

    def _load_from_file(self, img_path):
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image = cv2.resize(image, (IMG_SIZE, IMG_SIZE)) 
        image /= 255.0  # Normalize
        return image

In [None]:
def bms_collate(batch):
    
    imgs, label_formula, label_atom_conn, label_hydrogen, label_formula_len, label_atom_conn_len, label_hydrogen_len = [], [], [], [], [], [], []
    
    for data_point in batch:
        imgs.append(data_point[0])
        label_formula.append(data_point[1])
        label_atom_conn.append(data_point[2])
        label_hydrogen.append(data_point[3])
        label_formula_len.append(data_point[4])
        label_atom_conn_len.append(data_point[5])
        label_hydrogen_len.append(data_point[6])

    label_formula = pad_sequence(label_formula, batch_first=True, padding_value=vocab_formula.stoi["<pad>"])
    label_atom_conn = pad_sequence(label_atom_conn, batch_first=True, padding_value=vocab_atom_conn.stoi["<pad>"])
    label_hydrogen = pad_sequence(label_hydrogen, batch_first=True, padding_value=vocab_hydrogen.stoi["<pad>"])

    return torch.stack(imgs), label_formula, label_atom_conn, label_hydrogen, torch.stack(label_formula_len).reshape(-1, 1), torch.stack(label_atom_conn_len).reshape(-1, 1), torch.stack(label_hydrogen_len).reshape(-1, 1)
    

train_dataset = BMSDataset(train_df, get_train_augs(), vocab_formula, vocab_atom_conn, vocab_hydrogen)
train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    pin_memory=True,
    num_workers=4,
    shuffle=True,
    collate_fn=bms_collate
)

val_dataset = BMSDataset(test_df, get_valid_augs(), vocab_formula, vocab_atom_conn, vocab_hydrogen)
val_loader = DataLoader(
    val_dataset,
    batch_size=64,
    pin_memory=True,
    num_workers=4,
    shuffle=False,
    collate_fn=bms_collate
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Available device: {device}")

## Model Architecture

This code is heavily based on this great tutorial: https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Image-Captioning

In [None]:
class Encoder(nn.Module):
    """
    Encoder.
    """

    def __init__(self, back_bone, encoded_image_size=8):
        super(Encoder, self).__init__()
        self.enc_image_size = encoded_image_size

        # Remove linear and pool layers (since we're not doing classification)
        modules = list(back_bone.children())[:-2]
        self.back_bone = nn.Sequential(*modules)

        # Resize image to fixed size to allow input images of variable size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))

        self.fine_tune(False)

    def forward(self, images):
        """
        Forward propagation.
        :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size)
        :return: encoded images
        """
        # Extract features
        out = self.back_bone(images)  # (batch_size, 2048, image_size/32, image_size/32)
        out = self.adaptive_pool(out)  # (batch_size, 2048, encoded_image_size, encoded_image_size)
        out = out.permute(0, 2, 3, 1)  # (batch_size, encoded_image_size, encoded_image_size, 2048)
        return out

    def fine_tune(self, fine_tune=True):
        """
        Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder.
        :param fine_tune: Allow?
        """
        for p in self.back_bone.parameters():
            p.requires_grad = False
        # If fine-tuning, only fine-tune convolutional blocks 2 through 4
        for c in list(self.back_bone.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune

In [None]:
class Attention(nn.Module):
    """
    Attention Network.
    """

    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        """
        :param encoder_dim: feature size of encoded images
        :param decoder_dim: size of decoder's RNN
        :param attention_dim: size of the attention network
        """
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # linear layer to transform encoded image
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # linear layer to transform decoder's output
        self.full_att = nn.Linear(attention_dim, 1)  # linear layer to calculate values to be softmax-ed
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)  # softmax layer to calculate weights

    def forward(self, encoder_out, decoder_hidden):
        """
        Forward propagation.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
        :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
        :return: attention weighted encoding, weights
        """
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)

        return attention_weighted_encoding, alpha

In [None]:
class DecoderWithAttention(nn.Module):
    """
    Decoder.
    """

    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.4):
        """
        :param attention_dim: size of attention network
        :param embed_dim: embedding size
        :param decoder_dim: size of decoder's RNN
        :param vocab_size: size of vocabulary
        :param encoder_dim: feature size of encoded images
        :param dropout: dropout
        """
        super(DecoderWithAttention, self).__init__()

        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout

        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)  # attention network

        self.embedding = nn.Embedding(vocab_size, embed_dim)  # embedding layer
        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)  # decoding LSTMCell
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial hidden state of LSTMCell
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial cell state of LSTMCell
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # linear layer to create a sigmoid-activated gate
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)  # linear layer to find scores over vocabulary
        self.init_weights()  # initialize some layers with the uniform distribution

    def init_weights(self):
        """
        Initializes some parameters with values from the uniform distribution, for easier convergence.
        """
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def load_pretrained_embeddings(self, embeddings):
        """
        Loads embedding layer with pre-trained embeddings.
        :param embeddings: pre-trained embeddings
        """
        self.embedding.weight = nn.Parameter(embeddings)

    def fine_tune_embeddings(self, fine_tune=True):
        """
        Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings).
        :param fine_tune: Allow?
        """
        for p in self.embedding.parameters():
            p.requires_grad = fine_tune

    def init_hidden_state(self, encoder_out):
        """
        Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
        :return: hidden state, cell state
        """
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        Forward propagation.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
        :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
        :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
        :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
        """
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        # Flatten image
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # Sort input data by decreasing lengths; why? apparent below
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        # Embedding
        embeddings = self.embedding(encoded_captions)  # (batch_size, max_caption_length, embed_dim)

        # Initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)

        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        # So, decoding lengths are actual lengths - 1
        decode_lengths = (caption_lengths - 1).tolist()

        # Create tensors to hold word predicion scores and alphas
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)

        # At each time-step, decode by
        # attention-weighing the encoder's output based on the decoder's previous hidden state output
        # then generate a new word in the decoder with the previous word and the attention weighted encoding
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                                h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        return predictions, encoded_captions, decode_lengths, alphas, sort_ind

## Training

In [None]:
class AverageMeter:
    """
    Keeps track of most recent, average, sum, and count of a metric.
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def accuracy(scores, targets, k):
    """
    Computes top-k accuracy, from predicted and true labels.
    :param scores: scores from the model
    :param targets: true labels
    :param k: k in top-k accuracy
    :return: top-k accuracy
    """

    batch_size = targets.size(0)
    _, ind = scores.topk(k, 1, True, True)
    correct = ind.eq(targets.view(-1, 1).expand_as(ind))
    correct_total = correct.view(-1).float().sum()  # 0D tensor
    return correct_total.item() * (100.0 / batch_size)

In [None]:
def clip_gradient(optimizer, grad_clip):
    """
    Clips gradients computed during backpropagation to avoid explosion of gradients.
    :param optimizer: optimizer with the gradients to be clipped
    :param grad_clip: clip value
    """
    for group in optimizer.param_groups:
        for param in group['params']:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)

In [None]:
def train_epoch(train_loader, encoder, decoders, criterion, encoder_optimizer, decoder_optimizers, epoch):
    """
    Performs one epoch's training.
    :param train_loader: DataLoader for training data
    :param encoder: encoder model
    :param decoder: decoder model
    :param criterion: loss layer
    :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning)
    :param decoder_optimizer: optimizer to update decoder's weights
    :param epoch: epoch number
    """
    decoder_formula_optimizer, decoder_atom_conn_optimizer, decoder_hydrogen_optimizer = decoder_optimizers
    decoder_formula, decoder_atom_conn, decoder_hydrogen = decoders
    
    batch_time = AverageMeter()  # forward prop. + back prop. time
    losses_formula = AverageMeter()  # loss (per word decoded)
    losses_atom_conn = AverageMeter()
    losses_hydrogen = AverageMeter()
    
    # train mode (dropout and batchnorm is used)
    decoder_formula.train(), decoder_atom_conn.train(), decoder_hydrogen.train()
    start = time.time()
    
    for i, (imgs, label_formula, label_atom_conn, label_hydrogen, label_formula_len, label_atom_conn_len, label_hydrogen_lens) in tqdm(enumerate(train_loader), total=len(train_loader), position=0, leave=True):
        # Move to GPU, if available
        imgs = imgs.to(device)
        label_formula = label_formula.to(device)
        label_atom_conn = label_atom_conn.to(device)
        label_hydrogen = label_hydrogen.to(device)
        
        label_formula_len = label_formula_len.to(device)
        label_atom_conn_len = label_atom_conn_len.to(device)
        label_hydrogen_lens = label_hydrogen_lens.to(device)

        # Forward prop.
        imgs = encoder(imgs)
        scores_formula, caps_sorted_formula, decode_lengths_formula, alphas_formula, sort_ind_formula = decoder_formula(imgs, label_formula, label_formula_len)
        scores_atom_conn, caps_sorted_atom_conn, decode_lengths_atom_conn, alphas_atom_conn, sort_ind_atom_conn = decoder_atom_conn(imgs, label_atom_conn, label_atom_conn_len)
        scores_hydrogen, caps_sorted_hydrogen, decode_lengths_hydrogen, alphas_hydrogen, sort_ind_hydrogen = decoder_hydrogen(imgs, label_hydrogen, label_hydrogen_lens)

        # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
        targets_formula = caps_sorted_formula[:, 1:]
        targets_atom_conn = caps_sorted_atom_conn[:, 1:]
        targets_hydrogen = caps_sorted_hydrogen[:, 1:]

        # Remove timesteps that we didn't decode at, or are pads
        # pack_padded_sequence is an easy trick to do this
        scores_formula = pack_padded_sequence(scores_formula, decode_lengths_formula, batch_first=True).data
        targets_formula = pack_padded_sequence(targets_formula, decode_lengths_formula, batch_first=True).data

        
        scores_atom_conn = pack_padded_sequence(scores_atom_conn, decode_lengths_atom_conn, batch_first=True).data
        targets_atom_conn = pack_padded_sequence(targets_atom_conn, decode_lengths_atom_conn, batch_first=True).data
        
        scores_hydrogen = pack_padded_sequence(scores_hydrogen, decode_lengths_hydrogen, batch_first=True).data
        targets_hydrogen = pack_padded_sequence(targets_hydrogen, decode_lengths_hydrogen, batch_first=True).data
        
        # Calculate loss
        loss_formula = criterion(scores_formula, targets_formula)
        loss_atom_conn = criterion(scores_atom_conn, targets_atom_conn)
        loss_hydrogen = criterion(scores_hydrogen, targets_hydrogen)

        # Add doubly stochastic attention regularization
        loss_formula += alpha_c * ((1. - alphas_formula.sum(dim=1)) ** 2).mean()
        loss_atom_conn += alpha_c * ((1. - alphas_atom_conn.sum(dim=1)) ** 2).mean()
        loss_hydrogen += alpha_c * ((1. - alphas_hydrogen.sum(dim=1)) ** 2).mean()

        # Back prop.
        decoder_formula_optimizer.zero_grad()
        decoder_atom_conn_optimizer.zero_grad()
        decoder_hydrogen_optimizer.zero_grad()
        
        if encoder_optimizer is not None:
            encoder_optimizer.zero_grad()
        
        # Backpropagation
        loss_formula.backward()
        loss_atom_conn.backward()
        loss_hydrogen.backward()

        # Clip gradients
        if grad_clip is not None:
            clip_gradient(decoder_formula_optimizer, grad_clip)
            clip_gradient(decoder_atom_conn_optimizer, grad_clip)
            clip_gradient(decoder_hydrogen_optimizer, grad_clip)

            if encoder_optimizer is not None:
                clip_gradient(encoder_optimizer, grad_clip)

        # Update weights
        decoder_formula_optimizer.step()
        decoder_atom_conn_optimizer.step()
        decoder_hydrogen_optimizer.step()
        if encoder_optimizer is not None:
            encoder_optimizer.step()

        # Keep track of metrics
        losses_formula.update(loss_formula.item(), sum(decode_lengths_formula))
        losses_atom_conn.update(loss_atom_conn.item(), sum(decode_lengths_atom_conn))
        losses_hydrogen.update(loss_hydrogen.item(), sum(decode_lengths_hydrogen))
        batch_time.update(time.time() - start)

        start = time.time()
        
    return losses_formula, losses_atom_conn, losses_hydrogen, batch_time

In [None]:
def validate_epoch(val_loader, encoder, decoders, criterion):
    
    decoder_formula, decoder_atom_conn, decoder_hydrogen = decoders
    
    decoder_formula.eval()  # eval mode (no dropout or batchnorm)
    decoder_atom_conn.eval()
    decoder_hydrogen.eval()

    if encoder is not None:
        encoder.eval()
    
    batch_time = AverageMeter()
    losses_formula = AverageMeter()  # loss (per word decoded)
    losses_atom_conn = AverageMeter()
    losses_hydrogen = AverageMeter()
    
    start = time.time()

    references = list()  # references (true captions) for calculating BLEU-4 score
    hypotheses = list()  # hypotheses (predictions)

    with torch.no_grad():
        # Batches
        for i, (imgs, label_formula, label_atom_conn, label_hydrogen, label_formula_len, label_atom_conn_len, label_hydrogen_lens) in tqdm(enumerate(val_loader), total=len(val_loader), position=0, leave=True):

            # Move to device, if available
            imgs = imgs.to(device)
            label_formula = label_formula.to(device)
            label_atom_conn = label_atom_conn.to(device)
            label_hydrogen = label_hydrogen.to(device)

            label_formula_len = label_formula_len.to(device)
            label_atom_conn_len = label_atom_conn_len.to(device)
            label_hydrogen_lens = label_hydrogen_lens.to(device)

            # Forward prop.
            if encoder is not None:
                imgs = encoder(imgs)
            
            scores_formula, caps_sorted_formula, decode_lengths_formula, alphas_formula, sort_ind_formula = decoder_formula(imgs, label_formula, label_formula_len)
            scores_atom_conn, caps_sorted_atom_conn, decode_lengths_atom_conn, alphas_atom_conn, sort_ind_atom_conn = decoder_atom_conn(imgs, label_atom_conn, label_atom_conn_len)
            scores_hydrogen, caps_sorted_hydrogen, decode_lengths_hydrogen, alphas_hydrogen, sort_ind_hydrogen = decoder_hydrogen(imgs, label_hydrogen, label_hydrogen_lens)

            # Since we decoded starting with <start>, the targets are all words after <start>, up to <end>
            targets_formula = caps_sorted_formula[:, 1:]
            targets_atom_conn = caps_sorted_atom_conn[:, 1:]
            targets_hydrogen = caps_sorted_hydrogen[:, 1:]

            # Remove timesteps that we didn't decode at, or are pads
            # pack_padded_sequence is an easy trick to do this
            scores_formula = pack_padded_sequence(scores_formula, decode_lengths_formula, batch_first=True).data
            targets_formula = pack_padded_sequence(targets_formula, decode_lengths_formula, batch_first=True).data

            scores_atom_conn = pack_padded_sequence(scores_atom_conn, decode_lengths_atom_conn, batch_first=True).data
            targets_atom_conn = pack_padded_sequence(targets_atom_conn, decode_lengths_atom_conn, batch_first=True).data

            scores_hydrogen = pack_padded_sequence(scores_hydrogen, decode_lengths_hydrogen, batch_first=True).data
            targets_hydrogen = pack_padded_sequence(targets_hydrogen, decode_lengths_hydrogen, batch_first=True).data

            # Calculate loss
            loss_formula = criterion(scores_formula, targets_formula)
            loss_atom_conn = criterion(scores_atom_conn, targets_atom_conn)
            loss_hydrogen = criterion(scores_hydrogen, targets_hydrogen)

            # Add doubly stochastic attention regularization
            loss_formula += alpha_c * ((1. - alphas_formula.sum(dim=1)) ** 2).mean()
            loss_atom_conn += alpha_c * ((1. - alphas_atom_conn.sum(dim=1)) ** 2).mean()
            loss_hydrogen += alpha_c * ((1. - alphas_hydrogen.sum(dim=1)) ** 2).mean()


            # Keep track of metrics
            losses_formula.update(loss_formula.item(), sum(decode_lengths_formula))
            losses_atom_conn.update(loss_atom_conn.item(), sum(decode_lengths_atom_conn))
            losses_hydrogen.update(loss_hydrogen.item(), sum(decode_lengths_hydrogen))
            batch_time.update(time.time() - start)
            start = time.time()
        
        return losses_formula, losses_atom_conn, losses_hydrogen, batch_time

In [None]:
# Model parameters
emb_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.4
cudnn.benchmark = False  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Training parameters
start_epoch = 0
epochs = 20  # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU
workers = 4  # for data-loading; right now, only 1 works with h5py
encoder_lr = 2e-4  # learning rate for encoder if fine-tuning
decoder_lr = 2e-4  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0.  # BLEU-4 score right now
fine_tune_encoder = False  # fine-tune encoder?

In [None]:
# Train 3 separate decoders to reduce the sequence length while predicting
decoder_formula = DecoderWithAttention(
    attention_dim=attention_dim,
    embed_dim=emb_dim,
    decoder_dim=decoder_dim,
    vocab_size=len(vocab_formula),
    dropout=dropout
)
decoder_atom_conn = DecoderWithAttention(
    attention_dim=attention_dim,
    embed_dim=emb_dim,
    decoder_dim=decoder_dim,
    vocab_size=len(vocab_atom_conn),
    dropout=dropout
)
decoder_hydrogen = DecoderWithAttention(
    attention_dim=attention_dim,
    embed_dim=emb_dim,
    decoder_dim=decoder_dim,
    vocab_size=len(vocab_hydrogen),
    dropout=dropout
)

decoder_formula_optimizer = torch.optim.Adam(
    params=filter(lambda p: p.requires_grad, decoder_formula.parameters()),
    lr=decoder_lr
)

decoder_atom_conn_optimizer = torch.optim.Adam(
    params=filter(lambda p: p.requires_grad, decoder_atom_conn.parameters()),
    lr=decoder_lr
)

decoder_hydrogen_optimizer = torch.optim.Adam(
    params=filter(lambda p: p.requires_grad, decoder_hydrogen.parameters()),
    lr=decoder_lr
)

encoder = Encoder(timm.create_model('resnet101', pretrained=True))
encoder.fine_tune(fine_tune_encoder)
encoder_optimizer = torch.optim.Adam(
    params=filter(lambda p: p.requires_grad, encoder.parameters()),
    lr=encoder_lr
) if fine_tune_encoder else None

In [None]:
encoder = encoder.to(device)
decoder_formula = decoder_formula.to(device)
decoder_atom_conn = decoder_atom_conn.to(device)
decoder_hydrogen = decoder_hydrogen.to(device)

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss().to(device)

In [None]:
def save_checkpoint(epoch, encoder, decoder_formula, decoder_atom_conn, decoder_hydrogen, encoder_optimizer, decoder_formula_optimizer, decoder_atom_conn_optimizer, decoder_hydrogen_optimizer):
    """
    Saves model checkpoint.
    :param data_name: base name of processed dataset
    :param epoch: epoch number
    :param epochs_since_improvement: number of epochs since last improvement in BLEU-4 score
    :param encoder: encoder model
    :param decoder: decoder model
    :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning
    :param decoder_optimizer: optimizer to update decoder's weights
    :param bleu4: validation BLEU-4 score for this epoch
    :param is_best: is this checkpoint the best so far?
    """
    state = {
        'epoch': epoch,
        'encoder': encoder,
        'decoder_formula': decoder_formula,
        'decoder_atom_conn': decoder_atom_conn,
        'decoder_hydrogen': decoder_hydrogen,
        'encoder_optimizer': encoder_optimizer,
        'decoder_formula_optimizer': decoder_formula_optimizer,
        'decoder_atom_conn_optimizer': decoder_atom_conn_optimizer,
        'decoder_hydrogen_optimizer': decoder_hydrogen_optimizer
    }
    filename = f'checkpoint_{epoch}.pth.tar'
    torch.save(state, filename)
    # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint

In [None]:
# For inference:
epochs = 0

In [None]:
best_score = float('inf')
for epoch in range(start_epoch, epochs):

#     # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
#     if epochs_since_improvement == 20:
#         break
#     if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
#         adjust_learning_rate(decoder_optimizer, 0.8)
#         if fine_tune_encoder:
#             adjust_learning_rate(encoder_optimizer, 0.8)

    # One epoch's training
    train_losses_formula, train_losses_atom_conn, train_losses_hydrogen, batch_time = train_epoch(
        train_loader=train_loader,
        encoder=encoder,
        decoders=(decoder_formula, decoder_atom_conn, decoder_hydrogen),
        criterion=criterion,
        encoder_optimizer=encoder_optimizer,
        decoder_optimizers=(decoder_formula_optimizer, decoder_atom_conn_optimizer, decoder_hydrogen_optimizer),
        epoch=epoch
    )
    
    val_losses_formula, val_losses_atom_conn, val_losses_hydrogen, _ = validate_epoch(
        val_loader,
        encoder,
        decoders=(decoder_formula, decoder_atom_conn, decoder_hydrogen),
        criterion=criterion
    )
    
    if best_score > (val_losses_formula.avg + val_losses_atom_conn.avg + val_losses_hydrogen.avg)/3:
        best_score = (val_losses_formula.avg + val_losses_atom_conn.avg + val_losses_hydrogen.avg)/3
        print(f"Saving checkpoint. Best score: {best_score:.4f}")
        save_checkpoint(
            epoch+1,
            encoder,
            decoder_formula,
            decoder_atom_conn,
            decoder_hydrogen,
            encoder_optimizer,
            decoder_formula_optimizer,
            decoder_atom_conn_optimizer,
            decoder_hydrogen_optimizer
        )
    
    print(f'Epoch: {epoch+1:02} | Time: {batch_time.avg} sec\n')
    print(f'\t    Train Losses (Avg.): {(train_losses_formula.avg + train_losses_atom_conn.avg + train_losses_hydrogen.avg)/3:.4f}')
    print(f'\t    \tFormula: {train_losses_formula.avg:.4f} | Atom Connections: {train_losses_atom_conn.avg:.4f} | Hydrogen: {train_losses_hydrogen.avg:.4f}')
    print(f'\t    Validation Losses (Avg.): {(val_losses_formula.avg + val_losses_atom_conn.avg + val_losses_hydrogen.avg)/3:.4f}')
    print(f'\t    \tFormula: {val_losses_formula.avg:.4f} | Atom Connections: {val_losses_atom_conn.avg:.4f} | Hydrogen: {val_losses_hydrogen.avg:.4f}\n')

In [None]:
!ls

## Inference

In [None]:
checkpoint = torch.load('../input/bms-mt-show-attend-and-tell-pytorch-baseline/checkpoint_20.pth.tar')
encoder = checkpoint['encoder']
decoder_formula = checkpoint['decoder_formula']
decoder_atom_conn = checkpoint['decoder_atom_conn']
decoder_hydrogen = checkpoint['decoder_hydrogen']

In [None]:
sample_sub_df = pd.read_csv("../input/bms-molecular-translation/sample_submission.csv")
sample_sub_df.head()

In [None]:
class BMSDatasetTest(Dataset):
    def __init__(self, df):
        super().__init__()
        self.df = df
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        # Read the image
        img_id = self.df.iloc[idx]['image_id']
        img_path = os.path.join('../input/bms-molecular-translation/test', img_id[0], img_id[1], img_id[2], f'{img_id}.png')
        
        img = self._load_from_file(img_path)
        return torch.tensor(img).permute(2, 0, 1)

    def _load_from_file(self, img_path):
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image = cv2.resize(image, (IMG_SIZE, IMG_SIZE)) 
        image /= 255.0  # Normalize
        return image

In [None]:
test_dataset = BMSDatasetTest(sample_sub_df)
test_batch_size = 128
test_dataloader = DataLoader(
    test_dataset,
    batch_size=test_batch_size,
    pin_memory=True,
    num_workers=4,
    shuffle=False,
)

In [None]:
max_pred_len_formula = int(np.mean(label_df['formula'].str.len()))+1
max_pred_len_atom_conn = int(np.mean(label_df['atom_conn'].str.len()))+1
max_pred_len_hydrogen = int(np.mean(label_df['hydrogen'].str.len()))+1

print(f'Max formula pred. length: {max_pred_len_formula}')
print(f'Max atom conn. pred. length: {max_pred_len_atom_conn}')
print(f'Max hydrogen pred. length: {max_pred_len_hydrogen}')

In [None]:
def decode(decoder, vocab, encoder_out, batch_size, max_pred_len):
    h, c = decoder.init_hidden_state(encoder_out)
    start = torch.full((batch_size, 1), vocab.stoi['<sos>']).to(device)
    pred = torch.zeros((batch_size, max_pred_len), dtype=torch.long).to(device)
    pred[:, 0] = start.squeeze()
    
    idx = 1

    while True:
        embeddings = decoder.embedding(start).squeeze(1)
        
        awe, _ = decoder.attention(encoder_out, h)
        
        gate = decoder.sigmoid(decoder.f_beta(h))  # gating scalar, (s, encoder_dim)
        awe = gate * awe

        h, c = decoder.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))
        scores = decoder.fc(h)  # (s, vocab_size)
        scores = F.log_softmax(scores, dim=1)

        start = scores.argmax(1).reshape(-1, 1).to(device)
        
        pred[:, idx] = start.squeeze(1)
        
        if idx >= max_pred_len-1:
            break
        
        idx += 1
        
    return pred

In [None]:
def inference(encoder, decoders, imgs, vocabs):
    imgs = imgs.to(device)
    batch_size = len(imgs)
    
    decoder_formula, decoder_atom_conn, decoder_hydrogen = decoders
    vocab_formula, vocab_atom_conn, vocab_hydrogen = vocabs
    
    encoder_out = encoder(imgs)  # (1, enc_image_size, enc_image_size, encoder_dim)
    enc_image_size = encoder_out.size(1)
    encoder_dim = encoder_out.size(3)
    
    encoder_out = encoder_out.view(batch_size, -1, encoder_dim)
    
    ####################
    # Formula Decoding
    ####################
    pred_formula = decode(decoder_formula, vocab_formula, encoder_out, batch_size, max_pred_len_formula)
    pred_atom_conn = decode(decoder_atom_conn, vocab_atom_conn, encoder_out, batch_size, max_pred_len_atom_conn)
    pred_hydrogen = decode(decoder_hydrogen, vocab_hydrogen, encoder_out, batch_size, max_pred_len_hydrogen)
    
    return pred_formula, pred_atom_conn, pred_hydrogen

In [None]:
def batch_stringify(batch, vocab):
    preds = []
    for item in batch:
        pred = np.vectorize(vocab.itos.get)(item)
        # Truncate everything after <eos>
        try:
            pred = pred[1:np.nonzero(pred == '<eos>')[0][0]]
        except IndexError:
            pred = pred[1: ]
            pass

        preds.append("".join(pred))
    return preds

## Calculating Competition Metric on validation data

In [None]:
decoders = (decoder_formula, decoder_atom_conn, decoder_hydrogen)
vocabs = (vocab_formula, vocab_atom_conn, vocab_hydrogen)

In [None]:
preds_formula, preds_atom_conn, preds_hydrogen = [], [], []
gts_formula, gts_atom_conn, gts_hydrogen = [], [], []

for imgs, label_formula, label_atom_conn, label_hydrogen, label_formula_len, label_atom_conn_len, label_hydrogen_lens in tqdm(val_loader):
    pred_formula, pred_atom_conn, pred_hydrogen = inference(encoder, decoders, imgs, vocabs)
    
    pred_formula = pred_formula.cpu().detach().numpy()
    pred_atom_conn = pred_atom_conn.cpu().detach().numpy()
    pred_hydrogen = pred_hydrogen.cpu().detach().numpy()

    preds_formula.extend(batch_stringify(pred_formula, vocab_formula))
    preds_atom_conn.extend(batch_stringify(pred_atom_conn, vocab_atom_conn))
    preds_hydrogen.extend(batch_stringify(pred_hydrogen, vocab_hydrogen))
    
    gts_formula.extend(batch_stringify(label_formula, vocab_formula))
    gts_atom_conn.extend(batch_stringify(label_atom_conn, vocab_atom_conn))
    gts_hydrogen.extend(batch_stringify(label_hydrogen, vocab_hydrogen))
    

lev_dist_formula = np.mean(np.vectorize(levenshtein_distance)(preds_formula, gts_formula))
lev_dist_atom_conn = np.mean(np.vectorize(levenshtein_distance)(preds_atom_conn, gts_atom_conn))
lev_dist_hydrogen = np.mean(np.vectorize(levenshtein_distance)(preds_hydrogen, gts_hydrogen))

print(f'Levenshtein distance (Formula): {lev_dist_formula}')
print(f'Levenshtein distance (Atom Conn.): {lev_dist_atom_conn}')
print(f'Levenshtein distance (Hydrogen): {lev_dist_hydrogen}\n')
print(f'Avg. Validation Levenshtein distance {(lev_dist_formula + lev_dist_atom_conn + lev_dist_hydrogen)/3}')

## Prediction on Test Data

In [None]:
preds_formula, preds_atom_conn, preds_hydrogen = [], [], []
for imgs in tqdm(test_dataloader):
    pred_formula, pred_atom_conn, pred_hydrogen = inference(encoder, decoders, imgs, vocabs)
    
    pred_formula = pred_formula.cpu().detach().numpy()
    pred_atom_conn = pred_atom_conn.cpu().detach().numpy()
    pred_hydrogen = pred_hydrogen.cpu().detach().numpy()
    
    preds_formula.extend(batch_stringify(pred_formula, vocab_formula))
    preds_atom_conn.extend(batch_stringify(pred_atom_conn, vocab_atom_conn))
    preds_hydrogen.extend(batch_stringify(pred_hydrogen, vocab_hydrogen))

In [None]:
sample_sub_df['formula'] = preds_formula
sample_sub_df['atom_conn'] = preds_atom_conn
sample_sub_df['hydrogen'] = preds_hydrogen

In [None]:
sample_sub_df.head()

In [None]:
sample_sub_df['InChI'] = sample_sub_df[['formula', 'atom_conn', 'hydrogen']].apply(lambda row: f'InChI=1S/{row["formula"]}/c{row["atom_conn"]}/h{row["hydrogen"]}', axis=1)
sample_sub_df = sample_sub_df.drop(['formula', 'atom_conn', 'hydrogen'], axis=1)
sample_sub_df.to_csv("submission.csv", index=False)
sample_sub_df.head()