In [3]:
!pip install datasets
!pip install transformers
!pip install torch torchvision torchaudio



In [None]:
from datasets import load_dataset
#NOTE: they used train split, wiht ~3.5M examples
ds = load_dataset("wmt/wmt14", "de-en", split='train')
training_data = ds['translation']

In [4]:
import torch
import math

num_layers = 1 #TODO: change to 6
#TODO: is seq_len supposed to equal d_model??
seq_len = 512
d_model = 512

class FFN(torch.nn.Module):
    def __init__(self):
        super().__init__()            
        
        self.d_ff = d_model*4

        #TODO: can nn.Linear be wrapped in Parameter - probably
        self.W_1 = torch.nn.Linear(d_model, self.d_ff)
        self.b_1 = torch.nn.Parameter(torch.zeros(seq_len, self.d_ff))
        self.W_2 = torch.nn.Linear(self.d_ff, d_model)
        self.b_2 = torch.nn.Parameter(torch.zeros(seq_len, d_model))
        
    def forward(self, x):
        # assert x.shape == (seq_len, d_model)        
        #TODO: max(0, ...)
        out1 = self.W_1(x) + self.b_1
        out2 = self.W_2(out1) + self.b_2

        return out2
        # print(out2.shape)     

class MHA(torch.nn.Module):
    def __init__(self, h=8, has_mask=False):  
        super().__init__()          
        
        self.has_mask = has_mask
        self.d_k = d_model // h

        self.d_v = self.d_k        
        
        #self.scale = 1/math.sqrt(self.d_k)
        self.scale = 1

        assert d_model == 512
        assert self.d_k == 64
        assert self.d_v == self.d_k

        self.W_Q = torch.nn.Parameter(torch.nn.init.xavier_uniform_(torch.zeros(size=(d_model, self.d_k))))
        self.W_K = torch.nn.Parameter(torch.nn.init.xavier_uniform_(torch.zeros(size=(d_model, self.d_k))))
        self.W_V = torch.nn.Parameter(torch.nn.init.xavier_uniform_(torch.zeros(size=(d_model, self.d_v))))
        self.W_O = torch.nn.Parameter(torch.nn.init.xavier_uniform_(torch.zeros(size=(h*self.d_v, d_model))))
                
    #TODO: shouldnt need separate x: x_k and x_v
    def attn(self, x_k, x_v):
        # print(f"x: {x_k.shape}")
        # print(f"W_Q: {self.W_Q}")
        Q = x_k @ self.W_Q
        K = x_k @ self.W_K
        V = x_v @ self.W_V

        mask = torch.ones(seq_len, seq_len) 
        if self.has_mask:
            for i in range(seq_len):
                for j in range(seq_len):
                    if i == j and i < seq_len-1:
                        mask[i+1:, :] = -float('inf')

        # TODO: sanity check - softmax along seq_len dim
        sm = torch.softmax(input=self.scale*mask*(Q@K.T), dim=0)
        # print(f"Q: {Q.shape}")
        # print(f"K.T: {K.T.shape}")        
        # print(f"Q@K.T: {(Q@K.T).shape}")
        # print(f"V: {V.shape}")
        head = sm @ V
        return head

    def forward(self, x_k, x_v):
        # assert x_k.shape == (seq_len, d_model)
        # assert x_v.shape == (seq_len, d_model)
        
        heads = torch.cat([self.attn(x_k, x_v) for _ in range(8)], dim=1)
        res = heads @ self.W_O
        return res

In [11]:
class Encoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.M = MHA(has_mask=False)
        self.F = FFN()
        
    def forward(self, x):
        mha = self.M(x_k=x, x_v=x)

        #sl1 = torch.nn.LayerNorm(x + mha)
        sl1 = x + mha        
        # print(sl1.shape)

        ffn = self.F(sl1)
        
        #sl2 = torch.nn.LayerNorm(sl1 + ffn)
        sl2 = sl1 + ffn
        # print(sl2.shape)
        
        return sl2

class Decoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.M1 = MHA(has_mask=True)
        self.M2 = MHA(has_mask=False)
        self.F = FFN()
    
    def forward(self, x):
        #TODO: add LayerNorm after each sl{i}
        mmha = self.M1(x_k=x, x_v=x)

        sl1 = x + mmha

        mha = self.M2(x_k=x, x_v=sl1)

        sl2 = sl1 + mha

        ffn = self.F(sl2)

        sl3 = sl2 + ffn

        return sl3

class EncoderDecoder(torch.nn.Module):      
    def __init__(self):
        super().__init__()        
        self.encs = torch.nn.ModuleList([Encoder() for _ in range(num_layers)])
        self.decs = torch.nn.ModuleList([Decoder() for _ in range(num_layers)])

    def forward(self, x):
        # assert x.shape == (seq_len, d_model)

        enc_out = x
        for enc in self.encs:
            enc_out = enc(enc_out)

        dec_out = enc_out

        for dec in self.decs:
            dec_out = dec(dec_out)
            
        return dec_out   

In [12]:
vocab_dim = 37000 #50257

class Embed(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = torch.nn.Embedding(num_embeddings=vocab_dim, embedding_dim=d_model)#, padding_idx=50257)

    def forward(self, x):
        res = self.emb(x)
        print(f"In emb: {res}")
        return res

In [13]:
class Transformer(torch.nn.Module):
    def __init__(self):
       super().__init__()
       self.input_embed = Embed()
       self.enc_dec = EncoderDecoder()       
       self.linear = torch.nn.Linear(seq_len, vocab_dim)
        
    def forward(self, x):
       #TODO: add linear layer - torch.nn.Sequential
       print(f"x: {x.shape}")

       out = self.input_embed(x)
       out = self.enc_dec(out) 
       out = self.linear(out) 
        
       return out

       #TODO: once again this is along seq_len
       # final_softmax = torch.softmax(input=ed, dim=0)
       # return final_softmax

In [14]:
from transformers import AutoTokenizer

model = Transformer()

tokenizer = AutoTokenizer.from_pretrained("gpt2")

#TODO: figure out tokenizer
# tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.add_special_tokens({'pad_token': '?'})

num_epochs = 1
criterion = torch.nn.CrossEntropyLoss()

optim = torch.optim.Adam(params=model.parameters(), betas=(0.9,0.98), eps=10E-9)

for epoch in range(num_epochs):
    for sample in training_data[:1]:
        optim.zero_grad()

        tokenized_de = tokenizer.encode(sample['de'], padding='max_length', max_length=seq_len)
        tokenized_en = tokenizer.encode(sample['en'], padding='max_length', max_length=seq_len)

        # print(tokenized_en)
        
        targets = torch.tensor(tokenized_de)
        inputs = torch.tensor(tokenized_en)
        
        outputs = model(inputs)

        loss = criterion(targets, inputs)

        loss.backward()        
        optim.step()



x: torch.Size([512])
In emb: tensor([[-1.2018,  0.0822, -1.0743,  ..., -0.0747,  1.2666,  0.4669],
        [-0.6373,  0.7430, -0.5468,  ..., -1.6726, -0.7601,  0.3060],
        [-1.1435,  0.5712, -0.6182,  ..., -0.0038, -0.5941, -1.3934],
        ...,
        [-0.2326,  0.1957, -0.5240,  ..., -0.4980, -2.1349,  0.1179],
        [-0.2326,  0.1957, -0.5240,  ..., -0.4980, -2.1349,  0.1179],
        [-0.2326,  0.1957, -0.5240,  ..., -0.4980, -2.1349,  0.1179]],
       grad_fn=<EmbeddingBackward0>)
emb: torch.Size([512, 512])
enc-dec: tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<AddBackward0>)


RuntimeError: Expected floating point type for target with class probabilities, got Long