## Imports

In [1]:

%load_ext autoreload
%autoreload 2

import os
import sys
import torch
import einops
import random
import pickle
import numpy as np
import pprint as pp
import torch.nn as nn
from pathlib import Path
from pyfiglet import Figlet
from torch.utils.data import DataLoader

# Add the project root directory to Python path
top_dir = Path.cwd().parent

sys.path.insert(0, str(top_dir.absolute()))

from src import *

slurm_dir = top_dir / 'slurms'
pickle_dir = top_dir / 'pickles'


In [2]:

# Get list of results
results = helpers.get_dictionaries_from_pickles(pickle_dir)
print("Num. results:", len(results))


Num. results: 1920


In [5]:

f = Figlet(font='ogre')
for dataset in ['planetswe', 'sst', 'plasma']:
    best_result = helpers.get_top_N_models_by_loss(dataset, pickle_dir, N=1)
    print(best_result)
    plots.plot_model_results_scatter(results, dataset, 12, save=True, fname=f"{dataset}_worst_scatter", title_fontsize=20, axes_fontsize=20, legend_fontsize=20, reverse=False)



[('sindy_loss_gru_mlp_planetswe_e1_d1_lr1.00e-03_p1_s0.pkl', {'test_loss': 0.0019218050583731383, 'start_epoch': 56, 'best_val': 0.0022574458736926315, 'best_epoch': 56, 'train_losses': [0.13539191006372372, 0.03678547590970993, 0.035333632367352645, 0.034432734735310076, 0.03365748537083467, 0.03304062448441982, 0.0325704620530208, 0.03215357108662526, 0.03184492898484071, 0.03160551277299722, 0.031397384808709225, 0.03116783043369651, 0.03099450245499611, 0.030849991738796233, 0.030722183175385, 0.030597152343640726, 0.030499516831090052, 0.03040029527619481, 0.030295406592388947, 0.030201291510214407, 0.030113864131271838, 0.03004052077109615, 0.029974634852260353, 0.029910281642029683, 0.029859598726034164, 0.029782726801931857, 0.029745392066737018, 0.02966735794519385, 0.02962634861469269, 0.02958958751211564, 0.029521808245529732, 0.02949525071308017, 0.029457247847070298, 0.029433266104509432, 0.02943354342132807, 0.029333130301286776, 0.029261330297837656, 0.029290406312793494

[('gru_mlp_sst_e2_d1_lr1.00e-03_p1_s2.pkl', {'test_loss': 0.0012291716411709785, 'start_epoch': 65, 'best_val': 0.0010706607718020678, 'best_epoch': 65, 'train_losses': [0.1672745860285229, 0.07512702296177547, 0.033827256618274584, 0.02701966195470757, 0.024518782272934914, 0.023421030905511644, 0.02293417975306511, 0.022633249974913068, 0.022451146816213925, 0.022347463708784845, 0.022249660765131313, 0.0220568956186374, 0.02203246599270238, 0.021900197491049767, 0.02179942528406779, 0.021742651652958658, 0.021610755680335894, 0.02158963017993503, 0.0214728984153933, 0.021355981214178935, 0.021307408188780148, 0.021159285265538428, 0.021031749124328297, 0.020922080510192446, 0.020751559279031224, 0.020597263135843806, 0.020497622796230845, 0.020336969445149105, 0.020248877712421946, 0.020165388989779685, 0.020051089632842276, 0.019991698778337903, 0.019941633567214012, 0.019897494465112686, 0.019804369658231735, 0.019817862866653338, 0.01979681808087561, 0.019716526485151716, 0.01977

[('gru_unet_plasma_e3_d1_lr1.00e-03_p1_s0.pkl', {'test_loss': 0.0002057709571090527, 'start_epoch': 99, 'best_val': 0.000275115598924458, 'best_epoch': 99, 'train_losses': [0.0011198220398420324, 0.0005293651080976885, 0.0004928591593992538, 0.00048037336763137806, 0.00047406657204891625, 0.0004659415168974262, 0.0004641128202470449, 0.0004568589233363477, 0.00045339087284026813, 0.00044137748549334134, 0.00043388788113728736, 0.0004344144970393525, 0.00042302966288004356, 0.0004215712363545138, 0.00040928386330891115, 0.0004069045736776808, 0.00039577374431806116, 0.00038884406183989573, 0.0003836328396573663, 0.00037981295519365143, 0.00037373176802737784, 0.00036592404313313845, 0.00035850295814900444, 0.0003501121245790273, 0.000350170121796859, 0.00034553655244123476, 0.00033958648250868113, 0.0003346683941065119, 0.00033003602696296113, 0.00032643248022605595, 0.00032015335343133373, 0.00031651978721269046, 0.000307417884379482, 0.0003081003848749858, 0.0003026697063782754, 0.000

In [6]:

f = Figlet(font='ogre')
for dataset in ['planetswe', 'sst', 'plasma']:
    best_result = helpers.get_top_N_models_by_loss(dataset, pickle_dir, N=1)
    print(best_result)
    plots.plot_model_results_scatter(results, dataset, 12, save=True, fname=f"{dataset}_scatter", title_fontsize=20, axes_fontsize=20, legend_fontsize=20, reverse=True)


[('sindy_loss_gru_mlp_planetswe_e1_d1_lr1.00e-03_p1_s0.pkl', {'test_loss': 0.0019218050583731383, 'start_epoch': 56, 'best_val': 0.0022574458736926315, 'best_epoch': 56, 'train_losses': [0.13539191006372372, 0.03678547590970993, 0.035333632367352645, 0.034432734735310076, 0.03365748537083467, 0.03304062448441982, 0.0325704620530208, 0.03215357108662526, 0.03184492898484071, 0.03160551277299722, 0.031397384808709225, 0.03116783043369651, 0.03099450245499611, 0.030849991738796233, 0.030722183175385, 0.030597152343640726, 0.030499516831090052, 0.03040029527619481, 0.030295406592388947, 0.030201291510214407, 0.030113864131271838, 0.03004052077109615, 0.029974634852260353, 0.029910281642029683, 0.029859598726034164, 0.029782726801931857, 0.029745392066737018, 0.02966735794519385, 0.02962634861469269, 0.02958958751211564, 0.029521808245529732, 0.02949525071308017, 0.029457247847070298, 0.029433266104509432, 0.02943354342132807, 0.029333130301286776, 0.029261330297837656, 0.029290406312793494

[('gru_mlp_sst_e2_d1_lr1.00e-03_p1_s2.pkl', {'test_loss': 0.0012291716411709785, 'start_epoch': 65, 'best_val': 0.0010706607718020678, 'best_epoch': 65, 'train_losses': [0.1672745860285229, 0.07512702296177547, 0.033827256618274584, 0.02701966195470757, 0.024518782272934914, 0.023421030905511644, 0.02293417975306511, 0.022633249974913068, 0.022451146816213925, 0.022347463708784845, 0.022249660765131313, 0.0220568956186374, 0.02203246599270238, 0.021900197491049767, 0.02179942528406779, 0.021742651652958658, 0.021610755680335894, 0.02158963017993503, 0.0214728984153933, 0.021355981214178935, 0.021307408188780148, 0.021159285265538428, 0.021031749124328297, 0.020922080510192446, 0.020751559279031224, 0.020597263135843806, 0.020497622796230845, 0.020336969445149105, 0.020248877712421946, 0.020165388989779685, 0.020051089632842276, 0.019991698778337903, 0.019941633567214012, 0.019897494465112686, 0.019804369658231735, 0.019817862866653338, 0.01979681808087561, 0.019716526485151716, 0.01977

[('gru_unet_plasma_e3_d1_lr1.00e-03_p1_s0.pkl', {'test_loss': 0.0002057709571090527, 'start_epoch': 99, 'best_val': 0.000275115598924458, 'best_epoch': 99, 'train_losses': [0.0011198220398420324, 0.0005293651080976885, 0.0004928591593992538, 0.00048037336763137806, 0.00047406657204891625, 0.0004659415168974262, 0.0004641128202470449, 0.0004568589233363477, 0.00045339087284026813, 0.00044137748549334134, 0.00043388788113728736, 0.0004344144970393525, 0.00042302966288004356, 0.0004215712363545138, 0.00040928386330891115, 0.0004069045736776808, 0.00039577374431806116, 0.00038884406183989573, 0.0003836328396573663, 0.00037981295519365143, 0.00037373176802737784, 0.00036592404313313845, 0.00035850295814900444, 0.0003501121245790273, 0.000350170121796859, 0.00034553655244123476, 0.00033958648250868113, 0.0003346683941065119, 0.00033003602696296113, 0.00032643248022605595, 0.00032015335343133373, 0.00031651978721269046, 0.000307417884379482, 0.0003081003848749858, 0.0003026697063782754, 0.000