In [1]:
import argparse
import time

import jax
import jax.numpy as jnp
import numpy as np
import soundfile as sf
from IPython.display import Audio
from matplotlib import pyplot as plt
from scipy.io import loadmat, savemat

from vkplatejax.excitations import create_1d_raised_cosine
from vkplatejax.ftm import (
    PlateParameters,
    damping_term_simple,
    stiffness_term,
)
from vkplatejax.num_utils import (
    compute_coupling_matrix_numerical,
    multiresolution_eigendecomposition,
)
from vkplatejax.sv import (
    A_inv_vector,
    B_vector,
    C_vector,
    make_vk_nl_fn,
    solve_sv_vk_jax_scan,
)
from vkplatejax.excitations import create_1d_raised_cosine
from vkplatejax.ftm import (
    PlateParameters,
    damping_term_simple,
    eigenvalues_from_pde,
    evaluate_rectangular_eigenfunctions,
    plate_eigenvalues,
    plate_wavenumbers,
    stiffness_term,
)
from vkplatejax.sv import (
    A_inv_vector,
    B_vector,
    C_vector,
    solve_sv_berger_jax_scan,
    solve_sv_vk_jax_scan,
    make_tm_nl_fn,
)

In [2]:
n_modes = 100
sampling_rate = 44100
h = 0.004  # grid spacing in the lowest resolution
nx = 50  # number of grid points in the x direction in the lowest resolution
ny = 75  # number of grid points in the y direction in the lowest resolution
levels = 2  # number of grid refinements to perform
excitation_duration = 1.0
excitation_amplitude = 0.3
output_file = None
n_max_modes_x = 20
n_max_modes_y = 20
n_modes = 10

In [3]:
sampling_period = 1 / sampling_rate

if output_file is None:
    output_file = f"benchmark_input_{n_modes:03d}_tm.mat"

params = PlateParameters(
    E=2e12,
    nu=0.3,
    rho=7850,
    h=5e-4,
    l1=0.2,
    l2=0.3,
    Ts0=0,
)

In [4]:
wnx, wny = plate_wavenumbers(
    n_max_modes_x,
    n_max_modes_y,
    params.l1,
    params.l2,
)
lambda_mu = plate_eigenvalues(wnx, wny)

indices = np.argsort(lambda_mu.ravel())[:n_modes]
ky_indices, kx_indices = np.unravel_index(indices, lambda_mu.shape)
ky_indices, kx_indices = ky_indices + 1, kx_indices + 1
# note that these are switched and 1 added
selected_indices = np.stack([kx_indices, ky_indices], axis=-1)
lambda_mu = lambda_mu.reshape(-1).sort()[:n_modes]
omega_mu_squared = stiffness_term(params, lambda_mu)
c = damping_term_simple(np.sqrt(omega_mu_squared))
print(np.sqrt(omega_mu_squared))

[ 860.7548 1655.2977 2648.4766 2979.536  3443.0193 4767.2573 4833.469
 5628.012  6422.555  6621.191 ]


In [5]:
A_inv = A_inv_vector(sampling_period, c * 2)
B = B_vector(sampling_period, omega_mu_squared) * A_inv
C = C_vector(sampling_period, c * 2) * A_inv

In [12]:
force_position = (0.05, 0.05)
mode_gains_at_pos = evaluate_rectangular_eigenfunctions(
    selected_indices,
    force_position,
    params=params,
)

# generate a 1d raised cosine excitation
rc = create_1d_raised_cosine(
    duration=2.0,
    start_time=0.010,
    end_time=0.012,
    amplitude=10.0,
    sample_rate=44100,
)

# the modal excitation needs to be scaled by A_inv and divided by the density
mode_gains_at_pos_normalised = (mode_gains_at_pos / params.density) * A_inv
modal_excitation_normalised = rc[:, None] * mode_gains_at_pos_normalised

print(modal_excitation_normalised.max(), modal_excitation_normalised.min())

1.3095380307396015e-09 0.0


In [13]:
print(mode_gains_at_pos)

[0.35355339 0.61237244 0.5        0.70710678 0.8660254  1.
 0.61237244 0.35355339 0.61237244 0.8660254 ]


In [11]:
eigenfunction_norm = 0.25 * params.l1 * params.l2
tau_with_norm = (
    (params.E * params.h)
    / (2 * params.l1 * params.l2 * (1 - params.nu**2))
    * lambda_mu
    / eigenfunction_norm
)

print(tau_with_norm)

[2.1758388e+14 4.1843055e+14 6.6948883e+14 7.5317493e+14 8.7033553e+14
 1.2050798e+15 1.2218171e+15 1.4226636e+15 1.6235103e+15 1.6737222e+15]


In [9]:
nl_fn = jax.jit(make_tm_nl_fn(lambda_mu, tau_with_norm))
_, modal_sol = solve_sv_vk_jax_scan(
    A_inv,
    B,
    C,
    modal_excitation_normalised,
    nl_fn=nl_fn,
)

In [10]:
readout_position = (0.1, 0.1)

mode_gains_at_pos = evaluate_rectangular_eigenfunctions(
    selected_indices,
    readout_position,
    params=params,
)

out_pos = (modal_sol @ mode_gains_at_pos) / eigenfunction_norm

display(Audio(out_pos, rate=44100))