In [52]:
# import os
# os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import torch
from sqrll.sqrllm import SqrllConfig, SqrLLM
# from tokenizers import Tokenizer
import btok

assert torch.cuda.is_available()

device = torch.device('cuda')
dtype = torch.float32

n_vocab = 2048
model_file = f'models/model{n_vocab}wu.pt'
token_file = f'models/bpe{n_vocab}wu.pack'


# tokenizer = Tokenizer.from_file(token_file)
# n_vocab = tokenizer.get_vocab_size()
with open(token_file, 'rb') as f:
    tokenizer = btok.Tokenizer(f.read())
n_vocab = tokenizer.num_tokens()
print(f'{n_vocab=}')

try:
    model = SqrLLM.load(model_file)
    print('loaded')
except:
    cfg = SqrllConfig(
        n_in = n_vocab,
        n_out = n_vocab,
        n_embed = 256,
        n_mem = 384,
        n_ffn = 384,
        ffn_rate = 3,
        n_layer = 18,
        dropout = 0.05,
    )
    model = SqrLLM(cfg)

model = model.float().to(device)
params = sum(p.numel() for p in model.parameters())
print(f'{params=:,}')

n_vocab=2048
loaded
params=11,698,432


In [53]:
vocab = ([tokenizer.token(i) for i in range(n_vocab)])
vocab_lens = [len(v) for v in vocab]
vocab_lens = torch.tensor(vocab_lens, device=device)
vocab_lens.float().mean().item()

3.3505859375

In [54]:
from sqrll import zimloader
from sqrll import dataloaders as dl
from os.path import expanduser

def tokenize(x):
    x = bytes(x, 'utf8') if type(x)==str else x
    x = tokenizer.encode(x)
    return torch.tensor(x)

def train_iter():
    # data_zim = [expanduser('~/dev/data/text/wiki/zim/wikipedia_en_all_nopic_2024-04.shuf.txt')]
    # data_zim = dl.read_raw(data_zim, chunk=65536)

    data_zim = [expanduser('~/dev/data/text/wiki/zim/wikipedia_en_all_nopic_2024-04.zim')]
    data_zim = zimloader.read_zims(data_zim, nthreads=6)

    data_simp = [expanduser('~/dev/data/text/wiki/zim/wikipedia_en_simple_all_nopic_2024-05.zim')]
    data_simp = zimloader.read_zims(data_simp, nthreads=6)

    data_qa = [expanduser('~/dev/data/text/squad/train-v2.0-flat.txt')] * 10
    data_qa = dl.read_raw(data_qa, chunk=2048)
    data_qa = dl.shuffle(data_qa, bufsize=4096)

    data = dl.mix(data_zim, data_simp, data_qa)
    data = dl.shuffle(data, bufsize=16)
    # data = dl.mix(data_zim)
    
    # data = dl.tokenize(data, tokenizer)

    data = map(tokenize, data)

    data = dl.tetris(data, batch=8, seqlen=2048)
    data = dl.tensor_to(data, device)
    return data

    
trainset = train_iter()

In [55]:
import plotly.graph_objects as go

step = 0
bpc_avg = 0
tot_bytes = 0
bpc_curve = []
bytes_curve = []

sfig = go.FigureWidget()
sfig.add_scatter()
sfig.update_layout(
    margin=dict(l=20, r=20, t=20, b=20),
)
sfig

FigureWidget({
    'data': [{'type': 'scatter', 'uid': '2dca34d0-8aef-4f28-92fe-567cdbb13850'}],
    'layout': {'margin': {'b': 20, 'l': 20, 'r': 20, 't': 20}, 'template': '...'}
})

In [57]:
import math
from tqdm import tqdm

model.train().float().to(device)

lr = 1e-4
print(f'{lr=}')
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=lr,
    betas=(0.9, 0.999),
    weight_decay=4e-6,
)
loss_func = torch.nn.CrossEntropyLoss()

mem = None

for data in (prog := tqdm(trainset)):

    optimizer.zero_grad()

    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        outputs, mem = model(data, mem)

        # targets = data[:, 1:].flatten()
        # outputs = outputs[:, :-1].flatten(0,1)
        # loss = loss_func(outputs, targets)
            
        targets = data[:, 1:, None]
        outputs = outputs[:, :-1, :].log_softmax(dim=-1)
        bits = -torch.gather(outputs, 2, targets)

        loss = bits.mean()

    loss.backward()

    optimizer.step()
    bpc = (bits / vocab_lens[targets]).mean().item() / math.log(2)

    step += 1
    bpc_avg += (bpc - bpc_avg) / min(step, 500)

    tot_bytes += data.numel()
    prog.set_description(f'{(bpc_avg):.4f} bpc')

    if step % 64 == 0:
        bpc_curve += [bpc_avg]
        bytes_curve += [tot_bytes]
        sfig.data[0].y = bpc_curve
        sfig.data[0].x = bytes_curve

        for p in model.parameters():
            p.data.clamp_(-20, 20)
        
        mem = None
            
    if step % 8192 == 0:
        model.save(model_file)

lr=0.0001


1.4007 bpc: : 86456it [1:39:54, 14.42it/s]


KeyboardInterrupt: 

In [51]:
model.save(model_file)

In [42]:
# gen = b'The meaning of life is '
# gen = b'What is the meaning of life?\n'
# gen = bytes('Where is 中华人民共和国?\n', 'utf8')
# gen = bytes('中华人民共和国', 'utf8')
# gen = b'death destruction hate and murder'
gen = tokenize(gen)
gen = dl.batch([gen], 1)
gen = next(gen).to(device)

entropy = 0.5

model.eval()

# gen = torch.frombuffer(gen, dtype=torch.uint8)
# gen = gen.long()[None, :]
prev_len = 0
mem = None
with torch.no_grad():
    for t in tqdm(range(200)):
        pred, mem = model(gen[:, prev_len:], mem)
        pred = pred[0, -1:] / entropy
        
        choose = torch.multinomial(pred.softmax(dim=-1), 1)
        
        gen = torch.cat((gen, choose), dim=-1)

# out = bytes(gen[0].tolist()).decode('utf-8')

gen = gen[0].tolist()
out = tokenizer.decode(gen)
out = str(out, 'utf8', errors="ignore")

print(out)

100%|██████████| 200/200 [00:00<00:00, 202.20it/s]

中华人民共和国公國史合实在安国公原光子形世界光南女加皇国語国公六人辰人兵平國加南大学元大一式大牡人理元务德公天博大子政実女寺，宮语李伝元安家太宮復





In [None]:
acc = 0
bpc = 0
count = 0

fname = expanduser('~/dev/data/text/wiki/enwik8_test.txt')

with open(fname, 'r', encoding='utf-8') as f:
    x = f.read()
    x = tokenizer.encode(x).ids
    x = torch.tensor(x).to(device)

    chunksz = 8192
    nchunk = len(x) // chunksz
    x = x[:nchunk * chunksz]
    x = x.view(nchunk, 1, chunksz)

testset = x

model.eval().to(device)
mem = None

with torch.no_grad():
    for data in (prog := tqdm(testset)):
        
        outputs, mem = model(data, mem)
        outputs = outputs[:, :-1, :].log_softmax(dim=-1)
        targets = data[:, 1:, None]
        
        argmax = outputs.argmax(dim=-1)
        bits = -torch.gather(outputs, 2, targets)
        
        # if bits.isnan().any() or any(m.isnan().any() for m in mem):
        #     print(f'{mem=}')
        #     print(f'{outputs=}')
        #     print(f'{logmax=}')
        #     print(f'{bits=}')
        #     break

        
        bpc += (bits / vocab_lens[targets]).mean().item() / np.log(2)

        acc += (targets[:,:,0] == argmax).float().mean()
        count += 1
        prog.set_description(f'{bpc/count:.4f}')

acc = acc.item() / count
bpc = bpc.item() / count
print(f'{acc=} {bpc=}')

1.6812: 100%|██████████| 421/421 [00:32<00:00, 12.90it/s]

acc=0.4699897675502895 bpc=1.6811940598583646



