In [1]:
import torch
import torch.nn.functional as F
import numpy as np

from datasets import load_dataset
from model import Transformer, MiniLlamaArgs
from hellaswag import render_example, iterate_examples
from tokenizer import Tokenizer

In [2]:
# checkpoint = torch.load("log/model_19072.pt")
checkpoint = torch.load("fineweb_pretrain/model_19072.pt")
weights = checkpoint['model']

# Init the model
model = Transformer(MiniLlamaArgs())
model.load_state_dict(weights)

# Set Device
device = "cuda:0"

# Move the model to GPU
model.to(device)

  checkpoint = torch.load("fineweb_pretrain/model_19072.pt")


Transformer(
  (token_embeddings): Embedding(32000, 768)
  (dropout): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=768, out_features=768, bias=False)
        (wk): Linear(in_features=768, out_features=768, bias=False)
        (wv): Linear(in_features=768, out_features=768, bias=False)
        (wo): Linear(in_features=768, out_features=768, bias=False)
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=768, out_features=2048, bias=False)
        (w2): Linear(in_features=2048, out_features=768, bias=False)
        (w3): Linear(in_features=768, out_features=2048, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=

In [3]:
torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

In [4]:
## Generator Function to generate from the model
def generate(model, prompt):
    enc = Tokenizer()
    model.eval()
    num_return_sequences = 4
    max_length = 100
    tokens = enc.encode(prompt, True, False)
    tokens = torch.tensor(tokens, dtype=torch.long)
    tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
    xgen = tokens.to(device)
    sample_rng = torch.Generator(device=device)
    sample_rng.manual_seed(42)
    
    while xgen.size(1) < max_length:
        # forward the model to get the logits
        with torch.no_grad():
            logits, loss = model(xgen) # (B, T, vocab_size)
            # take the logits at the last position
            logits = logits[:, -1, :] # (B, vocab_size)
            # get the probabilities
            probs = F.softmax(logits, dim=-1)
            # do top-k sampling of 50 (huggingface pipeline default)
            # topk_probs here becomes (5, 50), topk_indices is (5, 50)
            topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
            # select a token from the top-k probabilities
            # note: multinomial does not demand the input to sum to 1
            ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
            # gather the corresponding indices
            xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
            # append to the sequence
            xgen = torch.cat((xgen, xcol), dim=1)
    
    # print the generated text
    for i in range(num_return_sequences):
        tokens = xgen[i, :max_length].tolist()
        decoded = enc.decode(tokens)
        print(f"Sample {i}: {decoded}")
        print()

In [5]:
## Generate from the Model I trained
generate(model, "Hello I'm an Llama")

Sample 0: Hello I'm an Llama and I live in a remote place (at your expense). Can you give me a summary of my current life as a Llama? I want to know, what is the importance of a living person on this planet and whether he or she is a living person?
Do you know where your mother is or where she is born? Your mother is probably dead and therefore you have two unspoken wishes to keep her alive. How do you

Sample 1: Hello I'm an Llama! The course is designed to teach parents how to interact with children at home in the best way. Through a series of simple, practical activities families can all make together to build confidence.
Your child will develop important skills in their everyday language by having a 'home' at school. The home is where you and your child will be at different times of the week and you and your child will be in different situations. Your child's English,

Sample 2: Hello I'm an Llama, and like so many who teach kids. I'm here to help you! I'm a teacher and I'm not a t

In [6]:
## Generate from the Model I trained
generate(model, "To stay healthy, I have to")

Sample 0: To stay healthy, I have to eat a whole foods diet – like chicken, eggs, grains, fruits and nuts. This includes all of the “healthy” foods, soy, dairy products, and fruits and vegetables, as well as refined grains, meats, milk, and dairy products. These foods are good. You can keep them part of a healthy diet by eating a whole food

Sample 1: To stay healthy, I have to work all day. But I also require a little push back from people who have been hurt.
How many times have you tried to help an elder baby?
I have done so many times in my life.
How many times have you thought you were cared for?
An older person’s decision to stay home from work can also result in a high-risk situation called drowning. This is a serious medical emergency.

Sample 2: To stay healthy, I have to be healthy!
In this blog post, I will look at a few key health benefits of eating healthily.
1. Eating Healthy
Healthy foods are great for your whole body and helps fight off disease. One easy way to get healt

## Eval: HellaSwag

In [7]:
def get_most_likely_row(tokens, mask, logits):
    # evaluate the autoregressive loss at all positions
    shift_logits = (logits[..., :-1, :]).contiguous()
    shift_tokens = (tokens[..., 1:]).contiguous()
    flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
    flat_shift_tokens = shift_tokens.view(-1)
    shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
    shift_losses = shift_losses.view(tokens.size(0), -1)
    # now get the average loss just for the completion region (where mask == 1), in each row
    shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
    masked_shift_losses = shift_losses * shift_mask
    # sum and divide by the number of 1s in the mask
    sum_loss = masked_shift_losses.sum(dim=1)
    avg_loss = sum_loss / shift_mask.sum(dim=1)
    # now we have a loss for each of the 4 completions
    # the one with the lowest loss should be the most likely
    pred_norm = avg_loss.argmin().item()
    return pred_norm

In [8]:
num_correct_norm = 0
num_total = 0
for i, example in enumerate(iterate_examples("val")):
    if i % 500 == 0:
        print(f"Example: {i}")
    # render the example into tokens and labels
    _, tokens, mask, label = render_example(example)
    tokens = tokens.to(device)
    mask = mask.to(device)
    # get the logits
    with torch.no_grad():
        logits, loss = model(tokens)
        pred_norm = get_most_likely_row(tokens, mask, logits)
    num_total += 1
    num_correct_norm += int(pred_norm == label)

acc_norm = num_correct_norm / num_total
print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")

Example: 0
Example: 500
Example: 1000
Example: 1500
Example: 2000
Example: 2500
Example: 3000
Example: 3500
Example: 4000
Example: 4500
Example: 5000
Example: 5500
Example: 6000
Example: 6500
Example: 7000
Example: 7500
Example: 8000
Example: 8500
Example: 9000
Example: 9500
Example: 10000
HellaSwag accuracy: 3047/10042=0.3034
