In [4]:
import torch
import torch.nn.functional as F
import numpy as np

from datasets import load_dataset
from gpt2 import GPT, GPTConfig
from hellaswag import render_example, iterate_examples
import tiktoken

In [5]:
checkpoint = torch.load("log/model_00079.pt")
weights = checkpoint['model']

# Init the model
model = GPT(GPTConfig())
model.load_state_dict(weights)

# Set Device
device = "cuda:0"

# Move the model to GPU
model.to(device)

  checkpoint = torch.load("log/model_00079.pt")


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): GELU(approximate='tanh')
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [6]:
torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

In [41]:
## Generator Function to generate from the model
def streaming_generate(model, prompt):
    enc = tiktoken.get_encoding("gpt2")
    model.eval()
    max_length = 500
    tokens = enc.encode(prompt)
    tokens = torch.tensor([tokens], dtype=torch.long)
    xgen = tokens.to(device)
    sample_rng = torch.Generator(device=device)
    sample_rng.manual_seed(42)

    while xgen.size(1) < max_length:
        # forward the model to get the logits
        with torch.no_grad():
            logits, loss = model(xgen[:, -1024:]) # (B, T, vocab_size)
            # take the logits at the last position
            logits = logits[:, -1, :] # (B, vocab_size)
            # get the probabilities
            probs = F.softmax(logits, dim=-1)
            # do top-k sampling of 50 (huggingface pipeline default)
            # topk_probs here becomes (5, 50), topk_indices is (5, 50)
            topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
            # select a token from the top-k probabilities
            # note: multinomial does not demand the input to sum to 1
            ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
            # gather the corresponding indices
            xcol = torch.gather(topk_indices, -1, ix) # (B, 1)

            # Check if we reached end of generation
            val = xcol.tolist()[0]
            decoded = enc.decode(val)
            if decoded == "<|endoftext|>":
                break
            else:
                yield decoded

            # append to the sequence
            xgen = torch.cat((xgen, xcol), dim=1)

In [23]:
def preprocess_function(example):
    """
    Formatting function returning a list of samples (kind of necessary for SFT API).
    """
    text = f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:\n{example['output']}"
    return text

In [24]:
def generate(model, query):
    prompt = preprocess_function({"instruction": query, "input": "", "output": ""})
    for tok in streaming_generate(model, prompt):
        print(tok, end='')
    print()

## Base Model Generation

In [39]:
base_model = GPT.from_pretrained("gpt2")
base_model.to(device)

loading weights from pretrained gpt: gpt2


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): GELU(approximate='tanh')
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [42]:
generate(base_model, "What is 2 + 2?")


What is 2+2?

### Input:

What is 2+2?

### Input:

What is 3 + 3?

### Input:

What is 3+3?

### Input:

What is 4 + 4?

### Input:

What is 4+4?

### Input:

What is 5 + 5?

### Input:

What is 5+5?

### Input:

What is 6 + 6?

### Input:

What is 6+6?

### Input:

What is 7 + 7?

### Input:

What is 8 + 8?

### Input:

What is 8+8?

### Input:

What is 9 + 9?

### Input:

What is 9+9?

### Input:

What is 10 + 10?

### Input:

What is 10+10?

### Input:

What is 11 + 11?

### Input:

What is 11+11?

### Input:

What is 12 + 12?

### Input:

What is 11+12?

### Input:

What is 12+12?

### Input:

What is 13 + 13?

### Input:

What is 13+13?

### Input:

What is 14 + 14?

### Input:

What is 14+14?

### Input:

What is 15 + 15?

### Input:

What is 15+15?

### Input:

What is 16 + 16?

### Input:

What is 16+16?

### Input:

What is 17 + 17?

### Input:

What is 17+17?

### Input:

What is 18 + 18?

### Input:

What is 18+18?

### Input:

What is 19 + 19?

### Input:

What is 19+19?


## Fine Tuned Model Generation

In [43]:
generate(model, "What is 2 + 2?")

2 + 2 is equivalent to 4.


In [47]:
generate(model, "Where is Eiffel Tower?")

The Eiffel Tower is located in Paris, France.


In [48]:
generate(model, "Who is the president of U.S?")

The president of the United States is Barack Obama.


In [49]:
generate(model, "What's the tallest building in the world?")

The tallest building in China is the 9,888-metre high Minglong Tower located in Jiechi Province in the Shandong province of China.
