# CNN to GRU with Attention and Teacher Forcing

This is a baseline for using aa CNN into LSTM type model for image captioning. The LSTM decoder uses an attention over the CNN activations for decoding. Teacher Forcing is used during training to speed up convergence.

To aavoid having a long notebook run time, this notebook by default doesn't use the full training dataset, nor does it predict on all items in the test dataset. Where this happens, I've added the full dataset version commented out. 

Running this notebook on the entire dataset (four training epochs) took about 30 hours on a 2080 ti GPU. The resulting submission scored 55.6 on the public leaderboard.

This notebook draws from [this example](https://github.com/fastai/course-nlp/blob/master/7b-seq2seq-attention-translation.ipynb)

Note that installing RDKit on a GPU accelerated instance usually takes around 15 minutes due to conda install issues.

In [None]:
!conda install -y -c rdkit rdkit

In [None]:
import pandas as pd
import numpy as np
from multiprocessing import Pool
import os
from skimage import color
import matplotlib.pyplot as plt
from PIL import Image
import skimage
import math
import random
from functools import partial
import gc

from rdkit import Chem
from rdkit.Chem import Draw
from rdkit import RDLogger

RDLogger.DisableLog('rdApp.*') 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import *
from torchvision.models import resnet34


In [None]:
from fastai import *
from fastai.data.core import DataLoaders
from fastai.learner import *
from fastai.callback.all import *
from fastai.losses import *
from fastai.metrics import *

In [None]:
def smile_to_mol(smile):
    return Chem.MolFromSmiles(smile)

def inchi_to_mol(inchi):
    return Chem.inchi.MolFromInchi(inchi)

def mol_to_smile(mol):
    return Chem.MolToSmiles(mol)

def mol_to_inchi(mol):
    return Chem.inchi.MolToInchi(mol)

In [None]:
def pad_square(image):
    w,h = image.shape[-2:]
    max_wh = max([w,h])
    hp = int((max_wh-w)/2)
    vp = int((max_wh-h)/2)
    padding = (vp, hp)
    return Pad(padding, padding_mode='edge')(image)

In [None]:
class ImageDataset(Dataset):
    def __init__(self, filenames, file_prefix, y_vals, itos, stoi, size, return_inchi=True):
        self.filenames = filenames
        self.file_prefix = file_prefix
        self.y_vals = y_vals
        self.itos = itos
        self.stoi = stoi
        self.size = size
        self.return_inchi = return_inchi
        self.resize = transforms.Resize((size, size))
        
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        
        filename = self.filenames[idx]

        image = np.array(Image.open(f'{self.file_prefix}/{filename[0]}/{filename[1]}/{filename[2]}/{filename}.png'))
        image = torch.FloatTensor(image)[None,:,:].repeat((3,1,1))/255.
        
        if not image.shape[-1]==image.shape[-2]:
            image = pad_square(image)
            
        image = self.resize(image)
            
        output = self.y_vals[idx]

        if 'InChI' in output:
            is_inchi = True
        else:
            is_inchi = False

        if (is_inchi and self.return_inchi) or (not is_inchi and not self.return_inchi):
            out_string = output

        elif is_inchi and not self.return_inchi:
            out_string = mol_to_smile(inchi_to_mol(output))

        else:
            out_string = mol_to_inchi(smile_to_mol(output))

        out_ints = [self.stoi['bos']] + [stoi[i] for i in out_string] + [self.stoi['eos']]

        return image.data, out_ints

## Data Setup

One thing I've played around with is comparing generating InChI strings directly, or generating something like a SMILES string (a shorter sequence) that is then converted to an InChI. It seems that a SMILES approach results in more correct structures (in terms of resolving to a valid compound), but predicting InChIs gives better performance overall due to the fact that any incorrect SMILES strings are basically lost (ie can't be converted to an InChI).

The `ImageDataset` dataset will return InChI strings by default, but will return SMILES strings if `return_inchi=False` is passed.

I also decided to remove compounds from the dataset that had an InChI string longer than 250 characters. These long tail sequences can sneak up on you and give you a Cuda memory error.

In [None]:
# smiles itos
# itos = ['bos', 'pad', 'eos',
#  'N', '1', '(', 'P', '8', 'S',
#  'H', ']', 'B', '#', '=', ')', 
#  's', '-', 'r', '7', '4',
#  '3', '[', 'c', '6', 'n', '2',
#  'i', 'o', 'O', '@', 'F', '/',
#  'I', 'l', '\\', '5', '+', 'C',
#  '%', '.', '0', '9', 'b', 'p']

# inchi itos
itos = ['bos', 'pad', 'eos',
 '=', '/', '+',
 '6', 'O', 'C', '5', 'D', ',', 'c',
 '7', '2', 'b', 'i', '0', 'B', '1',
 '8', 'h', ')', 'n', '9', '-', 'S',
 '(', '4', 'H', 'm', '3', 'F', 't',
 'P', 'l', 'N', 'T', 's', 'r', 'I']
stoi = {itos[i]:i for i in range(len(itos))}

In [None]:
# loading just the first 10000 rows, uncomment full load below for using the entire dataset
df = next(pd.read_csv('../input/bms-molecular-translation/train_labels.csv', chunksize=10000))

# df = pd.read_csv('../input/bms-molecular-translation/train_labels.csv')

In [None]:
lens = df.InChI.map(lambda x: len(x))
df = df[lens<250]
df.shape

In [None]:
cut = int(0.97*len(df))
df_train = df[:cut]
df_valid = df[cut:]

In [None]:
train_pefix = '../input/bms-molecular-translation/train'

# Note image size is hard-coded to 256x256
train_data = ImageDataset(df_train.image_id.values, train_pefix, df_train.InChI.values,
                          itos, stoi, 256, return_inchi=True)

valid_data = ImageDataset(df_valid.image_id.values, train_pefix, df_valid.InChI.values,
                          itos, stoi, 256, return_inchi=True)

In [None]:
plt.imshow(train_data[0][0].permute(1,2,0))

In [None]:
def collate_function(batch, pad=1):
    
    images = torch.stack([i[0] for i in batch])
    
    max_len = max([len(i[1]) for i in batch])
    res_y = torch.zeros(len(batch), max_len).long() + pad
    
    for i, s in enumerate(batch):
        res_y[i,:len(s[1])] = torch.LongTensor(s[1])
    
    return images, res_y

In [None]:
train_dl = DataLoader(train_data, batch_size=120, collate_fn=collate_function, shuffle=True, num_workers=4)
valid_dl = DataLoader(valid_data, batch_size=120, collate_fn=collate_function, shuffle=False, num_workers=4)

In [None]:
dls = DataLoaders(train_dl, valid_dl)

In [None]:
x,y = next(iter(dls.loaders[0]))

In [None]:
x.shape, y.shape

In [None]:
x.device

## Model Architecture

The model first maps the input images down to a set of activations using the `image_encoder`, which in this notebook is a Resnet34 model. Then the model decodes the output sequence one token at a time using a GRU. At each decoding step, the model computes attention over the image activations.

This model also has a teacher forcing parameter `pr_force`. At each decoding step, with the probability of `pr_force`, the model is given the ground truth answer. This helps speed up convergence early in training. We start with `pr_force=1` and decay the value to zero over the course of training.

In [None]:
class ImageCaption(nn.Module):
    def __init__(self, image_encoder, d_enc_out, nh, emb_sz_dec, voc_sz_dec, out_sl, nl=2, bos_idx=0, pad_idx=1):
        super().__init__()
        
        self.nl = nl
        self.nh = nh
        self.bos_idx = bos_idx
        self.pad_idx = pad_idx
        self.out_sl = out_sl
        self.emb_sz_dec = emb_sz_dec
        self.voc_sz_dec = voc_sz_dec
        self.pr_force = 0.
        
        self.encoder = image_encoder
        self.init_hidden = nn.Linear(d_enc_out, nh*nl*2)
        self.out_enc = nn.Linear(2*nh, self.emb_sz_dec, bias=False)
        
        self.emb_dec = nn.Embedding(voc_sz_dec, emb_sz_dec)
        self.gru_dec = nn.GRU(self.emb_sz_dec + 2*nh, self.emb_sz_dec, num_layers=nl,
                              dropout=0.1, batch_first=True)
        self.out_drop = nn.Dropout(0.35)
        self.out = nn.Linear(self.emb_sz_dec, self.voc_sz_dec)
        self.out.weight.data = self.emb_dec.weight.data
        
        self.enc_projection = nn.Linear(d_enc_out, nh*2)
        self.enc_att = nn.Linear(nh*2, self.emb_sz_dec, bias=False)
        self.hid_att = nn.Linear(self.emb_sz_dec, self.emb_sz_dec)
        self.V =  self.init_param(self.emb_sz_dec)
        
    def decoder(self, dec_inp, hid, enc_att, enc_out):
        hid_att = self.hid_att(hid[-1])
        u = torch.tanh(enc_att + hid_att[:,None])
        attn_wgts = F.softmax(u @ self.V, 1)
        ctx = (attn_wgts[...,None] * enc_out).sum(1)
        emb = self.emb_dec(dec_inp)
        outp, hid = self.gru_dec(torch.cat([emb, ctx], 1)[:,None], hid)
        outp = self.out(self.out_drop(outp[:,0]))
        return hid, outp
        
    def forward(self, x, targ=None):
        
        enc_out = self.encoder(x)
        bs, sl, _ = enc_out.shape
        
        out_sl = self.out_sl if targ is None else targ.shape[1]
        
        mean_encoder_out = enc_out.mean(dim=1)
        hid = self.init_hidden(mean_encoder_out)
        hid = hid.view(2, self.nl, bs, self.nh).permute(1,2,0,3).contiguous()
        hid = hid.view(self.nl, bs, 2*self.nh)
        hid = self.out_enc(hid)

        enc_out = self.enc_projection(enc_out)
        enc_att = self.enc_att(enc_out)
        
        dec_inp = x.new_zeros(bs).long() + self.bos_idx
        res = []
        
        for i in range(out_sl):
            hid, outp = self.decoder(dec_inp, hid, enc_att, enc_out)
            res.append(outp)
            dec_inp = outp.max(1)[1]
            if (dec_inp==self.pad_idx).all(): break
            if (targ is not None) and (random.random()<self.pr_force):
                if i>=targ.shape[1]: continue
                dec_inp = targ[:,i]
        return torch.stack(res, dim=1)
            
    def init_param(self, *sz): return nn.Parameter(torch.randn(sz)/math.sqrt(sz[0]))

In [None]:
class ImageEncoder(nn.Module):
    # wrapper for torchvision model
    def __init__(self, image_encoder):
        super().__init__()
        
        modules = list(image_encoder.children())[:-2]
        self.image_encoder = nn.Sequential(*modules)
        
    def forward(self, x):
        x = self.image_encoder(x) # (bs, ch, h, w)
        x = x.permute(0, 2, 3, 1) # (bs, h, w, ch)
        x = x.view(x.size(0), -1, x.size(-1)) # (bs, h*w, ch)

        return x

In [None]:
image_encoder = ImageEncoder(resnet34())

In [None]:
d_enc_out = 512
nh = 256
emb_sz_dec = 128
voc_sz_dec = len(itos)
out_sl = 220

ic = ImageCaption(image_encoder, d_enc_out, nh, emb_sz_dec, voc_sz_dec, out_sl)

In [None]:
def seq2seq_loss(out, targ, pad_idx=1):
    bs,targ_len = targ.size()
    _,out_len,vs = out.size()
    if targ_len>out_len: out  = F.pad(out,  (0,0,0,targ_len-out_len,0,0), value=pad_idx)
    if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx)
    return CrossEntropyLossFlat()(out, targ)

def seq2seq_acc(out, targ, pad_idx=1):
    bs,targ_len = targ.size()
    _,out_len,vs = out.size()
    if targ_len>out_len: out  = F.pad(out,  (0,0,0,targ_len-out_len,0,0), value=pad_idx)
    if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx)
    out = out.argmax(2)
    return (out==targ).float().mean()


In [None]:
class TeacherForcingCallback(Callback):
    # teacher forcing callback
    def __init__(self, start_batch, end_batch):
        self.start_batch = start_batch
        self.end_batch = end_batch

    def before_batch(self):
        
        n_iter = self.train_iter
        
        if n_iter < self.start_batch:
            self.learn.model.pr_force = 1.
            
        elif n_iter > self.end_batch:
            self.learn.model.pr_force = 0.
            
        else:
            self.learn.model.pr_force = 1 - (n_iter - self.start_batch)/(self.end_batch - self.start_batch)
        
        if self.training:
            x,y = self.x, self.y
            self.learn.xb = (x,y)

In [None]:
learn = Learner(dls, ic, loss_func=seq2seq_loss, cbs=[CudaCallback, 
                                                      TeacherForcingCallback(start_batch=4000, end_batch=25000)], 
                metrics=[seq2seq_acc, CorpusBLEUMetric(len(itos))])

## Training

In [None]:
learn.lr_find()

In [None]:
learn.fit_one_cycle(4, 3e-3)

## Prediction

In [None]:
test_df = pd.read_csv('../input/bms-molecular-translation/sample_submission.csv')

In [None]:
test_df.head()

In [None]:
test_prefix = '../input/bms-molecular-translation/test'

test_df = test_df[:1000] # predict on first 1000

test_data = ImageDataset(test_df.image_id.values, test_prefix, test_df.InChI.values,
                          itos, stoi, 256, return_inchi=False)

In [None]:
test_dl = DataLoader(test_data, batch_size=256, collate_fn=collate_function, shuffle=False)

In [None]:
learn.model.eval();

In [None]:
preds_list = []

with torch.no_grad():
    for i, batch in enumerate(test_dl):
        if i%500 == 0:
            print(i)
        
        x,y = batch
        preds = learn.model(x.cuda())
        preds = F.softmax(preds, -1).argmax(-1)
        preds_list.append(preds.detach().cpu())

In [None]:
pred_strs = []

for k, pred in enumerate(preds_list):
    if k%1000 == 0:
        print(k)
    
    gc.collect()
    for p in pred:
        pred_str = [itos[i] for i in p][1:]
        
        if 'eos' in pred_str:
            pred_str = pred_str[:pred_str.index('eos')]
            
        if 'pad' in pred_str:
            pred_str = pred_str[:pred_str.index('pad')]
            
        pred_str = ''.join(pred_str)
        
        pred_strs.append(pred_str)

In [None]:
test_df['preds'] = pred_strs

In [None]:
submission = test_df[['image_id', 'preds']]
submission.columns = ['image_id', 'InChI']

In [None]:
submission.head()

In [None]:
submission.to_csv('submission.csv', index=False)

## Suggested Improvements

This notebook is fairly basic. There's a lot of simple improvements to be made from here.

### Data

No data augmentation is used. This could be added, along with synthetic data from other datasets (ie use RDKit to generate images). The decision to predict SMILES instead of InChI can also be revisited.

### Model

The current setup is hard-coded to work with 256x256 images. This can be changed by adding an adaptive pooling layer after the CNN encoder. The decoder is actually fairly small. Most of the parameters are in the image encoder. Model size can be expanded on. It would also be good to investigate larger resnet backbones.

### Metrics

I used BLEU and accuracy to evaluate the model. A better metric would be to actually convert SMILES to InChI and calculate the actual Levenshtein distance.

### Inference

Inference shown here is extremely simple, using single prediction argmax samping. This could be improved by adding beam search decoding and test-time augmentation to the input images. If you wanted to get really fancy, you could use TTA+beam search to generate several pedictions, then find the concensus string between all pedictions.