## Imports

In [6]:

%load_ext autoreload
%autoreload 2

import os
import sys
import torch
import einops
import random
import pickle
import numpy as np
import pprint as pp
import torch.nn as nn
from pathlib import Path
from pyfiglet import Figlet
from torch.utils.data import DataLoader
from src.plots import plot_losses, plot_field_comparison

# Add the project root directory to Python path
top_dir = Path.cwd().parent
sys.path.insert(0, str(top_dir.absolute()))

from src import *

slurm_dir = top_dir / 'slurms'
pickle_dir = Path("/") / "data" / "alexey" / "T-SHRED_backup" / 'pickles_3'
print(pickle_dir)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/data/alexey/T-SHRED_backup/pickles_3


In [7]:

# We will iterate through every combination of these
datasets = ["sst", "plasma", "planetswe_pod", "gray_scott_reaction_diffusion_pod"]
encoders = ["lstm", "vanilla_transformer", "sindy_attention_transformer"]
decoders = ["mlp", "unet"]
lrs = [f"lr{i:0.2e}" for i in [1e-2, 1e-3, 1e-4, 1e-5]]

# These two will be zipped pairwise
encoder_depths = [f"e{i}" for i in[1, 2, 3]]
decoder_depths = [f"d{i}" for i in[1, 2]]


In [8]:

# Get list of results
results = helpers.get_dictionaries_from_pickles(pickle_dir, early_stop=None)


In [9]:

# Print the results / data format
print(type(results))
pp.pprint(list(results[0].keys()))
pp.pprint(results[0]['hyperparameters'])


<class 'list'>
['test_loss',
 'start_epoch',
 'best_val',
 'best_epoch',
 'train_losses',
 'val_losses',
 'sensors',
 'hyperparameters']
{'batch_size': 128,
 'best_checkpoint_path': PosixPath('/home/alexey/Research4/T-SHRED/checkpoints/sindy_loss_transformer_mlp_planetswe_e1_d1_lr1.00e-03_p1_model_best.pt'),
 'd_data_in': 3,
 'd_data_out': 3,
 'd_model': 150,
 'data_cols_in': 512,
 'data_cols_out': 512,
 'data_rows_in': 256,
 'data_rows_out': 256,
 'dataset': 'planetswe',
 'decoder': 'mlp',
 'decoder_depth': 1,
 'device': 'cuda:1',
 'dim_feedforward': 400,
 'dropout': 0.1,
 'dt': 1.0,
 'early_stop': 10,
 'encoder': 'sindy_loss_transformer',
 'encoder_depth': 1,
 'epochs': 100,
 'eval_full': False,
 'generate_test_plots': True,
 'generate_training_plots': False,
 'hidden_size': 100,
 'include_sine': False,
 'latest_checkpoint_path': PosixPath('/home/alexey/Research4/T-SHRED/checkpoints/sindy_loss_transformer_mlp_planetswe_e1_d1_lr1.00e-03_p1_model_latest.pt'),
 'lr': 0.001,
 'n_heads': 

In [10]:

best_result = helpers.get_top_N_models_by_loss('sst', pickle_dir, N=1)[0][1]
print(best_result.keys())
print(best_result['hyperparameters'].keys())
print(best_result['sensors'])


dict_keys(['test_loss', 'start_epoch', 'best_val', 'best_epoch', 'train_losses', 'val_losses', 'sensors', 'hyperparameters'])
dict_keys(['batch_size', 'dataset', 'decoder', 'decoder_depth', 'device', 'dropout', 'dt', 'early_stop', 'encoder', 'eval_full', 'encoder_depth', 'epochs', 'hidden_size', 'generate_test_plots', 'generate_training_plots', 'include_sine', 'lr', 'n_heads', 'n_sensors', 'n_well_tracks', 'poly_order', 'save_every_n_epochs', 'sindy_attention_threshold', 'sindy_attention_threshold_epoch', 'sindy_attention_weight', 'sindy_loss_threshold', 'sindy_loss_weight', 'skip_load_checkpoint', 'verbose', 'window_length', 'd_data_in', 'data_rows_in', 'data_cols_in', 'd_data_out', 'data_rows_out', 'data_cols_out', 'd_model', 'dim_feedforward', 'output_size', 'latest_checkpoint_path', 'best_checkpoint_path'])
[(98, 215), (10, 132), (130, 248), (103, 155), (122, 183), (149, 111), (129, 71), (72, 71), (64, 272), (154, 75), (18, 350), (84, 241), (143, 51), (90, 222), (80, 312), (141, 24

In [13]:

f = Figlet(font='ogre')
for dataset in ['planetswe', 'sst', 'plasma']:
    best_result = helpers.get_top_N_models_by_loss(dataset, pickle_dir, N=1)
    print(best_result)
    plots.plot_model_results_scatter(results, dataset, 12, save=True, fname=f"{dataset}_scatter",
                                     title_fontsize=20, axes_fontsize=20, legend_fontsize=20)



[('sindy_loss_gru_mlp_planetswe_e1_d1_lr1.00e-03_p1.pkl', {'test_loss': 0.0019565847178455442, 'start_epoch': 56, 'best_val': 0.0022861711913719773, 'best_epoch': 56, 'train_losses': [0.13539415560662746, 0.036787479060391585, 0.0353351520995299, 0.03443405789633592, 0.03365889924267928, 0.03304231601456801, 0.032571984082460405, 0.03215464521199465, 0.03184547889977694, 0.031606756461163364, 0.03139839951569835, 0.03116901076088349, 0.030996071423093477, 0.03085127261777719, 0.030723219271749258, 0.030598820342371862, 0.03050073590129614, 0.030401754068831603, 0.030296519802262386, 0.03020246752227346, 0.03011476279546817, 0.03004136439412832, 0.02997541936735312, 0.029911178319404524, 0.029860481278349955, 0.029783278082807858, 0.029745476817091308, 0.029668194769571225, 0.029627272703995306, 0.029590167012065648, 0.029522737363974254, 0.029495889476190012, 0.02945811888203025, 0.02943574149782459, 0.02943315894032518, 0.02933460703740517, 0.029261857084929944, 0.029289048320303362, 

[('gru_unet_sst_e1_d1_lr1.00e-02_p1.pkl', {'test_loss': 0.0014724985230714083, 'start_epoch': 23, 'best_val': 0.0013933521695435047, 'best_epoch': 23, 'train_losses': [0.07939040722946326, 0.02424876681632466, 0.021546482625934813, 0.020280032729109127, 0.019665385286013286, 0.01936081527835793, 0.019292056146595214, 0.01946179796424177, 0.01919062787459956, 0.019263348231712978, 0.01913987286388874, 0.019032497372892167, 0.01906306089626418, 0.01902498449716303, 0.01895106438961294, 0.018926063345538244, 0.01896102498802874, 0.018952007095019024, 0.018926526937219832, 0.01895444778104623, 0.018980390081803005, 0.018903647859891255, 0.018908305714527767], 'val_losses': [0.009746433235704899, 0.006941520143300295, 0.0025536769535392523, 0.0030938012059777975, 0.004076869692653418, 0.002561084693297744, 0.001281240489333868, 0.0019356503617018461, 0.003393482184037566, 0.0017464838456362486, 0.001683334819972515, 0.0032039738725870848, 0.0022324572782963514, 0.0019462195923551917, 0.0023

[('lstm_unet_plasma_e3_d1_lr1.00e-02_p1.pkl', {'test_loss': 0.0004636579687939957, 'start_epoch': 97, 'best_val': 0.0005385390541050583, 'best_epoch': 97, 'train_losses': [0.0074916467464600615, 0.0009504193126653823, 0.0006071071167333195, 0.0005124679256158953, 0.0004871508505087919, 0.0004825130476652143, 0.00047481054878936935, 0.000478632918272454, 0.00047608839052442747, 0.00047263418449662055, 0.0004700831554006212, 0.0004714863921085802, 0.00047339061319899675, 0.00047511756509685744, 0.0004733269384954698, 0.00047332813175251853, 0.0004703080160722423, 0.000473937512232134, 0.0004707753810530099, 0.0004717185354540841, 0.0004739120575742653, 0.0004749729123432189, 0.0004710283090109722, 0.00047281225516389194, 0.00047534062356974644, 0.0004755888590947367, 0.00047235457056488556, 0.0004736566110155903, 0.00047241051932080434, 0.0004715589300478594, 0.0004715936889764495, 0.0004741584192603253, 0.0004760962284098451, 0.000470003733286061, 0.00047491695687103155, 0.0004724907675

In [12]:

f = Figlet(font='ogre')
for dataset in ['planetswe', 'sst', 'plasma']:
    best_result = helpers.get_top_N_models_by_loss(dataset, pickle_dir, N=1)
    print(best_result)
    plots.plot_model_results_scatter(results, dataset, 12, save=True, fname=f"{dataset}_scatter",
                                     title_fontsize=20, axes_fontsize=20, legend_fontsize=20)


[('sindy_loss_gru_mlp_planetswe_e1_d1_lr1.00e-03_p1.pkl', {'test_loss': 0.0019565847178455442, 'start_epoch': 56, 'best_val': 0.0022861711913719773, 'best_epoch': 56, 'train_losses': [0.13539415560662746, 0.036787479060391585, 0.0353351520995299, 0.03443405789633592, 0.03365889924267928, 0.03304231601456801, 0.032571984082460405, 0.03215464521199465, 0.03184547889977694, 0.031606756461163364, 0.03139839951569835, 0.03116901076088349, 0.030996071423093477, 0.03085127261777719, 0.030723219271749258, 0.030598820342371862, 0.03050073590129614, 0.030401754068831603, 0.030296519802262386, 0.03020246752227346, 0.03011476279546817, 0.03004136439412832, 0.02997541936735312, 0.029911178319404524, 0.029860481278349955, 0.029783278082807858, 0.029745476817091308, 0.029668194769571225, 0.029627272703995306, 0.029590167012065648, 0.029522737363974254, 0.029495889476190012, 0.02945811888203025, 0.02943574149782459, 0.02943315894032518, 0.02933460703740517, 0.029261857084929944, 0.029289048320303362, 

[('gru_unet_sst_e1_d1_lr1.00e-02_p1.pkl', {'test_loss': 0.0014724985230714083, 'start_epoch': 23, 'best_val': 0.0013933521695435047, 'best_epoch': 23, 'train_losses': [0.07939040722946326, 0.02424876681632466, 0.021546482625934813, 0.020280032729109127, 0.019665385286013286, 0.01936081527835793, 0.019292056146595214, 0.01946179796424177, 0.01919062787459956, 0.019263348231712978, 0.01913987286388874, 0.019032497372892167, 0.01906306089626418, 0.01902498449716303, 0.01895106438961294, 0.018926063345538244, 0.01896102498802874, 0.018952007095019024, 0.018926526937219832, 0.01895444778104623, 0.018980390081803005, 0.018903647859891255, 0.018908305714527767], 'val_losses': [0.009746433235704899, 0.006941520143300295, 0.0025536769535392523, 0.0030938012059777975, 0.004076869692653418, 0.002561084693297744, 0.001281240489333868, 0.0019356503617018461, 0.003393482184037566, 0.0017464838456362486, 0.001683334819972515, 0.0032039738725870848, 0.0022324572782963514, 0.0019462195923551917, 0.0023

Saved /home/alexey/Research4/T-SHRED/figures/sst_scatter.pdf


[('lstm_unet_plasma_e3_d1_lr1.00e-02_p1.pkl', {'test_loss': 0.0004636579687939957, 'start_epoch': 97, 'best_val': 0.0005385390541050583, 'best_epoch': 97, 'train_losses': [0.0074916467464600615, 0.0009504193126653823, 0.0006071071167333195, 0.0005124679256158953, 0.0004871508505087919, 0.0004825130476652143, 0.00047481054878936935, 0.000478632918272454, 0.00047608839052442747, 0.00047263418449662055, 0.0004700831554006212, 0.0004714863921085802, 0.00047339061319899675, 0.00047511756509685744, 0.0004733269384954698, 0.00047332813175251853, 0.0004703080160722423, 0.000473937512232134, 0.0004707753810530099, 0.0004717185354540841, 0.0004739120575742653, 0.0004749729123432189, 0.0004710283090109722, 0.00047281225516389194, 0.00047534062356974644, 0.0004755888590947367, 0.00047235457056488556, 0.0004736566110155903, 0.00047241051932080434, 0.0004715589300478594, 0.0004715936889764495, 0.0004741584192603253, 0.0004760962284098451, 0.000470003733286061, 0.00047491695687103155, 0.0004724907675