In [1]:
# if running on Google colab
!pip install einops
!pip install torchtyping
!pip install transformers
import torch as t
import torch.nn as nn
from torch import einsum
from einops import rearrange, repeat, reduce
import math

from google.colab import drive
drive.mount('/content/gdrive')
%cd /content/gdrive/MyDrive/mlab/days/w2d3
import gpt_tests


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchtyping
  Downloading torchtyping-0.1.4-py3-none-any.whl (17 kB)
Collecting typeguard>=2.11.1
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, torchtyping
  Attempting uninstall: typeguard
    Found existing installation: typeguard 2.7.1
    Uninstalling typeguard-2.7.1:
      Successfully uninstalled typeguard-2.7.1
Successfully installed torchtyping-0.1.4 typeguard-2.13.3
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)
[K     |█████████████████████████

# 1 Making the GPT-2 module

In [2]:
class MultiHeadedSelfAttention(nn.Module):
  def __init__(self, hidden_size, num_heads):
    super().__init__()
    self.head_size = hidden_size // num_heads
    self.attention = nn.Linear(hidden_size, 3 * num_heads * self.head_size) # output is concatenated query, key, value
    self.project = nn.Linear(hidden_size, hidden_size) # things will break if num_heads doesn't evenly divide hidden_size

  def forward(self, input, past_key_values=None, return_key_values=False):
    # calculate raw attention scores
    q, k, v = rearrange(self.attention(input), 'b sl (qkv nh hs) -> qkv b nh sl hs', qkv=3, hs=self.head_size)
    if return_key_values: 
      new_kv = t.cat((k,v), dim=-1)
      kv_cache = t.cat((past_key_values.unsqueeze(0), new_kv), dim=-2)
      k, v = rearrange(kv_cache, 'b nh sl (kv hs) -> kv b nh sl hs', kv=2)
    attn_pattern = einsum('bhqi,bhki->bhqk', q, k) / math.sqrt(self.head_size)

    # mask the attention pattern so tokens only attend to past tokens
    if not return_key_values: # not needed if given past_key_values
      seq_len = attn_pattern.size(-1)
      masked_indices = t.ones(seq_len, seq_len, device=input.device).triu(1) > 0
      attn_pattern[...,masked_indices] = -1e4

    attn_scores = attn_pattern.softmax(-1)
    attn = einsum('bhqk,bhki->bhqi',attn_scores, v)
    out = self.project(rearrange(attn, 'b nh sl hs -> b sl (nh hs)'))
    if return_key_values:
      return out, new_kv
    else:
      return out

gpt_tests.test_unidirectional_attn(MultiHeadedSelfAttention)
gpt_tests.test_attn_cache(MultiHeadedSelfAttention)

Congrats! You've passed the test!
Checking encoding:
Congrats! You've passed the test!
Checking new key and value:
Congrats! You've passed the test!


In [3]:
class Block(nn.Module):
  def __init__(self, hidden_size, num_heads, dropout=0., layer_norm_epsilon=1e-5):
    super().__init__()
    self.ln = nn.LayerNorm(hidden_size, layer_norm_epsilon)
    self.attention = MultiHeadedSelfAttention(hidden_size, num_heads)
    self.mlp = nn.Sequential(
      nn.LayerNorm(hidden_size, layer_norm_epsilon),
      nn.Linear(hidden_size, 4 * hidden_size),
      nn.GELU(),
      nn.Linear(4 * hidden_size, hidden_size),
      nn.Dropout(dropout)
    )

  def forward(self, input, past_key_values=None, return_key_values=False):
    attn_out = self.attention(self.ln(input), past_key_values, return_key_values)
    if return_key_values:
      attn_out, new_key_values = attn_out 
      return self.mlp(input + attn_out) + input + attn_out, new_key_values
    else:
      return self.mlp(input + attn_out) + input + attn_out

gpt_tests.test_gpt_block(Block)

Congrats! You've passed the test!


In [4]:
from dataclasses import dataclass
from torchtyping import TensorType

@dataclass
class GPT2Output:
    logits: TensorType["batch_size", "vocab_size"]
    final_encoding: TensorType["batch_size", "hidden_size"]

In [5]:
class Embedding(nn.Module):
  def __init__(self, vocab_size, hidden_size, max_position_embeddings):
    super().__init__()
    self.token_embed = nn.Parameter(t.randn(vocab_size, hidden_size))
    self.pos_embed   = nn.Parameter(t.randn(max_position_embeddings, hidden_size))

  def forward(self, input):
    seq_len = input.size(-1)
    pos = t.arange(0, seq_len, device=input.device)
    return self.token_embed[input] + self.pos_embed[pos]

In [206]:
class GPT2(nn.Module):
  def __init__(self, num_layers, num_heads, vocab_size, hidden_size, 
               max_position_embeddings, dropout, layer_norm_epsilon, use_cache=False):
    super().__init__()
    self.num_layers = num_layers
    self.num_heads = num_heads
    self.head_size = hidden_size // num_heads
    self.embed = Embedding(vocab_size, hidden_size, max_position_embeddings)
    self.dropout = nn.Dropout(dropout)
    self.blocks = nn.Sequential(
        *[Block(hidden_size, num_heads, dropout, layer_norm_epsilon) for _ in range(num_layers)]
    )
    self.ln = nn.LayerNorm(hidden_size)
    self.use_cache = use_cache
    self.kv_cache = t.zeros(num_layers, num_heads, 0, 2 * self.head_size)

  def forward(self, input):
    x = self.dropout(self.embed(input))
    if self.use_cache: # remove previously processed tokens
      prev_seq_len = self.kv_cache.size(-2)
      x = x[...,prev_seq_len:,:]

    if not self.use_cache:
      x = self.blocks(x)

    else:
      new_kvs = []
      for i, block in enumerate(self.blocks):
        x, new_kv = block(x, past_key_values=self.kv_cache[i], return_key_values=self.use_cache)
        new_kvs.append(new_kv)
      new_kvs = t.cat(new_kvs)
      self.kv_cache = t.cat((self.kv_cache, new_kvs), dim=-2)

    final_encoding = self.ln(x)[...,-1,:]
    logits = einsum('ij,bj->bi',self.embed.token_embed, final_encoding)
    return GPT2Output(logits=logits, final_encoding=final_encoding)

  def clear_cache(self):
    self.kv_cache = t.zeros(self.num_layers, self.num_heads, 0, 2 * self.head_size)

  def next_token(self, input_ids, temperature, freq_penalty=2.0):
    logits = self.forward(input_ids.unsqueeze(0)).logits.squeeze(0)

    # tally frequencies
    id_frequencies = t.zeros(logits.size(0))
    for i in range(input_ids.size(0)): id_frequencies[input_ids[i]] += 1

    token_dist = (logits/temperature - id_frequencies * freq_penalty).softmax(-1)
    return t.multinomial(token_dist, num_samples=1)

  def generate(self, text, max_length=30, temperature=1.0, freq_penalty=2.0):
    self.clear_cache()
    tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
    input_ids = t.LongTensor(tokenizer(text).input_ids)
    for _ in range(max_length):
      next_token = self.next_token(input_ids, temperature, freq_penalty)
      input_ids = t.cat((input_ids, next_token))
      if next_token[0] == tokenizer.eos_token_id : break
    return tokenizer.decode(input_ids)
    

gpt_tests.test_gpt(GPT2)
gpt_tests.test_gpt_cache(GPT2)

Checking logits:
Congrats! You've passed the test!
Checking final encodings:
Congrats! You've passed the test!
Congrats! Your GPT returns the same results with and without cache.
It took 7.406s to generate a 500-token sentence without cache and 1.176s with cache.


# 2 Loading pretrained weights

In [93]:
my_gpt = GPT2(num_layers=12, num_heads=12, vocab_size=50257, hidden_size=768, max_position_embeddings=1024, dropout=0.1, layer_norm_epsilon=1e-5)
pretrained_gpt = gpt_tests.get_pretrained_gpt()

In [94]:
load_dict = {}
for (k,_), (_, v) in zip(my_gpt.state_dict().items(), pretrained_gpt.state_dict().items()):
  load_dict[k] = v
my_gpt.load_state_dict(load_dict)

<All keys matched successfully>

# 3 Efficient text generation

In [112]:
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")

In [10]:
tokens = t.LongTensor(tokenizer([
                                 "My life motto:",
                                 "My life motto: Fortune",
                                 "My life motto: Fortune favors",
                                 "My life motto: Fortune favors the",
                                 "My life motto: Fortune favors the bold"
                                ],padding=True).input_ids)

In [207]:
my_gpt = GPT2(num_layers=12, num_heads=12, vocab_size=50257, hidden_size=768, max_position_embeddings=1024, dropout=0.1, layer_norm_epsilon=1e-5, use_cache=True)
load_dict = {}
for (k,_), (_, v) in zip(my_gpt.state_dict().items(), pretrained_gpt.state_dict().items()):
  load_dict[k] = v
my_gpt.load_state_dict(load_dict)

<All keys matched successfully>

In [216]:
my_gpt.generate("I woke up and got out of")

'I woke up and got out of my bed, hit 8 liters as I went fro-up as fire swept through the bag. At 50 liters one could even first on your'