In [21]:
%reload_ext autoreload
%autoreload 2

import torch

torch.manual_seed(42)


<torch._C.Generator at 0x7659c00a0550>

In [22]:
from learning.gpt2.model import GPT2, PretrainedName


device = torch.device("cuda")
model, pretrained_model = GPT2.from_pretrained(PretrainedName.GPT2_SMALL, device=device)
pretrained_model = pretrained_model.to(device)


In [23]:
import tiktoken
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

tokenizer = tiktoken.get_encoding("gpt2")


def generate(model, s: str, times: int):
    model.eval()
    with torch.no_grad():
        indices = tokenizer.encode(s)
        inputs = torch.tensor([indices], dtype=torch.long, device=device)

        for _ in range(times):
            outputs = model.generate(inputs, 2)
            for i in range(outputs.shape[0]):
                print(tokenizer.decode(outputs[i].tolist()))


def top_logits(model, s: str, k: int):
    print(s)
    model.eval()
    with torch.no_grad():
        indices = tokenizer.encode(s)
        inputs = torch.tensor([indices], dtype=torch.long, device=device)

        outputs = model(inputs)
        if isinstance(outputs, CausalLMOutputWithCrossAttentions):
            outputs = outputs.logits
            assert isinstance(outputs, torch.Tensor)
        else:
            assert isinstance(outputs, torch.Tensor)

        last_output = outputs[:, -1, :]

        probs = torch.softmax(last_output, dim=-1).squeeze(0)
        top_probs, top_indices = torch.topk(probs, k=k)
        for i in range(k):
            print(f"{top_probs[i]:.2f} {tokenizer.decode([int(top_indices[i])])}")

In [24]:
def run_test(model):
    # From the paper: https://arxiv.org/abs/2211.00593
    cases = [
        "When Mary and John went to the store, John gave a drink to",
        "When Vincent and Vanessa went to the park, Vincent gave a leaf to",
        "When Vincent and Vanessa went to the park, Vanessa gave a leaf to",
        "Mary and John went to the store. John gave a drink to",
        "Mary and John went to the store; John gave a drink to",
        "Mary and John went to the store! John gave a drink to",
        "Mary and John went to the store. Mary gave a drink to",
        "Mary and John went to the store; Mary gave a drink to",
        "Mary and John went to the store! Mary gave a drink to",
    ]

    for case in cases:
        top_logits(model, case, 3)


print("-" * 80)
run_test(model)
print("-" * 80)
run_test(pretrained_model)

--------------------------------------------------------------------------------
When Mary and John went to the store, John gave a drink to
0.45  Mary
0.21  them
0.07  the
When Vincent and Vanessa went to the park, Vincent gave a leaf to
0.51  Vanessa
0.12  the
0.06  a
When Vincent and Vanessa went to the park, Vanessa gave a leaf to
0.54  Vincent
0.09  the
0.03  her
Mary and John went to the store. John gave a drink to
0.31  them
0.15  John
0.13  the
Mary and John went to the store; John gave a drink to
0.33  them
0.15  John
0.14  the
Mary and John went to the store! John gave a drink to
0.27  them
0.13  the
0.10  John
Mary and John went to the store. Mary gave a drink to
0.46  John
0.14  them
0.12  the
Mary and John went to the store; Mary gave a drink to
0.29  John
0.21  them
0.14  the
Mary and John went to the store! Mary gave a drink to
0.40  John
0.14  the
0.11  them
--------------------------------------------------------------------------------
When Mary and John went to the st

In [25]:
print(model)

print("-" * 80)
# model.set_capture_output(True)
# After this, the heads output is for case P(ABC)
top_logits(model, "When Mary and John went to the store, Tom gave a drink to", 3)

print("-" * 80)
top_logits(model, "When Mary and John went to the store, John gave a drink to", 3)

print("-" * 80)
# model.set_capture_output(False)
# model.set_use_frozen_output(True)
top_logits(model, "When Mary and John went to the store, Mary gave a drink to", 3)

GPT2(
  (embedding): Embedding(50257, 768)
  (positional_embedding): Embedding(1024, 768)
  (dropout): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attention): MultiHeadAttention(
        (heads): ModuleList(
          (0-11): 12 x AttentionHead(
            (query): Linear(in_features=768, out_features=64, bias=True)
            (key): Linear(in_features=768, out_features=64, bias=True)
            (value): Linear(in_features=768, out_features=64, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (projection): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (feed_forward): FeedForward(
        (linear): Linear(in_features=768, out_features=3072, bias=True)
        (gelu): GELU(approximate