In [None]:
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from catalyst import dl
from catalyst.utils import metrics

from surfbreak.loss_functions import wave_pml 
from surfbreak.datasets import WaveformVideoDataset, WaveformChunkDataset
import explore_siren as siren


model = siren.Siren(in_features=3, out_features=1, hidden_features=128, 
                          hidden_layers=3, outermost_linear=True, first_omega_0=1, dydt=0.65)
    
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

wf_dataset = WaveformVideoDataset(ydim=120, xrange=(30,91), timerange=(0,60), time_chunk_duration_s=30, time_chunk_stride_s=15, 
                                  time_axis_scale=0.5)
chunk_dataset = WaveformChunkDataset(wf_dataset, xy_bucket_sidelen=20, samples_per_xy_bucket=100, time_sample_interval=5
                                    dataset_length=1000) # This defines the the length of an epoch

loaders = {
    "train": DataLoader(chunk_dataset, batch_size=1),
}

class WaveformChunkRunner(dl.Runner):

    def _handle_batch(self, batch):
        model_input, ground_truth = batch
        coords = model_input['coords']
        wf_values = ground_truth['wavefront_values']

        wf_values_out, coords_out = self.model(coords)

        loss_mse = F.mse_loss(wf_values_out, wf_values)
        loss = loss_mse
        self.batch_metrics = {
            "loss": loss_mse,
        }

        if self.is_train_loader:
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()


runner = WaveformChunkRunner()
runner.train(
    model=model,
    optimizer=optimizer,
    loaders=loaders,
    num_epochs=3,
    verbose=False,
    logdir='logs/catalyst_trials_1',
    main_metric='loss'
)

[2020-06-28 22:38:43,816] 
1/3 * Epoch 1 (train): loss=0.0440
[2020-06-28 22:38:43,816] 
1/3 * Epoch 1 (train): loss=0.0440
[2020-06-28 22:41:10,430] 
2/3 * Epoch 2 (train): loss=0.0246
[2020-06-28 22:41:10,430] 
2/3 * Epoch 2 (train): loss=0.0246
[2020-06-28 22:43:37,155] 
3/3 * Epoch 3 (train): loss=0.0219
[2020-06-28 22:43:37,155] 
3/3 * Epoch 3 (train): loss=0.0219
Top best models:
logs/catalyst_trials_1/checkpoints/train.3.pth	0.0219


### Next steps:

This Catalyst library is a bit of a letdown.  Switch to Pytorch Lightning instead https://github.com/PyTorchLightning/pytorch-lightning

Set up running of sequences of values using Optuna:
https://github.com/optuna/optuna

Here's an example script with an integration of the two: https://github.com/optuna/optuna/blob/master/examples/pytorch_lightning_simple.py