## Spectra Encoder: Transformer

Primary reference: https://www.nature.com/articles/s42004-023-00932-3#Sec19

Using this paper as a framework, the purpose of this transformer is to take input GC-MS spectral data and output embeddings to be passed to the SMILES decoder. The reference used images of GC-MS data and implemented a CNN; we intend to use a transformer instead.  

#### Supplemental references:
https://jalammar.github.io/illustrated-transformer/ (Illustrated overview of Transformer function)

https://nlp.seas.harvard.edu/2018/04/03/attention.html (Harvard coding annotation of original Transformation paper)

https://www.datacamp.com/tutorial/building-a-transformer-with-py-torch (Datacamp Transformer tutorial)

Notebook overview:
1. Define model building blocks
2. Encoding
3. Decoding
4. Training
5. Evaluation


## Preparing the input data

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
import pandas as pd
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
import csv

In [2]:
# finding unique characters in the SMILES column of training data 

unique_characters = set() 

with open('dataset/filtered_gc_spec.csv', 'r') as f:
    reader = csv.DictReader(f)  
    for row in reader:
        for char in row["SMILES"]:
            unique_characters.add(char)  # Add each character to the set

print(len(unique_characters))  


45


In [3]:
# finding unique tuples in the spectral training data
unique_tuples = set()  

with open('dataset/filtered_gc_spec.csv', 'r') as f:
    reader = csv.DictReader(f)
    for row in reader:
        spectrum_data = row["Spectrum"]
        tuples = spectrum_data.split() 
        for tup in tuples:
            if ':' in tup and tup.count(':') == 1: 
                unique_tuples.add(tup)

print(len(unique_tuples))  


517627


So we need to go from a "vocabulary" of 517627 unique tuples to 45 unique characters

## Implementing model

In [None]:
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from transformers import BartTokenizer  

# BART tokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")

# load data
data = pd.read_csv('dataset/filtered_gc_spec.csv')
input_MS = pd.Series(data["Spectrum"][:100])
output_SMILES = pd.Series(data["SMILES"][:100])

assert len(input_MS) == len(output_SMILES)  # sanity check to ensure correct loading

# filter input by length of SMILES (<77 as per SMILES encoder)
output_SMILES_filtered = output_SMILES[output_SMILES.str.len() < 77]
input_MS_filtered = input_MS.loc[output_SMILES_filtered.index]

assert len(input_MS_filtered) == len(output_SMILES_filtered)  # sanity check to ensure correct filtering

print(f"Number of MS Spectra for input: {len(input_MS_filtered)}")
print(f"Number of SMILES sequences for output: {len(output_SMILES_filtered)}")

smiles_list = output_SMILES_filtered.tolist()

# tokenize SMILES data using BART tokenizer
tokenized_smiles = [tokenizer.encode(smiles, add_special_tokens=True, padding='max_length', truncation=True, max_length=77) for smiles in smiles_list]

smiles_tensor = torch.tensor(tokenized_smiles, dtype=torch.long)

# set max length of MS to 200
max_length = 200

# converting MS data to tensor
def spec_2_tensor(spectrum, max_length):
    spectrum_tuples = [(float(mz), float(intensity)) for mz, intensity in (item.split(":") for item in spectrum.split())]
    return spectrum_tuples[:max_length] + [(0, 0)] * (max_length - len(spectrum_tuples))  # padding as zeroes

input_MS_data = input_MS_filtered.apply(lambda x: spec_2_tensor(x, max_length))
ms_tensor = torch.tensor(input_MS_data.tolist(), dtype=torch.float32)

# linear layer to map from 2 features to vocab size
src_vocab_size = 5000 #should be 500,000, but my computer can't handle it
linear_layer = nn.Linear(2, src_vocab_size)

# flatten, transform, and reshape back
ms_tensor_flat = ms_tensor.view(-1, 2)  # flatten for batch processing
ms_tensor_transformed = linear_layer(ms_tensor_flat)
ms_tensor_indices = ms_tensor_transformed.argmax(dim=1)  # select the index with the highest value
ms_tensor_indices = ms_tensor_indices.view(len(input_MS_filtered), max_length)  # reshape back

# train/test split
ms_train, ms_test, smiles_train, smiles_test = train_test_split(ms_tensor_indices.numpy(), smiles_tensor.numpy(), test_size=0.2, random_state=42)

# convert back to tensors
ms_train = torch.tensor(ms_train, dtype=torch.long)
ms_test = torch.tensor(ms_test, dtype=torch.long)
smiles_train = torch.tensor(smiles_train, dtype=torch.long)
smiles_test = torch.tensor(smiles_test, dtype=torch.long)

# create datasets and dataloaders
batch_size_train = 10 #keep this low to avoid kernel crashing
batch_size_test = 10

train_dataset = TensorDataset(ms_train, smiles_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)

test_dataset = TensorDataset(ms_test, smiles_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=True)

Number of MS Spectra for input: 100
Number of SMILES sequences for output: 100


In [None]:
import torch
import torch.nn as nn
import math
from transformers import BartTokenizer  # import BartTokenizer

# initialize the BART tokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")

# Positional encoder
class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len):
        super().__init__()
        self.dropout = nn.Dropout(dropout_p)
        
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1)
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model)
        
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding", pos_encoding)
        
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])

class Transformer(nn.Module):
    def __init__(self, num_tokens, dim_model, num_heads, num_encoder_layers, num_decoder_layers, dropout_p):
        super().__init__()

        self.model_type = "Transformer"
        self.dim_model = dim_model

        self.positional_encoder = PositionalEncoding(dim_model=dim_model, dropout_p=dropout_p, max_len=5000)
        self.embedding = nn.Embedding(num_tokens, dim_model)
        
        self.transformer = nn.Transformer(
            d_model=dim_model, 
            nhead=num_heads, 
            num_encoder_layers=num_encoder_layers, 
            num_decoder_layers=num_decoder_layers, 
            dropout=dropout_p, 
            batch_first=True  # setting this to True for nested tensors, error otherwise
        )
        self.out = nn.Linear(dim_model, num_tokens)
        
    def forward(self, src, tgt, tgt_mask=None, src_pad_mask=None, tgt_pad_mask=None):
        src = self.embedding(src) * math.sqrt(self.dim_model)
        tgt = self.embedding(tgt) * math.sqrt(self.dim_model)
        src = self.positional_encoder(src)
        tgt = self.positional_encoder(tgt)

        transformer_out = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask)
        out = self.out(transformer_out)
        
        return out
    
    def get_tgt_mask(self, size) -> torch.tensor:
        mask = torch.tril(torch.ones(size, size) == 1)
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf'))
        mask = mask.masked_fill(mask == 1, float(0.0))
        return mask
    
    def create_pad_mask(self, matrix: torch.tensor, pad_token: int) -> torch.tensor:
        return (matrix == pad_token)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(num_tokens=tokenizer.vocab_size, dim_model=8, num_heads=8, num_encoder_layers=12, num_decoder_layers=12, dropout_p=0.1).to(device)
opt = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

from sklearn.metrics import accuracy_score, precision_score, f1_score


def untokenize_smiles(token_ids):
    #remove autogenerated padding tokens
    token_ids = [token_id for token_id in token_ids if token_id != tokenizer.pad_token_id]
    
    # convert token ids back to tokens
    tokens = tokenizer.convert_ids_to_tokens(token_ids)

    # tokens into a string 
    smiles = tokenizer.convert_tokens_to_string(tokens)
    
    return smiles

def calculate_metrics(predictions, targets, average='macro'):    
    # flatten the lists of token IDs to compare tokens at each position in the sequence
    predictions_flat = [item for sublist in predictions for item in sublist]
    targets_flat = [item for sublist in targets for item in sublist]

    # token-level comparison
    accuracy = accuracy_score(targets_flat, predictions_flat)

    # token-level comparisons
    precision = precision_score(targets_flat, predictions_flat, average=average, zero_division=1)
    f1 = f1_score(targets_flat, predictions_flat, average=average, zero_division=1)

    return accuracy, precision, f1

def train_loop(model, opt, loss_fn, train_dataloader, tokenizer):
    model.train()
    total_loss = 0.0
    all_predictions = []
    all_targets = []

    for batch in train_dataloader:
        X, y = batch
        opt.zero_grad()

        pred = model(X, y)  # [batch_size, seq_len, vocab_size]
        
        # reshape pred to [batch_size * seq_len, vocab_size] 
        pred_flat = pred.view(-1, pred.size(-1))  # flatten sequence and vocab dimensions
        y_flat = y.view(-1)  # flatten target sequence

        # loss cross entropy expects input: [batch_size * seq_len, vocab_size] targets: [batch_size * seq_len])
        loss = loss_fn(pred_flat, y_flat)
        
        loss.backward()
        opt.step()

        # pred token ids
        pred_token_ids = pred.argmax(dim=-1).tolist()

        # tokenized seq metrics
        accuracy, precision, f1 = calculate_metrics(pred_token_ids, y.tolist())
        print(f"Train Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, F1: {f1:.4f}")

        # store preds and targets
        all_predictions.extend(pred_token_ids)
        all_targets.extend(y.tolist())

        total_loss += loss.item()

    return total_loss / len(train_dataloader), all_predictions, all_targets

def val_loop(model, loss_fn, val_dataloader, tokenizer):
    model.eval()
    total_loss = 0.0
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for batch in val_dataloader:
            X, y = batch
            pred = model(X, y)

            pred_flat = pred.view(-1, pred.size(-1))  
            y_flat = y.view(-1)  

            loss = loss_fn(pred_flat, y_flat)

            pred_token_ids = pred.argmax(dim=-1).tolist()

            accuracy, precision, f1 = calculate_metrics(pred_token_ids, y.tolist())
            print(f"Validation Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, F1: {f1:.4f}")

            all_predictions.extend(pred_token_ids)
            all_targets.extend(y.tolist())

            total_loss += loss.item()

    return total_loss / len(val_dataloader), all_predictions, all_targets

def fit(model, opt, loss_fn, train_loader, val_loader, epochs, tokenizer):
    train_loss_list = []
    validation_loss_list = []

    for epoch in range(epochs):
        print("-"*25, f"Epoch {epoch + 1}","-"*25)

        # train 
        train_loss, train_preds, train_targets = train_loop(model, opt, loss_fn, train_loader, tokenizer)
        train_loss_list.append(train_loss)

        # validation 
        val_loss, val_preds, val_targets = val_loop(model, loss_fn, val_loader, tokenizer)
        validation_loss_list.append(val_loss)

        # save generated SMILES
        #generated_smiles = [untokenize_smiles(p) for p in val_preds]  
        #smiles_df = pd.DataFrame({'Generated_SMILES': generated_smiles})
        #smiles_df.to_csv(f'generated_smiles_epoch_{epoch+1}.csv', index=False)

        print(f"Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
        print("-"*50)

    return train_loss_list, validation_loss_list

train_loss_list, validation_loss_list = fit(model, opt, loss_fn, train_loader, test_loader, 20, tokenizer)


------------------------- Epoch 1 -------------------------
Train Accuracy: 0.0000, Precision: 0.4894, F1: 0.0000
Train Accuracy: 0.0000, Precision: 0.3529, F1: 0.0000
Train Accuracy: 0.0000, Precision: 0.2727, F1: 0.0000
Train Accuracy: 0.0000, Precision: 0.2273, F1: 0.0000
Train Accuracy: 0.0000, Precision: 0.6970, F1: 0.0000
Train Accuracy: 0.0000, Precision: 0.6286, F1: 0.0000
Train Accuracy: 0.0000, Precision: 0.4783, F1: 0.0000
Train Accuracy: 0.0000, Precision: 0.4800, F1: 0.0000
Validation Accuracy: 0.0000, Precision: 0.9565, F1: 0.0000
Validation Accuracy: 0.0000, Precision: 0.9565, F1: 0.0000
Train Loss: 10.5278, Validation Loss: 10.0218
--------------------------------------------------
------------------------- Epoch 2 -------------------------
Train Accuracy: 0.0000, Precision: 0.3958, F1: 0.0000
Train Accuracy: 0.0000, Precision: 0.4792, F1: 0.0000
Train Accuracy: 0.0000, Precision: 0.4318, F1: 0.0000
Train Accuracy: 0.0000, Precision: 0.5227, F1: 0.0000
Train Accuracy: 0