<a href="https://colab.research.google.com/github/tfglynn/piano-nn/blob/master/October_piano_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Lilypond

In [None]:
!apt install lilypond

# Server

In [None]:
# Copied from elsewhere

import os, sys, requests, json
from multiprocessing import Process
from flask import Flask, request, abort, logging

run_thread = True

app = Flask(__name__)

@app.route("/", methods=['GET'])
def test():
    salida = {'status':'OK','message':'Test'}
    return json.dumps(salida)

@app.route("/echo", methods=["GET"])
def echo():
    if (msg := request.headers.get("message")) is not None:
        return json.dumps({"echo": msg})

def stop_server():
  global server
  if server is not None:
    server.terminate()
    server.join()

def start_server(run_thread):
  global server
  if run_thread:
    server = Process(target=app.run, kwargs={"host": "localhost", "port": 8000})
    server.start()
  else:
    app.run(host="localhost", port=8000)

In [None]:
stop_server()

In [None]:
start_server(run_thread)

 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://localhost:8000
INFO:werkzeug:[33mPress CTRL+C to quit[0m


In [None]:
!npm install -g localtunnel

[K[?25h/tools/node/bin/lt -> /tools/node/lib/node_modules/localtunnel/bin/lt.js
+ localtunnel@2.0.2
added 22 packages from 22 contributors in 2.282s


In [None]:
!curl ipv4.icanhazip.com

104.199.147.125


In [None]:
!nohup lt --port 8000 >lt.log 2>&1 &

In [None]:
!cat lt.log

your url is: https://shy-plums-poke.loca.lt


In [None]:
!ps -ef | grep lt

root          83       7  0 01:15 ?        00:00:01 /usr/local/bin/dap_multiplexer --domain_socket_path=/tmp/debugger_26lqol2dmg
root       13389       1  0 02:10 ?        00:00:00 node /tools/node/bin/lt --port 8000
root       15505    5834  0 02:18 ?        00:00:00 /bin/bash -c ps -ef | grep lt
root       15507   15505  0 02:18 ?        00:00:00 grep lt


In [None]:
!kill 13389

# Preamble

## Installing libraries we need

In [None]:
%%capture
!pip install pretty_midi
!pip install rotary-embedding-torch
#!pip install hdbscan

## Imports

In [None]:
import bisect
import copy
import datetime as dt
import matplotlib.pyplot as plt
import numpy as np
import os
import pathlib
import pickle
import PIL
import pretty_midi as pm
import subprocess
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
from collections import Counter
from enum import Enum
from functools import partial
from google.colab import drive
#from hdbscan import HDBSCAN
from IPython.display import display, HTML, Image
from matplotlib import collections as mc
from operator import attrgetter, itemgetter
from sklearn.manifold import TSNE
from torch import optim
from torch.distributions import Beta
from torch.utils.data import Dataset, DataLoader
from rotary_embedding_torch import RotaryEmbedding

# Globals

In [None]:
DEFAULT_EPSILON = 1e-6
INCLUDE_SONG_BOUNDARIES = False
N_KEYS = 88 # all keys of a piano
NEXT_TOKEN = 0
BOUNDARY_TOKEN = N_KEYS + 1

MIN_PITCH = 21 # lowest A
MAX_PITCH = MIN_PITCH + N_KEYS - 1 # highest C

# Lilypond names
SHIFT_NAMES = ["32", "16", "16.", "8", "8.", "4", "4.", "2", "2.", "1", "1.", "breve"]
PITCH_NAMES = ["a", "ais", "b", "c", "cis", "d", "dis", "e", "f", "fis", "g", "gis"]

BATCH_SIZE = 4
CONTEXT_WINDOW = 1024
N_TRAINING_STEPS = 100_000
LEARNING_RATE = 1e-3

N_EVENTS = N_KEYS + (2 if INCLUDE_SONG_BOUNDARIES else 1)
N_SHIFTS = 12 # 32, 16, 16., 8, 8., 4, 4., 2, 2., 1, 1., breve

# Setup

In [None]:
if torch.cuda.is_available():
    device_name = "cuda"
    print("GPU support enabled")
else:
    device_name = "cpu"
    print("Using CPU only")
device = torch.device(device_name)

GPU support enabled


In [None]:
drive.mount("/content/drive")

Mounted at /content/drive


## Utilities

In [None]:
def pitch_name(p):
    p = p - MIN_PITCH + 1
    return PITCH_NAMES[(p - 1) % 12]

In [None]:
def ma(x, w):
    return np.convolve(x, np.ones(w), "valid") / w

In [None]:
def rolling(x, r):
    """Repeats `x` and rolls it `r` times

    x: (B?, T, D)
    r: a positive integer

    Returns: (B?, r, T, D)
    """
    t = x.shape[-2]
    reps = (1, r, 1, 1) if x.dim() == 3 else (r, 1, 1)
    x = x.unsqueeze(-3).repeat(reps)
    x = F.pad(x, (0, 0, 0, 1)).flatten(-3, -2)[..., :r*t, :].view(x.shape)
    return x

In [None]:
def l2_normalize(x, eps=1e-5):
    """Turns a tensor of shape (..., D) into unit vectors"""
    return x * torch.rsqrt(eps + x.pow(2).sum(dim=-1, keepdim=True))

In [None]:
if INCLUDE_SONG_BOUNDARIES:
    def chord_indices(x):
        punctuation = ((x == 0) | (x == N_EVENTS - 1)).int().to(device)
        return punctuation.cumsum(dim=-1) - punctuation + 1
else:
    def chord_indices(x):
        punctuation = (x == 0).int().to(device)
        return punctuation.cumsum(dim=-1) - punctuation + 1

In [None]:
def chord_position(x):
    punctuation = (x == 0).long().to(device)
    n = torch.arange(x.shape[-1], dtype=torch.long).to(device)
    if x.dim() == 2:
        n = n.unsqueeze(0)
    return n - F.pad(((n + 1) * punctuation)[..., :-1].cummax(dim=-1)[0], (1, 0))

## Progress bar

In [None]:
class Progress:
    def __init__(self, it, max=None):
        if max is None:
            max = len(it)
        self.max = max
        self.bar = display(self.html(0, max), display_id=True)
        self.i = 0
        self.it = iter(it)

    def html(self, value, max):
        return HTML("""
            <progress value="{value}" max="{max}", style="width: 90%; background-color: black">
            {value}
            </progress>
        """.format(value=value, max=max))

    def __iter__(self):
        try:
            while True:
                x = next(self.it)
                self.i += 1
                self.bar.update(self.html(self.i % self.max, self.max))
                yield x
        except StopIteration:
            pass

def progress(it, max=None):
    return Progress(it, max)

## MIDI conversion

In [None]:
def dedup(events, shifts):
    deduped_events = []
    deduped_shifts = []
    e1 = events[0]
    s1 = shifts[0]
    for e2, s2 in zip(events[1:], shifts[1:]):
        if e2 == e1:
            s1 = max(s1, s2)
        else:
            deduped_events.append(e1)
            deduped_shifts.append(s1)
            e1, s1 = e2, s2
    deduped_events.append(e1)
    deduped_shifts.append(s1)
    return (deduped_events, deduped_shifts)

In [None]:
def midi_to_instructions(mid, eps_scale=(2 ** 3)):
    mid = copy.deepcopy(mid)
    tempo = mid.estimate_tempo()
    qnote = 60 / tempo
    notes = sum([i.notes for i in mid.instruments if not i.is_drum], start=[])
    notes = sorted(notes, key=lambda n: n.start)

    #
    # Step 1: Cleanup
    #

    # The estimated tempo is probably slightly off, so we'll try to find
    # a representative note.
    for n in notes:
        if np.abs((duration := n.end - n.start) - qnote) < qnote / eps_scale:
            qnote = duration
            break
    eps = qnote / eps_scale

    # If 1/10 of the notes are about to get thrown out, shrink the quarter note.
    if np.mean([n.end - n.start < 1.5 * eps for n in notes]) >= 0.1:
        qnote /= 2
        eps /= 2

    # Throw out any garbage
    notes = [n for n in notes if n.end - n.start >= 1.5 * eps]

    #
    # Step 2: Adjusting everything
    #

    moments = []
    for i, n in enumerate(notes):
        moments.append((n.start, "s", i))
        moments.append((n.end, "e", i))
    moments = sorted(moments, key=lambda m: m[0])
    clusters = []
    current_cluster = [moments[0]]
    for moment in moments[1:]:
        if (
            moment[0] < np.mean([m[0] for m in current_cluster]) + eps and
            moment[-1] not in [m[-1] for m in current_cluster]):
            current_cluster.append(moment)
            # Is the cluster too big?
            if current_cluster[-1][0] - current_cluster[0][0] > eps:
                c1, c2 = [current_cluster[0]], [current_cluster[-1]]
                for c in current_cluster[1:-1]:
                    d1 = c[0] - current_cluster[0][0]
                    d2 = current_cluster[-1][0] - c[0]
                    if d1 > d2:
                        c2.append(c)
                    else:
                        c1.append(c)
                clusters.append(c1)
                current_cluster = c2
        else:
            clusters.append(current_cluster)
            current_cluster = [moment]
    clusters.append(current_cluster)

    for cluster in clusters:
        avg = np.mean([m[0] for m in cluster])
        for _, kind, i in cluster:
            if kind == "s":
                notes[i].start = avg
            else: # e
                notes[i].end = avg

    mid.instruments = [pm.Instrument(program=0)]
    mid.instruments[0].notes = notes

    # This is the last chance to get a different quarter note
    tempo = mid.estimate_tempo()
    qnote = 60 / tempo
    for n in notes:
        if np.abs((duration := n.end - n.start) - qnote) < qnote / eps_scale:
            qnote = duration
            break
    eps = qnote / eps_scale

    #
    # Step 2: Actually discretize it
    #

    notes = sorted(notes, key=lambda n: (n.start, -n.pitch)) # negative pitch so it's high to low
    events = []
    shifts = []
    clock = notes[0].start
    for n in notes:
        if n.end - n.start < eps:
            continue
        if (start := n.start) > clock:
            events.append(NEXT_TOKEN)
            shifts.append(start - clock)
            clock = start
        events.append(np.clip(n.pitch, MIN_PITCH, MAX_PITCH) - MIN_PITCH + 1)
        shifts.append(n.end - start)
    end = max(n.end for n in notes)
    events.append(NEXT_TOKEN)
    shifts.append(end - clock)

    # Make sure no notes got doubled up.
    events, shifts = dedup(events, shifts)

    events = np.array(events)
    shifts = np.array(shifts)
    # We need to make sure all the standard shifts are in order,
    # for the relative encoding to work.
    powers = np.power(2., np.array([-2, -1, 0, 1, 2]))
    standard_shifts = np.empty(2 + 2 * powers.size, dtype=np.float32)
    standard_shifts[0] = 2 ** -3
    standard_shifts[-1] = 2 ** 3
    standard_shifts[1:-1:2] = powers
    standard_shifts[2:-1:2] = 1.5 * powers
    standard_shifts *= qnote
    shifts = np.argmin(
        np.abs(shifts.reshape(-1, 1) - standard_shifts.reshape(1, -1)),
        axis=-1)

    events = np.concatenate([[BOUNDARY_TOKEN], events, [BOUNDARY_TOKEN]])
    shifts = np.concatenate([[0], shifts, [0]])

    return (events, shifts, tempo)

In [None]:
powers = np.power(2., np.array([-2, -1, 0, 1, 2]))
SHIFT_LENGTHS = np.empty(2 + 2 * powers.size, dtype=np.float32)
SHIFT_LENGTHS[0] = 2 ** -3
SHIFT_LENGTHS[-1] = 2 ** 3
SHIFT_LENGTHS[1:-1:2] = powers
SHIFT_LENGTHS[2:-1:2] = 1.5 * powers

SHIFT_LENGTHS_TENSOR = torch.tensor(SHIFT_LENGTHS).to(device)

def instructions_to_midi(events, shifts, tempo=120):
    qnote = 60 / tempo
    standard_shifts = SHIFT_LENGTHS * qnote

    mid = pm.PrettyMIDI()
    mid.instruments.append(pm.Instrument(0)) # piano
    notes = []
    clock = 0
    for e, s in zip(events, shifts):
        if e == BOUNDARY_TOKEN:
            continue
        if e == NEXT_TOKEN:
            clock += standard_shifts[s]
        else:
            end = clock + standard_shifts[s]
            notes.append(pm.Note(127, int(e) - 1 + MIN_PITCH, clock, end))
    mid.instruments[0].notes = notes
    return mid

def plot_midi(mid, title=None, filename=None, extrapolated=0, crop=None, truncate=0, beginning_marks=False, trim=True):
    if not isinstance(mid, pm.PrettyMIDI):
        mid = pm.PrettyMIDI(str(mid))
    notes = sum([i.notes for i in mid.instruments if not i.is_drum], start=[])
    if trim:
        min_start = min(n.start for n in notes)
        for n in notes:
            n.start -= min_start
            n.end -= min_start
    cropmax = np.inf
    if crop is not None:
        cropmax = crop[1] if isinstance(crop, tuple) else crop
    song_length = min(max(n.end for n in notes), cropmax)
    lines = [[(n.start + truncate, n.pitch), (n.end - truncate, n.pitch)] for n in notes]
    max_velocity = max(n.velocity for n in notes)
    alphas = [n.velocity / max_velocity for n in notes]
    colors = ["b"] * (len(notes) - extrapolated) + ["r"] * extrapolated
    if beginning_marks:
        beginning_length = song_length * 0.01
        beg_lc = mc.LineCollection(
            [[(x1, y1), (min(x2, x1 + beginning_length), y2)] for [(x1, y1), (x2, y2)] in lines],
            alpha=alphas, colors=colors)
        lc = mc.LineCollection(lines, alpha=[a * 0.5 for a in alphas], colors=colors)
    else:
        lc = mc.LineCollection(lines, alpha=alphas, colors=colors)
    fig, ax = plt.subplots()
    ax.set_xlabel("Seconds")
    ax.set_ylabel("Pitch")
    ax.add_collection(lc)
    if beginning_marks:
        ax.add_collection(beg_lc)
    ax.autoscale()
    if crop is not None:
        if isinstance(crop, tuple):
            ax.set_xlim(*crop)
        else:
            ax.set_xlim(0, crop)
    if title is not None:
        plt.title(title)
    if filename is None:
        plt.show()
    else:
        plt.savefig(filename)
        plt.close()

In [None]:
def convert_midi(inpath, outpath, eps_scale=(2 ** 3)):
    mid = pm.PrettyMIDI(inpath)
    events, shifts, tempo = midi_to_instructions(mid, eps_scale=eps_scale)
    mid = instructions_to_midi(events, shifts, tempo)
    mid.write(outpath)

## Dataset

In [None]:
def shuffle_chords(events, shifts):
    start = 1 # Index 0 should be a BOS token
    N = len(events)
    while start < N - 1: # Last token should also be BOS
        end = start
        for end in range(start, N):
            if events[end] == NEXT_TOKEN:
                break
        if end > start + 1: # Only bother if there's >1 note
            indices = np.arange(start, end)
            shuffled = indices.copy()
            np.random.shuffle(shuffled)
            events[indices] = events[shuffled]
            shifts[indices] = shifts[shuffled]
        start = end + 1

class PickleSet(Dataset):
    def __init__(self, file_list, augment=True, seed=None, random_chords=False, include_boundaries=False):
        self.file_list = list(file_list)
        self.augment = augment
        self.rng = np.random.default_rng(seed)
        self.random_chords = random_chords
        self.include_boundaries = include_boundaries

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        filename = str(self.file_list[idx])
        with open(filename, "rb") as infile:
            events, shifts, _tempo = pickle.load(infile)
            if not self.include_boundaries:
                events = events[1:-1]
                shifts = shifts[1:-1]
            if self.random_chords:
                shuffle_chords(events, shifts)
        if self.augment:
            punctuation = (events == NEXT_TOKEN) | (events == BOUNDARY_TOKEN)
            min_pitch = min(events[~punctuation])
            max_pitch = max(events[~punctuation])
            low = max(-7, 1 - min_pitch)
            high = min(7, N_KEYS - max_pitch)
            semitones = self.rng.integers(low=low, high=high+1)
            events[~punctuation] = events[~punctuation] + semitones
        return (
            filename,
            torch.tensor(events, dtype=torch.long),
            torch.tensor(shifts, dtype=torch.long))

def get_midi_list():
    return list(pathlib.Path("/content/drive/My Drive/midi").glob("**/*.mid"))

def get_pickle_list():
    return list(pathlib.Path("/content/drive/My Drive/pickle").glob("*.pickle"))

def get_training_set():
    return list(pathlib.Path("/content/drive/My Drive/pickle/training").glob("*.pickle"))

def get_validation_set():
    return list(pathlib.Path("/content/drive/My Drive/pickle/validation").glob("*.pickle"))

def collate(max_length, b, seed=None):
    rng = np.random.default_rng(seed)
    masks_batch = []
    event_batch = []
    shift_batch = []
    chord_batch = []
    filenames = []
    for filename, events, shifts in b:
        chords = chord_position(events)
        filenames.append(filename)
        length = len(events)
        if (excess := length - max_length) > 0:
            if (rng.random() < 0.1):
                # Take from one of the ends
                i = 0 if rng.random() < 0.5 else excess
            else:
                # Take a random slice
                i = rng.integers(low=0, high=excess+1)
            mask = torch.ones(max_length)
            events = events[i:i+max_length]
            shifts = shifts[i:i+max_length]
            chords = chords[i:i+max_length]
            length = max_length
        if length < max_length:
            events = F.pad(events, (0, max_length - length))
            shifts = F.pad(shifts, (0, max_length - length))
            chords = F.pad(chords, (0, max_length - length))
            mask = torch.cat([torch.ones(length), torch.zeros(max_length - length)])
        masks_batch.append(mask)
        event_batch.append(events)
        shift_batch.append(shifts)
        chord_batch.append(chords)
    return (
        filenames,
        torch.stack(masks_batch).to(device),
        torch.stack(event_batch).to(device),
        torch.stack(shift_batch).to(device),
        torch.stack(chord_batch).to(device))

def batch_cycle(dataloader):
    while True:
        for batch in dataloader:
            yield batch

In [None]:
def refresh_dataset(convert=False):
    print("Deleting old pickles")
    print("   ... from the training set ...")
    for path in progress(get_training_set()):
        os.remove(path)
    print("   ... from the validation set ...")
    for path in progress(get_validation_set()):
        os.remove(path)
    print("   ... from the parent folder ...")
    for path in progress(get_pickle_list()):
        os.remove(path)

    print("Creating new pickles ...")
    for path in progress(get_midi_list()):
        base = str(path.parent / path.stem).replace("/", "_")
        pickle_name = base + ".pickle"
        filename = str(path)
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=RuntimeWarning)
            mid = pm.PrettyMIDI(filename)
        stuff = midi_to_instructions(mid)
        if convert:
            try:
                mid = instructions_to_midi(*stuff)
                converted_name = base + ".mid"
                mid.write("/content/drive/My Drive/converted/" + converted_name)
            except Exception as e:
                print(f"Caught exception while converting file {path}: {e}")
                continue
        with open("/content/drive/My Drive/pickle/" + pickle_name, "wb") as outfile:
            pickle.dump(stuff, outfile)

    print("Making training and validation sets ...")
    for path in progress(get_pickle_list()):
        if np.random.rand() < 0.01:
            path.rename(path.parent / "validation" / path.name)
        else:
            path.rename(path.parent / "training" / path.name)
    print("Done!")

In [None]:
#refresh_dataset(convert=True)

## Plots

In [None]:
def plot_attention(A, filename=None, standardize=False, title="Attention heads"):
    A = A.detach()
    if (d := A.dim()) not in [2, 3]:
        raise ValueError(f"Attention must have dimension 2 or 3, but got {d}")
    if d == 2:
        A = A.unsqueeze(0)
    if standardize:
        A = A * torch.arange(1, A.shape[1] + 1).view(1, A.shape[1], 1)
    nrow = int(np.floor(np.sqrt(A.shape[0])))
    ncol = int(np.ceil(np.sqrt(A.shape[0])))
    fig, axs = plt.subplots(nrow, ncol, layout="constrained")
    cmap = plt.cm.hot
    cmap.set_bad("k", 0.5)
    for h in range(A.shape[0]):
        try:
            ax = axs.flat[h]
        except AttributeError: # No method flat
            ax = axs
        ax.invert_yaxis()
        a = A[h].numpy()
        a = np.ma.masked_array(a, mask=np.triu(np.ones_like(a), 1))
        pcm = ax.pcolormesh(a, cmap=cmap)
        ax.set_title(f"Head {h+1}")
    h += 1
    while h < nrow * ncol:
        axs.flat[h].set_visible(False)
        h += 1
    try:
        for i, ax in enumerate(axs.flat):
            if i % ncol != 0:
                ax.get_yaxis().set_visible(False)
            if i < (nrow - 1) * ncol:
                ax.get_xaxis().set_visible(False)
    except AttributeError: # No flat
        fig.colorbar(pcm)
    else:
        fig.colorbar(pcm, ax=axs[..., -1])
    fig.suptitle(title)
    if filename is not None:
        plt.savefig(filename)

# Modules

## Parts

### Relative position

In [None]:
def rot(x):
    assert x.shape[-1] % 2 == 0, f"Input must have even last dimension, got {x.shape[-1]}"
    return torch.stack([-x[..., 1::2], x[..., 0::2]], dim=-1).flatten(-2)

In [None]:
class RotaryEncoding(nn.Module):
    def __init__(self, **config):
        super().__init__()
        for hp in ["d_qk", "n_heads", "max_context", "decay_base", "zeta_floor"]:
            setattr(self, hp, config[hp])
        if (self.d_qk // self.n_heads) % 4 != 0:
            raise ValueError(f"`d_qk / n_heads` must be a multiple of 4, got {self.d_qk // self.n_heads}")
        M = (self.d_qk // self.n_heads) / (4 * np.log2(self.max_context))
        self.theta = 2 * torch.pi * 2 ** (-torch.arange((self.d_qk // self.n_heads) // 4) / M)
        self.zeta = torch.maximum(self.decay_base ** (self.theta / (2 * torch.pi)), torch.tensor(self.zeta_floor))

    def forward(self, Q, K, m=None):
        if m is None:
            m = torch.arange(Q.shape[-2])
        m = m.view(-1, 1)
        arg = m * self.theta.view(1, -1)
        c = torch.cos(arg); c = torch.stack([c, c], dim=-1).flatten(-2)
        s = torch.sin(arg); s = torch.stack([s, s], dim=-1).flatten(-2)
        Cq = torch.cat([c,  c], dim=-1).to(device)
        Sq = torch.cat([s,  s], dim=-1).to(device)
        Ck = torch.cat([c, -s], dim=-1).to(device)
        Sk = torch.cat([s,  c], dim=-1).to(device)
        T = torch.pow(self.zeta.view(1, -1), m)
        T = torch.stack([T, T], dim=-1).flatten(-2)
        T = torch.cat([T, T], dim=-1).to(device)
        for _ in range(Q.dim() - 2):
            Cq = Cq.unsqueeze(0); Sq = Sq.unsqueeze(0)
            Ck = Ck.unsqueeze(0); Sk = Sk.unsqueeze(0)
            T = T.unsqueeze(0)
        Q = (Q * Cq + rot(Q) * Sq) * T
        K = (K * Ck + rot(K) * Sk) / T
        return (Q, K)

### Attention

In [None]:
class Drophead(nn.Module):
    def __init__(self, p, eps=DEFAULT_EPSILON):
        super().__init__()
        self.p = p
        self.eps = eps

    def forward(self, y):
        if not self.training:
            return y
        *b, h = y.shape[:-2]
        mask = torch.bernoulli(torch.full((*b, h), 1 - self.p)).to(device)
        mask = mask.view(*b, h, 1, 1)
        scale = h / (mask.sum(dim=-3, keepdim=True) + self.eps)
        return y * mask * scale

In [None]:
class MusicAttention(nn.Module):
    def __init__(self, **config):
        super().__init__()
        for hp in ["d_model", "n_heads", "initial_window", "d_qk", "d_v"]:
            setattr(self, hp, config[hp])
        if self.d_qk % self.n_heads != 0:
            raise ValueError(f"`d_qk` must be a multiple of `n_heads`, got {self.d_qk} and {self.n_heads}")
        if self.d_v % self.n_heads != 0:
            raise ValueError(f"`d_v` must be a multiple of `n_heads`, got {self.d_v} and {self.n_heads}")
        self.same_chord = nn.Parameter(torch.randn(self.d_qk) / np.sqrt(self.d_qk))
        self.Q = nn.Linear(self.d_model, self.d_qk, bias=False)
        self.K = nn.Linear(self.d_model, self.d_qk, bias=False)
        self.V = nn.Linear(self.d_model, self.d_v , bias=False)
        self.unify_heads = nn.Linear(self.d_v, self.d_model)

    def initialize(self):
        nn.init.zeros_(self.Q.weight.data)
        for m in [self.K, self.V, self.unify_heads]:
            nn.init.xavier_normal_(m.weight.data)
        nn.init.zeros_(self.unify_heads.bias.data)

    def forward(self, q, kv, events, shifts):
        *b, t = q.shape[:-1]
        w = self.initial_window
        h = self.n_heads
        ch = chord_indices(events).unsqueeze(-1)

        matches = rolling(ch, w).squeeze(-1)
        matches = (matches == matches[..., 0, :].unsqueeze(-2)).unsqueeze(-1).float()
        Km = matches * self.same_chord.view((1,) * (matches.dim() - 1) + (-1,))
        Km = Km.view(*Km.shape[:-1], h, -1)

        Q = self.Q(q).view(*b, t, h, -1)
        # Head-wise normalize before rolling, to avoid duplicating work.
        K = rolling(l2_normalize(self.K(kv).view(*b, t, h, -1)).flatten(-2), w).view(*b, w, t, h, -1)
        V = rolling(self.V(kv), w).view(*b, w, t, h, -1)

        dot = torch.einsum("...THD,...WTHD->...WTH", Q, K + Km)
        dot = dot - torch.full(dot.shape[:-1], torch.inf).tril(diagonal=-1).unsqueeze(-1).to(device)
        A = F.softmax(dot, dim=-3)

        y = torch.einsum("...WTH,...WTHD->...THD", A, V)
        y = self.unify_heads(y.flatten(-2))
        return (y, A)

In [None]:
class RotaryAttention(nn.Module):
    def __init__(self, **config):
        super().__init__()
        for hp in ["d_model", "n_heads", "d_qk", "d_v", "drophead_p"]:
            setattr(self, hp, config[hp])
        self.Q = nn.Linear(self.d_model, self.d_qk, bias=False)
        self.K = nn.Linear(self.d_model, self.d_qk, bias=False)
        self.V = nn.Linear(self.d_model, self.d_v , bias=False)
        self.drophead = Drophead(self.drophead_p)
        self.unify_heads = nn.Linear(self.d_v, self.d_model)
        self.initialize()

    def initialize(self):
        nn.init.zeros_(self.Q.weight.data)
        for m in [self.K, self.V, self.unify_heads]:
            nn.init.xavier_normal_(m.weight.data)
        nn.init.zeros_(self.unify_heads.bias.data)

    def forward(self, q, kv, rot, t=None):
        h = self.n_heads
        Q = self.Q(q).view(*q.shape[:-1], h, -1).transpose(-3, -2).contiguous()
        K = l2_normalize(self.K(kv).view(*kv.shape[:-1], h, -1)).transpose(-3, -2).contiguous()
        V = self.V(kv).view(*kv.shape[:-1], h, -1).transpose(-3, -2).contiguous()
        Q, K = rot(Q, K, t)
        y = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
        y = self.drophead(y).transpose(-2, -3)
        y = self.unify_heads(y.flatten(-2))
        return (y, None)

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, attention, **config):
        super().__init__()
        for hp in ["d_model", "d_v", "d_ff", "n_heads", "dropout_p"]:
            setattr(self, hp, config[hp])
        self.attention = attention
        self.dropout_attention = nn.Dropout(self.dropout_p)
        self.feedforward = nn.Sequential(
            nn.Linear(self.d_model, self.d_ff),
            nn.Dropout(self.dropout_p),
            nn.GELU(),
            nn.Linear(self.d_ff, self.d_model),
            nn.Dropout(self.dropout_p))
        self.alpha1 = nn.Parameter(torch.tensor(0.))
        self.alpha2 = nn.Parameter(torch.tensor(0.))
        self.initialize()

    def initialize(self):
        self.attention.initialize()
        for m in [self.feedforward[0], self.feedforward[3]]:
            # Type hints to make the checker shut up.
            if isinstance(m.weight, torch.Tensor):
                nn.init.xavier_normal_(m.weight.data)
            if isinstance(m.bias, torch.Tensor):
                nn.init.zeros_(m.bias.data)

    def forward(self, q, kv, *args):
        res, _ = self.attention(q, kv, *args)
        res = self.dropout_attention(res)
        x = q + self.alpha1 * res
        res = self.feedforward(x)
        x = x + self.alpha2 * res
        return x

## Model class

In [None]:
default_config = {
    "d_model": 256,
    "max_chord": 7,
    "initial_window": 17,
    "d_qk": 256,
    "d_v": 512,
    "d_ff": 512,
    "n_heads": 8,
    "n_shared_blocks": 6,
    "n_event_blocks": 3,
    "n_shift_blocks": 3,
    "dropout_p": 0.1,
    "drophead_p": 0.1,
    "decay_base": 0.99,
    "zeta_floor": 0.995,
    "max_context": 2048
}

class Model(nn.Module):
    def __init__(self, **config):
        super().__init__()
        config = default_config | config

        for hp in ["n_shared_blocks", "n_event_blocks", "n_shift_blocks", "d_model", "max_chord", "d_qk", "n_heads"]:
            setattr(self, hp, config[hp])

        self.event_embedding = nn.Embedding(N_EVENTS, self.d_model)
        self.shift_embedding = nn.Embedding(N_SHIFTS, self.d_model)
        self.chord_embedding = nn.Embedding(self.max_chord, self.d_model)
        self.rot = RotaryEncoding(**config) # Used elsewhere

        self.initial_block = TransformerBlock(MusicAttention(**config), **config)
        self.shared_blocks = nn.ModuleList()
        self.event_blocks = nn.ModuleList()
        self.shift_blocks = nn.ModuleList()

        for _ in range(self.n_shared_blocks):
            self.shared_blocks.append(TransformerBlock(RotaryAttention(**config), **config))
        for _ in range(self.n_event_blocks):
            self.event_blocks.append(TransformerBlock(RotaryAttention(**config), **config))
        for i in range(self.n_shift_blocks):
            self.shift_blocks.append(TransformerBlock(RotaryAttention(**config), **config))

        self.predict_event = nn.Linear(self.d_model, N_EVENTS)
        self.predict_shift = nn.Linear(self.d_model, N_SHIFTS)

    def initialize(self):
        self.initial_block.initialize()
        for b in self.shared_blocks:
            b.initialize()
        for b in self.event_blocks:
            b.initialize()
        for b in self.shift_blocks:
            b.initialize()
        for m in [self.predict_event, self.predict_shift]:
            nn.init.xavier_normal_(m.weight.data)
            nn.init.zeros_(m.bias.data)

    def get_logits(self, events, shifts, chords):
        chords = torch.minimum(chords, torch.tensor(self.max_chord - 1).to(device))
        event_emb = self.event_embedding(events)
        shift_emb = self.shift_embedding(shifts)
        chord_emb = self.chord_embedding(chords)

        # Shared blocks
        x0 = event_emb + shift_emb + chord_emb
        x = self.initial_block(x0, x0, events, shifts)
        for b in self.shared_blocks:
            x = b(x, x, self.rot)
        x0 = x

        # Event branch
        x = x0
        for b in self.event_blocks:
            x = b(x, x, self.rot)
        event_logits = self.predict_event(x)

        # Shift branch
        kv = x0
        q = torch.cat([
            event_emb[..., 1:, :] + chord_emb[..., 1:, :],
            torch.zeros(*event_emb.shape[:-2], 1, event_emb.shape[-1]).to(device)], dim=-2)
        for b in self.shift_blocks:
            q = b(q, kv, self.rot)
        shift_logits = self.predict_shift(q)

        return (event_logits, shift_logits)

    def loss(self, el, sl, events, shifts, masks):
        d = (masks > 0).sum()
        Le = -masks * torch.take_along_dim(el, events.unsqueeze(-1), -1).squeeze(-1) / d
        Ls = -masks * torch.take_along_dim(sl, shifts.unsqueeze(-1), -1).squeeze(-1) / d
        return (Le.sum(), Ls.sum())

    def forward(self, events, shifts, chords, masks):
        el, sl = self.get_logits(events, shifts, chords)
        events, el = events[..., 1:], el[..., :-1, :]
        shifts, sl = shifts[..., 1:], sl[..., :-1, :]
        masks = masks[..., 1:]
        el = F.log_softmax(el, dim=-1)
        sl = F.log_softmax(sl, dim=-1)
        return self.loss(el, sl, events, shifts, masks)

# Training

In [None]:
def quadratic_sample(model, filename, t=100, first_event=50, first_shift=7, tempo=140):
    events, shifts = torch.tensor([first_event]).to(device), torch.tensor([first_shift]).to(device)

    with torch.no_grad():
        for _ in progress(range(t)):
            chords = chord_position(events)
            el, _ = model.get_logits(events, shifts, chords)
            e = torch.distributions.Categorical(logits=el[-1]).sample().unsqueeze(0)
            events = torch.cat([events, e])
            shifts = torch.cat([shifts, torch.tensor([0]).to(device)]) # dummy

            chords = chord_position(events)
            _, sl = model.get_logits(events, shifts, chords)
            s = torch.distributions.Categorical(logits=sl[-2]).sample().unsqueeze(0)
            shifts = torch.cat([shifts[:-1], s])

    mid = instructions_to_midi(events.to("cpu"), shifts.to("cpu"), tempo)
    mid.write(filename)
    return mid

In [None]:
codename = "small"
config = {
    "d_model": 128,
    "d_qk": 128,
    "d_v": 256,
    "d_ff": 256,
    "n_heads": 8,
    "n_shared_blocks": 3,
    "n_event_blocks": 2,
    "n_shift_blocks": 2,
    "max_context": 1024
}

load_path = None #"/content/drive/My Drive/model/small_2023-10-28_30000.pt"

if load_path is None:
    model = Model(**config).to(device)
    model.initialize()
    opt = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)
    iteration = 0
    event_loss_history = []
    shift_loss_history = []
    loss_history = []
    validation_loss_history = []
else:
    saved_state = torch.load(load_path)
    config = saved_state.get("config", dict())
    model = Model(**config).to(device)
    model.load_state_dict(saved_state["model_state_dict"])
    opt = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)
    opt.load_state_dict(saved_state["optimizer_state_dict"])
    iteration = saved_state["iteration"]
    event_loss_history = saved_state["event_loss_history"]
    shift_loss_history = saved_state["shift_loss_history"]
    loss_history = saved_state["loss_history"]
    validation_loss_history = saved_state["validation_loss_history"]

dataset = PickleSet(get_training_set())
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=partial(collate, CONTEXT_WINDOW), drop_last=True)
batches = batch_cycle(dataloader)

model.train()
batches = batch_cycle(dataloader)

update_freq = 100
sample_freq = 1_000
save_freq = 5_000

for i in progress(range(iteration, N_TRAINING_STEPS), update_freq):
    _names, mask, events, shifts, chords = next(batches)
    opt.zero_grad()
    event_loss, shift_loss = model(events, shifts, chords, mask)
    loss = event_loss + shift_loss
    loss.backward()
    opt.step()

    loss_history.append(loss.item())
    event_loss_history.append(event_loss.item())
    shift_loss_history.append(shift_loss.item())

    if (i + 1) % update_freq == 0:
        print(f"{i+1} iterations")

        med_loss = np.median(loss_history[-update_freq:])
        med_event_loss = np.median(event_loss_history[-update_freq:])
        med_shift_loss = np.median(shift_loss_history[-update_freq:])
        print(f"loss: {med_loss:.3f} ~ {med_event_loss:.3f} + {med_shift_loss:.3f}")

        with torch.no_grad():
            dataset = PickleSet(get_validation_set())
            validation = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=partial(collate, CONTEXT_WINDOW), drop_last=True)
            v = []
            for _names, mask, events, shifts, chords in validation:
                event_loss, shift_loss = model(events, shifts, chords, mask)
                v.append((event_loss + shift_loss).item())
            med_v = np.median(v)
            print(f"validation: {med_v:.3f}")
            validation_loss_history.append(med_v)

        print()

    if (i + 1) % sample_freq == 0:
        with torch.no_grad():
            filename = f"/content/drive/My Drive/Colab output/{codename}_{dt.date.today()}_{i+1}.mid"
            quadratic_sample(model, filename)

    if (i + 1) % save_freq == 0:
        save_path = f"/content/drive/My Drive/model/{codename}_{dt.date.today()}_{i+1}.pt"
        torch.save(
            {
                "config": config,
                "iteration": i+1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": opt.state_dict(),
                "event_loss_history": event_loss_history,
                "shift_loss_history": shift_loss_history,
                "loss_history": loss_history,
                "validation_loss_history": validation_loss_history},
            save_path)