<a href="https://colab.research.google.com/github/wronnyhuang/minGPT/blob/master/mingpt/practice/notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title Import
import torch
import torch.nn as nn
from torch.nn import functional as F
import importlib

def printt(*tensors):
  return [f'{tensor.dtype}, {tensor.shape}' for tensor in tensors]


In [None]:
# @title Layer norm
class LayerNormFun(torch.autograd.Function):
  @staticmethod
  def forward(ctx, x, weight, bias, eps=1e-5):
    mean = x.mean(dim=-1, keepdim=True)
    std = (x.var(dim=-1, keepdim=True, unbiased=False) + eps).sqrt()
    normalized = (x - mean) / std
    output = weight * normalized + bias
    ctx.save_for_backward(normalized, std, weight)
    return output
  @staticmethod
  def backward(ctx, grad_output):
    normalized, std, weight = ctx.saved_tensors
    grad_normalized = grad_output * weight
    grad_input = 1 / std * (grad_normalized - grad_normalized.mean(dim=-1, keepdim=True) - normalized * (grad_normalized * normalized).mean(dim=-1, keepdim=True))
    sum_dims = tuple(range(grad_output.dim() - 1))
    grad_weight = (grad_output * normalized).sum(dim=sum_dims)
    grad_bias = grad_output.sum(dim=sum_dims)
    return grad_input, grad_weight, grad_bias, None

b, d = 2, 4
x = torch.randn(b, d, requires_grad=True)
weight = torch.ones(d, requires_grad=True)
bias = torch.zeros(d, requires_grad=True)

# option 1: use my custom impl
output = LayerNormFun.apply(x, weight, bias)

# options 2: use torch's impl
# ln = nn.LayerNorm(d)
# output = ln(x)

loss = output.sum()
loss.backward()
print(ln.weight.grad)

In [None]:
# @title Linear layer
class Linear(torch.autograd.Function):
  @staticmethod
  def forward(ctx, x, weight, bias):
    output = x @ weight.transpose(0, 1) + bias
    ctx.save_for_backward(x, weight, bias)
    return output
  @staticmethod
  def backward(ctx, grad_output):
    x, weight, bias = ctx.saved_tensors
    grad_input = grad_output @ weight
    grad_weight = (x.transpose(0, 1) @ grad_output).transpose(0, 1)
    sum_dim = tuple(range(grad_output.dim() - 1))
    grad_bias = grad_output.sum(dim=sum_dim)
    return grad_input, grad_weight, grad_bias

b, d = 2, 4
x = torch.randn(b, d, requires_grad=True)
weight = torch.ones((2 * d, d), requires_grad=True)
bias = torch.ones(2 * d, requires_grad=True)

# output = Linear.apply(x, weight, bias)

# uncomment to check the correct answer
linear = nn.Linear(d, 2 * d)
output = linear(x)


loss = output.sum()
loss.backward()
print(ln.weight.grad)
print(ln.bias.grad)

In [None]:
# @title Softmax
class SoftMax(torch.autograd.Function):
  @staticmethod
  def forward(ctx, x):
    x_max, _ = x.max(dim=-1, keepdim=True)
    x = x - x_max
    sum_exp = torch.logsumexp(x, dim=-1, keepdim=True).exp()
    output = x.exp() / sum_exp
    ctx.save_for_backward(output)
    return output
  @staticmethod
  def backward(ctx, grad_output):
    output, = ctx.saved_tensors
    grad_input = output * (grad_output - (output * grad_output).sum(dim=-1, keepdim=True))
    return grad_input

b, d = 2, 3
x0 = torch.randn(b, d, requires_grad=True)
x1 = torch.tensor(x0, requires_grad=True)

output = SoftMax.apply(x0)
loss = output[0][0]
loss.backward()
print(x0.grad)

output = F.softmax(x1)
loss = output[0][0]
loss.backward()
print(x1.grad)

In [None]:
# @title Cross entropy
class CrossEntropy(torch.autograd.Function):
  @staticmethod
  def forward(ctx, x, one_hot_labels):
    probs = SoftMax.apply(x)
    loss = -(one_hot_labels * torch.log(probs)).sum(dim=-1).mean()
    ctx.save_for_backward(probs, one_hot_labels)
    return loss
  @staticmethod
  def backward(ctx, grad_output):
    probs, one_hot_labels = ctx.saved_tensors
    grad_input = grad_output * (probs - one_hot_labels)
    return grad_input, None

x0 = torch.randn(2, 4, requires_grad=True)
one_hot_labels = torch.tensor([[0,0,0,1], [1,0,0,0]], dtype=torch.float32, requires_grad=False)
loss = CrossEntropy.apply(x0, one_hot_labels)
loss.backward()
print(loss)
print(x.grad)

x1 = torch.tensor(x0, requires_grad=True)
loss = F.cross_entropy(x1, one_hot_labels)
loss.backward()
print(loss)
print(x.grad)


In [None]:
# @title Attention backprop
class AttentionFun(torch.autograd.Function):
  @staticmethod
  def forward(ctx, x, QKV, num_heads):
    b, t, d = x.shape
    qkv = x @ QKV.transpose(-1, -2)
    q, k, v = qkv.split(d, dim=-1)
    q = q.view(b, t, num_heads, d // num_heads).transpose(1, 2)
    k = k.view(b, t, num_heads, d // num_heads).transpose(1, 2)
    v = v.view(b, t, num_heads, d // num_heads).transpose(1, 2)
    logits = q @ k.transpose(-1, -2) / k.size(-1) ** 0.5
    logits = torch.where(torch.tril(torch.ones_like(logits)) == 1, logits, -float('inf'))
    scores = SoftMax.apply(logits)
    output = scores @ v
    output = output.transpose(1, 2).contiguous().view(b, t, d)
    ctx.save_for_backward(x, QKV, q, k, v, scores)
    return output

  def backward(ctx, grad_output):
    b, t, d = grad_output.shape
    x, QKV, q, k, v, scores = ctx.saved_tensors
    num_heads = k.size(1)

    grad_output = grad_output.view(b, t, num_heads, d // num_heads).transpose(1, 2)
    grad_scores = grad_output @ v.transpose(-1, -2)
    grad_v = scores.transpose(-1, -2) @ grad_output

    grad_logits = scores * (grad_scores - (scores * grad_scores).sum(dim=-1, keepdim=True))
    grad_q = grad_logits @ k / k.size(-1) ** 0.5
    grad_k = (q.transpose(-1, -2) @ grad_logits).transpose(-1, -2) / k.size(-1) ** 0.5

    grad_q = grad_q.transpose(1, 2).contiguous().view(b, t, d)
    grad_k = grad_k.transpose(1, 2).contiguous().view(b, t, d)
    grad_v = grad_v.transpose(1, 2).contiguous().view(b, t, d)
    grad_qkv = torch.cat([grad_q, grad_k, grad_v], dim=-1)
    grad_x = grad_qkv @ QKV
    grad_QKV = (x.transpose(-1, -2) @ grad_qkv).transpose(-1, -2)
    return grad_x, grad_QKV, None

b = 2
t = 5
d = 12
num_heads = 3

x = torch.randn((b, t, d), requires_grad=True)
QKV = torch.randn((3 * d, d), requires_grad=True)

x_clone = torch.tensor(x, requires_grad=True)
QKV_clone = torch.tensor(QKV, requires_grad=True)
output = AttentionFun.apply(x_clone, QKV_clone, num_heads)
loss = output.sum()
loss.backward()
print(x_clone.grad.shape)
print(x_clone.grad[1, :3, :3])
print(QKV_clone.grad[:3, :3])

x_clone = torch.tensor(x, requires_grad=True)
QKV_clone = torch.tensor(QKV, requires_grad=True)
attn_module = Attention(num_heads, d)
attn_module.qkv.weight.data = QKV_clone
attn_module.qkv.bias.data = torch.zeros_like(attn_module.qkv.bias.data)
output = attn_module(x_clone)
loss = output.sum()
loss.backward()
print(x_clone.grad.shape)
print(x_clone.grad[1, :3, :3])
print(attn_module.qkv.weight.grad[:3, :3])

In [None]:
# @title Put it all together
class Attention(nn.Module):
  def __init__(self, num_heads, dim):
    super().__init__()
    self.num_heads = num_heads
    self.QKV = nn.Linear(dim, 3 * dim)
  def forward(self, x):
    b, t, d = x.shape
    qkv = self.QKV(x)
    q, k, v = torch.split(qkv, d, dim=-1)
    q = q.view(b, t, num_heads, d // num_heads).transpose(1, 2)
    k = k.view(b, t, num_heads, d // num_heads).transpose(1, 2)
    v = v.view(b, t, num_heads, d // num_heads).transpose(1, 2)
    logits = q @ k.transpose(-1, -2) / d ** 0.5
    logits = torch.where(torch.tril(torch.ones_like(logits)) == 1, logits, -float('inf'))
    attn = F.softmax(logits, dim=-1)
    x = attn @ v
    x = x.transpose(1, 2).contiguous().view(b, t, d)
    return x

class Ffn(nn.Module):
  def __init__(self, ffn_dim, dim):
    super().__init__()
    self.fc1 = nn.Linear(dim, ffn_dim)
    self.fc2 = nn.Linear(ffn_dim, dim)
  def forward(self, x):
    x = self.fc1(x)
    x = F.gelu(x)
    x = self.fc2(x)
    x = F.gelu(x)
    return x

class LayerNorm(nn.Module):
  def __init__(self, dim, eps=1e-5):
    super().__init__()
    self.eps = eps
    self.weight = torch.nn.Parameter(torch.ones(dim))
    self.bias = torch.nn.Parameter(torch.zeros(dim))
  def forward(self, x):
    return LayerNormFun.apply(x, self.weight, self.bias, self.eps)


class Block(nn.Module):
  def __init__(self, num_heads, ffn_dim, dim):
    super().__init__()
    self.attn = Attention(num_heads, dim)
    self.ffn = Ffn(ffn_dim, dim)
    self.ln1 = LayerNorm(dim)
    self.ln2 = LayerNorm(dim)
  def forward(self, x):
    x = x + self.attn(self.ln1(x))
    x = x + self.ffn(self.ln2(x))
    return x


class Transformer(nn.Module):
  def __init__(self, num_layers, num_heads, ffn_dim, dim, vocab_size):
    super().__init__()
    self.layers = nn.ModuleList([Block(num_heads, ffn_dim, dim) for _ in range(num_layers)])
    self.lnf = LayerNorm(dim)
    self.lm_head = nn.Linear(dim, vocab_size)
  def forward(self, x, targets=None):
    b, t, d = x.shape
    for layer in self.layers:
      x = layer(x)
    x = self.lnf(x)
    logits = self.lm_head(x)

    loss = None
    if targets is not None:
      loss = F.cross_entropy(logits.view(-1, logits.size(-1), targets.view(-1), ignore_index=-1)
    return logits, loss

vocab_size = 10
batch_size = 2
dim = 12
ffn_dim = 36
num_heads = 3
num_layers = 2
transformer = Transformer(num_layers, num_heads, ffn_dim, dim, vocab_size)

query_len = 6
x = torch.randn(batch_size, query_len, dim)
logits, _ = transformer(x)
logits.shape

In [493]:
# @title Karpathy's dataset
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from mingpt.utils import set_seed
set_seed(3407)
import pickle

class SortDataset(Dataset):
    """
    Dataset for the Sort problem. E.g. for problem length 6:
    Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2
    Which will feed into the transformer concatenated as:
    input:  0 0 2 1 0 1 0 0 0 1 1
    output: I I I I I 0 0 0 1 1 2
    where I is "ignore", as the transformer is reading the input sequence
    """

    def __init__(self, split, length=6, num_digits=3):
        assert split in {'train', 'test'}
        self.split = split
        self.length = length
        self.num_digits = num_digits

    def __len__(self):
        return 10000 # ...

    def get_vocab_size(self):
        return self.num_digits

    def get_block_size(self):
        # the length of the sequence that will feed into transformer,
        # containing concatenated input and the output, but -1 because
        # the transformer starts making predictions at the last input element
        return self.length * 2 - 1

    def __getitem__(self, idx):

        # use rejection sampling to generate an input example from the desired split
        while True:
            # generate some random integers
            inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)
            # half of the time let's try to boost the number of examples that
            # have a large number of repeats, as this is what the model seems to struggle
            # with later in training, and they are kind of rate
            if torch.rand(1).item() < 0.5:
                if inp.unique().nelement() > self.length // 2:
                    # too many unqiue digits, re-sample
                    continue
            # figure out if this generated example is train or test based on its hash
            h = hash(pickle.dumps(inp.tolist()))
            inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test
            if inp_split == self.split:
                break # ok

        # solve the task: i.e. sort
        sol = torch.sort(inp)[0]

        # concatenate the problem specification and the solution
        cat = torch.cat((inp, sol), dim=0)

        # the inputs to the transformer will be the offset sequence
        x = cat[:-1].clone()
        y = cat[1:].clone()
        # we only want to predict at output locations, mask out the loss at the input locations
        y[:self.length-1] = -1
        return x, y

# print an example instance of the dataset
train_dataset = SortDataset('train')
test_dataset = SortDataset('test')
x, y = train_dataset[0]
for a, b in zip(x,y):
    print(int(a),int(b))


1 -1
1 -1
0 -1
2 -1
0 -1
1 0
0 0
0 1
1 1
1 1
1 2


In [515]:
import importlib
from mingpt.practice import my_functions
from mingpt.practice import my_modules
importlib.reload(my_functions)
importlib.reload(my_modules)


block_size = train_dataset.get_block_size() + 1
vocab_size = 3
batch_size = 2
dim = 12
ffn_dim = 36
num_heads = 3
num_layers = 2
model = my_modules.Transformer(num_layers, num_heads, ffn_dim, dim, vocab_size, block_size)

logits, loss = model(torch.tensor(x)[None, :], torch.tensor(y)[None, :])
print(logits.shape)
print(loss)

torch.Size([1, 11, 3])
tensor(1.1027, grad_fn=<NllLossBackward0>)


  logits, loss = model(torch.tensor(x)[None, :], torch.tensor(y)[None, :])


In [518]:
# @title create Karpathy Train
import importlib
from mingpt import trainer
importlib.reload(trainer)
Trainer = trainer.Trainer


train_config = Trainer.get_default_config()
train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
train_config.max_iters = 2000
train_config.num_workers = 0
trainer = Trainer(train_config, model, train_dataset)

def batch_end_callback(trainer):
    if trainer.iter_num % 100 == 0:
        print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
trainer.set_callback('on_batch_end', batch_end_callback)

trainer.run()


running on device cpu
iter_dt 0.00ms; iter 0: train loss 0.05245
iter_dt 8.03ms; iter 100: train loss 0.01490
iter_dt 8.07ms; iter 200: train loss 0.00260
iter_dt 7.72ms; iter 300: train loss 0.00060
iter_dt 7.40ms; iter 400: train loss 0.00017
iter_dt 14.36ms; iter 500: train loss 0.00005
iter_dt 7.85ms; iter 600: train loss 0.00001
iter_dt 9.10ms; iter 700: train loss 0.00000
iter_dt 7.72ms; iter 800: train loss 0.00000
iter_dt 14.68ms; iter 900: train loss 0.00000
iter_dt 7.59ms; iter 1000: train loss 0.00000
iter_dt 9.20ms; iter 1100: train loss 0.00000
iter_dt 7.76ms; iter 1200: train loss 0.00000
iter_dt 7.80ms; iter 1300: train loss 0.00000
iter_dt 10.61ms; iter 1400: train loss 0.00000
iter_dt 9.77ms; iter 1500: train loss 0.00000
iter_dt 10.76ms; iter 1600: train loss 0.00000
iter_dt 11.10ms; iter 1700: train loss 0.00000
iter_dt 14.40ms; iter 1800: train loss 0.00000
iter_dt 8.66ms; iter 1900: train loss 0.00000


In [519]:
# @title Karpathy Eval
def eval_split(trainer, split, max_batches):
    dataset = {'train':train_dataset, 'test':test_dataset}[split]
    n = train_dataset.length # naugy direct access shrug
    results = []
    mistakes_printed_already = 0
    loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)
    for b, (x, y) in enumerate(loader):
        x = x.to(trainer.device)
        y = y.to(trainer.device)
        # isolate the input pattern alone
        inp = x[:, :n]
        sol = y[:, -n:]
        # let the model sample the rest of the sequence
        cat = model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling
        sol_candidate = cat[:, n:] # isolate the filled in sequence
        # compare the predicted sequence to the true sequence
        correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha
        for i in range(x.size(0)):
            results.append(int(correct[i]))
            if not correct[i] and mistakes_printed_already < 3: # only print up to 5 mistakes to get a sense
                mistakes_printed_already += 1
                print("GPT claims that %s sorted is %s but gt is %s" % (inp[i].tolist(), sol_candidate[i].tolist(), sol[i].tolist()))
        if max_batches is not None and b+1 >= max_batches:
            break
    rt = torch.tensor(results, dtype=torch.float)
    print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean()))
    return rt.sum()

# now let's perform some evaluation
model.eval();

# run a lot of examples from both train and test through the model and verify the output correctness
with torch.no_grad():
    train_score = eval_split(trainer, 'train', max_batches=50)
    test_score  = eval_split(trainer, 'test',  max_batches=50)

train final score: 5000/5000 = 100.00% correct
GPT claims that [2, 2, 2, 2, 2, 2] sorted is [1, 2, 2, 2, 2, 2] but gt is [2, 2, 2, 2, 2, 2]
GPT claims that [2, 2, 2, 2, 2, 2] sorted is [1, 2, 2, 2, 2, 2] but gt is [2, 2, 2, 2, 2, 2]
GPT claims that [2, 2, 2, 2, 2, 2] sorted is [1, 2, 2, 2, 2, 2] but gt is [2, 2, 2, 2, 2, 2]
test final score: 4977/5000 = 99.54% correct
