## Imports

In [None]:

%load_ext autoreload
%autoreload 2

import os
import sys
import pickle
import numpy as np
import pprint as pp
import pysindy as ps
from scipy import fft
from pathlib import Path
from scipy.integrate import solve_ivp
from sklearn.preprocessing import MinMaxScaler

# Ignore matplotlib deprecation warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Seed the random number generators for reproducibility
np.random.seed(100)

# Update path to include mypkg
sys.path.insert(0, str(Path(os.path.abspath('')).parent.parent.parent.absolute()))

from src import helpers, plot_data, global_config, datasets
config = global_config.config
image_dir_og = config.top_dir



## Load data (and smooth it)

In [None]:
# Hyperparameters

# Initial condition parameters
n_avg = 5 # number of curves for moving average
u_true_cutoff = 150 # final index for MNIST propogated wave

# IVP parameters
dx = None
x = None
t = None
dt = None


In [None]:

# Load the data
dataset = "MNIST"
config.top_dir = str(Path(image_dir_og).parent / "fft_images" / dataset)
file = os.path.join(config.top_dir, "pkl", f"{dataset.lower()}_og_integral.pkl")
with open(file, 'rb') as file:
    d = pickle.load(file)

# Calculate moving average
u_total = helpers.moving_average(d['ints'],n=n_avg,axis=0)
u_true = u_total[0:u_true_cutoff, :]
x = np.asarray(d['r_x'])
t_total, t_true = (np.arange(0,u_total.shape[0]) + 0., np.arange(0,u_true.shape[0]) + 0.)
dt = 1.


In [None]:

# Plots
plot_data.plot_surface(
        z=u_total,x=x,y=t_total,
        xaxis_title="x", yaxis_title="time", zaxis_title="u(t,x)",
        title='PDE Input (Full Data)',
        hovertemplate='t: %{y:0.2f}<br>x: %{x:0.2f}<br> u: %{z:0.2f}<extra></extra>',
        colorscale='agsunset'
    )

plot_data.plot_surface(
        z=u_true, x=x, y=t_total,
        xaxis_title="x", yaxis_title="time", zaxis_title="u(t,x)",
        title='PDE Input (Cropped Data)',
        hovertemplate='t: %{y:0.2f}<br>x: %{x:0.2f}<br> u: %{z:0.2f}<extra></extra>',
        colorscale='agsunset'
    )

# Plot slices of u
xs = [x,x,x]
ys = [u_true[0,:], u_true[30,:], u_true[-1,:]]
labels = [f"u(x,t=0)", f"u(x,t=30)", f"u(x,t={u_true.shape[0]-1})"]
title = "Time slices of u(x,t)"

plot_data.plot_line(x=xs, y=ys, label=labels, title=title)



## Rescale

In [None]:

x.min()
x_scaled, x_min, x_max = helpers.min_max_fit(x, 0., 1.)
t_scaled, t_min, t_max = helpers.min_max_fit(t_true, 0., 1.)
u_scaled, u_min, u_max = helpers.min_max_fit(u_true, 0., 1.)

plot_data.plot_surface(
        z=u_scaled, x=x_scaled, y=t_scaled,
        xaxis_title="x", yaxis_title="time", zaxis_title="u(t,x)",
        title='PDE Input (Cropped and Scaled Data)',
        hovertemplate='t: %{y:0.2f}<br>x: %{x:0.2f}<br> u: %{z:0.2f}<extra></extra>',
        colorscale='agsunset'
    )

x_inv = helpers.min_max_fit_inv(x_scaled, x_min, x_max, 0., 1.,)
t_inv = helpers.min_max_fit_inv(t_scaled, t_min, t_max, 0., 1.,)
u_inv = helpers.min_max_fit_inv(u_scaled, u_min, u_max, 0., 1.,)

print(t_inv[-1] - t_inv[-2])

print("||u_inv - u_true||_2", np.linalg.norm(u_inv-u_true))


## Feature Library

In [None]:

u = u_scaled; x = x_scaled; t = t_scaled;
dx = x[1]-x[0]; dt = t[1]-t[0]

dummy_u = np.random.randn(x.shape[0], t.shape[0], 1)

# Define PDE library that is quadratic in u, and
# third-order in spatial derivatives of u.
# library_functions = [lambda x: x, lambda x: x * x]
# library_function_names = [lambda x: x, lambda x: x + x]
pde_lib = ps.PDELibrary(function_library=ps.PolynomialLibrary(degree=2,include_bias=False),
                        derivative_order=4, spatial_grid=x,
                        include_bias=True, is_uniform=True)

dummy_pde_lib = pde_lib
dummy_pde_lib.fit([dummy_u])
feature_names = [helpers.modify_pde_sindy_out(feature) for feature in dummy_pde_lib.get_feature_names()]
print("Library:")
print(feature_names)


## Grid Search

In [None]:

md = helpers.pysindy_grid_search(pde_lib, u_scaled, x_scaled, t_scaled, 0.8, 30, 90, 1e-5)
print()
pp.pprint(md)


# Propagate

## Hyperparameters

In [None]:
# Initial condition parameters
u0_idx = 81 # index for u0
u_true_cutoff = 150 # final index for MNIST propogated wave

# FFT and IVP parameters
L=5
n=512
x2=np.linspace(-L/2,L/2,n+1)
x=x2[1:n+1]
dx = x_scaled[1] - x_scaled[0] + 0.
k=(2.*np.pi*fft.fftfreq(n)*n/L)
k2=fft.fftshift(k)
t = np.arange(t_scaled[0], t_scaled[-1]*3+dt, dt)

# PDE parameters
alpha = -11.91553
beta = 20.48636



## Interpolate

In [None]:


# Interpolate
u_interp = helpers.interpolate(u_scaled, x_scaled, x, 0)

# Get gaussians of raw MNIST data
means, stdevs, amps, gaussians = helpers.fit_gaussians(u_interp, x, 0)
u0 = gaussians[u0_idx]




In [None]:

# Plot

xs = [x_scaled,x,x]
ys = [u_scaled[u0_idx,:], u_interp[u0_idx,:], u0]
labels = [f"u_scaled", f"u_interp", f"u_gassuan"]
title = f"Time slices of u(x,t={u0_idx})"

plot_data.plot_line(x=xs, y=ys, label=labels, title=title)



## Solve

In [None]:

# Initial data
u0_fft_in = fft.fft(u0)

def og_pde_rhs(t, u_fft, k, alpha, beta):
    # u_fft: u in frequency domain
    # u: u in space domain
    # u_x: u partial derivative with respect to x

    u = np.real(fft.ifft(u_fft))

    u_x = np.real(fft.ifft(1j * k * u_fft))
    #u_xx = np.real(fft.ifft(-1 * k**2 * u_fft))
    #u_xxx = np.real(fft.ifft(-1* 1j * k**3 * u_fft))
    #u_xxxx = np.real(fft.ifft(k**4 * u_fft))

    rhs = alpha*fft.fft(u * u_x) + beta*fft.fft(u**2 * u_x)

    return rhs

# Solve IVP
print(t.shape)
print(u0_idx)
sol = solve_ivp(og_pde_rhs, (t[u0_idx], t[-1]), u0_fft_in, t_eval=t[u0_idx:], args=(k, alpha, beta))

# sol.y is (512, 31)
# ifft along axis 0 means compute the ifft each of the columns (sol.y[:,i])
print("sol.y shape:", sol.y.shape)
print("last t:", sol.t[-1])
print("t range:", t[u0_idx], t[-1])
usol = np.real(fft.ifft(sol.y, axis=0)).T




## Scale back to original sizes

In [None]:

x_inv = helpers.min_max_fit_inv(x, x_min, x_max, 0., 1.,)
t_inv = helpers.min_max_fit_inv(t, t_min, t_max, 0., 1.,)
u_inv = helpers.min_max_fit_inv(usol, u_min, u_max, 0., 1.,)


## Plot

In [None]:

plot_data.plot_surface(
        z=u_inv, x=x_inv,y=t_inv[u0_idx:],
        xaxis_title="x", yaxis_title="time", zaxis_title="u(t,x)",
        title='PDE Solution',
        hovertemplate='t: %{y:0.2f}<br>x: %{x:0.2f}<br> u: %{z:0.2f}<extra></extra>',
        colorscale='agsunset'
    )

