In [2]:
import torch
from sqrll.sqrllm import SqrllConfig, SqrLLM
from tqdm import tqdm
import math
from os.path import expanduser

assert torch.cuda.is_available()

device = torch.device('cuda')
dtype = torch.float32

model_file = 'models/joyce.pt'

try:
    # no
    model = SqrLLM.load(model_file)
    print('loaded')
except:
    cfg = SqrllConfig(
        n_embed = 128,
        n_mem = 256,
        n_ffn = 256,
        ffn_rate = 4,
        n_layer = 12,
        dropout = 0.05,
    )
    model = SqrLLM(cfg)

model = model.float().to(device)
params = sum(p.numel() for p in model.parameters())
print(f'{params=:,}')

loaded
params=2,339,200


In [3]:
import random
import sqrll.dataloaders as dl

def train_iter():
    files = []
    files += ['finnegans_wake.txt'] * 50
    # files += [expanduser('~/dev/data/text/wiki/enwik8_train.txt')]
    # files += [expanduser('~/dev/data/text/alice_in_wonderland.txt')] * 10
    random.shuffle(files)

    data = dl.read_raw(files, chunk=16384)
    data = dl.shuffle(data, bufsize=2048)
    data = dl.str_tensor(data)
    data = dl.tetris(data, batch=4, seqlen=2048)
    data = dl.tensor_to(data, device)
    data = dl.tensor_to(data, torch.int64)
    return data

In [4]:
import plotly.graph_objects as go

step = 0
bpc_avg = 0
tot_bytes = 0
bpc_curve = []
bytes_curve = []

sfig = go.FigureWidget()
sfig.add_scatter()
sfig.update_layout(
    margin=dict(l=20, r=20, t=20, b=20),
)
sfig

FigureWidget({
    'data': [{'type': 'scatter', 'uid': '1f17dd85-e84e-4835-8c4d-f53f057ed860'}],
    'layout': {'margin': {'b': 20, 'l': 20, 'r': 20, 't': 20}, 'template': '...'}
})

In [8]:
model.train().float().to(device)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.999),
    weight_decay=1e-5,
)
loss_func = torch.nn.CrossEntropyLoss()

trainset = list(train_iter())

mem = None

for data in (prog := tqdm(trainset)):

    optimizer.zero_grad()

    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        outputs, mem = model(data, mem)

        targets = data[:, 1:].flatten()
        outputs = outputs[:, :-1].flatten(0,1)

        loss = loss_func(outputs, targets)

    loss.backward()

    optimizer.step()
    bpc = loss.item() / math.log(2)

    step += 1
    bpc_avg += (bpc - bpc_avg) / min(step, 500)

    tot_bytes += data.numel()
    prog.set_description(f'{(bpc_avg):.4f} bpc')

    if step % 64 == 0:
        bpc_curve += [bpc_avg]
        bytes_curve += [tot_bytes]
        sfig.data[0].y = bpc_curve
        sfig.data[0].x = bytes_curve

        for p in model.parameters():
            p.data.clamp_(-20, 20)
        
        mem = None
            
    if step % 8192 == 0:
        model.save(model_file)

0.9214 bpc: 100%|██████████| 8001/8001 [02:34<00:00, 51.66it/s]


In [6]:
model.save(model_file)

In [7]:
gen = b'The meaning of life is '
entropy = 1

model.eval().cpu()

gen = torch.frombuffer(gen, dtype=torch.uint8)
gen = gen.long()[None, :]
prev_len = 0
mem = None
with torch.no_grad():
    for t in tqdm(range(200)):
        pred, mem = model(gen[:, prev_len:], mem)
        pred = pred[0, -1:] / entropy
        
        choose = torch.multinomial(pred.softmax(dim=-1), 1)
        
        gen = torch.cat((gen, choose), dim=-1)

out = bytes(gen[0].tolist()).decode('utf-8')
print(out)



  0%|          | 0/200 [00:00<?, ?it/s]

100%|██████████| 200/200 [00:02<00:00, 90.07it/s] 

The meaning of life is down, scaldbrother, before he will be ground, nievre you, we say. Who would ontrifan, since we’re regularly abroadside in betinned the whole fairness of promise with considerable patriarch? That was





In [30]:
acc = 0
bpc = 0
count = 0

fname = expanduser('~/dev/data/text/wiki/enwik8_test.txt')

with open(fname, 'rb') as f:
    x = f.read()
    chunksz = 8192
    nchunk = len(x) // chunksz
    x = x[:nchunk * chunksz]
    x = torch.frombuffer(x, dtype=torch.uint8)
    x = x.to(device).long()
    x = x.view(nchunk, 1, chunksz)

testset = x

model.eval().to(device)
mem = None

with torch.no_grad():
    for data in (prog := tqdm(testset)):
        
        outputs, mem = model(data, mem)
        outputs = outputs[:, :-1]
        targets = data[:, 1:]
        
        argmax = outputs.argmax(dim=-1)
        logmax = outputs.log_softmax(dim=-1)
        bits = torch.gather(logmax, 2, targets[:,:,None])
        
        # if bits.isnan().any() or any(m.isnan().any() for m in mem):
        #     print(f'{mem=}')
        #     print(f'{outputs=}')
        #     print(f'{logmax=}')
        #     print(f'{bits=}')
        #     break

        
        bpc -= bits.mean() / math.log(2)
        acc += (targets == argmax).float().mean()
        count += 1
        prog.set_description(f'{bpc/count:.4f}')

acc = acc.item() / count
bpc = bpc.item() / count
print(f'{acc=} {bpc=}')

1.3054: 100%|██████████| 1220/1220 [01:56<00:00, 10.44it/s]

acc=0.7388890500928534 bpc=1.305364990234375



