In [1]:
import matplotlib.pyplot as plt
import torch
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

In [2]:
class DifferentiableKS:
    def __init__(self, resolution, domain_size, dt, dealiasing=False):
        self.resolution = resolution
        self.domain_size = domain_size
        self.dt = dt
        self.dx = domain_size / resolution
        self.dealiasing = dealiasing

        # Matrices for exponential timestepping
        self.wavenumbers = torch.fft.fftfreq(resolution, self.dx, device=device) * 1j
        self.L_mat = -(self.wavenumbers**2) - self.wavenumbers**4
        self.exp_lin = torch.exp(self.L_mat * dt)

        # Coefficients for RK2
        self.nonlinear_coef_1 = torch.where(
            self.L_mat == 0, dt, (self.exp_lin - 1) / self.L_mat
        )
        self.nonlinear_coef_2 = torch.where(
            self.L_mat == 0,
            dt / 2,
            (self.exp_lin - 1 - self.L_mat * dt) / (dt * self.L_mat**2),
        )

    # Step function for temporal evolution, using exponential timestepping for linear tearms and RK2 for nonlinear terms
    def etrk2(self, u):
        if self.dealiasing:
            u = self.dealias(u)
        nonlin_current = self.calc_nonlinear(u)
        u_interm = (
            self.exp_lin * torch.fft.fftn(u) + nonlin_current * self.nonlinear_coef_1
        )
        u_new = (
            u_interm
            + (self.calc_nonlinear(torch.fft.ifftn(u_interm)) - nonlin_current)
            * self.nonlinear_coef_2
        )
        return torch.fft.ifftn(u_new).real

    def calc_nonlinear(self, u):
        return -0.5 * self.wavenumbers * torch.fft.fftn(u**2)

    def dealias(self, u):
        # filter out the largest third of wavenumbers in the 1D spectrum
        u_hat = torch.fft.fftn(u)
        u_hat[..., self.resolution // 6 : 2 * self.resolution // 6] = 0
        return torch.fft.ifftn(u_hat).real

In [3]:
# Solver Initialization

RESOLUTION = 48
DOMAIN_SIZE = 7.2
TIMESTEP = 0.5

torchKS = DifferentiableKS(resolution=RESOLUTION, domain_size=DOMAIN_SIZE, dt=TIMESTEP, dealiasing=True)

# Initial condition
x = (
    torchKS.domain_size
    * torch.tensor(np.arange(0, torchKS.resolution), device=device)
    / torchKS.resolution
)
u_init = [
    torch.cos(2 * x * np.pi)
    + 0.1
    * torch.cos(2 * np.pi * x / torchKS.domain_size)
    * (1 - 2 * torch.sin(2 * np.pi * x / torchKS.domain_size)),
]

In [4]:
# Compute sample trajectory
u_traj = u_init
u_iter = u_init[0]
for i in range(1000):
    u_iter = torchKS.etrk2(u_iter)
    u_traj.append(u_iter)


In [5]:
u_traj_np = torch.stack(u_traj, -1).cpu().numpy()
x_np = x.cpu().numpy()

In [6]:
import plotly.graph_objects as go


fig = go.Figure(
    data=go.Heatmap(z=u_traj_np, y=x_np, colorscale="Viridis", zsmooth="best")
)

fig.update_layout(
    autosize=False,
    width=1000,
    height=500,
    margin=dict(l=10, r=10, b=10, t=10),
)
fig.show()


## Spectra

### PSD

In [7]:
from scipy.signal import welch

u_traj_np = torch.stack(u_traj, -1).cpu().numpy()
x_np = x.cpu().numpy()

N = 4

fig = go.Figure()

for i in range(N):
    time_series = u_traj_np[i, :]
    freqs, psd = welch(time_series)
    fig.add_trace(
        go.Scatter(
            x=freqs, y=psd, mode="lines", name=f"x={x_np[int(i*RESOLUTION/N)]:.3f}"
        )
    )

fig.update_layout(
    autosize=False,
    width=600,
    height=300,
    margin=dict(l=10, r=10, b=10, t=10),
    template="plotly_dark",
)
fig.update_yaxes(type="log")
fig.show()


### Energy spectrum

In [18]:
from scipy.fft import fft, fftfreq


RESOLUTION = 48

# energy spectrum 

u_traj_np = torch.stack(u_traj, -1).cpu().numpy()

N = 10

fig = go.Figure()

for i in range(N):
    space_series = u_traj_np[:, i]
    freqs = fftfreq(RESOLUTION, DOMAIN_SIZE / RESOLUTION)
    spectrum = np.abs(fft(space_series))**2
    fig.add_trace(
        go.Scatter(
            x=freqs[:RESOLUTION//2], y=spectrum[:RESOLUTION//2], mode="lines", name=f"t={i*TIMESTEP:.3f}"
        )
    )
    
fig.update_layout(
    autosize=False,
    width=600,
    height=300,
    margin=dict(l=10, r=10, b=10, t=10),
    template="plotly_dark",
)
# fig.update_xaxes(type="log")
fig.update_yaxes(type="log")
fig.show()

## Learning

Train an autoregressive model to predict the next value of the time series given the previous values:
$$
    u_{t+1} = f(u_t, u_{t-1}, \ldots, u_{t-p})
$$
