In [765]:
import torch

from tqdm import trange

import warnings
warnings.filterwarnings("ignore", category=UserWarning) 

In [777]:
class Transformer():
    
    def __init__(self, d_model, vocab, n_heads, n_stack):
        if(d_model % n_heads != 0):
            raise Exception("Model dimensions must be divisible by number of attention heads")
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_stack = n_stack
        self.vocab = sorted(set(vocab))
        
        self.d_head = int(self.d_model / self.n_heads)
        self.d_head_t = torch.tensor(self.d_head)
        self.ff_hidden_size = self.d_model*4
        self.vocab_size = len(self.vocab)
        
        self.v_to_i = {ch:idx for (idx, ch) in enumerate(self.vocab)}
        self.i_to_v = {idx:ch for (idx, ch) in enumerate(self.vocab)}
        
        self.encoder_params = [self.__generate_encoder_params() for _ in range(self.n_stack)]
        self.decoder_params = [self.__generate_decoder_params() for _ in range(self.n_stack)]
        self.embeddings = torch.randn((self.vocab_size, d_model), dtype=torch.float)
        self.projection = torch.randn((d_model, self.vocab_size), dtype=torch.float)
        self.params = [
            *[p for param_group in self.encoder_params for p in param_group],
            *[p for param_group in self.decoder_params for p in param_group],
            self.embeddings,
            self.projection
        ]
        for p in self.params:
            p.requires_grad = True
        self.n_params = sum([p.nelement() for p in self.params])
        print(f"Transformer created with {self.n_params} parameters")
        
    def predict_next_token(self, decoder_in):
        decoder_out = self.decode(decoder_in)
        last_token = decoder_out[-1]
        
        pred_out = (last_token @ self.projection).softmax(dim=0)
        next_idx = torch.argmax(pred_out)
        return self.i_to_v[next_idx.item()]    
    
    def predict_all(self, decoder_in):
        decoder_out = self.decode(decoder_in)
        pred_out = (decoder_out @ self.projection).softmax(dim=1)
        return pred_out
        
    def encode(self, inp):
        encoder_outs = torch.empty((self.n_stack, inp.shape[0], inp.shape[1]))
        for (idx, encoder_param_group) in enumerate(self.encoder_params):
            [mha_params, ln_1_params, ff_params, ln_2_params] = self.__interpret_encoder_params(encoder_param_group)
            
            # multi-head attention
            mha = self.__multi_head_attention(inp, inp, inp, mha_params, False)
            
            # add & norm
            mha_out = self.__layer_norm(mha + inp, ln_1_params)
            
            # feed-forward
            ff = self.__feed_forward(mha_out, ff_params)
            
            # add & norm
            encoder_outs[idx] = self.__layer_norm(ff + mha_out, ln_2_params)
            
        self.encoder_outs = encoder_outs
        return self.encoder_outs
    
    def decode(self, inp):
        if(self.n_stack > self.encoder_outs.shape[0]):
            raise Exception("Encode before decoding")
            
        stack_in = inp
        for (idx, decoder_param_group) in enumerate(self.decoder_params):
            [ 
                masked_mha_params,
                ln_1_params,
                mha_params,
                ln_2_params,
                ff_params,
                ln_3_params
            ] = self.__interpret_decoder_params(decoder_param_group)
            
            # masked multi-head attention
            masked_mha = self.__multi_head_attention(stack_in, stack_in, stack_in, masked_mha_params, True)
            
            # add & norm
            masked_mha_out= self.__layer_norm(masked_mha + stack_in, ln_1_params)
            
            # multi-head attention (Q - from masked_mha_out, K - from encoder, V - from encoder)
            mha = self.__multi_head_attention(masked_mha_out, self.encoder_outs[idx], self.encoder_outs[idx], mha_params, False)
            
            # add & norm
            mha_out = self.__layer_norm(mha + masked_mha_out, ln_2_params)
            
            # feed-forward
            ff = self.__feed_forward(mha_out, ff_params)
            
            # add & norm
            stack_in = self.__layer_norm(ff + mha_out, ln_3_params)
        return stack_in
            
    def __feed_forward(self, inp, ff_params):
        [W1, b1, W2, b2] = ff_params
        return (((inp @ W1) + b1).relu() @ W2) + b2
    
    def __layer_norm(self, inp, ln_params):
        [scale, bias] = ln_params
        mean = inp.mean(dim=1).unsqueeze(dim=1)
        std = inp.std(dim=1).unsqueeze(dim=1)

        norm = (inp - mean)/std

        return (norm * scale) + bias
    
    def __multi_head_attention(self, q_inp, k_inp, v_inp, params, masked=False):
        W_o = params[0]
        W_qkv_groups = params[1:]
        
        if(len(W_qkv_groups) != self.n_heads):
            raise Exception(f"Params doesn't match num of heads. Expected: {self.n_heads}. Found: {len(W_qkv_groups)}")
            
        multi_head_out = torch.tensor([])
        for head_idx in range(self.n_heads):
            [W_q, W_k, W_v] = W_qkv_groups[head_idx]        
            head_out = self.__single_head_attention(q_inp, k_inp, v_inp, W_q, W_k, W_v, masked)
            multi_head_out = torch.cat([multi_head_out, head_out], dim=1)
            
        return multi_head_out @ W_o
    
    def __single_head_attention(self, q_inp, k_inp, v_inp, W_q, W_k, W_v, masked):
        # linear layers
        Q = q_inp @ W_q
        K = k_inp @ W_k
        V = v_inp @ W_v
        
        # scaled dot-product attention
        mat_mul = Q @ K.T
        if(masked):
            mask = torch.tril(torch.ones(mat_mul.shape))
            mat_mul = torch.where(mask == 0, float("-inf"), mat_mul)
            
        scale = mat_mul/torch.sqrt(self.d_head_t)
        softmax = torch.nn.functional.softmax(scale, dim=None)
        return softmax @ V
        
    def embed_seq(self, seq):
        one_hot = torch.nn.functional.one_hot(torch.tensor([self.v_to_i[v] for v in seq]), self.vocab_size).float()
        embedded = one_hot @ self.embeddings
        return torch.stack([emb + self.__get_positional_encoding(pos) for (pos, emb) in enumerate(embedded)])
        
    def __get_positional_encoding(self, pos):
        v_dims = torch.arange(0, (self.d_model+1)/2, step=1, dtype=torch.float)

        def get_pe(v):
            exponent = (2*v)/self.d_model
            return pos/torch.pow(10000,exponent)

        pre_sinusoid = get_pe(v_dims)
        pe_even = pre_sinusoid.sin()
        pe_odd = pre_sinusoid.cos() 

        pe = torch.stack([pe_even, pe_odd])

        return pe.mT.reshape(-1)[:self.d_model]
    
    def __generate_encoder_params(self):
        return [
            ### attention ###
            # linear
            torch.randn((self.d_model, self.d_model), dtype=torch.float),
            # heads
            *[torch.randn((self.d_model, self.d_head), dtype=torch.float) for _ in range(3) for _ in range(self.n_heads)],
            ### layer norm 1 ###
            torch.randn((self.d_model), dtype=torch.float),
            torch.randn((self.d_model), dtype=torch.float),
            ### feed forward ###
            # W1
            torch.randn((self.d_model, self.ff_hidden_size), dtype=torch.float),
            # B1
            torch.randn((self.ff_hidden_size), dtype=torch.float),
            # W2
            torch.randn((self.ff_hidden_size, self.d_model), dtype=torch.float),
            # B4
            torch.randn((self.d_model), dtype=torch.float),
            ### layer norm 2 ###
            torch.randn((self.d_model), dtype=torch.float),
            torch.randn((self.d_model), dtype=torch.float),
        ]
    
    def __interpret_encoder_params(self, param_group):
        attention_params = param_group[:(3*self.n_heads)+1]
        param_group = param_group[(3*self.n_heads)+1:]
        
        grouped_heads = [attention_params[i+1: i+4] for i in range(0, 3*self.n_heads, 3)]
        
        layer_norm_1_params = param_group[:2]
        param_group = param_group[2:]
        
        ff_params = param_group[:4]
        param_group = param_group[4:]
        
        layer_norm_2_params = param_group[:2]
        param_group = param_group[2:]
        
        return [
            # attention group
            [
                attention_params[0],
                *grouped_heads
            ],
            # layer norm 1 group
            layer_norm_1_params,
            # feed forward group
            ff_params,
            # layer norm 2 group
            layer_norm_2_params
        ]
    
    def __generate_decoder_params(self):
        return [
            ### masked attention ###
            # linear
            torch.randn((self.d_model, self.d_model), dtype=torch.float),
            # heads
            *[torch.randn((self.d_model, self.d_head), dtype=torch.float) for _ in range(3) for _ in range(self.n_heads)],
            ### layer norm 1 ###
            torch.randn((self.d_model), dtype=torch.float),
            torch.randn((self.d_model), dtype=torch.float),
            ### attention ###
            # linear
            torch.randn((self.d_model, self.d_model), dtype=torch.float),
            # heads
            *[torch.randn((self.d_model, self.d_head), dtype=torch.float) for _ in range(3) for _ in range(self.n_heads)],
            ### layer norm 2 ###
            torch.randn((self.d_model), dtype=torch.float),
            torch.randn((self.d_model), dtype=torch.float),
            ### feed forward ###
            # W1
            torch.randn((self.d_model, self.ff_hidden_size), dtype=torch.float),
            # B1
            torch.randn((self.ff_hidden_size), dtype=torch.float),
            # W2
            torch.randn((self.ff_hidden_size, self.d_model), dtype=torch.float),
            # B4
            torch.randn((self.d_model), dtype=torch.float),
            ### layer norm 3 ###
            torch.randn((self.d_model), dtype=torch.float),
            torch.randn((self.d_model), dtype=torch.float),
        ]
    
    def __interpret_decoder_params(self, param_group):
        masked_attention_params = param_group[:(3*self.n_heads)+1]
        param_group = param_group[(3*self.n_heads)+1:]
        
        grouped_masked_heads = [masked_attention_params[i+1: i+4] for i in range(0, 3*self.n_heads, 3)]
        
        layer_norm_1_params = param_group[:2]
        param_group = param_group[2:]
        
        attention_params = param_group[:(3*self.n_heads)+1]
        param_group = param_group[(3*self.n_heads)+1:]
        
        grouped_heads = [attention_params[i+1: i+4] for i in range(0, 3*self.n_heads, 3)]
        
        layer_norm_2_params = param_group[:2]
        param_group = param_group[2:]
                
        ff_params = param_group[:4]
        param_group = param_group[4:]
        
        layer_norm_3_params = param_group[:2]
        param_group = param_group[2:]
        
        
        return [
            # masked attention group
            [
                masked_attention_params[0],
                *grouped_masked_heads
            ],
            # layer norm 1 group
            layer_norm_1_params,
            # attention group
            [
                attention_params[0],
                *grouped_heads
            ],
            # layer norm 2 group
            layer_norm_2_params,
            # feed forward group
            ff_params,
            # layer norm 3 group
            layer_norm_3_params
        ]
        
    

In [837]:
raw_names = open("names.txt").read().split("\n")
names = [list(n) for n in raw_names]

open_token = "<o>"
close_token = "<c>"
vocab = sorted(set([open_token, close_token] + [ch for name in names for ch in name]))

In [838]:
dataset_size = 100
X = names[:dataset_size]
Y = [[open_token] + n[::-1] + [close_token] for n in names[:dataset_size]]

In [844]:
transformer = Transformer(2, vocab, 1, 1)

Transformer created with 264 parameters


In [845]:
X_encode_emb = [transformer.embed_seq(x) for x in X]
X_decode_emb = [transformer.embed_seq(y) for y in Y]

def one_hot_target(seq, transformer):
    one_hot = torch.nn.functional.one_hot(torch.tensor([transformer.v_to_i[v] for v in seq]), transformer.vocab_size).float()
    return one_hot

Y_targets = [one_hot_target(y, transformer)[1:] for y in Y]

In [855]:
cross_entropy_loss = torch.nn.CrossEntropyLoss()
optimiser = torch.optim.Adamax(transformer.params, lr=0.1)

In [859]:
with trange(1000) as pbar:
    for it in pbar:
        batch_size = 16
        ix = torch.randint(0, dataset_size, (batch_size,))

        optimiser.zero_grad(set_to_none=True)
        
        loss = 0
        for idx in ix:
            x_encode = transformer.embed_seq(X[idx])
            x_decode = transformer.embed_seq(Y[idx])

            transformer.encode(x_encode)

            out = transformer.predict_all(x_decode)
            loss += cross_entropy_loss(out[:-1], Y_targets[idx])

        loss /= batch_size

        pbar.set_description(f"Loss: {loss.item():.4f}")

        loss.backward()

        optimiser.step()   

Loss: 3.2508:   5%|█▎                         | 48/1000 [00:01<00:38, 24.57it/s]


KeyboardInterrupt: 

In [858]:
for p in transformer.params:
    print(p.grad.mean())

tensor(-1.1535e-32)
tensor(-8.9768e-31)
tensor(1.3967e-31)
tensor(6.0826e-32)
tensor(-4.4275e-25)
tensor(5.7013e-25)
tensor(4.8852e-26)
tensor(2.9182e-26)
tensor(5.5932e-25)
tensor(5.0250e-25)
tensor(-7.6124e-19)
tensor(5.2250e-18)
tensor(-6.8561e-22)
tensor(-6.0481e-22)
tensor(1.0923e-22)
tensor(4.3129e-22)
tensor(-2.5657e-18)
tensor(-1.6265e-18)
tensor(1.6910e-18)
tensor(-4.3059e-19)
tensor(-4.0898e-19)
tensor(-6.3976e-19)
tensor(3.3482e-11)
tensor(4.5965e-11)
tensor(-3.3333e-12)
tensor(-1.4998e-12)
tensor(1.4114e-11)
tensor(1.2960e-11)
tensor(-6.6900e-06)
tensor(-0.0002)
tensor(6.0268e-24)
tensor(7.0575e-11)


In [800]:
transformer.encode(transformer.embed_seq(names[0]))
print(names[0])
ch = open_token
word = [ch]
while(ch != close_token and len(word) < 100):
    ch = transformer.predict_next_token(transformer.embed_seq([open_token]))
    word += [ch]

print(word[1:-1])

['a']
['a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a']
