# Fit to arbitrary shapes

This notebook shows how to overfit a neural resonator to arbitrary shapes.

In [None]:
import os
from pathlib import Path

import lightning.pytorch as pl
import torch
import wandb
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers import WandbLogger
from torchvision.models.efficientnet import (EfficientNet_B0_Weights,
                                             efficientnet_b0)

from neuralresonator.data import (MultiShapeMultiMaterialDataModule,
                                  generate_random_dataset)
from neuralresonator.dsp import biquad_freqz
from neuralresonator.modal import MATERIALS
from neuralresonator.models import CoefficientsFC
from neuralresonator.utilities import FFTLoss, MelScaleLoss, plot_sample

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

Generate a random dataset

In [None]:
n_shapes = 5
n_materials = 1
n_refinements = 3
n_vertices = 13
sample_rate = 16000
audio_length_in_seconds = 0.2
samples = int(sample_rate * audio_length_in_seconds)

data_dir = Path("data")
if not data_dir.exists():
    data_dir.mkdir()

pl.seed_everything(3407, workers=True)

generate_random_dataset(
    n_shapes=n_shapes,
    n_materials=n_materials,
    materials=[MATERIALS['polycarbonate']],
    n_vertices=n_vertices,
    n_modes=32,
    n_refinements=n_refinements,
)

Lightning module class

In [None]:
class FitShapes(pl.LightningModule):
    def __init__(
        self,
        output_folder: Path,
        lr: float = 1e-4,
        n_parallel: int = 32,
        n_biquads: int = 2,
    ):
        super().__init__()

        self.output_folder = output_folder
        self.n_parallel = n_parallel
        self.n_biquads = n_biquads

        self.encoder = efficientnet_b0(
            weights=EfficientNet_B0_Weights.DEFAULT,
        )

        self.lr = lr

        self.model = CoefficientsFC(
            input_size=3,
            hidden_sizes=[1024] * 2,
            layer_norm=False,
        )

        self.criterion = FFTLoss(
            lin_l1=1.0,
            lin_l2=0.0,
            log_l1=0.2,
            log_l2=0.0,
        )

        self.training_outputs = []

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        mask = batch["mask"]
        coords = batch["coords"]
        audio = batch["audio"]

        mag_ffts = torch.fft.rfft(
            audio.float().clamp(-1, 1),
        ).abs()

        features = self.encoder(mask.repeat(1, 3, 1, 1).float())
        
        # Predict biquad coefficients
        # using the mean of the features makes the model overfit faster to the shapes
        ba = self.forward(torch.cat([features.mean(-1, keepdim=True), coords], dim=-1))
        b = ba[..., :3]
        a = ba[..., 3:]

        p_mag_ffts = biquad_freqz(b, a, audio.shape[-1]).prod(dim=-2).sum(dim=-2).abs()

        loss = self.criterion(
            p_mag_ffts,
            mag_ffts,
        )

        self.log("train/loss", loss, on_step=True, on_epoch=False)

        output = dict(
            loss=loss,
            a=a,
            b=b,
            audio=audio,
        )

        self.training_outputs.append(
            output,
        )

        return loss

    def on_train_epoch_end(self) -> None:
        batch = self.training_outputs[-1]

        with torch.no_grad():
            audio = batch["audio"][0].cpu().numpy()
            fig, pred_signal = plot_sample(
                a=batch["a"][0].cpu().numpy(),
                b=batch["b"][0].cpu().numpy(),
                gt_audio=audio,
            )
            self.logger.experiment.log({f"train_epoch_end": wandb.Image(fig)})
            wandb_gt_audio = wandb.Audio(audio, sample_rate=sample_rate)
            wandb_pred_audio = wandb.Audio(pred_signal, sample_rate=sample_rate)
            self.logger.experiment.log(
                {f"train_epoch_end/audio": [wandb_gt_audio, wandb_pred_audio]}
            )

        self.training_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(lr=self.lr, params=self.parameters())
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=0.8,
            patience=600,
        )

        lr_scheduler_config = {
            "scheduler": scheduler,
            "monitor": "train/loss",
            "frequency": 1,
            "interval": "step",
        }

        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler_config,
        }


In [None]:
output_folder = Path("output")
if not output_folder.exists():
    output_folder.mkdir()

datamodule = MultiShapeMultiMaterialDataModule(
    train_index_map_path=data_dir / "index_map.csv",
    val_index_map_path=data_dir / "index_map.csv",
    test_index_map_path=data_dir / "index_map.csv",
    batch_size=16,
    num_workers=8,
    audio_length_in_seconds=audio_length_in_seconds,
)

model = FitShapes(
    output_folder=output_folder,
)

In [None]:
logger = WandbLogger(
    project="neuralresonator",
)
lr_monitor = LearningRateMonitor(logging_interval='step')

trainer = pl.Trainer(
    max_epochs=300,
    logger=logger,
    callbacks=[lr_monitor],
)
trainer.fit(
    model=model,
    datamodule=datamodule,
)
