### Next steps:

This Catalyst library is a bit of a letdown.  Switch to Pytorch Lightning instead https://github.com/PyTorchLightning/pytorch-lightning

Set up running of sequences of values using Optuna:
https://github.com/optuna/optuna

Here's an example script with an integration of the two: https://github.com/optuna/optuna/blob/master/examples/pytorch_lightning_simple.py

#### First get the basic Tensorflow Lightning training loop working

In [None]:
import os
import pkg_resources
import shutil

import matplotlib.pyplot as plt

import pytorch_lightning as pl
from pytorch_lightning import Callback
import torch
import torch.nn.functional as F
from torch.optim import Adam
import torch.utils.data

import optuna
from optuna.integration import PyTorchLightningPruningCallback

from surfbreak.loss_functions import wave_pml 
from surfbreak.datasets import WaveformVideoDataset, WaveformChunkDataset
import explore_siren as siren

BATCHSIZE = 1
EPOCHS = 5
STEPS_PER_VID_CHUNK = 150  # This defines the single-tensor resampled dataset length
DIR = os.path.join(os.getcwd(), 'logs')
MODEL_DIR = os.path.join(DIR, "result")
LEARNING_RATE = 1e-4

        
class LitSirenNet(pl.LightningModule):
    def __init__(self, hidden_features=128, hidden_layers=3, first_omega_0=3, hidden_omega_0=0.3, squared_slowness=3.0):
        super().__init__()
        self.save_hyperparameters()
        self.model = siren.Siren(in_features=3, 
                                 out_features=1, 
                                 hidden_features=hidden_features,
                                 hidden_layers=hidden_layers, outermost_linear=True,
                                 first_omega_0=first_omega_0,
                                 hidden_omega_0=hidden_omega_0,
                                 squared_slowness=squared_slowness) 
        
        self.example_input_array = torch.ones(1,1337,3)

    def forward(self, data):
        return self.model(data)

    def training_step(self, batch, batch_nb):
        model_input, ground_truth = batch
        wf_values_out, coords_out = self.model(model_input['coords'])
        loss = F.mse_loss(wf_values_out, ground_truth['wavefront_values'])
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):
        model_input, ground_truth = batch
        wf_values_out, coords_out = self.model(model_input['all_coords'])
        loss = F.mse_loss(wf_values_out, ground_truth['all_wavefront_values'])
        wf_gt_txy = ground_truth['all_wavefront_values'].reshape(ground_truth['full_tensor_shape']).cpu()
        wf_out_txy = wf_values_out.reshape(ground_truth['full_tensor_shape']).cpu()
        fig, axes = plt.subplots(ncols=5,nrows=2, figsize=(10,6), sharey=True)
        for idx in range(5):
            axes[0][idx].imshow(wf_out_txy[5*idx,:,:].T)
            axes[1][idx].imshow(wf_gt_txy[5*idx,:,:].T)
        plt.tight_layout()
        self.logger.experiment.add_figure('val_xyslice', fig, self.current_epoch)
        # Attempting to pass logs to tensorboard on each validation step doesn't work - just validation_epoch_end()
        return {'val_loss':loss}

    def validation_epoch_end(self, outputs):
        avg_loss = sum(x["val_loss"] for x in outputs) / len(outputs)

        # Pass the accuracy to the `DictLogger` via the `'log'` key.
        tensorboard_logs = {'avg_val_loss': avg_loss}
        return {"avg_val_loss": avg_loss, "log":tensorboard_logs}

    def configure_optimizers(self):
        return Adam(self.model.parameters(), lr=LEARNING_RATE)
    
    def setup(self, stage):
        # Train on a dataset consisting of 30-second chunks offset by 30 seconds
        self.wf_train_video_dataset = WaveformVideoDataset(ydim=120, xrange=(30,91), timerange=(0,61), time_chunk_duration_s=30, 
                                                     time_chunk_stride_s=30, time_axis_scale=0.5)
        self.wf_train_chunk_dataset = WaveformChunkDataset(self.wf_train_video_dataset, xy_bucket_sidelen=20, samples_per_xy_bucket=100, 
                                                     time_sample_interval=5, steps_per_video_chunk=STEPS_PER_VID_CHUNK)
        # Validate on a dataset centered on the gap between the two training video chunks. Evaluate the MSE in this center area.
        # Having the same center timepoint will ensure the centered time representations are aligned between training and validation
        self.wf_valid_video_dataset = WaveformVideoDataset(ydim=120, xrange=(30,91), timerange=(25,36), time_chunk_duration_s=10, 
                                                     time_chunk_stride_s=10, time_axis_scale=0.5)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.wf_train_chunk_dataset, batch_size=BATCHSIZE, shuffle=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.wf_valid_video_dataset, batch_size=1, shuffle=False)


In [None]:
tb_logger = pl.loggers.TensorBoardLogger('logs/', name="basic_tests")

trainer = pl.Trainer(
    logger=tb_logger,
    limit_val_batches=1,
    max_epochs=EPOCHS,
    gpus=1 if torch.cuda.is_available() else None,
    )
    
model = LitSirenNet()
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type  | Params | In sizes     | Out sizes                   
------------------------------------------------------------------------------
0 | model | Siren | 50 K   | [1, 1337, 3] | [[1, 1337, 1], [1, 1337, 3]]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

### Now get optuna optimization trials working

In [None]:
class MetricsCallback(Callback):
    """PyTorch Lightning metric callback."""
    def __init__(self):
        super().__init__()
        self.metrics = []

    def on_validation_end(self, trainer, pl_module):
        self.metrics.append(trainer.callback_metrics)

def objective(trial):
    # Filenames for each trial must be made unique in order to access each checkpoint.
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        os.path.join(MODEL_DIR, "trial_{}".format(trial.number), "{epoch}"), monitor="loss"
    )
    
    tb_logger = pl.loggers.TensorBoardLogger('logs/', name="optuna_tests")
    
    # The default logger in PyTorch Lightning writes to event files to be consumed by
    # TensorBoard. We don't use any logger here as it requires us to implement several abstract
    # methods. Instead we setup a simple callback, that saves metrics from each validation step.
    metrics_callback = MetricsCallback()
    
    trainer = pl.Trainer(
        logger=tb_logger,
        limit_val_batches=1,
        checkpoint_callback=checkpoint_callback,
        max_epochs=EPOCHS,
        gpus=1 if torch.cuda.is_available() else None,
        callbacks=[metrics_callback],
        early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="avg_val_loss"),
        )
                                     
    model = LitSirenNet(hidden_features= trial.suggest_categorical('hidden_features', [128, 256]),
                        hidden_layers=3,
                        first_omega_0=trial.suggest_loguniform('first_omega_0', 0.01, 100.),
                        hidden_omega_0=trial.suggest_loguniform('hidden_omega_0', 0.01, 100.), 
                        squared_slowness=3.0)
    trainer.fit(model)

    return metrics_callback.metrics[-1]["avg_val_loss"].item()


### Optimal values from this run:
Setting both omegas to ~0.3, with 128 hidden features gives the best results. 

In [None]:
pruner = optuna.pruners.MedianPruner()  #optuna.pruners.NopPruner()

study = optuna.create_study(direction="minimize", pruner=pruner)
study.optimize(objective, n_trials=100, timeout=30*60) # run for 30 minutes or 100 trials

print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
trial = study.best_trial
print("  Value: {}".format(trial.value))

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

shutil.rmtree(MODEL_DIR)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type  | Params | In sizes     | Out sizes                   
------------------------------------------------------------------------------
0 | model | Siren | 198 K  | [1, 1337, 3] | [[1, 1337, 1], [1, 1337, 3]]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-06-29 22:20:12,096] Finished trial#0 with value: 1.5067410469055176 with parameters: {'hidden_features': 256, 'first_omega_0': 1.7405572714572137, 'hidden_omega_0': 0.1330713578907485}. Best is trial#0 with value: 1.5067410469055176.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type  | Params | In sizes     | Out sizes                   
------------------------------------------------------------------------------
0 | model | Siren | 198 K  | [1, 1337, 3] | [[1, 1337, 1], [1, 1337, 3]]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-06-29 22:23:40,670] Finished trial#1 with value: 0.2186404913663864 with parameters: {'hidden_features': 256, 'first_omega_0': 0.9549486752596997, 'hidden_omega_0': 0.49118714942649216}. Best is trial#1 with value: 0.2186404913663864.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type  | Params | In sizes     | Out sizes                   
------------------------------------------------------------------------------
0 | model | Siren | 50 K   | [1, 1337, 3] | [[1, 1337, 1], [1, 1337, 3]]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-06-29 22:27:05,305] Finished trial#2 with value: 0.10818317532539368 with parameters: {'hidden_features': 128, 'first_omega_0': 0.083083956332908, 'hidden_omega_0': 0.11347110156424996}. Best is trial#2 with value: 0.10818317532539368.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type  | Params | In sizes     | Out sizes                   
------------------------------------------------------------------------------
0 | model | Siren | 50 K   | [1, 1337, 3] | [[1, 1337, 1], [1, 1337, 3]]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-06-29 22:30:29,321] Finished trial#3 with value: 0.10992781072854996 with parameters: {'hidden_features': 128, 'first_omega_0': 0.02085500125394689, 'hidden_omega_0': 20.535140125585162}. Best is trial#2 with value: 0.10818317532539368.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type  | Params | In sizes     | Out sizes                   
------------------------------------------------------------------------------
0 | model | Siren | 198 K  | [1, 1337, 3] | [[1, 1337, 1], [1, 1337, 3]]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…