# InChI Tokenisation

Here we explore a couple of tokenisation strategies for InChI strings for the BMS Molecular Translation challenge in an attempt to keep the sequences as short as possible while retaining the molecular information.

Shortened sequences will not only speed up training, they reduce memory usage (important for transformer models since memory use scales quadratically with sequence length). Additionally, shorter sequences may also improve model accuracy since the dependcies between tokens are, in general, shorter. Shorter sequences, will however often lead to larger tokeniser vocabularies.

In [None]:
import re
import random
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

In [None]:
path = Path("/kaggle/input/bms-molecular-translation/train_labels.csv")
train_labels = pd.read_csv(path)

In [None]:
inchis = train_labels["InChI"].tolist()

In [None]:
# Firstly, lets take a subset of the full dataset to speed things up a bit

random.seed(1)
num_samples = 500000
rand_inchis = random.sample(inchis, k=num_samples)

In [None]:
# Check lengths using naive (single character) tokenisation

naive_inchi_lens = [len(inchi) for inchi in rand_inchis]

print(f"Max length: {max(naive_inchi_lens)}")
print(f"Min length: {min(naive_inchi_lens)}")
print(f"Avg length: {sum(naive_inchi_lens) / len(naive_inchi_lens)}")

In [None]:
plt.hist(naive_inchi_lens, bins=100)
plt.show()

InChIs can be very long strings since each molecular structure is required to have a unique InChI. As you can see above, a naive approach to tokenising the InChIs will lead to some very long sequences. Below we show how we can use knowledge of the InChI standard to produce shorter sequences.

Firstly, a regex which splits off the start of the InChI string as well as atoms with two characters.

In [None]:
regex = "InChI=1S|/|[0-9]|-|\+|,|\(|\)|[A-Z][a-z]|."
prog = re.compile(regex)

In [None]:
tokens = [prog.findall(inchi) for inchi in rand_inchis]

In [None]:
# Have a look at some examples :)

for i, ts in enumerate(tokens[:10]):
    print(rand_inchis[i])
    print(ts)
    print()

In [None]:
regex_inchi_lens = [len(ts) for ts in tokens]

print(f"Max length: {max(regex_inchi_lens)}")
print(f"Min length: {min(regex_inchi_lens)}")
print(f"Avg length: {sum(regex_inchi_lens) / len(regex_inchi_lens)}")

In [None]:
plt.hist(regex_inchi_lens, bins=100)
plt.show()

We can see a small improvement in sequence length, but most of this probably comes from shortening the first 8 characters into one token. Perhaps we can merge 2 and 3 digit numbers together into a single token (with the drawback of having a larger vocabulary)...

In [None]:
regex = "InChI=1S|/|[0-9]{3}|[0-9]{2}|[0-9]|-|\+|,|\(|\)|[A-Z][a-z]|."
prog = re.compile(regex)

In [None]:
tokens = [prog.findall(inchi) for inchi in rand_inchis]

In [None]:
for i, ts in enumerate(tokens[:10]):
    print(rand_inchis[i])
    print(ts)
    print()

In [None]:
regex_inchi_lens = [len(ts) for ts in tokens]

print(f"Max length: {max(regex_inchi_lens)}")
print(f"Min length: {min(regex_inchi_lens)}")
print(f"Avg length: {sum(regex_inchi_lens) / len(regex_inchi_lens)}")

In [None]:
plt.hist(naive_inchi_lens, bins=100)
plt.hist(regex_inchi_lens, bins=100)
plt.show()

We can see a much bigger improvement here, crucially the longest sequences is now only 273 tokens (vs the 386 we had before). But perhaps we can do even better using other molecular notations...

### Tokenising SMILES

SMILES (**S**implified **M**olecular-**I**nput **L**ine-**E**ntry **S**ystem)[1] is another commonly used molecular representation, which can be easily constructed from a molecule's InChI string. Similarly, it is straightforward to convert a SMILES to InChI. Unlike InChI strings, however, a molecule can have multiple valid SMILES representations. However, we may be able to use SMILES strings to significantly reduce the number of tokens in a molecular representation.

[1] Weininger, David. "SMILES, a chemical language and information system. 1. Introduction to methodology and encoding rules." Journal of chemical information and computer sciences 28.1 (1988): 31-36.

First, we need to employ the help of RDKit!

In [None]:
!conda install -y rdkit -c rdkit

In [None]:
from rdkit import Chem

In [None]:
# Lets take a few examples first...

small_sample_inchis = random.sample(rand_inchis, k=5)
mols = [Chem.rdinchi.InchiToMol(inchi)[0] for inchi in small_sample_inchis]
smiles = [Chem.MolToSmiles(mol) for mol in mols]

In [None]:
for inchi, smi in zip(rand_inchis, smiles):
    print(inchi)
    print(smi)
    print()

In [None]:
def process_inchi(inchi):
    mol = Chem.rdinchi.InchiToMol(inchi)[0]
    smi = Chem.MolToSmiles(mol)
    return smi

In [None]:
# This will take a few minutes...

smiles = [process_inchi(inchi) for inchi in rand_inchis]

In [None]:
# Make sure everything worked as expected

invalids = [smi is None or smi == "" for smi in smiles]
print(f"Number of invalid mols: {sum(invalids)}")

In [None]:
# Now we need a new tokeniser for SMILES
# Here's one possible tokenisation scheme

smi_regex = "\[|\]|Br|Cl|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|[0-9]{2}|[0-9]|."
smi_prog = re.compile(smi_regex)

In [None]:
smi_tokens = [smi_prog.findall(smi) for smi in smiles]

In [None]:
for smi, ts in zip(smiles[:10], smi_tokens[:10]):
    print(smi)
    print(ts)
    print()

In [None]:
smiles_lens = [len(ts) for ts in smi_tokens]

print(f"Max length: {max(smiles_lens)}")
print(f"Min length: {min(smiles_lens)}")
print(f"Avg length: {sum(smiles_lens) / len(smiles_lens)}")

In [None]:
plt.hist(naive_inchi_lens, bins=100)
plt.hist(regex_inchi_lens, bins=100)
plt.hist(smiles_lens, bins=100)
plt.show()

Using SMILES leads to a very significant reduction in sequence length from the original InChI strings and converting between the representations is straightforward. However, SMILES strings come with the disadvantage that each molecule can be represented in many different ways.