## Imports

In [1]:

%load_ext autoreload
%autoreload 2

import os
import sys
import torch
import einops
import random
import pickle
import numpy as np
import pprint as pp
import torch.nn as nn
from pathlib import Path
from pyfiglet import Figlet
from torch.utils.data import DataLoader
from src.plots import plot_losses, plot_field_comparison

# Add the project root directory to Python path
top_dir = Path.cwd().parent
sys.path.insert(0, str(top_dir.absolute()))

from src import *

slurm_dir = top_dir / 'slurms'
pickle_dir = top_dir / 'pickles'


In [2]:

# We will iterate through every combination of these
datasets = ["sst", "plasma", "planetswe_pod", "gray_scott_reaction_diffusion_pod"]
encoders = ["lstm", "vanilla_transformer", "sindy_attention_transformer"]
decoders = ["mlp", "unet"]
lrs = [f"lr{i:0.2e}" for i in [1e-2, 1e-3, 1e-4, 1e-5]]

# These two will be zipped pairwise
encoder_depths = [f"e{i}" for i in[1, 2, 3]]
decoder_depths = [f"d{i}" for i in[1, 2]]


In [3]:

# Get list of results
results = helpers.get_dictionaries_from_pickles(pickle_dir)


In [4]:

# Print the results / data format
print(type(results))
pp.pprint(list(results[0].keys()))
pp.pprint(results[0]['hyperparameters'])


<class 'list'>
['test_loss',
 'start_epoch',
 'best_val',
 'best_epoch',
 'train_losses',
 'val_losses',
 'sensors',
 'hyperparameters']
{'batch_size': 128,
 'best_checkpoint_path': PosixPath('/app/code/checkpoints/lstm_unet_sst_e4_d2_lr1.00e-03_model_best.pt'),
 'd_data': 1,
 'd_model': 50,
 'data_cols': 360,
 'data_rows': 180,
 'dataset': 'sst',
 'decoder': 'unet',
 'decoder_depth': 2,
 'device': 'cuda:0',
 'dim_feedforward': 32,
 'dropout': 0.1,
 'encoder': 'lstm',
 'encoder_depth': 4,
 'epochs': 200,
 'eval_full': False,
 'hidden_size': 8,
 'include_sine': False,
 'latest_checkpoint_path': PosixPath('/app/code/checkpoints/lstm_unet_sst_e4_d2_lr1.00e-03_model_latest.pt'),
 'lr': 0.001,
 'n_heads': 2,
 'n_sensors': 50,
 'output_size': 64800,
 'poly_order': 2,
 'save_every_n_epochs': 10,
 'skip_load_checkpoint': False,
 'verbose': True,
 'window_length': 50}


In [5]:

best_result = helpers.get_top_N_models_by_loss('sst', pickle_dir, N=1)[0][1]
print(best_result.keys())
print(best_result['hyperparameters'].keys())
print(best_result['sensors'])


dict_keys(['test_loss', 'start_epoch', 'best_val', 'best_epoch', 'train_losses', 'val_losses', 'sensors', 'hyperparameters'])
dict_keys(['batch_size', 'dataset', 'decoder', 'decoder_depth', 'device', 'dropout', 'encoder', 'eval_full', 'encoder_depth', 'epochs', 'hidden_size', 'include_sine', 'lr', 'n_heads', 'n_sensors', 'poly_order', 'save_every_n_epochs', 'skip_load_checkpoint', 'verbose', 'window_length', 'd_data', 'data_rows', 'data_cols', 'd_model', 'dim_feedforward', 'output_size', 'latest_checkpoint_path', 'best_checkpoint_path'])
[(98, 215), (10, 132), (130, 248), (103, 155), (122, 183), (149, 111), (129, 71), (72, 71), (24, 316), (64, 272), (154, 75), (79, 50), (18, 350), (84, 241), (143, 51), (90, 222), (80, 312), (163, 104), (141, 244), (113, 266), (66, 31), (140, 7), (23, 204), (171, 320), (0, 313), (126, 170), (62, 166), (16, 97), (145, 113), (61, 72), (139, 229), (23, 41), (81, 260), (125, 55), (77, 282), (74, 63), (140, 170), (138, 104), (154, 280), (150, 147), (113, 46)

In [6]:

f = Figlet(font='ogre')
for dataset in ['planetswe_pod', 'sst', 'plasma', 'gray_scott_reaction_diffusion_pod']:
    print(f.renderText(dataset))
    best_result = helpers.get_top_N_models_by_loss(dataset, pickle_dir, N=1)[0][1]
    #if 'pod' in best_result['hyperparameters']['dataset']:
        #plots.plot_field_comparison_scratch(best_result)
    #plots.plot_field_comparison_scratch(best_result)
    plots.plot_model_results_scatter(results, dataset, 12, save=True, fname=f"{dataset}_scatter.pdf")
    if 'train_losses' in best_result:
        plots.plot_losses(best_result['train_losses'], best_result['val_losses'], best_result['best_epoch'], save=True, fname=f"{dataset}_losses.pdf")


       _                  _                                         _ 
 _ __ | | __ _ _ __   ___| |_ _____      _____      _ __   ___   __| |
| '_ \| |/ _` | '_ \ / _ \ __/ __\ \ /\ / / _ \    | '_ \ / _ \ / _` |
| |_) | | (_| | | | |  __/ |_\__ \\ V  V /  __/    | |_) | (_) | (_| |
| .__/|_|\__,_|_| |_|\___|\__|___/ \_/\_/ \___|____| .__/ \___/ \__,_|
|_|                                          |_____|_|                

         _   
 ___ ___| |_ 
/ __/ __| __|
\__ \__ \ |_ 
|___/___/\__|
             

       _                           
 _ __ | | __ _ ___ _ __ ___   __ _ 
| '_ \| |/ _` / __| '_ ` _ \ / _` |
| |_) | | (_| \__ \ | | | | | (_| |
| .__/|_|\__,_|___/_| |_| |_|\__,_|
|_|                                

                                         _   _                           _   _ 
  __ _ _ __ __ _ _   _     ___  ___ ___ | |_| |_     _ __ ___  __ _  ___| |_(_)
 / _` | '__/ _` | | | |   / __|/ __/ _ \| __| __|   | '__/ _ \/ _` |/ __| __| |
| (_| | | | (_| | |_| |   \__ \