In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm.auto import tqdm
import numpy as np
import librosa
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from nnAudio import Spectrogram

from beatbrain.datasets.audio import AudioClipDataset

import optuna

In [3]:
AUDIO_PATH = "../data/fma/wav"
SR = 22050
MAX_SEGMENT_LENGTH=10
N_FFT = 2048
HOP_LENGTH = N_FFT // 8
N_MELS = N_FFT // 8
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 128
MAX_ITEMS = 2048
eps = 1e-12

In [4]:
np.random.seed(42)
dataset = AudioClipDataset(AUDIO_PATH, sample_rate=SR, max_segment_length=MAX_SEGMENT_LENGTH)
dataset = Subset(dataset, np.random.choice(len(dataset), size=min(len(dataset), MAX_ITEMS), replace=False))
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=False, num_workers=0)

In [5]:
# study = optuna.create_study(storage="sqlite:///optuna.db")
study = optuna.load_study("no-name-ee92733b-3897-4030-bb07-3c8b4c49937d", storage="sqlite:///optuna.db")

In [None]:
def objective(trial):
    lr = trial.suggest_loguniform("lr", 1e-3, 1e7)
    momentum = trial.suggest_categorical("momentum", [0, 0.9])
    loss_threshold = trial.suggest_loguniform("loss_threshold", 1e-10, 1e-2)
    grad_threshold = trial.suggest_loguniform("grad_threshold", 1e-11, 1e-3)
    random_start = trial.suggest_categorical("random_start", [True, False])

    print(dict(
        lr=lr,
        momentum=momentum,
        loss_threshold=loss_threshold,
        grad_threshold=grad_threshold,
        random_start=random_start,
    ))

    to_stft = Spectrogram.STFT(N_FFT, hop_length=HOP_LENGTH, sr=SR, output_format="Magnitude", device=DEVICE, verbose=False)
    to_mel = Spectrogram.MelSpectrogram(n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS, sr=SR, device=DEVICE, verbose=False)

    loss = 0
    for i, (audio, sr) in enumerate(tqdm(dataloader)):
        audio = audio.to(DEVICE)
        stft = to_stft(audio)
        mel = to_mel(audio)
        try:
            recon_stft, pred_mel, losses = to_mel.to_stft(mel, loss_threshold=loss_threshold, grad_threshold=grad_threshold, sgd_kwargs=dict(lr=lr, momentum=momentum), return_extras=True)
        except OverflowError:
            print("Overflow detected!")
            raise optuna.TrialPruned()
        if losses[-1] > losses[0]:
            print("Divergence detected!")
            raise optuna.TrialPruned()
        loss += (recon_stft - stft).pow(2).mean().item()
        trial.report(loss, i)
        if trial.should_prune():
            raise optuna.TrialPruned()
    return loss / len(dataloader.dataset)


study.optimize(objective)

In [6]:
study.best_params

{'grad_threshold': 0.0004775416410927695,
 'loss_threshold': 0.007012997184047137,
 'lr': 101170.83388141652,
 'momentum': 0.9,
 'random_start': True}

In [None]:
optuna.visualization.plot_parallel_coordinate(study).update_layout(height=1000).show()
optuna.visualization.plot_optimization_history(study).update_layout(height=1000).show()
optuna.visualization.plot_slice(study).update_layout(height=1000).show()
optuna.visualization.plot_contour(study).update_layout(height=1000).show()