In [9]:
import os
import sys

import pandas as pd
import torch

In [10]:
# Add the parent directory of 'src' to sys.path if not already present
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import src.assimilation.assimilator  # noqa: E402
import src.assimilation.test_configs  # noqa: E402, F401

In [11]:
import matplotlib as mpl

mpl.rcParams["text.usetex"] = True
mpl.rc("text.latex", preamble=r"\usepackage{amsmath}")

In [12]:
SAVE_ROOT = "../ass_results/assimilate_various_obs_dist"

pickle_cache_path = "./plot_various_obs_dist.pkl"

read_cache = 0

try:
    assert read_cache
    results = pd.read_pickle(pickle_cache_path)
except:  # noqa: E722
    entries = []
    for filename in os.listdir(SAVE_ROOT):
        if filename.endswith(".pt"):
            file_path = os.path.join(SAVE_ROOT, filename)
            print(file_path)
            saved_data = torch.load(file_path, weights_only=False)
            entry = pd.DataFrame(
                {
                    "obs_type": saved_data["obs_type"][0],
                    "method": saved_data["method"],
                    "rand_mask_ratio": saved_data["rand_mask_ratio"],
                    "seed_no": saved_data["seedno"],
                    "rmse": saved_data["rmse"],
                    "step_idx": range(len(saved_data["rmse"])),
                    "time_stride": saved_data["time_stride"],
                    "dist_name": saved_data["dist_name"],
                    "avg_rmse": sum(saved_data["rmse"]) / len(saved_data["rmse"]),
                }
            )
            entries.append(entry)
    results = pd.concat(entries)
    results.to_pickle(pickle_cache_path)

../ass_results/assimilate_various_obs_dist/['arctan3x']_soad_seedno3_normal.pt
../ass_results/assimilate_various_obs_dist/['arctan3x']_soad_seedno3_lognormal-0.2.pt
../ass_results/assimilate_various_obs_dist/['arctan3x']_soad_seedno3_lognormal-0.5.pt
../ass_results/assimilate_various_obs_dist/['arctan3x']_soad_seedno4_normal.pt
../ass_results/assimilate_various_obs_dist/['vor2vel']_soad_seedno0_normal.pt
../ass_results/assimilate_various_obs_dist/['vor2vel']_soad_seedno0_laplace.pt
../ass_results/assimilate_various_obs_dist/['vor2vel']_soad_seedno0_uniform.pt
../ass_results/assimilate_various_obs_dist/['vor2vel']_soad_seedno0_lognormal-0.1.pt
../ass_results/assimilate_various_obs_dist/['vor2vel']_soad_seedno0_lognormal-0.2.pt
../ass_results/assimilate_various_obs_dist/['vor2vel']_soad_seedno0_lognormal-0.5.pt
../ass_results/assimilate_various_obs_dist/['vor2vel']_soad_seedno0_lognormal-1.0.pt
../ass_results/assimilate_various_obs_dist/['vor2vel']_soad_seedno1_normal.pt
../ass_results/a

In [13]:
results

Unnamed: 0,obs_type,method,rand_mask_ratio,seed_no,rmse,step_idx,time_stride,dist_name,avg_rmse
0,arctan3x,soad,1.00,3,0.083017,0,1,normal,0.084139
1,arctan3x,soad,1.00,3,0.080437,1,1,normal,0.084139
2,arctan3x,soad,1.00,3,0.079489,2,1,normal,0.084139
3,arctan3x,soad,1.00,3,0.080367,3,1,normal,0.084139
4,arctan3x,soad,1.00,3,0.081497,4,1,normal,0.084139
...,...,...,...,...,...,...,...,...,...
4,vor2vel,soad,0.25,4,0.130568,4,2,lognormal-1.0,0.132894
5,vor2vel,soad,0.25,4,0.136879,5,2,lognormal-1.0,0.132894
6,vor2vel,soad,0.25,4,0.141858,6,2,lognormal-1.0,0.132894
7,vor2vel,soad,0.25,4,0.151983,7,2,lognormal-1.0,0.132894


In [14]:
results.groupby(["time_stride", "obs_type", "dist_name"])["rmse"].mean()

time_stride  obs_type  dist_name    
1            arctan3x  laplace          0.084545
                       lognormal-0.1    0.084564
                       lognormal-0.2    0.084567
                       lognormal-0.5    0.084601
                       lognormal-1.0    0.085292
                       normal           0.084565
                       uniform          0.084490
             sin3x     laplace          0.064407
                       lognormal-0.1    0.064441
                       lognormal-0.2    0.064440
                       lognormal-0.5    0.064445
                       lognormal-1.0    0.065269
                       normal           0.064443
                       uniform          0.064414
             vor2vel   laplace          0.103135
                       lognormal-0.1    0.103044
                       lognormal-0.2    0.103050
                       lognormal-0.5    0.103070
                       lognormal-1.0    0.103645
                       normal   

In [15]:
results.groupby(["time_stride", "obs_type", "dist_name"])["rmse"].std()

time_stride  obs_type  dist_name    
1            arctan3x  laplace          0.004627
                       lognormal-0.1    0.004667
                       lognormal-0.2    0.004669
                       lognormal-0.5    0.004673
                       lognormal-1.0    0.004793
                       normal           0.004664
                       uniform          0.004639
             sin3x     laplace          0.003628
                       lognormal-0.1    0.003639
                       lognormal-0.2    0.003635
                       lognormal-0.5    0.003635
                       lognormal-1.0    0.003745
                       normal           0.003639
                       uniform          0.003605
             vor2vel   laplace          0.004903
                       lognormal-0.1    0.004918
                       lognormal-0.2    0.004914
                       lognormal-0.5    0.004901
                       lognormal-1.0    0.005103
                       normal   