# Session 6: Autoregressive Music Generation (Part 1)

Agenda
- Overview of the Transformer model
- Understanding Anticipatory Music Transformers
- Hands On: Using AMT to generate MIDI data

## The Transformer architecture

From the paper [Attention Is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762).

![](./assets/transformer.png)

## Understanding Anticipatory Music Transformers

From the papers [Music Transformer (Huang et al., 2018)](https://arxiv.org/abs/1809.04281) and [Anticipatory Music Transformers (Thickstun et al., 2023)](https://arxiv.org/abs/2306.08620).

![](./assets/music_transformer.png)

## Hands On: Using AMT to generate MIDI Data

In [None]:
# Clone the anticipation repository
!git clone https://github.com/lancelotblanchard/anticipation.git ../repositories/anticipation

In [None]:
import os
os.environ["HF_HUB_CACHE"] = os.path.abspath("../huggingface_hub_cache/")

from transformers import GPT2LMHeadModel

# set Hub Cache folder
model = GPT2LMHeadModel.from_pretrained("stanford-crfm/music-small-800k", attn_implementation="eager")

In [None]:
from torchinfo import summary
summary(model)

In [None]:
# Add our repository to the Python path

import sys
sys.path.append('../repositories/anticipation')

In [None]:
from anticipation import ops
from anticipation.config import MAX_INSTR, MAX_PITCH
from anticipation.vocab import ANTICIPATE, CONTROL_OFFSET, DUR_OFFSET, NOTE_OFFSET, TIME_OFFSET
from anticipation.sample import nucleus
import torch

# Let's look at generating some tokens unconditionally. Before we can do so,
# we need to look at building an inference function:

def generate_note(model, tokens, current_time, active_instruments, top_p=0.98, history_length=340, monophony=False):
      assert len(tokens) % 3 == 0 # we need to have a valid sequence

      history = tokens.copy()
      lookback = max(len(tokens) - 3*history_length, 0) 
      history = history[lookback:]
      offset = ops.min_time(history, seconds=False)
      history[::3] = (tok - offset for tok in history[::3])

      # 3 tokens per note - time, duration, note
      new_tokens=[]

      with torch.no_grad(): #only inference, so no_grad needed
            for i in range(3):
                  input_sequence = torch.tensor([ANTICIPATE] + history + new_tokens).unsqueeze(0).to(model.device)
                  logits = model(input_sequence).logits[0, -1]
                  
                  # Filter the logits
                  #1. Do not generate controls
                  logits[CONTROL_OFFSET:] = -float('inf')

                  
                  if i==0: 
                        # no notes either (?)
                        logits[DUR_OFFSET:CONTROL_OFFSET] = -float('inf')

                        # don't want to look back beyond start, so remove all events before
                        if current_time > 0:
                              logits[TIME_OFFSET:TIME_OFFSET+current_time+(1 if monophony else 0)] = -float('inf')

                  elif i==1:
                        logits[TIME_OFFSET:DUR_OFFSET] = -float('inf')
                        logits[NOTE_OFFSET:CONTROL_OFFSET] = -float('inf')
                  elif i==2:
                        logits[TIME_OFFSET:NOTE_OFFSET] = -float('inf')
                        active_instruments = sorted(active_instruments)
                        logits[NOTE_OFFSET:NOTE_OFFSET+active_instruments[0]*MAX_PITCH] = -float('inf')

                        for j in range(len(active_instruments) - 1):
                              logits[NOTE_OFFSET + (active_instruments[j]+1)*MAX_PITCH : NOTE_OFFSET + active_instruments[j+1]*MAX_PITCH] = -float('inf')

                        logits[NOTE_OFFSET + (active_instruments[-1]+1)*MAX_PITCH : CONTROL_OFFSET] = -float('inf')
                  
                  #sampling
                  logits = nucleus(logits, top_p=top_p)
                  probs = torch.nn.functional.softmax(logits, dim=-1)
                  new_token = torch.multinomial(probs, 1).item()
                  new_tokens.append(new_token)

            new_tokens[0] += offset

            return new_tokens
      

In [None]:
# Let's try to generate a few notes

tokens = []
current_time = 0
for i in range(40):
    new_tokens = generate_note(model, tokens, current_time, active_instruments=[0])
    print(new_tokens)
    tokens+=new_tokens
    current_time = new_tokens[0]
...

In [None]:
from anticipation.convert import events_to_midi
import midi2audio
import librosa
from IPython.display import Audio

# Let's listen to our sequence of tokens

# We first need to convert our tokens to a MIDI file
events_to_midi(tokens).save("assets/generation.mid")

midi2audio_obj = midi2audio.FluidSynth("../session2_setup/assets/soundfont.sf2")
midi2audio_obj.midi_to_audio("assets/generation.mid", "assets/generation.wav")

y, sr = librosa.load("assets/generation.wav", sr=44100)

display(Audio(y, rate=sr))

In [None]:
import matplotlib.pyplot as plt

# Let's take a look at what logits look like
tokens_subset = tokens[:3*10]
with torch.no_grad():
    input_sequence = torch.tensor([ANTICIPATE] + tokens_subset).unsqueeze(0).to(model.device)
    logits = model(input_sequence).logits[0,-1]

print(logits.shape)

plt.figure(figsize=(20, 5))

plt.imshow(logits.cpu().unsqueeze(0).numpy(), aspect="auto", interpolation="nearest")
plt.xlabel("Token")
plt.yticks([])
plt.title("Logits for token 31")
plt.colorbar()

plt.show()

In [None]:
from anticipation.config import MAX_TIME

# Let's do the same thing after nucleus sampling processing

print(f"Last time of the sequence is {tokens_subset[-3]}")

# Safety filtering for time tokens
logits[DUR_OFFSET:] = -float('inf')

logits[TIME_OFFSET:TIME_OFFSET+tokens_subset[-3]] = -float('inf')

# Get the probability distribution of the new token
nucleus_logits = nucleus(logits, top_p=0.98)
probs = torch.nn.functional.softmax(nucleus_logits, dim=-1)
new_token = torch.multinomial(probs, 1).item()

print(f"New token would be sampled for {new_token}")

plt.figure(figsize=(20, 5))
plt.plot(probs[:MAX_TIME].cpu().numpy())
plt.xlabel("Token")
plt.ylabel("Probability")
plt.title("Probablility distribution for sampling token 31")

plt.show()

In [None]:
from anticipation.convert import midi_to_events

# Let's now work with a conditioning signal! We'll convert our MIDI file to tokens

symphony40 = midi_to_events("../session2_setup/assets/symphony40.mid")
print(f"Number of tokens: {len(symphony40)}, number of notes: {len(symphony40)//3}")
print(f"First tokens: {symphony40[:6]}")

In [None]:
# We can calculate the number of instruments by looking at the tokens
instruments = set()
for note in symphony40[2::3]:
    instruments.add((note-NOTE_OFFSET) // MAX_PITCH)


print(f"Number of instruments: {len(instruments)}")
print(f"Instruments: {instruments}")

In [None]:
# Let's filter the first 238 notes of the sequence and only keep instrument 42
# We will also keep the notes of instrument 40 as a ground truth

control_tokens = []
ground_truth = []
for t, d, n in zip(symphony40[:238*3:3], symphony40[1:238*3:3], symphony40[2:238*3:3]):
    if (n-NOTE_OFFSET)//MAX_PITCH == 42:
             control_tokens += [t,d,n]
             ground_truth+=[t,d,n]

    if (n - NOTE_OFFSET)//MAX_PITCH == 40:
           ground_truth +=[t,d,n]


           
events_to_midi(control_tokens).save("assets/control.mid")
events_to_midi(ground_truth).save("assets/ground_truth.mid")
midi2audio_obj.midi_to_audio("assets/control.mid", "assets/control.wav")
midi2audio_obj.midi_to_audio("assets/ground_truth.mid", "assets/ground_truth.wav")

y, sr = librosa.load("assets/control.wav", sr=44100)
display(Audio(y, rate=sr))

y, sr = librosa.load("assets/ground_truth.wav", sr=44100)
display(Audio(y, rate=sr))

In [None]:
from anticipation.config import DELTA
import math

# We will use this sequence as a conditioning signal for our generation
# Let's generate instrument 42 with the given control signal

# First, let's add CONTROL_OFFSET to the control tokens
anticipated_control_tokens = [CONTROL_OFFSET + t for t in control_tokens]

# We select the first control token and leave the rest for later
atime, adur, anote = anticipated_control_tokens[0:3]
atokens = anticipated_control_tokens[3:]

# This is the time of the first control
anticipated_time = atime - CONTROL_OFFSET

# We will generate until end time
end_time = max(control_tokens[::3])
current_time = 0
conditioned_tokens = []

# Generation loop
while current_time < end_time:
    # Anticipated if needed
    while current_time >= anticipated_time - DELTA:
        conditioned_tokens.extend([atime, adur, anote])

        if len(atokens) > 0:
            atime, adur, anote = atokens[0:3]
            atokens = atokens[3:]
            anticipated_time = atime - CONTROL_OFFSET
        else:
            anticipated_time = math.inf
        
    new_tokens = generate_note(model, conditioned_tokens, current_time, active_instruments=[40], monophony=True)
    print(new_tokens)
    conditioned_tokens += new_tokens
    current_time = new_tokens[0]

# We remove the control tokens and add them without CONTROL_OFFSET
conditioned_tokens, _ = ops.split(conditioned_tokens)
conditioned_tokens = ops.sort(conditioned_tokens + control_tokens)

In [None]:
# We can now listen to our result

events_to_midi(conditioned_tokens).save("assets/conditioned_generation.mid")

midi2audio_obj = midi2audio.FluidSynth("../session2_setup/assets/soundfont.sf2")
midi2audio_obj.midi_to_audio("assets/conditioned_generation.mid", "assets/conditioned_generation.wav")

y, sr = librosa.load("assets/conditioned_generation.wav", sr=44100)

display(Audio(y, rate=sr))