# Session 5: Representation Learning for Music

**Agenda:**

- Comparison of Music representations
- Case Study: Encodec
- Hands On 1: Using Encodec
- Hands On 2: Encoding MIDI
- Hands On 3: Using HiFi-GAN

## Comparison of some representations

From [Comparing Representations for Audio Synthesis Using Generative Adversarial Networks (Nistal, Lattner, Richard, 2020)](https://arxiv.org/abs/2006.09266).

![](./assets/comparison_representations.png)

## Deep Dive into Encodec

Encodec is a *neural audio compression framework* developed by Meta, that
enables the efficient compression of high-fidelity audio into a compact discrete
representation.

![](./assets/encodec_diagram.png)

### Understanding Residual Vector Quantization (RVQ)

![](./assets/rvq.png)

A limitation of this method is that it takes $N_q$ steps (with many lookups!) to
quantize one single vector. So this becomes very costly to quantize an entire 
tensor.

### The Encodec speedup: RVQ through Transformers

![](./assets/rvq_transformer.png)

## Hands On 1: Using Encodec to encode and decode audio waveforms

We can directly use the `EncodecModel` provided by the `transformers` package.

In [None]:
import os
os.environ["HF_HUB_CACHE"] = os.path.abspath("../huggingface_hub_cache/")

from transformers import EncodecModel, EncodecFeatureExtractor

model = EncodecModel.from_pretrained("facebook/encodec_24khz")
processor = EncodecFeatureExtractor.from_pretrained("facebook/encodec_24khz")

In [None]:
from torchinfo import summary
import torch

# Again, let's get some info using torchinfo
summary(model, input_data=[torch.randn(1, 1, 320)])

In [None]:
# How do we get this downsampling factor value? We can just multiply all of the
# strides!
downsampling_factor = 1

for layer in model.encoder.layers:
    if hasattr(layer, "stride"):
        downsampling_factor *= layer.stride.item()

print(f"Downsampling factor: {downsampling_factor}")

In [None]:
import librosa

# Let's load the audio file
y, sr = librosa.load("../session2_setup/assets/stargazing.wav", sr=processor.sampling_rate)

# Let's first go through the processor
inputs=processor(raw_audio=y, sampling_rate=sr,return_tensors="pt")

# Let's select the lowest bandwidth
bandwidth = model.config.target_bandwidths[0]
print(f"Target bandwidth: {bandwidth} kbps")

# We can now pass the inputs to the model to get the encoder outputs
encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"], bandwidth)

# Let's get the shape of the encoder outputs
print(encoder_outputs.audio_codes.shape)

# Let's calculate the compression ratio
compression_ratio = y.shape[0] / encoder_outputs.audio_codes.shape[-1]
print(f"Compression ratio: {compression_ratio:.2f}")

In [None]:
from IPython.display import Audio, display

# Let's decode those codes back to audio
audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"])[0].squeeze(0, 1)

display(Audio(audio_values.detach().numpy(), rate=processor.sampling_rate))

In [None]:
# Let's try it one more time with a higher bandwidth!
bandwidth_high = model.config.target_bandwidths[-1]
print(f"Target bandwidth: {bandwidth_high} kbps")

encoder_outputs_high = model.encode(inputs["input_values"], inputs["padding_mask"], bandwidth_high)
print(f"Encoder outputs shape: {encoder_outputs_high.audio_codes.shape}")

# Let's decode those codes back to audio
audio_values_high = model.decode(encoder_outputs_high.audio_codes, encoder_outputs_high.audio_scales, inputs["padding_mask"])[0].squeeze(0,1)

display(Audio(audio_values_high.detach().numpy(), rate=processor.sampling_rate))

In [None]:
# We can print some of the codes
print(encoder_outputs_high.audio_codes)

In [None]:
# Let's dive a bit through the encoder (before the quantization)
pre_quantized = model.encoder(inputs["input_values"]*inputs["padding_mask"].unsqueeze(1))

print(f"Pre-quantized shape: {pre_quantized.shape}")
print(f"Hidden size of the encoder: {model.config.hidden_size}")

print(pre_quantized)

In [None]:
# Let's calculate the outputs of the first codebook and print them
cb1 = model.quantizer.layers[0].codebook.quantize(pre_quantized.permute(0,2,1).reshape(-1, model.config.hidden_size))
print(cb1)
# print("Orig: ", pre_quantized.permute(0,2,1).Shape)
# print("Orig: ", pre_quantized.permute(0,2,1).Shape)
# print("Reshape: ", pre_quantized.permute(0,2,1).reshape(-1, model.config.hidden_size).Shape)

In [None]:
# How far away are we? Let's decode the first code and compare to our first vector
x_hat = torch.nn.functional.embedding(cb1[0], model.quantizer.layers[0].codebook.embed)
print(x_hat.shape)

# Let's calculate the Euclidean distance between the two vectors
dist = torch.norm(pre_quantized[0, :, 0] - x_hat)
print(f"Euclidean distance: {dist:.4f}")

In [None]:
# Is it the best we can do? Let's verify whether `quantize` gave us the best code
distances = []
for i in range(model.config.codebook_size):
    x_hat = torch.nn.functional.embedding(torch.tensor(i), model.quantizer.layers[0].codebook.embed)
    dist = torch.norm(pre_quantized[0,:,0] - x_hat)
    distances.append(dist.item())

argmin = torch.argmin(torch.tensor(distances))

print(f"Minimum distance is {distances[argmin]:.4f} for code {argmin.item()}")

In [None]:
# Let's calculate the outputs of the SECOND codebook, and print them.
#
residual = pre_quantized - (torch.nn.functional.embedding(cb1.unsqueeze(0), model.quantizer.layers[0].codebook.embed).permute(0,2,1))

cb2 = model.quantizer.layers[1].codebook.quantize(residual.permute(0,2,1).reshape(-1,model.config.hidden_size))
print(cb2)

In [None]:
# Let's check whether we improved the distance
x_hat2 = torch.nn.functional.embedding(cb2[0], model.quantizer.layers[1].codebook.embed)
print(x_hat2.shape)


# Let's calculate the Euclidean distance between the two vectors
dist = torch.norm(pre_quantized[0,:,0] - (x_hat + x_hat2))
print(f"Euclidean distance: {dist:.4f}") # hope that this is smaller than the previous one

In [None]:
# Let's calculate the final distance
n_codebooks = model.quantizer.get_num_quantizers_for_bandwidth(bandwidth_high)

all_xhats = torch.zeros(n_codebooks, model.config.hidden_size)

for i in range(n_codebooks):
    xhat = torch.nn.functional.embedding(encoder_outputs_high.audio_codes[0,0,i,0], model.quantizer.layers[i].codebook.embed)
    all_xhats[i]=xhat

# Let's calculate the Euclidean distance between the two vectors
final_dist = torch.norm(pre_quantized [0,:,0] - torch.sum(all_xhats, dim=0))
# print(final_dist)
print(f"Euclidean distance: {final_dist:.4f}") # hope that this is smaller than the previous one

In [None]:
import matplotlib.pyplot as plt

# Let's plot how distance decreases with higher bandwidth / more codebooks
distances = []

for i in range(n_codebooks):
    distances.append(torch.norm(pre_quantized[0,:,0] - torch.sum(all_xhats[:i+1], dim=0)).item())

target_codebooks = [model.quantizer.get_num_quantizers_for_bandwidth(n) for n in model.config.target_bandwidths]

plt.figure(figsize=(10, 4))
plt.plot(range(1, n_codebooks + 1), distances)
plt.vlines(target_codebooks, ymin=0, ymax=max(distances), color="grey", linestyles="dashed")
plt.title("Euclydian Distances Based on Number of Codebooks")
plt.xlabel("Number of Codebooks")
plt.ylabel("Euclydian Distance")

plt.show()

## Hands On 2: Encoding MIDI using the Anticipatory Music Transformer strategy

In [None]:
import mido
from collections import defaultdict

# This is our time resolution (how many ticks per second)
TIME_RESOLUTION = 100

# Let's first load the MIDI file
midi = mido.MidiFile("../session2_setup/assets/symphony40.mid")

def midi_to_tokens(midi):
    # We will keep track of time, tokens, note index, open note ons, and instruments
    time = 0
    tokens = []
    note_idx = 0
    open_notes = defaultdict(list)
    instruments = defaultdict(list)

    # We iterate through the MIDI file
    for message in midi:
        # We add the message time to our global clock
        time += message.time

        # If we have a program change, add a new instrument
        if message.type == "program_change":
            instruments[message.channel] = message.program
        
        elif message.type == "note_on" and message.velocity > 0:
            time_in_ticks = round(TIME_RESOLUTION*time)

            # Add (time, duration, note, instrument, velocity) to tokens
            tokens.append(time_in_ticks)
            tokens.append(-1) # We do not have the duration yet 
            tokens.append(message.note)
            tokens.append(instruments[message.channel])
            tokens.append(message.velocity)

            # Keep track of open Note On
            open_notes[(instruments[message.channel], message.note, message.channel)].append((note_idx, time))
            note_idx += 1

        elif message.type == "note_off":
            try:
                open_idx, onset_time = open_notes[(instruments[message.channel], message.note, message.channel)].pop(0)
            except IndexError:
                print("WARNING: Note off before note on!")
            else:
                duration_in_ticks = round(TIME_RESOLUTION*(time - onset_time))
                tokens[5*open_idx + 1] = duration_in_ticks

    # At the end of the conversion, check how many notes are still open
    unclosed_count = 0
    for _, v in open_notes.items():
        unclosed_count += len(v)

    if unclosed_count > 0:
        print(f"WARNING: {unclosed_count} unclosed notes")

    return tokens

# We can call our function
tokens = midi_to_tokens(midi)
print(f"Length of sequence: {len(tokens)}")


In [None]:
print(tokens)

In [None]:
def tokens_to_midi(tokens, bpm=120):
    midi = mido.MidiFile()
    midi.ticks_per_beat = (60 * TIME_RESOLUTION) // bpm

    # We will create a dictionary that maps (time_in_ticks, event_type) to
    # (note, instrument, velocity) with event_type=0 for Note On and
    # event_type=1 for Note Off
    time_idx = defaultdict(list)
    for (time_in_ticks, duration_in_ticks, note, instrument, velocity) in zip(tokens[0::5], tokens[1::5], tokens[2::5], tokens[3::5], tokens[4::5]):
        time_idx[(time_in_ticks, 0)].append((note, instrument, velocity))
        time_idx[(time_in_ticks+duration_in_ticks, 1)].append((note, instrument, 0))

    # track_idx maps instruments to (track, previous_time, idx)
    track_idx = {}
    num_tracks = 0

    # Double loop to go through all events
    for time_in_ticks, event_type in sorted(time_idx.keys()):
        for (note, instrument, velocity) in time_idx[(time_in_ticks, event_type)]:
            # If Note On, add Note On to the track
            if event_type == 0:
                try:
                    track, previous_time, idx = track_idx[instrument]
                except KeyError:
                    # If it doesn't exist, add it!
                    track = mido.MidiTrack()
                    previous_time = 0
                    idx = num_tracks

                    # Add the track to our MIDI file
                    midi.tracks.append(track)

                    # Create a program_change event and add it to the track
                    message = mido.Message("program_change", channel=idx, program=instrument)
                    track.append(message)
                    num_tracks += 1
                finally:
                    track.append(mido.Message("note_on", note=note, channel=idx, velocity=velocity, time=time_in_ticks-previous_time))
                    track_idx[instrument] = (track, time_in_ticks, idx)
            
            # If Note Off, add Note Off to the track
            elif event_type == 1:
                try:
                    track, previous_time, idx = track_idx[instrument]
                except KeyError:
                    # If it doesn't exist, there is a problem
                    print(f"WARNING: Note Off for note {note} and instrument {instrument} before Note On")
                else:
                    track.append(mido.Message("note_off", note=note, channel=idx, time=time_in_ticks-previous_time))
                    track_idx[instrument] = (track, time_in_ticks, idx)
            
    return midi

# We can call our function
reconstructed_midi = tokens_to_midi(tokens)
reconstructed_midi.save("assets/reconstructed_midi.mid")

In [None]:
import midi2audio

# We can listen to our reconstructed MIDI
midi2audio_obj = midi2audio.FluidSynth("../session2_setup/assets/soundfont.sf2")
midi2audio_obj.midi_to_audio("assets/reconstructed_midi.mid", "assets/reconstructed_midi.wav")

y, sr = librosa.load("assets/reconstructed_midi.wav")
display(Audio(y, rate=sr))

## Hands On 3: Using HiFi-GAN to synthesize speech spectrograms

In [None]:
# Clone HiFi-GAN repository
!git clone https://github.com/jik876/hifi-gan.git ../repositories/hifi-gan

In [None]:
import huggingface_hub

# Download pretrained model from `lancelotblanchard/hifi_gan_vctk_v3`

model_path = huggingface_hub.snapshot_download(
    repo_id="lancelotblanchard/hifi_gan_vctk_v3",
    cache_dir="../huggingface_hub_cache",
)

In [None]:
# We add the path to the repository to the system path
import sys
sys.path.append("../repositories/hifi-gan")

import os
import json
from env import AttrDict

config_file = os.path.join(model_path, "config.json")
with open(config_file) as f:
    h = AttrDict(json.load(f))

In [None]:
# Let's print the config
print(json.dumps(h, indent=2))

In [None]:
from models import Generator
import torch

generator = Generator(h)
state_dict_g = torch.load(os.path.join(model_path, "generator_v3"), map_location="cpu")["generator"]
generator.load_state_dict(state_dict_g)

In [None]:
import librosa
from IPython.display import Audio, display

# Let's grab an audio file
y, sr = librosa.load("../session2_setup/assets/stargazing.wav")
display(Audio(y, rate=sr))

In [None]:
audio = torch.from_numpy(y).reshape(1, -1)

n_fft = h["n_fft"]
num_mels = h["num_mels"]
fmin = h["fmin"]
fmax = h["fmax"]
win_size = h["win_size"]
hop_size = h["hop_size"]

# we use librosa's mel() function to create the mel filterbank
mel = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel = torch.from_numpy(mel).float()

# We create a Hann window of size win_size.
# The Hann window is used to smooth the signal before applying the STFT.
hann_window = torch.hann_window(win_size)

# We pad the audio signal to make sure that the length is a multiple of hop_size
audio = torch.nn.functional.pad(audio, (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect")
audio = audio.squeeze(1)

# We compute the STFT of the audio signal
spec = torch.view_as_real(torch.stft(audio, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,
                  center=False, pad_mode="reflect", normalized=False, onesided=True, return_complex=True))

# We compute the magnitude of the STFT
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))

# Finally, we compute the mel spectrogram by applying the mel filterbank to the
# magnitude of the STFT.
spec = torch.matmul(mel, spec)

# We normalize the mel spectrogram to the range [0, 1]
spec = torch.log(torch.clamp(spec, min=1e-5) * 1)

In [None]:
# We need to reset the matplotlib backend to use the inline backend
import matplotlib
from importlib import reload
matplotlib = reload(matplotlib)
matplotlib.use("inline")

import matplotlib.pyplot as plt


# Let's plot the mel spectrogram
plt.figure(figsize=(10, 4))
plt.imshow(spec[0].detach().numpy(), aspect="auto", origin="lower")
plt.title("Mel spectrogram")
plt.xlabel("Time")
plt.ylabel("Frequency")
# plt.colorbar()
plt.show()

In [None]:
y_g_hat = generator(spec)

audio2 = y_g_hat.squeeze()

display(Audio(audio2.detach().numpy(), rate=22050))