<a href="https://colab.research.google.com/github/telanan/test/blob/master/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import math

In [2]:
random_torch=torch.rand(4,4)
print(random_torch)

tensor([[0.6738, 0.5451, 0.0017, 0.2785],
        [0.6992, 0.3364, 0.8491, 0.5756],
        [0.9297, 0.9176, 0.3056, 0.1768],
        [0.6977, 0.0268, 0.2305, 0.4321]])


In [3]:
class TokenEmbedding(nn.Embedding):
    def __init__(self,vocab_size,d_model):
        super().__init__(vocab_size,d_model,padding_idx=1)

In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self,d_model,max_len,device):
        super(PositionalEncoding,self).__init__()
        self.encoding=torch.zeros(max_len,d_model,device=device)
        self.encoding.requires_grad=False

        pos=torch.arange(0,max_len,device=device)
        pos=pos.float().unsqueeze(dim=1)
        _2i=torch.arange(0,d_model,step=2,device=device).float()
        self.encoding[:,0::2]=torch.sin(pos/10000**(_2i/d_model))
        self.encoding[:,1::2]=torch.cos(pos/10000**(_2i/d_model))

    def forward(self,x):
        batch_size,seq_len=x.size()
        return self.encoding[:seq_len,:]

In [7]:
class TransformerEmbedding(nn.Module):
    def __init__(self,vocab_size,d_model,max_len,drop_prob,device):
        super(TransformerEmbedding,self).__init__()
        self.tok_emb = TokenEmbedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model, max_len, device)
        self.drop_out = nn.Dropout(p=drop_prob)
    def forward(self,x):
        tok_emb=self.tok_emb(x)
        pos_emb=self.pos_emb(x)
        return self.drop_out(tok_emb+pos_emb)

In [9]:
x=torch.rand(128,32,512)
d_model=512
n_head=8

In [12]:
class MutilHeadAttention(nn.Module):
    def __init__(self,d_model,n_head):
        super(MutilHeadAttention,self).__init__()
        self.d_model=d_model
        self.n_head=n_head
        self.w_q=nn.Linear(d_model,d_model)
        self.w_k=nn.Linear(d_model,d_model)
        self.w_v=nn.Linear(d_model,d_model)
        self.w_combine=nn.Linear(d_model,d_model)
        self.softmax=nn.Softmax(dim=-1)
    def forward(self,q,k,v,mask=None):
        batch,time,dimension=q.shape
        q=self.w_q(q)
        k=self.w_k(k)
        v=self.w_v(v)
        n_d=self.d_model//self.n_head
        q=q.view(batch,time,self.n_head,n_d).permute(0,2,1,3)
        k=k.view(batch,time,self.n_head,n_d).permute(0,2,1,3)
        v=v.view(batch,time,self.n_head,n_d).permute(0,2,1,3)
        score=torch.matmul(q,k.permute(0,1,3,2))/math.sqrt(n_d)
        if mask is not None:
            score=score.masked_fill(mask==0,-1e9)
        score=self.softmax(score)@v
        score=score.permute(0,2,1,3).contiguous().view(batch,time,dimension)
        output=self.w_combine(score)
        return output

attention=MutilHeadAttention(d_model,n_head)
output=attention(x,x,x)
print(output.shape)

torch.Size([128, 32, 512])


In [13]:
class LayerNorm(nn.Module):
    def __init__(self,d_model,eps=1e-12):
        super(LayerNorm,self).__init__()
        self.gamma=nn.Parameter(torch.ones(d_model))
        self.beta=nn.Parameter(torch.zeros(d_model))
        self.eps=eps
    def forward(self,x):
        mean=x.mean(-1,keepdim=True)
        std=x.std(-1,keepdim=True)
        return self.gamma*(x-mean)/(std+self.eps)+self.beta

In [14]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self,d_model,hidden,drop_prob=0.1):
        super(PositionwiseFeedForward,self).__init__()
        self.linear1=nn.Linear(d_model,hidden)
        self.linear2=nn.Linear(hidden,d_model)
        self.drop_out=nn.Dropout(p=drop_prob)
    def forward(self,x):
        x=self.linear1(x)
        x=F.relu(x)
        x=self.drop_out(x)
        x=self.linear2(x)
        return x

In [17]:
class EncoderLayer(nn.Module):
    def __init__(self,d_model,ffn_hidden,n_head,drop_prob):
        super(EncoderLayer,self).__init__()
        self.attention=MutilHeadAttention(d_model,n_head)
        self.norm1=LayerNorm(d_model)
        self.dropout1=nn.Dropout(p=drop_prob)
        self.ffn=PositionwiseFeedForward(d_model,ffn_hidden,drop_prob)
        self.norm2=LayerNorm(d_model)
        self.dropout2=nn.Dropout(p=drop_prob)
    def forward(self,x,mask):
        _x=x
        x=self.attention(x,x,x,mask)
        x=self.dropout1(x)
        x=self.norm1(x+_x)
        _x=x
        x=self.ffn(x)
        x=self.dropout2(x)
        x=self.norm2(x+_x)
        return x

In [19]:
class Encoder(nn.Module):
    def __init__(self,enc_voc_size,d_model,max_len,ffn_hidden,n_head,n_layer,drop_prob,device):
        super(Encoder,self).__init__()
        self.embedding=TransformerEmbedding(enc_voc_size,d_model,max_len,drop_prob,device)
        self.layers=nn.ModuleList([EncoderLayer(d_model,ffn_hidden,n_head,drop_prob) for _ in range(n_layer)])
    def forward(self,x,s_mask):
        x=self.embedding(x)
        for layer in self.layers:
            x=layer(x,s_mask)
        return x

In [20]:
class DecoderLayer(nn.Module):
    def __init__(self,d_model,ffn_hidden,n_head,drop_prob):
        super(DecoderLayer,self).__init__()
        self.attention1=MutilHeadAttention(d_model,n_head)
        self.norm1=LayerNorm(d_model)
        self.dropout1=nn.Dropout(p=drop_prob)
        self.cross_attention=MutilHeadAttention(d_model,n_head)
        self.norm2=LayerNorm(d_model)
        self.dropout2=nn.Dropout(p=drop_prob)
        self.ffn=PositionwiseFeedForward(d_model,ffn_hidden,drop_prob)
        self.norm3=LayerNorm(d_model)
        self.dropout3=nn.Dropout(p=drop_prob)
    def forward(self,dec,enc_out,s_mask,t_mask):
        _x=dec
        x=self.attention1(dec,dec,dec,t_mask)
        x=self.dropout1(x)
        x=self.norm1(x+_x)
        _x=x
        x=self.cross_attention(x,enc_out,enc_out,s_mask)
        x=self.dropout2(x)
        x=self.norm2(x+_x)
        _x=x # Added to match encoder layer structure
        x=self.ffn(x)
        x=self.dropout3(x) # Changed to dropout3
        x=self.norm3(x+_x) # Changed to norm3
        return x

In [21]:
class Decoder(nn.Module):
    def __init__(self,dec_voc_size,d_model,max_len,ffn_hidden,n_head,n_layer,drop_prob,device):
        super(Decoder,self).__init__()
        self.embedding=TransformerEmbedding(dec_voc_size,d_model,max_len,drop_prob,device)
        self.layers=nn.ModuleList([DecoderLayer(d_model,ffn_hidden,n_head,drop_prob) for _ in range(n_layer)])
        self.fc=nn.Linear(d_model,dec_voc_size)
    def forward(self,dec,enc_out,s_mask,t_mask):
        dec=self.embedding(dec) # Applied embedding to dec
        for layer in self.layers:
            dec=layer(dec,enc_out,s_mask,t_mask)
        dec=self.fc(dec)
        return dec

In [23]:
class Transformer(nn.Module):
    def __init__(self,src_pad_idx,trg_pad_idx,enc_voc_size,dec_voc_size,
                 d_model,max_len,n_head,ffn_hidden,n_layer,drop_out,device):
        super(Transformer,self).__init__()
        self.encoder=Encoder(enc_voc_size,d_model,max_len,ffn_hidden,n_head,n_layer,drop_out,device)
        self.decoder=Decoder(dec_voc_size,d_model,max_len,ffn_hidden,n_head,n_layer,drop_out,device)
        self.src_pad_idx=src_pad_idx
        self.trg_pad_idx=trg_pad_idx
        self.device=device
    def make_pad_mask(self,q,k,pad_idx_q,pad_idx_k):
        q_len=q.shape[1]
        k_len=k.shape[1]
        q_pad_mask=q.eq(pad_idx_q).unsqueeze(1).unsqueeze(3)
        q_pad_mask=q_pad_mask.repeat(1,1,1,k_len)
        k_pad_mask=k.eq(pad_idx_k).unsqueeze(1).unsqueeze(2)
        k_pad_mask=k_pad_mask.repeat(1,1,q_len,1)
        mask=q_pad_mask&k_pad_mask
        return mask
    def make_casual_mask(self,q,k):
        batch,q_len=q.shape
        mask=torch.tril(torch.ones(q_len,q_len)).bool().to(self.device)
        return mask
    def forward(self,src,trg):
        src_mask=self.make_pad_mask(src,src,self.src_pad_idx,self.src_pad_idx)
        trg_mask=self.make_pad_mask(trg,trg,self.trg_pad_idx,self.trg_pad_idx)
        trg_mask=trg_mask&self.make_casual_mask(trg,trg)
        enc_src=self.encoder(src,src_mask)
        output=self.decoder(trg,enc_src,src_mask,trg_mask)
        return output