# About This Notebook

This notebook briefly explores allowed external InChIs and how to use `rdkit` to generate additional training images. It is similar to [this notebook](https://www.kaggle.com/stainsby/improved-synthetic-data-for-bms-competition-v3) except that I try to generate images that more closely resemble those given to us, and with approved InChIs.

Ways to use this new data:
* Pretrain a model with it and finetune on given training data
* Concatenate with given training data and train model
* Upsample existing data

# Library

In [None]:
!conda install -y -c rdkit rdkit

In [None]:
import numpy as np
import pandas as pd
import re
from tqdm.auto import tqdm
tqdm.pandas()

import torch
from torch.utils.data import Dataset
import cv2

from albumentations import (
    Compose, OneOf, Normalize, Resize, HorizontalFlip, VerticalFlip, Rotate, RandomRotate90, CenterCrop
    )

from albumentations.pytorch import ToTensorV2

import rdkit.Chem as Chem

from matplotlib import pyplot as plt

# Helper Functions

Tokenizer class and preprocessing is taken from [@yasufuminakama](https://www.kaggle.com/yasufuminakama)'s notebook [here](https://www.kaggle.com/yasufuminakama/inchi-preprocess-2).

In [None]:
class Tokenizer(object):
    
    def __init__(self):
        self.stoi = {}
        self.itos = {}

    def __len__(self):
        return len(self.stoi)
    
    def fit_on_texts(self, texts):
        vocab = set()
        for text in texts:
            vocab.update(text.split(' '))
        vocab = sorted(vocab)
        vocab.append('<sos>')
        vocab.append('<eos>')
        vocab.append('<pad>')
        for i, s in enumerate(vocab):
            self.stoi[s] = i
        self.itos = {item[1]: item[0] for item in self.stoi.items()}
        
    def text_to_sequence(self, text):
        sequence = []
        sequence.append(self.stoi['<sos>'])
        for s in text.split(' '):
            sequence.append(self.stoi[s])
        sequence.append(self.stoi['<eos>'])
        return sequence
    
    def texts_to_sequences(self, texts):
        sequences = []
        for text in texts:
            sequence = self.text_to_sequence(text)
            sequences.append(sequence)
        return sequences

    def sequence_to_text(self, sequence):
        return ''.join(list(map(lambda i: self.itos[i], sequence)))
    
    def sequences_to_texts(self, sequences):
        texts = []
        for sequence in sequences:
            text = self.sequence_to_text(sequence)
            texts.append(text)
        return texts
    
    def predict_caption(self, sequence):
        caption = ''
        for i in sequence:
            if i == self.stoi['<eos>'] or i == self.stoi['<pad>']:
                break
            caption += self.itos[i]
        return caption
    
    def predict_captions(self, sequences):
        captions = []
        for sequence in sequences:
            caption = self.predict_caption(sequence)
            captions.append(caption)
        return captions
    
def split_form(form):
    string = ''
    for i in re.findall(r"[A-Z][^A-Z]*", form):
        elem = re.match(r"\D+", i).group()
        num = i.replace(elem, "")
        if num == "":
            string += f"{elem} "
        else:
            string += f"{elem} {str(num)} "
    return string.rstrip(' ')

def split_form2(form):
    string = ''
    for i in re.findall(r"[a-z][^a-z]*", form):
        elem = i[0]
        num = i.replace(elem, "").replace('/', "")
        num_string = ''
        for j in re.findall(r"[0-9]+[^0-9]*", num):
            num_list = list(re.findall(r'\d+', j))
            assert len(num_list) == 1, f"len(num_list) != 1"
            _num = num_list[0]
            if j == _num:
                num_string += f"{_num} "
            else:
                extra = j.replace(_num, "")
                num_string += f"{_num} {' '.join(list(extra))} "
        string += f"/{elem} {num_string}"
    return string.rstrip(' ')

def get_atom_counts(dataframe):
    # https://www.kaggle.com/ttahara/bms-mt-chemical-formula-regression-training
    TARGETS = [
    'B', 'Br', 'C', 'Cl',
    'F', 'H', 'I', 'N',
    'O', 'P', 'S', 'Si']
    elem_regex = re.compile(r"[A-Z][a-z]?[0-9]*")
    atom_regex = re.compile(r"[A-Z][a-z]?")
    dgts_regex = re.compile(r"[0-9]*")
    
    atom_dict_list = []
    for fml in tqdm(dataframe["InChI_1"].values):
        atom_dict = dict()
        for elem in elem_regex.findall(fml):
            atom = dgts_regex.sub("", elem)
            dgts = atom_regex.sub("", elem)
            atom_cnt = int(dgts) if len(dgts) > 0 else 1
            atom_dict[atom] = atom_cnt
        atom_dict_list.append(atom_dict)

    atom_df = pd.DataFrame(
        atom_dict_list).fillna(0).astype(int)
    atom_df = atom_df.sort_index(axis="columns")
    for atom in TARGETS:
        dataframe[atom] = atom_df[atom]
    return dataframe

In [None]:
def get_train_file_path(image_id):
    return "../input/bms-molecular-translation/train/{}/{}/{}/{}.png".format(
        image_id[0], image_id[1], image_id[2], image_id 
    )

# Data Loading

In [None]:
OUTPUT_DIR = './'

In [None]:
train = pd.read_pickle('../input/inchi-preprocess-2/train2.pkl')
train['file_path'] = train['image_id'].apply(get_train_file_path)
allowed_inchi = pd.read_csv('../input/bms-molecular-translation/extra_approved_InChIs.csv')
print(allowed_inchi.shape)
allowed_inchi.head()

In [None]:
allowed_inchi['InChI_1'] = allowed_inchi['InChI'].progress_apply(lambda x: x.split('/')[1])
allowed_inchi['InChI_text'] = allowed_inchi['InChI_1'].progress_apply(split_form) + ' ' + \
                        allowed_inchi['InChI'].apply(lambda x: '/'.join(x.split('/')[2:])).progress_apply(split_form2).values

In [None]:
train, allowed_inchi = get_atom_counts(train), get_atom_counts(allowed_inchi)
allowed_inchi.head()

In [None]:
# ====================================================
# create tokenizer
# ====================================================
tokenizer = Tokenizer()
tokenizer.fit_on_texts(allowed_inchi['InChI_text'].values)
torch.save(tokenizer, 'tokenizer2.pth')
print('Saved tokenizer')
# ====================================================
# preprocess allowed_inchi.csv
# ====================================================
lengths = []
for text in tqdm(allowed_inchi['InChI_text'].values, total=len(allowed_inchi)):
    seq = tokenizer.text_to_sequence(text)
    length = len(seq) - 2
    lengths.append(length)
allowed_inchi['InChI_length'] = lengths
allowed_inchi.to_pickle('allowed_inchi_processed.pkl')

# Brief EDA

In [None]:
# compare length distribution to given training sequences
print(allowed_inchi['InChI_length'].min(), allowed_inchi['InChI_length'].max())
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.title('External')
plt.hist(allowed_inchi['InChI_length'].values, bins=20)
plt.subplot(1, 2, 2)
plt.title('Given')
plt.hist(train['InChI_length'].values, bins=20);

In [None]:
# ensure no overlap between inchis in train.csv
allowed_inchi[allowed_inchi['InChI_text'].isin(train['InChI_text']).values]

In [None]:
# ensure no duplicate inchis
allowed_inchi['InChI'].nunique() == len(allowed_inchi)

In [None]:
# explore atom count distribution
atoms = [
    'B', 'Br', 'C', 'Cl',
    'F', 'H', 'I', 'N',
    'O', 'P', 'S', 'Si']
fig, ax = plt.subplots(len(atoms), 2, figsize=(15, 20))
for i, atom in enumerate(atoms):
    ax[i, 0].hist(allowed_inchi[atom].values, bins=20, density=True)
    ax[i, 0].set_title(f'External - {atom}')
    ax[i, 1].hist(train[atom].values, bins=20, density=True)
    ax[i, 1].set_title(f'Given - {atom}')
    fig.tight_layout()

# Generate Additional Training Images

In [None]:
def sp_noise(image):
    #https://gist.github.com/lucaswiman/1e877a164a69f78694f845eab45c381a
    output = image.copy()
    if len(image.shape) == 2:
        black = 0
        white = 255            
    else:
        colorspace = image.shape[2]
        if colorspace == 3:  # RGB
            black = np.array([0, 0, 0], dtype='uint8')
            white = np.array([255, 255, 255], dtype='uint8')
        else:  # RGBA
            black = np.array([0, 0, 0, 255], dtype='uint8')
            white = np.array([255, 255, 255, 255], dtype='uint8')
    probs = np.random.random(image.shape[:2])
    image[probs < .00015] = black
    image[probs > .85] = white
    return image

def noisy_inchi(inchi, inchi_path, add_noise=True, crop_and_pad=True):
    mol = Chem.MolFromInchi(inchi)
    d = Chem.Draw.rdMolDraw2D.MolDraw2DCairo(300, 300)
    # https://www.kaggle.com/stainsby/improved-synthetic-data-for-bms-competition-v3
    Chem.rdDepictor.SetPreferCoordGen(True)
    d.drawOptions().maxFontSize=14
    d.drawOptions().multipleBondOffset=np.random.uniform(0.05, 0.2)
    d.drawOptions().useBWAtomPalette()
    d.drawOptions().bondLineWidth=1
    d.drawOptions().additionalAtomLabelPadding=np.random.uniform(0, .2)
    d.DrawMolecule(mol)
    d.FinishDrawing()
    d.WriteDrawingText(inchi_path)  
    if crop_and_pad:
        img = cv2.imread(inchi_path, cv2.IMREAD_GRAYSCALE)
        crop_rows = img[~np.all(img==255, axis=1), :]
        img = crop_rows[:, ~np.all(crop_rows==255, axis=0)]
        img = cv2.copyMakeBorder(img, 30, 30, 30, 30, cv2.BORDER_CONSTANT, value=255)
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    else:
        img = cv2.imread(inchi_path)
    if add_noise:
        img = sp_noise(img)
        cv2.imwrite(inchi_path, img)
    return img

In [None]:
# ====================================================
# Dataset
# ====================================================
class ExternalDataset(Dataset):
    def __init__(self, df, tokenizer, transform=None):
        super().__init__()
        self.df = df
        self.tokenizer = tokenizer
        self.labels = df['InChI_text'].values
        self.inchis = df['InChI'].values
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        inchi = self.inchis[idx]
        image = noisy_inchi(inchi, inchi_path=f'{OUTPUT_DIR}{idx}.png')
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = self.labels[idx]
        label = self.tokenizer.text_to_sequence(label)
        label_length = len(label)
        label_length = torch.LongTensor([label_length])
        return image, torch.LongTensor(label), label_length
    
class TrainDataset(Dataset):
    def __init__(self, df, tokenizer, transform=None):
        super().__init__()
        self.df = df
        self.tokenizer = tokenizer
        self.file_paths = df['file_path'].values
        self.labels = df['InChI_text'].values
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = self.labels[idx]
        label = self.tokenizer.text_to_sequence(label)
        label_length = len(label)
        label_length = torch.LongTensor([label_length])
        return image, torch.LongTensor(label), label_length
    
# ====================================================
# Augmentations
# ====================================================
def get_transforms(*, data):
    if data == 'train':
        return Compose([
            Resize(224, 224),
            OneOf([
                 VerticalFlip(),
                 RandomRotate90(),
                  ], p=.5),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])
    
    elif data == 'valid':
        return Compose([
            Resize(224, 224),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

In [None]:
start = np.random.randint(100000)
ext_dataset = ExternalDataset(allowed_inchi, tokenizer, transform=get_transforms(data='valid'))

for i in range(start, start+5):
    plt.figure(figsize=(5, 5))
    image, label, label_length = ext_dataset[i]
    text = tokenizer.sequence_to_text(label.numpy())
    plt.imshow(image.transpose(0, 1).transpose(1, 2))
    plt.show() 

# Compare to Original Images

In [None]:
#start = np.random.randint(100000)
start=0
ext_dataset = ExternalDataset(train, tokenizer, transform=get_transforms(data='valid'))
orig_dataset = TrainDataset(train, tokenizer, transform=get_transforms(data='valid'))

for i in range(start, start+8):
    plt.figure(figsize=(10, 10))
    image, label, label_length = ext_dataset[i]
    orig_image, _, _ = orig_dataset[i]
    text = tokenizer.sequence_to_text(label.numpy())
    plt.subplot(1, 2, 1)
    plt.title('Generated')
    plt.imshow(image.transpose(0, 1).transpose(1, 2))
    plt.subplot(1, 2, 2)
    plt.title('Given')
    plt.imshow(orig_image.transpose(0, 1).transpose(1, 2))
    plt.show() 

Not too bad. You can increase the noise and change the drawing parameters in `rdkit` for further improvements.