## Spectra Encoder: Transformer

Primary reference: https://www.nature.com/articles/s42004-023-00932-3#Sec19

Using this paper as a framework, the purpose of this transformer is to take input GC-MS spectral data and output embeddings to be passed to the SMILES decoder. The reference used images of GC-MS data and implemented a CNN; we intend to use a transformer instead.  

#### Supplemental references:
https://jalammar.github.io/illustrated-transformer/ (Illustrated overview of Transformer function)

https://nlp.seas.harvard.edu/2018/04/03/attention.html (Harvard coding annotation of original Transformation paper)

https://www.datacamp.com/tutorial/building-a-transformer-with-py-torch (Datacamp Transformer tutorial)

Notebook overview:
1. Define model building blocks
2. Encoding
3. Decoding
4. Training
5. Evaluation


## Preparing the input data

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
import pandas as pd
import numpy as np
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
import csv

In [2]:
# finding unique characters in the SMILES column of training data 

unique_characters = set() 

with open('dataset/filtered_gc_spec.csv', 'r') as f:
    reader = csv.DictReader(f)  
    for row in reader:
        for char in row["SMILES"]:
            unique_characters.add(char)  # Add each character to the set

print(len(unique_characters))  


45


In [3]:
# finding unique tuples in the spectral training data
unique_tuples = set()  

with open('dataset/filtered_gc_spec.csv', 'r') as f:
    reader = csv.DictReader(f)
    for row in reader:
        spectrum_data = row["Spectrum"]
        tuples = spectrum_data.split() 
        for tup in tuples:
            if ':' in tup and tup.count(':') == 1: 
                unique_tuples.add(tup)

print(len(unique_tuples))  


517627


So we need to go from a "vocabulary" of 517627 unique tuples to 45 unique characters

In [2]:
# load input and output datasets for training

# load dataset
data = pd.read_csv('dataset/filtered_gc_spec.csv')
input_MS = pd.Series(data["Spectrum"][:5])
output_SMILES = pd.Series(data["SMILES"][:5])

assert len(input_MS)==len(output_SMILES) #sanity check to ensure correct loading

In [3]:
#filter input by length of SMILES (<77 as per SMILES encoder)

output_SMILES_filtered = output_SMILES[output_SMILES.str.len() < 77]

#filter input by the same indices
input_MS_filtered = input_MS.loc[output_SMILES_filtered.index]

assert len(input_MS_filtered)==len(output_SMILES_filtered) #sanity check to ensure correct filtering

print(f"Number of GC-MS Spectra for input: {len(input_MS_filtered)}")
print(f"Number of SMILES sequences for output: {len(output_SMILES_filtered)}")

Number of GC-MS Spectra for input: 5
Number of SMILES sequences for output: 5


## Tokenizing the SMILES data for training

In [145]:
pip install deepchem

Collecting deepchem
  Downloading deepchem-2.8.0-py3-none-any.whl.metadata (2.0 kB)
Downloading deepchem-2.8.0-py3-none-any.whl (1.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[?25hInstalling collected packages: deepchem
Successfully installed deepchem-2.8.0
Note: you may need to restart the kernel to use updated packages.


In [146]:
from deepchem.feat.smiles_tokenizer import SmilesTokenizer

No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!
2024-12-01 19:19:06.842836: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1733109546.953096 1786219 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1733109546.988163 1786219 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-01 19:19:07.245322: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Instructions for updating:
experimental_relax_shapes is deprecated, use reduce_retracing instead


Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/home/marieanand/miniconda3/envs/msse-python/lib/python3.11/site-packages/deepchem/models/torch_models/__init__.py)
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'lightning'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


In [6]:
pip install transformers

Collecting transformers
  Downloading transformers-4.46.3-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.1/44.1 kB[0m [31m902.4 kB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
Collecting tokenizers<0.21,>=0.20 (from transformers)
  Downloading tokenizers-0.20.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting safetensors>=0.4.1 (from transformers)
  Downloading safetensors-0.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)
Downloading transformers-4.46.3-py3-none-any.whl (10.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m29.1 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25hDownloading safetensors-0.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (435 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m435.0/435.0 kB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tokenizers

In [9]:
pip install tokenizers

Note: you may need to restart the kernel to use updated packages.


In [4]:
smiles_list = output_SMILES_filtered.tolist()

len(smiles_list)

5

In [5]:
# Requriments - transformers, tokenizers
# Right now, the Smiles Tokenizer uses an exiesting vocab file from rxnfp that is fairly comprehensive and from the USPTO dataset.
# The vocab may be expanded in the near future

import collections
import os
import re
import pkg_resources
from typing import List
from transformers import BertTokenizer
from logging import getLogger

logger = getLogger(__name__)
"""
SMI_REGEX_PATTERN: str
    SMILES regex pattern for tokenization. Designed by Schwaller et. al.

References

.. [1]  Philippe Schwaller, Teodoro Laino, Théophile Gaudin, Peter Bolgar, Christopher A. Hunter, Costas Bekas, and Alpha A. Lee
        ACS Central Science 2019 5 (9): Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction
        1572-1583 DOI: 10.1021/acscentsci.9b00576

"""

SMI_REGEX_PATTERN = r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|
#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""

# add vocab_file dict
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}


class SmilesTokenizer(BertTokenizer):
    def __init__(self, vocab_file: str = '', max_len: int = 512, **kwargs):
        """Constructs a SmilesTokenizer.

        Parameters
        ----------
        vocab_file: str
            Path to a SMILES character per line vocabulary file.
            Default vocab file is found in deepchem/feat/tests/data/vocab.txt
        max_len: int
            Maximum length for tokenized sequences.
        """
        self.max_len = max_len

        super().__init__(vocab_file=vocab_file, max_len=max_len, **kwargs)
        
        # take into account special tokens in max length
        #self.max_len_single_sentence = self.max_len - 2
        #self.max_len_sentences_pair = self.max_len - 3

        if not os.path.isfile(vocab_file):
            raise ValueError(
                "Can't find a vocab file at path '{}'.".format(vocab_file)
            )
        self.vocab = load_vocab(vocab_file)
        self.highest_unused_index = max(
            [i for i, v in enumerate(self.vocab.keys()) if v.startswith("[unused")])
        self.ids_to_tokens = collections.OrderedDict(
            [(ids, tok) for tok, ids in self.vocab.items()])
        self.basic_tokenizer = BasicSmilesTokenizer()
        self.init_kwargs["max_len"] = self.max_len


    @property
    def vocab_size(self):
      return len(self.vocab)

    @property
    def vocab_list(self):
      return list(self.vocab.keys())

    def _tokenize(self, text: str):
      """
          Tokenize a string into a list of tokens.

          Parameters
          ----------
          text: str
              Input string sequence to be tokenized.
          """

      split_tokens = [token for token in self.basic_tokenizer.tokenize(text)]
      return split_tokens

    def _convert_token_to_id(self, token):
      """
          Converts a token (str/unicode) in an id using the vocab.

          Parameters
          ----------
          token: str
              String token from a larger sequence to be converted to a numerical id.
          """

      return self.vocab.get(token, self.vocab.get(self.unk_token))

    def _convert_id_to_token(self, index):
      """
          Converts an index (integer) in a token (string/unicode) using the vocab.

          Parameters
          ----------
          index: int
              Integer index to be converted back to a string-based token as part of a larger sequence.
          """

      return self.ids_to_tokens.get(index, self.unk_token)

    def convert_tokens_to_string(self, tokens: List[str]):
      """ Converts a sequence of tokens (string) in a single string.

          Parameters
          ----------
          tokens: List[str]
              List of tokens for a given string sequence.

          Returns
          -------
          out_string: str
              Single string from combined tokens.
          """

      out_string: str = " ".join(tokens).replace(" ##", "").strip()
      return out_string

    def add_special_tokens_ids_single_sequence(self, token_ids: List[int]):
      """
          Adds special tokens to the a sequence for sequence classification tasks.
          A BERT sequence has the following format: [CLS] X [SEP]

          Parameters
          ----------

          token_ids: list[int]
              list of tokenized input ids. Can be obtained using the encode or encode_plus methods.
          """

      return [self.cls_token_id] + token_ids + [self.sep_token_id]

    def add_special_tokens_single_sequence(self, tokens: List[str]):
      """
          Adds special tokens to the a sequence for sequence classification tasks.
          A BERT sequence has the following format: [CLS] X [SEP]

          Parameters
          ----------
          tokens: List[str]
              List of tokens for a given string sequence.

          """
      return [self.cls_token] + tokens + [self.sep_token]

    def add_special_tokens_ids_sequence_pair(self, token_ids_0: List[int],
                                            token_ids_1: List[int]) -> List[int]:
      """
          Adds special tokens to a sequence pair for sequence classification tasks.
          A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]

          Parameters
          ----------
          token_ids_0: List[int]
              List of ids for the first string sequence in the sequence pair (A).

          token_ids_1: List[int]
              List of tokens for the second string sequence in the sequence pair (B).
          """

      sep = [self.sep_token_id]
      cls = [self.cls_token_id]

      return cls + token_ids_0 + sep + token_ids_1 + sep

    def add_padding_tokens(self,
                          token_ids: List[int],
                          length: int,
                          right: bool = True) -> List[int]:
      """
          Adds padding tokens to return a sequence of length max_length.
          By default padding tokens are added to the right of the sequence.

          Parameters
          ----------
          token_ids: list[int]
              list of tokenized input ids. Can be obtained using the encode or encode_plus methods.

          length: int

          right: bool (True by default)

          Returns
          ----------
          token_ids :
              list of tokenized input ids. Can be obtained using the encode or encode_plus methods.

          padding: int
              Integer to be added as padding token

          """
      padding = [self.pad_token_id] * (length - len(token_ids))

      if right:
        return token_ids + padding
      else:
        return padding + token_ids

    def save_vocabulary(
        self, vocab_path: str
    ):  # -> tuple[str]: doctest issue raised with this return type annotation
      """
          Save the tokenizer vocabulary to a file.

          Parameters
          ----------
          vocab_path: obj: str
              The directory in which to save the SMILES character per line vocabulary file.
              Default vocab file is found in deepchem/feat/tests/data/vocab.txt

          Returns
          ----------
          vocab_file: :obj:`Tuple(str)`:
              Paths to the files saved.
              typle with string to a SMILES character per line vocabulary file.
              Default vocab file is found in deepchem/feat/tests/data/vocab.txt

          """
      index = 0
      if os.path.isdir(vocab_path):
        vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
      else:
        vocab_file = vocab_path
      with open(vocab_file, "w", encoding="utf-8") as writer:
        for token, token_index in sorted(
            self.vocab.items(), key=lambda kv: kv[1]):
          if index != token_index:
            logger.warning(
                "Saving vocabulary to {}: vocabulary indices are not consecutive."
                " Please check that the vocabulary is not corrupted!".format(
                    vocab_file))
            index = token_index
          writer.write(token + "\n")
          index += 1
      return (vocab_file,)


class BasicSmilesTokenizer(object):
  """

    Run basic SMILES tokenization using a regex pattern developed by Schwaller et. al. This tokenizer is to be used
    when a tokenizer that does not require the transformers library by HuggingFace is required.

    Examples
    --------
    >>> from deepchem.feat.smiles_tokenizer import BasicSmilesTokenizer
    >>> tokenizer = BasicSmilesTokenizer()
    >>> print(tokenizer.tokenize("CC(=O)OC1=CC=CC=C1C(=O)O"))
    ['C', 'C', '(', '=', 'O', ')', 'O', 'C', '1', '=', 'C', 'C', '=', 'C', 'C', '=', 'C', '1', 'C', '(', '=', 'O', ')', 'O']


    References
    ----------
    .. [1]  Philippe Schwaller, Teodoro Laino, Théophile Gaudin, Peter Bolgar, Christopher A. Hunter, Costas Bekas, and Alpha A. Lee
            ACS Central Science 2019 5 (9): Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction
            1572-1583 DOI: 10.1021/acscentsci.9b00576

    """

  def __init__(self, regex_pattern: str = SMI_REGEX_PATTERN):
    """ Constructs a BasicSMILESTokenizer.
        Parameters
        ----------

        regex: string
            SMILES token regex

        """
    self.regex_pattern = regex_pattern
    self.regex = re.compile(self.regex_pattern)

  def tokenize(self, text):
    """ Basic Tokenization of a SMILES.
        """
    tokens = [token for token in self.regex.findall(text)]
    return tokens


def load_vocab(vocab_file):
  """Loads a vocabulary file into a dictionary."""
  vocab = collections.OrderedDict()
  with open(vocab_file, "r", encoding="utf-8") as reader:
    tokens = reader.readlines()
  for index, token in enumerate(tokens):
    token = token.rstrip("\n")
    vocab[token] = index
  return vocab

  import pkg_resources


In [6]:
vocab_file = 'vocab.txt'
tokenizer = SmilesTokenizer(vocab_file)

In [7]:
#define a function to pad the sequence if they're not the max length (6))

max_length = 64

def pad_seq(tokens, max_length):
    return tokens + [0] * (max_length - len(tokens))  # padding


#tokenize SMILES data
tokenized_smiles = [tokenizer.encode(smiles) for smiles in smiles_list]

#pad all entries
padded_smiles = [pad_seq(tokens, max_length) for tokens in tokenized_smiles]

#convert to tensor
smiles_tensor = torch.tensor(padded_smiles)

In [8]:
print(smiles_tensor.shape)

torch.Size([5, 64])


In [9]:
#converting the spectrum data to tensors


def spec_2_tensor(spectrum: str, max_length: 64):
    #takes in a string that has the spectrum and the max length
    spectrum_tuples = [(float(mz), float(intensity)) for mz, intensity in (item.split(":") for item in spectrum.split())]
    return spectrum_tuples[:max_length] + [(0, 0)] * (max_length - len(spectrum_tuples))  # Padding

input_MS_data = input_MS_filtered.apply(lambda x: spec_2_tensor(x, 64))

ms_tensor = torch.tensor(input_MS_data.tolist())


In [10]:
print(ms_tensor.shape)

torch.Size([5, 64, 2])


In [11]:
#create batches

from torch.utils.data import DataLoader, TensorDataset



# create a dataset and DataLoader
dataset = TensorDataset(ms_tensor, smiles_tensor)
batch_size = 8
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


# Defining model components

In [37]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
       
        # model parameters
        self.d_model = d_model # dimensions of the model
        self.num_heads = num_heads # number of attention heads
        self.d_k = d_model // num_heads # dimension of each head's key, query, and value
        
        # transformations
        self.W_q = nn.Linear(d_model, d_model) # query transformation
        self.W_k = nn.Linear(d_model, d_model) # key transformation
        self.W_v = nn.Linear(d_model, d_model) # value transformation
        self.W_o = nn.Linear(d_model, d_model) # output transformation
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # apply mask if necessary
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        # softmax to get attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1)
        
        # compute final output
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        # Reshape the input to have num_heads
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        # Combine heads back to original shape
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        # Apply linear transformations and split heads
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        # Apply scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Combine heads and apply output transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output


In [13]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff): #inputs - dimensions and inner-layer dimensions
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [14]:
#embeddings class converts input/output tokens to vectors specified by the dimensions of our model
#softmax converts the output to probabilities

class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

In [15]:
#positional encoding is used to inject token position info into the input
#otherwise transformer has no info about token position in the input sequence
#essentially uses offset sin/cos graphs based on position. freq/offset is different for each dimension

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
    # Ensure the dimensions match before adding
        if x.size(2) != self.pe.size(2):
            raise ValueError(f"Dimension mismatch: {x.size(2)} != {self.pe.size(2)}")
        return x + self.pe[:, :x.size(1)]
    #adds positional info to the input

## Implementing model

In [16]:
#defining encoder layer for class
#steps: multiattention, position feed forward, 2x layer normalization, dropout

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads) #self attention mechanism
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff) #attention output appended to original input
        self.norm1 = nn.LayerNorm(d_model) #normalization
        self.norm2 = nn.LayerNorm(d_model) #normalization
        self.dropout = nn.Dropout(dropout) #dropout
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [17]:
#defining decoder layer for class
#steps: self attention, normalize, cross attention, normalize, feed forward, normalize

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        #initialize with previously designated parameters
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x


In [30]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()

        # Embedding layers
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)  # src embedding
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)  # tgt embedding layer

        # Positional encoding
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        # Transformer layers
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])  # encoder layers
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])  # decoder layers

        # Output layer
        self.fc = nn.Linear(d_model, tgt_vocab_size)  # linear layer to map to target vocab size
        self.dropout = nn.Dropout(dropout)  # dropout layer

    def generate_mask(self, src, tgt, pad_idx):
        # sequence lengths for source and target
        src_len = src.size(1)  # source sequence length
        tgt_len = tgt.size(1)  # target sequence length

        # create source mask (1 for real tokens, 0 for padding tokens)
        src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)  # (batch_size, 1, 1, seq_len)

        # generate nopeak mask for the target sequence (upper triangular matrix)
        nopeak_mask = torch.triu(torch.ones(tgt_len, tgt_len), diagonal=1).bool()  # Shape: (tgt_len, tgt_len)

        # adjust nopeak_mask size to match the target sequence length dynamically
        nopeak_mask = nopeak_mask[:tgt_len, :tgt_len]  # ensure nopeak_mask size matches target length

        # ensure the tgt_mask has compatible dimensions for broadcasting
        tgt_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(3)  # Shape: (batch_size, 1, seq_len, 1)
        tgt_mask = tgt_mask & nopeak_mask.unsqueeze(0).unsqueeze(0)  # Shape: (batch_size, 1, seq_len, seq_len)

        return src_mask, tgt_mask


    def forward(self, src, tgt, pad_idx=0):
        # Generate masks
        src_mask, tgt_mask = self.generate_mask(src, tgt, pad_idx)
        
        # apply embedding + positional encoding
        src_embedded = self.encoder_embedding(src)
        src_embedded = self.positional_encoding(src_embedded)  # Add positional encoding to source embeddings
        src_embedded = self.dropout(src_embedded)  # Apply dropout to source embeddings

        tgt_embedded = self.decoder_embedding(tgt)
        tgt_embedded = self.positional_encoding(tgt_embedded)  # Add positional encoding to target embeddings
        tgt_embedded = self.dropout(tgt_embedded)  # Apply dropout to target embeddings

        # Pass through encoder layers
        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        # Pass through decoder layers
        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        # Linear layer to output the final predictions
        output = self.fc(dec_output)
        return output


## Training

In [42]:
# define model parameters
src_vocab_size = 64
tgt_vocab_size = max(smiles_tensor.max() + 1, 64)  # ensure vocab size is at least the largest index in smiles_tensor
d_model = 64
num_heads = 8
num_layers = 6
d_ff = 2048  # From original paper
max_seq_len = max(ms_tensor.size(1), smiles_tensor.size(1))
dropout = 0.1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [45]:
linear_layer = nn.Linear(2, src_vocab_size)  # map from 2 features to vocab size
ms_tensor_indices = linear_layer(ms_tensor.view(-1, 2))  # flatten for batch processing

ms_tensor_indices = ms_tensor_indices.argmax(dim=1)  # select the index with the highest value
ms_tensor_indices = ms_tensor_indices.view(5, 64)  # reshape back to (batch_size, seq_length)

In [46]:
transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout)

transformer = transformer.to(device)
ms_tensor_indices = ms_tensor_indices.to(device)
smiles_tensor = smiles_tensor.to(device)

ms_tensor_indices = ms_tensor_indices.long()
smiles_tensor = smiles_tensor.long()

# Forward pass
output = transformer(ms_tensor_indices, smiles_tensor, pad_idx=0)

ValueError: too many values to unpack (expected 4)

## Evaluation

In [None]:
#code from example, haven't tested yet

transformer.eval()
#setting transformer into evaluation mode

# Generate random sample validation data
val_src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
val_tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

with torch.no_grad():

    val_output = transformer(val_src_data, val_tgt_data[:, :-1])
    val_loss = criterion(val_output.contiguous().view(-1, tgt_vocab_size), val_tgt_data[:, 1:].contiguous().view(-1))
    print(f"Validation Loss: {val_loss.item()}")