In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math
import numpy as np

In [17]:
DEVICE = torch.device('cuda', 1) if torch.cuda.is_available else torch.device('cpu')
DTYPE = torch.float32

### Scaled Dot Product Attention

In [18]:
class SDPA(nn.Module):
    def __init__(self, in_dimension, out_dimension, kq_dimension, device=DEVICE, dtype=DTYPE):
        super(SDPA, self).__init__()
        
        self.in_dimension = in_dimension
        self.out_dimension = out_dimension
        self.kq_dimension = kq_dimension
        self.device = device
        self.dtype = dtype

        self.query = nn.Linear(in_features=self.in_dimension, out_features=self.kq_dimension,
                               device=self.device, dtype=self.dtype
        )
        
        self.key = nn.Linear(in_features=self.in_dimension, out_features=self.kq_dimension,
                               device=self.device, dtype=self.dtype
        )
        
        self.value = nn.Linear(in_features=self.in_dimension, out_features=self.out_dimension,
                               device=self.device, dtype=self.dtype
        )

    def forward(self,x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        
        kq = math.sqrt(self.kq_dimension)
        attn = torch.matmul(Q,K.permute(0,2,1))/kq
        attn = F.softmax(attn, dim=0)
        
        Z = torch.matmul(attn,V)
        return Z
        

### Multi Head Attention

In [19]:
class MHA(nn.Module):
    def __init__(self, in_dimension, out_dimension, kq_dimension, num_heads=8,
                 device=DEVICE, dtype=DTYPE
                ):
        super(MHA, self).__init__()
        
        self.in_dimension = in_dimension
        self.out_dimension = out_dimension
        self.kq_dimension = kq_dimension
        self.num_heads = num_heads
        
        self.device = device
        self.dtype = dtype
        
        self.sdpa_head_list = nn.ModuleList([])
        for _ in range(num_heads):
            sdpa_head = SDPA(self.in_dimension, self.out_dimension, self.kq_dimension,
                            device=DEVICE, dtype=DTYPE)
            self.sdpa_head_list.append(sdpa_head)
            
        self.out = nn.Linear(in_features=num_heads*out_dimension, out_features=out_dimension,
                             device=self.device, dtype=self.dtype
                            )
            
    def forward(self, x):
        y_list = []
        for head in self.sdpa_head_list:
            y = head(x)
            y_list.append(y)
        
        Y = torch.cat(y_list, dim=-1)
        Z = self.out(Y)
        return Z
                
        

### Single Encoder Layer

In [20]:
class EncoderLayer(nn.Module):
    def __init__(self, in_dimension, out_dimension, kq_dimension,
                 num_heads=8, linear_stretch=2,
                 device=DEVICE, dtype=DTYPE
                ):
        super(EncoderLayer, self).__init__()
        
        self.in_dimension = in_dimension
        self.out_dimension = out_dimension
        self.kq_dimension = kq_dimension
        self.num_heads = num_heads
        self.linear_stretch = linear_stretch
        
        self.device = device
        self.dtype = dtype
        
        self.mha = MHA(self.in_dimension, self.out_dimension, self.kq_dimension, self.num_heads,
                       self.device, self.dtype)
        
        self.ff1 = nn.Linear(self.out_dimension, self.linear_stretch*self.out_dimension,
                             device=self.device, dtype=self.dtype
                            )
        
        self.ff2 = nn.Linear(self.linear_stretch*self.out_dimension, self.out_dimension,
                             device=self.device, dtype=self.dtype
                            )
        
        self.layernorm1 = nn.LayerNorm(self.out_dimension, device=self.device, dtype=self.dtype)
        self.layernorm2 = nn.LayerNorm(self.out_dimension, device=self.device, dtype=self.dtype)
        
    def forward(self, x):
        residual = x
        Y = self.mha(x)
        Y = self.layernorm1(Y + residual) 
        
        residual = Y
        Z = self.ff1(Y)
        Z = F.relu(Z)
        Z = self.ff2(Z)
        Z = self.layernorm2(Z + residual)
        
        return Z

In [6]:
a  = torch.rand((3, 15, 35), device=DEVICE, dtype=DTYPE)
l = EncoderLayer(35, 35, 50)
b = l.forward(a)
b.shape

torch.Size([3, 15, 35])

In [7]:
a = torch.tensor([[1, 2, 3],
                  [3, 4, 5]])
l = nn.Embedding(6,13)
b = l(a)
b.shape

torch.Size([2, 3, 13])

### TRANSFORMER ENCODER

In [21]:
class ENCODER(nn.Module):
    def __init__(self, model_dimension, kq_dimension, vocab_size,
                 num_heads=8, linear_stretch=2, num_layers=6,
                 use_pos_enc=True, device=DEVICE, dtype=DTYPE
                ):
        super(ENCODER, self).__init__()
        
        self.model_dim = model_dimension
        self.kq_dimension = kq_dimension
        
        self.num_heads = num_heads
        self.linear_stretch = linear_stretch
        self.num_layers = num_layers
        self.use_pos_enc = use_pos_enc
        
        self.device = device
        self.dtype = dtype
        
        self.embd = nn.Embedding(vocab_size, self.model_dim, padding_idx=0,
                                dtype=self.dtype, device=self.device)
        
        self.layer_list = nn.ModuleList([])
        for _ in range(self.num_layers):
            layer = EncoderLayer(self.model_dim, self.model_dim, self.kq_dimension,
                 num_heads=self.num_heads, linear_stretch=self.linear_stretch,
                 device=self.device, dtype=self.dtype
                )
            self.layer_list.append(layer)
        
    
    def positional_encoding(self, seq_len, d_model):
        position = torch.arange(0, seq_len, dtype=self.dtype, device=self.device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=self.dtype, device=self.device) * -(torch.log(torch.tensor(10000.0)) / d_model))
        pe = torch.zeros(seq_len, d_model, dtype=self.dtype, device=self.device)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)
    
    def forward(self, x):
        Z = self.embd(x) * torch.sqrt(torch.tensor(self.model_dim, dtype=self.dtype, device=self.device))
        if self.use_pos_enc:
            seq_len = x.size(1) 
            pe = self.positional_encoding(seq_len, self.model_dim)
            Z = Z + pe
        print(Z.shape)
        for layer in self.layer_list:
            Z = layer(Z)
        return Z
        
        
        
        
        

In [22]:
a = torch.randint(0, 500, (10,15), device=torch.device('cuda', 1))
enc = ENCODER(128, 256, 500)
b = enc.forward(a)
b.shape

torch.Size([10, 15, 128])


torch.Size([10, 15, 128])