In [1]:
import torch 
import torch.nn as nn 
import math 
from torch.utils.data import Dataset,DataLoader 
import torch.optim as optim 
import copy


### Positional encoding

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self,d_model:int,seq_len:int)->None:
        super().__init__()
        self.d_model=d_model
        self.seq_len=seq_len

        # Create a matrix of size (seq_len,d_model)
        pe = torch.zeros(seq_len,d_model)
        position = torch.arange(0,seq_len,dtype=float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))
        
        pe[:,0::2] = torch.sin(position*div_term)
        pe[:,1::2] = torch.cos(position*div_term)
        
        self.register_buffer("pe",pe.unsqueeze(0))

    def forward(self,x):
        return x+self.pe[:,:x.size(1)]


### Feedforward neural network

In [3]:
class FeedForwardNN(nn.Module):
    def __init__(self,d_model:int,ff_dim:int)->None:
        super().__init__()
        self.d_model = d_model 
        self.ff_dim = ff_dim 
        self.linear_1 = nn.Linear(d_model,ff_dim)
        self.relu = nn.ReLU()
        self.linear_2 = nn.Linear(ff_dim,d_model)

    def forward(self,x):
        x = self.linear_1(x)
        x = self.relu(x)
        x = self.linear_2(x)


### Multihead attention

In [4]:
class MultiheadAttention(nn.Module):
    def __init__(self,d_model:int,num_heads:int)->None:
        super().__init__()

        assert d_model%num_heads==0 , "model dimention must be divisible by number of heads."
        self.d_model=d_model
        self.num_heads=num_heads
        self.d_k = d_model//num_heads

        self.w_Q = nn.Linear(d_model,d_model)
        self.w_K = nn.Linear(d_model,d_model)
        self.w_V = nn.Linear(d_model,d_model)
        self.w_W = nn.Linear(d_model,d_model)

    def split_heads(self,X):
        batch_size,seq_len,d_model = X.size()
        X = X.view(batch_size,seq_len,self.num_heads,self.d_k).transpose(1,2)
        return X

    def scaled_dot_product_attention(self,Q,K,V,mask=None):
        # softmax(Q.KT+opt(mask))/sert(d)*V

        score = torch.matmul(Q,K.transpose(-2,-1))/math.sqrt(self.d_k)
        if mask is not None:
            score = score.masked_fill(mask==0,-1e9)
        
        atten_prob = torch.softmax(score,dim=-1)
        return torch.matmul(atten_prob,V)

    def concatinate_heads(self,X):
        batch_size,_,seq_len,_ = X.size()
        X = X.transpose(1,2).contiguous().view(batch_size,seq_len,self.d_model)
        return X


    def forward(self,Q,K,V,mask=None):

        Q = self.split_heads(self.w_Q(Q)) # (batch_size,num_heads,seq_len,w_k)
        K = self.split_heads(self.w_K(K))
        V = self.split_heads(self.w_V(V))

        attention = self.scaled_dot_product_attention(Q,K,V,mask)

        return self.w_W(self.concatinate_heads(attention))


### Encoder Layer

In [5]:
class EncoderLayer(nn.Module):
    def __init__(self,d_model:int,num_heads:int,ff_dim:int,dropout:float)->None:
        super().__init__()
       
        self.self_attention = MultiheadAttention(d_model,num_heads)
        self.feed_forward = FeedForwardNN(d_model,ff_dim)
        self.layer_norm_1 = nn.LayerNorm(d_model)
        self.layer_norm_2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self,X,mask):
        attention = self.self_attention(X,X,X,mask)
        X = self.layer_norm_1(X+self.dropout(attention))
        ff_out = self.feed_forward(X)
        X = self.layer_norm_2(X+self.dropout(ff_out))
        return X



### Decoder Layer

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self,d_model:int,num_heads:int,ff_dim:int,dropout:float)->None:
        super().__init__()
        self.self_attention = MultiheadAttention(d_model,num_heads)
        self.cross_attention = MultiheadAttention(d_model,num_heads)
        self.ff_nn = FeedForwardNN(d_model,ff_dim)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x,enc_out,src_mask,trg_mask):
        atten_out = self.self_attention(x,x,x,trg_mask)
        x = self.norm1(x+self.dropout(atten_out))
        atten_out = self.cross_attention(x,enc_out,enc_out,src_mask)
        x = self.norm2(x+self.dropout(atten_out))
        ff_out = self.ff_nn(x)
        x = self.norm3(x+self.dropout(ff_out))
        return x

### Transformer (Encoder + Decoder)

In [6]:
class Transformer(nn.Module):
    def __init__(self,d_model:int,seq_len:int,num_layers:int,num_heads:int,ff_dim:int,dropout:float,source_vocab_size:int,target_vocab_size:int):
        super().__init__()
        self.encoder_embedding = nn.Embedding(source_vocab_size,d_model)
        self.decoder_embedding = nn.Embedding(target_vocab_size,d_model)
        self.positional_encoding = PositionalEncoding(d_model,seq_len)
        self.dropout = nn.Dropout(dropout)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model,num_heads,ff_dim,dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model,num_heads,ff_dim,dropout) for _ in range(num_layers)])

        self.fc_layer = nn.Linear(d_model,target_vocab_size)

    def generate_mask(self, src, tgt):
            src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
            tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
            seq_length = tgt.size(1)
            nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
            tgt_mask = tgt_mask & nopeak_mask
            return src_mask, tgt_mask
            

    def forward(self,src,tgt):
            src_mask,tgt_mask = self.generate_mask(src,tgt)

            src_embedding = self.encoder_embedding(src)
            tgt_embedding = self.decoder_embedding(tgt)

            src_embedding = self.dropout(self.positional_encoding(src_embedding))
            tgt_embedding = self.dropout(self.positional_encoding(tgt_embedding))

            src_out = src_embedding

            for encoder_layer in self.encoder_layers:
                src_out = encoder_layer(src_out,src_mask)

            tgt_out = tgt_embedding
            for decoder_layer in self.decoder_layers:
                tgt_out = decoder_layer(tgt_out,tgt_mask)

            output = self.fc_layer(tgt_out)
            return output

            
