In [None]:
import holoviews as hv
hv.extension("bokeh")
import hvplot.pandas
import pandas as pd
from pathlib import Path
import sys
sys.path.append("/proj/gaia-climate/team/kirill/gaia-surrogate")
from gaia.training import load_hparams_file


In [None]:
import tqdm.auto as tqdm

In [None]:
def get_metrics(model_dir, dataset = "spcam"):

    try:
        metrics = pd.read_json(next(Path(model_dir).glob(f"*{dataset}*"))).T.squeeze()
    except Exception:
        return pd.Series()
    hparams = load_hparams_file(model_dir)
    metrics["subsample"] = int(hparams["dataset_params"]["train"].get("subsample",1))
    metrics["batch_size"] = int(hparams["dataset_params"]["train"].get("batch_size",1))

    metrics["is_finetuned"]  = int(hparams.get("is_finetuned",False))
    metrics["base_model"] = "random"
    metrics["samples"] = 24 * 96 * 144 * 54 // metrics["subsample"]
    metrics["lr"] = hparams["lr"]
    
    if "base_cam4" in str(model_dir):
        # return pd.Series()
        metrics["is_finetuned"] = True
        metrics["samples"] = 10

    # if "base_spacm" in str(model_dir):
    #     metrics["base_model"] = "cam4" 
    #     # metrics["base_model"] = "cam4"
    #     # metrics["is_finetuned"] = True
    #     # metrics["subsample"] = 

    
    if metrics["is_finetuned"]:
        metrics["base_model"] = "cam4"        
        
        
    return metrics

data = pd.concat([get_metrics(f).to_frame().T for f in Path("lightning_logs").glob("*")]).dropna()
data = data.iloc[:,1:]
data.columns = [c.replace("test_skill_ave_trunc_","") for c in data.columns]





In [None]:
from gaia.plot import levels, levels26


In [None]:
pd.Series(levels).round(2).to_json()

In [None]:
data = data.query("batch_size>=64")#.subsample.unique().astype(int)

In [None]:
# data.subsample.drop_duplicates().astype(int).to_csv("subsample.csv")

In [None]:
min_level = 11

data["PTEQ"] = data.loc[:,[c for c in data.columns if "PTEQ_" in c]].iloc[:,min_level:].mean(1)
data["PTTEND"] = data.loc[:,[c for c in data.columns if "PTTEND_" in c]].mean(1)


In [None]:
def plot_one_with_error_bars(base_model, metric):
    temp1 = data.query(f"base_model=='{base_model}'")
    temp = temp1.groupby(["samples"])[metric].agg(["mean","std"]).rename(columns = {"mean":metric})
    return  hv.Scatter(temp1, kdims=["samples"], vdims=[metric], label = base_model).opts(size = 2)*\
    hv.Curve(temp.reset_index(), kdims=["samples"], vdims=[metric], label = base_model).opts(logx=True,line_width=.5, show_grid = True).redim.range(**{metric:(0,1)})
    # hv.ErrorBars(temp.reset_index(), kdims=["samples"], vdims=[metric,"std"], label = base_model).opts(line_width=.5)*\

    # hv.Scatter(temp.reset_index(), kdims=["samples"], vdims=[metric], label = base_model).opts(size = 3)*\



out = hv.Layout([(plot_one_with_error_bars("cam4", v)*plot_one_with_error_bars("random", v)).opts(legend_position = "bottom_right") for v in ["PRECT","PTTEND","PTEQ"]])
# hv.save(out,"plot_finetune.html")
out


In [None]:
len(levels)

In [None]:
var = "PTTEND"
num_levels = 30
out = hv.Layout([(plot_one_with_error_bars("cam4", f"{var}_{i:02}")*\
                  plot_one_with_error_bars("random", f"{var}_{i:02}").redim.range(**{f"{var}_{i:02}":(-.05,1.05)}))\
                 .opts(legend_position = "bottom_right", title = f"{i:02}: {levels[i]:.2f}", width = 250, height = 250) for i in range(num_levels)]).cols(5)
hv.save(out,f"plot_finetune_{var}.html")
out