In [1]:
import torch
from sqrll.sqrllm import SqrLLM
from tqdm import tqdm
import math

assert torch.cuda.is_available()

device = torch.device('cuda')
dtype = torch.float32

model = SqrLLM(
    n_embed = 256,
    n_mem = 512,
    n_ffn = 256,
    ffn_rate = 4,
    n_layer = 6,
).float().to(device)

params = sum(p.numel() for p in model.parameters())
print(f'{params=:,}')

try:
    model.load_state_dict(torch.load(f'./model{params}.pt'))
    print('loaded')
except:
    pass

params=4,468,736


In [21]:

from os.path import expanduser

# files = ['finnegans_wake.txt'] * 4
# files = files + [expanduser('~/dev/data/text/wiki/enwik8s/x08')] + files

files = ['finnegans_wake.txt']
files += [expanduser('~/dev/data/text/alice_in_wonderland.txt')] * 10
files *= 20

def file_iter(files=files, chunk=16384):
    for fname in files:
        with open(fname, 'rb') as f:
            offset = torch.randint(0, chunk, ())
            f.seek(offset, 0)
            while len(x := f.read(chunk)) == chunk:
                x = torch.frombuffer(x, dtype=torch.uint8)
                yield x.to(device).long()

def shuf_iter(src=file_iter(), bufsize=2048):
    buffer = [None] * bufsize
    for x in src:
        i = torch.randint(0, bufsize, ())
        out = buffer[i]
        buffer[i] = x
        if out is not None:
            yield out
    for out in buffer:
        if out is not None:
            yield out

def batch_iter(src=shuf_iter(), batch=4):
    buf = []
    for x in src:
        buf += [x]
        if len(buf) == batch:
            yield torch.stack(buf)
            buf = []

def train_iter(src=batch_iter(), chunk=4096):
    for x in src:
        for o in range(0, x.shape[1]-chunk+1, chunk):
            yield x[:, o:o+chunk]

trainset = train_iter()

In [22]:
import plotly.graph_objects as go

step = 0
bpc_avg = 0
bpc_curve = []
off_curve = []
bpc_win = []

sfig = go.FigureWidget()
sfig.add_scatter()
sfig.update_layout(
    margin=dict(l=20, r=20, t=20, b=20),
)
sfig

FigureWidget({
    'data': [{'type': 'scatter', 'uid': '460fda97-c968-444e-89b2-14595a99e307'}],
    'layout': {'margin': {'b': 20, 'l': 20, 'r': 20, 't': 20}, 'template': '...'}
})

In [23]:


model.train().float().to(device)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-4,
    betas=(0.9, 0.99),
    weight_decay=0.0 # lr / 100"
)
loss_func = torch.nn.CrossEntropyLoss()

# model.reset_states()
mem = None

for data in (prog := tqdm(trainset)):

    optimizer.zero_grad()

    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        outputs, mem = model(data, mem)

        targets = data[:, 1:].flatten()
        outputs = outputs[:, :-1].flatten(0,1)

        loss = loss_func(outputs, targets)

    loss.backward()

    optimizer.step()

    step += 1
    bpc = loss.item() / math.log(2)
    bpc_win = bpc_win[-299:] + [bpc]
    bpc_avg = sum(bpc_win) / len(bpc_win)
    prog.set_description(f'{(bpc_avg):.4f} bpc')

    if step % 64 == 0:
        bpc_curve += [bpc_avg]
        sfig.data[0].y = bpc_curve

        for p in model.parameters():
            p.data.clamp_(-20, 20)

0.4369 bpc: : 3500it [01:48, 32.40it/s]


In [24]:
torch.save(model.state_dict(), f'./model{params}.pt')

In [25]:
gen = b'What is the meaning of life? Well, you see '
entropy = 1

model.eval().cpu()

gen = torch.frombuffer(gen, dtype=torch.uint8)
gen = gen.long()[None, :]
prev_len = 0
mem = None
with torch.no_grad():
    for t in tqdm(range(200)):
        pred, mem = model(gen[:, prev_len:], mem)
        pred = pred[0, -1:] / entropy
        
        choose = torch.multinomial(pred.softmax(dim=-1), 1)
        
        gen = torch.cat((gen, choose), dim=-1)

out = bytes(gen[0].tolist()).decode('utf-8')
print(out)

100%|██████████| 200/200 [00:01<00:00, 123.58it/s]

What is the meaning of life? Well, you see me is come to you cuttinrunner on porpoise. He fell from him! Unhim! With his threestar chapelite soul is heavenly girled from cigarha beautiful science saltles, in Moly Saint Delestian’s hairs for 



