Loading libraries and setting up the environment

In [6]:
import sys
import os

root_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(root_path)

In [7]:
import torch
import pickle
from contextlib import nullcontext
from baselines.nanogpt.model import GPT, GPTConfig

Configuring the model for sampling

In [8]:
MODEL_PATH = root_path + "/baselines/nanogpt/shakespeare-char/baseline_model.pt"
META_PATH = root_path + "/baselines/nanogpt/shakespeare-char/meta.pkl"
START_PROMPT = "to be, or not to be -that is the question:\n"
NUM_SAMPLES = 1
MAX_NEW_TOKENS = 500
TEMPERATURE = 0.8
TOP_K = 200
SEED = 42
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DTYPE = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
COMPILE = False

In [9]:
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

device_type = 'cuda' if DEVICE == 'cuda' else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[DTYPE]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

Loading the model and tokenizer

In [10]:
# --- Load model ---
checkpoint = torch.load(MODEL_PATH, map_location='cpu')
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)

# Remove unwanted prefixes if any
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval().to(DEVICE)

if COMPILE:
    model = torch.compile(model)

# --- Load tokenizer ---
with open(META_PATH, 'rb') as f:
    meta = pickle.load(f)

stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])


number of parameters: 10.65M


Example: Shakespeare text generation

In [11]:
# Encoding prompt 
if START_PROMPT.startswith("FILE:"):
    with open(START_PROMPT[5:], 'r', encoding='utf-8') as f:
        START_PROMPT = f.read()

start_ids = encode(START_PROMPT)
x = torch.tensor(start_ids, dtype=torch.long, device=DEVICE)[None, ...]

# Generating text samples
with torch.no_grad():
    for i in range(NUM_SAMPLES):
        y = model.generate(x, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, top_k=TOP_K)
        print(f"--- Sample {i+1} ---")
        print(decode(y[0].tolist()))
        print()

--- Sample 1 ---
to be, or not to be -that is the question:
And I do you at not say it is pomper to
betreman.

MENENIUS:
He was it in the change sword.

First Citizen:
There is a duke enemy comfort in the purpose.

CORIOLANUS:
I come, he hath readed with the custy gentle have
for the worst to the moreign of the poor cours'
'Said and far to she percupe of a fire to the prominius?
We have fury of the cattery writ of the princes of the incarce,
And more the feasts of him well but to me words.

TYRREL:
I think your honour father.

BENVOLIO:
I thus lean very 

