In [1]:
import torch
import numpy as np
import plotly.graph_objects as go

from src.ks.kuramoto_sivashinsky import DifferentiableKS
from src.ks.spectra import get_welsh_psd, get_energy_spectrum

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Spectra

In [3]:
# Solver Initialization
RESOLUTION = 96
DOMAIN_SIZE = 7.2
TIMESTEP = 0.5
NUM_STEPS = 2000

solver_ks = DifferentiableKS(
    resolution=RESOLUTION,
    domain_size=DOMAIN_SIZE,
    dt=TIMESTEP,
    dealiasing=True,
    device=device,
)

x = solver_ks.get_1d_grid()
x_np = x.cpu().numpy()

u_init = solver_ks.get_init()
u_traj = solver_ks.get_trajectory(u_init, num_steps=NUM_STEPS)
u_traj_torch = torch.stack(u_traj, -1)
u_traj_np = u_traj_torch.cpu().numpy()

### PSD

In [4]:
welsh_psd_fig = get_welsh_psd(x_np, u_traj_np, RESOLUTION)
welsh_psd_fig.show()

### Energy spectrum

In [5]:
energy_spectrum_fig = get_energy_spectrum(u_traj_np, RESOLUTION, DOMAIN_SIZE, TIMESTEP)
energy_spectrum_fig.show()