# Training

> Pytorch lightning modules for training

In [None]:
# | default_exp training

In [None]:
# | export

from typing import Any, Type

import lightning as pl
import torch
from torch import nn, optim
from torchmetrics.regression.mae import MeanAbsoluteError
from torchmetrics.regression.mse import MeanSquaredError
from torchvision.models.efficientnet import (EfficientNet_B0_Weights,
                                             efficientnet_b0)

from neuralresonator.dsp import biquad_freqz
from neuralresonator.utilities import plot_sample, FFTLoss
import wandb

In [None]:
# | export


class MultiShapeMultiMaterialLitModule(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        optimizer: Type[optim.Optimizer],
        scheduler: Type[optim.lr_scheduler.LRScheduler],
        criterion: nn.Module = FFTLoss(),
        **kwargs,
    ):
        super().__init__()

        self.save_hyperparameters(logger=False)

        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.criterion = criterion

        self.encoder = efficientnet_b0(
            weights=EfficientNet_B0_Weights.DEFAULT,
        )

        self.mse = MeanSquaredError()
        self.mae = MeanAbsoluteError()

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        return self.model(x)

    def step(self, batch: Any):
        mask = batch["mask"]
        coords = batch["coords"]
        audio = batch["audio"]
        material_params = batch["material_params"]

        mag_ffts = torch.fft.rfft(
            audio.float().clamp(-1, 1),
        ).abs()

        # Repeat mask to match weights
        features = self.encoder(mask.repeat(1, 3, 1, 1).float())

        # Predict biquad coefficients
        ba = self.forward(torch.cat([features, coords, material_params], dim=-1))
        b = ba[..., :3]
        a = ba[..., 3:]

        p_mag_ffts = biquad_freqz(b, a, audio.shape[-1]).prod(dim=-2).sum(dim=-2).abs()

        loss = self.criterion(
            p_mag_ffts,
            mag_ffts,
        )

        return dict(
            loss=loss,
            a=a,
            b=b,
            p_mag_ffts=p_mag_ffts,
            mag_ffts=mag_ffts,
            audio=audio,
        )

    def get_first_and_plot(
        self,
        batch: Any,
        name: str,
    ) -> None:
        with torch.no_grad():
            # Get the first sample from the batch
            a = batch["a"][0].cpu().numpy()
            b = batch["b"][0].cpu().numpy()
            audio = batch["audio"][0].cpu().numpy()
            fig, pred_signal = plot_sample(
                a=a,
                b=b,
                gt_audio=audio,
            )
            wandb.log({f"{name}/plot": wandb.Image(fig)})
            wandb_gt_audio = wandb.Audio(audio, sample_rate=32000)
            wandb_pred_audio = wandb.Audio(pred_signal, sample_rate=32000)
            wandb.log({f"{name}/audio": [wandb_gt_audio, wandb_pred_audio]})

    def training_step(
        self,
        batch: Any,
        batch_idx: int,
    ):
        batch_output: dict = self.step(batch)
        self.log("train/loss", batch_output["loss"], on_step=True, on_epoch=False)
        if batch_idx % self.trainer.val_check_interval == 0 and self.logger:
            self.get_first_and_plot(
                batch=batch_output,
                name="train_epoch_end",
            )

        return batch_output["loss"]

    def validation_step(
        self,
        batch: Any,
        batch_idx: int,
    ):
        batch_output: dict = self.step(batch)
        self.log("val/loss", batch_output["loss"])
        return None

    def test_step(
        self,
        batch: Any,
        batch_idx: int,
    ):
        batch_output: dict = self.step(batch)

        p_mag_ffts = batch_output["p_mag_ffts"]
        mag_ffts = batch_output["mag_ffts"]

        self.mae(torch.log(p_mag_ffts + 1e-10), torch.log(mag_ffts + 1e-10))
        self.mse(torch.log(p_mag_ffts + 1e-10), torch.log(mag_ffts + 1e-10))
        return None

    def on_test_epoch_end(
        self,
    ):
        self.log("test/mae", self.mae.compute())
        self.log("test/mse", self.mse.compute())
        return None

    def configure_optimizers(
        self,
    ):
        optimizer = self.optimizer(self.parameters())
        lr_scheduler_config = {
            "scheduler": self.scheduler(optimizer),
            "monitor": "train/loss",
            "frequency": 1,
            "interval": "step",
        }

        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler_config,
        }


Try to run a single batch

In [None]:
# | eval: false

from neuralresonator.data import MultiShapeMultiMaterialDataModule
from neuralresonator.models import FC
from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate

dataset_args = dict()

datamodule = MultiShapeMultiMaterialDataModule(
    train_index_map_path="data/index_map.csv",
    val_index_map_path="data/index_map.csv",
    test_index_map_path="data/index_map.csv",
)

cfg = OmegaConf.create(
    {
        "_target_": "neuralresonator.training.MultiShapeMultiMaterialLitModule",
        "model": {
            "_target_": "neuralresonator.models.CoefficientsFC",
            "input_size": 1007,
            "hidden_sizes": [1024, 1024, 1024, 1024, 1024, 1024],
            "n_parallel": 32,
            "n_biquads": 2,
        },

        "criterion": {
            "_target_": "neuralresonator.utilities.FFTLoss",
        },
        "optimizer": {
            "_target_": "torch.optim.Adam",
            "_partial_": True,
            "lr": 0.0001,
        },
        "scheduler": {
            "_target_": "torch.optim.lr_scheduler.ExponentialLR",
            "_partial_": True,
            "gamma": 0.999,
            "verbose": True,
        },
    }
)
from lightning.pytorch import loggers

model = instantiate(cfg)
logger = loggers.WandbLogger(project="neuralresonator")

trainer = pl.Trainer(
    limit_train_batches=1,
    max_epochs=1,
    limit_val_batches=1,
    logger=logger,
)

trainer.fit(model=model, datamodule=datamodule)


  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type              | Params
------------------------------------------------
0 | model     | CoefficientsFC    | 7.7 M 
1 | criterion | FFTLoss           | 0     
2 | encoder   | EfficientNet      | 5.3 M 
3 | mse       | MeanSquaredError  | 0     
4 | mae       | MeanAbsoluteError | 0     
------------------------------------------------
12.9 M    Trainable params
0         Non-trainable params
12.9 M    Total params
51.785    Total estimated model params size (MB)


Adjusting learning rate of group 0 to 1.0000e-04.


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [None]:
# | eval: false

# Checkpointing
print(f"Model hparams: {model.hparams}")
trainer.save_checkpoint("checkpoint.ckpt")

# Load checkpoint
model = MultiShapeMultiMaterialLitModule.load_from_checkpoint("checkpoint.ckpt")


Model hparams: "criterion": FFTLoss()
"model":     CoefficientsFC(
  (fc): FC(
    (activation): LeakyReLU(negative_slope=0.2, inplace=True)
    (network): Sequential(
      (0): FCBlock(
        (activation): LeakyReLU(negative_slope=0.2, inplace=True)
        (linear): Linear(in_features=1007, out_features=1024, bias=True)
        (ln): Identity()
      )
      (1): FCBlock(
        (activation): LeakyReLU(negative_slope=0.2, inplace=True)
        (linear): Linear(in_features=1024, out_features=1024, bias=True)
        (ln): Identity()
      )
      (2): FCBlock(
        (activation): LeakyReLU(negative_slope=0.2, inplace=True)
        (linear): Linear(in_features=1024, out_features=1024, bias=True)
        (ln): Identity()
      )
      (3): FCBlock(
        (activation): LeakyReLU(negative_slope=0.2, inplace=True)
        (linear): Linear(in_features=1024, out_features=1024, bias=True)
        (ln): Identity()
      )
      (4): FCBlock(
        (activation): LeakyReLU(negative_slo

  rank_zero_warn(
  rank_zero_warn(
