In [8]:
import numpy as np
import torch
from scipy.io import loadmat

# DeepMoD stuff
from deepymod_torch import DeepMoD
from deepymod_torch.model.func_approx import NN
from deepymod_torch.model.library import Library1D
from deepymod_torch.model.constraint import LeastSquares
from deepymod_torch.model.sparse_estimators import  Threshold, PDEFIND
from deepymod_torch.training import train_split_full
from deepymod_torch.training.sparsity_scheduler import TrainTestPeriodic

if torch.cuda.is_available():
    device ='cuda'
else:
    device = 'cpu'

# Settings for reproducibility
np.random.seed(42)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [9]:
# Prepping data
data = loadmat('kuramoto_sivishinky.mat')

t = data['tt']
x = data['x']
u = data['uu']
x_grid, t_grid = np.meshgrid(x, t, indexing='ij')

x_grid = x_grid[:, :100]
t_grid = t_grid[:, :100]
u = u[:, :100]

X = np.transpose((t_grid.flatten(), x_grid.flatten()))
y = u.reshape((u.size, 1))

noise_level = 0.05
y_noisy = y + noise_level * np.std(y) * np.random.randn(y[:,0].size, 1)
number_of_samples = 25000

idx = np.random.permutation(y.shape[0])
X = torch.tensor(X[idx, :][:number_of_samples], dtype=torch.float32).to(device)
y = torch.tensor(y_noisy[idx, :][:number_of_samples], dtype=torch.float32).to(device)

In [10]:
network = NN(2, [30, 30, 30, 30, 30, 30, 30], 1)
library = Library1D(poly_order=1, diff_order=4) # Library function
estimator = PDEFIND(lam=1e-4) # Sparse estimator 
constraint = LeastSquares() # How to constrain
model = DeepMoD(network, library, estimator, constraint).to(device) # Putting it all in the model

In [11]:
sparsity_scheduler = TrainTestPeriodic(periodicity=50, patience=8, delta=1e-5) # in terms of write iterations
optimizer = torch.optim.Adam(model.parameters(), betas=(0.99, 0.999), amsgrad=True, lr=2e-3) # Defining optimizer

In [None]:
train_split_full(model, X, y, optimizer,sparsity_scheduler, log_dir='runs/noisy_0.05_bigger_net/', split=0.8, test='full', write_iterations=25, max_iterations=100000, delta=1e-5, patience=20) 

| Iteration | Progress | Time remaining |     Loss |      MSE |      Reg |    L1 norm |
      12025     12.03%            7536s   3.16e-02   3.04e-02   1.20e-03   2.22e+00 