# BERT Model for understanding mutation in protien sequences

## 0. init

In [1]:
import numpy as np
import random
import pandas as pd
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import tqdm
from torch.optim import Adam
import time

import string
from typing import Iterable, Tuple, Optional, List

## 1. Building Dataset and DataLoader

In [2]:
file_path = 'X_set.txt'

# Initialize lists to hold the phylogenetic position strings and amino acid sequences
specie_code = []
amino_acid_sequences = []

# Read the file
with open(file_path, 'r') as file:
    for line in file:
        parts = line.strip().split(' ')
        specie_code.append(parts[0])
        amino_acid_sequences.append(parts[1])

amino_acid_sequences[0:3]

['---LSQF--LLMLWVPGSKGEIVLTQSPASVSVSPGERVTISCQASESVGNTYLNWLQQKSGQSPRWLIYQVSKLESGIPARFRGSGSGTDFTFTISRVEAEDVAHYYSQQ-----',
 'MESLSQC--LLMLWVPVSRGAIVLTQSPALVSVSPGERVTISCKASQSVGNTYLSWFRQKPGQSPRGLIYKVSNLPSGVPSRFRGSGAEKDFTLTISRVEAVDGAVYYCAQASYSP',
 'MESLSQC--LLMLWVPVSRGAIVLTQSPASVSVSPGERVTISCKASQSLGNTYLHWFQQKPGQSPRRLIYQVSNLLSGVPSRFSGSGAGKDFSLTISSVEAGDGAVYYCFQGSYDP']

In [3]:
len(amino_acid_sequences)

1001

In [4]:
# we would need two more set of information 

# 1. protien type 
protein_types = ['A1'] * len(amino_acid_sequences)

# 2. weights for species
specie_weight = torch.rand(len(amino_acid_sequences))


### 1.1. Tokenizer

- There are 20 amino acids, each letter in the chain represents one of them. 
- Converting them into 20 tokens, meaning each amino acid would get a number associated with it. 
- Would also need a special character token, which is "-", something related to multiple-sequence-alignment 

In [5]:
# making sure that there are only 20 diff types of amino acids 

amino_acid_set = set()

for seq in amino_acid_sequences:
    for acid in seq:
        if acid != "-":
            amino_acid_set.add(acid)

# 20 amino acids
print(f"Num of Amino Acids: {len(amino_acid_set) }")
amino_acids_list = list(amino_acid_set)

Num of Amino Acids: 20


In [6]:
# Creating a Tokenzer class, which ennodes and decodes an amino acid sequence 

class AminoAcidTokenizer:
    ''' 
    To encode and decode any amino acid string
    '''
    # class attribute
    # all 20 types of amino acids
    amino_acids = ['S','D','H','L','T','E','W','N','Y','Q','C','G','V','K','I','R','M','F','A','P']

    def __init__(self, special_tokens: Optional[Iterable[str]] = None):
        # define a vocab
        self.vocab = AminoAcidTokenizer.amino_acids
        if special_tokens:
            self.vocab += list(special_tokens)
        # mapping each vocab to a token (a numeric value)
        self.token2idx = {token:i for i, token in enumerate(self.vocab)} 
        # mapping numeric value back to a token
        self.idx2token = {i:token for token, i  in self.token2idx.items()}

    def encode(self, inputs: Iterable[str]) -> Iterable[int]:
        return [self.token2idx[token] for token in inputs]
    
    def decode(self, inputs: Iterable[int]) -> Iterable[str]:
        return [self.idx2token[idx] for idx in inputs]

    def __len__(self):
        return len(self.vocab)
    
    @property
    def vocab_size(self) -> int:
        return len(self.vocab)


In [7]:
# creating an instance of the Tokenizer. 
amino_acid_tokenizer = AminoAcidTokenizer(special_tokens=["-", "[MASK]", "[PAD]"])

# let's encode the first amino-acid-sequence and see the first 10 positions
print(f"First 20 amino acids         : {[i for i in amino_acid_sequences[0][0:20]]}")
print(f"First 20 encoded amino acids : {amino_acid_tokenizer.encode(amino_acid_sequences[0])[0:20]}")
print(f"First 20 decoded amino acids : {amino_acid_tokenizer.decode(amino_acid_tokenizer.encode(amino_acid_sequences[0])[0:20])}")

First 20 amino acids         : ['-', '-', '-', 'L', 'S', 'Q', 'F', '-', '-', 'L', 'L', 'M', 'L', 'W', 'V', 'P', 'G', 'S', 'K', 'G']
First 20 encoded amino acids : [20, 20, 20, 3, 0, 9, 17, 20, 20, 3, 3, 16, 3, 6, 12, 19, 11, 0, 13, 11]
First 20 decoded amino acids : ['-', '-', '-', 'L', 'S', 'Q', 'F', '-', '-', 'L', 'L', 'M', 'L', 'W', 'V', 'P', 'G', 'S', 'K', 'G']


In [8]:
# all tokens mapped to an idx
print(amino_acid_tokenizer.token2idx)

{'S': 0, 'D': 1, 'H': 2, 'L': 3, 'T': 4, 'E': 5, 'W': 6, 'N': 7, 'Y': 8, 'Q': 9, 'C': 10, 'G': 11, 'V': 12, 'K': 13, 'I': 14, 'R': 15, 'M': 16, 'F': 17, 'A': 18, 'P': 19, '-': 20, '[MASK]': 21, '[PAD]': 22}


In [9]:
amino_acid_tokenizer.vocab_size

23

In [10]:
# similar to Amino Acid tokenizers, creating Protien tokenizer

class ProteinTokenizer:
    '''
    To encode and decode protein types and amino acid sequences
    '''
    # class attribute
    protiens = ['A1', 'A2']

    def __init__(self, special_tokens: Iterable[str] = None):
        # define a vocab
        self.vocab = ProteinTokenizer.protiens 
        if special_tokens:   
            self.vocab += list(special_tokens)
        # mapping each vocab to a token (a numeric value)
        self.token2idx = {token:i for i, token in enumerate(self.vocab)} 
        # mapping numeric value back to a token
        self.idx2token = {i:token for token, i  in self.token2idx.items()}

    def encode(self, inputs: Iterable[str]) -> Iterable[int]:
        return [self.token2idx[token] for token in inputs]
    
    def decode(self, inputs: Iterable[int]) -> Iterable[str]:
        return [self.idx2token[idx] for idx in inputs]

    def __len__(self):
        return len(self.vocab)
    
    @property
    def vocab_size(self) -> int:
        return len(self.vocab)


In [11]:
protein_tokenizer = ProteinTokenizer()
protein_tokenizer.token2idx

{'A1': 0, 'A2': 1}

### 1.3 Create Training Data

In [12]:
# tokenizing all protein seq and protein types 
def create_encoded_tensors(
    amino_acid_sequences: List[str],
    protein_types: List[str],
    amino_acid_tokenizer: AminoAcidTokenizer,
    protein_tokenizer: ProteinTokenizer
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    """
    Create encoded tensors for amino acid sequences and protein types.

    This function takes lists of amino acid sequences and protein types,
    and encodes them using the provided tokenizers.

    Args:
        amino_acid_sequences (List[str]): List of amino acid sequences.
        protein_types (List[str]): List of protein types.
        amino_acid_tokenizer (AminoAcidTokenizer): Tokenizer for amino acid sequences.
        protein_tokenizer (ProteinTokenizer): Tokenizer for protein types.

    Returns:
        Tuple[List[torch.Tensor], List[torch.Tensor]]: A tuple containing two lists:
            1. List of encoded amino acid sequences as tensors.
            2. List of encoded protein types as tensors.
    """
    amino_acid_tensors = []
    protein_type_tensors = []

    for seq, p_type in zip(amino_acid_sequences, protein_types):
        amino_acid_tensors.append(torch.tensor(amino_acid_tokenizer.encode(seq), dtype=torch.int64))
        protein_type_tensors.append(torch.tensor(protein_tokenizer.encode([p_type]), dtype=torch.int64))

    return amino_acid_tensors, protein_type_tensors

In [13]:
encoded_amino_acids, encoded_protein_types = create_encoded_tensors(
    amino_acid_sequences, 
    protein_types, 
    amino_acid_tokenizer, 
    protein_tokenizer
)

In [14]:
# both list to have same len
print(len(encoded_amino_acids))
print(len(encoded_protein_types))

1001
1001


In [15]:
from torch.utils.data import Dataset, DataLoader

class MaskedAminoSeqDataset(Dataset):
    def __init__(self, encoded_amino_acids: list,
                encoded_protein_types: list,
                specie_weight : torch.Tensor,
                mask_token: int,
                pad_token: int,
                max_len: int ):
            """
            Dataset for masked amino acid sequence prediction.

            Args:
                encoded_amino_acids (list): List of encoded amino acid sequences.
                encoded_protein_types (list): List of encoded protein types.
                specie_weight (torch.Tensor): Weight associated with each species.
                mask_token (int): The token used for masking.
                pad_token (int): The token used for padding.
                max_len (int): Maximum length of the sequences.
            """
            self.encoded_amino_acids = encoded_amino_acids
            self.encoded_protein_types = encoded_protein_types
            self.specie_weight = specie_weight
            self.mask_token = mask_token
            self.pad_token = pad_token
            self.max_len = max_len
    

    def __len__(self):
        return len(self.encoded_amino_acids)

    def __getitem__(self, idx):
        input_seqs, target_amino_acids, mask_positions, encoded_protien, sample_weight, padding_masks = \
            self._create_training_data(self.encoded_amino_acids[idx],
                                        self.encoded_protein_types[idx],
                                        self.specie_weight[idx],
                                        self.mask_token,
                                        self.pad_token,
                                        self.max_len, 
                                        )
        
        return input_seqs.squeeze(0), target_amino_acids.squeeze(0), mask_positions.squeeze(0), encoded_protien, sample_weight, padding_masks


    def _create_training_data(self, encoded_amino_acid: torch.Tensor,
                            encoded_protein_type: torch.Tensor, 
                            specie_weight: torch.Tensor, 
                            mask_token: int,
                            pad_token: int,
                            max_len: int,  
                            min_masks: int = 1,
                            max_masks: int = 5):
        """
        Create training data for masked amino acid sequence prediction.

        This function takes an encoded amino acid sequence and applies random masking
        to create input-target pairs for training a BERT-like model. It also handles
        padding or truncation to ensure consistent sequence length.

        Args:
            encoded_amino_acid (torch.Tensor): Encoded amino acid sequence.
            encoded_protein_type (torch.Tensor): Encoded protein type.
            specie_weight (torch.Tensor): Weight associated with the species.
            mask_token (int): Token used for masking.
            pad_token (int): Token used for padding.
            max_len (int): Maximum length of the sequence.
            min_masks (int, optional): Minimum number of tokens to mask. Defaults to 1.
            max_masks (int, optional): Maximum number of tokens to mask. Defaults to 5.

        Returns:
            tuple: A tuple containing:
                - masked_seq (torch.Tensor): Input sequence with masked tokens.
                - target_seq (torch.Tensor): Target sequence for masked token prediction.
                - fixed_mask_positions (torch.Tensor): Fixed-size tensor of mask positions.
                - encoded_protein_type (torch.Tensor): Encoded protein type.
                - specie_weight (torch.Tensor): Weight associated with the species.

        Notes:
            - The function pads or truncates the input sequence to `max_len`.
            - It randomly masks between `min_masks` and `max_masks` tokens.
            - The `fixed_mask_positions` tensor has a fixed size of `max_masks`,
            with -1 values indicating unused mask positions.
            - Target sequences use -100 for non-masked positions (ignored in loss calculation).
        """

        # Get the original sequence length
        seq_len = encoded_amino_acid.shape[0]

        # Determine number of masks (based on the original sequence length)
        num_masks = torch.randint(min_masks, min(max_masks + 1, seq_len + 1), (1,)).item()

        # Create mask positions (based on the actual sequence length)
        mask_positions = torch.randperm(seq_len)[:num_masks]

        # Create masked input sequence
        masked_seq = encoded_amino_acid.clone()
        masked_seq[mask_positions] = mask_token

        # Pad or truncate the masked sequence to max_len
        if seq_len < max_len:
            padding = torch.full((max_len - seq_len,), pad_token, dtype=encoded_amino_acid.dtype)
            input_seq = torch.cat([masked_seq, padding])
            # Adjust mask_positions for padding
            mask_positions = torch.cat([mask_positions, torch.full((max_masks - num_masks,), -1, dtype=torch.long)])
        else:
            input_seq = masked_seq[:max_len]
            # Adjust mask_positions for truncation
            mask_positions = mask_positions[mask_positions < max_len]
            mask_positions = torch.cat([mask_positions, torch.full((max_masks - mask_positions.shape[0],), -1, dtype=torch.long)])

        # Create target sequence
        target_seq = torch.full((max_len,), -100, dtype=input_seq.dtype)
        target_seq[:seq_len][mask_positions[mask_positions != -1]] = encoded_amino_acid[mask_positions[mask_positions != -1]]
    
        # a vector to mark if the padding locations
        # 1 = real useful data and 0 = padded token (dont learn from it)
        padding_masks = torch.arange(max_len) < min(seq_len, max_len)

        #print(padding_mask.shape)
        # Ensure encoded_protein_type is the right shape
        if encoded_protein_type.dim() == 0:
            encoded_protein_type = encoded_protein_type.unsqueeze(0)


        return input_seq.squeeze(0), target_seq, mask_positions.squeeze(0), encoded_protein_type, specie_weight, padding_masks

In [16]:
# Assuming input_tensor is your tensor of amino acid sequences
masked_amino_seq_dataset = MaskedAminoSeqDataset(
    encoded_amino_acids=encoded_amino_acids, 
    encoded_protein_types=encoded_protein_types, 
    specie_weight=specie_weight,
    mask_token=21,
    pad_token=22, 
    max_len=120
) 
masked_amino_seq_dataloader = DataLoader(masked_amino_seq_dataset, batch_size=32, shuffle=True)

In [17]:
## each iteration now gives a batch with 32 data points.
amino_seqs, targets, mask_pos, protien, weight, padding_masks = 0, 0, 0, 0, 0, 0
for data in masked_amino_seq_dataloader:
    print(f"amino seqs with masked: shape: {data[0].shape}")
    print(f"targets amino acid:     shape: {data[1].shape}")
    print(f"mask posittions:        shape: {data[2].shape} ")
    print(f"encoded protein type:   shpae: {data[3].shape}")
    print(f"specie_weight:          shape: {data[4].shape}")
    print(f"padding mask indicator  shape: {data[5].shape}")

    amino_seqs = data[0]
    targets = data[1]
    mask_pos = data[2]
    protien = data[3]
    weight = data[4]
    paddx = data[5]
    break

amino seqs with masked: shape: torch.Size([32, 120])
targets amino acid:     shape: torch.Size([32, 120])
mask posittions:        shape: torch.Size([32, 5]) 
encoded protein type:   shpae: torch.Size([32, 1])
specie_weight:          shape: torch.Size([32])
padding mask indicator  shape: torch.Size([32, 120])


In [18]:
# first amino-acid-seq
print(f"single amino acid seq: {amino_seqs[0]}")
print(f"target amino acid: {targets[0]}")
print(f"mask position: {mask_pos[0]}")
print(f"protien : {protien[0]}")
print(f"weight : {weight[0]}")
print(f"padding masks : {paddx[0]}")

single amino acid seq: tensor([16,  5,  3, 14,  0,  9, 17, 12, 17, 20,  3,  3,  3,  6,  3,  0, 11, 18,
         1, 11, 21, 14, 12, 16,  4,  9,  0, 21, 13,  0, 16,  0, 14,  0, 12, 11,
         1, 15, 12,  4, 16,  7, 10, 13, 18,  0,  9,  7, 12, 20,  8,  7,  7, 14,
        18, 21,  8,  9,  9, 13, 19, 11,  9,  0, 19, 13,  3,  3, 14,  8,  8, 18,
         0,  7, 15,  8,  7, 11, 12, 19,  1, 15, 17,  4, 11,  0, 11,  8, 11,  4,
         1, 17,  4,  3,  4, 14,  7,  0, 12, 21, 18,  5,  1, 18, 18, 17,  8,  8,
        10,  9, 15, 14,  8,  7, 21, 19, 22, 22, 22, 22])
target amino acid: tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100,    7, -100, -100, -100,
        -100, -100, -100,   19, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100,    6, -100, -100, -100, -100,
        -100, -100, -

## 2. Embeddings

- amino acid embeddings 
- position embeddings 
- protein type embeddings

In [19]:
class SinusoidalPositionEncoding(nn.Module):
    """
    Sinusoidal Positional Encoding module.

    This module generates sinusoidal position embeddings for input sequences.
    It can create up to `max_seq_length` unique position embeddings (default 5000).

    Args:
        embed_size (int): The size of each embedding vector.
        max_seq_length (int, optional): The maximum sequence length to support. 
            Defaults to 5000.

    Attributes:
        embed_size (int): The size of each embedding vector.
        pe (Tensor): The pre-computed position encoding matrix of shape 
            (1, max_seq_length, embed_size).

    Note:
        - The actual number of unique embeddings used depends on the input 
          sequence length in the forward pass.
        - While there are `max_seq_length` distinct vectors, positions beyond 
          this could theoretically be represented due to the periodic nature 
          of sine and cosine functions, albeit with some loss of uniqueness.
    """
    def __init__(self, embed_size, max_seq_length=5000):
        super().__init__()
        self.embed_size = embed_size
        
        pe = torch.zeros(max_seq_length, embed_size)
        position = torch.arange(0, max_seq_length, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return self.pe[:, :x.size(1)]


In [20]:
class BERTEmbeddings(nn.Module):

    def __init__(self, amino_vocab_size, protien_vocab_size, embed_size, max_seq_length, dropout=0.1):
        super().__init__()

        self.embed_size = embed_size
        self.amino_acid_token = torch.nn.Embedding(amino_vocab_size, embed_size, dtype=torch.float32)
        self.position = SinusoidalPositionEncoding(embed_size, max_seq_length=max_seq_length)
        self.protien_token = torch.nn.Embedding(protien_vocab_size, embed_size, dtype=torch.float32)
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, amino_acid_seqs, protiens):
        """
        amino_acid_seqs = (B * C) ; protien =  (32 * 1)
        output ===> (B * C * d_model)
        """
    
        amino_acid_embed = self.amino_acid_token(amino_acid_seqs) 
        pos_embed = self.position(amino_acid_seqs)
        protien_embed = self.protien_token(protiens)
        # to see the dim of each embeddings
        # print(amino_acid_embed.shape)
        # print(pos_embed.shape)
        # print(protien_embed.shape)
        out = amino_acid_embed + pos_embed + protien_embed

        return self.dropout(out)

In [21]:
# example 
amino_vocab_size = len(amino_acid_tokenizer)
protein_vocab_size = len(protein_tokenizer)
d_model = 64 # embedding size 
max_seq_length = 200 # this doen't have to be precise, this is only for positional encoding


test_emb = BERTEmbeddings(amino_vocab_size=amino_vocab_size,
                        protien_vocab_size=protein_vocab_size,
                        embed_size=d_model,
                        max_seq_length=max_seq_length)

In [22]:
## each iteration now gives a batch with 32 data points.
for data in masked_amino_seq_dataloader:
    print(f"input batch shape:     {data[0].shape} ")

    print(f"embedded batch shape: {test_emb(data[0], data[3]).shape}")

    break

input batch shape:     torch.Size([32, 120]) 
embedded batch shape: torch.Size([32, 120, 64])


## 3. Multi Headed Attention module

In [23]:
class MultiHeadedAttention(nn.Module):
    
    def __init__(self, heads, d_model, dropout=0.1):
        super().__init__()
        
        assert d_model % heads == 0
        self.d_k = d_model // heads
        self.heads = heads
        self.dropout = torch.nn.Dropout(dropout)

        self.query = torch.nn.Linear(d_model, d_model, dtype=torch.float32)
        self.key = torch.nn.Linear(d_model, d_model, dtype=torch.float32)
        self.value = torch.nn.Linear(d_model, d_model, dtype=torch.float32)
        self.output_linear = torch.nn.Linear(d_model, d_model, dtype=torch.float32)
        
    def forward(self, query, key, value, mask):
        """
        query, key, value of shape: (batch_size, max_len, d_model)
        mask of shape: (batch_size, 1, 1, max_words)
            # Note: mask is not used, it is mainly to tell attention the locations on which 
                it should not learn much, like padding indexes
                - we dont have padding here as of now, so no need it. 
                # TODO: add the use of padding tokens. 
        """
        # (batch_size, max_len, d_model)
        query = self.query(query)
        key = self.key(key)        
        value = self.value(value)   
        
        # (batch_size, max_len, d_model) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
        query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)   
        key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        
        # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
        scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / math.sqrt(query.size(-1))

        # to mask the pads (diff from the other mask) so the attention does not learn from it
        # # fill 0 mask with super small number so it wont affect the softmax weight
        # # (batch_size, h, max_len, max_len)
        scores = scores.masked_fill(mask.view(32, 1, 1, 120) == 0, -1e9)    

        # (batch_size, h, max_len, max_len)
        # softmax to put attention weight for all non-pad tokens
        # max_len X max_len matrix of attention
        weights = F.softmax(scores, dim=-1)           
        weights = self.dropout(weights)

        # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
        context = torch.matmul(weights, value)

        # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, d_model)
        context = context.permute(0, 2, 1, 3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)

        # (batch_size, max_len, d_model)
        return self.output_linear(context)


In [24]:
heads = MultiHeadedAttention(heads = 16, d_model=d_model)

In [25]:
## each iteration now gives a batch with 32 data points.
for i in masked_amino_seq_dataloader:
    print(f"input batch shape:     {i[0].shape} ")
    print(f"mask posiitons shape:   {i[2].shape}")
    embded = test_emb(data[0], data[3])
    print(f"embedded batch shape: {embded.shape}")
    mask = i[5]
    attention_output = heads(embded, embded, embded,  mask)

    print(f"The output from the Attention : {attention_output.shape}")

    break

input batch shape:     torch.Size([32, 120]) 
mask posiitons shape:   torch.Size([32, 5])
embedded batch shape: torch.Size([32, 120, 64])
The output from the Attention : torch.Size([32, 120, 64])
