In [47]:
import numpy as np
import torch
import os
import pickle
from model import GPT, GPTConfig
import matplotlib.pyplot as plt
import seaborn as sns

In [48]:
ALPHABET = [chr(i) for i in range(ord('a'), ord('z') + 1)]
SEP_BAR, SEP_Q = '|', '?'
batch_size = 64
block_size = 2048
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cuda'

In [49]:
dataset = 'openwebtext'
# poor man's data loader
data_dir = os.path.join('data', dataset)
with open(os.path.join(data_dir, 'meta.pkl'), 'rb') as f:
    meta = pickle.load(f)
stoi, itos = meta['stoi'], meta['itos']
vocab_size = meta['vocab_size'] 
print(f"Using vocab size of {vocab_size} (a-z + separators)")

# ---------------- helper: random mono‑alphabetic key --------
alpha_ids = np.array([stoi[c] for c in ALPHABET], dtype=np.uint8)
def random_key():
    perm = np.random.permutation(26)
    enc  = {alpha_ids[i]: alpha_ids[perm[i]] for i in range(26)}   # plain→cipher
    dec  = {v: k for k, v in enc.items()}                          # cipher→plain
    return enc, dec

def get_batch(split):
    mmap = np.memmap(os.path.join(data_dir, f'{split}.bin'),
                     dtype=np.uint8, mode='r')

    k_pairs   = 1024                              # desired number of pairs
    known_k   = k_pairs - 1                       # last one is the query
    prompt_sz = 2 * k_pairs                       # 2048 tokens
    assert prompt_sz == block_size, "block_size must be 2*k_pairs"

    X = torch.full((batch_size, block_size), stoi['|'],  dtype=torch.long)
    Y = torch.full((batch_size, block_size), -1,          dtype=torch.long)

    for b in range(batch_size):
        # ----- 1. grab k plaintext letters from corpus -------------------
        start = np.random.randint(0, len(mmap) - k_pairs - 1)
        plain = mmap[start:start + k_pairs].copy()          # np.uint8, shape (k_pairs,)

        # ----- 2. fresh random key for this sample -----------------------
        enc, _ = random_key()

        # ----- 3. build prompt ------------------------------------------
        buf, tgt = [], []
        for i, p in enumerate(plain):
            c = enc[p]
            if i < known_k:                                 # give answer
                buf.extend([c, p]);      tgt.extend([-1, p])
            else:                                           # query pair
                buf.extend([c, stoi['?']])
                tgt.extend([-1, p])

        X[b] = torch.from_numpy(np.asarray(buf,  np.uint8))
        Y[b] = torch.from_numpy(np.asarray(tgt, np.int64))

    if device_type == 'cuda':
        X, Y = X.pin_memory().to(device, non_blocking=True), \
               Y.pin_memory().to(device, non_blocking=True)
    else:
        X, Y = X.to(device), Y.to(device)
    return X, Y

Using vocab size of 28 (a-z + separators)


In [50]:
x, y = get_batch('train')

In [51]:
for j in range(x.shape[0]):
    for i in range(1, x.shape[-1], 2):
        print(itos[x[j, i].item()], end='')
    print("\n")

odstothequalityoflifepredictedforgenerationsbyeconomistslikejohnstuartmillandjohnmaynardkeynesistakingplacebeforeoureyesintheusandotheradvancedcapitalistnationsifowningacarandahousedefinedmiddleclassstatusinthethcenturyinthestcenturymiddleclassstatusmaybedefinedbyhavingaccesstostateofthearthealtheducationandrecreationalservicesprovidedbyotherpeopleservicesthatbytheirnaturearesharedbynumerousconsumerswhethertheyareprovidedbytheprivatesectorthepublicsectororthenonprofitsectorifthisisrightthenthepoliticaldebateinshouldhavebeenabouttwosubjectsthathavebeenslightedsofarweneedtoboostproductivityasmuchaswecanwhileensuringthatthegainsfromproductivitydrivengrowtharewidelysharedwitheveryoneinthegrowingservicesectorworkforcetheonceinamilleniumdramaofthetransitionfromasocietyoffarmerstoasocietyofurbanworkersisunlikelytoberepeatedneverthelesswecantrytoclearobstaclestotheinventionofthenextworldchanginginventionlikethesteamengineorthesiliconchipwithoutrelyingupontechnologicalwonderstobailusouteveninth

In [52]:
x, y = get_batch('val')

In [53]:
enc, _ = random_key()

In [54]:
checkpoint = torch.load('out/lr-1e-3/final_lr-1e-3_524M_ckpt.pt', map_location=device)
checkpoint_model_args = checkpoint['model_args']
n_layer=8
n_head=8
n_embd=256
bias=False
dropout=0.0
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
                  bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
# force these config attributes to be equal otherwise we can't even resume training
# the rest of the attributes (e.g. dropout) can stay as desired from command line
for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
    model_args[k] = checkpoint_model_args[k]
# create the model
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
state_dict = checkpoint['model']
# fix the keys of the state dictionary :(
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
state_dict = None
iter_num = checkpoint['iter_num']
best_val_loss = checkpoint['best_val_loss']
total_tokens = checkpoint['total_tokens']
model.eval()

number of parameters: 6.30M


GPT(
  (transformer): ModuleDict(
    (wte): Embedding(28, 256)
    (wpe): Embedding(2048, 256)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-7): 8 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=256, out_features=768, bias=False)
          (c_proj): Linear(in_features=256, out_features=256, bias=False)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=256, out_features=1024, bias=False)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=1024, out_features=256, bias=False)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=256, out_features=28, bias=False)
)

In [55]:
encode_vec    = np.vectorize(lambda c: stoi[c], otypes=[np.uint8])

In [56]:
inp = "abcdefghijklmnopqrstuvwxyz"
text = ''.join(ch for ch in inp.lower() if 'a' <= ch <= 'z')
inp = encode_vec(list(text))
known_k = len(inp) - 1
buf, tgt = [], []
for i, p in enumerate(inp):
    c = enc[p]
    if i < known_k:                                 # give answer
        buf.extend([c, p]);      
        tgt.extend([-1, p])
    else:                                           # query pair
        buf.extend([c, stoi['?']])
        tgt.extend([-1, p])
print(buf)
print(tgt)
x = torch.full((1, len(buf)), stoi['|'], dtype=torch.long)
x[0] = torch.from_numpy(np.asarray(buf,  np.uint8))

[np.uint8(2), np.uint8(0), np.uint8(5), np.uint8(1), np.uint8(15), np.uint8(2), np.uint8(12), np.uint8(3), np.uint8(0), np.uint8(4), np.uint8(8), np.uint8(5), np.uint8(24), np.uint8(6), np.uint8(10), np.uint8(7), np.uint8(13), np.uint8(8), np.uint8(19), np.uint8(9), np.uint8(14), np.uint8(10), np.uint8(17), np.uint8(11), np.uint8(11), np.uint8(12), np.uint8(9), np.uint8(13), np.uint8(3), np.uint8(14), np.uint8(21), np.uint8(15), np.uint8(4), np.uint8(16), np.uint8(1), np.uint8(17), np.uint8(23), np.uint8(18), np.uint8(6), np.uint8(19), np.uint8(20), np.uint8(20), np.uint8(25), np.uint8(21), np.uint8(7), np.uint8(22), np.uint8(22), np.uint8(23), np.uint8(18), np.uint8(24), np.uint8(16), 27]
[-1, np.uint8(0), -1, np.uint8(1), -1, np.uint8(2), -1, np.uint8(3), -1, np.uint8(4), -1, np.uint8(5), -1, np.uint8(6), -1, np.uint8(7), -1, np.uint8(8), -1, np.uint8(9), -1, np.uint8(10), -1, np.uint8(11), -1, np.uint8(12), -1, np.uint8(13), -1, np.uint8(14), -1, np.uint8(15), -1, np.uint8(16), -1, 

In [57]:
for j in range(x.shape[0]):
    for i in range(0, x.shape[-1], 2):
        print(f'({itos[x[j, i].item()], itos[x[j, i+1].item()]})')
    print("\n")

(('c', 'a'))
(('f', 'b'))
(('p', 'c'))
(('m', 'd'))
(('a', 'e'))
(('i', 'f'))
(('y', 'g'))
(('k', 'h'))
(('n', 'i'))
(('t', 'j'))
(('o', 'k'))
(('r', 'l'))
(('l', 'm'))
(('j', 'n'))
(('d', 'o'))
(('v', 'p'))
(('e', 'q'))
(('b', 'r'))
(('x', 's'))
(('g', 't'))
(('u', 'u'))
(('z', 'v'))
(('h', 'w'))
(('w', 'x'))
(('s', 'y'))
(('q', '?'))




In [58]:
output = model(x)

In [59]:
itos[output[0].argmax(dim=-1).item()]

'z'