## Imports

In [11]:

%load_ext autoreload
%autoreload 2

import os
import sys
import torch
import einops
import random
import pickle
import numpy as np
import pprint as pp
import torch.nn as nn
from pathlib import Path
from pyfiglet import Figlet
from torch.utils.data import DataLoader
from src.plots import plot_losses, plot_field_comparison

# Add the project root directory to Python path
top_dir = Path.cwd().parent
sys.path.insert(0, str(top_dir.absolute()))

from src import *

slurm_dir = top_dir / 'slurms'
pickle_dir = top_dir / 'pickles'
print(pickle_dir)


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/Users/alexey/Git/T-SHRED_Repo/pickles


In [12]:

# We will iterate through every combination of these
datasets = ["sst", "plasma", "planetswe_pod", "gray_scott_reaction_diffusion_pod"]
encoders = ["lstm", "vanilla_transformer", "sindy_attention_transformer"]
decoders = ["mlp", "unet"]
lrs = [f"lr{i:0.2e}" for i in [1e-2, 1e-3, 1e-4, 1e-5]]

# These two will be zipped pairwise
encoder_depths = [f"e{i}" for i in[1, 2, 3]]
decoder_depths = [f"d{i}" for i in[1, 2]]


In [13]:

# Get list of results
results = helpers.get_dictionaries_from_pickles(pickle_dir, early_stop=None)


In [14]:

# Print the results / data format
print(type(results))
pp.pprint(list(results[0].keys()))
pp.pprint(results[0]['hyperparameters'])


<class 'list'>
['test_loss',
 'start_epoch',
 'best_val',
 'best_epoch',
 'train_losses',
 'val_losses',
 'sensors',
 'hyperparameters']
{'batch_size': 128,
 'best_checkpoint_path': PosixPath('/app/code/checkpoints/sindy_loss_transformer_mlp_plasma_e1_d1_lr1.00e-02_p1_model_best.pt'),
 'd_data': 1,
 'd_model': 50,
 'data_cols': 2000,
 'data_rows': 1,
 'dataset': 'plasma',
 'decoder': 'mlp',
 'decoder_depth': 1,
 'device': 'cuda:0',
 'dim_feedforward': 16,
 'dropout': 0.1,
 'dt': 1.0,
 'early_stop': 10,
 'encoder': 'sindy_loss_transformer',
 'encoder_depth': 1,
 'epochs': 100,
 'eval_full': False,
 'generate_test_plots': True,
 'generate_training_plots': False,
 'hidden_size': 4,
 'include_sine': False,
 'latest_checkpoint_path': PosixPath('/app/code/checkpoints/sindy_loss_transformer_mlp_plasma_e1_d1_lr1.00e-02_p1_model_latest.pt'),
 'lr': 0.01,
 'n_heads': 2,
 'n_sensors': 50,
 'n_well_tracks': 10,
 'output_size': 2000,
 'poly_order': 1,
 'save_every_n_epochs': 10,
 'sindy_threshold':

In [15]:

best_result = helpers.get_top_N_models_by_loss('sst', pickle_dir, N=1)[0][1]
print(best_result.keys())
print(best_result['hyperparameters'].keys())
print(best_result['sensors'])


dict_keys(['test_loss', 'start_epoch', 'best_val', 'best_epoch', 'train_losses', 'val_losses', 'sensors', 'hyperparameters'])
dict_keys(['batch_size', 'dataset', 'decoder', 'decoder_depth', 'device', 'dropout', 'dt', 'early_stop', 'encoder', 'eval_full', 'encoder_depth', 'epochs', 'hidden_size', 'generate_test_plots', 'generate_training_plots', 'include_sine', 'lr', 'n_heads', 'n_sensors', 'n_well_tracks', 'poly_order', 'save_every_n_epochs', 'sindy_threshold', 'sindy_weight', 'skip_load_checkpoint', 'verbose', 'window_length', 'd_data', 'data_rows', 'data_cols', 'd_model', 'dim_feedforward', 'output_size', 'latest_checkpoint_path', 'best_checkpoint_path'])
[(98, 215), (10, 132), (130, 248), (103, 155), (122, 183), (149, 111), (129, 71), (72, 71), (64, 272), (154, 75), (18, 350), (84, 241), (143, 51), (90, 222), (80, 312), (141, 244), (113, 266), (140, 7), (0, 313), (126, 170), (62, 166), (145, 113), (139, 229), (23, 41), (81, 260), (125, 55), (77, 282), (74, 63), (140, 170), (138, 104

In [16]:

f = Figlet(font='ogre')
for dataset in ['planetswe_full', 'sst', 'plasma']:
    best_result = helpers.get_top_N_models_by_loss(dataset, pickle_dir, N=1)
    print(best_result)
    plots.plot_model_results_scatter(results, dataset, 12, save=True, fname=f"{dataset}_scatter")


[('sindy_attention_transformer_mlp_planetswe_full_e4_d1_lr1.00e-03_p1.pkl', {'test_loss': 2666.7405700683594, 'start_epoch': 79, 'best_val': 2658.951220703125, 'best_epoch': 79, 'train_losses': [29668.244759114583, 28685.5783203125, 27504.205631510416, 26180.196809895835, 24712.47685546875, 23193.797591145834, 21716.901139322916, 20345.537760416668, 19008.105110677083, 17796.333626302083, 16667.783740234376, 15632.181168619793, 14701.136263020833, 13829.835091145833, 13008.491471354168, 12217.68310546875, 11470.368375651042, 10809.981998697916, 10184.480289713541, 9593.36513671875, 9112.484195963541, 8642.800577799479, 8253.454296875, 7894.79892578125, 7560.525455729166, 7293.146240234375, 7064.3615234375, 6840.81259765625, 6651.67568359375, 6522.373046875, 6371.6815999348955, 6245.40419921875, 6154.406486002604, 6066.626896158854, 5980.98408203125, 5909.398746744791, 5862.173356119792, 5801.894555664063, 5738.948494466146, 5692.903206380209, 5647.671053059896, 5594.2410074869795, 5546

[('gru_mlp_sst_e1_d1_lr1.00e-02_p1.pkl', {'test_loss': 0.0015706232516095042, 'start_epoch': 40, 'best_val': 0.0013682936551049352, 'best_epoch': 40, 'train_losses': [0.08763495542936856, 0.025611266493797302, 0.027644022471374936, 0.02358312987618976, 0.021068058701025114, 0.02010872173640463, 0.01963767098883788, 0.019427185464236472, 0.019307324041922886, 0.019355880717436474, 0.01930902298125956, 0.019288397083679836, 0.01920843082997534, 0.019163688230845664, 0.019118323715196714, 0.019129213152660265, 0.01916152321630054, 0.01910323691036966, 0.019068804052140977, 0.019076125903262034, 0.01909270489381419, 0.01901308385034402, 0.01902760937809944, 0.019004565767116018, 0.01899071844915549, 0.018994363231791392, 0.01894232899778419, 0.01899667725794845, 0.01891983983417352, 0.018967037192649312, 0.01891509298649099, 0.0188889242708683, 0.018915969050592847, 0.01888340525329113, 0.018926629175742466, 0.018891462435324986, 0.018938654826747045, 0.018841072503063414, 0.01887480521367

[('lstm_mlp_plasma_e3_d1_lr1.00e-02_p1.pkl', {'test_loss': 0.0009910279180845604, 'start_epoch': 96, 'best_val': 0.002528695808275818, 'best_epoch': 96, 'train_losses': [0.06751685664782849, 0.043637260876215285, 0.04292129389851053, 0.04288506996421338, 0.042856693023095166, 0.04286523630112917, 0.042869424361309576, 0.0428516896681774, 0.04285840807061126, 0.042822744797036016, 0.042872795903117114, 0.04285112790165156, 0.04284719654207102, 0.042865893193985136, 0.04283975475352176, 0.04286431604565785, 0.04286354470209484, 0.04289540296778482, 0.04287197526970339, 0.042851343977552844, 0.042877176645100844, 0.04283410056047776, 0.042851163142353946, 0.042840250489050456, 0.042885759369517766, 0.042860649625822866, 0.042860323168935566, 0.04281570184114786, 0.042874809708038385, 0.04279780977036251, 0.04288944399886178, 0.04280651641733165, 0.0428446507729463, 0.042813631232347514, 0.04284077153117407, 0.04286007224208247, 0.04283048436408205, 0.042865650016823534, 0.0428439745829053