In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math
import numpy as np

In [17]:
DEVICE = torch.device('cuda', 1) if torch.cuda.is_available else torch.device('cpu')
DTYPE = torch.float32

### Scaled Dot Product Attention

In [84]:
class SDPA(nn.Module):
    def __init__(self, in_dimension, out_dimension, kq_dimension,
                masked=False, device=DEVICE, dtype=DTYPE):
        super(SDPA, self).__init__()
        
        self.in_dimension = in_dimension
        self.out_dimension = out_dimension
        self.kq_dimension = kq_dimension
        
        self.masked = masked
        self.device = device
        self.dtype = dtype

        self.query = nn.Linear(in_features=self.in_dimension, out_features=self.kq_dimension,
                               device=self.device, dtype=self.dtype
        )
        
        self.key = nn.Linear(in_features=self.in_dimension, out_features=self.kq_dimension,
                               device=self.device, dtype=self.dtype
        )
        
        self.value = nn.Linear(in_features=self.in_dimension, out_features=self.out_dimension,
                               device=self.device, dtype=self.dtype
        )

    def forward(self,x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        
        kq = math.sqrt(self.kq_dimension)
        attn = torch.matmul(Q,K.permute(0,2,1))/kq
        if self.masked:
            attn_mask = torch.triu(torch.ones_like(attn), diagonal=1).bool()
            attn.masked_fill_(attn_mask, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        Z = torch.matmul(attn,V)
        return Z
        

### Multi Head Attention

In [99]:
class MHA(nn.Module):
    def __init__(self, in_dimension, out_dimension, kq_dimension, num_heads=8,
                 masked=False, device=DEVICE, dtype=DTYPE
                ):
        super(MHA, self).__init__()
        
        self.in_dimension = in_dimension
        self.out_dimension = out_dimension
        self.kq_dimension = kq_dimension
        self.num_heads = num_heads
        
        self.masked = masked
        self.device = device
        self.dtype = dtype
        
        self.sdpa_head_list = nn.ModuleList([])
        for _ in range(num_heads):
            sdpa_head = SDPA(self.in_dimension, self.out_dimension, self.kq_dimension,
                            masked=self.masked, device=DEVICE, dtype=DTYPE)
            self.sdpa_head_list.append(sdpa_head)
            
        self.out = nn.Linear(in_features=num_heads*out_dimension, out_features=out_dimension,
                             device=self.device, dtype=self.dtype
                            )
            
    def forward(self, x):
        y_list = []
        for head in self.sdpa_head_list:
            y = head(x)
            y_list.append(y)
        
        Y = torch.cat(y_list, dim=-1)
        Z = self.out(Y)
        return Z

### Single Encoder Layer

In [110]:
class EncoderLayer(nn.Module):
    def __init__(self, in_dimension, out_dimension, kq_dimension,
                 num_heads=8, linear_stretch=2, dropout=0.1,
                 device=DEVICE, dtype=DTYPE
                ):
        super(EncoderLayer, self).__init__()
        
        self.in_dimension = in_dimension
        self.out_dimension = out_dimension
        self.kq_dimension = kq_dimension
        self.num_heads = num_heads
        self.linear_stretch = linear_stretch
        self.dropout = dropout
        
        self.device = device
        self.dtype = dtype
        
        self.mha = MHA(self.in_dimension, self.out_dimension, self.kq_dimension, self.num_heads,
                       device=self.device, dtype=self.dtype)
        
        self.ff1 = nn.Linear(self.out_dimension, self.linear_stretch*self.out_dimension,
                             device=self.device, dtype=self.dtype
                            )
        
        self.ff2 = nn.Linear(self.linear_stretch*self.out_dimension, self.out_dimension,
                             device=self.device, dtype=self.dtype
                            )
        
        self.layernorm1 = nn.LayerNorm(self.out_dimension, device=self.device, dtype=self.dtype)
        self.layernorm2 = nn.LayerNorm(self.out_dimension, device=self.device, dtype=self.dtype)
        
        self.dropoutL = nn.Dropout(self.dropout)        
        
    def forward(self, x):
        residual = x
        Y = self.mha(x)
        Y = self.layernorm1(Y + residual) 
        Y = self.dropoutL(Y)
        
        residual = Y
        Z = self.ff1(Y)
        Z = F.relu(Z)
        Z = self.ff2(Z)
        Z = self.layernorm2(Z + residual)
        Z = self.dropoutL(Z)
        return Z

In [112]:
a  = torch.rand((3, 15, 35), device=DEVICE, dtype=DTYPE)
l = EncoderLayer(35, 35, 50)
b = l.forward(a)
b.shape

torch.Size([3, 15, 35])

In [7]:
a = torch.tensor([[1, 2, 3],
                  [3, 4, 5]])
l = nn.Embedding(6,13)
b = l(a)
b.shape

torch.Size([2, 3, 13])

### TRANSFORMER ENCODER

In [103]:
class ENCODER(nn.Module):
    def __init__(self, encoder_dimension, kq_dimension, vocab_size, seq_len,
                 num_heads=8, linear_stretch=2, num_layers=6, padding_index=0,
                 use_pos_enc=True, device=DEVICE, dtype=DTYPE
                ):
        super(ENCODER, self).__init__()
        
        self.encoder_dim = encoder_dimension
        self.kq_dimension = kq_dimension
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        
        self.num_heads = num_heads
        self.linear_stretch = linear_stretch
        self.num_layers = num_layers
        self.use_pos_enc = use_pos_enc
        self.padding_index = padding_index
        
        self.device = device
        self.dtype = dtype
        
        self.embd = nn.Embedding(self.vocab_size, self.encoder_dim, padding_idx=self.padding_index,
                                dtype=self.dtype, device=self.device)
        
        self.layer_list = nn.ModuleList([])
        for _ in range(self.num_layers):
            layer = EncoderLayer(self.encoder_dim, self.encoder_dim, self.kq_dimension,
                 num_heads=self.num_heads, linear_stretch=self.linear_stretch,
                 device=self.device, dtype=self.dtype
                )
            self.layer_list.append(layer)
        
    
    def positional_encoding(self):
        position = torch.arange(0, self.seq_len, dtype=self.dtype, device=self.device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.encoder_dim, 2, dtype=self.dtype, device=self.device) * -(torch.log(torch.tensor(10000.0)) / self.encoder_dim))
        pe = torch.zeros(self.seq_len, self.encoder_dim, dtype=self.dtype, device=self.device)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)
    
    def forward(self, x):
        Z = self.embd(x) * torch.sqrt(torch.tensor(self.encoder_dim, dtype=self.dtype, device=self.device))
        if self.use_pos_enc:
            pe = self.positional_encoding()
            Z = Z + pe
        print(Z.shape)
        for layer in self.layer_list:
            Z = layer(Z)
        return Z

### cross SDPA

In [97]:
class Cross_SDPA(nn.Module):
    def __init__(self, enc_inp_dim, dec_inp_dim, out_dimension, kq_dimension,
                device=DEVICE, dtype=DTYPE):
        super(Cross_SDPA, self).__init__()
        
        self.enc_inp_dim = enc_inp_dim
        self.dec_inp_dim = dec_inp_dim
        self.out_dimension = out_dimension
        self.kq_dimension = kq_dimension
        
        self.device = device
        self.dtype = dtype

        self.query = nn.Linear(in_features=self.dec_inp_dim, out_features=self.kq_dimension,
                               device=self.device, dtype=self.dtype
                            )
        
        self.key = nn.Linear(in_features=self.enc_inp_dim, out_features=self.kq_dimension,
                               device=self.device, dtype=self.dtype
                            )
        
        self.value = nn.Linear(in_features=self.enc_inp_dim, out_features=self.out_dimension,
                               device=self.device, dtype=self.dtype
                            )

    def forward(self, z, y):
        Q = self.query(y)
        K = self.key(z)
        V = self.value(z)
        
        kq = math.sqrt(self.kq_dimension)
        attn = torch.matmul(Q,K.permute(0,2,1))/kq
        attn = F.softmax(attn, dim=-1)
        
        Y = torch.matmul(attn,V)
        return Y

### Cross MHA

In [100]:
class Cross_MHA(nn.Module):
    def __init__(self, enc_inp_dim, dec_inp_dim, out_dimension, kq_dimension, num_heads=8,
                 device=DEVICE, dtype=DTYPE
                ):
        super(Cross_MHA, self).__init__()
        
        self.enc_inp_dim = enc_inp_dim
        self.dec_inp_dim = dec_inp_dim
        self.out_dimension = out_dimension
        self.kq_dimension = kq_dimension
        self.num_heads = num_heads
        
        self.device = device
        self.dtype = dtype
        
        self.sdpa_head_list = nn.ModuleList([])
        for _ in range(num_heads):
            sdpa_head = Cross_SDPA(self.enc_inp_dim, self.dec_inp_dim, self.out_dimension,
                                   self.kq_dimension, device=self.device, dtype=self.dtype
                                )
            self.sdpa_head_list.append(sdpa_head)
            
        self.out = nn.Linear(in_features=num_heads*out_dimension, out_features=out_dimension,
                             device=self.device, dtype=self.dtype
                            )
            
    def forward(self, z, y):
        w_list = []
        for head in self.sdpa_head_list:
            w = head(z, y)
            w_list.append(w)
        
        W = torch.cat(w_list, dim=-1)
        W = self.out(W)
        return W

### single decoder layer

In [106]:
class DecoderLayer(nn.Module):
    def __init__(self, enc_inp_dim, dec_inp_dim, out_dimension, kq_dimension,
                 num_heads=8, linear_stretch=2, dropout=0.1,
                 device=DEVICE, dtype=DTYPE
                ):
        super(DecoderLayer, self).__init__()
        
        self.enc_inp_dim = enc_inp_dim
        self.dec_inp_dim = dec_inp_dim
        
        self.out_dimension = out_dimension
        self.kq_dimension = kq_dimension
        self.num_heads = num_heads
        self.linear_stretch = linear_stretch
        self.dropout = dropout
        
        self.device = device
        self.dtype = dtype
        
        self.masked_mha = MHA(self.dec_inp_dim, self.out_dimension, self.kq_dimension, self.num_heads,
                                True, self.device, self.dtype
                            )
        
        self.cross_mha = Cross_MHA(self.enc_inp_dim, self.dec_inp_dim, self.out_dimension,
                                   self.kq_dimension, self.num_heads, device=self.device,
                                   dtype=self.dtype
                                )
        
        self.ff1 = nn.Linear(self.out_dimension, self.linear_stretch*self.out_dimension,
                             device=self.device, dtype=self.dtype
                            )
        
        self.ff2 = nn.Linear(self.linear_stretch*self.out_dimension, self.out_dimension,
                             device=self.device, dtype=self.dtype
                            )
        
        self.layernorm1 = nn.LayerNorm(self.out_dimension, device=self.device, dtype=self.dtype)
        self.layernorm2 = nn.LayerNorm(self.out_dimension, device=self.device, dtype=self.dtype)
        self.layernorm3 = nn.LayerNorm(self.out_dimension, device=self.device, dtype=self.dtype)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, e):
        residual = x
        Y = self.masked_mha(x)
        Y = self.layernorm1(Y + residual) 
        Y = self.dropout(Y)
        
        residual = Y
        Z = self.cross_mha(e, Y)
        Z = self.layernorm2(Z + residual)
        Z = self.dropout(Z)
        
        residual = Z
        W = self.ff1(Z)
        W = F.relu(W)
        W = self.ff2(W)
        W = self.layernorm3(W + residual)
        W = self.dropout(W)
        return W

### TRANSFORMER DECODER

In [107]:
class DECODER(nn.Module):
    def __init__(self, decoder_dimension, encoder_dimension, kq_dimension, vocab_size,
                 max_seq_len, num_heads=8, linear_stretch=2, num_layers=6, padding_index=0,
                 use_pos_enc=True, dropout=0.1, device=DEVICE, dtype=DTYPE):
        super(DECODER, self).__init__()
        self.decoder_dim = decoder_dimension
        self.encoder_dim = encoder_dimension
        self.kq_dimension = kq_dimension
        self.vocab_size = vocab_size
        
        self.max_seq_len = max_seq_len
        self.num_heads = num_heads
        self.linear_stretch = linear_stretch
        self.num_layers = num_layers
        self.use_pos_enc = use_pos_enc
        self.padding_index = padding_index
        
        self.device = device
        self.dtype = dtype

        self.embd = nn.Embedding(self.vocab_size, self.decoder_dim, padding_idx=self.padding_index,
                                 dtype=self.dtype, device=self.device)
        
        self.layer_list = nn.ModuleList([])
        for _ in range(self.num_layers):
            layer = DecoderLayer(self.encoder_dim, self.decoder_dim, self.decoder_dim, self.kq_dimension,
                                 self.num_heads, self.linear_stretch, self.dropout,
                                 self.device, self.dtype)
            self.layer_list.append(layer)
        
        self.final = nn.Linear(self.decoder_dim, self.vocab_size,
                               device=self.device, dtype=self.dtype)

    def positional_encoding(self):
        pe = torch.zeros(self.max_seq_len, self.decoder_dim, device=self.device, dtype=self.dtype)
        position = torch.arange(0, self.max_seq_len, dtype=self.dtype, device=self.device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.decoder_dim, 2, dtype=self.dtype, device=self.device) *
                             -(math.log(10000.0) / self.decoder_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        return pe

    def forward(self, x, encoder_output):
        Z = self.embd(x) * torch.sqrt(torch.tensor(self.decoder_dim, dtype=self.dtype, device=self.device))
        if self.use_pos_enc:
            pe = self.positional_encoding()
            Z = Z + pe[:, :Z.size(1), :]
        
        for layer in self.layer_list:
            Z = layer(Z, encoder_output)
        
        Y = self.final(Z)
        Y = F.softmax(Y, dim=-1)
        return Y