### Next steps:


TODO:
1. Set up the batch dimension to be one chunk per batch.  Batchsize = n_chunks
1. Develop method for measuring wave propogation speed in image coordinates,  then find the flattened homography that makes this as constant as possible.
1. Train again with wave equation now that wavespeed is more uniform
1. Infer an accumulated wave speed field (squared slowness field) in the flattened coordinates for wavefunction normalization with non-uniform propogation speed. 
1. Create a pipeline that generates animations for each stage of the training process (normalization, clipping, learning), and compiles into a demo video.

#### First get the basic Tensorflow Lightning training loop working

In [None]:
%load_ext autoreload
%autoreload 2
import torch
import pytorch_lightning as pl
from surfbreak.waveform_models import LitSirenNet

tb_logger = pl.loggers.TensorBoardLogger('logs/', name="basic")

trainer = pl.Trainer(
    logger=tb_logger,
    limit_val_batches=1,
    max_epochs=10,
    gpus=1 if torch.cuda.is_available() else None,
    )
    
model = LitSirenNet(hidden_features=256, hidden_layers=3, first_omega_0=1., hidden_omega_0=5., squared_slowness=3.0,
                    steps_per_vid_chunk=150, learning_rate=1e-4, grad_loss_scale=0, wave_loss_scale=0)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


KeyboardInterrupt: 

### Now get optuna optimization trials working
The most well-regularized hyperparameters for simple (`mse_loss` only) ended up being 
256 hidden features, 3.7995 first_omega_0, 2.9312 hidden_omega_0

In [None]:
import os
import torch
import pytorch_lightning as pl
from surfbreak.waveform_models import LitSirenNet
from optuna.integration import PyTorchLightningPruningCallback
from surfbreak.studies import run_waveform_hyperparam_search, MetricsCallback
LOGDIR = 'logs'
MODELDIR = os.path.join(LOGDIR, 'opt_models')

def objective(trial):
    checkpoint_callback = pl.callbacks.ModelCheckpoint( # Filenames for each trial must be made unique
        os.path.join(MODELDIR, "trial_{}".format(trial.number), "{epoch}"), monitor="val_loss")
    tb_logger = pl.loggers.TensorBoardLogger(LOGDIR+'/', name="optuna")
    metrics_callback = MetricsCallback()     # Simple callback that saves metrics from each validation step.

    trainer = pl.Trainer(logger=tb_logger,
        limit_val_batches=1,
        checkpoint_callback=checkpoint_callback,
        max_epochs=20,
        gpus=1 if torch.cuda.is_available() else None,
        callbacks=[metrics_callback],
        early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="val_loss"),
        )

    model = LitSirenNet(hidden_features=256, # trial.suggest_categorical('hidden_features', [128, 256]),
                        hidden_layers=3,
                        first_omega_0=trial.suggest_loguniform('first_omega_0', 0.1, 100.),
                        hidden_omega_0=trial.suggest_loguniform('hidden_omega_0', 0.1, 100.), 
                        squared_slowness=trial.suggest_loguniform('squared_slowness',0.03, 30),
                        steps_per_vid_chunk=150, 
                        learning_rate=1e-4,
                        grad_loss_scale=0, 
                        wave_loss_scale=trial.suggest_loguniform('wave_loss_scale', 1e-8, 1e-4),
                        ) 
    trainer.fit(model)
    return metrics_callback.metrics[-1]["val_loss"].item()


study = run_waveform_hyperparam_search(objective, n_trials=100, timeout=8*60*60, model_dir=MODELDIR, prune=True, n_startup_trials=3)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type  | Params | In sizes     | Out sizes                   
------------------------------------------------------------------------------
0 | model | Siren | 198 K  | [1, 1337, 3] | [[1, 1337, 1], [1, 1337, 3]]

The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-07-01 20:30:39,018] Finished trial#0 with value: 0.12161409854888916 with parameters: {'first_omega_0': 0.31266756217719327, 'hidden_omega_0': 69.10042088702289, 'squared_slowness': 0.07196600943612112, 'wave_loss_scale': 3.175252258600564e-08}. Best is trial#0 with value: 0.12161409854888916.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type  | Params | In sizes     | Out sizes                   
------------------------------------------------------------------------------
0 | model | Siren | 198 K  | [1, 1337, 3] | [[1, 1337, 1], [1, 1337, 3]]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…