# Session 6: Autoregressive Music Generation

Agenda
- Overview of the Transformer model
- Understanding Anticipatory Music Transformers
- Understanding MusicGen
- Hands On 1: Using AMT to generate MIDI data
- Hands On 2: Using MusicGen to generate audio

## The Transformer architecture

From the paper [Attention Is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762).

![](./assets/transformer.png)

## Understanding Anticipatory Music Transformers

From the papers [Music Transformer (Huang et al., 2018)](https://arxiv.org/abs/1809.04281) and [Anticipatory Music Transformers (Thickstun et al., 2023)](https://arxiv.org/abs/2306.08620).

![](./assets/music_transformer.png)

## Understanding MusicGen

From the paper [Simple and Controllable Music Generation (Copet et al., 2023)](https://arxiv.org/abs/2306.05284).

![](./assets/musicgen.png)

# Hands On 1: Using AMT to generate MIDI Data

In [None]:
# Clone the anticipation repository
!git clone https://github.com/lancelotblanchard/anticipation.git ../repositories/anticipation

In [None]:
import os
os.environ["HF_HUB_CACHE"] = os.path.abspath("../huggingface_hub_cache/")

from transformers import GPT2LMHeadModel

# set Hub Cache folder
model = GPT2LMHeadModel.from_pretrained("stanford-crfm/music-small-800k", attn_implementation="eager")

In [None]:
# Add our repository to the Python path

import sys
sys.path.append('../repositories/anticipation')

In [None]:
from anticipation import ops
from anticipation.config import MAX_INSTR, MAX_PITCH
from anticipation.vocab import ANTICIPATE, CONTROL_OFFSET, DUR_OFFSET, NOTE_OFFSET, TIME_OFFSET
from anticipation.sample import nucleus
import torch

# Let's look at generating some tokens unconditionally. Before we can do so,
# we need to look at building an inference function:

def generate_note(model, tokens, current_time, active_instruments, top_p=0.98, history_length=340, monophony=False):
    assert len(tokens) % 3 == 0 # we need to have a valid sequence

    new_tokens = []

    ...

    return new_tokens

In [None]:
# Let's try to generate a few notes

tokens = []
current_time = 0
...

In [None]:
from anticipation.convert import events_to_midi
import midi2audio
import librosa
from IPython.display import Audio

# Let's listen to our sequence of tokens

# We first need to convert our tokens to a MIDI file
events_to_midi(tokens).save("assets/generation.mid")

midi2audio_obj = midi2audio.FluidSynth("../session2_setup/assets/soundfont.sf2")
midi2audio_obj.midi_to_audio("assets/generation.mid", "assets/generation.wav")

y, sr = librosa.load("assets/generation.wav", sr=44100)

display(Audio(y, rate=sr))

In [None]:
import matplotlib.pyplot as plt

# Let's take a look at what logits look like
tokens_subset = tokens[:3*10]
with torch.no_grad():
    input_sequence = ...
    logits = ...

print(logits.shape)

plt.figure(figsize=(20, 5))

...

plt.show()

In [None]:
from anticipation.config import MAX_TIME

# Let's do the same thing after nucleus sampling processing

print(f"Last time of the sequence is {tokens_subset[-3]}")

# Safety filtering for time tokens
...

# Get the probability distribution of the new token
nucleus_logits = ...
probs = ...
new_token = ...

print(f"New token would be sampled for {new_token}")

plt.figure(figsize=(20, 5))

...

plt.show()

In [None]:
from anticipation.convert import midi_to_events

# Let's now work with a conditioning signal! We'll convert our MIDI file to tokens

symphony40 = midi_to_events("../session2_setup/assets/symphony40.mid")
print(f"Number of tokens: {len(symphony40)}, number of notes: {len(symphony40)//3}")
print(f"First tokens: {symphony40[:6]}")

In [None]:
# We can calculate the number of instruments by looking at the tokens

instruments = set()
...

print(f"Number of instruments: {len(instruments)}")
print(f"Instruments: {instruments}")

In [None]:
# Let's filter the first 238 notes of the sequence and only keep instrument 42
# We will also keep the notes of instrument 40 as a ground truth

control_tokens = []
ground_truth = []
...

events_to_midi(control_tokens).save("assets/control.mid")
midi2audio_obj.midi_to_audio("assets/control.mid", "assets/control.wav")

y, sr = librosa.load("assets/control.wav", sr=44100)

display(Audio(y, rate=sr))

In [None]:
from anticipation.config import DELTA
import math

# We will use this sequence as a conditioning signal for our generation
# Let's generate instrument 42 with the given control signal

# First, let's add CONTROL_OFFSET to the control tokens
anticipated_control_tokens = ...

# We select the first control token and leave the rest for later
atime, adur, anote = ...
atokens = ...
# This is the time of the first control
anticipated_time = ...

# We will generate until end time
end_time = max(control_tokens[::3])
current_time = 0
conditioned_tokens = []

# Generation loop
while current_time < end_time:
    # Anticipated if needed
    while current_time >= anticipated_time - DELTA:
        ...

    new_tokens = ...
    print(new_tokens)
    ...

# We remove the control tokens and add them without CONTROL_OFFSET
conditioned_tokens, _ = ops.split(conditioned_tokens)
conditioned_tokens = ops.sort(conditioned_tokens + control_tokens)

In [None]:
# We can now listen to our result

events_to_midi(conditioned_tokens).save("assets/conditioned_generation.mid")

midi2audio_obj = midi2audio.FluidSynth("../session2_setup/assets/soundfont.sf2")
midi2audio_obj.midi_to_audio("assets/conditioned_generation.mid", "assets/conditioned_generation.wav")

y, sr = librosa.load("assets/conditioned_generation.wav", sr=44100)

display(Audio(y, rate=sr))

## Hands On 2: Using MusicGen to generate audio

In [None]:
import os
os.environ["HF_HUB_CACHE"] = os.path.abspath("../huggingface_hub_cache")

from transformers import MusicgenMelodyForConditionalGeneration, MusicgenMelodyProcessor
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = MusicgenMelodyProcessor.from_pretrained("facebook/musicgen-melody")
model = MusicgenMelodyForConditionalGeneration.from_pretrained("facebook/musicgen-melody")
model = model.to(device)

In [None]:
from torchinfo import summary

summary(model)

In [None]:
# We can first generate unconditional music
unconditional_inputs = ...

unconditional_audio_values = ...

In [None]:
from IPython.display import Audio, display

# Let's listen to our audio

display(Audio(unconditional_audio_values.squeeze(0, 1).cpu(), rate=model.config.sampling_rate))

In [None]:
# We can also generate a piece of music conditionally, with a given text prompt

text_conditioned_inputs = ...

text_conditioned_audio_values = ...

display(Audio(text_conditioned_audio_values.squeeze(0, 1).cpu(), rate=model.config.sampling_rate))

In [None]:
import librosa
from IPython.display import Audio

# And we can also generate with a melody condition, passed as an audio array

y, sr = librosa.load("bolero_ravel.mp3", sr=model.config.sampling_rate)

display(Audio(y, rate=sr))

melody_conditioned_inputs = ...

melody_conditioned_audio_values = ...

display(Audio(melody_conditioned_audio_values.squeeze(0, 1).cpu(), rate=model.config.sampling_rate))

In [None]:
# Let's take a look at how this model actually generates music

# ############### #
# 0. CONDITIONING # 
# ############### #

text_prompt = ...
inputs_tensor = text_prompt["input_ids"].to(model.device)
attention_mask = text_prompt["attention_mask"].to(model.device)

print(inputs_tensor)
print(attention_mask)

# Then, we get our melody conditioning (a chroma spectrogram)
melody_prompt = ...
input_features = melody_prompt["input_features"].to(model.device)

print(melody_prompt["input_features"].shape)

In [None]:
import copy

# ################# #
# 1. PREPARE CONFIG #
# ################# #

generation_config = copy.deepcopy(model.generation_config)
model._prepare_special_tokens(generation_config, device=model.device)

In [None]:
import math

# #################### #
# 2. TEXT CONDITIONING #
# #################### #

encoder = model.get_text_encoder()
with torch.no_grad():
    encoder_hidden_states = ...

# project encoder_hidden_states
encoder_hidden_states = ...

# for classifier free guidance we need to add a 'null' input to our encoder hidden states
encoder_hidden_states = ...
encoder_attention_mask = ...
encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[..., None]

In [None]:
# ##################### #
# 3. AUDIO CONDITIONING #
# ##################### #

null_audio_hidden_states = ...

# for classifier free guidance we need to add a 'null' input to our audio hidden states
audio_hidden_states = torch.concatenate([input_features, null_audio_hidden_states], dim=0)

# project audio_hidden_states ->
# (batch_size, seq_len, num_chroma) -> (batch_size, seq_len, hidden_size)
audio_hidden_states = ...

# pad or truncate to config.chroma_length
n_repeat = ...
audio_hidden_states = ...

audio_hidden_states = ...

encoder_hidden_states = torch.cat([audio_hidden_states, encoder_hidden_states], dim=1)

In [None]:
# ##################################### #
# 4. PREPARE AUTO-REGRESSIVE GENERATION #
# ##################################### #

input_ids = ...

In [None]:
# ###################### #
# 5. BUILD DELAY PATTERN #
# ###################### #

max_length = 513

input_ids, decoder_delay_pattern_mask =...
print(decoder_delay_pattern_mask.shape)
print(decoder_delay_pattern_mask)

In [None]:
from transformers import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList, TopKLogitsWarper

# ########################### #
# 6. PREPARE LOGITS PROCESSOR #
# ########################### #

guidance_scale = 3

logits_processor = ...

In [None]:
from transformers import StoppingCriteriaList, MaxLengthCriteria

# ############################ #
# 7. PREPARE STOPPING CRITERIA #
# ############################ #

stopping_criteria = ...

In [None]:
# #################### #
# 8. RUN SAMPLING LOOP #
# #################### #

with torch.no_grad():
    outputs = ...

In [None]:
# ################ #
# 9. DECODE OUTPUT #
# ################ #

# apply the pattern mask to the final ids
output_ids = ...

# revert the pattern delay mask by filtering the pad token id
output_ids = ...

# append the frame dimension back to the audio codes
output_ids = ...

with torch.no_grad():
    output_values = ...

In [None]:
# Do we get a similar output?

display(Audio(output_values.cpu().squeeze(0, 1), rate=model.config.sampling_rate))