In [1]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
%reload_ext autoreload
%autoreload 2

In [2]:
from datasets import load_dataset
dataset = load_dataset("Rowan/hellaswag", split="validation")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import tiktoken
enc = tiktoken.get_encoding("gpt2")
import jax.numpy as jnp

def render_example(example):
    """
    Given the example as a dictionary, render it as three torch tensors:
    - tokens (the tokens of context + completion, of size 4xN, as there are always 4 candidates)
    - mask (is 1 in the region of the candidate completion, where we evaluate likelihoods)
    - label (the index of the correct completion, which we hope has the highest likelihood)
    """
    ctx = example["ctx"]
    label = example["label"]
    endings = example["endings"]

    # data needed to reproduce this eval on the C size
    data = {
        "label": label,
        "ctx_tokens": None,
        "ending_tokens": [],
    }

    # gather up all the tokens
    ctx_tokens = enc.encode(ctx)
    data["ctx_tokens"] = ctx_tokens
    tok_rows = []
    mask_rows = []

    # print(f"Context tokens: {ctx}")
    # print(f"Endings: {endings}")

    for end in endings:
        end_tokens = enc.encode(" " + end) # note: prepending " " because GPT-2 tokenizer
        tok_rows.append(ctx_tokens + end_tokens)
        mask_rows.append([0]*len(ctx_tokens) + [1]*len(end_tokens))
        data["ending_tokens"].append(end_tokens)

    # have to be careful during the collation because the number of tokens in each row can differ
    max_len = max(len(row) for row in tok_rows)
    tokens = jnp.zeros((4, max_len), dtype=jnp.int32)
    mask = jnp.zeros((4, max_len))
    for i, (tok_row, mask_row) in enumerate(zip(tok_rows, mask_rows)):
        tokens = tokens.at[i, :len(tok_row)].set(jnp.array(tok_row))
        mask = mask.at[i, :len(mask_row)].set(jnp.array(mask_row))

    return data, tokens, mask, label

In [4]:
x = jnp.zeros((4,3))
x = x.at[3, :2].set(jnp.array([1,2]))
x

Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [1., 2., 0.]], dtype=float32)

In [5]:
from transformers import FlaxGPT2LMHeadModel
model = FlaxGPT2LMHeadModel.from_pretrained('gpt2')

data, tokens, mask, label = render_example(dataset[0])
logits = model(tokens).logits
logits.shape

(4, 20, 50257)

In [6]:
logits[..., :-1, :].shape
tokens.shape
tokens[..., 1:].shape

(4, 19, 50257)

(4, 20)

(4, 19)

In [7]:
import torch
import numpy as np
x = torch.from_numpy(np.asarray(logits[..., :-1, :]))
x.shape
#x.size(-1)
x.view(-1, x.size(-1)).shape    # -1 allows 76 to be inferred
y = torch.from_numpy(np.asarray(tokens[..., 1:]))
y.shape
y.view(-1).shape

  x = torch.from_numpy(np.asarray(logits[..., :-1, :]))


torch.Size([4, 19, 50257])

torch.Size([76, 50257])

torch.Size([4, 19])

torch.Size([76])

In [8]:
shift_logits = logits[..., :-1, :]
shift_tokens = tokens[..., 1:]
shift_logits.reshape([-1, shift_logits.shape[-1]]).shape
shift_tokens.reshape([-1]).shape

(76, 50257)

(76,)

In [9]:
import optax
flat_shift_logits = shift_logits.reshape([-1, shift_logits.shape[-1]])
flat_shift_tokens = shift_tokens.reshape([-1])
flat_shift_logits[0], jnp.int32(flat_shift_tokens)[0]
shift_losses = optax.softmax_cross_entropy_with_integer_labels(
    flat_shift_logits, 
    jnp.int32(flat_shift_tokens))
shift_losses.shape
shift_losses[0]
shift_losses.reshape(tokens.shape[0], -1).shape

(Array([-33.570633, -32.76895 , -35.450985, ..., -40.98076 , -40.18672 ,
        -33.21528 ], dtype=float32),
 Array(582, dtype=int32))

(76,)

Array(7.302154, dtype=float32)

(4, 19)

In [10]:
shift_losses.reshape(tokens.shape[0], -1)[0]

Array([7.3021541e+00, 3.3172197e+00, 6.1763878e+00, 1.3953547e+00,
       8.9701486e-01, 5.5496802e+00, 3.7710347e+00, 7.8738132e+00,
       1.5087028e+00, 5.8585443e+00, 1.1923371e+01, 2.1637678e+00,
       2.7706001e+00, 2.6047549e+00, 5.8060970e+00, 1.1142294e-02,
       7.5942540e+00, 1.6238186e+00, 1.5568147e+00], dtype=float32)

In [11]:
shift_mask = mask[..., 1:] # we must shift mask, so we start at the last prompt token
masked_shift_losses = shift_losses.reshape(tokens.shape[0], -1) * shift_mask
masked_shift_losses.shape
shift_mask.shape
# sum and divide by the number of 1s in the mask
sum_loss = masked_shift_losses.sum(axis=1)
avg_loss = sum_loss / shift_mask.sum(axis=1)
sum_loss
avg_loss

(4, 19)

(4, 19)

Array([43.421867, 34.639187, 23.588688, 36.63145 ], dtype=float32)

Array([3.9474425, 5.7731977, 2.948586 , 4.0701613], dtype=float32)

In [12]:
sum_loss.argmin().item()
avg_loss.argmin().item()
label

2

2

'3'

In [16]:
def evaluate():
    model = FlaxGPT2LMHeadModel.from_pretrained('gpt2')
    num_correct_norm = 0
    num_correct = 0
    num_total = 0
    for example in dataset:
        data, tokens, mask, label = render_example(example)
        # get the logits
        logits = model(tokens).logits
        # evaluate the autoregressive loss at all positions
        shift_logits = logits[..., :-1, :]
        shift_tokens = tokens[..., 1:]
        flat_shift_logits = shift_logits.reshape([-1, shift_logits.shape[-1]])
        flat_shift_tokens = shift_tokens.reshape([-1])
        shift_losses = optax.softmax_cross_entropy_with_integer_labels(
            flat_shift_logits, 
            jnp.int32(flat_shift_tokens)
            )
        shift_losses = shift_losses.reshape(tokens.shape[0], -1)
        # now get the average loss just for the completion region (where mask == 1), in each row
        shift_mask = mask[..., 1:] # we must shift mask, so we start at the last prompt token
        masked_shift_losses = shift_losses * shift_mask
        # sum and divide by the number of 1s in the mask
        sum_loss = masked_shift_losses.sum(axis=1)
        avg_loss = sum_loss / shift_mask.sum(axis=1)
        # now we have a loss for each of the 4 completions
        # the one with the lowest loss should be the most likely
        pred = sum_loss.argmin().item()
        pred_norm = avg_loss.argmin().item()

        # accumulate stats
        num_total += 1
        num_correct += int(pred == int(label))
        num_correct_norm += int(pred_norm == int(label))
        print(f"{num_total} acc_norm: {num_correct_norm}/{num_total}={num_correct_norm/num_total:.4f}")

        # debug: pretty print a few examples, and the losses in each case
        if num_total < 10:
            print("---")
            print(f"Context:\n {example['ctx']}")
            print(f"Endings:")
            for i, end in enumerate(example["endings"]):
                print(f"{i} (loss: {avg_loss[i].item():.4f}) {end}")
            print(f"predicted: {pred_norm}, actual: {label}")

        if num_total > 51:
            break

In [17]:
evaluate()

1 acc_norm: 0/1=0.0000
---
Context:
 A man is sitting on a roof. he
Endings:
0 (loss: 3.9474) is using wrap to wrap a pair of skis.
1 (loss: 5.7732) is ripping level tiles off.
2 (loss: 2.9486) is holding a rubik's cube.
3 (loss: 4.0702) starts pulling up roofing on a roof.
predicted: 2, actual: 3
2 acc_norm: 0/2=0.0000
---
Context:
 A lady walks to a barbell. She bends down and grabs the pole. the lady
Endings:
0 (loss: 3.6541) swings and lands in her arms.
1 (loss: 2.5001) pulls the barbell forward.
2 (loss: 2.3630) pulls a rope attached to the barbell.
3 (loss: 2.8399) stands and lifts the weight over her head.
predicted: 2, actual: 3
3 acc_norm: 1/3=0.3333
---
Context:
 Two women in a child are shown in a canoe while a man pulls the canoe while standing in the water, with other individuals visible in the background. the child and a different man
Endings:
0 (loss: 2.9822) are then shown paddling down a river in a boat while a woman talks.
1 (loss: 3.3812) are driving the canoe, they

In [15]:
from jax_gpt2 import GPT
model = GPT.from_pretrained('gpt2')
data, tokens, mask, label = render_example(dataset[0])
tokens[0][0]
tokens.shape
logits = model(tokens)

loading weights from pretrained gpt: gpt2
Length of pytorch state dict: 149
Length of prepared JAX modules dict: 76
Total JAX matrices: 149
Transposing  lm_head


Array(32, dtype=int32)

(4, 20)