## BERT model of Amino-Seq Masked-Language-Model

## 0. init

In [2]:
import numpy as np
import random
import pandas as pd
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import tqdm
from torch.optim import Adam
import time

import string
from typing import Iterable, Tuple

## 1. Build Dataset and DataLoader

In [3]:
file_path = 'X_set.txt'

# Initialize lists to hold the phylogenetic position strings and amino acid sequences
specie_code = []
amino_acid_sequences = []

# Read the file
with open(file_path, 'r') as file:
    for line in file:
        parts = line.strip().split(' ')
        specie_code.append(parts[0])
        amino_acid_sequences.append(parts[1])

amino_acid_sequences[0:3]

['---LSQF--LLMLWVPGSKGEIVLTQSPASVSVSPGERVTISCQASESVGNTYLNWLQQKSGQSPRWLIYQVSKLESGIPARFRGSGSGTDFTFTISRVEAEDVAHYYSQQ-----',
 'MESLSQC--LLMLWVPVSRGAIVLTQSPALVSVSPGERVTISCKASQSVGNTYLSWFRQKPGQSPRGLIYKVSNLPSGVPSRFRGSGAEKDFTLTISRVEAVDGAVYYCAQASYSP',
 'MESLSQC--LLMLWVPVSRGAIVLTQSPASVSVSPGERVTISCKASQSLGNTYLHWFQQKPGQSPRRLIYQVSNLLSGVPSRFSGSGAGKDFSLTISSVEAGDGAVYYCFQGSYDP']

### 1.1. Tokenizer

- There are 20 amino acids, each letter in the chain represents one of them. 
- Converting them into 20 tokens, meaning each amino acid would get a number associated with it. 
- Would also need a special character token, which is "-", something related to multiple-sequence-alignment 

In [4]:
# Creating a set of all amino-acids

amino_acid_set = set()

for seq in amino_acid_sequences:
    for acid in seq:
        if acid != "-":
            amino_acid_set.add(acid)

# 20 amino acids
print(f"Num of Amino Acids: {len(amino_acid_set) }")
amino_acids_list = list(amino_acid_set)

Num of Amino Acids: 20


In [5]:
# Creating a Tokenzer class, which ennodes and decodes an amino acid sequence 

class Tokenizer:
    ''' 
    To encode and decode any amino acid string
    '''
    # class attribute 
    amino_acids = amino_acids_list

    def __init__(self, special_tokens = Iterable[str]):
        # define a vocab
        self.vocab = Tokenizer.amino_acids + list(special_tokens)
        # mapping each vocab to a token (a numeric value)
        self.token2idx = {token:i for i, token in enumerate(self.vocab)} 
        # mapping numeric value back to a token
        self.idx2token = {i:token for token, i  in self.token2idx.items()}

    def encode(self, inputs: Iterable[str]) -> Iterable[int]:
        return [self.token2idx[token] for token in inputs]
    
    def decode(self, inputs: Iterable[int]) -> Iterable[str]:
        return [self.idx2token[idx] for idx in inputs]

    def __len__(self):
        return len(self.vocab)

In [6]:
# creating an instance of the Tokenizer. 
amino_acid_tokenizer = Tokenizer(special_tokens=["-", "[MASK]"])

# let's encode the first amino-acid-sequence and see the first 10 positions
print(f"First 20 amino acids         : {[i for i in amino_acid_sequences[0][0:20]]}")
print(f"First 20 encoded amino acids : {amino_acid_tokenizer.encode(amino_acid_sequences[0])[0:20]}")
print(f"First 20 decoded amino acids : {amino_acid_tokenizer.decode(amino_acid_tokenizer.encode(amino_acid_sequences[0])[0:20])}")

First 20 amino acids         : ['-', '-', '-', 'L', 'S', 'Q', 'F', '-', '-', 'L', 'L', 'M', 'L', 'W', 'V', 'P', 'G', 'S', 'K', 'G']
First 20 encoded amino acids : [20, 20, 20, 7, 15, 0, 4, 20, 20, 7, 7, 9, 7, 17, 6, 18, 8, 15, 19, 8]
First 20 decoded amino acids : ['-', '-', '-', 'L', 'S', 'Q', 'F', '-', '-', 'L', 'L', 'M', 'L', 'W', 'V', 'P', 'G', 'S', 'K', 'G']


In [7]:
print(amino_acid_tokenizer.token2idx)

{'Q': 0, 'I': 1, 'R': 2, 'D': 3, 'F': 4, 'Y': 5, 'V': 6, 'L': 7, 'G': 8, 'M': 9, 'T': 10, 'E': 11, 'C': 12, 'A': 13, 'H': 14, 'S': 15, 'N': 16, 'W': 17, 'P': 18, 'K': 19, '-': 20, '[MASK]': 21}


### 1.2 Creating a Tensor for all amino-seqs

In [8]:
# making sure that the size of each amino-acid-seq is same

len_amino_acid_seq = set()
for seq in amino_acid_sequences:
    len_amino_acid_seq.add(len(seq))

# this set should have only one value 
len_amino_acid_seq
# perfect! all the seq are 116 character long

{116}

In [9]:

def create_amino_acids_tensor(amino_acid_sequences:list, my_tokenizer:Tokenizer):

    amino_acid_tensors = []

    for seq in amino_acid_sequences:
        amino_acid_tensors.append(torch.Tensor(my_tokenizer.encode(seq)).to(torch.int64))

    # stacking them 
    stacked_tensor =  torch.stack(amino_acid_tensors)

    return stacked_tensor

In [10]:
all_amino_acids_tensor = create_amino_acids_tensor(amino_acid_sequences, amino_acid_tokenizer)

In [11]:
all_amino_acids_tensor.shape

# 1001 seqs, each with the length of 116

torch.Size([1001, 116])

### 1.3 Create Training data 

- So what we need is to mask a random position in seq
- let's only mask one posiiton as of now

In [12]:
from torch.utils.data import Dataset, DataLoader

class MaskedAminoSeqDataset(Dataset):
    def __init__(self, input_tensor: torch.Tensor, mask_token: int):
            """
            Dataset for masked amino acid sequence prediction.

            Args:
            input_tensor (torch.Tensor): Input tensor of shape (num_sequences, sequence_length).
            mask_token (int): The token used for masking.
            """
            self.input_tensor = input_tensor
            self.mask_token = mask_token

    def __len__(self):
        return self.input_tensor.shape[0] 

    def __getitem__(self, idx):
        input_seqs, target_amino_acids, mask_positions = \
            self._create_training_data(self.input_tensor, batch_size=1, mask_token=self.mask_token)
        return input_seqs.squeeze(0), target_amino_acids.squeeze(0), mask_positions.squeeze(0)

    def _create_training_data(self, input_tensor: torch.Tensor, batch_size: int, mask_token: int, min_masks: int = 1, max_masks: int = 5):
        """
        Creates masked training data efficiently using vectorized operations with a random number of masks per sequence.

        Args:
        input_tensor (torch.Tensor): Input tensor of shape (num_sequences, sequence_length)
        batch_size (int): The desired batch size.
        mask_token (int): The token used for masking.
        min_masks (int): Minimum number of positions to mask in each sequence. Default is 1.
        max_masks (int): Maximum number of positions to mask in each sequence. Default is 5.

        Returns:
        tuple: (input_seqs, target_amino_acids, mask_positions)
            - input_seqs: Tensor of shape (batch_size, sequence_length) with masked sequences.
            - target_amino_acids: Tensor of shape (batch_size, sequence_length) with scattered target values.
            - mask_positions: Tensor of shape (batch_size, max_masks) indicating mask positions.
        """
        rows, seq_len = input_tensor.shape
        
        # Randomly select 'batch_size' rows (amino acid sequences)
        idx = torch.randint(rows, size=(batch_size,))
        input_seqs = input_tensor[idx].clone()

        # Generate random number of masks for each sequence
        num_masks_per_seq = torch.randint(min_masks, max_masks + 1, (batch_size,))

        # Create mask_positions tensor
        mask_positions = torch.zeros((batch_size, max_masks), dtype=torch.long)
        for i, num_masks in enumerate(num_masks_per_seq):
            mask_positions[i, :num_masks] = torch.randperm(seq_len)[:num_masks]

        # Create target_amino_acids tensor with the same shape as input_seqs
        target_amino_acids = torch.zeros_like(input_seqs)
        # TODO: This is wrong, as 0 is one of the tokens of Amino Acids

        # Create a mask for the selected positions
        mask = torch.zeros_like(input_seqs, dtype=torch.bool)

        # Use advanced indexing to set the target values and create mask
        for i in range(batch_size):
            positions = mask_positions[i, :num_masks_per_seq[i]]
            target_amino_acids[i, positions] = input_seqs[i, positions]
            mask[i, positions] = True

        # Apply the mask to replace the target positions with the mask_token
        input_seqs[mask] = mask_token

        return input_seqs, target_amino_acids, mask_positions

In [13]:
# token id for the MASK
amino_acid_tokenizer.encode(["[MASK]"])

[21]

In [14]:
# Assuming input_tensor is your tensor of amino acid sequences
masked_amino_seq_dataset = MaskedAminoSeqDataset(all_amino_acids_tensor, mask_token=21) 
masked_amino_seq_dataloader = DataLoader(masked_amino_seq_dataset, batch_size=32, shuffle=True)

In [15]:
## each iteration now gives a batch with 32 data points.
i, t, m = 0, 0, 0 
for data in masked_amino_seq_dataloader:
    print(f"amino seqs with masked: \n shape: {data[0].shape} \n {data[0]}")
    print(f"targets amino acid:  \n shape: {data[1].shape} \n{data[1]}")
    print(f"mask posittions:  \n shape: {data[2].shape} \n{data[2]}")

    i = data[0]
    t = data[1]
    m = data[2]
    break

amino seqs with masked: 
 shape: torch.Size([32, 116]) 
 tensor([[20, 20, 20,  ..., 11,  7, 18],
        [ 9,  2,  6,  ..., 15, 10, 18],
        [20, 20, 20,  ..., 15, 13, 18],
        ...,
        [ 9,  2,  6,  ..., 11,  5, 18],
        [ 9,  6,  4,  ..., 15, 17, 18],
        [20, 20, 20,  ..., 10, 10,  1]])
targets amino acid:  
 shape: torch.Size([32, 116]) 
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
mask posittions:  
 shape: torch.Size([32, 5]) 
tensor([[ 89,   3,  31, 106,   0],
        [ 40, 111,  71, 100,  38],
        [ 25,  26,  39,  87,   0],
        [ 56,  40,  85,  64,  93],
        [ 52,  89,  50,   0,   0],
        [ 80,  78,   0,   0,   0],
        [ 54,  94,  37,   0,   0],
        [ 77,  29,   0,   0,   0],
        [ 15,  70,  17,  61, 100],
        [ 65, 106,  82,  41,  62],
        [  4,  39,  

## 2. Embeddings

We need to embeddings

- amino acid embeddings 
- position embeddings



In [16]:
class SinusoidalPositionEncoding(nn.Module):
    def __init__(self, embed_size, max_seq_length=5000):
        super().__init__()
        self.embed_size = embed_size
        
        pe = torch.zeros(max_seq_length, embed_size)
        position = torch.arange(0, max_seq_length, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return self.pe[:, :x.size(1)]


In [17]:
class BERTEmbeddings(nn.Module):

    def __init__(self, vocab_size, embed_size, max_seq_length, dropout=0.1):
        super().__init__()

        self.embed_size = embed_size
        self.token = torch.nn.Embedding(vocab_size, embed_size, dtype=torch.float32)
        self.position = SinusoidalPositionEncoding(embed_size, max_seq_length=max_seq_length)
        self.dropout = torch.nn.Dropout(p=dropout)

    def forward(self, x):
    
        word_embed = self.token(x) 
        pos_embed = self.position(x)
        out = word_embed + pos_embed

        return self.dropout(out)

In [18]:
vocab_size = len(amino_acid_tokenizer)
d_model = 64 # embedding size 
max_seq_length = masked_amino_seq_dataset.input_tensor.shape[1]

In [19]:
test_emb = BERTEmbeddings(vocab_size=vocab_size, embed_size=d_model, max_seq_length=max_seq_length)

In [20]:
## each iteration now gives a batch with 32 data points.
for i in masked_amino_seq_dataloader:
    print(f"input batch shape:     {i[0].shape} ")

    print(f"embedded batch shape: {test_emb(i[0]).shape}")

    break

input batch shape:     torch.Size([32, 116]) 
embedded batch shape: torch.Size([32, 116, 64])


## 3. Multi Headed Attention 

In [21]:
class MultiHeadedAttention(nn.Module):
    
    def __init__(self, heads, d_model, dropout=0.1):
        super().__init__()
        
        assert d_model % heads == 0
        self.d_k = d_model // heads
        self.heads = heads
        self.dropout = torch.nn.Dropout(dropout)

        self.query = torch.nn.Linear(d_model, d_model, dtype=torch.float32)
        self.key = torch.nn.Linear(d_model, d_model, dtype=torch.float32)
        self.value = torch.nn.Linear(d_model, d_model, dtype=torch.float32)
        self.output_linear = torch.nn.Linear(d_model, d_model, dtype=torch.float32)
        
    def forward(self, query, key, value, mask):
        """
        query, key, value of shape: (batch_size, max_len, d_model)
        mask of shape: (batch_size, 1, 1, max_words)
            # Note: mask if not used, it is mainly to tell attention the locations on which 
                it should not learn much, like padding indexes
                - we dont have padding here as of now, so no need it. 
        """
        # (batch_size, max_len, d_model)
        query = self.query(query)
        key = self.key(key)        
        value = self.value(value)   
        
        # (batch_size, max_len, d_model) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
        query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)   
        key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        
        # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
        scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / math.sqrt(query.size(-1))

        # to mask the pads (diff from the other mask) so the attention does not learn from it
        # # fill 0 mask with super small number so it wont affect the softmax weight
        # # (batch_size, h, max_len, max_len)
        # scores = scores.masked_fill(mask == 0, -1e9)    

        # (batch_size, h, max_len, max_len)
        # softmax to put attention weight for all non-pad tokens
        # max_len X max_len matrix of attention
        weights = F.softmax(scores, dim=-1)           
        weights = self.dropout(weights)

        # (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
        context = torch.matmul(weights, value)

        # (batch_size, h, max_len, d_k) --> (batch_size, max_len, h, d_k) --> (batch_size, max_len, d_model)
        context = context.permute(0, 2, 1, 3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)

        # (batch_size, max_len, d_model)
        return self.output_linear(context)

class FeedForward(torch.nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, middle_dim=2048, dropout=0.1):
        super(FeedForward, self).__init__()
        
        self.fc1 = torch.nn.Linear(d_model, middle_dim)
        self.fc2 = torch.nn.Linear(middle_dim, d_model)
        self.dropout = torch.nn.Dropout(dropout)
        self.activation = torch.nn.GELU()

    def forward(self, x):
        out = self.activation(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

class EncoderLayer(torch.nn.Module):
    def __init__(
        self, 
        d_model=768,
        heads=12, 
        feed_forward_hidden=768 * 4, 
        dropout=0.1
        ):
        super(EncoderLayer, self).__init__()
        self.layernorm = torch.nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadedAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model, middle_dim=feed_forward_hidden)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, embeddings, mask):
        # embeddings: (batch_size, max_len, d_model)
        # encoder mask: (batch_size, 1, 1, max_len)
        # result: (batch_size, max_len, d_model)
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        # residual layer
        interacted = self.layernorm(interacted + embeddings)
        # bottleneck
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        encoded = self.layernorm(feed_forward_out + interacted)
        return encoded

In [22]:
heads = MultiHeadedAttention(heads = 16, d_model=d_model)

In [23]:
## each iteration now gives a batch with 32 data points.
for i in masked_amino_seq_dataloader:
    print(f"input batch shape:     {i[0].shape} ")
    print(f"mask posiitons shape:   {i[2].shape}")

    print(f"embedded batch shape: {test_emb(i[0]).shape}")
    embded = test_emb(i[0])
    mask = i[2]
    attention_output = heads(embded, embded, embded,  mask)

    print(f"The output from the Attention : {attention_output.shape}")

    break

input batch shape:     torch.Size([32, 116]) 
mask posiitons shape:   torch.Size([32, 5])
embedded batch shape: torch.Size([32, 116, 64])
The output from the Attention : torch.Size([32, 116, 64])


- The output from multiheaded attention goes through a little bit of forward passes, because why not!! 
- so below is a simple Feedforward pass code

In [24]:
class FeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, middle_dim=2048, dropout=0.1):
        super().__init__()
        
        self.fc1 = torch.nn.Linear(d_model, middle_dim)
        self.fc2 = torch.nn.Linear(middle_dim, d_model)
        self.dropout = torch.nn.Dropout(dropout)
        self.activation = torch.nn.GELU()

    def forward(self, x):
        out = self.activation(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

#### Encoder Layer

- putting all together, 

- embedded matrix comes, first it goes to Attention module 
- Then layer normalization 
- then a feed forward part 
- and it again goes through a layer normalization

In [25]:
class EncoderLayer(nn.Module):
    def __init__(
        self, 
        d_model=768,
        heads=12, 
        feed_forward_hidden=768 * 4, 
        dropout=0.1
        ):
        super(EncoderLayer, self).__init__()
        self.layernorm = torch.nn.LayerNorm(d_model, dtype=torch.float32)
        self.self_multihead = MultiHeadedAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model, middle_dim=feed_forward_hidden)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, embeddings, mask):
        # embeddings: (batch_size, max_len, d_model)
        # encoder mask: (batch_size, 1, 1, max_len)
        # result: (batch_size, max_len, d_model)
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        # residual layer
        interacted = interacted.to(torch.float32)
        embeddings = embeddings.to(torch.float32)

        interacted = self.layernorm(interacted + embeddings)
        # bottleneck
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        encoded = self.layernorm(feed_forward_out + interacted)
        return encoded

## 3. Forward pass

In [26]:
class BERT(nn.Module):
    """
    BERT model : Bidirectional Encoder Representations from Transformers.
    """

    def __init__(self, vocab_size, d_model=768, n_layers=12, heads=12, max_seq_length=500, dropout=0.1):
        """
        :param vocab_size: vocab_size of total words
        :param hidden: BERT model hidden size
        :param n_layers: numbers of Transformer blocks(layers)
        :param attn_heads: number of attention heads
        :param dropout: dropout rate
        """

        super().__init__()
        self.d_model = d_model
        self.n_layers = n_layers
        self.heads = heads

        # paper noted they used 4 * hidden_size for ff_network_hidden_size
        self.feed_forward_hidden = d_model * 4

        # embedding for BERT, sum of positional, segment, token embeddings
        self.embedding = BERTEmbeddings(vocab_size=vocab_size, embed_size=d_model, max_seq_length=max_seq_length)

        # multi-layers transformer blocks, deep network
        self.encoder_blocks = torch.nn.ModuleList(
            [EncoderLayer(d_model, heads, d_model * 4, dropout) for _ in range(n_layers)])

    def forward(self, x):
        # attention masking for padded token

        # (batch_size, 1, seq_len, seq_len)
        # mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        # as of now mask has no role to play, it's for not paying attention to Padding idx
        mask = torch.Tensor([1])

        # embedding the indexed sequence to sequence of vectors
        x = self.embedding(x)

        # running over multiple transformer blocks
        for encoder in self.encoder_blocks:
            x = encoder.forward(x, mask)
        return x

In [27]:
vocab_size = len(amino_acid_tokenizer)
d_model = 64 # embedding size 
max_seq_length = masked_amino_seq_dataset.input_tensor.shape[1]

In [28]:
bert_encoder_test = BERT(vocab_size=vocab_size, d_model=d_model, n_layers=3, heads=4, max_seq_length=116)

In [29]:
# let's see if we can run oe forward pass

## each iteration now gives a batch with 32 data points.
for i in masked_amino_seq_dataloader:
    print(f"input batch shape:     {i[0].shape} ")
    # So far we have not used this .. it to be used at for defining loss
    print(f"mask posiitons shape:   {i[2].shape}")

    # just pass on the raw x data and the bert model does everythign

    bert_output = bert_encoder_test(i[0])
    print(f"bert encoder output shape: {bert_output.shape}")

    
    break

input batch shape:     torch.Size([32, 116]) 
mask posiitons shape:   torch.Size([32, 5])
bert encoder output shape: torch.Size([32, 116, 64])


In [30]:
# predicting the masked amino-acid

class MaskedAminoModel(nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, hidden, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
        self.linear = torch.nn.Linear(hidden, vocab_size)
        self.log_softmax = torch.nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.log_softmax(self.linear(x))

In [31]:
class BERTAmino(nn.Module):
    """
    BERT Language Model
    Masked Language Model
    """

    def __init__(self, bert: BERT, vocab_size):
        """
        :param bert: BERT model which should be trained
        :param vocab_size: total vocab size for masked_lm
        """
        super().__init__()
        self.bert = bert
        self.mask_lm = MaskedAminoModel(self.bert.d_model, vocab_size)

    def forward(self, x):
        x = self.bert(x)
        return self.mask_lm(x)

In [32]:
mask_pred_test = BERTAmino(bert = bert_encoder_test, vocab_size=vocab_size)

In [33]:
# let's see if we can run oe forward pass
fianl_output = 0
## each iteration now gives a batch with 32 data points.
for i in masked_amino_seq_dataloader:
    print(f"input batch shape:     {i[0].shape} ")
    # So far we have not used this .. it to be used at for defining loss
    print(f"mask posiitons shape:   {i[2].shape}")

    # just pass on the raw x data and the bert model does everythig

    bert_mask_pred = mask_pred_test(i[0])
    print(f"final output shape: {bert_mask_pred.shape}")

    fianl_output = bert_mask_pred

    # wohooo... it predicted the the amino acid vector for each of location
    break

input batch shape:     torch.Size([32, 116]) 
mask posiitons shape:   torch.Size([32, 5])
final output shape: torch.Size([32, 116, 22])


In [34]:
fianl_output.shape

torch.Size([32, 116, 22])

## 4. Create a Loss Func

In [35]:
input, target, masks = 0, 0, 0 
for data in masked_amino_seq_dataloader:
    # print(f"amino seqs with masked: \n shape: {data[0].shape} \n {data[0]}")
    # print(f"targets amino acid:  \n shape: {data[1].shape} \n{data[1]}")
    # print(f"mask posittions:  \n shape: {data[2].shape} \n{data[2]}")

    input = data[0]
    target = data[1]
    masks = data[2]
    break

In [36]:
target

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])

In [37]:
mask_lm_output =  mask_pred_test(input)
print(f"lm output shape{mask_lm_output.shape}")
print(f"target shape: {target.shape}")

lm output shapetorch.Size([32, 116, 22])
target shape: torch.Size([32, 116])


In [38]:
mask_lm_output[0,2,:]

tensor([-2.9999, -3.3467, -2.8532, -2.6441, -3.4476, -4.1010, -3.1388, -3.2681,
        -3.9283, -4.5574, -3.1541, -2.8913, -3.9380, -2.1659, -3.0111, -3.1769,
        -3.1969, -2.9439, -2.8219, -2.9521, -3.7809, -2.5906],
       grad_fn=<SliceBackward0>)

In [39]:
mask_lm_output.permute(0,2,1).shape

torch.Size([32, 22, 116])

In [40]:
# because pytorch requires this in this shape where num of classes comes before the dimentions 

mask_lm_output.transpose(1, 2).shape

torch.Size([32, 22, 116])

In [41]:
# Negative log likelyhood function

criterion = nn.NLLLoss(ignore_index=0)
criterion(mask_lm_output.transpose(1, 2), target)

tensor(3.2629, grad_fn=<NllLoss2DBackward0>)

In [42]:
# the same if used as a funcitonal form
F.nll_loss(mask_lm_output.transpose(1, 2), target, ignore_index=0)

tensor(3.2629, grad_fn=<NllLoss2DBackward0>)

## 5. Traning Loop

In [43]:
class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.n_current_steps = 0
        self.init_lr = np.power(d_model, -0.5)

    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

In [44]:
class BERTTrainer:
    def __init__(
        self, 
        model, 
        train_dataloader, 
        test_dataloader=None, 
        lr= 1e-4,
        weight_decay=0.01,
        betas=(0.9, 0.999),
        warmup_steps=1000,
        log_freq=10000,
        device='mps'
        ):

        self.device = device
        self.model = model.to(self.device)
        self.train_data = train_dataloader
        self.test_data = test_dataloader

        # Setting the Adam optimizer with hyper-param
        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
        self.optim_schedule = ScheduledOptim(
            self.optim, self.model.bert.d_model, n_warmup_steps=warmup_steps
            )

        # Using Negative Log Likelihood Loss function for predicting the masked_token
        self.criterion = torch.nn.NLLLoss(ignore_index=0)
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
    
    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        
        avg_loss = 0.0
        
        mode = "train" if train else "test"

        # progress bar
        data_iter = tqdm.tqdm(
            enumerate(data_loader),
            desc="EP_%s:%d" % (mode, epoch),
            total=len(data_loader),
            bar_format="{l_bar}{r_bar}"
        )
        start_time = time.time()
        for _, data in enumerate(data_loader):

            # 0. batch_data will be sent into the device(GPU or cpu)
            #data = data.to(self.device)
            #data = {key: value.to(self.device) for key, value in data.items()}

            # 1. forward the next_sentence_prediction and masked_lm model
            mask_lm_output = self.model.forward(data[0].to(self.device)) # data at 0 is input

            # 2-1. NLL(negative log likelihood) loss of is_next classification result
            # next_loss = self.criterion(next_sent_output, data["is_next"])

            # 2-2. NLLLoss of predicting masked token word
            # transpose to (m, vocab_size, seq_len) vs (m, seq_len)
            # criterion(mask_lm_output.view(-1, mask_lm_output.size(-1)), data["bert_label"].view(-1))
            mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data[1].to(self.device)) # data at 1 is target

            # 2-3. mask_loss : 3.4 Pre-training Procedure
            loss =  mask_loss

            # 3. backward and optimization only in train
            if train:
                self.optim_schedule.zero_grad()
                loss.backward()
                self.optim_schedule.step_and_update_lr()

            # next sentence prediction accuracy

            avg_loss += loss.item()
            elapsed = time.time() - start_time

        print(f"EP{epoch}, {mode}: avg_loss={avg_loss:.6f}, time={elapsed:.2f}s")

            # post_fix = {
            #     "epoch": epoch,
            #     "iter": i,
            #     "avg_loss": avg_loss / (i + 1),
            #     "loss": loss.item()
            # }

            # if i % self.log_freq == 0:
            #     data_iter.write(str(post_fix))
            #     print(
            #         f"EP{epoch}, {mode}: \
            #         avg_loss={avg_loss / len(data_iter)}" )
        

In [45]:
# checking if the mps is availabe
print(torch.backends.mps.is_available())

True


In [46]:
# training run

vocab_size = len(amino_acid_tokenizer)
d_model = 64 # embedding size 
heads = 4
n_layers = 2 
device = 'mps' if torch.backends.mps.is_available() else 'cpu'

max_seq_length = masked_amino_seq_dataset.input_tensor.shape[1]

bert_model = BERT(vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, heads=heads, max_seq_length=max_seq_length)

bert_lm = BERTAmino(bert_model, vocab_size)
bert_trainer = BERTTrainer(bert_lm, masked_amino_seq_dataloader, device=device)
epochs = 20

for epoch in range(epochs):
  bert_trainer.train(epoch)

Total Parameters: 102550


EP_train:0:   0%|| 0/32 [00:07<?, ?it/s]


EP0, train: avg_loss=98.830059, time=7.36s


EP_train:1:   0%|| 0/32 [00:00<?, ?it/s]


EP1, train: avg_loss=92.435315, time=0.84s


EP_train:2:   0%|| 0/32 [00:00<?, ?it/s]


EP2, train: avg_loss=89.371938, time=0.72s


EP_train:3:   0%|| 0/32 [00:00<?, ?it/s]


EP3, train: avg_loss=86.914500, time=0.67s


EP_train:4:   0%|| 0/32 [00:00<?, ?it/s]


EP4, train: avg_loss=81.975970, time=0.69s


EP_train:5:   0%|| 0/32 [00:00<?, ?it/s]


EP5, train: avg_loss=77.620508, time=0.73s


EP_train:6:   0%|| 0/32 [00:00<?, ?it/s]


EP6, train: avg_loss=75.232580, time=0.69s


EP_train:7:   0%|| 0/32 [00:00<?, ?it/s]


EP7, train: avg_loss=74.009016, time=0.65s


EP_train:8:   0%|| 0/32 [00:00<?, ?it/s]


EP8, train: avg_loss=72.849724, time=0.72s


EP_train:9:   0%|| 0/32 [00:00<?, ?it/s]


EP9, train: avg_loss=70.343302, time=0.70s


EP_train:10:   0%|| 0/32 [00:00<?, ?it/s]


EP10, train: avg_loss=69.166628, time=0.67s


EP_train:11:   0%|| 0/32 [00:00<?, ?it/s]


EP11, train: avg_loss=67.990653, time=0.67s


EP_train:12:   0%|| 0/32 [00:00<?, ?it/s]


EP12, train: avg_loss=66.967089, time=0.64s


EP_train:13:   0%|| 0/32 [00:00<?, ?it/s]


EP13, train: avg_loss=66.730579, time=0.66s


EP_train:14:   0%|| 0/32 [00:00<?, ?it/s]


EP14, train: avg_loss=64.778711, time=0.70s


EP_train:15:   0%|| 0/32 [00:00<?, ?it/s]


EP15, train: avg_loss=63.854363, time=0.75s


EP_train:16:   0%|| 0/32 [00:00<?, ?it/s]


EP16, train: avg_loss=63.232361, time=0.76s


EP_train:17:   0%|| 0/32 [00:00<?, ?it/s]


EP17, train: avg_loss=62.540929, time=0.67s


EP_train:18:   0%|| 0/32 [00:00<?, ?it/s]


EP18, train: avg_loss=60.240434, time=0.64s


EP_train:19:   0%|| 0/32 [00:00<?, ?it/s]

EP19, train: avg_loss=58.693107, time=0.68s





In [47]:
bert_model

BERT(
  (embedding): BERTEmbeddings(
    (token): Embedding(22, 64)
    (position): SinusoidalPositionEncoding()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder_blocks): ModuleList(
    (0-1): 2 x EncoderLayer(
      (layernorm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (self_multihead): MultiHeadedAttention(
        (dropout): Dropout(p=0.1, inplace=False)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (key): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (output_linear): Linear(in_features=64, out_features=64, bias=True)
      )
      (feed_forward): FeedForward(
        (fc1): Linear(in_features=64, out_features=256, bias=True)
        (fc2): Linear(in_features=256, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation): GELU(approximate='none')
      )
      (dropout): Dropout(p=0.1, inplace=False)
    

## Prediction from the trained model

In [48]:
def prepare_input(sequence, mask_token_id, mask_position=None):
    # Convert the sequence to a list of tokens
    tokens = list(sequence)
    
    # If mask_position is not provided, choose a random position to mask
    if mask_position is None:
        mask_position = random.randint(0, len(tokens) - 1)

    # Replace the chosen token with the mask token
    original_token = tokens[mask_position]
    tokens[mask_position] = "[MASK]"
    
    # Convert tokens to ids
    input_ids = amino_acid_tokenizer.encode(tokens)
    
    return torch.tensor([input_ids]), mask_position, original_token

In [49]:

# Example usage
ex_sequence = "MASLTQC--LLLFWLAGSQGEVVLTQSPASVSVSLGERVTIKCKASQSLGKTYLHWFQQKLGKSIKRTIYQVSNLDSGVPPRFSGSGSGTDFTLTISSLEPEDAAMYYCGQHTHWP"

# Mask a specific position (e.g., the 10th amino acid)
input_tensor, mask_position, original_token = prepare_input(ex_sequence,
                                                            amino_acid_tokenizer.token2idx["[MASK]"],
                                                            mask_position=2)


In [50]:
print(f"mask_position = {mask_position}")
print(f"original token = {original_token}")

mask_position = 2
original token = S


In [51]:
preds = bert_lm(input_tensor.to(device))
preds.shape

torch.Size([1, 116, 22])

In [52]:
# getting the probabilites at the mask_positing 

# Predicted scores for each of the 
preds[:,2,:]

tensor([[-6.1034, -3.6875, -4.7891, -2.8770, -4.9880, -6.6007, -2.2603, -4.1596,
         -2.5774, -4.4833, -4.4368, -4.6260, -6.9109, -2.1862, -4.9076, -2.4845,
         -5.3131, -5.1837, -0.8550, -3.7203, -5.3274, -6.1643]],
       device='mps:0', grad_fn=<SliceBackward0>)

In [53]:
def predict_masked_token(model, input_tensor, mask_position):
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        input_tensor = input_tensor.to(device)
        output = model(input_tensor)
        
        # Get the prediction for the masked position
        masked_token_logits = output[0, mask_position, :]
        
        # to get back the probabilites from the log-softmax (model's output)
        token_probabilities = torch.exp(masked_token_logits)
        
        return token_probabilities

In [54]:
token_prob = predict_masked_token(model=bert_lm, input_tensor=input_tensor, mask_position=mask_position)
token_prob

tensor([0.0022, 0.0427, 0.0078, 0.0438, 0.0105, 0.0022, 0.1419, 0.0437, 0.0503,
        0.0230, 0.0169, 0.0068, 0.0011, 0.1246, 0.0078, 0.1010, 0.0066, 0.0087,
        0.3165, 0.0208, 0.0191, 0.0021], device='mps:0')

In [55]:
def interpret_predictions(token_probabilities, tokenizer):
    sorted_probs, sorted_indices = torch.sort(token_probabilities, descending=True)
    for prob, index in zip(sorted_probs, sorted_indices):
        token = tokenizer.idx2token[index.item()]
        print(f"{token}: {prob.item():.4f}")

# Example usage
interpret_predictions(token_prob, amino_acid_tokenizer)

P: 0.3165
V: 0.1419
A: 0.1246
S: 0.1010
G: 0.0503
D: 0.0438
L: 0.0437
I: 0.0427
M: 0.0230
K: 0.0208
-: 0.0191
T: 0.0169
F: 0.0105
W: 0.0087
R: 0.0078
H: 0.0078
E: 0.0068
N: 0.0066
Y: 0.0022
Q: 0.0022
[MASK]: 0.0021
C: 0.0011


In [56]:
# Productionalization of this whole code. 

""" amino_bert/
   ├── data/
   │   ├── __init__.py
   │   ├── dataset.py
   │   └── tokenizer.py
   ├── models/
   │   ├── __init__.py
   │   ├── bert.py
   │   ├── embeddings.py
   │   └── attention.py
   ├── training/
   │   ├── __init__.py
   │   ├── trainer.py
   │   └── optimizer.py
   ├── utils/
   │   ├── __init__.py
   │   └── helpers.py
   ├── config.py
   ├── main.py
   └── requirements.txt
"""

' amino_bert/\n   ├── data/\n   │   ├── __init__.py\n   │   ├── dataset.py\n   │   └── tokenizer.py\n   ├── models/\n   │   ├── __init__.py\n   │   ├── bert.py\n   │   ├── embeddings.py\n   │   └── attention.py\n   ├── training/\n   │   ├── __init__.py\n   │   ├── trainer.py\n   │   └── optimizer.py\n   ├── utils/\n   │   ├── __init__.py\n   │   └── helpers.py\n   ├── config.py\n   ├── main.py\n   └── requirements.txt\n'