In [2]:
!pip install datasets
!pip install transformers
!pip install torch torchvision torchaudio



In [99]:
import torch
import math

num_layers = 1
#TODO: is seq_len supposed to equal d_model??
seq_len = 512
d_model = 512

class FFN(torch.nn.Module):
    def __init__(self):
        super().__init__()            
        
        self.d_ff = d_model*4
    
        self.W_1 = torch.nn.Linear(d_model, self.d_ff)
        self.b_1 = torch.nn.Parameter(torch.zeros(seq_len, self.d_ff))
        self.W_2 = torch.nn.Linear(self.d_ff, d_model)
        self.b_2 = torch.nn.Parameter(torch.zeros(seq_len, d_model))
        
    def forward(self, x):
        # assert x.shape == (seq_len, d_model)        
        #TODO: max(0, ...)
        out1 = self.W_1(x) + self.b_1
        out2 = self.W_2(out1) + self.b_2

        return out2
        # print(out2.shape)     

class MHA(torch.nn.Module):
    def __init__(self, h=8, has_mask=False):  
        super().__init__()          
        
        self.has_mask = has_mask
        self.d_k = 8
        # self.d_k = d_model // h

        self.d_v = self.d_k        
        
        #self.scale = 1/math.sqrt(self.d_k)
        self.scale = 1

        assert d_model == 512
        assert self.d_k == 8
        assert self.d_v == self.d_k

        self.W_Q = torch.nn.Parameter(torch.nn.init.xavier_uniform_(torch.zeros(size=(d_model, self.d_k))))
        self.W_K = torch.nn.Parameter(torch.nn.init.xavier_uniform_(torch.zeros(size=(d_model, self.d_k))))
        self.W_V = torch.nn.Parameter(torch.nn.init.xavier_uniform_(torch.zeros(size=(d_model, self.d_v))))
        self.W_O = torch.nn.Parameter(torch.nn.init.xavier_uniform_(torch.zeros(size=(h*self.d_v, d_model))))
                
    #TODO: shouldnt need separate x: x_k and x_v
    def attn(self, x_k, x_v):
        print(f"x: {x_k.shape}")
        print(f"W_Q: {self.W_Q}")
        Q = x_k @ self.W_Q
        K = x_k @ self.W_K
        V = x_v @ self.W_V

        mask = torch.ones(seq_len, seq_len) 
        # if self.has_mask:
        #     for i in range(seq_len):
        #         for j in range(seq_len):
        #             if i == j and i < seq_len-1:
        #                 mask[i+1:, :] = -float('inf')

        # TODO: sanity check - softmax along seq_len dim
        sm = torch.softmax(input=self.scale*mask*(Q@K.T), dim=0)
        print(f"Q: {Q.shape}")
        print(f"K.T: {K.T.shape}")        
        print(f"Q@K.T: {(Q@K.T).shape}")
        print(f"V: {V.shape}")
        head = sm @ V
        return head

    def forward(self, x_k, x_v):
        # assert x_k.shape == (seq_len, d_model)
        # assert x_v.shape == (seq_len, d_model)
        
        heads = torch.cat([self.attn(x_k, x_v) for _ in range(8)], dim=1)
        res = heads @ self.W_O
        return res

In [100]:
class Encoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.M = MHA(has_mask=False)
        self.F = FFN()
        
    def forward(self, x):
        mha = self.M(x_k=x, x_v=x)

        #sl1 = torch.nn.LayerNorm(x + mha)
        sl1 = x + mha        
        # print(sl1.shape)

        ffn = self.F(sl1)
        
        #sl2 = torch.nn.LayerNorm(sl1 + ffn)
        sl2 = sl1 + ffn
        # print(sl2.shape)
        
        return sl2

class Decoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.M1 = MHA(has_mask=True)
        self.M2 = MHA(has_mask=False)
        self.F = FFN()
    
    def forward(self, x):
        #TODO: add LayerNorm after each sl{i}
        mmha = self.M1(x_k=x, x_v=x)

        sl1 = x + mmha

        mha = self.M2(x_k=x, x_v=sl1)

        sl2 = sl1 + mha

        ffn = self.F(sl2)

        sl3 = sl2 + ffn

        return sl3

class EncoderDecoder(torch.nn.Module):      
    def __init__(self):
        super().__init__()        
        self.encs = torch.nn.ModuleList([Encoder() for _ in range(num_layers)])
        self.decs = torch.nn.ModuleList([Decoder() for _ in range(num_layers)])

    def forward(self, x):
        # assert x.shape == (seq_len, d_model)        
        for i in range(num_layers):
            if i == 0:
                self.encs[i] = self.encs[i](x)
            else:
                self.encs[i] = self.encs[i](self.encs[i-1])
                
            self.decs[i] = self.decs[i](self.encs[i])

        return self.decs[num_layers-1]      

In [101]:
class Transformer(torch.nn.Module):
    def __init__(self):
       super().__init__()
       self.enc_dec = EncoderDecoder()                  
        
    def forward(self, x):
       #TODO: add linear layer - torch.nn.Sequential
       return torch.softmax(input=self.enc_dec(x), dim=0) #TODO: once again this is along seq_len

In [102]:
model = Transformer()
[param for param in model.parameters()]
# model(torch.Tensor(seq_len, d_model))

[Parameter containing:
 tensor([[-0.1036, -0.0180, -0.0282,  ...,  0.0656,  0.0069, -0.0363],
         [-0.0805,  0.0056, -0.0150,  ...,  0.0313, -0.0909,  0.1062],
         [ 0.0084,  0.0085,  0.0661,  ...,  0.0457,  0.0272, -0.1020],
         ...,
         [-0.0434,  0.0085,  0.0411,  ..., -0.0610,  0.0642,  0.0168],
         [-0.0841,  0.0927, -0.0093,  ..., -0.0941, -0.0790,  0.0642],
         [-0.0241,  0.0683,  0.0180,  ...,  0.1024,  0.0413,  0.0603]],
        requires_grad=True),
 Parameter containing:
 tensor([[ 0.0566, -0.0924,  0.0604,  ...,  0.0898,  0.0738, -0.0218],
         [ 0.0376,  0.1005,  0.0214,  ..., -0.0261, -0.0470,  0.0928],
         [-0.1011,  0.0783, -0.0246,  ..., -0.0312, -0.0458,  0.0581],
         ...,
         [-0.1035, -0.0899,  0.0718,  ..., -0.0428,  0.0426, -0.0373],
         [ 0.0787,  0.0682, -0.0895,  ...,  0.1059, -0.0511,  0.0518],
         [ 0.0091,  0.0351, -0.0589,  ...,  0.0681, -0.0370,  0.1045]],
        requires_grad=True),
 Parameter con

In [103]:
from datasets import load_dataset
from transformers import AutoTokenizer

#NOTE: they used train split, wiht ~3.5M examples
ds = load_dataset("wmt/wmt14", "de-en", split='validation')
training_data = ds['translation']

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

um_epochs = 100
criterion = torch.nn.CrossEntropyLoss()

optim = torch.optim.Adam(params=model.parameters(), betas=(0.9,0.98), eps=10E-9)

for epoch in range(num_epochs):
    for sample in training_data[:3]:
        optim.zero_grad()

        real = torch.tensor(tokenizer.encode(sample['de'], padding='max_length', max_length=seq_len)).float()
        fake = torch.tensor(tokenizer.encode(sample['en'], padding='max_length', max_length=seq_len)).float()
        # print(f"real: {real}")
        
        outputs = model(fake)

        loss = criterion(real, fake)

        loss.backward()        
        optim.step()

x: torch.Size([512])
W_Q: Parameter containing:
tensor([[-0.1036, -0.0180, -0.0282,  ...,  0.0656,  0.0069, -0.0363],
        [-0.0805,  0.0056, -0.0150,  ...,  0.0313, -0.0909,  0.1062],
        [ 0.0084,  0.0085,  0.0661,  ...,  0.0457,  0.0272, -0.1020],
        ...,
        [-0.0434,  0.0085,  0.0411,  ..., -0.0610,  0.0642,  0.0168],
        [-0.0841,  0.0927, -0.0093,  ..., -0.0941, -0.0790,  0.0642],
        [-0.0241,  0.0683,  0.0180,  ...,  0.1024,  0.0413,  0.0603]],
       requires_grad=True)
Q: torch.Size([8])
K.T: torch.Size([8])
Q@K.T: torch.Size([])
V: torch.Size([8])


RuntimeError: size mismatch, got input (512), mat (512x512), vec (8)