In [6]:
!pip install datasets
!pip install transformers
!pip install torch torchvision torchaudio

Collecting torch
  Downloading torch-2.4.1-cp312-cp312-manylinux1_x86_64.whl.metadata (26 kB)
Collecting torchvision
  Downloading torchvision-0.19.1-cp312-cp312-manylinux1_x86_64.whl.metadata (6.0 kB)
Collecting torchaudio
  Downloading torchaudio-2.4.1-cp312-cp312-manylinux1_x86_64.whl.metadata (6.4 kB)
Collecting sympy (from torch)
  Downloading sympy-1.13.2-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 k

In [36]:
from datasets import load_dataset
ds = load_dataset("wmt/wmt14", "cs-en")

In [37]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")



In [42]:
import torch
import math

num_layers=6
seq_len = 1024 #TODO: this should be input dependent
d_model = 512

class FFN:
    def __init__(self, x):
        assert x.shape == (seq_len, d_model)
        
        self.d_ff = d_model*4
    
        self.W_1 = torch.nn.Linear(d_model, self.d_ff)
        self.b_1 = torch.zeros(seq_len, self.d_ff)
        
    def fwd(self):
        #TODO: max(0, ...)
        out1 = self.W_1(x) + self.b_1

        self.W_2 = torch.nn.Linear(self.d_ff, d_model)
        self.b_2 = torch.zeros(seq_len, d_model)
        out2 = self.W_2(out1) + self.b_2

        return out2
        # print(out2.shape)     

class MHA:
    def __init__(self, x_k, x_v, h=8, has_mask=False):
        assert x_k.shape == (seq_len, d_model)
        assert x_v.shape == (seq_len, d_model)
        
        self.has_mask = has_mask
        self.d_k = 8
        # self.d_k = d_model // h
        h_d_v = 64 
        #h_d_v = h*self.d_v
        
        self.d_v = self.d_k
        #self.scale = 1/math.sqrt(self.d_k)
        self.scale = 1

        assert d_model == 512
        assert self.d_k == 8

        self.x_k = x_k
        self.x_v = x_v
 
        self.W_Q = torch.zeros(size=(d_model, self.d_k))
        self.W_K = torch.zeros(size=(d_model, self.d_k))
        self.W_V = torch.zeros(size=(d_model, self.d_v))
        self.W_O = torch.zeros(size=(h_d_v, d_model))   
                

    def attn(self):

        Q = self.x_k @ self.W_Q
        K = self.x_k @ self.W_K
        V = self.x_v @ self.W_V

        #TODO: fix masking logic
        if self.has_mask:
            mask = torch.ones(seq_len, d_model) 
            last_idx = len(x)-1
            mask[last_idx:, :] = -float('inf')
            self.x_k = self.x_k * mask
            self.x_v = self.x_v * mask

        #TODO: dim=0 => softmax along seq_len
        
        head = torch.softmax(input=self.scale*(Q@K.T), dim=0) @ V
        return head

    def multi(self):
        heads = torch.cat([self.attn() for i in range(8)], dim=1)
        res = heads @ self.W_O
        return res   


class EncLayer(torch.nn.Module):
    def __init__(self):
        super().__init__()          

    def forward(self, x):
        M = MHA(x_k=x, x_v=x)
        mha = M.multi()

        #sl1 = torch.nn.LayerNorm(x + mha)
        sl1 = x + mha        
        # print(sl1.shape)

        F = FFN(sl1)
        ffn = F.fwd()
        
        #sl2 = torch.nn.LayerNorm(sl1 + ffn)
        sl2 = sl1 + ffn
        # print(sl2.shape)
        
        return sl2

class DecLayer(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        #TODO: add LayerNorm after each sl{i}
        M1 = MHA(has_mask=True, x_k=x, x_v=x)
        mmha = M1.multi()

        sl1 = x + mmha

        M2 = MHA(has_mask=False, x_k=x, x_v=sl1)
        mha = M2.multi()

        sl2 = sl1 + mha

        F = FFN(sl2)
        ffn = F.fwd()

        sl3 = sl2 + ffn

        return sl3

class EncoderDecoder(torch.nn.Module):
    def __init__(self, x):
        assert x.shape == (seq_len, d_model)
        super().__init__()        
        self.x = x
        self.encs = [None] * num_layers
        self.decs = [None] * num_layers

    def forward(self):
        for i in range(num_layers):
            if i == 0:
                self.encs[i] = EncLayer()(self.x)
            else:
                self.encs[i] = EncLayer()(self.encs[i-1])
                
            self.decs[i] = DecLayer()(self.encs[i])

        return self.decs[num_layers-1]

class Transformer():
    def __init__(self, x):
       self.x = x 
       self.enc_dec = EncoderDecoder(self.x)()          

    def fwd(self):
       #TODO: add linear layer
       return torch.softmax(input=self.enc_dec, dim=0) #TODO: once again this is along seq_len
        

In [44]:
x = torch.rand(size=(seq_len, d_model))
trans = Transformer(x)
res = trans.fwd()
res.shape

torch.Size([1024, 512])