In [None]:
import numpy as np
import torchaudio
import torch
import matplotlib.pyplot as plt
%matplotlib inline

# Make sure to have torch installed and numpy<2

In [None]:
#run only if on colab
# to download a file from google drive, share it and get the id from the share link (be sure that you can share it with the link)
!pip install encodec
!pip install numpy<2
!pip install torchcodec
!gdown 1Rj84epzYkfj00-B79nRdU4h3fquwUBsS
!ls

In [None]:
!gdown 1EM4F6Rp2nf7E79toGGHQ68qgprYFa9VC

In [None]:
from encodec import EncodecModel   # `pip install encodec` if unavailable

model = EncodecModel.encodec_model_48khz()  # will download and cache the weights when run for the first time
model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

stft_transformer = torchaudio.transforms.Spectrogram(n_fft=1<<13, power=2)
stft_transformer_device = torchaudio.transforms.Spectrogram(n_fft=1<<13, power=2).to(device)
def compute_fp_stft(p):
    stft = stft_transformer(torch.Tensor(p)).numpy()
    stft = 10 * np.log10(np.clip( stft, 1e-10, 1e6 ) )
    fingerprint = np.mean(stft, axis=(2) ) 
    return fingerprint

def compute_stft_full_cpu(p):
        stft = stft_transformer(torch.Tensor(p))
        stft = 10 * torch.log10(torch.clamp(stft, 1e-10, 1e6))
        return stft.numpy()

def compute_stft_full_device(p):
        stft = stft_transformer_device(p)
        stft = 10 * torch.log10(torch.clamp(stft, 1e-10, 1e6))
        return stft

In [None]:
audio_path = "Fair-L - War Song - Blood and Rose.mp3"
audio, sr = torchaudio.load(audio_path, channels_first = False)
audio = torch.Tensor(audio.T).unsqueeze(0).to(device)
frames = model.encode(audio)

fft_audio_input = compute_stft_full_device(audio[0])

# we ignore the output scaling and linear interpolation between frames (1 sec audio) and directly decode everything
codes = torch.concatenate([a[0] for a in frames], -1)
with torch.no_grad():
    emb = model.quantizer.decode(codes.transpose(0, 1))
    post_lstm = model.decoder.model[:2](emb)

audio_latent = post_lstm

In [None]:
audio_latent.shape

In [None]:
# Auto-detects GPU/CPU => run on GPU or CPU, time saving : from 2min to 5s
print(f"Using device: {device}")

layers = [
    {"l": 4, "stride": 8, "title": "(3): ConvTranspose(stride=8)"},
    {"l": 7, "stride": 5, "title": "(6): ConvTranspose(stride=5)"},
    {"l": 10, "stride": 4, "title": "(9): ConvTranspose(stride=4)"},
    {"l": 13, "stride": 2, "title": "(12): ConvTranspose(stride=2)"},
]


if torch.cuda.is_available():
    model.to(device)
    audio_latent_device = audio_latent.to(device)
    
    input_stft = compute_stft_full_device(audio_latent_device[0]).cpu().numpy()
    mean_input_fft = np.mean(input_stft, axis=2)

    ffts = []
    latents = []
    with torch.no_grad():
        for layer in layers:
            latent = model.decoder.model[2:layer["l"]](audio_latent_device)
            print(f"l: {layer['l']}, latent.shape: {latent.shape}")
            latents.append(latent)
            stft_latent = compute_stft_full_device(latent[0]).cpu().numpy()
            ffts.append(np.mean(stft_latent, axis=2))

    with torch.no_grad():
        audio_rec = model.decoder.model[2:](audio_latent_device)
    output_stft = compute_stft_full_device(audio_rec[0]).cpu().numpy()
    output_fft = np.mean(output_stft, axis=2)

    print("GPU processing complete!")
else:
    input_stft = compute_stft_full_cpu(audio_latent[0])
    mean_input_fft = np.mean(input_stft, axis=2)

    ffts = []
    latents = []
    with torch.no_grad():
        for layer in layers:
            latent = model.decoder.model[2:layer["l"]](audio_latent)
            latents.append(latent)
            stft_latent = compute_stft_full_cpu(latent[0])
            ffts.append(np.mean(stft_latent, axis=2))

    with torch.no_grad():
        audio_rec = model.decoder.model[2:](audio_latent)
    output_stft = compute_stft_full_cpu(audio_rec[0])
    output_fft = np.mean(output_stft, axis=2)

    print("CPU processing complete!")



In [None]:
def plot_fractal(k=8, N=4097, y_offset=0):
    """ Display where the peaks are supposed to be. """
    x = []
    y = []
    for t in range(0, k//2+1):
        x.append(t*2*N/k)
        y.append(y_offset)
    plt.scatter(x, y, c="magenta", marker="x", linewidth=2, clip_on=False)

plt.figure(figsize=(8, 20))
plt.subplot(7, 1, 1)
fft_audio_input_mean = np.mean(fft_audio_input.cpu().numpy(), axis=2)
plt.plot(np.mean(fft_audio_input_mean, 0), color="red")
plt.title("input signal spectrum", fontsize=10)
plt.xticks([0, 2048, 4096], ["0Hz", "12kHz", "24kHz"], fontsize=8)
plt.yticks([], [])
plt.ylabel("amplitude (dB)", fontsize=8)

plt.subplot(7, 1, 2)
plt.plot(np.mean(mean_input_fft, 0), color="red")
plt.title("encoded signal spectrum", fontsize=10)
plt.xticks([], [])
plt.yticks([], [])

for i, layer in enumerate(layers):
    plt.subplot(7, 1, 3 + i)
    base = np.mean(ffts[i], 0)
    plt.plot(base, label="original")
    plt.title(layer["title"], fontsize=10)
    plt.xticks([], [])
    plt.yticks([], [])
    plot_fractal(k=layer["stride"], y_offset=np.mean(base)-8)

plt.subplot(7, 1, 7)
plt.plot(np.mean(output_fft, 0), label="output", color="red")
plt.title("generated signal spectrum", fontsize=10)

plt.xticks(fontsize=10)
plt.yticks([], [])
plt.xticks([0, 2048, 4096], ["0Hz", "12kHz", "24kHz"], fontsize=8)
plt.show()

## Audio Manipulations

In [None]:
def time_stretch_spec(spec, rate):
    c, f, t = spec.shape
    new_t = int(np.round(t / rate))
    idx = np.linspace(0, t - 1, new_t)
    out = np.empty((c, f, new_t), dtype=spec.dtype)
    for ci in range(c):
        for fi in range(f):
            out[ci, fi] = np.interp(idx, np.arange(t), spec[ci, fi])
    return out

def resample_time_tensor(x, rate):
    # x: [C, T], resample along time to change pitch
    c, t = x.shape
    new_t = int(np.round(t / rate))
    if isinstance(x, torch.Tensor):
        x_t = x.unsqueeze(0)  # [1, C, T]
        x_rs = torch.nn.functional.interpolate(x_t, size=new_t, mode="linear", align_corners=False)
        return x_rs.squeeze(0)
    idx = np.linspace(0, t - 1, new_t)
    out = np.empty((c, new_t), dtype=x.dtype)
    for ci in range(c):
        out[ci] = np.interp(idx, np.arange(t), x[ci])
    return out

In [None]:
# TIME-STRETCH comparison
stretch_rate = 1.2

# Simple algo: linear interpolation in time domain on latent
latent0 = latents[0][0] if torch.cuda.is_available() else latents[0][0]
latent0_simple_ts = resample_time_tensor(latent0, rate=stretch_rate)

# Real algo: phase vocoder on latent's spectrogram (STFT time-warp + Griffin-Lim)
n_fft_latent = 1024
hop_length_latent = 256
win_length_latent = n_fft_latent
window_latent = torch.hann_window(win_length_latent)

latent0_cpu = latent0.detach().cpu()
complex_spec_latent = torch.stft(latent0_cpu, n_fft=n_fft_latent, hop_length=hop_length_latent, 
                                  win_length=win_length_latent, window=window_latent, return_complex=True)
mag_latent = complex_spec_latent.abs().numpy()
mag_latent_ts = time_stretch_spec(mag_latent, rate=stretch_rate)
mag_latent_ts_torch = torch.from_numpy(mag_latent_ts)

latent0_real_ts = []
for c in range(mag_latent_ts_torch.shape[0]):
    recon = torchaudio.functional.griffinlim(
        mag_latent_ts_torch[c],
        window=window_latent,
        n_fft=n_fft_latent,
        hop_length=hop_length_latent,
        win_length=win_length_latent,
        power=1.0,
        n_iter=32,
        momentum=0.99,
        length=int(latent0_cpu.shape[-1] / stretch_rate),
        rand_init=True
    )
    latent0_real_ts.append(recon)
latent0_real_ts = torch.stack(latent0_real_ts, dim=0)

# Compute FFTs for visualization
if torch.cuda.is_available():
    stft0_simple = compute_stft_full_device(latent0_simple_ts.to(device)).cpu().numpy()
    stft0_real = compute_stft_full_device(latent0_real_ts.to(device)).cpu().numpy()
else:
    stft0_simple = compute_stft_full_cpu(latent0_simple_ts)
    stft0_real = compute_stft_full_cpu(latent0_real_ts)

fft0_simple_ts = np.mean(stft0_simple, axis=2)
fft0_real_ts = np.mean(stft0_real, axis=2)

In [None]:
# Visualize peak shifting: Original vs Simple vs Real algorithm
base = np.mean(ffts[0], 0)
fft_simple_mean = np.mean(fft0_simple_ts, 0)
fft_real_mean = np.mean(fft0_real_ts, 0)

plt.figure(figsize=(14, 6))

plt.subplot(1, 2, 1)
plt.plot(base, label="Original latent", color="#1f77b4", linewidth=2.5, alpha=0.9)
plt.plot(fft_simple_mean, label=f"Simple (linear interp) x{stretch_rate}", 
         color="#ff7f0e", linestyle=":", linewidth=2.0, alpha=0.8)
plt.plot(fft_real_mean, label=f"Real (STFT+Griffin-Lim) x{stretch_rate}", 
         color="#2ca02c", linestyle="--", linewidth=2.0, alpha=0.8)
plt.title("Time-Stretch: Peak Shifting in Frequency Domain", fontsize=12, fontweight='bold')
plt.xlabel("Frequency Bins", fontsize=10)
plt.ylabel("Amplitude (dB)", fontsize=10)
plt.xticks([0, 2048, 4096], ["0Hz", "12kHz", "24kHz"], fontsize=9)
plt.legend(fontsize=9, loc='upper right')
plt.grid(alpha=0.3, linestyle='--')
plot_fractal(k=layers[0]["stride"], y_offset=np.mean(base)-8)

plt.subplot(1, 2, 2)
plt.plot(base, label="Original", color="#1f77b4", linewidth=2.5, alpha=0.9)
plt.plot(fft_simple_mean, label="Simple algorithm", color="#ff7f0e", linestyle=":", linewidth=2.0, alpha=0.8)
plt.plot(fft_real_mean, label="Real algorithm", color="#2ca02c", linestyle="--", linewidth=2.0, alpha=0.8)
plt.title("Zoomed View (0-2kHz)", fontsize=12, fontweight='bold')
plt.xlabel("Frequency Bins", fontsize=10)
plt.xlim(0, 1000)
plt.ylim(np.min(base[:1000])-5, np.max(base[:1000])+5)
plt.legend(fontsize=9)
plt.grid(alpha=0.3, linestyle='--')

plt.tight_layout()
plt.savefig("time_stretch_peak_shift_comparison.png", dpi=150)
plt.show()

In [None]:
# Listen to time-stretch on REAL AUDIO
from IPython.display import Audio, display

# First decode the original audio from audio_latent (post_lstm)
with torch.no_grad():
    if torch.cuda.is_available():
        audio_original = model.decoder.model[2:](audio_latent.to(device)).cpu()
    else:
        audio_original = model.decoder.model[2:](audio_latent)

wave_original = audio_original.squeeze(0)

# Apply time-stretch SIMPLE: linear interpolation on decoded audio
wave_simple_ts = resample_time_tensor(wave_original, rate=stretch_rate)

# Apply time-stretch REAL: STFT + Griffin-Lim on decoded audio
n_fft_audio = 2048
hop_length_audio = 512
win_length_audio = n_fft_audio
window_audio = torch.hann_window(win_length_audio)

wave_cpu = wave_original.cpu()
complex_spec_audio = torch.stft(wave_cpu, n_fft=n_fft_audio, hop_length=hop_length_audio,
                                 win_length=win_length_audio, window=window_audio, return_complex=True)
mag_audio = complex_spec_audio.abs().numpy()
mag_audio_ts = time_stretch_spec(mag_audio, rate=stretch_rate)
mag_audio_ts_torch = torch.from_numpy(mag_audio_ts)

wave_real_ts = []
for c in range(mag_audio_ts_torch.shape[0]):
    recon = torchaudio.functional.griffinlim(
        mag_audio_ts_torch[c],
        window=window_audio,
        n_fft=n_fft_audio,
        hop_length=hop_length_audio,
        win_length=win_length_audio,
        power=1.0,
        n_iter=32,
        momentum=0.99,
        length=int(wave_cpu.shape[-1] / stretch_rate),
        rand_init=True
    )
    wave_real_ts.append(recon)
wave_real_ts = torch.stack(wave_real_ts, dim=0)

print("="*60)
print("AUDIO COMPARISON: Original vs Time-Stretched (Simple vs Real)")
print("="*60)
print(f"\n🎵 Original Audio")
display(Audio(wave_original.numpy(), rate=sr))

print(f"\n🎵 Time-Stretch x{stretch_rate} - SIMPLE (linear interpolation)")
display(Audio(wave_simple_ts.numpy(), rate=sr))

print(f"\n🎵 Time-Stretch x{stretch_rate} - REAL (STFT + Griffin-Lim)")
display(Audio(wave_real_ts.numpy(), rate=sr))

print("\n" + "="*60)
print("Notice: Real algorithm preserves frequency content better")
print("Simple algorithm may introduce artifacts and frequency distortion")
print("="*60)

In [None]:
# PITCH-SHIFT comparison
pitch_rate = 0.8

# Simple algo: linear interpolation (time-domain resampling changes pitch)
latent0_simple_ps = resample_time_tensor(latent0, rate=pitch_rate)

# Real algo: resampling with anti-aliasing filter
latent0_real_ps = torchaudio.functional.resample(
    latent0.cpu(), 
    orig_freq=int(1000),  
    new_freq=int(1000 * pitch_rate)
)

# Compute FFTs for visualization
if torch.cuda.is_available():
    stft0_simple_ps = compute_stft_full_device(latent0_simple_ps.to(device)).cpu().numpy()
    stft0_real_ps = compute_stft_full_device(latent0_real_ps.to(device)).cpu().numpy()
else:
    stft0_simple_ps = compute_stft_full_cpu(latent0_simple_ps)
    stft0_real_ps = compute_stft_full_cpu(latent0_real_ps)

fft0_simple_ps = np.mean(stft0_simple_ps, axis=2)
fft0_real_ps = np.mean(stft0_real_ps, axis=2)

In [None]:
# Visualize peak shifting for pitch-shift: Original vs Simple vs Real
base = np.mean(ffts[0], 0)
fft_simple_ps_mean = np.mean(fft0_simple_ps, 0)
fft_real_ps_mean = np.mean(fft0_real_ps, 0)

plt.figure(figsize=(14, 6))

plt.subplot(1, 2, 1)
plt.plot(base, label="Original latent", color="#1f77b4", linewidth=2.5, alpha=0.9)
plt.plot(fft_simple_ps_mean, label=f"Simple (linear interp) x{pitch_rate}", 
         color="#d62728", linestyle=":", linewidth=2.0, alpha=0.8)
plt.plot(fft_real_ps_mean, label=f"Real (resample w/ anti-alias) x{pitch_rate}", 
         color="#9467bd", linestyle="--", linewidth=2.0, alpha=0.8)
plt.title("Pitch-Shift: Peak Shifting in Frequency Domain", fontsize=12, fontweight='bold')
plt.xlabel("Frequency Bins", fontsize=10)
plt.ylabel("Amplitude (dB)", fontsize=10)
plt.xticks([0, 2048, 4096], ["0Hz", "12kHz", "24kHz"], fontsize=9)
plt.legend(fontsize=9, loc='upper right')
plt.grid(alpha=0.3, linestyle='--')
plot_fractal(k=layers[0]["stride"], y_offset=np.mean(base)-8)

plt.subplot(1, 2, 2)
plt.plot(base, label="Original", color="#1f77b4", linewidth=2.5, alpha=0.9)
plt.plot(fft_simple_ps_mean, label="Simple algorithm", color="#d62728", linestyle=":", linewidth=2.0, alpha=0.8)
plt.plot(fft_real_ps_mean, label="Real algorithm", color="#9467bd", linestyle="--", linewidth=2.0, alpha=0.8)
plt.title("Zoomed View (0-2kHz) - Note Frequency Shift", fontsize=12, fontweight='bold')
plt.xlabel("Frequency Bins", fontsize=10)
plt.xlim(0, 1000)
plt.ylim(np.min(base[:1000])-5, np.max(base[:1000])+5)
plt.legend(fontsize=9)
plt.grid(alpha=0.3, linestyle='--')

plt.tight_layout()
plt.savefig("pitch_shift_peak_shift_comparison.png", dpi=150)
plt.show()

In [None]:
# Listen to pitch-shift on REAL AUDIO

# Apply pitch-shift SIMPLE: linear interpolation on decoded audio
wave_simple_ps = resample_time_tensor(wave_original, rate=pitch_rate)

# Apply pitch-shift REAL: resample with anti-aliasing on decoded audio
wave_real_ps = torchaudio.functional.resample(
    wave_original,
    orig_freq=int(sr),
    new_freq=int(sr * pitch_rate)
)

print("="*60)
print("AUDIO COMPARISON: Original vs Pitch-Shifted (Simple vs Real)")
print("="*60)
print(f"\n🎵 Original Audio")
display(Audio(wave_original.numpy(), rate=sr))

print(f"\n🎵 Pitch-Shift x{pitch_rate} - SIMPLE (linear interpolation)")
display(Audio(wave_simple_ps.numpy(), rate=sr))

print(f"\n🎵 Pitch-Shift x{pitch_rate} - REAL (resample with anti-aliasing)")
display(Audio(wave_real_ps.numpy(), rate=sr))

print("\n" + "="*60)
print("Notice: Pitch shift changes frequency content")
print("Real algorithm uses anti-aliasing to prevent aliasing artifacts")
print("="*60)