In [1]:
import torch
from sqrll.sqrllm import SqrLLM
from tqdm import tqdm
import math

assert torch.cuda.is_available()

device = torch.device('cuda')
dtype = torch.float32

model = SqrLLM(
    n_embed = 768,
    n_mem = 1024,
    n_layer = 12,
).float().to(device)

params = sum(p.numel() for p in model.parameters())
print(f'{params=:,}')

try:
    model.load_state_dict(torch.load(f'models/model{params}.pt'))
    print('loaded')
except:
    pass


params=75,986,176


In [26]:
from libzim.reader import Archive
from bs4 import BeautifulSoup

from os.path import expanduser
import random
import re
import numpy as np

zname = expanduser('~/dev/data/text/wiki/zim/wikipedia_en_simple_all_nopic_2024-05.zim')
zim = Archive(zname)
space_re = re.compile(r'\n\s*\n')


def read_zim():
    order = list(range(zim.all_entry_count))
    random.shuffle(order)
    for i in order:
        entry = zim._get_entry_by_id(i).get_item()
        if entry.mimetype != 'text/html':
            continue
        entry = entry.content.tobytes()
        entry = BeautifulSoup(entry).get_text()
        entry = re.sub(space_re, '\n\n', entry)
        entry = bytes(entry, 'utf-8')
        entry = torch.frombuffer(entry, dtype=torch.uint8)
        entry = entry.long().to(device)
        yield entry

def tetris(data, batch=16, seqlen=256):
    seqs = [[]] * batch
    seq_len = np.zeros(batch, dtype=int)
    for d in data:
        insert = np.argmin(seq_len)
        if seq_len[insert] == 0:
            seqs[insert] = d
        else:
            seqs[insert] = torch.cat((seqs[insert], d), dim=0)
        seq_len[insert] += len(d)

        while np.min(seq_len) >= seqlen:
            yield torch.stack([
                s[:seqlen] for s in seqs
            ], dim=0)
            seqs = [s[seqlen:] for s in seqs]
            seq_len -= seqlen


trainset = tetris(read_zim(), batch=4, seqlen=2048)

In [27]:
import plotly.graph_objects as go

step = 0
bpc_avg = 0
bpc_curve = []
off_curve = []
bpc_win = []

sfig = go.FigureWidget()
sfig.add_scatter()
sfig.update_layout(
    margin=dict(l=20, r=20, t=20, b=20),
)
sfig

FigureWidget({
    'data': [{'type': 'scatter', 'uid': 'e770a778-bbc6-41f2-a1d8-9f5146757392'}],
    'layout': {'margin': {'b': 20, 'l': 20, 'r': 20, 't': 20}, 'template': '...'}
})

In [None]:
model.train().float().to(device)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.99),
    weight_decay=0.0 # lr / 100"
)
loss_func = torch.nn.CrossEntropyLoss()

mem = None

for data in (prog := tqdm(trainset)):

    optimizer.zero_grad()

    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        outputs, mem = model(data, mem)

        targets = data[:, 1:].flatten()
        outputs = outputs[:, :-1].flatten(0,1)

        loss = loss_func(outputs, targets)

    loss.backward()

    optimizer.step()

    step += 1
    bpc = loss.item() / math.log(2)
    bpc_win = bpc_win[-99:] + [bpc]
    bpc_avg = sum(bpc_win) / len(bpc_win)
    prog.set_description(f'{(bpc_avg):.4f} bpc')

    if step % 64 == 0:
        bpc_curve += [bpc_avg]
        sfig.data[0].y = bpc_curve

        for p in model.parameters():
            p.data.clamp_(-20, 20)

In [9]:
torch.save(model.state_dict(), f'models/model{params}.pt')

In [None]:
gen = b'The meaning of life is '
entropy = 1

model.eval().cpu()

gen = torch.frombuffer(gen, dtype=torch.uint8)
gen = gen.long()[None, :]
prev_len = 0
mem = None
with torch.no_grad():
    for t in tqdm(range(100)):
        pred, mem = model(gen[:, prev_len:], mem)
        pred = pred[0, -1:] / entropy
        prev_len = gen.shape[1]

        choose = torch.multinomial(pred.softmax(dim=-1), 1)
        
        gen = torch.cat((gen, choose), dim=-1)

out = bytes(gen[0].tolist()).decode('utf-8')
print(out)