In [None]:
import os, pickle, torch
from contextlib import nullcontext
from model import GPT

In [None]:
DATA_DIR = "data/"
MODEL_DIR = "best_models/"
CHECKPOINT = "gpt.pt"
device = "mps"
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device =", device)
sample_from_base = "gpt2-large"  # None

In [None]:
compile = False
if device == "cuda":
    compile = True
    torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
    if torch.cuda.is_bf16_supported():
        ctx = torch.amp.autocast(device_type=device, dtype=torch.bfloat16)
    else:
        ctx = torch.amp.autocast(device_type=device, dtype=torch.float16)
else:
    ctx = nullcontext()

In [None]:
if sample_from_base is None:
    checkpoint = torch.load(MODEL_DIR + CHECKPOINT, map_location=device)
    print("best val loss:", checkpoint["best_val_loss"].item())
    config = checkpoint["config"]
    model = GPT(config)
    state_dict = checkpoint["model"]
    unwanted_prefix = "_orig_mod."
    for k, v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
else:
    config = dict(dropout=0.0)
    model = GPT.from_pretrained(sample_from_base, config)
    model.crop_block_size(128)
    config = model.config
model.eval()
model = model.to(device)
if compile:
    print("compiling the model... (takes a ~minute)")
    model = torch.compile(model)

In [None]:
model

In [None]:
print("--- learnable parameters ---")
for pn, p in model.named_parameters():
    if p.requires_grad:
        print(pn)

In [None]:
import tiktoken

gpt2 = tiktoken.get_encoding("gpt2")

end_text_token = 50256
start_input_token = 50257
end_input_token = 50258
concept_delimiter_token = 50259
pad_token = 50260
enc = tiktoken.Encoding(
    name="gpt_modified",
    pat_str=gpt2._pat_str,
    mergeable_ranks=gpt2._mergeable_ranks,
    special_tokens={
        **gpt2._special_tokens,
        "<|start_of_input|>": start_input_token,
        "<|end_of_input|>": end_input_token,
        "<|concept_delimiter|>": concept_delimiter_token,
        "<|padding|>": pad_token,
    },
)

In [None]:
# start = "<|start_of_input|>mirzapur<|concept_delimiter|>traffic<|concept_delimiter|>late<|end_of_input|>"
start = "a sentence using word morning and car is"
num_samples = 5
max_new_tokens = 50
temperature = 1.0
top_k = 25

x = torch.tensor(
    enc.encode(
        start,
        allowed_special={
            "<|start_of_input|>",
            "<|end_of_input|>",
            "<|concept_delimiter|>",
        },
    ),
    dtype=torch.long,
    device=device,
)[None, ...]
if config.get("prompt_vocab_size", 0) > 0:
    prompt = torch.arange(config["prompt_vocab_size"], dtype=torch.long, device=device)[
        None, ...
    ]
else:
    prompt = None
with torch.no_grad():
    for k in range(num_samples):
        with ctx:
            y = model.generate(
                x,
                max_new_tokens,
                temperature=temperature,
                top_k=top_k,
                end_token=end_text_token,
                prompt=prompt,
            )
        output = enc.decode(y[0].tolist())
        output = output.split(start)[1]
        print("-----", output + "\n")