## Imports

In [7]:

%load_ext autoreload
%autoreload 2

import os
import sys
import torch
import einops
import random
import pickle
import numpy as np
import pprint as pp
import torch.nn as nn
from pathlib import Path
from pyfiglet import Figlet
from torch.utils.data import DataLoader
from src.plots import plot_losses, plot_field_comparison

# Add the project root directory to Python path
top_dir = Path.cwd().parent
sys.path.insert(0, str(top_dir.absolute()))

from src import *

slurm_dir = top_dir / 'slurms'
pickle_dir = top_dir / 'pickles'


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:

# We will iterate through every combination of these
datasets = ["sst", "plasma", "planetswe_pod", "gray_scott_reaction_diffusion_pod"]
encoders = ["lstm", "vanilla_transformer", "sindy_attention_transformer"]
decoders = ["mlp", "unet"]
lrs = [f"lr{i:0.2e}" for i in [1e-2, 1e-3, 1e-4, 1e-5]]

# These two will be zipped pairwise
encoder_depths = [f"e{i}" for i in[1, 2, 3]]
decoder_depths = [f"d{i}" for i in[1, 2]]


In [9]:

# Get list of results
results = helpers.get_dictionaries_from_pickles(pickle_dir, early_stop=None)


In [10]:

# Print the results / data format
print(type(results))
pp.pprint(list(results[0].keys()))
pp.pprint(results[0]['hyperparameters'])


<class 'list'>
['test_loss',
 'start_epoch',
 'best_val',
 'best_epoch',
 'train_losses',
 'val_losses',
 'sensors',
 'hyperparameters']
{'batch_size': 128,
 'best_checkpoint_path': PosixPath('/home/alexey/Research4/T-SHRED/checkpoints/gru_unet_planetswe_full_e4_d1_lr1.00e-03_p1_model_best.pt'),
 'd_data': 3,
 'd_model': 150,
 'data_cols': 512,
 'data_rows': 256,
 'dataset': 'planetswe_full',
 'decoder': 'unet',
 'decoder_depth': 1,
 'device': 'cuda:1',
 'dim_feedforward': 400,
 'dropout': 0.1,
 'dt': 1.0,
 'early_stop': 10,
 'encoder': 'gru',
 'encoder_depth': 4,
 'epochs': 100,
 'eval_full': False,
 'generate_test_plots': True,
 'generate_training_plots': False,
 'hidden_size': 100,
 'include_sine': False,
 'latest_checkpoint_path': PosixPath('/home/alexey/Research4/T-SHRED/checkpoints/gru_unet_planetswe_full_e4_d1_lr1.00e-03_p1_model_latest.pt'),
 'lr': 0.001,
 'n_heads': 2,
 'n_sensors': 50,
 'n_well_tracks': 10,
 'output_size': 393216,
 'poly_order': 1,
 'save_every_n_epochs': 10,

In [11]:

best_result = helpers.get_top_N_models_by_loss('sst', pickle_dir, N=1)[0][1]
print(best_result.keys())
print(best_result['hyperparameters'].keys())
print(best_result['sensors'])


dict_keys(['test_loss', 'start_epoch', 'best_val', 'best_epoch', 'train_losses', 'val_losses', 'sensors', 'hyperparameters'])
dict_keys(['batch_size', 'dataset', 'decoder', 'decoder_depth', 'device', 'dropout', 'dt', 'early_stop', 'encoder', 'eval_full', 'encoder_depth', 'epochs', 'hidden_size', 'generate_test_plots', 'generate_training_plots', 'include_sine', 'lr', 'n_heads', 'n_sensors', 'n_well_tracks', 'poly_order', 'save_every_n_epochs', 'sindy_threshold', 'sindy_weight', 'skip_load_checkpoint', 'verbose', 'window_length', 'd_data', 'data_rows', 'data_cols', 'd_model', 'dim_feedforward', 'output_size', 'latest_checkpoint_path', 'best_checkpoint_path'])
[(98, 215), (10, 132), (130, 248), (103, 155), (122, 183), (149, 111), (129, 71), (72, 71), (64, 272), (154, 75), (18, 350), (84, 241), (143, 51), (90, 222), (80, 312), (141, 244), (113, 266), (140, 7), (0, 313), (126, 170), (62, 166), (145, 113), (139, 229), (23, 41), (81, 260), (125, 55), (77, 282), (74, 63), (140, 170), (138, 104

In [14]:

f = Figlet(font='ogre')
for dataset in ['planetswe_full', 'sst', 'plasma']:
    best_result = helpers.get_top_N_models_by_loss(dataset, pickle_dir, N=1)[0][1]
    plots.plot_model_results_scatter(results, dataset, 12, save=True, fname=f"{dataset}_scatter")


Saved /home/alexey/Research4/T-SHRED/figures/planetswe_full_scatter.pdf


Saved /home/alexey/Research4/T-SHRED/figures/sst_scatter.pdf


Saved /home/alexey/Research4/T-SHRED/figures/plasma_scatter.pdf
