In [1]:
import os, pickle, torch
from model import GPT

In [2]:
DATA_DIR = "data/"
MODEL_DIR = "best_models/"
CHECKPOINT = "instruction_tuning.pt"
device = "mps"

In [3]:
with open(DATA_DIR + "meta.pkl", "rb") as f:
    meta = pickle.load(f)
stoi, itos = meta["stoi"], meta["itos"]
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])

In [4]:
checkpoint = torch.load(MODEL_DIR + CHECKPOINT, map_location=device)
print("best val loss:", checkpoint["best_val_loss"].item())
config = checkpoint["config"]
print(config)
model = GPT(config)
state_dict = checkpoint["model"]
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model = model.to(device)

best val loss: 1.3631967306137085
{'n_layer': 6, 'n_head': 6, 'n_embd': 150, 'block_size': 25, 'bias': False, 'vocab_size': 38, 'dropout': 0.2, 'pad_token': 26}
total number of parameters: 1627650 learnable: 1627650


In [5]:
model

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(38, 150, padding_idx=26)
    (wpe): Embedding(25, 150)
    (drop): Dropout(p=0.2, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=150, out_features=450, bias=False)
          (c_proj): Linear(in_features=150, out_features=150, bias=False)
          (attn_dropout): Dropout(p=0.2, inplace=False)
          (resid_dropout): Dropout(p=0.2, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=150, out_features=600, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=600, out_features=150, bias=False)
          (dropout): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=150, out_features=38, bias=False)
)

In [6]:
print("--- learnable parameters ---")
for pn, p in model.named_parameters():
    if p.requires_grad:
        print(pn)

--- learnable parameters ---
transformer.wte.weight
transformer.wpe.weight
transformer.h.0.ln_1.weight
transformer.h.0.attn.c_attn.weight
transformer.h.0.attn.c_proj.weight
transformer.h.0.ln_2.weight
transformer.h.0.mlp.c_fc.weight
transformer.h.0.mlp.c_proj.weight
transformer.h.1.ln_1.weight
transformer.h.1.attn.c_attn.weight
transformer.h.1.attn.c_proj.weight
transformer.h.1.ln_2.weight
transformer.h.1.mlp.c_fc.weight
transformer.h.1.mlp.c_proj.weight
transformer.h.2.ln_1.weight
transformer.h.2.attn.c_attn.weight
transformer.h.2.attn.c_proj.weight
transformer.h.2.ln_2.weight
transformer.h.2.mlp.c_fc.weight
transformer.h.2.mlp.c_proj.weight
transformer.h.3.ln_1.weight
transformer.h.3.attn.c_attn.weight
transformer.h.3.attn.c_proj.weight
transformer.h.3.ln_2.weight
transformer.h.3.mlp.c_fc.weight
transformer.h.3.mlp.c_proj.weight
transformer.h.4.ln_1.weight
transformer.h.4.attn.c_attn.weight
transformer.h.4.attn.c_proj.weight
transformer.h.4.ln_2.weight
transformer.h.4.mlp.c_fc.weight

In [25]:
start = "Ean{"
num_samples = 5
max_new_tokens = 25
temperature = 1.0
top_k = 3

x = torch.tensor(encode(start), dtype=torch.long, device=device)[None, ...]
if config.get("prompt_vocab_size", 0) > 0:
    prompt = torch.arange(config["prompt_vocab_size"], dtype=torch.long, device=device)[
        None, ...
    ]
else:
    prompt = None
with torch.no_grad():
    for k in range(num_samples):
        y = model.generate(
            x,
            max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            end_token=stoi["}"],
            prompt=prompt,
        )
        print(decode(y[0].tolist()))

Ean{sunanda}
Ean{saran}
Ean{santonioni}
Ean{anastacia}
Ean{manina}


In [8]:
a = torch.Tensor([0.1, 0.1, 0.8, 0.1, 0.1])
print(
    "temperature = 0.2 ==>",
    [round(i, 2) for i in torch.nn.functional.softmax(a / 0.2, dim=0).tolist()],
)
print(
    "temperature = 0.5 ==>",
    [round(i, 2) for i in torch.nn.functional.softmax(a / 0.5, dim=0).tolist()],
)
print(
    "temperature = 1.0 ==>",
    [round(i, 2) for i in torch.nn.functional.softmax(a / 1.0, dim=0).tolist()],
)
print(
    "temperature = 1.5 ==>",
    [round(i, 2) for i in torch.nn.functional.softmax(a / 1.5, dim=0).tolist()],
)
print(
    "temperature = 5.0 ==>",
    [round(i, 2) for i in torch.nn.functional.softmax(a / 5.0, dim=0).tolist()],
)

temperature = 0.2 ==> [0.03, 0.03, 0.89, 0.03, 0.03]
temperature = 0.5 ==> [0.12, 0.12, 0.5, 0.12, 0.12]
temperature = 1.0 ==> [0.17, 0.17, 0.33, 0.17, 0.17]
temperature = 1.5 ==> [0.18, 0.18, 0.29, 0.18, 0.18]
temperature = 5.0 ==> [0.19, 0.19, 0.22, 0.19, 0.19]
