# IIR Optimisation Example

Dinding IIR filter parameters by gradient descent

In [None]:
import math

import IPython.display as ipd
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F

from neuralresonator.dsp import IIRMethod, IIRParameters, apply_iir
from neuralresonator.modal import render_modes, Material, MATERIALS, System
from neuralresonator.utilities import save_and_display_audio


We begin by creating our target signal, here a 0.3 second long sample from a synthetically
excited glass surface.

In [None]:
T = 0.3
f_s = 24000
N = int(T * f_s)

num_modes = 128

m = MATERIALS["ceramic"]
s = System(m)

target_freqs = torch.from_numpy(s.damped_frequencies[:num_modes])
target_amplitudes = torch.from_numpy(s.get_mode_gains(100)[:num_modes])
target_decays = torch.from_numpy(s.damping_coefficients[:num_modes])

target_signal = torch.from_numpy(s.render(T, f_s, truncate_modes=num_modes)).float()
plt.plot(target_signal)

In [None]:
save_and_display_audio(target_signal.numpy(), "target_signal.wav", int(f_s))

We define a simple FFT loss function:

In [None]:
def fft_loss(
    pred_signal,
    target_signal,
    pred_is_fft: bool = False,
    lin_l1: float = 1.0,
    lin_l2: float = 0.0,
    log_l1: float = 0.0,
    log_l2: float = 0.0,
):
    pred_fft = (
        torch.fft.rfft(pred_signal).abs() if not pred_is_fft else pred_signal.abs()
    )
    target_fft = torch.fft.rfft(target_signal).abs()

    pred_fft = pred_fft  # / math.sqrt(pred_fft.shape[-1])
    target_fft = target_fft  # / math.sqrt(target_fft.shape[-1])

    return (
        (pred_fft - target_fft).abs().mean() * lin_l1
        + (pred_fft - target_fft).pow(2).mean() * lin_l2
        + (torch.log(pred_fft + 1e-7) - torch.log(target_fft + 1e-7)).abs().mean()
        * log_l1
        + (torch.log(pred_fft + 1e-7) - torch.log(target_fft + 1e-7)).pow(2).mean()
        * log_l2
    )

Next we randomly generate some parameters for our hybrid filter bank. We will be optimising
these to match the modes of our target signal.

In [None]:
n_parallel = 16
n_biquads = 8

poles = torch.atanh(0.999 * torch.ones(n_parallel, n_biquads)) * torch.exp(1j * torch.rand(n_parallel, n_biquads) * math.pi)
zeros = 0.0 * torch.rand(n_parallel, n_biquads) * torch.exp(1j * torch.rand(n_parallel, n_biquads) * math.pi)
# poles = torch.randn(n_parallel, n_biquads) + 1j * torch.randn(n_parallel, n_biquads)
# zeros = torch.randn(n_parallel, n_biquads) + 1j * torch.randn(n_parallel, n_biquads)

gains = torch.ones(n_parallel, n_biquads) #/ (n_parallel ** (1 / n_biquads))
gains = gains / (gains.sum())

poles.requires_grad_(True)
zeros.requires_grad_(True)
gains.requires_grad_(True)

IIRParameters(poles=poles.detach(), zeros=zeros.detach(), gains=gains.detach(), constrain_zeros=False).pole_zero_plot()

Finally, we run an optimisation loop and find that the differentiable IIR filter bank is
able to fit the target signal very closely. Note that we optimise the approximate frequency
response using the `freqz` method of the `IIRParameters` class, but are able to synthesise
the signal and visualise the empirical frequency response by pinging the recursive implementation with an impulse.

In [None]:
from neuralresonator.dsp import biquad_freqz, constrain_complex_pole_or_zero, pole_or_zero_to_iir_coeff

steps = 10000
lr = 1e-3

impulse = torch.zeros(N)
impulse[0] = 1.0

optimizer = torch.optim.AdamW([poles, zeros, gains], lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, factor=0.5, patience=2500, verbose=True
)

for step in range(steps):
    b = pole_or_zero_to_iir_coeff(zeros) * gains[..., None]
    a = pole_or_zero_to_iir_coeff(constrain_complex_pole_or_zero(poles, c=1, d=1)) 
    H = biquad_freqz(b, a, N).prod(dim=-2).sum(dim=-2).abs()

    loss = fft_loss(
        H,
        target_signal,
        pred_is_fft=True,
        lin_l1=0.0,
        lin_l2=1.0,
        log_l1=0.0,
        log_l2=0.2,
    )

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step(loss)

    if step % 50 == 0:
        print(f"Step {step}: {loss.item():.8f}", end="\r")

    if step % 1000 == 0 or step == steps - 1:
        print(f"Step {step}: {loss.item():.8f}")
        print(f"Rendering audio and plotting results...")
        with torch.no_grad():
            params = IIRParameters(
                poles=poles,
                zeros=zeros,
                gains=gains,
                constrain_zeros=False,
            )
            pred_signal = apply_iir(impulse, params, IIRMethod.TDFII)

            print("Target")
            save_and_display_audio(
                target_signal.numpy(), f"target_signal_{step}.wav", int(f_s)
            )
            print("Prediction")
            save_and_display_audio(
                pred_signal.detach().cpu().numpy(), f"pred_signal_{step}.wav", int(f_s)
            )

            pred_fft = torch.fft.rfft(pred_signal, norm="forward")
            target_fft = torch.fft.rfft(target_signal, norm="forward")

            fig, ax = plt.subplots(3, 1, figsize=(12, 8))

            ax[0].plot(target_fft.abs().detach().cpu().log10().mul(10))
            ax[0].plot(pred_fft.abs().detach().cpu().log10().mul(10))
            ax[0].set_title("Magnitude")

            ax[1].plot(target_fft.angle().detach().cpu())
            ax[1].plot(pred_fft.angle().detach().cpu())
            ax[1].set_title("Phase")

            ax[2].plot(target_signal.detach().cpu())
            ax[2].plot(pred_signal.detach().cpu())
            ax[2].set_title("Signal")

            params.pole_zero_plot()

            plt.show()
