In [11]:
import torch
from sqrll.sqrllm import SqrLLM
from tqdm import tqdm
import math

assert torch.cuda.is_available()

device = torch.device('cuda')
dtype = torch.float32

model = SqrLLM(
    n_embed = 384,
    n_mem = 512,
    n_layer = 12,
).float().to(device)

params = sum(p.numel() for p in model.parameters())
print(f'{params=:,}')

try:
    model.load_state_dict(torch.load(f'./model{params}.pt'))
    print('loaded')
except:
    pass

loaded
params=19,118,848


In [15]:
def file_iter(fname='finnegans_wake.txt', chunk=8192):
    for _ in range(10):
        with open(fname, 'rb') as f:
            while len(x := f.read(chunk)) == chunk:
                x = torch.frombuffer(x, dtype=torch.uint8)
                yield x.to(device).long()

def shuf_iter(src=file_iter(), bufsize=512):
    buffer = [None] * bufsize
    for x in src:
        i = torch.randint(0, bufsize, ())
        out = buffer[i]
        buffer[i] = x
        if out is not None:
            yield out
    for out in buffer:
        if out is not None:
            yield out

def batch_iter(src=shuf_iter(), batch=4):
    buf = []
    for x in src:
        buf += [x]
        if len(buf) == batch:
            yield torch.stack(buf)
            buf = []

def train_iter(src=batch_iter(), chunk=512):
    for x in src:
        for o in range(0, x.shape[1]-chunk+1, chunk):
            yield x[:, o:o+chunk]

trainset = train_iter()

In [16]:
import plotly.graph_objects as go

step = 0
bpc_avg = 0
bpc_curve = []
off_curve = []
bpc_win = []

sfig = go.FigureWidget()
sfig.add_scatter()
sfig.update_layout(
    margin=dict(l=20, r=20, t=20, b=20),
)
sfig

FigureWidget({
    'data': [{'type': 'scatter', 'uid': '9efffecc-fb74-411d-8263-59aa1d75b974'}],
    'layout': {'margin': {'b': 20, 'l': 20, 'r': 20, 't': 20}, 'template': '...'}
})

In [None]:


model.train().float().to(device)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=2e-4,
    betas=(0.9, 0.99),
    weight_decay=0.0 # lr / 100"
)
loss_func = torch.nn.CrossEntropyLoss()

# model.reset_states()
mem = None

for data in (prog := tqdm(trainset)):

    optimizer.zero_grad()

    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        outputs, mem = model(data, mem)

        targets = data[:, 1:].flatten()
        outputs = outputs[:, :-1].flatten(0,1)

        loss = loss_func(outputs, targets)

    loss.backward()

    optimizer.step()

    step += 1
    bpc = loss.item() / math.log(2)
    bpc_win = bpc_win[-99:] + [bpc]
    bpc_avg = sum(bpc_win) / len(bpc_win)
    prog.set_description(f'{(bpc_avg):.4f} bpc')

    if step % 64 == 0:
        bpc_curve += [bpc_avg]
        sfig.data[0].y = bpc_curve

        for p in model.parameters():
            p.data.clamp_(-20, 20)

In [18]:
torch.save(model.state_dict(), f'./model{params}.pt')

In [20]:
gen = b'The meaning of life is '
entropy = 1

model.eval().cpu()

gen = torch.frombuffer(gen, dtype=torch.uint8)
gen = gen.long()[None, :]
prev_len = 0
mem = None
with torch.no_grad():
    for t in tqdm(range(200)):
        pred, mem = model(gen[:, prev_len:], mem)
        pred = pred[0, -1:] / entropy
        
        choose = torch.multinomial(pred.softmax(dim=-1), 1)
        
        gen = torch.cat((gen, choose), dim=-1)

out = bytes(gen[0].tolist()).decode('utf-8')
print(out)

100%|██████████| 1024/1024 [01:10<00:00, 14.42it/s]

The meaning of life is fatal as the Jacob’s great way to the way With a Grandest Street in a hurtwhyed have Tunderloon snowdow. And still a light and last perhaps. The dame dowager’s tay in the dyings and they say. Notorious, there is no strong and if she showed no more scheinish might fall a delltangle. Declare to present. And was theirs to be continued. For as Punch, hand and rarring, rouge. And it was not a boundless either of the younging panes from the bird of the three ballows so was feeling with the forest. Though his free link has a stroke to lay and his frokerfor. But the ruck mack that would she the charmhaloosum? Pass the pipette whereas he lags a toll a tarnpike. Adversarian! The swabsister Kates for the Clunkey soft Danno, she said the shortlegman may have been tourned by the sundawn. And the prankquean went and all the way how it was in the barrel, read the strangewrote anaglyptics of his slow polishments and threehailed concerning and not a few eggs in begging quite havi


