Installs

In [1]:
!apt install fluidsynth

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  fluid-soundfont-gm libevdev2 libfluidsynth3 libgudev-1.0-0 libinput-bin
  libinput10 libinstpatch-1.0-2 libmd4c0 libmtdev1 libqt5core5a libqt5dbus5
  libqt5gui5 libqt5network5 libqt5svg5 libqt5widgets5 libwacom-bin
  libwacom-common libwacom9 libxcb-icccm4 libxcb-image0 libxcb-keysyms1
  libxcb-render-util0 libxcb-util1 libxcb-xinerama0 libxcb-xinput0 libxcb-xkb1
  libxkbcommon-x11-0 qsynth qt5-gtk-platformtheme qttranslations5-l10n
  timgm6mb-soundfont
Suggested packages:
  fluid-soundfont-gs qt5-image-formats-plugins qtwayland5 jackd
The following NEW packages will be installed:
  fluid-soundfont-gm fluidsynth libevdev2 libfluidsynth3 libgudev-1.0-0
  libinput-bin libinput10 libinstpatch-1.0-2 libmd4c0 libmtdev1 libqt5core5a
  libqt5dbus5 libqt5gui5 libqt5network5 libqt5svg5 libqt5widgets5 libwacom-bin
  libwacom-common libwacom9 libx

In [2]:
!git clone https://github.com/jthickstun/anticipation.git
!pip install ./anticipation
!pip install -r anticipation/requirements.txt

Cloning into 'anticipation'...
remote: Enumerating objects: 1526, done.[K
remote: Counting objects: 100% (351/351), done.[K
remote: Compressing objects: 100% (115/115), done.[K
remote: Total 1526 (delta 286), reused 275 (delta 236), pack-reused 1175 (from 2)[K
Receiving objects: 100% (1526/1526), 56.24 MiB | 35.35 MiB/s, done.
Resolving deltas: 100% (1009/1009), done.
Processing ./anticipation
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: anticipation
  Building wheel for anticipation (setup.py) ... [?25l[?25hdone
  Created wheel for anticipation: filename=anticipation-1.0-py3-none-any.whl size=18682 sha256=107b29c3224497207393133733355c7f3e0abe4c99a4f77973891305fd87d0a5
  Stored in directory: /tmp/pip-ephem-wheel-cache-n40bfkfe/wheels/00/47/a1/fce9dedfd7d5c624e471dc01096a22fd7c945799cf58510c11
Successfully built anticipation
Installing collected packages: anticipation
Successfully installed anticipation-1.0
Collecting matplotlib==3.

In [3]:
!pip install tokenizers
!pip install "midi2audio==0.1.1"
!pip install "mido==1.2.10"

Collecting midi2audio==0.1.1
  Using cached midi2audio-0.1.1-py2.py3-none-any.whl.metadata (5.7 kB)
Using cached midi2audio-0.1.1-py2.py3-none-any.whl (8.7 kB)
Installing collected packages: midi2audio
Successfully installed midi2audio-0.1.1
Collecting mido==1.2.10
  Using cached mido-1.2.10-py2.py3-none-any.whl.metadata (3.4 kB)
Using cached mido-1.2.10-py2.py3-none-any.whl (51 kB)
Installing collected packages: mido
Successfully installed mido-1.2.10


Setup the runtime environment

In [4]:
import sys,time

import midi2audio
import transformers
from transformers import AutoModelForCausalLM

from IPython.display import Audio

from anticipation import ops
from anticipation.sample import generate
from anticipation.tokenize import extract_instruments
from anticipation.convert import events_to_midi,midi_to_events
from anticipation.visuals import visualize
from anticipation.config import *
from anticipation.vocab import *

In [5]:
SMALL_MODEL = 'stanford-crfm/music-small-800k'     # faster inference, worse sample quality
MEDIUM_MODEL = 'stanford-crfm/music-medium-800k'   # slower inference, better sample quality
LARGE_MODEL = 'stanford-crfm/music-large-800k'     # slowest inference, best sample quality

# load an anticipatory music transformer
model = AutoModelForCausalLM.from_pretrained(SMALL_MODEL).cuda()

# a MIDI synthesizer
fs = midi2audio.FluidSynth('/usr/share/sounds/sf2/FluidR3_GM.sf2')

# the MIDI synthesis script
def synthesize(fs, tokens):
    mid = events_to_midi(tokens)
    mid.save('tmp.mid')
    fs.midi_to_audio('tmp.mid', 'tmp.wav')
    return 'tmp.wav'

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/512M [00:00<?, ?B/s]

In [6]:
import math

import torch
import torch.nn.functional as F

from tqdm import tqdm

from anticipation import ops
from anticipation.config import *
from anticipation.vocab import *

Custom functions

In [7]:
def safe_logits(logits, idx):
    logits[CONTROL_OFFSET:SPECIAL_OFFSET] = -float('inf') # don't generate controls
    logits[SPECIAL_OFFSET:] = -float('inf')               # don't generate special tokens

    # don't generate stuff in the wrong time slot
    if idx % 3 == 0:
        logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')
        logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')
    elif idx % 3 == 1:
        logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')
        logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')
    elif idx % 3 == 2: #expecting a note token
        logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')
        logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')

    return logits


def nucleus(logits, top_p):
    # from HF implementation
    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p

        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = -float("inf")

    return logits


def future_logits(logits, curtime):
    """ don't sample events in the past """
    if curtime > 0:
        logits[TIME_OFFSET:TIME_OFFSET+curtime] = -float('inf')

    return logits


def instr_logits_part1(logits, full_history, instruments):
    """ don't sample more than 16 instruments """
    instrs = ops.get_instruments(full_history)
    print("instruments full history", instrs)

    if instruments is not None:
    #ONLY ALLOW SPECIFIED INSTRUMENTS, BE CAREFUL -- which instruments are present in full_history?
        #print("ONLY ALLOW SPECIFIED INSTRUMENTS")
        for instr_id in range(128):
            if instr_id not in instruments:
                #print("block instrument", instr_id)
                logits[NOTE_OFFSET+instr_id*MAX_PITCH:NOTE_OFFSET+(instr_id+1)*MAX_PITCH] = -float('inf')
            else:
                print("allowed instruemtn", instr_id)

    if len(instrs) < 15: # 16 - 1 to account for the reserved drum track
        return logits

    for instr in range(MAX_INSTR):
        if instr not in instrs: #only use instruments in instrs, which i guess means from full_history it should be instruments used in the past
            logits[NOTE_OFFSET+instr*MAX_PITCH:NOTE_OFFSET+(instr+1)*MAX_PITCH] = -float('inf')


    return logits


def add_token_part1(model, z, tokens, top_p, current_time, instruments, debug=False):
    assert len(tokens) % 3 == 0

    history = tokens.copy()
    lookback = max(len(tokens) - 1017, 0)
    history = history[lookback:] # Markov window
    offset = ops.min_time(history, seconds=False)
    history[::3] = [tok - offset for tok in history[::3]] # relativize time in the history buffer

    new_token = []
    with torch.no_grad():
        for i in range(3):
            input_tokens = torch.tensor(z + history + new_token).unsqueeze(0).to(model.device)
            logits = model(input_tokens).logits[0,-1]

            idx = input_tokens.shape[1]-1
            logits = safe_logits(logits, idx)
            if i == 0:
                logits = future_logits(logits, current_time - offset)
            elif i == 2:
                logits = instr_logits_part1(logits, tokens, instruments) #PASS DOWN THE RESTRICTION HERE
            logits = nucleus(logits, top_p)

            probs = F.softmax(logits, dim=-1)
            token = torch.multinomial(probs, 1)
            new_token.append(int(token))

    new_token[0] += offset # revert to full sequence timing
    if debug:
        print(f'  OFFSET = {offset}, LEN = {len(history)}, TIME = {tokens[::3][-5:]}')
    print("new token: ", new_token[0], new_token[1], new_token[2])

    return new_token


In [8]:
from dataclasses import dataclass, field
from math import inf
from typing import Optional, List

@dataclass
class Beam:
    tokens: List[int]                  #full token history (context + generated)
    score: float = 0.0                 #sum of log-probs for generated tokens
    current_time: float = 0.0          #last generated absolute TIME (after offset)
    control_tokens: List[int] = field(default_factory=list)  #put controls per-beam rather than global in case times don't align
    anticip_time: float = inf          #onset of next anticipatory triple (ATIME - ATIME_OFFSET)
    gen_len: int = 0
    constraint_tracker: Optional[int] = None

In [9]:
def safe_logits(logits, idx):
    logits[CONTROL_OFFSET:SPECIAL_OFFSET] = -float('inf') # don't generate controls
    logits[SPECIAL_OFFSET:] = -float('inf')               # don't generate special tokens

    # don't generate stuff in the wrong time slot
    if idx % 3 == 0:
        logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')
        logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')
    elif idx % 3 == 1:
        logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')
        logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')
    elif idx % 3 == 2: #expecting a note token
        logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')
        logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')

    return logits


In [10]:
def nucleus(logits, top_p):
    # from HF implementation
    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p

        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = -float("inf")

    return logits

In [11]:
def future_logits(logits, curtime):
    """ don't sample events in the past """
    if curtime > 0:
        logits[TIME_OFFSET:TIME_OFFSET+curtime+1] = -float('inf')

    return logits

In [12]:
def instr_logits_part1(logits, full_history, instruments):
    """ don't sample more than 16 instruments """
    instrs = ops.get_instruments(full_history)
    #print("instruments full history", instrs)

    if instruments is not None:
    #ONLY ALLOW SPECIFIED INSTRUMENTS, BE CAREFUL -- which instruments are present in full_history?
        #print("ONLY ALLOW SPECIFIED INSTRUMENTS")
        for instr_id in range(129): #INCLUDE DRUMS AS SOMETHING WHICH CAN BE BLOCKED
            if instr_id not in instruments:
                #print("block instrument", instr_id)
                logits[NOTE_OFFSET+instr_id*MAX_PITCH:NOTE_OFFSET+(instr_id+1)*MAX_PITCH] = -float('inf')
            #else:
                #print("allowed instruemtn", instr_id)

    if len(instrs) < 15: # 16 - 1 to account for the reserved drum track
        return logits

    for instr in range(MAX_INSTR):
        if instr not in instrs: #only use instruments in instrs, which i guess means from full_history it should be instruments used in the past
            logits[NOTE_OFFSET+instr*MAX_PITCH:NOTE_OFFSET+(instr+1)*MAX_PITCH] = -float('inf')


    return logits

In [13]:
def add_token_part1_modified(model, z, tokens, top_p, current_time, instruments, debug=False):
    assert len(tokens) % 3 == 0

    history = tokens.copy()
    lookback = max(len(tokens) - 1017, 0)
    history = history[lookback:] # Markov window
    offset = ops.min_time(history, seconds=False)
    history[::3] = [tok - offset for tok in history[::3]] # relativize time in the history buffer

    new_token = []
    new_token_score = 0
    with torch.no_grad():
        for i in range(3):
            input_tokens = torch.tensor(z + history + new_token).unsqueeze(0).to(model.device)
            logits = model(input_tokens).logits[0,-1]

            idx = input_tokens.shape[1]-1
            logits = safe_logits(logits, idx)
            if i == 0:
                logits = future_logits(logits, current_time - offset)
            elif i == 2:
                logits = instr_logits_part1(logits, tokens, instruments) #PASS DOWN THE RESTRICTION HERE
            logits = nucleus(logits, top_p)

            probs = F.softmax(logits, dim=-1)
            log_probs = F.log_softmax(logits, dim=-1)
            token = torch.multinomial(probs, 1)
            new_token.append(int(token))
            new_token_score += float(log_probs[int(token)].item())

    new_token[0] += offset # revert to full sequence timing
    if debug:
        print(f'  OFFSET = {offset}, LEN = {len(history)}, TIME = {tokens[::3][-5:]}')
    print("new token: ", new_token[0], new_token[1], new_token[2], "score: ", new_token_score)

    return new_token, new_token_score

In [14]:
def topk_triples(model, z, tokens, current_time, instruments, debug=False, K_time=4, K_dur=2, K_note=2, K_total=8, top_p=None):
    assert len(tokens) % 3 == 0

    device=model.device

    history = tokens.copy()
    lookback = max(len(tokens) - 1017, 0)
    history = history[lookback:] # Markov window
    offset = ops.min_time(history, seconds=False)
    history[::3] = [tok - offset for tok in history[::3]] # relativize time in the history buffer

    def apply_masks(logits, phase_idx, inp_len, tokens): #uhhh compared to original code inp_len is basically input_tokens.shape[1]
        logits = safe_logits(logits, inp_len - 1)
        if phase_idx == 0:
            logits = future_logits(logits, current_time - offset)
        elif phase_idx == 2:
            logits = instr_logits_part1(logits, tokens, instruments)
        if top_p is not None:
            logits = nucleus(logits, top_p)
        return logits

    with torch.no_grad():
        #TIME TOKEN: generate K_time possibilities
        inp0 = torch.tensor(z + history, device=device).unsqueeze(0)
        logits_t = model(inp0).logits[0, -1] #(1, L, V) -> just shape V
        logits_t = apply_masks(logits_t, phase_idx=0, inp_len=inp0.shape[1], tokens=tokens)
        logp_t = torch.log_softmax(logits_t, dim=-1)
        t_vals, t_ids = torch.topk(logp_t, K_time)
        t_ids = t_ids.tolist(); t_vals = t_vals.tolist()

        #DURATION TOKEN (batch over K_time)
        #build batch prefixes: z + history + [t_i]
        batch_time_inputs = [z + history + [t] for t in t_ids]
        inp1 = torch.nn.utils.rnn.pad_sequence(
            [torch.tensor(x, device=device) for x in batch_time_inputs],
            batch_first=True
        )
        logits_d_all = model(inp1).logits[:, -1, :] #[K_time, L, V] -> [K_time, V]
        logp_d_all = []
        d_ids_all  = []
        for row, base_len in zip(logits_d_all, [len(x) for x in batch_time_inputs]):
            row = apply_masks(row, phase_idx=1, inp_len=base_len, tokens=tokens)
            #note tokens is the "old" full history but it's ok here, the only thing apply_masks
            #passes it into is instr_logit which cares about instrument history we're chilling
            lp = torch.log_softmax(row, dim=-1)
            d_vals, d_ids = torch.topk(lp, K_dur) #take top K_dur options
            logp_d_all.append(d_vals)
            d_ids_all.append(d_ids)
        #shapes end up being lists of K_time tensors, each [K_dur]

        #NOTE TOKEN (batch over K_time *K_dur)
        td_pairs = []
        td_logps = []
        td_inputs = []
        for i in range(len(t_ids)): #K_time outer loop
            for j in range(K_dur): #K_dur inner loop
                time_token_id = t_ids[i] #token id
                lp_time = t_vals[i] #log prob for that token
                dur_token_id = d_ids_all[i][j].item()
                lp_dur = logp_d_all[i][j].item()
                td_pairs.append((time_token_id, dur_token_id, lp_time, lp_dur))
                td_inputs.append(z + history + [time_token_id, dur_token_id])

        inp2 = torch.nn.utils.rnn.pad_sequence( #batched processing (K_time*K_dur) batches
            [torch.tensor(x, device=device) for x in td_inputs],
            batch_first=True
        )
        logits_n_all = model(inp2).logits[:, -1, :] #(K_time*K_dur, L, V) -> (K_time*K_dur, V)

        candidates = []  # (triple_ids, joint_logp)
        idx = 0
        for i in range(len(t_ids)): #K_time
            for j in range(K_dur): #K_dur
                row = logits_n_all[idx]
                idx += 1 #counts up to K_time*K_dur
                base_len = len(td_inputs[i*K_dur + j])
                row = apply_masks(row, phase_idx=2, inp_len=base_len, tokens=tokens)
                lp = torch.log_softmax(row, dim=-1)
                note_vals, note_ids = torch.topk(lp, K_note) #pick top K_note options
                time_token_id, dur_token_id, lp_time, lp_dur = td_pairs[i*K_dur + j]
                for k in range(K_note):
                    note_token_id = note_ids[k].item()
                    lp_note = note_vals[k].item()
                    joint = lp_time + lp_dur + lp_note
                    candidates.append(([time_token_id, dur_token_id, note_token_id], joint))

        def dedup(candidates):
            unique = {}
            for note_choice, logprob in candidates:
                key = tuple(note_choice)
                if key not in unique or logprob > unique[key]:
                    unique[key] = logprob #tuplify the array of 3
            return [(list(key), unique[key]) for key in unique.keys()]

        candidates = dedup(candidates) #remove duplicates hopefully this helps

        #candidates has list of triples ([time token, dur token, note token], prob)
        joint_logps = torch.tensor([logprob for _, logprob in candidates], device=device)
        #gumbel top k sampling, basically the idea is you add random noise before you take the top k
        u = torch.rand_like(joint_logps)
        g = -torch.log(-torch.log(u))              # Gumbel(0,1)
        tau = 1.0                                  # temperature: 1.0–1.5 = good range
        scores = joint_logps / tau + g             # random jittered scores

        #choose K_total without replacement (highest noised scores)
        top = torch.topk(scores, K_total)
        best = [candidates[i] for i in top.indices.tolist()]

        #ALTERNATIVELY, DETERMINISTIC TOP TOTAL_K -> tried this and it led to beam collapse
        #candidates.sort(key=lambda x: x[1], reverse=True) #highest to lowest by joint prob
        #best = candidates[:K_total]

        triples = torch.tensor([ids for ids,_ in best], device=device, dtype=torch.long)  # [K_total, 3]
        logps  = torch.tensor([lp  for _,lp in best], device=device, dtype=torch.float)   # [K_total]

    #if TIME in history was relativized by `offset`, undo for output:
    triples[:, 0] = triples[:, 0] + offset
    print("triple of tokens option from single beam", triples)

    return triples, logps #FORMAT IS TENSOR OF SHAPES [K_total, 3] and [K_total]


In [15]:
def generate_beams(model, start_time, end_time, inputs=None, controls=None, top_p=None,
                   debug=False, delta=DELTA*TIME_RESOLUTION, instruments=None,
                   num_beams=10, K_total=10, K_time=5, K_dur=2, K_note=3):
    if inputs is None:
        inputs = []

    if controls is None:
        controls = []

    start_time = int(TIME_RESOLUTION*start_time)
    end_time = int(TIME_RESOLUTION*end_time)

    # prompt is events up to start_time
    prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time)

    # treat events beyond start_time as controls
    future = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False)
    if debug:
        print('Future')
        ops.print_tokens(future)

    # clip controls that preceed the sequence
    controls = ops.clip(controls, DELTA, ops.max_time(controls, seconds=False), clip_duration=False, seconds=False)

    if debug:
        print('Controls')
        ops.print_tokens(controls)

    z = [ANTICIPATE] if len(controls) > 0 or len(future) > 0 else [AUTOREGRESS]
    if debug:
        print('AR Mode' if z[0] == AUTOREGRESS else 'AAR Mode')

    # interleave the controls with the events
    tokens, controls = ops.anticipate(prompt, ops.sort(controls + [CONTROL_OFFSET+token for token in future]))

    if debug:
        print('Prompt')
        ops.print_tokens(tokens)

    current_time = ops.max_time(prompt, seconds=False)

    if debug:
        print('Current time:', current_time)

    #ok now we make a list of beams each initializing tokens with the controls
    beams = []
    for _ in range(num_beams):
        beams.append(Beam(
          tokens=tokens.copy(),
          control_tokens=controls.copy(),
          anticip_time=(controls[0] - ATIME_OFFSET if controls else math.inf),
          score=0.0,
          current_time=current_time,
          gen_len = 0
        ))

    #with tqdm(range(end_time-start_time)) as progress:
    not_done = True
    counter = -1
    phrase = 3

    while not_done:

        counter += 1
        candidates = []
        unique_beams = {}
        not_done = False

        for idx, beam in enumerate(beams): #for every add a token triplet in each one which hasn't finished

            #directly add the finished beams
            if beam.current_time >= end_time:
                candidates.append(beam)
                continue

            not_done = True #if at least one beam gets token added, then not done
            #last pass not_done will be False if every beam is done
            print("beam", idx, "has current time", beam.current_time, "tokens", beam.tokens)
            print("anticipated_time", beam.anticip_time, "end_time", end_time)

            #directly mutate control_tokens, anticip_time
            while beam.current_time >= beam.anticip_time - delta:

                if not beam.control_tokens:
                    break

                atime, adur, anote = beam.control_tokens[:3]
                beam.tokens.extend([atime, adur, anote])
                beam.control_tokens = beam.control_tokens[3:]

                if debug:
                    note = anote - ANOTE_OFFSET
                    instr = note//2**7
                    print('A', atime - ATIME_OFFSET, adur - ADUR_OFFSET, instr, note - (2**7)*instr)

                if len(beam.control_tokens) > 0:
                    beam.anticip_time = beam.control_tokens[0] - ATIME_OFFSET
                else:
                    beam.anticip_time = math.inf

            if counter % phrase != 0: #just generate normally rather than branching on this triplet
                new_token, new_token_score = add_token_part1_modified(model, z, beam.tokens, top_p=top_p,
                                            current_time=max(start_time, beam.current_time),
                                            instruments=instruments, debug=True)
                if new_token[0] < end_time: #new token's time
                      possible_beam = Beam(
                          tokens=beam.tokens.copy() + new_token,
                          control_tokens=beam.control_tokens[:],
                          anticip_time = beam.anticip_time,
                          score=beam.score + new_token_score,
                          current_time=new_token[0],
                          gen_len = beam.gen_len+3
                      )
                      candidates.append(possible_beam)
                else:
                    candidates.append(beam)

                continue #don't do the branching

            new_triples, logps = topk_triples(model, z, beam.tokens, max(start_time,beam.current_time), instruments=instruments, K_total=K_total,
                                              debug=debug, K_time=K_time, K_dur=K_dur, K_note=K_note, top_p=None)
            #also has default parameters (debug=False, K_time=4, K_dur=2, K_note=2, K_total=8, top_p=None)
            #shapes [K_total, 3] and [K_total]

            for row, logp in zip(new_triples, logps):

                  new_time = row[0].item() - TIME_OFFSET
                  if new_time < beam.current_time:
                      continue

                  if new_time < end_time:
                      possible_beam = Beam(
                          tokens=beam.tokens.copy() + [token.item() for token in row],
                          control_tokens=beam.control_tokens[:],
                          anticip_time = beam.anticip_time,
                          score=beam.score + logp.item(),
                          current_time=new_time, #the new time, don't actually mutate new_triples
                          gen_len = beam.gen_len+3
                      )
                  else: #DON'T ACTUALLY APPEND THAT TRIPLE, though anticipation has been mutated
                      possible_beam = Beam(
                          tokens=beam.tokens.copy(),
                          control_tokens=beam.control_tokens[:],
                          anticip_time = beam.anticip_time,
                          score=beam.score,
                          current_time=new_time, #terminal beam candidate, kill it from growing to prevent inf loop
                          gen_len=beam.gen_len
                      )

                  if tuple(possible_beam.tokens) not in unique_beams: #DEDUP
                      unique_beams[tuple(possible_beam.tokens)] = True
                      candidates.append(possible_beam)

            if debug:
                print("print data about the new triples generated?")
                #new_note = new_token[2] - NOTE_OFFSET
                #new_instr = new_note//2**7
                #new_pitch = new_note - (2**7)*new_instr
                #print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch)

        def rank(b, alpha=0.5, gamma=0.01, empty_penalty=1e6, start_tick=0):
            #normalize for length
            base = b.score / (max(1, b.gen_len) ** alpha)

            #favor beams that advance forward in time
            #be careful to scale gamma to your tick units / seconds, start_tick is start_time in ticks
            prog = gamma * max(0, b.current_time - start_tick)

            #add penalty for empty generations so that we don't just end up with prompt and nothing else
            if b.gen_len == 0:
                return base + prog - empty_penalty

            return base + prog

        candidates.sort(key=lambda b: rank(b), reverse=True) #highest to lowest by score, in-place sort
        beams = candidates[:num_beams]
        # for beam in beams:
        #     print("current time", beam.current_time, "tokens", beam.tokens)


    #NOW CHOOSE FINAl OUTPUT OFF OF BEAMS LIST
    if beams:
        best_tokens = beams[0].tokens
    else:
        best_tokens = tokens

    print("best_tokens", best_tokens)
    events, _ = ops.split(best_tokens)
    return ops.sort(ops.unpad(events) + future)

Testing


In [17]:
sample_tokens= generate_beams(model, start_time=0, end_time=5, top_p=.98, instruments={1}, debug=False,
                              num_beams=10, K_total=10, K_time=5, K_dur=2, K_note=5) #K_total is branch factor
Audio(synthesize(fs, sample_tokens))

beam 0 has current time 0 tokens []
anticipated_time inf end_time 500
triple of tokens option from single beam tensor([[    0, 10000, 27512],
        [    0, 10001, 27512],
        [    0, 10001, 11202],
        [    0, 10001, 11190],
        [    0, 10001, 11200],
        [    0, 10000, 11195],
        [    0, 10001, 11195],
        [    2, 10000, 27512],
        [    1, 10012, 27512],
        [    0, 10000, 11190]], device='cuda:0')
beam 1 has current time 0 tokens []
anticipated_time inf end_time 500
triple of tokens option from single beam tensor([[    0, 10000, 27512],
        [    0, 10001, 27512],
        [    0, 10001, 11190],
        [    0, 10000, 11195],
        [    0, 10000, 11190],
        [    0, 10001, 11195],
        [    0, 10001, 11202],
        [    0, 10000, 11200],
        [    0, 10001, 11200],
        [    0, 10000, 11202]], device='cuda:0')
beam 2 has current time 0 tokens []
anticipated_time inf end_time 500
triple of tokens option from single beam tensor([[  

Constrain the beam search: now force injecting a note
- exact note with given time and duration
- a pitch within a specified time window
- a pitch anywhere within the sequence

The plan is to add the desired note as a possibility while branching, and prioritize beams which contain the note (after sorting candidates, use quota)

In future want to extend this to chord progressions, etc

- grid beam search (bins by how many constraints satisfied)
- dynamic (fractional) allocation
- alternatively, weight the logit for that token option by a lot?


In [None]:
@dataclass
class PitchWindow:
    pitch: int
    window_start: float
    window_end: float

In [None]:
def topk_triples_constrained(model, z, tokens, current_time, instruments, forced_note, debug=False, K_time=4, K_dur=2, K_note=2, K_total=8, top_p=None):
    assert len(tokens) % 3 == 0

    device=model.device

    history = tokens.copy()
    lookback = max(len(tokens) - 1017, 0)
    history = history[lookback:] # Markov window
    offset = ops.min_time(history, seconds=False)
    history[::3] = [tok - offset for tok in history[::3]] # relativize time in the history buffer

    def apply_masks(logits, phase_idx, inp_len, tokens): #uhhh compared to original code inp_len is basically input_tokens.shape[1]
        logits = safe_logits(logits, inp_len - 1)
        if phase_idx == 0:
            logits = future_logits(logits, current_time - offset)
        elif phase_idx == 2:
            logits = instr_logits_part1(logits, tokens, instruments)
        if top_p is not None:
            logits = nucleus(logits, top_p)
        return logits

    with torch.no_grad():
        #TIME TOKEN: generate K_time possibilities
        inp0 = torch.tensor(z + history, device=device).unsqueeze(0)
        logits_t = model(inp0).logits[0, -1] #(1, L, V) -> just shape V
        logits_t = apply_masks(logits_t, phase_idx=0, inp_len=inp0.shape[1], tokens=tokens)
        logp_t = torch.log_softmax(logits_t, dim=-1)
        t_vals, t_ids = torch.topk(logp_t, K_time)
        t_ids = t_ids.tolist(); t_vals = t_vals.tolist()

        #DURATION TOKEN (batch over K_time)
        #build batch prefixes: z + history + [t_i]
        batch_time_inputs = [z + history + [t] for t in t_ids]
        inp1 = torch.nn.utils.rnn.pad_sequence(
            [torch.tensor(x, device=device) for x in batch_time_inputs],
            batch_first=True
        )
        logits_d_all = model(inp1).logits[:, -1, :] #[K_time, L, V] -> [K_time, V]
        logp_d_all = []
        d_ids_all  = []
        for row, base_len in zip(logits_d_all, [len(x) for x in batch_time_inputs]):
            row = apply_masks(row, phase_idx=1, inp_len=base_len, tokens=tokens)
            #note tokens is the "old" full history but it's ok here, the only thing apply_masks
            #passes it into is instr_logit which cares about instrument history we're chilling
            lp = torch.log_softmax(row, dim=-1)
            d_vals, d_ids = torch.topk(lp, K_dur) #take top K_dur options
            logp_d_all.append(d_vals)
            d_ids_all.append(d_ids)
        #shapes end up being lists of K_time tensors, each [K_dur]

        #NOTE TOKEN (batch over K_time *K_dur)
        td_pairs = []
        td_logps = []
        td_inputs = []
        for i in range(len(t_ids)): #K_time outer loop
            for j in range(K_dur): #K_dur inner loop
                time_token_id = t_ids[i] #token id
                lp_time = t_vals[i] #log prob for that token
                dur_token_id = d_ids_all[i][j].item()
                lp_dur = logp_d_all[i][j].item()
                td_pairs.append((time_token_id, dur_token_id, lp_time, lp_dur))
                td_inputs.append(z + history + [time_token_id, dur_token_id])

        inp2 = torch.nn.utils.rnn.pad_sequence( #batched processing (K_time*K_dur) batches
            [torch.tensor(x, device=device) for x in td_inputs],
            batch_first=True
        )
        logits_n_all = model(inp2).logits[:, -1, :] #(K_time*K_dur, L, V) -> (K_time*K_dur, V)

        candidates = []  # (triple_ids, joint_logp)
        idx = 0
        for i in range(len(t_ids)): #K_time
            for j in range(K_dur): #K_dur
                row = logits_n_all[idx]
                idx += 1 #counts up to K_time*K_dur
                base_len = len(td_inputs[i*K_dur + j])
                row = apply_masks(row, phase_idx=2, inp_len=base_len, tokens=tokens)
                lp = torch.log_softmax(row, dim=-1)
                note_vals, note_ids = torch.topk(lp, K_note) #pick top K_note options
                time_token_id, dur_token_id, lp_time, lp_dur = td_pairs[i*K_dur + j]
                for k in range(K_note):
                    note_token_id = note_ids[k].item()
                    lp_note = note_vals[k].item()
                    joint = lp_time + lp_dur + lp_note
                    candidates.append(([time_token_id, dur_token_id, note_token_id], joint))

        def dedup(candidates):
            unique = {}
            for note_choice, logprob in candidates:
                key = tuple(note_choice)
                if key not in unique or logprob > unique[key]:
                    unique[key] = logprob #tuplify the array of 3
            return [(list(key), unique[key]) for key in unique.keys()]

        candidates = dedup(candidates) #remove duplicates hopefully this helps

        #candidates has list of triples ([time token, dur token, note token], prob)
        joint_logps = torch.tensor([logprob for _, logprob in candidates], device=device)
        #gumbel top k sampling, basically the idea is you add random noise before you take the top k
        u = torch.rand_like(joint_logps)
        g = -torch.log(-torch.log(u))              # Gumbel(0,1)
        tau = 1.0                                  # temperature: 1.0–1.5 = good range
        scores = joint_logps / tau + g             # random jittered scores

        #choose K_total without replacement (highest noised scores)
        top = torch.topk(scores, K_total-1) #reserve 1 spot for the forced option
        best = [candidates[i] for i in top.indices.tolist()]

        #ALTERNATIVELY, DETERMINISTIC TOP TOTAL_K -> tried this and it led to beam collapse
        #candidates.sort(key=lambda x: x[1], reverse=True) #highest to lowest by joint prob
        #best = candidates[:K_total]

        triples = torch.tensor([ids for ids,_ in best], device=device, dtype=torch.long)  # [K_total, 3]
        logps  = torch.tensor([lp  for _,lp in best], device=device, dtype=torch.float)   # [K_total]

    #if TIME in history was relativized by `offset`, undo for output:
    triples[:, 0] = triples[:, 0] + offset
    print("triple of tokens option from single beam", triples)

    return triples, logps #FORMAT IS TENSOR OF SHAPES [K_total, 3] and [K_total]


Functions which evaluate musical characteristics of a beam:
- Valence (positive/negative feel) heuristics: Major mode, diatonic pitch use, strong authentic cadences, fewer dissonances → higher valence. Minor/mode mixture, frequent chromatic alterations, larger fifths-distance jumps, deceptive cadences → lower valence.

- Energy heuristics: Higher tempo, higher note density, strong accents (velocity), more syncopation, wider ambitus, more leaps → higher. Slower tempo, legato (high articulation ratio), low density, stable harmonic rhythm → lower.

Brightness (symbolic): more high-register usage, open intervals (5ths, 6ths), triadic purity → “brighter.”

Tension/Release: tonal-centroid distance spikes, dissonance rate, unresolved suspensions → “tense/anxious”; stable tonality + cadences → “calm/resolved.”

Use JordanAI model


In [1]:
!pip install -q transformers huggingface_hub
from huggingface_hub import login

login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "mitmedialab/JordanAI-pianoTrading-v0.1-pytorch"

model = AutoModelForCausalLM.from_pretrained(model_id, token=True, trust_remote_code=True).to("cuda")

model.safetensors:   0%|          | 0.00/1.44G [00:00<?, ?B/s]