In [1]:
import torch
from sqrll.sqrllm import SqrLLM
from tqdm import tqdm
import math

assert torch.cuda.is_available()

device = torch.device('cuda')
dtype = torch.float32

model = SqrLLM(
    n_embed = 512,
    n_mem = 768,
    n_ffn = 1024,
    ffn_rate = 4,
    n_layer = 16,
).float().to(device)

params = sum(p.numel() for p in model.parameters())
print(f'{params=:,}')

try:
    model.load_state_dict(torch.load(f'models/model{params}.pt'))
    print('loaded')
except:
    pass


params=36,491,008


In [2]:
from libzim.reader import Archive
from bs4 import BeautifulSoup
import multiprocessing
import collections

from os.path import expanduser
import random
import re
import numpy as np

# zname = expanduser('~/dev/data/text/wiki/zim/wikipedia_en_simple_all_nopic_2024-05.zim')
zname = expanduser('~/dev/data/text/wiki/zim/wikipedia_en_all_nopic_2024-04.zim')
zim = Archive(zname)
space_re = re.compile(r'\n\s*\n')


def map_parallel(func, data, lookahead=128, workers=16, timeout=20):
    with multiprocessing.Pool(workers) as pool:
        q = collections.deque()
        for x in data:
            q.append(pool.apply_async(func, (x,)))
            if len(q) >= lookahead:
                if (r := q.popleft().get(timeout=timeout)) is not None:
                    yield r
        while len(q):
            if (r := q.popleft().get(timeout=timeout)) is not None:
                yield r


def tetris(data, batch=16, seqlen=256):
    seqs = [[]] * batch
    seq_len = np.zeros(batch, dtype=int)
    for d in data:
        insert = np.argmin(seq_len)
        if seq_len[insert] == 0:
            seqs[insert] = d
        else:
            seqs[insert] = torch.cat((seqs[insert], d), dim=0)
        seq_len[insert] += len(d)

        while np.min(seq_len) >= seqlen:
            yield torch.stack([
                s[:seqlen] for s in seqs
            ], dim=0)
            seqs = [s[seqlen:] for s in seqs]
            seq_len -= seqlen


def read_zim_entry(page_id):
    entry = zim._get_entry_by_id(page_id).get_item()
    if entry.mimetype != 'text/html':
        return None
    entry = entry.content.tobytes().decode('utf-8')
    entry = BeautifulSoup(entry).get_text()
    entry = re.sub(space_re, '\n\n', entry)
    entry = bytes(entry, 'utf-8')
    entry = torch.frombuffer(entry, dtype=torch.uint8)
    return entry


def dataset(batch, seqlen):
    page_ids = list(range(zim.all_entry_count))
    random.shuffle(page_ids)
    data = map_parallel(read_zim_entry, page_ids, 
        workers=4, lookahead=128, timeout=30)

    data = tetris(data, batch, seqlen)
    data = map(lambda e:e.long().to(device), data)
    return data


In [3]:
import plotly.graph_objects as go

step = 0
bpc_avg = 0
tot_bytes = 0
bpc_curve = []
bytes_curve = []

trainset = dataset(batch=4, seqlen=2048)

sfig = go.FigureWidget()
sfig.add_scatter()
sfig.update_layout(
    margin=dict(l=20, r=20, t=20, b=20),
)
sfig

FigureWidget({
    'data': [{'type': 'scatter', 'uid': '78bd1577-0c45-4a0c-8a9f-cfb0b642f99e'}],
    'layout': {'margin': {'b': 20, 'l': 20, 'r': 20, 't': 20}, 'template': '...'}
})

In [4]:
model.train().float().to(device)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=4e-4,
    betas=(0.9, 0.99),
    weight_decay=0.0 # lr / 100"
)
loss_func = torch.nn.CrossEntropyLoss()

mem = None

for data in (prog := tqdm(trainset)):

    optimizer.zero_grad()

    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        outputs, mem = model(data, mem)

        targets = data[:, 1:].flatten()
        outputs = outputs[:, :-1].flatten(0,1)

        loss = loss_func(outputs, targets)

    loss.backward()

    optimizer.step()

    tot_bytes += data.numel()
    step += 1
    bpc = loss.item() / math.log(2)
    bpc_avg += (bpc - bpc_avg) / min(step, 500)
    prog.set_description(f'{(bpc_avg):.4f} bpc')

    if step % 64 == 0:
        bpc_curve += [bpc_avg]
        bytes_curve += [tot_bytes]
        sfig.data[0].y = bpc_curve
        sfig.data[0].x = bytes_curve

        for p in model.parameters():
            p.data.clamp_(-20, 20)

















2.8853 bpc: : 1089it [01:06, 16.85it/s]

In [6]:
torch.save(model.state_dict(), f'models/model{params}.pt')

In [10]:
gen = b'The meaning of life is '
entropy = .5

model.eval().cpu()

gen = torch.frombuffer(gen, dtype=torch.uint8)
gen = gen.long()[None, :]
prev_len = 0
mem = None
with torch.no_grad():
    for t in tqdm(range(200)):
        pred, mem = model(gen[:, prev_len:], mem)
        pred = pred[0, -1:] / entropy
        prev_len = gen.shape[1]

        choose = torch.multinomial(pred.softmax(dim=-1), 1)
        
        gen = torch.cat((gen, choose), dim=-1)

out = bytes(gen[0].tolist()).decode('utf-8')
print(out)

100%|██████████| 200/200 [00:01<00:00, 186.67it/s]

The meaning of life is the largest in the world. The following is a separate unit of settlement in the world, but the lowest point is to prevent a lot of time. The two main is a part of the weather and the stream of a stron



