In [None]:
import holoviews as hv
hv.extension("bokeh")
import hvplot.pandas
import pandas as pd
from pathlib import Path
import sys
sys.path.append("/proj/gaia-climate/team/kirill/gaia-surrogate")
from gaia.training import load_hparams_file, get_dataset_from_model, get_checkpoint_file, get_levels
from gaia.data import unflatten_tensor, flatten_tensor
from gaia.config import levels
from gaia.plot import lats, lons, get_land_outline
from gaia.models import TrainingModel
import tqdm.auto as tqdm
import torch
import numpy as np
import hvplot

In [None]:
model_dir_cam4 = "/proj/gaia-climate/team/kirill/gaia-paper/fine-tune/lightning_logs/base_cam4"
model_dir_spcam = "/proj/gaia-climate/team/kirill/gaia-paper/fine-tune/lightning_logs/base_spcam"
model_cam4 = TrainingModel.load_from_checkpoint(get_checkpoint_file(model_dir_cam4), map_location="cpu").eval()
model_spcam = TrainingModel.load_from_checkpoint(get_checkpoint_file(model_dir_spcam), map_location="cpu").eval()

# predictions = torch.load(model_dir+f"/predictions_{dataset}.pt")
# test_dataset, test_loader  = get_dataset_from_model(model)
# targets =  unflatten_tensor(test_dataset["y"])


In [None]:
for t in ["input","output"]:
    print(t,",".join(model_cam4.hparams[f"{t}_index"].keys()))

### plot input normalization


In [None]:
assert model_cam4.hparams["input_index"] == model_spcam.hparams["input_index"]

plots = {}

cam4_mean = model_cam4.input_normalize.mean.squeeze().numpy()
spcam_mean = model_spcam.input_normalize.mean.squeeze().numpy()

cam4_std = model_cam4.input_normalize.std.squeeze().numpy()
spcam_std = model_spcam.input_normalize.std.squeeze().numpy()

levels_30 = np.array(levels["spcam"])

for k,v in model_cam4.hparams["input_index"].items():
    s,e = v
    
    if e-s>1:
        print(k,v)

        mean_min_max = [cam4_mean[s:e].min(),cam4_mean[s:e].max()]
        std_min_max = [cam4_std[s:e].min(),cam4_std[s:e].max()]

        plots[(k,"mean")] = hv.Curve((mean_min_max,mean_min_max),[f"spcam_{k}_mean"],[f"cam4_{k}_mean"])*hv.Scatter((cam4_mean[s:e], spcam_mean[s:e],levels_30),[f"spcam_{k}_mean"],[f"cam4_{k}_mean","levels"])
        plots[(k,"std")] = hv.Curve((std_min_max,std_min_max),[f"spcam_{k}_std"],[f"cam4_{k}_std"])*hv.Scatter((cam4_std[s:e], spcam_std[s:e],levels_30),[f"spcam_{k}_std"],[f"cam4_{k}_std","levels"])
    else:
        print(k,"not vector")

stat_comp_plot = hv.HoloMap(plots, kdims=["variable", "stat"]).opts(hv.opts.Scatter(padding = .1, tools = ["hover"], show_grid = True, size = 10, color = "levels", cmap = "PuOr", width = 500, height = 500, colorbar = True),hv.opts.Curve(line_width = .5)).opts(axiswise =False).layout(["variable","stat"]).cols(2).opts(title = "Input Stats")
hv.save(stat_comp_plot, "stat_comp_plot_input_3d_vars.html")
stat_comp_plot



### plot output normalization


In [None]:
### plot input normalization
assert model_cam4.hparams["output_index"] == model_spcam.hparams["output_index"]

plots = {}

cam4_mean = model_cam4.output_normalize.mean.squeeze().numpy()
spcam_mean = model_spcam.output_normalize.mean.squeeze().numpy()

cam4_std = model_cam4.output_normalize.std.squeeze().numpy()
spcam_std = model_spcam.output_normalize.std.squeeze().numpy()

levels_30 = np.array(levels["spcam"])

for k,v in model_cam4.hparams["output_index"].items():
    s,e = v
    
    if e-s>1:
        print(k,v)

        mean_min_max = [cam4_mean[s:e].min(),cam4_mean[s:e].max()]
        std_min_max = [cam4_std[s:e].min(),cam4_std[s:e].max()]

        plots[(k,"mean")] = hv.Curve((mean_min_max,mean_min_max),[f"spcam_{k}_mean"],[f"cam4_{k}_mean"])*hv.Scatter((cam4_mean[s:e], spcam_mean[s:e],levels_30),[f"spcam_{k}_mean"],[f"cam4_{k}_mean","levels"])
        plots[(k,"std")] = hv.Curve((std_min_max,std_min_max),[f"spcam_{k}_std"],[f"cam4_{k}_std"])*hv.Scatter((cam4_std[s:e], spcam_std[s:e],levels_30),[f"spcam_{k}_std"],[f"cam4_{k}_std","levels"])


stat_comp_plot = hv.HoloMap(plots, kdims=["variable", "stat"]).opts(hv.opts.Scatter(padding = .1, tools = ["hover"], show_grid = True, size = 10, color = "levels", cmap = "PuOr", width = 500, height = 500, colorbar = True),hv.opts.Curve(line_width = .5)).opts(axiswise =False).layout(["variable"]).select(stat = "std").cols(1).opts(title = "Output Stats")
hv.save(stat_comp_plot, "stat_comp_plot_output_3d_vars.html")
stat_comp_plot


### look at dists


In [None]:
test_dataset_cam4, test_loader  = get_dataset_from_model(model_cam4,split = "test")
# test_dataset_spcam, test_loader  = get_dataset_from_model(model_spcam)
test_dataset_cam4 = torch.load("/ssddg1/gaia/fixed/cam4-famip-30m-timestep_4_test.pt")

In [None]:
test_dataset_cam4["index"]

In [None]:
data = torch.load("/ssddg1/gaia/fixed/cam4-famip-30m-timestep_4_val.pt")

In [None]:
((data["x"].shape[0]*data["x"].shape[1])*10)/(96*144*365*3)

In [None]:
temp = torch.load("/ssddg1/gaia/fixed/cam4-famip-30m-timestep_4_val.pt")

In [None]:
temp["x"][:,0,134:135,...].max()

In [None]:
import xarray as xr

In [None]:
temp = xr.load_dataset("/proj/gaia-climate/data/cam4_upload4/rF_AMIP_CN_CAM4--torch-test.cam2.h1.1979-01-01-00000.nc")

In [None]:
! ~/aws-cli/bin/aws s3 cp s3://ff350d3a-89fc-11ec-a398-ac1f6baca408/cam4-famip-30m-timestep-third-upload/rF_AMIP_CN_CAM4--torch-test.cam2.h1.1981-12-18-00000.nc /proj/gaia-climate/data/cam4_upload5/

In [None]:
temp2 = xr.load_dataset("/proj/gaia-climate/data/cam4_upload5/rF_AMIP_CN_CAM4--torch-test.cam2.h1.1981-12-18-00000.nc")

In [None]:
import hvplot.xarray
import holoviews as hv
hv.extension("bokeh")

In [None]:
print(list(temp2.variables))

In [None]:
temp.hvplot.image()

In [None]:
temp.time

In [None]:
temp["FSNS"].max(dim = ["lat","lon"]).hvplot.scatter(size = 100,padding = .1)

In [None]:
for k,(s,e) in model_cam4.hparams["input_index"].items():
    if k.startswith("F"):
        temp

In [None]:
import tqdm.auto as tqdm

In [None]:
dfs = []

for model in ["cam4","spcam"]:
    test_dataset = test_dataset_cam4 if model == "cam4" else test_dataset_spcam
    for type in ["x","y"]:
        type_name = "input" if type == "x" else "output"
        for k,(s,e) in tqdm.tqdm(list(model_cam4.hparams[f"{type_name}_index"].items())):
            if e-s>1:
                temp = pd.DataFrame(test_dataset[type][:,s:e].numpy(),columns = levels["spcam"]).sample(100000)
                temp.columns.name = "pressure"
                temp.name = k
                temp = temp.T.unstack()
                temp.name = "val"
                temp = temp.reset_index().iloc[:,1:]
                temp["variable"] = k
            else:
                temp = pd.DataFrame(test_dataset[type][:,s:e].numpy(),columns =["val"]).sample(100000)
                temp["variable"] = k
        
            temp["model"] = model
            temp["type"] = type_name
            dfs.append(temp) 

        # break
        
dfs = pd.concat(dfs,ignore_index=True)
# x_cam4[:,

In [None]:
dfs.pressure.isna().sum()

In [None]:
temp = dfs.loc[~dfs.pressure.isna()].hvplot.density("val",by="model",groupby=["variable","pressure"]).opts(width = 400,height = 300, legend_position = "top_right")\
.layout("variable").opts(shared_axes =False).cols(2)
hv.save(temp, "temp.html")#,widget_location = "top")


In [None]:
model_cam4.hparams["input_index"] == model_spcam.hparams["input_index"]

In [None]:
vars_2d = dfs.loc[dfs.pressure.isna()].variable.unique()

In [None]:
print(dfs.loc[dfs.variable.isin(vars_2d)].groupby(["variable","model"])[["val"]].mean().to_markdown())

In [None]:
temp = dfs.loc[dfs.variable.isin(vars_2d)].hvplot.density("val",by="model",groupby=["variable"]).opts(width = 400,height = 300, legend_position = "top_right")\
.layout("variable").opts(shared_axes =False).cols(2)
# hv.save(temp, "temp2.html")#,widget_location = "top")
temp

In [None]:
!~/aws-cli/bin/aws s3 cp FSNS_cam4.html s3://855da60d-505b-4eee-942c-e19fb87dcc5f/gaia/cam4_vs_spcam_comparison/FSNS_cam4.html