# Pre-Process Fakeprints for Training

This notebook generates the fakeprints for the training dataset. It loads each audio file, resamples it to the model's sampling rate if necessary, and computes the fakeprints using the `get_fakeprints` function. The resulting fakeprints are saved as `.npz` files in the specified output directory.

In [None]:
import os
import torch
import numpy as np

from nnAudio.features import CQT

project_root = os.path.abspath(os.path.join(os.getcwd(), "..",".."))
os.chdir(project_root)

ai_dir = "src/checkpoints/fp/ai"
human_dir = "src/checkpoints/fp/human"

os.makedirs(os.path.join(project_root, ai_dir), exist_ok=True)
os.makedirs(os.path.join(project_root, human_dir), exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

N_FFT = 1 << 14  # 16384
SR = 48000
nyquist = SR / 2
F_MIN = 32.7  # C1 note frequency
BINS_PER_OCTAVE = 96
F_RANGE = [200, 6000]

# 1. Pre-Process Function

In [None]:
import soxr
import torchaudio

from tqdm.notebook import tqdm
from src.models.utils import get_freqs, get_cqt, get_fakeprints

def preprocess(
    file_paths,
    n_fft=16384,
    sampling_rate=48000,
    bins_per_octave=96,
    freq_range=[1000, 22000],
    db_range=[-80, 5],
    f_min=32.7, # C1 note frequency
    device=torch.device("cpu"),
):
    assert device.type != "mps", "MPS device is not supported for this preprocessing pipeline. Please use CPU or CUDA."

    hop_length = n_fft // 2
    nyquist = sampling_rate / 2  # Maximum frequency that can be represented
    n_octaves = np.log2(nyquist / f_min) - 0.1  # Subtract a small margin to ensure we don't exceed Nyquist
    nbins = int(n_octaves * bins_per_octave)  # Total number of CQT bins to cover the desired frequency range

    cqt_layer = CQT(
        sr=sampling_rate,
        hop_length=hop_length,
        fmin=f_min,
        n_bins=nbins,
        bins_per_octave=bins_per_octave,
        output_format="Magnitude",
        verbose=False,
    ).to(device)

    freqs, mask = get_freqs(
        n_bins=nbins,
        sr=sampling_rate,
        bins_per_octave=bins_per_octave,
        freq_range=freq_range,
        f_min=f_min
    )

    fakeprints = []
    for path in tqdm(file_paths, desc="Processing audio files"):
        try:
            waveform, sr = torchaudio.load(path, channels_first=True)
        except Exception as e:
            print(f"Error loading {path}: {e}")
            continue
        if sr != sampling_rate:
            print(f"Resampling {path} from {sr} Hz to {sampling_rate} Hz")
            waveform = soxr.resample(waveform.T, sr, sampling_rate).T
            waveform = torch.from_numpy(waveform).to(device)

        waveform = waveform.mean(dim=0, keepdim=True).to(device)  # Convert to mono
        cqt = get_cqt(cqt_layer, waveform) # (1, n_bins, T')
        spec = cqt.mean(dim=-1).squeeze(0)  # (n_bins,)
        
        spec_crop = spec[mask]
        fp = get_fakeprints(spec_crop, freqs, db_range=db_range)

        fakeprints.append(fp)

    fakeprints = torch.stack(fakeprints, dim=0)  # (N, freqs)
    return freqs, fakeprints

## 2. Preprocess Generated Dataset

In [None]:
import glob

num_samples = 500

data_dir = "/path/to/datasets/suno_v5_500"

file_paths = glob.glob(f"{data_dir}/*.mp3")
file_paths = file_paths[:num_samples]

freqs, ai_fp = preprocess(
    file_paths,
    n_fft=N_FFT,
    sampling_rate=SR,
    bins_per_octave=BINS_PER_OCTAVE,
    freq_range=F_RANGE,
    device=DEVICE,
)

### Save Fakeprints

In [None]:
np.savez(
    f"{ai_dir}/fakeprints_01.npz",
    fakeprints=ai_fp.cpu().numpy(),
)

## 3. Preprocess Human Dataset

In [None]:
import glob

num_samples = 500

data_dir = "/path/to/datasets/fma_small"

file_paths = glob.glob(f"{data_dir}/**/*.mp3", recursive=True)
file_paths = file_paths[:num_samples]

freqs, human_fp = preprocess(
    file_paths,
    n_fft=N_FFT,
    sampling_rate=SR,
    bins_per_octave=BINS_PER_OCTAVE,
    freq_range=F_RANGE,
    device=DEVICE,
)

### Save Fakeprints

In [None]:
np.savez(
    f"{human_dir}/fakeprints_01.npz",
    fakeprints=human_fp.cpu().numpy(),
)

## 4. Load Fakeprints for Visualization

In [None]:
from src.models.utils import get_freqs

def load_fp(file_path):
    file = np.load(file_path)
    fakeprints = file["fakeprints"]
    return fakeprints

fakeprints = load_fp(f"{ai_dir}/fakeprints_01.npz")

nyquist = SR / 2  # Maximum frequency that can be represented
n_octaves = np.log2(nyquist / F_MIN) - 0.1  # Subtract a small margin to ensure we don't exceed Nyquist
nbins = int(n_octaves * BINS_PER_OCTAVE)  # Total number of CQT bins to cover the desired frequency range

freqs, _ = get_freqs(n_bins=nbins, sr=SR, bins_per_octave=BINS_PER_OCTAVE, freq_range=F_RANGE, f_min=F_MIN)

In [None]:
import matplotlib.pyplot as plt

N = 10

def plot_fp(freqs, fakeprints, log_scale=True):
    plt.figure(figsize=(12, 5))
    for i, fp in enumerate(fakeprints):
        plt.plot(freqs, fp, label=f"Fakeprint {i}", alpha=0.8)
    if log_scale:
        plt.xscale('log')
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Normalized Residue')
    plt.title('Fakeprint Comparison')
    plt.legend()
    plt.grid()
    plt.show()

plot_fp(freqs, fakeprints[:N], log_scale=True)