# Session 5: Representation Learning for Music

**Agenda:**

- Comparison of Music representations
- Case Study: Encodec
- Hands On 1: Using Encodec
- Hands On 2: Encoding MIDI
- Hands On 3: Using HiFi-GAN

## Comparison of some representations

From [Comparing Representations for Audio Synthesis Using Generative Adversarial Networks (Nistal, Lattner, Richard, 2020)](https://arxiv.org/abs/2006.09266).

![](./assets/comparison_representations.png)

## Deep Dive into Encodec

Encodec is a *neural audio compression framework* developed by Meta, that
enables the efficient compression of high-fidelity audio into a compact discrete
representation.

![](./assets/encodec_diagram.png)

### Understanding Residual Vector Quantization (RVQ)

![](./assets/rvq.png)

A limitation of this method is that it takes $N_q$ steps (with many lookups!) to
quantize one single vector. So this becomes very costly to quantize an entire 
tensor.

### The Encodec speedup: RVQ through Transformers

![](./assets/rvq_transformer.png)

## Hands On 1: Using Encodec to encode and decode audio waveforms

We can directly use the `EncodecModel` provided by the `transformers` package.

In [None]:
import os
os.environ["HF_HUB_CACHE"] = os.path.abspath("../huggingface_hub_cache/")

from transformers import EncodecModel, EncodecFeatureExtractor

model = EncodecModel.from_pretrained("facebook/encodec_24khz")
processor = EncodecFeatureExtractor.from_pretrained("facebook/encodec_24khz")

In [None]:
from torchinfo import summary
import torch

# Again, let's get some info using torchinfo
summary(model, input_data=[torch.randn(1, 1, 320)])

In [None]:
# How do we get this downsampling factor value? We can just multiply all of the
# strides!
downsampling_factor = ...

print(f"Downsampling factor: {downsampling_factor}")

In [None]:
import librosa

# Let's load the audio file
y, sr = librosa.load("../session2_setup/assets/stargazing.wav", sr=processor.sampling_rate)

# Let's first go through the processor
inputs = ...

# Let's select the lowest bandwidth
bandwidth = ...
print(f"Target bandwidth: {bandwidth} kbps")

# We can now pass the inputs to the model to get the encoder outputs
encoder_outputs = ...

# Let's get the shape of the encoder outputs
print(encoder_outputs.audio_codes.shape)

# Let's calculate the compression ratio
compression_ratio = ...
print(f"Compression ratio: {compression_ratio:.2f}")

In [None]:
from IPython.display import Audio, display

# Let's decode those codes back to audio
audio_values = ...

display(Audio(audio_values.detach().numpy(), rate=processor.sampling_rate))

In [None]:
# Let's try it one more time with a higher bandwidth!
bandwidth_high = ...
print(f"Target bandwidth: {bandwidth_high} kbps")

encoder_outputs_high = ...
print(f"Encoder outputs shape: {encoder_outputs_high.audio_codes.shape}")

# Let's decode those codes back to audio
audio_values_high = ...

display(Audio(audio_values_high.detach().numpy(), rate=processor.sampling_rate))

In [None]:
# We can print some of the codes
print(encoder_outputs_high.audio_codes)

In [None]:
# Let's dive a bit through the encoder (before the quantization)
pre_quantized = ...
print(f"Pre-quantized shape: {pre_quantized.shape}")
print(f"Hidden size of the encoder: {model.config.hidden_size}")

In [None]:
# Let's calculate the outputs of the first codebook and print them
cb1 = ...
print(cb1)

In [None]:
# How far away are we? Let's decode the first code and compare to our first vector
x_hat = ...
print(x_hat.shape)

# Let's calculate the Euclidean distance between the two vectors
dist = ...
print(f"Euclidean distance: {dist:.4f}")

In [None]:
# Is it the best we can do? Let's verify whether `quantize` gave us the best code
distances = ...

argmin = torch.argmin(torch.tensor(distances))

print(f"Minimum distance is {distances[argmin]:.4f} for code {argmin.item()}")

In [None]:
# Let's calculate the outputs of the SECOND codebook, and print them.
residual = ...

cb2 = ...
print(cb2)

In [None]:
# Let's check whether we improved the distance
x_hat_2 = ...
print(x_hat_2.shape)

# Let's calculate the Euclidean distance between the two vectors
dist = ...
print(f"Euclidean distance: {dist:.4f}") # hope that this is smaller than the previous one

In [None]:
# Let's calculate the final distance
n_codebooks = ...
all_x_hats = ...

# Let's calculate the Euclidean distance between the two vectors
final_dist = ...
print(f"Euclidean distance: {final_dist:.4f}") # hope that this is smaller than the previous one

In [None]:
import matplotlib.pyplot as plt

# Let's plot how distance decreases with higher bandwidth / more codebooks
distances = ...

target_codebooks = ...

plt.figure(figsize=(10, 4))

plt.show()

## Hands On 2: Encoding MIDI using the Anticipatory Music Transformer strategy

In [None]:
import mido
from collections import defaultdict

# This is our time resolution (how many ticks per second)
TIME_RESOLUTION = 100

# Let's first load the MIDI file
midi = mido.MidiFile("../session2_setup/assets/symphony40.mid")

def midi_to_tokens(midi):
    ...
    pass

# We can call our function
tokens = midi_to_tokens(midi)
print(f"Length of sequence: {len(tokens)}")


In [None]:
def tokens_to_midi(tokens, bpm=120):
    ...
    pass

# We can call our function
reconstructed_midi = tokens_to_midi(tokens)
reconstructed_midi.save("assets/reconstructed_midi.mid")

In [None]:
import midi2audio

# We can listen to our reconstructed MIDI
midi2audio_obj = midi2audio.FluidSynth("../session2_setup/assets/soundfont.sf2")
midi2audio_obj.midi_to_audio("assets/reconstructed_midi.mid", "assets/reconstructed_midi.wav")

y, sr = librosa.load("assets/reconstructed_midi.wav")
display(Audio(y, rate=sr))

## Hands On 3: Using HiFi-GAN to synthesize speech spectrograms

In [None]:
# Clone HiFi-GAN repository
!git clone https://github.com/jik876/hifi-gan.git ../repositories/hifi-gan

In [None]:
import huggingface_hub

# Download pretrained model from `lancelotblanchard/hifi_gan_vctk_v3`

model_path = huggingface_hub.snapshot_download(
    repo_id="lancelotblanchard/hifi_gan_vctk_v3",
    cache_dir="../huggingface_hub_cache",
)

In [None]:
# We add the path to the repository to the system path
import sys
sys.path.append("../repositories/hifi-gan")

import os
import json
from env import AttrDict

config_file = os.path.join(model_path, "config.json")
with open(config_file) as f:
    h = AttrDict(json.load(f))

In [None]:
# Let's print the config
print(json.dumps(h, indent=2))

In [None]:
from models import Generator
import torch

generator = Generator(h)
state_dict_g = torch.load(os.path.join(model_path, "generator_v3"), map_location="cpu")["generator"]
generator.load_state_dict(state_dict_g)

In [None]:
import librosa
from IPython.display import Audio, display

# Let's grab an audio file
y, sr = librosa.load("../session2_setup/assets/stargazing.wav")
display(Audio(y, rate=sr))

In [None]:
audio = torch.from_numpy(y).reshape(1, -1)

n_fft = h["n_fft"]
num_mels = h["num_mels"]
fmin = h["fmin"]
fmax = h["fmax"]
win_size = h["win_size"]
hop_size = h["hop_size"]

# we use librosa's mel() function to create the mel filterbank
mel = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel = torch.from_numpy(mel).float()

# We create a Hann window of size win_size.
# The Hann window is used to smooth the signal before applying the STFT.
hann_window = torch.hann_window(win_size)

# We pad the audio signal to make sure that the length is a multiple of hop_size
audio = torch.nn.functional.pad(audio, (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect")
audio = audio.squeeze(1)

# We compute the STFT of the audio signal
spec = torch.view_as_real(torch.stft(audio, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,
                  center=False, pad_mode="reflect", normalized=False, onesided=True, return_complex=True))

# We compute the magnitude of the STFT
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))

# Finally, we compute the mel spectrogram by applying the mel filterbank to the
# magnitude of the STFT.
spec = torch.matmul(mel, spec)

# We normalize the mel spectrogram to the range [0, 1]
spec = torch.log(torch.clamp(spec, min=1e-5) * 1)

In [None]:
# We need to reset the matplotlib backend to use the inline backend
import matplotlib
from importlib import reload
matplotlib = reload(matplotlib)
matplotlib.use("inline")

import matplotlib.pyplot as plt


# Let's plot the mel spectrogram
plt.figure(figsize=(10, 4))
plt.imshow(spec[0].detach().numpy(), aspect="auto", origin="lower")
plt.title("Mel spectrogram")
plt.xlabel("Time")
plt.ylabel("Frequency")
# plt.colorbar()
plt.show()

In [None]:
y_g_hat = generator(spec)

audio2 = y_g_hat.squeeze()

display(Audio(audio2.detach().numpy(), rate=22050))