In [None]:
%pip install -r requirements.txt

In [10]:
import numba
import torch
import numpy as np

@numba.jit
def sinusoids(dt, resolution=0.01, dim=64, min_period=0.05, max_period=60):
    t = np.cumsum(dt * resolution).astype(np.float32)
    lmin = np.log(np.pi * 2 / min_period)
    lmax = np.log(np.pi * 2 / max_period)
    f = np.exp(np.linspace(lmin, lmax, dim // 2)).astype(np.float32)
    t = t[:, None] * f[None, :]
    s = np.stack((np.sin(t), np.cos(t)), axis=-1)
    return s.reshape(s.shape[0], dim)

token_dtype = np.dtype([
    ('dt', 'u1'),
    ('note', 'u1'),
    ('program', 'u1'),
], align=True)

@numba.jit
def tokenize(track, resolution=0.01, max_dt=63, max_toks=16384):
    out = np.zeros(max_toks, dtype=token_dtype)
    fill = 0
    t = 0
    for m in track:
        tick = int(m.time / resolution)
        
        # fill extended pauses with note=0 tokens
        while tick - t >= max_dt:
            out[fill].dt = max_dt
            out[fill].note = 0
            fill += 1
            t += max_dt
            if fill >= max_toks:
                return out
        
        out[fill].dt = tick - t
        # use note+128 for note_off
        out[fill].note = m.key if m.type==tensormidi.NOTE_ON else m.key+128
        # use program 128 for drums
        out[fill].program = 128 if m.channel==9 else m.program
        fill += 1
        t = tick
        if fill >= max_toks:
            return out

    return out[:fill]

@numba.jit
def note_states(notes):
    n = len(notes)
    out = np.empty((n,128), dtype=np.uint8)
    s = np.zeros((128), dtype=np.uint8)
    for i, m in enumerate(notes):
        if m == 0: # note:0 is no-op
            pass
        elif m < 128:
            s[m] = 1
        else:
            s[m-128] = 0
        out[i] = s
    return out

def midi_tensors(track):
    tokens = tokenize(track).view(np.recarray)
    pos_embd = sinusoids(tokens.dt)
    note_on = note_states(tokens.note)
    return (
        torch.tensor(tokens.dt).cuda().long(),
        torch.tensor(tokens.note).cuda().long(),
        torch.tensor(tokens.program).cuda().long(),
        torch.tensor(note_on).cuda().bfloat16(),
        torch.tensor(pos_embd).cuda().bfloat16(),
    )

In [11]:
import sqrll.dataloaders as dl
import tensormidi
import os

# use whatever dir you want
defaultdir = '~/dev/data/midi/bread-midi-dataset'

def dir_iter(rootdir=defaultdir):
    root = os.path.expanduser(rootdir)
    for curdir, _, files in os.walk(root):
        for f in files:
            yield os.path.join(curdir, f)

def track_iter(data):
    for fname in data:
        try:
            for t in tensormidi.load(fname, merge_tracks=False):
                yield t
        except:
            pass # ignore corrupt files


def data_iter(batch=8, seqlen=512):
    data = dir_iter()
    data = track_iter(data)
    data = map(midi_tensors, data)
    data = dl.multitetris(data, 5, batch=batch, seqlen=seqlen)
    return data


In [16]:
from sqrll.sqrllm import SqrllConfig, SqrLLM
from dataclasses import dataclass


@dataclass
class SqrllJamConfig(SqrllConfig):
    n_rest: int = 64
    n_prog: int = 130
    n_pos: int = 64

class SqrllJam(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        # bypass the SqrLLM input projections because
        # we do our own (more complex) input projection
        cfg.n_tokens_in = 0
        cfg.n_vector_in = 0
        cfg.n_out = 256 + cfg.n_rest

        self.config = cfg
        self.e_dt = torch.nn.Embedding(cfg.n_rest, cfg.n_embed)
        self.e_note = torch.nn.Embedding(256, cfg.n_embed)
        self.e_prog = torch.nn.Embedding(cfg.n_prog, cfg.n_embed)
        self.w_on = torch.nn.Linear(128, cfg.n_embed, bias=False)
        self.w_pos = torch.nn.Linear(cfg.n_pos, cfg.n_embed, bias=False)
        self.idropout = torch.nn.Dropout(p=cfg.dropout)
        self.sqrll = SqrLLM(cfg)

    def forward(self, dt, note, program, note_on, pos_embd, mem=None):
        x = 0
        x = x + self.e_dt(dt)
        x = x + self.e_note(note)
        x = x + self.e_prog(program)
        x = x + self.w_on(note_on)
        x = x + self.w_pos(pos_embd)
        x = self.idropout(x)

        y, mem = self.sqrll(in_raw=x, mem=mem)
        p_note, p_dt = y[..., :256], y[..., 256:]
        return p_note, p_dt, mem

    def loss(self, dt, note, program, note_on, pos_embd, mem=None):
        p_note, p_dt, mem = self.forward(
            dt, note, program, note_on, pos_embd, mem)

        # auto regressive cross entropy loss
        p_note = p_note[:, :-1].log_softmax(dim=-1)
        p_dt = p_dt[:, :-1].log_softmax(dim=-1)
        label_note = note[:, 1:, None]
        label_dt = dt[:, 1:, None]
        loss = -torch.gather(p_note, -1, label_note).mean()
        loss = loss - torch.gather(p_dt, -1, label_dt).mean()
        
        return loss, mem

    def save(self, filename):
        model_dict = {
            'config': self.config,
            'weights': self.state_dict(),
        }
        torch.save(model_dict, filename)

    @staticmethod
    def load(filename):
        d = torch.load(filename)
        model = SqrllJam(d['config'])
        model.load_state_dict(d['weights'])
        return model
    

In [18]:
model_file = 'model.pt'

try:
    model = SqrllJam.load(model_file)
    print('loaded')
except:
    cfg = SqrllJamConfig(
        n_rest = 64,
        n_pos = 64,
        # from SqrllmConfig:
        n_embed = 384,
        n_mem = 384,
        n_layer = 18,
        n_ffn = 384,
        ffn_rate = 2,
        dropout = 0.1,
    )
    model = SqrllJam(cfg)
    print('new')

model = model.cuda()
params = sum(p.numel() for p in model.parameters())
print(f'{params=:,}')

new
params=17,660,160


In [21]:
import plotly.graph_objects as go

step = 0
loss_avg = 0
loss_curve = []

sfig = go.FigureWidget()
sfig.add_scatter()
sfig.update_layout(
    margin=dict(l=20, r=20, t=20, b=20),
)
sfig

FigureWidget({
    'data': [{'type': 'scatter', 'uid': '2cd432ee-4e17-49f4-ad5a-48bc999ccf2f'}],
    'layout': {'margin': {'b': 20, 'l': 20, 'r': 20, 't': 20}, 'template': '...'}
})

In [22]:
from tqdm import tqdm

model.train()

lr = 4e-4
print(f'{lr=}')
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=lr,
    betas=(0.9, 0.996),
    weight_decay=1e-6,
)
mem = None
dataset = data_iter(batch=32, seqlen=512)

for pack in (bar := tqdm(dataset)):

    optimizer.zero_grad()

    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        loss, mem = model.loss(*pack, mem)

    if not (loss.item() < 9e9):
        print(f'{loss.item()=}')
        continue

    loss.backward()
    optimizer.step()

    step += 1
    loss_avg += (loss.item() - loss_avg) / min(step, 1000)
    bar.set_description(f'{(loss_avg):.4f}')

    if step % 64 == 0:
        loss_curve += [loss_avg]
        sfig.data[0].y = loss_curve
        mem = None
        
    if step % 2048 == 0:
        model.save(model_file)
        with open('loss.txt', 'a') as f:
            f.write(f'{step=}\n')
            f.write(f'{loss_avg=}\n')


lr=0.0004


0it [00:00, ?it/s]Exception ignored in: <generator object track_iter at 0x7b6cdbcd40b0>
Traceback (most recent call last):
  File "/tmp/ipykernel_74695/2911003321.py", line 16, in <module>
RuntimeError: generator ignored GeneratorExit
1.4054: : 1732it [02:19, 12.37it/s]


KeyboardInterrupt: 