In [1]:
import os
import sys

import pandas as pd
import torch

In [2]:
# Add the parent directory of 'src' to sys.path if not already present
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import src.assimilation.assimilator  # noqa: E402
import src.assimilation.test_configs  # noqa: E402, F401

In [3]:
import matplotlib as mpl

mpl.rcParams["text.usetex"] = True
mpl.rc("text.latex", preamble=r"\usepackage{amsmath}")

In [4]:
SAVE_ROOT = "../ass_results/assimilate_various_obs_dist_no_back"

pickle_cache_path = "./plot_various_obs_dist_no_back.pkl"

read_cache = 0

try:
    assert read_cache
    results = pd.read_pickle(pickle_cache_path)
except:  # noqa: E722
    entries = []
    for filename in os.listdir(SAVE_ROOT):
        if filename.endswith(".pt"):
            file_path = os.path.join(SAVE_ROOT, filename)
            print(file_path)
            saved_data = torch.load(file_path, weights_only=False)
            entry = pd.DataFrame(
                {
                    "obs_type": saved_data["obs_type"][0],
                    "method": saved_data["method"],
                    "rand_mask_ratio": saved_data["rand_mask_ratio"],
                    "seed_no": saved_data["seedno"],
                    "rmse": saved_data["rmse"],
                    "step_idx": range(len(saved_data["rmse"])),
                    "time_stride": saved_data["time_stride"],
                    "dist_name": saved_data["dist_name"],
                    "avg_rmse": sum(saved_data["rmse"]) / len(saved_data["rmse"]),
                }
            )
            entries.append(entry)
    results = pd.concat(entries)
    results.to_pickle(pickle_cache_path)

../ass_results/assimilate_various_obs_dist_no_back/['arctan3x']_soad_seedno3_uniform.pt
../ass_results/assimilate_various_obs_dist_no_back/['arctan3x']_soad_seedno3_lognormal-0.2.pt
../ass_results/assimilate_various_obs_dist_no_back/['arctan3x']_soad_seedno3_lognormal-1.0.pt
../ass_results/assimilate_various_obs_dist_no_back/['arctan3x']_soad_seedno4_laplace.pt
../ass_results/assimilate_various_obs_dist_no_back/['vor2vel']_soad_seedno0_normal.pt
../ass_results/assimilate_various_obs_dist_no_back/['vor2vel']_soad_seedno0_laplace.pt
../ass_results/assimilate_various_obs_dist_no_back/['vor2vel']_soad_seedno0_uniform.pt
../ass_results/assimilate_various_obs_dist_no_back/['vor2vel']_soad_seedno0_lognormal-0.1.pt
../ass_results/assimilate_various_obs_dist_no_back/['vor2vel']_soad_seedno0_lognormal-0.2.pt
../ass_results/assimilate_various_obs_dist_no_back/['vor2vel']_soad_seedno0_lognormal-0.5.pt
../ass_results/assimilate_various_obs_dist_no_back/['vor2vel']_soad_seedno0_lognormal-1.0.pt
../a

In [5]:
results

Unnamed: 0,obs_type,method,rand_mask_ratio,seed_no,rmse,step_idx,time_stride,dist_name,avg_rmse
0,arctan3x,soad,1.00,3,0.148576,0,1,uniform,0.136991
1,arctan3x,soad,1.00,3,0.142846,1,1,uniform,0.136991
2,arctan3x,soad,1.00,3,0.139388,2,1,uniform,0.136991
3,arctan3x,soad,1.00,3,0.136268,3,1,uniform,0.136991
4,arctan3x,soad,1.00,3,0.134005,4,1,uniform,0.136991
...,...,...,...,...,...,...,...,...,...
4,sin3x,soad,0.25,4,0.755156,4,2,lognormal-1.0,0.756805
5,sin3x,soad,0.25,4,0.750577,5,2,lognormal-1.0,0.756805
6,sin3x,soad,0.25,4,0.746437,6,2,lognormal-1.0,0.756805
7,sin3x,soad,0.25,4,0.743420,7,2,lognormal-1.0,0.756805


In [6]:
results.groupby(["time_stride", "obs_type", "dist_name"])["rmse"].mean()

time_stride  obs_type  dist_name    
1            arctan3x  laplace          0.138637
                       lognormal-0.1    0.138213
                       lognormal-0.2    0.138236
                       lognormal-0.5    0.138319
                       lognormal-1.0    0.142043
                       normal           0.138191
                       uniform          0.138434
             sin3x     laplace          0.144401
                       lognormal-0.1    0.101446
                       lognormal-0.2    0.130168
                       lognormal-0.5    0.120055
                       lognormal-1.0    0.140685
                       normal           0.129269
                       uniform          0.159051
             vor2vel   laplace          0.118525
                       lognormal-0.1    0.118502
                       lognormal-0.2    0.118512
                       lognormal-0.5    0.118534
                       lognormal-1.0    0.120064
                       normal   

In [7]:
results.groupby(["time_stride", "obs_type", "dist_name"])["rmse"].std()

time_stride  obs_type  dist_name    
1            arctan3x  laplace          0.005311
                       lognormal-0.1    0.005283
                       lognormal-0.2    0.005286
                       lognormal-0.5    0.005304
                       lognormal-1.0    0.006502
                       normal           0.005280
                       uniform          0.005319
             sin3x     laplace          0.011751
                       lognormal-0.1    0.035030
                       lognormal-0.2    0.049982
                       lognormal-0.5    0.060793
                       lognormal-1.0    0.016260
                       normal           0.031738
                       uniform          0.055866
             vor2vel   laplace          0.007599
                       lognormal-0.1    0.007534
                       lognormal-0.2    0.007538
                       lognormal-0.5    0.007547
                       lognormal-1.0    0.007834
                       normal   