In [9]:
"""
Debugging exercise: Fixing GPT-7.
(Basic knowledge of PyTorch and transformer models is required for this task.)

It's the night right before we want to launch our latest and greatest model, GPT-7.
But, oh no! It seems some bugs have crept in at the last minute.
We need you to fix them in time before our big launch.
Three sections in the model are marked, each of which contains some bugs.
Run the training/sampling code to check whether the model is working.
"""
from dataclasses import dataclass
from typing import Optional, List

import math
import numpy as np
import time
import torch
import torch.nn as nn
from torch.nn import functional as F


class CausalSelfAttention(nn.Module):

    def __init__(self, hiddens, n_heads):
        super().__init__()
        self.key = nn.Linear(hiddens, hiddens)
        self.query = nn.Linear(hiddens, hiddens)
        self.value = nn.Linear(hiddens, hiddens)
        self.out = nn.Linear(hiddens, hiddens)
        self.n_head = n_heads

    def forward(self, x):
        b, t, c = x.size()

        q = self.query(x).view(b, t, self.n_head, c // self.n_head).transpose(1, 2)
        k = self.key(x).view(b, t, self.n_head, c // self.n_head).transpose(1, 2)
        v = self.value(x).view(b, t, self.n_head, c // self.n_head).transpose(1, 2)
        
        # >>> THIS SECTION CONTAINS 2 BUGS.
        att = torch.matmul(q, k.transpose(-2, -1))
        att = att / math.sqrt(k.shape[-1])
        mask = torch.tril(torch.ones(t, t, dtype=torch.bool)).view(1, 1, t, t)
        att = torch.where(mask, att, 0) 
        att = F.softmax(att, dim=-1) # (B, n_head, t, t)
        y = torch.matmul(att, v) 
        y = v.transpose(1, 2)
        y = y.reshape(b, t, c)
        y = self.out(y)
        # <<< SECTION ENDS HERE.

        return y


class Block(nn.Module):

    def __init__(self, hiddens, n_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(hiddens)
        self.ln2 = nn.LayerNorm(hiddens)
        self.attn = CausalSelfAttention(hiddens, n_heads)
        self.mlp = nn.Sequential(
            nn.Linear(hiddens, 4 * hiddens),
            nn.GELU(),
            nn.Linear(4 * hiddens, hiddens),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class GPT(nn.Module):

    def __init__(self,
                 n_tokens=64,
                 hiddens=64,
                 n_heads=4,
                 layers=4,
                 vocab=256,
        ):
        super().__init__()
        # >>> THIS SECTION CONTAINS A BUG.
        self.layers = layers
        self.tok_emb = nn.Embedding(vocab, hiddens)
        self.pos_emb = nn.Parameter(torch.empty(1, n_tokens, hiddens))
        self.blocks = nn.ModuleList([Block(hiddens, n_heads) for _ in range(layers)])
        self.ln_f = nn.LayerNorm(hiddens)
        self.final = nn.Linear(hiddens, vocab, bias=False)
        # <<< SECTION ENDS HERE.

    def forward(self, x):
        b, t = x.size()
        token_embeddings = self.tok_emb(x)
        position_embeddings = self.pos_emb[:, :t, :]
        x = token_embeddings + position_embeddings
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        return self.final(x)

    def loss(self, tokens):
        # >>> THIS SECTION CONTAINS A BUG.
        targets = tokens[:, :-1]
        logits = self.forward(tokens[:, :-1])
        # print(targets.shape, logits.shape, logits.reshape(-1, logits.size(-1)))
        return F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.flatten())
        # <<< SECTION ENDS HERE.


#################################################################
# Trains the model and samples from it. There are no bugs here. #
#################################################################

def get_prompt():
    text = torch.tensor(np.array([[ord(c) for c in '{:02d}_'.format(i)] for i in range(1, 16)]))
    return text


def sample(gpt):
    print()
    print('Samples:')
    text = get_prompt()
    print(text)
    for _ in range(60):
        logits = gpt(text)[:, -1:]
        new_tokens = torch.distributions.Categorical(F.softmax(logits / 0.8, dim=-1)).sample()
        text = torch.cat([text, new_tokens], axis=1)
    for t in text:
        print(''.join([chr(c) for c in t]))


def train():
    data = [
        "09_Over every mistake",
        "08_But there's no sense crying",
        "07_Because we can",
        "12_And the science gets done",
        "04_My satisfaction",
        "10_You just keep on trying",
        "15_Still alive",
        "01_This was a triumph",
        "13_And you make a neat gun",
        "11_Till you run out of cake",
        "06_We do what we must",
        "03_It's hard to overstate",
        "14_For the people who are",
        "05_Aperture Science:",
        "02_I'm making a note here; 'Huge success'",
    ]
    data = [[ord(c) for c in d] for d in data]
    data = torch.tensor(np.array([np.pad(x, (0, 64 - len(x)), constant_values=ord('_')) for x in data]))

    gpt = GPT()
    optimizer = torch.optim.Adam(gpt.parameters(), lr=1e-2)
    print('Loss:')
    for i in range(100):
        optimizer.zero_grad()
        l = gpt.loss(data)
        if i % 10 == 0:
            print('{:02d} {}'.format(i, l.detach().numpy()))
        l.backward()
        optimizer.step()

    return gpt


def run():
    gpt = train()
    sample(gpt)

run()


Loss:
00 6.02623987197876
10 0.09828782826662064
20 0.006855769082903862
30 0.0016643410781398416
40 0.0008433522889390588
50 0.0005899668904021382
60 0.0004798386653419584
70 0.0004194046196062118
80 0.0003798803372774273
90 0.00035048392601311207

Samples:
tensor([[48, 49, 95],
        [48, 50, 95],
        [48, 51, 95],
        [48, 52, 95],
        [48, 53, 95],
        [48, 54, 95],
        [48, 55, 95],
        [48, 56, 95],
        [48, 57, 95],
        [49, 48, 95],
        [49, 49, 95],
        [49, 50, 95],
        [49, 51, 95],
        [49, 52, 95],
        [49, 53, 95]])
01_____________________________________________________________
02_____________________________________________________________
03_____________________________________________________________
04_____________________________________________________________
05_____________________________________________________________
06_____________________________________________________________
07________________________

In [12]:
"""
Debugging exercise: Fixing GPT-7.
(Basic knowledge of PyTorch and transformer models is required for this task.)

It's the night right before we want to launch our latest and greatest model, GPT-7.
But, oh no! It seems some bugs have crept in at the last minute.
We need you to fix them in time before our big launch.
Three sections in the model are marked, each of which contains some bugs.
Run the training/sampling code to check whether the model is working.
"""
from dataclasses import dataclass
from typing import Optional, List

import math
import numpy as np
import time
import torch
import torch.nn as nn
from torch.nn import functional as F


class CausalSelfAttention(nn.Module):

    def __init__(self, hiddens, n_heads):
        super().__init__()
        self.key = nn.Linear(hiddens, hiddens)
        self.query = nn.Linear(hiddens, hiddens)
        self.value = nn.Linear(hiddens, hiddens)
        self.out = nn.Linear(hiddens, hiddens)
        self.n_head = n_heads

    def forward(self, x):
        b, t, c = x.size() # b=15, t=63, c=64, n_head=4
        # print(b,t,c, self.n_head)

        q = self.query(x).view(b, t, self.n_head, c // self.n_head).transpose(1, 2)
        k = self.key(x).view(b, t, self.n_head, c // self.n_head).transpose(1, 2)
        v = self.value(x).view(b, t, self.n_head, c // self.n_head).transpose(1, 2)
        
        # >>> THIS SECTION CONTAINS 2 BUGS.
        att = torch.matmul(q, k.transpose(-2, -1))
        att = att / math.sqrt(k.shape[-1])
        mask = torch.tril(torch.ones(t, t, dtype=torch.bool)).view(1, 1, t, t)
        #att = torch.where(mask, att, 0) masking attention to 0 is wrong 
        att = torch.where(mask, att, -1e9)
        att = F.softmax(att, dim=-1) # (B, n_head, t, t)
        y = torch.matmul(att, v) 
        #y = v.transpose(1, 2) typo found in original code
        y = y.transpose(1, 2)
        y = y.reshape(b, t, c)
        y = self.out(y)
        # <<< SECTION ENDS HERE.

        return y


class Block(nn.Module):

    def __init__(self, hiddens, n_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(hiddens)
        self.ln2 = nn.LayerNorm(hiddens)
        self.attn = CausalSelfAttention(hiddens, n_heads)
        self.mlp = nn.Sequential(
            nn.Linear(hiddens, 4 * hiddens),
            nn.GELU(),
            nn.Linear(4 * hiddens, hiddens),
        )

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class GPT(nn.Module):

    def __init__(self,
                 n_tokens=64,
                 hiddens=64,
                 n_heads=4,
                 layers=4,
                 vocab=256,
        ):
        super().__init__()
        # >>> THIS SECTION CONTAINS A BUG.
        self.layers = layers
        self.tok_emb = nn.Embedding(vocab, hiddens)
        # self.pos_emb = nn.Parameter(torch.empty(1, n_tokens, hiddens)) initialization of positional embedding is wrong
        self.pos_emb = nn.Parameter(torch.zeros(1, n_tokens, hiddens))
        self.blocks = nn.ModuleList([Block(hiddens, n_heads) for _ in range(layers)])
        self.ln_f = nn.LayerNorm(hiddens)
        self.final = nn.Linear(hiddens, vocab, bias=False)
        # <<< SECTION ENDS HERE.

    def forward(self, x):
        b, t = x.size()
        token_embeddings = self.tok_emb(x)
        position_embeddings = self.pos_emb[:, :t, :]
        x = token_embeddings + position_embeddings
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        return self.final(x)

    def loss(self, tokens):
        # >>> THIS SECTION CONTAINS A BUG.
        # targets = tokens[:, :-1] targets should be the next token
        targets = tokens[:, 1:]
        logits = self.forward(tokens[:, :-1])
        # print(targets.shape, logits.shape, logits.reshape(-1, logits.size(-1)))
        return F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.flatten())
        # <<< SECTION ENDS HERE.


#################################################################
# Trains the model and samples from it. There are no bugs here. #
#################################################################

def get_prompt():
    text = torch.tensor(np.array([[ord(c) for c in '{:02d}_'.format(i)] for i in range(1, 16)]))
    return text


def sample(gpt):
    print()
    print('Samples:')
    text = get_prompt()
    for _ in range(60):
        logits = gpt(text)[:, -1:]
        new_tokens = torch.distributions.Categorical(F.softmax(logits / 0.8, dim=-1)).sample()
        text = torch.cat([text, new_tokens], axis=1)
    for t in text:
        print(''.join([chr(c) for c in t]))


def train():
    data = [
        "09_Over every mistake",
        "08_But there's no sense crying",
        "07_Because we can",
        "12_And the science gets done",
        "04_My satisfaction",
        "10_You just keep on trying",
        "15_Still alive",
        "01_This was a triumph",
        "13_And you make a neat gun",
        "11_Till you run out of cake",
        "06_We do what we must",
        "03_It's hard to overstate",
        "14_For the people who are",
        "05_Aperture Science:",
        "02_I'm making a note here; 'Huge success'",
    ]
    data = [[ord(c) for c in d] for d in data]
    data = torch.tensor(np.array([np.pad(x, (0, 64 - len(x)), constant_values=ord('_')) for x in data]))

    gpt = GPT()
    optimizer = torch.optim.Adam(gpt.parameters(), lr=1e-2)
    print('Loss:')
    for i in range(100):
        optimizer.zero_grad()
        l = gpt.loss(data)
        if i % 10 == 0:
            print('{:02d} {}'.format(i, l.detach().numpy()))
        l.backward()
        optimizer.step()

    return gpt


def run():
    gpt = train()
    sample(gpt)


# model = CausalSelfAttention(64, 4)
# x = torch.randn(5, 6, 64)
# print(model(x).shape)

run()


Loss:
00 6.317563056945801
10 1.0246399641036987
20 0.6647700071334839
30 0.35471609234809875
40 0.11384560167789459
50 0.057370759546756744
60 0.04002990201115608
70 0.0340501144528389
80 0.0331902876496315
90 0.03291754052042961

Samples:
01_This was a triumph__________________________________________
02_I'm making a note here; 'Huge success'______________________
03_It's hard to overstate______________________________________
04_My satisfaction_____________________________________________
05_Aperture Science:___________________________________________
06_We do what we must__________________________________________
07_Because we can______________________________________________
08_But there's no sense crying_________________________________
09_Over every mistake__________________________________________
10_You just keep on trying_____________________________________
11_Till you run out of cake____________________________________
12_And the science gets done___________________________