## Imports

In [None]:


%load_ext autoreload
%autoreload 2

import os
import sys
import pickle
import numpy as np
import pprint as pp
import pysindy as ps
from pathlib import Path
from sklearn.preprocessing import MinMaxScaler

# Ignore matplotlib deprecation warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Seed the random number generators for reproducibility
np.random.seed(100)

# Update path to include mypkg
mypkg_path = str(Path(os.path.abspath('')).parent.parent.parent.absolute())
print(mypkg_path)
sys.path.insert(0, mypkg_path)

from src import helpers, plot_data, global_config, datasets
config = global_config.config
image_dir_og = config.top_dir


## Load data

In [None]:

dataset = "MNIST"
config.top_dir = str(Path(image_dir_og).parent / "fft_images" / dataset)

print("pde_surface_dict.pkl:")
file = os.path.join(config.top_dir, "pkl", f"pde_surface_dict.pkl")
with open(file, 'rb') as file:
    pde_surface_dict = pickle.load(file)

pp.pprint(pde_surface_dict.keys())

pde_lib = pde_surface_dict["pde_lib"]
u_train = pde_surface_dict["u_train"]
print(u_train.shape)
print(type(u_train))
dt = pde_surface_dict["dt"]

for n_terms in pde_surface_dict.keys():

    if n_terms == "pde_lib" or\
       n_terms == "u_train" or\
       n_terms == "dt":
        continue

    threshold, alpha, _ = pde_surface_dict[n_terms]
    optimizer = ps.STLSQ(threshold=threshold, alpha=alpha, normalize_columns=True)
    model = ps.SINDy(feature_library=pde_lib, optimizer=optimizer)
    model.fit(u_train, t=dt)
    model.print()
