# Session 7: Autoregressive Music Generation (Part 2)

Agenda
- Understanding MusicGen
- Hands On: Using MusicGen to generate audio

## Understanding MusicGen

From the paper [Simple and Controllable Music Generation (Copet et al., 2023)](https://arxiv.org/abs/2306.05284).

![](./assets/musicgen.png)

## Hands On: Using MusicGen to generate audio

In [None]:
import os
os.environ["HF_HUB_CACHE"] = os.path.abspath("../huggingface_hub_cache")

from transformers import MusicgenMelodyForConditionalGeneration, MusicgenMelodyProcessor
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = MusicgenMelodyProcessor.from_pretrained("facebook/musicgen-melody")
model = MusicgenMelodyForConditionalGeneration.from_pretrained("facebook/musicgen-melody")
model = model.to(device)

In [None]:
from torchinfo import summary

summary(model)

In [None]:
# We can first generate unconditional music
unconditional_inputs = ...

unconditional_audio_values = ...

In [None]:
from IPython.display import Audio, display

# Let's listen to our audio

display(Audio(unconditional_audio_values.squeeze(0, 1).cpu(), rate=model.config.sampling_rate))

In [None]:
# We can also generate a piece of music conditionally, with a given text prompt

text_conditioned_inputs = ...

text_conditioned_audio_values = ...

display(Audio(text_conditioned_audio_values.squeeze(0, 1).cpu(), rate=model.config.sampling_rate))

In [None]:
import librosa
from IPython.display import Audio

# And we can also generate with a melody condition, passed as an audio array

y, sr = librosa.load("bolero_ravel.mp3", sr=model.config.sampling_rate)

display(Audio(y, rate=sr))

melody_conditioned_inputs = ...

melody_conditioned_audio_values = ...

display(Audio(melody_conditioned_audio_values.squeeze(0, 1).cpu(), rate=model.config.sampling_rate))

In [None]:
# Let's take a look at how this model actually generates music

# ############### #
# 0. CONDITIONING # 
# ############### #

text_prompt = ...
inputs_tensor = text_prompt["input_ids"].to(model.device)
attention_mask = text_prompt["attention_mask"].to(model.device)

print(inputs_tensor)
print(attention_mask)

# Then, we get our melody conditioning (a chroma spectrogram)
melody_prompt = ...
input_features = melody_prompt["input_features"].to(model.device)

print(melody_prompt["input_features"].shape)

In [None]:
import copy

# ################# #
# 1. PREPARE CONFIG #
# ################# #

generation_config = copy.deepcopy(model.generation_config)
model._prepare_special_tokens(generation_config, device=model.device)

In [None]:
import math

# #################### #
# 2. TEXT CONDITIONING #
# #################### #

encoder = model.get_text_encoder()
with torch.no_grad():
    encoder_hidden_states = ...

# project encoder_hidden_states
encoder_hidden_states = ...

# for classifier free guidance we need to add a 'null' input to our encoder hidden states
encoder_hidden_states = ...
encoder_attention_mask = ...
encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[..., None]

In [None]:
# ##################### #
# 3. AUDIO CONDITIONING #
# ##################### #

null_audio_hidden_states = ...

# for classifier free guidance we need to add a 'null' input to our audio hidden states
audio_hidden_states = torch.concatenate([input_features, null_audio_hidden_states], dim=0)

# project audio_hidden_states ->
# (batch_size, seq_len, num_chroma) -> (batch_size, seq_len, hidden_size)
audio_hidden_states = ...

# pad or truncate to config.chroma_length
n_repeat = ...
audio_hidden_states = ...

audio_hidden_states = ...

encoder_hidden_states = torch.cat([audio_hidden_states, encoder_hidden_states], dim=1)

In [None]:
# ##################################### #
# 4. PREPARE AUTO-REGRESSIVE GENERATION #
# ##################################### #

input_ids = ...

In [None]:
# ###################### #
# 5. BUILD DELAY PATTERN #
# ###################### #

max_length = 513

input_ids, decoder_delay_pattern_mask =...
print(decoder_delay_pattern_mask.shape)
print(decoder_delay_pattern_mask)

In [None]:
from transformers import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList, TopKLogitsWarper

# ########################### #
# 6. PREPARE LOGITS PROCESSOR #
# ########################### #

guidance_scale = 3

logits_processor = ...

In [None]:
from transformers import StoppingCriteriaList, MaxLengthCriteria

# ############################ #
# 7. PREPARE STOPPING CRITERIA #
# ############################ #

stopping_criteria = ...

In [None]:
# #################### #
# 8. RUN SAMPLING LOOP #
# #################### #

with torch.no_grad():
    outputs = ...

In [None]:
# ################ #
# 9. DECODE OUTPUT #
# ################ #

# apply the pattern mask to the final ids
output_ids = ...

# revert the pattern delay mask by filtering the pad token id
output_ids = ...

# append the frame dimension back to the audio codes
output_ids = ...

with torch.no_grad():
    output_values = ...

In [None]:
# Do we get a similar output?

display(Audio(output_values.cpu().squeeze(0, 1), rate=model.config.sampling_rate))