# End to end Tranformer

In this notebook I extend my previous example on image captioning with transformers (https://www.kaggle.com/juansensio/e2e-transformer-image-captioning-example) for this particular challenge.

# Data preprocessing

Generate the vocab, extracted from https://www.kaggle.com/yasufuminakama/inchi-preprocess-2.

In [None]:
import os
from pathlib import Path

path = Path('/kaggle/input/bms-molecular-translation')
os.listdir(path)

In [None]:
import pandas as pd

train_labels = pd.read_csv(path / 'train_labels.csv')
train_labels

In [None]:
from tqdm.auto import tqdm
import re
tqdm.pandas()

def split_form(form):
    string = ''
    for i in re.findall(r"[A-Z][^A-Z]*", form):
        elem = re.match(r"\D+", i).group()
        num = i.replace(elem, "")
        if num == "":
            string += f"{elem} "
        else:
            string += f"{elem} {str(num)} "
    return string.rstrip(' ')

def split_form2(form):
    string = ''
    for i in re.findall(r"[a-z][^a-z]*", form):
        elem = i[0]
        num = i.replace(elem, "").replace('/', "")
        num_string = ''
        for j in re.findall(r"[0-9]+[^0-9]*", num):
            num_list = list(re.findall(r'\d+', j))
            assert len(num_list) == 1, f"len(num_list) != 1"
            _num = num_list[0]
            if j == _num:
                num_string += f"{_num} "
            else:
                extra = j.replace(_num, "")
                num_string += f"{_num} {' '.join(list(extra))} "
        string += f"/{elem} {num_string}"
    return string.rstrip(' ')

In [None]:
train_labels['InChI_1'] = train_labels.InChI.progress_apply(lambda x: x.split('/')[1])
train_labels['InChI_text'] = train_labels['InChI_1'].progress_apply(split_form) + ' ' + train_labels['InChI'].apply(lambda x: '/'.join(x.split('/')[2:])).progress_apply(split_form2).values

In [None]:
train_labels

In [None]:
def compute_vocab(InChIs):
    special = ['PAD', 'SOS', 'EOS']
    vocab = special + sorted(list({s for InChI in InChIs for s in InChI}))
    return vocab

In [None]:
VOCAB = compute_vocab(train_labels.InChI_text.map(lambda x: x.split(' ')))
VOCAB, len(VOCAB)

In [None]:
lens = train_labels.InChI_text.map(lambda x: x.split(' ')).map(len)
lens.min(), lens.max()

In [None]:
train_labels.to_csv('train_labels_tokenized.csv', index=False)

# DataModule

In [None]:
def get_image_path(image_id, path=Path('data'), mode="train"):
    return path / mode / image_id[0] / image_id[1] / image_id[2] / f'{image_id}.png'

In [None]:
import torch
from skimage import io
import pytorch_lightning as pl
from pathlib import Path
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import albumentations as A 
import numpy as np

class Dataset(torch.utils.data.Dataset):
    def __init__(self, images, inchis=None, max_len=512, trans=None, train=True, tokens=(0, 1, 2)):
        self.images = images
        self.inchis = inchis
        self.trans = trans
        self.train = train
        self.max_len = max_len
        self.PAD, self.SOS, self.EOS = tokens

    def __len__(self):
        return len(self.images)

    def __getitem__(self, ix):
        image = io.imread(self.images[ix]) 
        if self.trans:
            image = self.trans(image=image)['image']
        image = torch.tensor(image / 255., dtype=torch.float).unsqueeze(0)
        if self.train:
            inchi = torch.tensor([self.SOS] + self.inchis[ix] + [self.EOS], dtype=torch.long)
            #inchi = torch.nn.functional.pad(inchi, (0, self.max_len - len(inchi)), 'constant', self.PAD)
            return image, inchi
        return image

    def collate(self, batch):
        if self.train:
            # compute max batch length
            lens = [len(inchi) for _, inchi in batch]
            max_len = max(lens)    
            # pad inchis to max length
            images, inchis = [], []
            for image, inchi in batch:
                images.append(image)
                inchis.append(torch.nn.functional.pad(inchi, (0, max_len - len(inchi)), 'constant', self.PAD))
            # optionally, sort by length
            ixs = torch.argsort(torch.tensor(lens), descending=True)
            return torch.stack(images)[ixs], torch.stack(inchis)[ixs]
        return torch.stack([img for img in batch])

class DataModule(pl.LightningDataModule):
    def __init__(
        self, 
        data_file = 'train_labels_tokenized.csv', 
        path=Path('data'), 
        text_column="InChI_text",
        test_size=0.1, 
        random_state=42, 
        batch_size=64, 
        num_workers=0, 
        pin_memory=True, 
        shuffle_train=True, 
        val_with_train=False,
        train_trans=None,
        val_trans=None,
        subset=None,
        max_len=512,
        **kwargs
    ):
        super().__init__()
        self.data_file = data_file
        self.path = path
        self.test_size=test_size
        self.random_state=random_state
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.shuffle_train = shuffle_train
        self.val_with_train = val_with_train
        self.train_trans = train_trans
        self.val_trans = val_trans
        self.subset = subset
        self.max_len = max_len
        self.text_column = text_column
        self.stoi = {}
        self.itos = {}

    def encode(self, InChI):
        return [self.stoi[token] for token in InChI]

    def decode(self, ixs):
        skip = [self.stoi['PAD'], self.stoi['SOS'], self.stoi['EOS']]
        return ('').join([self.itos[ix.item()] for ix in ixs if ix.item() not in skip])

    def setup(self, stage=None):
        # build indices
        for i, s in enumerate(VOCAB):
            self.stoi[s] = i
        self.itos = {item[1]: item[0] for item in self.stoi.items()}
        # read csv file with data
        df = pd.read_csv(self.path / self.data_file)
        if self.subset:
            df = df.sample(int(len(df)*self.subset), random_state=self.random_state)
        # build images paths
        df.image_id = df.image_id.map(lambda x: get_image_path(x, self.path))
        # encode inchis
        df.InChI = df[self.text_column].map(lambda x: x.split(' '))
        df.InChI = df.InChI.map(self.encode)
        # train / val splits
        train, val = train_test_split(df, test_size=self.test_size, random_state=self.random_state, shuffle=True)
        print("Training samples: ", len(train))
        print("Validation samples: ", len(val))
        # datasets
        self.train_ds = Dataset(train.image_id.values, train.InChI.values, self.max_len, 
            tokens=(self.stoi['PAD'], self.stoi['SOS'], self.stoi['EOS']), trans = A.Compose([
            getattr(A, trans)(**params) for trans, params in self.train_trans.items()
        ]) if self.train_trans else None)
        self.val_ds = Dataset(val.image_id.values, val.InChI.values, self.max_len, 
            tokens=(self.stoi['PAD'], self.stoi['SOS'], self.stoi['EOS']), trans = A.Compose([
            getattr(A, trans)(**params) for trans, params in self.val_trans.items()
        ]) if self.val_trans else None)
        if self.val_with_train:
            self.val_ds = self.train_ds           
    
    def train_dataloader(self):
        return DataLoader(
            self.train_ds, 
            batch_size=self.batch_size, 
            num_workers=self.num_workers, 
            shuffle=self.shuffle_train, 
            pin_memory=self.pin_memory, 
            collate_fn=self.train_ds.collate
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds, 
            batch_size=self.batch_size, 
            num_workers=self.num_workers, 
            shuffle=False, 
            pin_memory=self.pin_memory, 
            collate_fn=self.val_ds.collate
        )


# Model

In [None]:
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import numpy as np
import torch.nn as nn

# https://github.com/jankrepl/mildlyoverfitted/blob/master/github_adventures/vision_transformer/custom.py

class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_chans, embed_dim):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  
        x = x.flatten(2)  
        x = x.transpose(1, 2) 
        return x

class Transformer(pl.LightningModule):
    def __init__(self, config=None):
        super().__init__()
        self.save_hyperparameters(config)
        self.len_vocab = len(VOCAB)
        
        self.patch_embed = PatchEmbedding(self.hparams.img_size, self.hparams.patch_size, 1, self.hparams.embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.n_patches, self.hparams.embed_dim))
        
        self.trg_emb = nn.Embedding(self.len_vocab, self.hparams.embed_dim)
        self.trg_pos_emb = nn.Embedding(self.hparams.max_len, self.hparams.embed_dim)

        dim_feedforward = 4 * self.hparams.embed_dim
        self.transformer = torch.nn.Transformer(
            self.hparams.embed_dim, self.hparams.nhead, self.hparams.num_encoder_layers, self.hparams.num_decoder_layers, dim_feedforward, self.hparams.dropout
        )
        
        self.l = nn.LayerNorm(self.hparams.embed_dim)
        self.fc = nn.Linear(self.hparams.embed_dim, self.len_vocab)

        self.apply(self._init_weights)

    def forward(self, images, captions):
        # embed images
        embed_imgs = self.patch_embed(images)
        embed_imgs = embed_imgs + self.pos_embed  
        # embed captions
        B, trg_seq_len = captions.shape 
        trg_positions = (torch.arange(0, trg_seq_len).expand(B, trg_seq_len).to(self.device))
        embed_trg = self.trg_emb(captions) + self.trg_pos_emb(trg_positions)
        trg_mask = self.transformer.generate_square_subsequent_mask(trg_seq_len).to(self.device)
        tgt_padding_mask = captions == 0 # PAD token !!!
        # transformer
        y = self.transformer(
            embed_imgs.permute(1,0,2),  
            embed_trg.permute(1,0,2),  
            tgt_mask=trg_mask, 
            tgt_key_padding_mask = tgt_padding_mask
        ).permute(1,0,2) 
        # head
        return self.fc(self.l(y))

    # https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
    
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def predict(self, images, SOS=1, EOS=2, temp=1.):
        self.eval()
        with torch.no_grad():
            images = images.to(self.device)
            B = images.shape[0]
            # start of sentence
            trg_input = torch.tensor([SOS], dtype=torch.long, device=self.device).expand(B, 1)
            while True:
                # get latest prediction
                logits = self(images, trg_input)[:,-1,:] / temp
                probs = F.softmax(logits, dim=-1) 
                # sample
                pred = torch.multinomial(probs, num_samples=1)
                # add new prediction
                trg_input = torch.cat([trg_input, pred], 1)
                if torch.any(trg_input == EOS, 1).sum().item() == B or trg_input.shape[1] >= self.hparams.max_len:
                    return trg_input

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x, y[:,:-1])
        loss = F.cross_entropy(y_hat.transpose(1,2), y[:,1:]) 
        self.log('loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x, y[:,:-1])
        loss = F.cross_entropy(y_hat.transpose(1,2), y[:,1:]) 
        self.log('val_loss', loss, prog_bar=True)
    
    def configure_optimizers(self):
        optimizer = getattr(torch.optim, self.hparams.optimizer)(self.parameters(), lr=self.hparams.lr)
        if 'scheduler' in self.hparams:
            schedulers = [
                getattr(torch.optim.lr_scheduler, scheduler)(optimizer, **params)
                for scheduler, params in self.hparams.scheduler.items()
            ]
            return [optimizer], schedulers 
        return optimizer

# Train

In [None]:
config = {
    'lr': 0.001,
    'optimizer': "Adam",
    'batch_size': 64,
    'gradient_clip_val': 1.0,
    'num_workers': 4,
    'pin_memory': True,
    'subset': 0.1,
    'img_size': 128,
    'patch_size': 16,
    'embed_dim': 64,
    'nhead': 1,
    'num_encoder_layers': 1,
    'num_decoder_layers': 1,
    'dropout': 0.,
    'max_len': 277,
    'train_trans': {
      'Resize': {
        'width': 128,
        'height': 128,
      }
    },
    'val_trans': {
      'Resize': {
        'width': 128,
        'height': 128,
      }
    },
    'gpus': 1,
    'precision': 16,
    'max_epochs': 10
}


In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint = ModelCheckpoint(
    dirpath='./', 
    filename=f'transformer-{{val_loss:.4f}}',
    save_top_k=1, 
    monitor='val_loss', 
    mode='min'
)

In [None]:
dm = DataModule(
    data_file = '/kaggle/working/train_labels_tokenized.csv', 
    path=Path('/kaggle/input/bms-molecular-translation'), 
    **config
)

model = Transformer(config)

trainer = pl.Trainer(
    gpus=config['gpus'],
    precision=config['precision'],
    max_epochs=config['max_epochs'],
    gradient_clip_val=config['gradient_clip_val'],
    callbacks=[checkpoint]
)

trainer.fit(model, dm)

# Predictions

Compute metric on validation set.

In [None]:
from tqdm import tqdm

preds, labels = [], []
model.cuda()
for imgs, labs in tqdm(dm.val_dataloader()):
    outputs = model.predict(imgs)
    preds += outputs
    labels += labs.tolist()
    
len(preds)

In [None]:
!pip install python-Levenshtein

In [None]:
preds_decoded = [dm.decode(pred) for pred in preds]
preds_inchis = ['InChI=1S/' + pred for pred in preds_decoded]

labs_decoded = [dm.decode(lab) for lab in labs]
inchis = ['InChI=1S/' + lab for lab in labs_decoded]

In [None]:
from Levenshtein import distance

metric = []
for pred, inchi in zip(preds_inchis, inchis):
    metric.append(distance(pred, inchi))
    
np.mean(metric)

Better than the sample submission !!!

In [None]:
sample_submission = pd.read_csv(path / 'sample_submission.csv')
sample_submission

Generate predictions for top 25% test images (the ones used in leaderboard, or are they a random sample?)

In [None]:
limit = int(0.25*len(sample_submission))
limit

In [None]:
test_images = sample_submission.image_id[:limit]
test_images = test_images.apply(lambda i: get_image_path(i, path, mode="test"))

ds = Dataset(test_images, train=False, trans=A.Compose([A.Resize(128,128)]))

dl = torch.utils.data.DataLoader(ds, batch_size=100, num_workers=4, pin_memory=True, shuffle=False)

In [None]:
# this step takes so long with big models and all test images ...

preds = []
model.cuda()
for batch in tqdm(dl):
    outputs = model.predict(batch)
    preds += outputs

In [None]:
preds_decoded = [dm.decode(pred) for pred in preds]
sample_submission.InChI[:limit] = ['InChI=1S/'+pred for pred in preds_decoded]
sample_submission

In [None]:
sample_submission.to_csv('submission.csv', index=False)

# Where to go from here ?


Some ideas:

- Train more epochs
- Better prediction method (Beam Search?)
- More layers
- Data augmentation
- Higher image resolution
- Cross Validation and model ensembling
- Train with all the dataset
- Generate predictions for all test set

For me, the biggest stopper is to generate predictions reliably and efficiently. Advice is welcomed :)

Multi-gpu or even TPU training should be relatively straightforward thanks to Pytorch Lightning.