In [27]:
import torch
import torch.nn as nn

**Attention mechanism**
1. Produce queries ,keys, values,
2. Calculate `attention_scores`, mask those above diagonal to prevent cheating
3. Put `attention_scores` into softmax function, then do weighted sum of values 

In [28]:
class SimpleAttention(nn.Module):
    def __init__(self, cfg, qkv_bias=False):
        super().__init__()
        d_in = cfg['n_embd']
        d_out = cfg['n_embd']
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    def forward(self, x):
        # Step1: produce queries, keys, values
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # Step2: calculate attention_scores
        attn_scores = queries @ keys.transpose(-1, -2)
        mask = torch.tril(torch.ones_like(attn_scores), diagonal=0)
        attn_scores.masked_fill_(mask==0, -torch.inf)

        # Step3
        weight_scores = torch.softmax(attn_scores, dim=-1)
        attn_out = weight_scores @ values

        return attn_out

In [29]:
GPT2_CONFIG = {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_layer": 12,
  "n_positions": 1024,
  "resid_pdrop": 0.1,
  "summary_activation": None,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": True,
  "summary_type": "cls_index",
  "summary_use_proj": True,
  "task_specific_params": {
    "text-generation": {
      "do_sample": True,
      "max_length": 50
    }
  },
  "vocab_size": 50257
}

In [30]:
import tiktoken
tokenizer = tiktoken.get_encoding('gpt2')
print(tokenizer.n_vocab)

50257


In [31]:
embedding_layer = nn.Embedding(GPT2_CONFIG['vocab_size'], GPT2_CONFIG['n_embd'])
pos_embedding_layer = nn.Embedding(GPT2_CONFIG['n_positions'], GPT2_CONFIG['n_embd'])
txt = 'Do you know who am i?'
tokens =  tokenizer.encode(txt)
token_len = len(tokens)
tokens = torch.tensor(tokens) # convert to pytorch tensor for compapility
embeded_vec = embedding_layer(tokens)
pos_vec = embedding_layer(torch.arange(token_len))
input_vec = embeded_vec + pos_vec
print(input_vec.shape)
print(input_vec)

torch.Size([7, 768])
tensor([[ 3.1114,  0.4252, -1.0459,  ..., -1.6001,  0.7646,  0.4166],
        [ 0.7925,  1.1849, -0.0425,  ...,  1.8415, -0.8759, -3.4306],
        [-0.5999,  1.8658, -0.6810,  ..., -0.0278, -0.5934, -2.6497],
        ...,
        [ 0.6929, -1.1271, -1.5086,  ...,  0.0370,  1.1239, -1.1810],
        [ 0.0143,  2.1574, -0.2989,  ..., -0.2977,  1.8317,  0.3929],
        [-0.2165, -1.9488,  0.0999,  ...,  0.1306, -3.2307, -1.2602]],
       grad_fn=<AddBackward0>)


In [32]:
sa = SimpleAttention(GPT2_CONFIG)
outputs = sa(input_vec)
print(input_vec.shape)
print(input_vec)

torch.Size([7, 768])
tensor([[ 3.1114,  0.4252, -1.0459,  ..., -1.6001,  0.7646,  0.4166],
        [ 0.7925,  1.1849, -0.0425,  ...,  1.8415, -0.8759, -3.4306],
        [-0.5999,  1.8658, -0.6810,  ..., -0.0278, -0.5934, -2.6497],
        ...,
        [ 0.6929, -1.1271, -1.5086,  ...,  0.0370,  1.1239, -1.1810],
        [ 0.0143,  2.1574, -0.2989,  ..., -0.2977,  1.8317,  0.3929],
        [-0.2165, -1.9488,  0.0999,  ...,  0.1306, -3.2307, -1.2602]],
       grad_fn=<AddBackward0>)


**Multihead attention**
Above one is good, for multihead mechanism, jsut put it into a list, but computation is not efficient, for efficiency we can parallize it.  
    *Instruction: Same as above*

In [33]:
class MultiHeadAttention(nn.Module):
    def __init__(self, cfg, qkv_bias=False):
        super().__init__()
        d_in = cfg['n_embd']
        d_out = cfg['n_embd']
        self.n_head = cfg['n_head']
        
        assert d_out % self.n_head == 0

        self.head_dim = d_out // self.n_head
        self.W_keys = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_queries = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_values = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    def forward(self, x, attn_mask=None):
        b, n_tokens, dim = x.shape
        # Step 1: Calculate keys, queries, values for batch
        keys = self.W_keys(x)
        queries = self.W_queries(x)
        values = self.W_values(x)

        # Step 2: calculate attn_scores, feed it to softmax, mask it, get weight_scores
        keys = keys.view(b, n_tokens, self.n_head, self.head_dim)
        queries = queries.view(b, n_tokens, self.n_head, self.head_dim)
        values = values.view(b, n_tokens, self.n_head, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(-1, -2)
        subsequent_mask = torch.tril(torch.ones(n_tokens, n_tokens), diagonal=0)
        attn_scores.masked_fill_(subsequent_mask == 0, -torch.inf) # mask subsequent tokens
        if attn_mask is not None: # mask padded tokens
            attn_scores.masked_fill_(~attn_mask, -torch.inf)
        weight_scores = torch.softmax(attn_scores, dim=-1)

        weighted_sum = weight_scores @ values

        weighted_sum = weighted_sum.reshape(b, n_tokens, -1)
        return weighted_sum, weight_scores

In [34]:
from blocks import InputPreprocess
mha = MultiHeadAttention(GPT2_CONFIG)
text = ["Who am i?",
        "Tell me what's the transformers in 5 minutes"]
tokenizer = tiktoken.get_encoding('gpt2')
ip = InputPreprocess(tokenizer ,GPT2_CONFIG)
inputs, attn_mask = ip(text)
out_attn, weight_scores = mha(inputs, attn_mask)

In [35]:
print(out_attn.shape)
print(out_attn)

torch.Size([2, 10, 768])
tensor([[[ 0.5148, -0.7446,  0.6746,  ...,  1.0137, -0.2017, -1.2037],
         [-1.1692,  0.3488,  1.2602,  ..., -0.1721,  0.1469, -0.3785],
         [-0.1008, -0.7582, -0.0482,  ..., -0.1435, -0.3514,  0.1864],
         ...,
         [ 0.2218, -0.5004,  0.1220,  ..., -0.1558,  0.7435, -0.6291],
         [ 0.1820,  0.1435,  0.1762,  ..., -0.2816, -0.6893,  0.0477],
         [ 0.6674,  0.8395, -0.0776,  ..., -0.1935, -0.3169,  1.3297]],

        [[-0.3730, -0.9076,  1.2012,  ..., -0.2333,  0.8082, -0.0533],
         [ 0.4610,  0.7832,  1.8192,  ..., -0.9667,  0.5596, -0.5501],
         [-0.5123,  1.0797,  0.1452,  ...,  1.9712,  0.1371, -0.3466],
         ...,
         [-0.0801, -0.0315, -0.5084,  ...,  0.5341, -0.0106,  0.4413],
         [ 0.6552, -0.8133, -0.6314,  ..., -0.0660, -0.5160, -0.1251],
         [-0.7225,  0.2390,  0.7831,  ..., -0.4127,  0.4306, -0.2961]]],
       grad_fn=<ViewBackward0>)


In [36]:
print(attn_mask)

tensor([[[[ True,  True,  True,  True, False, False, False, False, False, False]]],


        [[[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]]]])


In [37]:
print(weight_scores)

tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [5.9474e-01, 4.0526e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.6968e-02, 3.7213e-04, 9.8266e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [2.3043e-06, 2.5264e-03, 9.9747e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [5.2163e-01, 4.4168e-01, 3.6689e-02,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [9.5386e-01, 1.6825e-03, 4.4439e-02,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.0000e+00, 5.5031e-07, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.6840e-08, 9.9999e-01, 1.2233e-05,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [1.2996e-01, 8.5078e-01, 1.9189e-02,  ..., 0.0000