In [None]:
import os, pickle, torch
from model_exercise5_solution import GPT

In [None]:
DATA_DIR = "data/"
MODEL_DIR = "best_models/"
CHECKPOINT = "instruction_tuning.pt"
device = "mps"
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device =", device)

In [None]:
with open(DATA_DIR + "meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])

In [None]:
checkpoint = torch.load(MODEL_DIR + CHECKPOINT, map_location=device)
print("best val loss:", checkpoint["best_val_loss"].item())
config = checkpoint["config"]
print(config)
model = GPT(config)
state_dict = checkpoint["model"]
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model = model.to(device)

In [None]:
model

In [None]:
print("--- learnable parameters ---")
for pn, p in model.named_parameters():
    if p.requires_grad:
        print(pn)

In [None]:
start = "G{julina="
num_samples = 5
max_new_tokens = 25
temperature = 1.0
top_k = 1

x = torch.tensor(encode(start), dtype=torch.long, device=device)[None, ...]
if config.get("prompt_vocab_size", 0) > 0:
    prompt = torch.arange(config["prompt_vocab_size"], dtype=torch.long, device=device)[
        None, ...
    ]
else:
    prompt = None
with torch.no_grad():
    for k in range(num_samples):
        y = model.generate(
            x,
            max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            end_token=stoi["}"],
            prompt=prompt,
        )
        print(decode(y[0].tolist()))

In [None]:
a = torch.Tensor([0.1, 0.1, 0.8, 0.1, 0.1])
print(
    "temperature = 0.2 ==>",
    [round(i, 2) for i in torch.nn.functional.softmax(a / 0.2, dim=0).tolist()],
)
print(
    "temperature = 0.5 ==>",
    [round(i, 2) for i in torch.nn.functional.softmax(a / 0.5, dim=0).tolist()],
)
print(
    "temperature = 1.0 ==>",
    [round(i, 2) for i in torch.nn.functional.softmax(a / 1.0, dim=0).tolist()],
)
print(
    "temperature = 1.5 ==>",
    [round(i, 2) for i in torch.nn.functional.softmax(a / 1.5, dim=0).tolist()],
)
print(
    "temperature = 5.0 ==>",
    [round(i, 2) for i in torch.nn.functional.softmax(a / 5.0, dim=0).tolist()],
)