## Spectra Encoder: Transformer

Primary reference: https://www.nature.com/articles/s42004-023-00932-3#Sec19

Using this paper as a framework, the purpose of this transformer is to take input GC-MS spectral data and output embeddings to be passed to the SMILES decoder. The reference used images of GC-MS data and implemented a CNN; we intend to use a transformer instead.  

#### Supplemental references:
https://jalammar.github.io/illustrated-transformer/ (Illustrated overview of Transformer function)

https://nlp.seas.harvard.edu/2018/04/03/attention.html (Harvard coding annotation of original Transformation paper)

https://www.datacamp.com/tutorial/building-a-transformer-with-py-torch (Datacamp Transformer tutorial)

Notebook overview:
1. Define model building blocks
2. Encoding
3. Decoding
4. Training
5. Evaluation


## Preparing the input data

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
import pandas as pd
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
import csv

In [2]:
# finding unique characters in the SMILES column of training data 

unique_characters = set() 

with open('dataset/filtered_gc_spec.csv', 'r') as f:
    reader = csv.DictReader(f)  
    for row in reader:
        for char in row["SMILES"]:
            unique_characters.add(char)  # Add each character to the set

print(len(unique_characters))  


45


In [3]:
# finding unique tuples in the spectral training data
unique_tuples = set()  

with open('dataset/filtered_gc_spec.csv', 'r') as f:
    reader = csv.DictReader(f)
    for row in reader:
        spectrum_data = row["Spectrum"]
        tuples = spectrum_data.split() 
        for tup in tuples:
            if ':' in tup and tup.count(':') == 1: 
                unique_tuples.add(tup)

print(len(unique_tuples))  


517627


So we need to go from a "vocabulary" of 517627 unique tuples to 45 unique characters

## Tokenizing the SMILES data for training

In [7]:
# Requriments - transformers, tokenizers
# Right now, the Smiles Tokenizer uses an exiesting vocab file from rxnfp that is fairly comprehensive and from the USPTO dataset.
# The vocab may be expanded in the near future

import collections
import os
import re
import pkg_resources
from typing import List
from transformers import BertTokenizer
from logging import getLogger

logger = getLogger(__name__)
"""
SMI_REGEX_PATTERN: str
    SMILES regex pattern for tokenization. Designed by Schwaller et. al.

References

.. [1]  Philippe Schwaller, Teodoro Laino, Théophile Gaudin, Peter Bolgar, Christopher A. Hunter, Costas Bekas, and Alpha A. Lee
        ACS Central Science 2019 5 (9): Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction
        1572-1583 DOI: 10.1021/acscentsci.9b00576

"""

SMI_REGEX_PATTERN = r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|
#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""

# add vocab_file dict
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}


class SmilesTokenizer(BertTokenizer):
    def __init__(self, vocab_file: str = '', max_len: int = 512, **kwargs):
        """Constructs a SmilesTokenizer.

        Parameters
        ----------
        vocab_file: str
            Path to a SMILES character per line vocabulary file.
            Default vocab file is found in deepchem/feat/tests/data/vocab.txt
        max_len: int
            Maximum length for tokenized sequences.
        """
        self.max_len = max_len

        super().__init__(vocab_file=vocab_file, max_len=max_len, **kwargs)
        
        # take into account special tokens in max length
        #self.max_len_single_sentence = self.max_len - 2
        #self.max_len_sentences_pair = self.max_len - 3

        if not os.path.isfile(vocab_file):
            raise ValueError(
                "Can't find a vocab file at path '{}'.".format(vocab_file)
            )
        self.vocab = load_vocab(vocab_file)
        self.highest_unused_index = max(
            [i for i, v in enumerate(self.vocab.keys()) if v.startswith("[unused")])
        self.ids_to_tokens = collections.OrderedDict(
            [(ids, tok) for tok, ids in self.vocab.items()])
        self.basic_tokenizer = BasicSmilesTokenizer()
        self.init_kwargs["max_len"] = self.max_len


    @property
    def vocab_size(self):
      return len(self.vocab)

    @property
    def vocab_list(self):
      return list(self.vocab.keys())

    def _tokenize(self, text: str):
      """
          Tokenize a string into a list of tokens.

          Parameters
          ----------
          text: str
              Input string sequence to be tokenized.
          """

      split_tokens = [token for token in self.basic_tokenizer.tokenize(text)]
      return split_tokens

    def _convert_token_to_id(self, token):
      """
          Converts a token (str/unicode) in an id using the vocab.

          Parameters
          ----------
          token: str
              String token from a larger sequence to be converted to a numerical id.
          """

      return self.vocab.get(token, self.vocab.get(self.unk_token))

    def _convert_id_to_token(self, index):
      """
          Converts an index (integer) in a token (string/unicode) using the vocab.

          Parameters
          ----------
          index: int
              Integer index to be converted back to a string-based token as part of a larger sequence.
          """

      return self.ids_to_tokens.get(index, self.unk_token)

    def convert_tokens_to_string(self, tokens: List[str]):
      """ Converts a sequence of tokens (string) in a single string.

          Parameters
          ----------
          tokens: List[str]
              List of tokens for a given string sequence.

          Returns
          -------
          out_string: str
              Single string from combined tokens.
          """

      out_string: str = " ".join(tokens).replace(" ##", "").strip()
      return out_string

    def add_special_tokens_ids_single_sequence(self, token_ids: List[int]):
      """
          Adds special tokens to the a sequence for sequence classification tasks.
          A BERT sequence has the following format: [CLS] X [SEP]

          Parameters
          ----------

          token_ids: list[int]
              list of tokenized input ids. Can be obtained using the encode or encode_plus methods.
          """

      return [self.cls_token_id] + token_ids + [self.sep_token_id]

    def add_special_tokens_single_sequence(self, tokens: List[str]):
      """
          Adds special tokens to the a sequence for sequence classification tasks.
          A BERT sequence has the following format: [CLS] X [SEP]

          Parameters
          ----------
          tokens: List[str]
              List of tokens for a given string sequence.

          """
      return [self.cls_token] + tokens + [self.sep_token]

    def add_special_tokens_ids_sequence_pair(self, token_ids_0: List[int],
                                            token_ids_1: List[int]) -> List[int]:
      """
          Adds special tokens to a sequence pair for sequence classification tasks.
          A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]

          Parameters
          ----------
          token_ids_0: List[int]
              List of ids for the first string sequence in the sequence pair (A).

          token_ids_1: List[int]
              List of tokens for the second string sequence in the sequence pair (B).
          """

      sep = [self.sep_token_id]
      cls = [self.cls_token_id]

      return cls + token_ids_0 + sep + token_ids_1 + sep

    def add_padding_tokens(self,
                          token_ids: List[int],
                          length: int,
                          right: bool = True) -> List[int]:
      """
          Adds padding tokens to return a sequence of length max_length.
          By default padding tokens are added to the right of the sequence.

          Parameters
          ----------
          token_ids: list[int]
              list of tokenized input ids. Can be obtained using the encode or encode_plus methods.

          length: int

          right: bool (True by default)

          Returns
          ----------
          token_ids :
              list of tokenized input ids. Can be obtained using the encode or encode_plus methods.

          padding: int
              Integer to be added as padding token

          """
      padding = [self.pad_token_id] * (length - len(token_ids))

      if right:
        return token_ids + padding
      else:
        return padding + token_ids

    def save_vocabulary(
        self, vocab_path: str
    ):  # -> tuple[str]: doctest issue raised with this return type annotation
      """
          Save the tokenizer vocabulary to a file.

          Parameters
          ----------
          vocab_path: obj: str
              The directory in which to save the SMILES character per line vocabulary file.
              Default vocab file is found in deepchem/feat/tests/data/vocab.txt

          Returns
          ----------
          vocab_file: :obj:`Tuple(str)`:
              Paths to the files saved.
              typle with string to a SMILES character per line vocabulary file.
              Default vocab file is found in deepchem/feat/tests/data/vocab.txt

          """
      index = 0
      if os.path.isdir(vocab_path):
        vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
      else:
        vocab_file = vocab_path
      with open(vocab_file, "w", encoding="utf-8") as writer:
        for token, token_index in sorted(
            self.vocab.items(), key=lambda kv: kv[1]):
          if index != token_index:
            logger.warning(
                "Saving vocabulary to {}: vocabulary indices are not consecutive."
                " Please check that the vocabulary is not corrupted!".format(
                    vocab_file))
            index = token_index
          writer.write(token + "\n")
          index += 1
      return (vocab_file,)


class BasicSmilesTokenizer(object):
  """

    Run basic SMILES tokenization using a regex pattern developed by Schwaller et. al. This tokenizer is to be used
    when a tokenizer that does not require the transformers library by HuggingFace is required.

    Examples
    --------
    >>> from deepchem.feat.smiles_tokenizer import BasicSmilesTokenizer
    >>> tokenizer = BasicSmilesTokenizer()
    >>> print(tokenizer.tokenize("CC(=O)OC1=CC=CC=C1C(=O)O"))
    ['C', 'C', '(', '=', 'O', ')', 'O', 'C', '1', '=', 'C', 'C', '=', 'C', 'C', '=', 'C', '1', 'C', '(', '=', 'O', ')', 'O']


    References
    ----------
    .. [1]  Philippe Schwaller, Teodoro Laino, Théophile Gaudin, Peter Bolgar, Christopher A. Hunter, Costas Bekas, and Alpha A. Lee
            ACS Central Science 2019 5 (9): Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction
            1572-1583 DOI: 10.1021/acscentsci.9b00576

    """

  def __init__(self, regex_pattern: str = SMI_REGEX_PATTERN):
    """ Constructs a BasicSMILESTokenizer.
        Parameters
        ----------

        regex: string
            SMILES token regex

        """
    self.regex_pattern = regex_pattern
    self.regex = re.compile(self.regex_pattern)

  def tokenize(self, text):
    """ Basic Tokenization of a SMILES.
        """
    tokens = [token for token in self.regex.findall(text)]
    return tokens


def load_vocab(vocab_file):
  """Loads a vocabulary file into a dictionary."""
  vocab = collections.OrderedDict()
  with open(vocab_file, "r", encoding="utf-8") as reader:
    tokens = reader.readlines()
  for index, token in enumerate(tokens):
    token = token.rstrip("\n")
    vocab[token] = index
  return vocab

  import pkg_resources


In [8]:
vocab_file = 'vocab.txt'
tokenizer = SmilesTokenizer(vocab_file)

# Defining model components

## Implementing model

In [85]:
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset

# load data
data = pd.read_csv('dataset/filtered_gc_spec.csv')
input_MS = pd.Series(data["Spectrum"][:100])
output_SMILES = pd.Series(data["SMILES"][:100])

assert len(input_MS) == len(output_SMILES)  # sanity check to ensure correct loading

# filter input by length of SMILES (<77 as per SMILES encoder)
output_SMILES_filtered = output_SMILES[output_SMILES.str.len() < 77]
input_MS_filtered = input_MS.loc[output_SMILES_filtered.index]

assert len(input_MS_filtered) == len(output_SMILES_filtered)  # sanity check to ensure correct filtering

print(f"Number of GC-MS Spectra for input: {len(input_MS_filtered)}")
print(f"Number of SMILES sequences for output: {len(output_SMILES_filtered)}")

smiles_list = output_SMILES_filtered.tolist()

# tokenize SMILES data
tokenized_smiles = [tokenizer.encode(smiles) for smiles in smiles_list]

# set max length of SMILES to 64
max_length = 64

def pad_seq(tokens, max_length):
    return tokens + [0] * (max_length - len(tokens))  # adding padding (zeroes)

# pad all SMILES entries and convert to tensor
padded_smiles = [pad_seq(tokens, max_length) for tokens in tokenized_smiles]
smiles_tensor = torch.tensor(padded_smiles, dtype=torch.long)

# converting MS data to tensor
def spec_2_tensor(spectrum, max_length):
    spectrum_tuples = [(float(mz), float(intensity)) for mz, intensity in (item.split(":") for item in spectrum.split())]
    return spectrum_tuples[:max_length] + [(0, 0)] * (max_length - len(spectrum_tuples))  # padding as zeroes

input_MS_data = input_MS_filtered.apply(lambda x: spec_2_tensor(x, max_length))
ms_tensor = torch.tensor(input_MS_data.tolist(), dtype=torch.float32)

# linear layer to map from 2 features to vocab size
src_vocab_size = 64
linear_layer = nn.Linear(2, src_vocab_size)

# flatten, transform, and reshape back
ms_tensor_flat = ms_tensor.view(-1, 2)  # flatten for batch processing
ms_tensor_transformed = linear_layer(ms_tensor_flat)
ms_tensor_indices = ms_tensor_transformed.argmax(dim=1)  # select the index with the highest value
ms_tensor_indices = ms_tensor_indices.view(len(input_MS_filtered), max_length)  # reshape back

# train/test split
ms_train, ms_test, smiles_train, smiles_test = train_test_split(ms_tensor_indices.numpy(), smiles_tensor.numpy(), test_size=0.2, random_state=42)

# convert back to tensors
ms_train = torch.tensor(ms_train, dtype=torch.long)
ms_test = torch.tensor(ms_test, dtype=torch.long)
smiles_train = torch.tensor(smiles_train, dtype=torch.long)
smiles_test = torch.tensor(smiles_test, dtype=torch.long)

# create datasets and dataloaders
batch_size_train = 5
batch_size_test = 5

train_dataset = TensorDataset(ms_train, smiles_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)

test_dataset = TensorDataset(ms_test, smiles_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=True)


Number of GC-MS Spectra for input: 100
Number of SMILES sequences for output: 100


In [86]:
import torch
import torch.nn as nn
import math

#positional encoder
class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len):
        super().__init__()
        self.dropout = nn.Dropout(dropout_p)
        
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1)
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model)
        
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding", pos_encoding)
        
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])

class Transformer(nn.Module):
    def __init__(self, num_tokens, dim_model, num_heads, num_encoder_layers, num_decoder_layers, dropout_p):
        super().__init__()

        self.model_type = "Transformer"
        self.dim_model = dim_model

        self.positional_encoder = PositionalEncoding(dim_model=dim_model, dropout_p=dropout_p, max_len=5000)
        self.embedding = nn.Embedding(num_tokens, dim_model)
        
        # Setting batch_first=True for better performance with nested tensors
        self.transformer = nn.Transformer(
            d_model=dim_model, 
            nhead=num_heads, 
            num_encoder_layers=num_encoder_layers, 
            num_decoder_layers=num_decoder_layers, 
            dropout=dropout_p, 
            batch_first=True  # setting this to True for nested tensors, error otherwise
        )
        self.out = nn.Linear(dim_model, num_tokens)
        
    def forward(self, src, tgt, tgt_mask=None, src_pad_mask=None, tgt_pad_mask=None):
        src = self.embedding(src) * math.sqrt(self.dim_model)
        tgt = self.embedding(tgt) * math.sqrt(self.dim_model)
        src = self.positional_encoder(src)
        tgt = self.positional_encoder(tgt)

        transformer_out = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask)
        out = self.out(transformer_out)
        
        return out
    
    def get_tgt_mask(self, size) -> torch.tensor:
        mask = torch.tril(torch.ones(size, size) == 1)
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf'))
        mask = mask.masked_fill(mask == 1, float(0.0))
        return mask
    
    def create_pad_mask(self, matrix: torch.tensor, pad_token: int) -> torch.tensor:
        return (matrix == pad_token)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(num_tokens=src_vocab_size, dim_model=8, num_heads=2, num_encoder_layers=3, num_decoder_layers=3, dropout_p=0.1).to(device)
opt = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

def train_loop(model, opt, loss_fn, dataloader):
    model.train()
    total_loss = 0
    
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        y_input = y[:, :-1]
        y_expected = y[:, 1:]
        
        sequence_length = y_input.size(1)
        tgt_mask = model.get_tgt_mask(sequence_length).to(device)

        pred = model(X, y_input, tgt_mask)
        
        # Flatten the predictions and expected outputs for computing the loss
        pred = pred.view(-1, pred.size(-1))
        y_expected = y_expected.contiguous().view(-1)
        
        loss = loss_fn(pred, y_expected)

        opt.zero_grad()
        loss.backward()
        opt.step()
    
        total_loss += loss.item()
        
    return total_loss / len(dataloader)

def validation_loop(model, loss_fn, dataloader):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)

            y_input = y[:, :-1]
            y_expected = y[:, 1:]
            
            sequence_length = y_input.size(1)
            tgt_mask = model.get_tgt_mask(sequence_length).to(device)

            pred = model(X, y_input, tgt_mask)
            
            # Flatten the predictions and expected outputs for computing the loss
            pred = pred.view(-1, pred.size(-1))
            y_expected = y_expected.contiguous().view(-1)
            
            loss = loss_fn(pred, y_expected)
            total_loss += loss.item()
        
    return total_loss / len(dataloader)

def fit(model, opt, loss_fn, train_dataloader, val_dataloader, epochs):
    train_loss_list, validation_loss_list = [], []
    
    print("Training and validating model")
    for epoch in range(epochs):
        print("-"*25, f"Epoch {epoch + 1}","-"*25)
        
        train_loss = train_loop(model, opt, loss_fn, train_dataloader)
        train_loss_list.append(train_loss)
        
        validation_loss = validation_loop(model, loss_fn, val_dataloader)
        validation_loss_list.append(validation_loss)
        
        print(f"Training loss: {train_loss:.4f}")
        print(f"Validation loss: {validation_loss:.4f}")
        print()
        
    return train_loss_list, validation_loss_list

train_loss_list, validation_loss_list = fit(model, opt, loss_fn, train_loader, test_loader, 10)


Training and validating model
------------------------- Epoch 1 -------------------------
Training loss: 3.5624
Validation loss: 3.0133

------------------------- Epoch 2 -------------------------
Training loss: 2.9007
Validation loss: 2.6253

------------------------- Epoch 3 -------------------------
Training loss: 2.5886
Validation loss: 2.3684

------------------------- Epoch 4 -------------------------
Training loss: 2.3817
Validation loss: 2.1973

------------------------- Epoch 5 -------------------------
Training loss: 2.2283
Validation loss: 2.0614

------------------------- Epoch 6 -------------------------
Training loss: 2.1030
Validation loss: 1.9429

------------------------- Epoch 7 -------------------------
Training loss: 1.9904
Validation loss: 1.8249

------------------------- Epoch 8 -------------------------
Training loss: 1.8837
Validation loss: 1.7295

------------------------- Epoch 9 -------------------------
Training loss: 1.7870
Validation loss: 1.6298

-------