In [157]:
!pip install datasets
!pip install transformers
!pip install torch torchvision torchaudio

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [158]:
from datasets import load_dataset
#NOTE: they used train split, wiht ~3.5M examples
ds = load_dataset("wmt/wmt14", "de-en", split='train')
training_data = ds['translation']

In [159]:
import torch
import math

num_layers = 6
#TODO: is seq_len supposed to equal d_model??
seq_len = 512
d_model = 512

class FFN(torch.nn.Module):
    def __init__(self):
        super().__init__()            
        
        self.d_ff = d_model*4
        self.W_1 = torch.nn.Linear(d_model, self.d_ff)
        self.relu = torch.nn.ReLU()
        self.W_2 = torch.nn.Linear(self.d_ff, d_model)
        
    def forward(self, x):
        # assert x.shape == (seq_len, d_model)    
        out = self.W_1(x)
        out = self.relu(out)
        out = self.W_2(out)

        return out

class MHA(torch.nn.Module):
    def __init__(self, h=8, has_mask=False):  
        super().__init__()          
        
        self.has_mask = has_mask
        self.d_k = d_model // h

        self.d_v = self.d_k        
        
        self.scale = 1/math.sqrt(self.d_k)

        assert d_model == 512
        assert self.d_k == 64
        assert self.d_v == self.d_k

        self.W_Q = torch.nn.Parameter(torch.nn.init.xavier_uniform_(torch.zeros(size=(d_model, self.d_k))))
        self.W_K = torch.nn.Parameter(torch.nn.init.xavier_uniform_(torch.zeros(size=(d_model, self.d_k))))
        self.W_V = torch.nn.Parameter(torch.nn.init.xavier_uniform_(torch.zeros(size=(d_model, self.d_v))))
        self.W_O = torch.nn.Parameter(torch.nn.init.xavier_uniform_(torch.zeros(size=(h*self.d_v, d_model))))
                
    #TODO: shouldnt need separate x: x_k and x_v
    def attn(self, x_k, x_v):
        Q = x_k @ self.W_Q
        K = x_k @ self.W_K
        V = x_v @ self.W_V

        mask = torch.ones(seq_len, seq_len) 
        if self.has_mask:
            mask = torch.tril(mask)
        
        # TODO: sanity check - softmax along seq_len dim
        sm = torch.softmax(input=self.scale*mask*(Q@K.T), dim=0)
        head = sm @ V
        return head

    def forward(self, x_k, x_v):
        # assert x_k.shape == (seq_len, d_model)
        # assert x_v.shape == (seq_len, d_model)
        
        heads = torch.cat([self.attn(x_k, x_v) for _ in range(8)], dim=1)
        res = heads @ self.W_O
        return res

In [160]:
class Encoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.M = MHA(has_mask=False)
        self.F = FFN()
        
    def forward(self, x):
        mha = self.M(x_k=x, x_v=x)

        #TODO: add layernorms
        
        #sl1 = torch.nn.LayerNorm(x + mha)
        sl1 = x + mha        
        # print(sl1.shape)

        ffn = self.F(sl1)
        
        #sl2 = torch.nn.LayerNorm(sl1 + ffn)
        sl2 = sl1 + ffn
        # print(sl2.shape)

        return sl2

class Decoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.M1 = MHA(has_mask=True)
        self.M2 = MHA(has_mask=False)
        self.F = FFN()
    
    def forward(self, x, enc_out):
        #TODO: add LayerNorm after each sl{i}
        mmha = self.M1(x_k=x, x_v=x)

        sl1 = x + mmha

        # print("sl1", sl1)

        mha = self.M2(x_k=enc_out, x_v=sl1)

        sl2 = sl1 + mha

        # print("sl2", sl2)        

        ffn = self.F(sl2)

        sl3 = sl2 + ffn

        return sl3

class EncoderDecoder(torch.nn.Module):      
    def __init__(self):
        super().__init__()        
        self.encs = torch.nn.ModuleList([Encoder() for _ in range(num_layers)])
        self.decs = torch.nn.ModuleList([Decoder() for _ in range(num_layers)])

    def forward(self, x):
        # assert x.shape == (seq_len, d_model)

        enc_out = x
        for enc in self.encs:
            enc_out = enc(enc_out)

        # print("Enc_out", enc_out)
        dec_out = enc_out

        for dec in self.decs:
            dec_out = dec(enc_out, dec_out)
            
        # print("Dec_out", dec_out)
        return dec_out   

In [161]:
vocab_dim = 37000 #50257

class Embed(torch.nn.Module):
    def __init__(self, vocab_dim, d_model):
        super().__init__()
        self.emb = torch.nn.Embedding(num_embeddings=vocab_dim, embedding_dim=d_model)#, padding_idx=50257)

    def forward(self, x):
        res = self.emb(x)
        return res

In [162]:
class Transformer(torch.nn.Module):
    def __init__(self):
       super().__init__()
       self.input_embed = Embed(vocab_dim=vocab_dim, d_model=d_model)
       self.enc_dec = EncoderDecoder()       
       self.linear = torch.nn.Linear(vocab_dim, seq_len) 
        
    def forward(self, x):
       #TODO: torch.nn.Sequential
       print(f"x: {x.shape}")

       out = self.input_embed(x)
       #TODO: add PosEmbed
       print(f"out: {out.shape}")                
       out = self.enc_dec(out) 
       print(f"out: {out.shape}")      
       #TODO: figure out what linear is supposed to do
       #out = self.linear(out)
       #print(f"out: {out.shape}")

       #TODO: once again this is along seq_len
       softmax = torch.softmax(input=out, dim=0) 
       print(f"softmax: {softmax.shape}")        
       return softmax

In [163]:
from transformers import AutoTokenizer

model = Transformer()

tokenizer = AutoTokenizer.from_pretrained("gpt2")

#TODO: figure out tokenizer
# tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.add_special_tokens({'pad_token': '?'})

num_epochs = 10
criterion = torch.nn.CrossEntropyLoss()

optim = torch.optim.Adam(params=model.parameters(), betas=(0.9,0.98), eps=10E-9)

for epoch in range(num_epochs):
    for sample in training_data[:1]:
        optim.zero_grad()

        tokenized_de = tokenizer.encode(sample['de'], padding='max_length', max_length=seq_len)
        tokenized_en = tokenizer.encode(sample['en'], padding='max_length', max_length=seq_len)
        
        targets = torch.tensor(tokenized_de)
        inputs = torch.tensor(tokenized_en)


        outputs = model(inputs)
        
        print(targets[:5])        
        print(outputs[:5])

        print(f"targets: {targets.shape}")
        tgt_embed = Embed(vocab_dim=vocab_dim, d_model=d_model)
        target_embeddings = tgt_embed(targets)
        print(f"target embeddings: {target_embeddings.shape}")
        #TODO: add PosEmbed        
        print(f"outputs: {outputs.shape}")        

        loss = criterion(target_embeddings, outputs)
        print(f"loss: {loss}")

        loss.backward()        
        optim.step()

x: torch.Size([512])
out: torch.Size([512, 512])
out: torch.Size([512, 512])
softmax: torch.Size([512, 512])
tensor([   54,   798,   263,   559, 22184])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 1.]], grad_fn=<SliceBackward0>)
targets: torch.Size([512])
target embeddings: torch.Size([512, 512])
outputs: torch.Size([512, 512])
loss: 6.7332563400268555
x: torch.Size([512])
out: torch.Size([512, 512])
out: torch.Size([512, 512])
softmax: torch.Size([512, 512])
tensor([   54,   798,   263,   559, 22184])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], grad_fn=<SliceBackward0>)
targets: torch.Size([512])
target embeddings: torch.Size([512, 512])
outputs: torch.Size([512, 512])
loss: 6.80129