# Top

In [None]:
import numpy as np
import pandas as pd
import re
import sys, time, os, gc
import cv2
import numba
import torch
import torchvision
from Levenshtein import distance

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2
import albumentations as A
import torch.nn as nn

from sklearn.model_selection import train_test_split as tts
from tqdm.notebook import tqdm

tqdm().pandas()

SEED = 1023

np.random.seed(SEED)

start_time = time.time()
time_limit = 8.5 * 3600 #8.5 hours limit

# Read Datasets

In [None]:
train = pd.read_csv("../input/bms-molecular-translation/train_labels.csv")
print(train.shape)
train.head()

In [None]:
test = pd.read_csv("../input/bms-molecular-translation/sample_submission.csv")
print(test.shape)
test.head()

# Define Constants

In [None]:
TRAIN = True #To control whether to train the model or not
PREDICT = False #To control whether to predict submission or not
DEBUG = True

IMG_SIZE = 224
MAX_LEN = 400
VOCAB_SIZE = 40
BATCH_SIZE = 128
WORKERS = 4
EPOCHS = 3

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Current device:", DEVICE.upper())

In [None]:
if DEBUG:
    sample_factor = 0.20
    n = int(train.shape[0]*sample_factor)
    train = train.sample(n, random_state=SEED*3).reset_index(drop=True)
    print("New train size:", train.shape)

In [None]:
VOCAB = ['<pad>', '<sos>', '<eos>',
         '(', ')', '+', ',', '-', '=', '/b', '/c', '/h', '/i', '/m', '/s', '/t',
         *[str(x) for x in range(10)],
         'B', 'Br', 'C', 'Cl', 'D', 'F', 'H', 'I', 'N', 'O', 'P', 'S', 'Si', 'T']

stoi = {b:n for n,b in enumerate(VOCAB)}
itos = {n:b for n,b in enumerate(VOCAB)}

# Helper Functions

In [None]:
def bms_collate(batch):
    images, labels = [], []
    for data in batch:
        images.append(data[0])
        labels.append(data[1])
    labels = pad_sequence(labels, batch_first = True, padding_value = 0) #0 is the value of stoi["<pad>"]
    return torch.stack(images), labels

def generate_inchi(pred):
    label = [itos[i] for i in pred.to("cpu").numpy()]
    result = []
    for i in range(len(label)):
        if label[i] == "<eos>":
            break
        result.append(label[i])
    result = "InChI=1S/" + "".join(result)
    return result

def levenshtein_score(y_pred, y):
    #y_pred = np.argmax(y_pred, axis=2).astype(uint8)
    score = []
    for i in range(len(y)):
        a = generate_inchi(y_pred[i])
        b = generate_inchi(y[i])
        #print(a, b, sep='\n')
        score.append(distance(a, b))
    return score

# Dataloader

In [None]:
class ImageDataLoader(Dataset):
    '''Loads the image dataset and prepares it for feedint to model.'''
    
    def __init__(self, df, img_size=224, train=True):
        '''
        df         => DataFrame where data will be extracted
        paths      => List of paths to the image files
        inchi      => List of InChI strings
        inchi_pattern => RegEx pattern for tokenizing InChI
        vocab      => List of all tokens to be used
        stoi       => Converts string token to integer
        itos       => Converts integer token to string
        '''
        self.df = df
        self.img_size = img_size
        self.is_train = train
        loc = "train" if self.is_train else "test"
        self.paths = df["image_id"].apply(
            lambda x: f"../input/bms-molecular-translation/{loc}/{x[0]}/{x[1]}/{x[2]}/{x}.png"
        )
        
        self.inchi = self.df["InChI"].values
        self.inchi_pattern = r"([A-Z][a-z]?|[0-9]|\/[a-z]?|[\S])"
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        '''Fetch image at index idx. If train, fetch the labels.'''
        
        path = self.paths[idx]
        image = cv2.imread(path)
        #Read Image
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        #Flip: Majority of Image is white, flip white and black
        image = np.apply_along_axis(lambda x: 255.0 - x, 0, image)
        image = self.rotate_if_not_landscape(image)
        image = A.Compose([
            A.Resize(self.img_size, self.img_size),
            A.Normalize(),
            ToTensorV2()
        ])(image = image)["image"]
        
        if self.is_train:
            inchi = self.inchi[idx]
            inchi = [stoi[s] for s in self.tokenize_inchi(inchi)]
            inchi.insert(0, stoi["<sos>"])
            inchi.append(stoi["<eos>"])
            inchi = torch.LongTensor(inchi)
            return image, inchi
        else:
            return image
    
    def tokenize_inchi(self, inchi):
        '''Aply Regex pattern to split InChI string.'''
        #Since the part "InChI=1S/" is similar across all InChI, remove it
        # then apply regex using the pattern to search for:
        # Element, Number, Separator, and any other characters
        return re.findall(self.inchi_pattern, inchi[9:])
    
    def rotate_if_not_landscape(self, image):
        '''If the image is long vertically (portrait), rotate 90 deg.'''
        if image.shape[0] < image.shape[1]:
            return np.rot90(image, axes=(0, 1))
        return image

In [None]:
df_train, df_valid = tts(train, test_size = 0.1, shuffle = True)
df_train, df_valid = df_train.reset_index(drop=True), df_valid.reset_index(drop=True)
print("Train:", df_train.shape)
print("Validation:", df_valid.shape)

train_data = ImageDataLoader(df_train)
valid_data = ImageDataLoader(df_valid)
train_loader = DataLoader(train_data, batch_size = BATCH_SIZE, shuffle = True,
                          drop_last = True, collate_fn = bms_collate, num_workers = WORKERS,
                          prefetch_factor=BATCH_SIZE//WORKERS)

valid_loader = DataLoader(valid_data, batch_size = BATCH_SIZE * 2, shuffle = False, drop_last = False,
                          collate_fn = bms_collate, num_workers = WORKERS,
                          prefetch_factor=(2*BATCH_SIZE)//WORKERS)

df_test = test.reset_index(drop=True)
print("Test:", df_test.shape)
test_data = ImageDataLoader(df_test, train=False)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False,
                         drop_last = False, num_workers = WORKERS,
                         prefetch_factor=BATCH_SIZE//WORKERS)

# Model

## Encoder

In [None]:
class Encoder(nn.Module):
    '''Encodes the data input into a new vectorized representation.'''
    def __init__(self):
        super().__init__()
        self.resnet = torchvision.models.resnet18(pretrained=True)
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, 256)
        self.resnet = self.resnet.to(DEVICE)
        
    def forward(self, images):
        batch_size = images.size(0)
        features = self.resnet(images)
        features = features.permute(0, 1)
        features = features.view(features.size(0), -1, features.size(-1))
        return features

## Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size, encoder_dim, decoder_dim):
        super().__init__()
        self.vocab_size = vocab_size
        self.decoder_dim = decoder_dim
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.init_h = nn.Linear(in_features = encoder_dim, out_features = decoder_dim)
        self.init_c = nn.Linear(in_features = encoder_dim, out_features = decoder_dim)
        self.lstm = nn.LSTMCell(input_size = (embed_size + encoder_dim), hidden_size = decoder_dim, bias = True)
        self.drop = nn.Dropout(p = 0.3)
        self.linear = nn.Linear(in_features = decoder_dim, out_features = vocab_size)
        
    def forward(self, features, inchis):
        embeds = self.embedding(inchis)
        
        features = features.mean(dim = 1)
        h = self.init_h(features)
        c = self.init_c(features)
        
        seq_length = len(inchis[0]) - 1
        batch_size = inchis.size(0)
        preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(DEVICE)
        
        for s in range(seq_length):
            lstm_input = torch.cat((embeds[:, s], features), dim = 1)
            h, c = self.lstm(lstm_input, (h, c))
            x = self.drop(h)
            x = self.linear(x)
            preds[:, s] = x
        return preds

    def predict(self, features, max_len):
        batch_size = features.size(0)
        features = features.mean(dim = 1)
        h = self.init_h(features)
        c = self.init_c(features)
        
        word = torch.full((batch_size, 1), stoi["<sos>"]).to(DEVICE)
        embeds = self.embedding(word)
        preds = torch.zeros((batch_size, max_len), dtype = torch.long).to(DEVICE)
        preds[:, 0] = word.squeeze()
        for i in range(max_len):
            lstm_input = torch.cat((embeds[:, 0], features), dim = 1)
            h, c = self.lstm(lstm_input, (h, c))
            x = self.drop(h)
            x = self.linear(x)
            x = x.view(batch_size, -1)
            pred_idx = x.argmax(dim = 1)
            preds[:, i] = pred_idx
            embeds = self.embedding(pred_idx).unsqueeze(1)
        return preds

# Train

In [None]:
encoder = Encoder().to(DEVICE)
decoder = Decoder(
    vocab_size = VOCAB_SIZE, embed_size = 256, encoder_dim = 256, decoder_dim = 512
).to(DEVICE)

encoder_optimizer = torch.optim.Adam(
    encoder.parameters(), lr = 1e-4, weight_decay = 1e-6, amsgrad = False)
decoder_optimizer = torch.optim.Adam(
    decoder.parameters(), lr = 1e-4, weight_decay = 1e-6, amsgrad = True)

encoder_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    encoder_optimizer, T_max = 4, eta_min = 1e-6, last_epoch = -1 )
decoder_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    decoder_optimizer, T_max = 4, eta_min = 1e-6, last_epoch = -1 )

criterion = nn.CrossEntropyLoss(ignore_index = 0)

## Train Loop

In [None]:
if TRAIN:
    #Import trained weights
    encoder.load_state_dict(torch.load(
        "../input/pytorch-train-with-lstm/bms_encoder.pth", map_location = DEVICE))
    decoder.load_state_dict(torch.load(
        "../input/pytorch-train-with-lstm/bms_decoder.pth", map_location = DEVICE))
    best_loss = np.inf
    for epoch in range(EPOCHS):
        print(f"Epoch [{epoch+1}/{EPOCHS}]:")
        
        if time.time() - start_time > time_limit:
            break #Avoid TimeLimitExceeded Error
            
        encoder.train()
        decoder.train()
        for images, inchis in tqdm(train_loader):
            images = images.to(DEVICE)
            inchis = inchis.to(DEVICE)
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            encoded = encoder(images)
            preds = decoder(encoded, inchis)
            loss = criterion(preds.permute(0, 2, 1), inchis[:, 1:])
            loss.backward()
            encoder_optimizer.step()
            decoder_optimizer.step()
        encoder.eval()
        decoder.eval()
        valid_loss = 0
        valid_levenshtein = []
        
        for images, inchis in tqdm(valid_loader):
            images = images.to(DEVICE)
            inchis = inchis.to(DEVICE)
            with torch.no_grad():
                encoded = encoder(images)
                preds = decoder(encoded, inchis)
                loss = criterion(preds.permute(0, 2, 1), inchis[:, 1:])
                valid_loss += loss.item()
                
                pred_vals = decoder.predict(encoded, max_len = MAX_LEN)
                vd = levenshtein_score(pred_vals, inchis) #Levenshtein distance
                valid_levenshtein.append(np.mean(vd))
                
        valid_loss /= len(valid_loader)
        print(f"[epoch {epoch+1}/{EPOCHS}] loss:{valid_loss}")
        valid_levenshtein = sum(valid_levenshtein) / len(valid_levenshtein)
        print(f"[epoch {epoch+1}/{EPOCHS}] Lavenshtein Score:{valid_levenshtein}")

        if valid_loss < best_loss:
            best_loss = valid_loss
            torch.save(encoder.state_dict(), "bms_encoder.pth")
            torch.save(decoder.state_dict(), "bms_decoder.pth")
            print("Saved...")
        else:
            print("Loss did not improve...")

## Inference

In [None]:
if PREDICT:
    encoder = Encoder().to(DEVICE)
    decoder = Decoder(vocab_size = VOCAB_SIZE, embed_size = 256, encoder_dim = 256, decoder_dim = 512).to(DEVICE)
    #Must Load weights from train_kernel
    encoder.load_state_dict(torch.load("./bms_encoder.pth", map_location = DEVICE)) #Change Path!
    decoder.load_state_dict(torch.load("./bms_decoder.pth", map_location = DEVICE)) #Change Path!
    
    preds = []
    for images in tqdm(test_loader):
        if time.time() - start_time > time_limit:
            break #Avoid TimeLimitExceeded Error
            
        with torch.no_grad():
            images = images.to(DEVICE)
            encoded = encoder(images)
            pred = decoder.predict(encoded, max_len = MAX_LEN)
            preds.append(pred)
    
    submit_preds = []
    for pred in preds:
        submit_preds.append([generate_inchi(p) for p in pred])
    submit_preds = np.concatenate(submit_preds, axis = 0)
    print(submit_preds.shape)
    
    submit = test[["image_id", "InChI"]].copy().reset_index(drop=True)
    submit.iloc[:len(submit_preds), 1] = submit_preds
    submit.to_csv("submission.csv", index = False)
    print(submit.head(10))