# Fitting to a single material shape combination

In [None]:
from pathlib import Path

import lightning.pytorch as pl
import torch
import wandb
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.loggers import WandbLogger
from skfem import MeshTri
from torch.utils.data import DataLoader, Dataset

from neuralresonator.data import SingleShapeDataset
from neuralresonator.dsp import (biquad_freqz, constrain_complex_pole_or_zero,
                                 pole_or_zero_to_iir_coeff)
from neuralresonator.models import CoefficientsFC
from neuralresonator.utilities import (FFTLoss, MelScaleLoss, plot_sample,
                                       to_zpk)

Create a material shape combination and fit it to a set of coordinates.

In [None]:
sample_rate = 16000
audio_length_in_seconds = 0.2
samples = int(sample_rate * audio_length_in_seconds)

In [None]:

class FitSingleShape(pl.LightningModule):
    def __init__(
            self,
            output_folder: Path,
        ):
        super().__init__()
        
        self.model = CoefficientsFC(
            input_size=2,
            hidden_sizes=[1024] * 2,
            layer_norm=False,
        )

        self.output_folder = output_folder

        self.criterion = FFTLoss(
            lin_l1=1.0,
            lin_l2=0.0,
            log_l1=0.01,
            log_l2=0.0,
        )

        self.training_outputs = []

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        coords = batch["coords"]
        audio = batch["audio"]

        mag_ffts = torch.fft.rfft(
            audio.float().clamp(-1, 1),
        ).abs()

        ba = self.forward(coords)
        b = ba[..., :3]
        a = ba[..., 3:]

        p_mag_ffts = (
            biquad_freqz(b, a, audio.shape[-1]).prod(dim=-2).sum(dim=-2).abs()
        )

        loss = self.criterion(
            p_mag_ffts,
            mag_ffts,
        )

        self.log("train/loss", loss)

        output = dict(
            loss=loss,
            a=a,
            b=b,
            audio=audio,
        )

        self.training_outputs.append(
            output,
        )

        return output["loss"]
    
    def on_train_epoch_end(self) -> None:
        batch = self.training_outputs[-1]
        self.log("train/loss", batch["loss"], prog_bar=True,)

        with torch.no_grad():
            audio = batch["audio"][0].cpu().numpy()
            fig, pred_signal = plot_sample(
                a=batch["a"][0].cpu().numpy(),
                b=batch["b"][0].cpu().numpy(),
                gt_audio=audio,
            )
            self.logger.experiment.log({f"train_epoch_end": wandb.Image(fig)})
            wandb_gt_audio = wandb.Audio(audio, sample_rate=sample_rate)
            wandb_pred_audio = wandb.Audio(pred_signal, sample_rate=sample_rate)
            self.logger.experiment.log(
                {f"train_epoch_end/audio": [wandb_gt_audio, wandb_pred_audio]}
            )

        self.training_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(lr=5e-5, params=self.parameters())
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            factor=0.8,
            patience=300,
        )

        lr_scheduler_config = {
            "scheduler": scheduler,
            "monitor": "train/loss",
            "frequency": 1,
            "interval": "step",
        }

        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler_config,
        }


In [None]:
output_folder = Path("output")
if not output_folder.exists():
    output_folder.mkdir()

model = FitSingleShape(
    output_folder=output_folder,
)

dataset = SingleShapeDataset(
    mesh=MeshTri.init_circle(smoothed=True),
    audio_length_in_seconds=audio_length_in_seconds,
    sample_rate=sample_rate,
    n_refinements=3,
)
train_loader = DataLoader(
    dataset=dataset,
    batch_size=8,
    shuffle=True,
    num_workers=8,
)


In [None]:
logger = WandbLogger(
    project="neuralresonator",
)
lr_monitor = LearningRateMonitor(logging_interval='step')

trainer = pl.Trainer(
    max_epochs=200,
    logger=logger,
    callbacks=[
        lr_monitor,
    ],
    enable_checkpointing=False,
)

trainer.fit(
    model=model,
    train_dataloaders=train_loader,
)
