In [1]:
import torch
import torch.nn as nn

**Attention mechanism**
1. Produce queries ,keys, values,
2. Calculate `attention_scores`, mask those above diagonal to prevent cheating
3. Put `attention_scores` into softmax function, then do weighted sum of values 

In [2]:
class SimpleAttention(nn.Module):
    def __init__(self, cfg, qkv_bias=False):
        super().__init__()
        d_in = cfg['n_embd']
        d_out = cfg['n_embd']
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    def forward(self, x):
        # Step1: produce queries, keys, values
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # Step2: calculate attention_scores
        attn_scores = queries @ keys.transpose(-1, -2)
        mask = torch.tril(torch.ones_like(attn_scores), diagonal=0)
        attn_scores.masked_fill_(mask==0, -torch.inf)

        # Step3
        weight_scores = torch.softmax(attn_scores, dim=-1)
        attn_out = weight_scores @ values

        return attn_out

In [3]:
GPT2_CONFIG = {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_layer": 12,
  "n_positions": 1024,
  "resid_pdrop": 0.1,
  "summary_activation": None,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": True,
  "summary_type": "cls_index",
  "summary_use_proj": True,
  "task_specific_params": {
    "text-generation": {
      "do_sample": True,
      "max_length": 50
    }
  },
  "vocab_size": 50257
}

In [4]:
import tiktoken
tokenizer = tiktoken.get_encoding('gpt2')
print(tokenizer.n_vocab)

50257


In [5]:
embedding_layer = nn.Embedding(GPT2_CONFIG['vocab_size'], GPT2_CONFIG['n_embd'])
pos_embedding_layer = nn.Embedding(GPT2_CONFIG['n_positions'], GPT2_CONFIG['n_embd'])
txt = 'Do you know who am i?'
tokens =  tokenizer.encode(txt)
token_len = len(tokens)
tokens = torch.tensor(tokens) # convert to pytorch tensor for compapility
embeded_vec = embedding_layer(tokens)
pos_vec = embedding_layer(torch.arange(token_len))
input_vec = embeded_vec + pos_vec
print(input_vec.shape)
print(input_vec)

torch.Size([7, 768])
tensor([[-1.5461,  1.7754, -1.7155,  ..., -1.6903,  0.1360,  1.1227],
        [-0.5807, -0.0902,  1.2370,  ..., -3.0944, -1.6663,  0.7922],
        [ 0.1952, -0.4214, -0.4008,  ...,  0.8180, -2.9210, -0.1329],
        ...,
        [-0.6921, -0.5745, -0.9969,  ...,  0.9060,  2.9993,  1.8954],
        [-1.0469, -0.6413, -0.7287,  ..., -1.4326, -0.2984,  2.2459],
        [-1.0589,  0.3241,  2.0155,  ...,  1.3143,  0.3874,  1.2294]],
       grad_fn=<AddBackward0>)


In [6]:
sa = SimpleAttention(GPT2_CONFIG)
outputs = sa(input_vec)
print(input_vec.shape)
print(input_vec)

torch.Size([7, 768])
tensor([[-1.5461,  1.7754, -1.7155,  ..., -1.6903,  0.1360,  1.1227],
        [-0.5807, -0.0902,  1.2370,  ..., -3.0944, -1.6663,  0.7922],
        [ 0.1952, -0.4214, -0.4008,  ...,  0.8180, -2.9210, -0.1329],
        ...,
        [-0.6921, -0.5745, -0.9969,  ...,  0.9060,  2.9993,  1.8954],
        [-1.0469, -0.6413, -0.7287,  ..., -1.4326, -0.2984,  2.2459],
        [-1.0589,  0.3241,  2.0155,  ...,  1.3143,  0.3874,  1.2294]],
       grad_fn=<AddBackward0>)


**Multihead attention**
Above one is good, for multihead mechanism, jsut put it into a list, but computation is not efficient, for efficiency we can parallize it.  
    *Instruction: Same as above*

In [21]:
class MultiHeadAttention(nn.Module):
    def __init__(self, cfg, qkv_bias=False):
        super().__init__()
        d_in = cfg['n_embd']
        d_out = cfg['n_embd']
        self.n_head = cfg['n_head']
        
        assert d_out % self.n_head == 0

        self.head_dim = d_out // self.n_head
        self.W_keys = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_queries = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_values = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    def forward(self, x):
        b, n_tokens, dim = x.shape
        # Step 1: Calculate keys, queries, values for batch
        keys = self.W_keys(x)
        queries = self.W_queries(x)
        values = self.W_values(x)

        # Step 2: calculate attn_scores, feed it to softmax, mask it, get weight_scores
        keys = keys.view(b, n_tokens, self.n_head, self.head_dim)
        queries = queries.view(b, n_tokens, self.n_head, self.head_dim)
        values = values.view(b, n_tokens, self.n_head, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = keys @ queries.transpose(-1, -2)
        weight_scores = torch.softmax(attn_scores, dim=-1)

        weighted_sum = weight_scores @ keys

        weighted_sum = weighted_sum.reshape(b, n_tokens, -1)
        return weighted_sum

In [22]:
mha = MultiHeadAttention(GPT2_CONFIG)
text = "Who am i?"
ids = tokenizer.encode(text)
context_length = len(ids)
ids = torch.tensor(ids) # convert to tensor pytorch for compability
embed_vec = embedding_layer(ids)
pos_vec = pos_embedding_layer(torch.arange(context_length))
input_vec = embed_vec + pos_vec
print(input_vec.shape)

torch.Size([4, 768])


In [23]:
input_vec = input_vec.unsqueeze(0)
mha_output = mha(input_vec)
print(mha_output.shape)

torch.Size([1, 4, 768])
