In [1]:
import torch

In [2]:
import torch.nn as nn

In [3]:
import math

In [71]:
class InputEmbeddings(nn.Module):

    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

In [74]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model,seq_len, dropout):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        pe=torch.zeros(seq_len,d_model)
        position=torch.arange(0,seq_len,dtype=torch.float).unsqueeze(1)
        div_term=torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))

        pe[:,0::2]=torch.sin(position*div_term)
        pe[:,1::2]=torch.cos(position*div_term)

        pe=pe.unsqueeze(0)

        self.register_buffer('pe',pe)

    def forward(self,x):
        x=x+(self.pe[:,:x.shape[1],:]).requires_grad_(False)
        return self.dropout(x)

In [76]:
class LayerNormalisation(nn.Module):
    def __init__(self, eps:float =10**-6 , *args, **kwargs) -> None:
        super().__init__()
        self.eps=eps
        self.alpha=nn.Parameter(torch.ones(1))
        self.bias=nn.Parameter(torch.zeros(1))
    def forward(self,x):
        mean=x.mean(dim=-1,keepdim=True)
        std=x.std(dim=-1,keepdim=True)
        return self.alpha*(x-mean)/(std+self.eps) +self.bias

In [78]:
class FeedForwardBlock(nn.Module):
    def __init__(self,d_model:int,d_ff:int,dropout:float, *args, **kwargs) -> None:
        super().__init__()
        self.linear1=nn.Linear(d_model,d_ff)
        self.dropout=nn.Dropout(dropout)
        self.linear2=nn.Linear(d_ff,d_model)
    def forward(self,x):
        return self.linear2(self.dropout(torch.relu(self.linear1(x))))

In [80]:
class MultiHeadAttention(nn.Module):
    def __init__(self,d_model:int,h:int,dropout:float, *args, **kwargs) -> None:
        super().__init__()
        self.d_model=d_model
        self.h=h
        assert d_model%h==0 ,'dmodel is not divisible by h'
        self.d_k=d_model//h
        self.w_k=nn.Linear(d_model,d_model)
        self.w_q=nn.Linear(d_model,d_model)
        self.w_v=nn.Linear(d_model,d_model)
        self.w_o=nn.Linear(d_model,d_model)
        self.dropout=nn.Dropout(dropout)

    @staticmethod
    def attention(query,key,value,mask,dropout:nn.Dropout):
        d_k=query.shape[-1]
        attention_scores=(query@key.transpose(-2,-1))/math.sqrt(d_k)
        if mask is not None:
            attention_scores.masked_fill_(mask==0,-1e9)
        attention_scores=attention_scores.softmax(dim=-1)
        if dropout is not None:
            attention_scores=dropout(attention_scores)
        return (attention_scores@value),attention_scores

    def forward(self,q,k,v,mask):
        query=self.w_q(q)
        key=self.w_k(k)
        value=self.w_v(v)

        query=query.view(query.shape[0],query.shape[1],self.h,self.d_k).transpose(1,2)
        key=key.view(key.shape[0],key.shape[1],self.h,self.d_k).transpose(1,2)
        value=value.view(value.shape[0],value.shape[1],self.h,self.d_k).transpose(1,2)

        x,self.attention_scores=MultiHeadAttention.attention(query,key,value,mask,self.dropout)

        x=x.transpose(1,2).contiguous().view(x.shape[0],-1,self.h,self.d_k)

        return self.w_o(x)

In [85]:
class ResidualConnection(nn.Module):
    def __init__(self, dropout:float) -> None:
        super().__init__()
        self.dropout=nn.Dropout(dropout)
        self.norm=LayerNormalisation()
    def forward(self,x,sublayer):
        return x+self.dropout(sublayer(self.norm(x)))

In [86]:
class EncoderBlock(nn.Module):
    def __init__(self,features:int,self_attention_block:MultiHeadAttention,feed_forward_block:FeedForwardBlock,dropout:float,*args, **kwargs) -> None:
        super().__init__()
        self.self_attention_block=self_attention_block
        self.feed_forward_block=feed_forward_block
        self.residual_connections=nn.ModuleList([ResidualConnection(features,dropout) for _ in range(2)])

    def forward(self,x,src_mask):
        x=self.residual_connections[0](x,lambda x:self.self_attention_block(x,x,x,src_mask))
        x=self.residual_connections[1](self.feed_forward_block)
        return x

In [87]:
class Encoder(nn.Module):
    def __init__(self,features:int,layers:nn.ModuleList, *args, **kwargs) -> None:
        super().__init__()
        self.layers=layers
        self.norm=LayerNormalisation(features)

    def forward(self,x,mask):
        for layer in self.layers:
            x=layer(x,mask)
        return self.norm(x)

In [111]:
class DecoderBlock(nn.Module):
    def __init__(self,features:int,self_attention_block:MultiHeadAttention,cross_attention_block:MultiHeadAttention,feed_forward_block:FeedForwardBlock,dropout:float, *args, **kwargs) -> None:
        super().__init__()
        self.self_attention_block=self_attention_block
        self.cross_attention_block=cross_attention_block
        self.feed_forward_block=feed_forward_block
        self.residual_connections=nn.ModuleList([ResidualConnection(features,dropout) for _ in range(3)])

    def forawrd(self,x,encoder_output,srs_mask,tgt_mask):
        x=self.residual_connections[0](x,lambda x :self.self_attention_block(x,x,x,tgt_mask))
        x=self.residual_connections[1](x,lambda x :self.cross_attention_block(x,encoder_output,encoder_output,srs_mask))
        x=self.residual_connections[2](x,self.feed_forward_block)
        return x

In [112]:
class Decoder(nn.Module):
    def __init__(self,features:int,layers:nn.ModuleList, *args, **kwargs) -> None:
        super().__init__()
        self.layers=layers
        self.norm=LayerNormalisation(features)

    def forward(self,x,encoder_output,srs_mask,tgt_mask):
        for layers in self.layers:
            x=layers(x,encoder_output,srs_mask,tgt_mask)
        return self.norm(x)

In [113]:
class ProjectionLayer(nn.Module):
    def __init__(self,d_model,vocab_size, *args, **kwargs) -> None:
        super().__init__()
        self.proj=nn.Linear(d_model,vocab_size)

    def forward(self,x):
        return self.proj(x)

In [114]:
class Transformer(nn.Module):
    def __init__(self,encoder:Encoder,decoder:Decoder,srs_emb=InputEmbeddings,trg_emb=InputEmbeddings,srs_pos=PositionalEncoding,trg_pos=PositionalEncoding,projection_layer=ProjectionLayer,*args, **kwargs) -> None:
        super().__init__()
        self.encoder=encoder
        self.decoder=decoder
        self.srs_emb=srs_emb
        self.trg_emb=trg_emb
        self.srs_pos=srs_pos
        self.trg_pos=trg_pos
        self.projection_layer=projection_layer
    def encode(self,srs,srs_mask):
        srs=self.srs_emb(srs)
        srs=self.srs_pos(srs)
        return self.encoder(srs,srs_mask)
    def decode(self,encoder_output:torch.Tensor,srs_mask:torch.Tensor,tgt:torch.Tensor,tgt_mask:torch.Tensor):
        tgt=self.trg_emb(tgt)
        tgt=self.trg_pos(tgt)
        return self.decoder(tgt,encoder_output,srs_mask,tgt_mask)
    def project(self,x):
        return self.projection_layer(x)

In [None]:
def build_transformer(srs_vocab_size:int,tgt_vocab_size:int,srs_seq_len:int,tgt_seq_len:int,d_model=512,N=6,h=8,dropout=0.1,d_ff=2048):
    srs_emb=InputEmbeddings(d_model,srs_vocab_size)
    tgt_emb=InputEmbeddings(d_model,tgt_vocab_size)

    srs_pos=PositionalEncoding(d_model,srs_seq_len,dropout)
    tgt_pos=PositionalEncoding(d_model,tgt_seq_len,dropout)

    encoder_blocks=[]
    for _ in range(N):
        encoder_self_attention_block=MultiHeadAttention(d_model,h,dropout)
        feed_forward_block=FeedForwardBlock(d_model,d_ff,dropout)
        encoder_block=EncoderBlock(d_model,encoder_self_attention_block,feed_forward_block,dropout)
        encoder_blocks.append(encoder_block)

    decoder_blocks=[]
    for _ in range(N):
        decoder_self_attention_block=MultiHeadAttention(d_model,h,dropout)
        decoder_cross_attention_block=MultiHeadAttention(d_model,h,dropout)
        feed_forward_block=FeedForwardBlock(d_model,d_ff,dropout)
        decoder_block=DecoderBlock(d_model,decoder_self_attention_block,decoder_cross_attention_block,feed_forward_block,dropout)
        decoder_blocks.append(decoder_block)

    encoder=Encoder(d_model,nn.ModuleList(encoder_blocks))
    decoder=Decoder(d_model,nn.ModuleList(decoder_blocks))

    projection_layer=ProjectionLayer(d_model,tgt_vocab_size)

    transformer=Transformer(encoder,decoder,srs_emb,tgt_emb,srs_pos,tgt_pos,projection_layer)

    for p in transformer.parameters():
        if p.dim()>1:
            nn.init.xavier_uniform_(p)
    
    return transformer

