# Session 7: Autoregressive Music Generation (Part 2)

Agenda
- Understanding MusicGen
- Hands On: Using MusicGen to generate audio

## Understanding MusicGen

From the paper [Simple and Controllable Music Generation (Copet et al., 2023)](https://arxiv.org/abs/2306.05284).

![](./assets/musicgen.png)

## Hands On: Using MusicGen to generate audio

In [None]:
import os
os.environ["HF_HUB_CACHE"] = os.path.abspath("../huggingface_hub_cache")

from transformers import MusicgenMelodyForConditionalGeneration, MusicgenMelodyProcessor
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = MusicgenMelodyProcessor.from_pretrained("facebook/musicgen-melody")
model = MusicgenMelodyForConditionalGeneration.from_pretrained("facebook/musicgen-melody")
model = model.to(device)

In [None]:
from torchinfo import summary

summary(model)

In [None]:
# We can first generate unconditional music
unconditional_inputs = processor.get_unconditional_inputs(num_samples=1).to(model.device)
unconditional_audio_values =  model.generate(**unconditional_inputs, max_new_tokens=512)

In [None]:
from IPython.display import Audio, display

# Let's listen to our audio

display(Audio(unconditional_audio_values.squeeze(0, 1).cpu(), rate=model.config.sampling_rate))

In [None]:
# We can also generate a piece of music conditionally, with a given text prompt

text_conditioned_inputs = processor(
    text=['90s grunge track with gritty guitar strumming'],
    padding=True,
    return_tensors="pt"
)

# text_conditioned_audio_values = model.generate(**text_conditioned_inputs, guidance_scale=3)


display(Audio(text_conditioned_audio_values.squeeze(0, 1).cpu(), rate=model.config.sampling_rate))

In [None]:
import librosa
from IPython.display import Audio

# And we can also generate with a melody condition, passed as an audio array

y, sr = librosa.load("assets/bolero_ravel.mp3", sr=model.config.sampling_rate)

display(Audio(y, rate=sr))

melody_conditioned_inputs = processor(
    audio=y,
    sampling_rate=model.config.sampling_rate,
    text=['bluegrass americana guitar, fiddle, bass, percussion'],
    padding=True,
    return_tensors='pt'
)

melody_conditioned_audio_values = model.generate(**melody_conditioned_inputs, guidance_scale=3, max_new_tokens=512)
display(Audio(melody_conditioned_audio_values.squeeze(0, 1).cpu(), rate=model.config.sampling_rate))

melody_conditioned_audio_values = model.generate(**melody_conditioned_inputs, guidance_scale=3, max_new_tokens=512)
display(Audio(melody_conditioned_audio_values.squeeze(0, 1).cpu(), rate=model.config.sampling_rate))

melody_conditioned_audio_values = model.generate(**melody_conditioned_inputs, guidance_scale=3, max_new_tokens=512)
display(Audio(melody_conditioned_audio_values.squeeze(0, 1).cpu(), rate=model.config.sampling_rate))

In [None]:
# Let's take a look at how this model actually generates music

# ############### #
# 0. CONDITIONING # 
# ############### #

text_prompt = processor.tokenizer("1970s rock/reggae fusion like the band The Police", return_tensors="pt")
inputs_tensor = text_prompt["input_ids"].to(model.device)
attention_mask = text_prompt["attention_mask"].to(model.device)

print(inputs_tensor)
print(attention_mask)

# Then, we get our melody conditioning (a chroma spectrogram)
melody_prompt = processor.feature_extractor(y, sampling_rate=model.config.sampling_rate, return_tensors="pt")
input_features = melody_prompt["input_features"].to(model.device)

print(melody_prompt["input_features"].shape)

In [None]:
import copy

# ################# #
# 1. PREPARE CONFIG #
# ################# #

generation_config = copy.deepcopy(model.generation_config)
model._prepare_special_tokens(generation_config, device=model.device)

In [None]:
import math

# #################### #
# 2. TEXT CONDITIONING #
# #################### #

encoder = model.get_text_encoder()

# get last hidden state to pass to decoder for cross attention
with torch.no_grad():
    encoder_hidden_states = encoder(
        input_ids=inputs_tensor,
        attention_mask = attention_mask,
        output_attentions = generation_config.output_attentions,
        output_hidden_states=generation_config.output_hidden_states
    ).last_hidden_state

# project encoder_hidden_states
print(encoder_hidden_states.shape)
encoder_hidden_states = model.enc_to_dec_proj(encoder_hidden_states)
print(encoder_hidden_states.shape)

# for classifier free guidance we need to add a 'null' input to our encoder hidden states - includes version without condition, so need null inputs added
encodeer_hidden_states = torch.concatenate([encoder_hidden_states, torch.zeros_like(encoder_hidden_states)])
encoder_attention_mask = torch.concatenate(
    [attention_mask, torch.zeros_like(attention_mask)]

)
encoder_hidden_states = encoder_hidden_states * encoder_attention_mask[..., None]

print(encoder_hidden_states.shape) #now has two batches, one for with and without the conditioning

In [None]:
# ##################### #
# 3. AUDIO CONDITIONING #
# ##################### #

print(input_features.shape)

null_audio_hidden_states = torch.zeros_like(input_features)
null_audio_hidden_states[:,:,0]=1
#
# for classifier free guidance we need to add a 'null' input to our audio hidden states
audio_hidden_states = torch.concatenate([input_features, null_audio_hidden_states], dim=0)

print(audio_hidden_states.shape)

# project audio_hidden_states ->
# (batch_size, seq_len, num_chroma) -> (batch_size, seq_len, hidden_size)
audio_hidden_states = torch.concatenate([input_features, null_audio_hidden_states], dim=0)
print(audio_hidden_states.shape)

audio_hidden_states = model.audio_enc_to_dec_proj(audio_hidden_states)
print(audio_hidden_states.shape)

# pad or truncate to config.chroma_length
n_repeat = int(math.ceil(model.config.chroma_length / audio_hidden_states.shape[1]))
print(audio_hidden_states.shape)

audio_hidden_states = audio_hidden_states = audio_hidden_states.repeat(1, n_repeat,1)
print(audio_hidden_states.shape)

audio_hidden_states = audio_hidden_states[:, :model.config.chroma_length]
print(audio_hidden_states.shape)

encoder_hidden_states = torch.cat([audio_hidden_states, encoder_hidden_states], dim=1)
print(encoder_hidden_states.shape)

In [None]:
# ##################################### #
# 4. PREPARE AUTO-REGRESSIVE GENERATION #
# ##################################### #

input_ids = torch.tensor([
    [2048],
    [2048],
    [2048],
    [2048]
], device=model.device)



In [None]:
# ###################### #
# 5. BUILD DELAY PATTERN #
# ###################### #

max_length = 513

input_ids, decoder_delay_pattern_mask = model.decoder.build_delay_pattern_mask(
    input_ids,
    pad_token_id=generation_config._decoder_start_token_tensor,
    max_length=max_length,
)
print(decoder_delay_pattern_mask.shape)
print(decoder_delay_pattern_mask)

In [None]:
from transformers import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList, TopKLogitsWarper

# ########################### #
# 6. PREPARE LOGITS PROCESSOR #
# ########################### #

guidance_scale = 3

logits_processor = LogitsProcessorList()
logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(guidance_scale))
logits_processor.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1))

In [None]:
from transformers import StoppingCriteriaList, MaxLengthCriteria

# ############################ #
# 7. PREPARE STOPPING CRITERIA #
# ############################ #

stopping_criteria = StoppingCriteriaList()
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))

In [None]:
# #################### #
# 8. RUN SAMPLING LOOP #
# #################### #

with torch.no_grad():
    outputs = model._sample(input_ids, logits_processor = logits_processor, stopping_criteria=stopping_criteria,
                            generation_config=generation_config, use_cache=True, guidance_scale=guidance_scale,
                            input_features=input_features, encoder_hiden_states=encoder_hidden_states,
                            decoder_delay_pattern_mask=decoder_delay_pattern_mask,
                            synced_gpus=None, streamer=None,)

In [None]:
# ################ #
# 9. DECODE OUTPUT #
# ################ #

# apply the pattern mask to the final ids
output_ids = model.decoder.apply_delay_pattern_mask(outputs, decoder_delay_pattern_mask)

# revert the pattern delay mask by filtering the pad token id
output_ids = output_ids[output_ids!=generation_config._pad_token_tensor].reshape(1, model.decoder.num_codebooks, -1)

# append the frame dimension back to the audio codes
output_ids = output_ids[None, ...]

with torch.no_grad():
    output_values = model.audio_encoder.decode(output_ids, audio_scales=[None],).audio_values

In [None]:
# Do we get a similar output?

display(Audio(output_values.cpu().squeeze(0, 1), rate=model.config.sampling_rate))