### Next steps:

TODO:
1. Develop method for measuring wave propogation speed in image coordinates,  then find the flattened homography that makes this as constant as possible.
1. Create a pipeline that generates animations for each stage of the training process (normalization, clipping, learning), and compiles into a demo video.
1. Get tensorboard metrics logging working with this logger class: https://github.com/PyTorchLightning/pytorch-lightning/issues/1228#issuecomment-622963564


#### First get the basic Tensorflow Lightning training loop working

### Now get optuna optimization trials working

TODO:
1. Improve early stopping such that I have control over patience parameter and ensure the _minimum_ validation loss is passed to the optuna study.

The most well-regularized hyperparameters for simple (`mse_loss` only) ended up being 
256 hidden features, 3.7995 first_omega_0, 2.9312 hidden_omega_0

With wavefunc loss, `squared_slowness` of around 0.5 may be close

```
Finished trial#26 with value: 0.09781524538993835 with parameters: 
{'first_omega_0': 4.839289222946841, 'hidden_omega_0': 13.756932872278343, 'squared_slowness': 0.27488941275825124, 'wave_loss_scale': 9.252313787089657e-08}
```

In [None]:
import os
import torch
import pytorch_lightning as pl
from surfbreak.waveform_models import WaveformNet
from surfbreak.datasets import WavefrontDatasetTXYC, MaskedWavefrontBatchesNC, CachedDataset
from optuna.integration import PyTorchLightningPruningCallback
from surfbreak.studies import run_waveform_hyperparam_search, MetricsCallback

os.chdir('/home/erik/work/surfbreak/nbs')

# NOTE! Change this environment variable during each optimization experiemnt to group them properly on wandb
os.environ["WANDB_RUN_GROUP"] = "test1"

LOGDIR = 'wandb'
MODELDIR = os.path.join(LOGDIR, 'opt_models')

def objective(trial):
    checkpoint_callback = pl.callbacks.ModelCheckpoint( # Filenames for each trial must be made unique
        os.path.join(MODELDIR, "trial_{}".format(trial.number), "{epoch}"), monitor="val_loss")

    wandb_logger = pl.loggers.wandb.WandbLogger(name="wfnet_opt_{}".format(trial.number), save_dir=LOGDIR, project='surfbreak', log_model=True)
    #tb_logger = pl.loggers.TensorBoardLogger(LOGDIR+'/', name="opt_v2")
    metrics_callback = MetricsCallback()     # Simple callback that saves metrics from each validation step.
    
    pl.seed_everything(42)

    training_video = '../data/shirahama_1590387334_SURF-93cm.ts'
    cnn_checkpoint ='../models/simplecnn_shirahama.ckpt'
    start_s=60 
    duration_s=30
    max_epochs=10
    wf_net_kwargs = dict(        
        hidden_features=trial.suggest_categorical('hidden_features', [256]),
        hidden_layers=trial.suggest_categorical('hidden_layers', [3]),
        first_omega_0=trial.suggest_uniform('first_omega_0', 2.5, 4.5), #2.5, 
        hidden_omega_0=trial.suggest_uniform('hidden_omega_0', 10, 15), #11,
        squared_slowness=trial.suggest_uniform('squared_slowness', 1.0, 2.0), #1.0,
        learning_rate=2e-4,
        wavefunc_loss_scale=trial.suggest_loguniform('wavefunc_loss_scale', 1e-12, 1e-12), #1e-9,
        wavespeed_loss_scale=trial.suggest_loguniform('wavespeed_loss_scale', 1e-10, 1e-9), #1e-12,
        wavespeed_first_omega_0=trial.suggest_uniform('ws_fo0', 1.0,2.0), #0.5
        wavespeed_hidden_omega_0=trial.suggest_uniform('ws_ho0',1.0,3.0), #2.0
        wfloss_growth_scale=trial.suggest_loguniform('wfloss_growth_scale', 1.5, 4),
    )

    # Train consists of 4-second chunks with a 1-second gap between each
    txy_cache = CachedDataset(WavefrontDatasetTXYC, training_video, timerange=(start_s,start_s+duration_s), 
                                                    time_chunk_duration_s=3, time_chunk_stride_s=4, 
                                                    wavecnn_ckpt=cnn_checkpoint)
    wf_train_dataset = MaskedWavefrontBatchesNC(txy_cache,samples_per_batch=600, included_time_fraction=1.0)
    
    # Validation covers last few seconds if the waveform, plus next few seconds (ability to extrapolate is desireable)
    txy_valid = CachedDataset(WavefrontDatasetTXYC, training_video, timerange=(start_s,start_s+duration_s), 
                                                    time_chunk_duration_s=4, time_chunk_stride_s=4, 
                                                    wavecnn_ckpt=cnn_checkpoint)
    wf_valid_dataset = MaskedWavefrontBatchesNC(txy_valid, samples_per_batch=600, included_time_fraction=0.5)
    
    # Visualize the last 25s of the waveform, plus the 5 seconds of validation-only data 
    viz_inftxy_dataset = WavefrontDatasetTXYC(training_video, timerange=(start_s+duration_s-25,start_s+duration_s+5), 
                                             time_chunk_duration_s=30, time_chunk_stride_s=30,
                                             wavecnn_ckpt=cnn_checkpoint)
   
    wavefunc_model = WaveformNet(train_dataset=wf_train_dataset, valid_dataset=wf_valid_dataset,
                                 viz_dataset=viz_inftxy_dataset, batch_size=100,
                                 **wf_net_kwargs)    
    
    trainer = pl.Trainer(logger=wandb_logger, #limit_val_batches=50,
                         max_epochs=max_epochs, 
                         gpus=1 if torch.cuda.is_available() else None,
                         callbacks=[metrics_callback],
                         early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="val_loss"),
                        )

    trainer.fit(wavefunc_model)
    return metrics_callback.metrics[-1]["val_loss"].item()


study = run_waveform_hyperparam_search(objective, n_trials=100, timeout=12*60*60, model_dir=MODELDIR, 
                                       prune=True, n_startup_trials=3, n_warmup_steps=5)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]



  | Name           | Type  | Params | In sizes     | Out sizes                   
---------------------------------------------------------------------------------------
0 | model          | Siren | 198 K  | [1, 1337, 3] | [[1, 1337, 1], [1, 1337, 3]]
1 | slowness_model | Siren | 8 K    | ?            | ?                           


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

In [None]:
study.best_params

{'hidden_features': 380,
 'hidden_layers': 3,
 'first_omega_0': 2.235704528549165,
 'hidden_omega_0': 10.007569568983525,
 'squared_slowness': 0.341543868797459,
 'wavefunc_loss_scale': 5.97776850767863e-09,
 'wavespeed_loss_scale': 0.00010504655720688889}

In [None]:
sdf = study.trials_dataframe()
top_trials = sdf.sort_values(by='value')[:5]
top_trials

Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_first_omega_0,params_hidden_features,params_hidden_layers,params_hidden_omega_0,params_squared_slowness,params_wavefunc_loss_scale,params_wavespeed_loss_scale,state
0,0,0.131143,2020-07-04 23:21:25.800355,2020-07-04 23:49:03.413944,00:27:37.613589,2.235705,380,3,10.00757,0.341544,5.977769e-09,0.000105,COMPLETE
11,11,0.131585,2020-07-05 04:15:30.272110,2020-07-05 04:40:23.186604,00:24:52.914494,2.496738,256,3,10.969015,0.204919,5.471891e-09,0.000435,COMPLETE
10,10,0.132419,2020-07-05 03:50:36.195951,2020-07-05 04:15:30.269835,00:24:54.073884,2.469356,256,3,10.737842,0.202324,5.536735e-09,0.000381,COMPLETE
16,16,0.137052,2020-07-05 06:19:26.508430,2020-07-05 06:44:19.426469,00:24:52.918039,1.573309,128,3,11.413517,0.301143,8.719026e-09,0.000751,COMPLETE
14,14,0.137106,2020-07-05 05:29:46.862792,2020-07-05 05:54:38.817953,00:24:51.955161,2.275557,256,3,11.61578,0.688081,9.518309e-09,0.00026,COMPLETE


In [None]:
print(top_trials.mean())
print(top_trials.std())

number                                            9.5
value                                        0.136229
duration                       0 days 00:25:27.666368
params_first_omega_0                          1.95588
params_hidden_features                            255
params_hidden_layers                                3
params_hidden_omega_0                         11.3476
params_squared_slowness                      0.400915
params_wavefunc_loss_scale                8.04706e-09
params_wavespeed_loss_scale               0.000499383
dtype: object
number                                        5.12696
value                                      0.00405406
duration                       0 days 00:01:14.189584
params_first_omega_0                         0.630964
params_hidden_features                         95.253
params_hidden_layers                                0
params_hidden_omega_0                         0.97345
params_squared_slowness                      0.179524
params_wavefun