In [None]:
import holoviews as hv
hv.extension("bokeh")
import hvplot.pandas
import pandas as pd
from pathlib import Path
import sys
sys.path.append("/proj/gaia-climate/team/kirill/gaia-surrogate")
from gaia.training import load_hparams_file, get_dataset_from_model, get_checkpoint_file, get_levels
from gaia.data import unflatten_tensor, flatten_tensor
from gaia.config import levels
from gaia.plot import lats, lons, get_land_outline
from gaia.models import TrainingModel
import tqdm.auto as tqdm
import torch

In [None]:
lons = torch.tensor([l if l<=180 else l-360 for l in lons])
lon_vals,lon_idx =  lons.sort() 
lons = lon_vals.tolist()
outline = get_land_outline()

## model evaluated on each dataset

In [None]:
def get_metrics(y,yhat,reduce_dims = [0,3], y2  = None):
    mse = (y-yhat).square().mean(dim = reduce_dims)

    if y2 is None:
        var = y.var(reduce_dims, unbiased = False)
    else:
        var = y2.var(reduce_dims, unbiased = False)

    skill = (1 - mse/var).clip(min = 0)
    return dict(rmse = mse.sqrt(), std = var.sqrt(), skill = skill) 


    
# mse, var, skill = get_2d_metrics(targets, predictions)

In [None]:
skill_cmap = "dimgray"
def plot_levels_vs_lats(x,z_name):
    if "skill" in z_name:
        cmap = skill_cmap
    else:
        cmap = "fire"
        
    width = 350
    height = 300
        
    return hv.QuadMesh((lats, levels["spcam"], x),["lats","levels"],[z_name]).opts(invert_yaxis = True, colorbar = True, tools = ["hover"], cmap = cmap, width = width, height = height)

def plot_lats_vs_lons(x, z_name):
    if "skill" in z_name:
        cmap = skill_cmap
    else:
        cmap = "fire"

        
    width = 400
    height = 300

    x = x[:,lon_idx]
        
    return hv.QuadMesh((lons, lats, x),["lons","lats"],[z_name]).opts(invert_yaxis = False, colorbar = True, tools = ["hover"], cmap = cmap, width = width, height = height)


def plot_lats_vs_metric(x, z_name):
        
    width = 400
    height = 300
        
    return hv.Curve((lats, x),["lats"],[z_name]).opts( tools = ["hover"],  width = width, height = height)


In [None]:
from collections import OrderedDict

In [None]:
def make_plots(targets, predictions, output_index, true_predictions = None):
    ### make lats vs level plots for 2d vars
    
    metric_dict = get_metrics(targets, predictions, reduce_dims = [0,3], y2 = true_predictions)
    
    plots = OrderedDict()
    
    for k,v in output_index.items():
        s,e = v
        if e-s > 1:
            for metric_name, metric_value in metric_dict.items():
                plot_title = f"{metric_name}_{k}"
                temp  = plot_levels_vs_lats(metric_value[s:e],f"{k}_std_units" if metric_name != "skill" else "skill")
                plots[(k,metric_name)] = temp#.opts(title = plot_title)
                
                
    metric_dict = get_metrics(targets, predictions, reduce_dims = [0],y2 = true_predictions)

    
    
    for k,v in output_index.items():
        s,e = v
        if e-s == 1:
            for metric_name, metric_value in metric_dict.items():
                plot_title = f"{metric_name}_{k}"

                temp = plot_lats_vs_lons(metric_value[s:e].squeeze(), f"{k}_std_units" if metric_name != "skill" else "skill")
                plots[(k,metric_name)] = temp#.opts(title = plot_title)
    

    
    return plots

def compute_metrics(targets, predictions, output_index, true_predictions = None):

    # metric_dict = get_metrics(targets, predictions, reduce_dims = [0],y2 = true_predictions)

    metric_dict = get_metrics(targets, predictions, reduce_dims = [0,2,3],y2 = true_predictions)

    metric_dict["variable"]  = [f"{k}_{l:02}" if e-s>1 else k for k,(s,e) in output_index.items() for l in range(e-s)]
    
    
    
    # for k,v in output_index.items():
    #     s,e = v
        
    #     # if k == "PTEQ":
    #     #     s = s+8 #ignore top levels
        
    #     metric_dict = get_metrics(targets[:,s:e,...], predictions[:,s:e,...], reduce_dims = [0,1,2,3],y2 = true_predictions[:,s:e,...] if true_predictions is not None else None)
    #     metric_dict["variable"] = k
    #     metrics_out.append(metric_dict)
        
    return metric_dict
    
    
        
    
    

### Evaluate models trained on [cam4, spcam] on [cam4 ,spcam]

In [None]:
all_plots = OrderedDict()

for model_name in tqdm.tqdm(["cam4","spcam"]):
    for dataset_name in tqdm.tqdm(["cam4","spcam"]):
        

        model_dir = f"../fine-tune/lightning_logs/base_{model_name}"
        dataset = f"{dataset_name}_paper" if "cam4" in dataset_name else f"{dataset_name}_fixed"

        model = TrainingModel.load_from_checkpoint(get_checkpoint_file(model_dir), map_location="cpu").eval()
        predictions = torch.load(model_dir+f"/predictions_{dataset}.pt")
        test_dataset, test_loader  = get_dataset_from_model(model, dataset =dataset )
        targets =  unflatten_tensor(test_dataset["y"])
        plots = make_plots(targets, predictions, model.hparams.output_index)
        
        for k,v in plots.items():
            new_key = (model_name, dataset_name) + k
            all_plots[new_key] = v

        
            
        
        

In [None]:
import holoviews as hv
hv.extension("bokeh")

In [None]:
from bokeh.themes import built_in_themes
print(built_in_themes.keys())
hv.renderer('bokeh').theme = built_in_themes['dark_minimal']

In [None]:
all_plots.keys()

In [None]:
hv.renderer('bokeh').theme = 'caliber'
import numpy as np
# for cmap in ["bgyw","dimgray","bmy","fire"]:
for cmap in ["bgyw"]:
    colorbar = True
    temp = hv.HoloMap(OrderedDict({k[:-1]:v.opts(ylabel="pressure", width = 420) for k,v in all_plots.items() if "PREC" not in k[2] and k[3] == "skill"}),sort = False, kdims = ["model","dataset","variable"])
    for v in ["PTTEND","PTEQ"]:
        temp1 = temp[:,:,v].layout(["model","dataset"]).cols(2).opts(hv.opts.QuadMesh(cmap = cmap, colorbar=False))
        hv.save(temp1,f"levels_vs_lats_{v}_{cmap}.html")
    hv.save(hv.Image(np.array([[0,1.],[0,1.]])).opts(colorbar = True,cmap = cmap), f"colorbar_{cmap}.html")
# temp

In [None]:
# temp2 = temp.select(model="cam4",variable="PTEQ")
# from bokeh.io import export_png
# export_png(temp2,filename = "temp1.png")
# hv.save(temp2,"temp2.png")
# temp2

In [None]:
!rm -R /home/kirill.trapeznikov/chromedriver_path/chromedriver

In [None]:
hv.renderer('bokeh').theme = 'caliber'
for cmap in ["bgyw","dimgray","bmy","fire"]:
    temp = hv.HoloMap(OrderedDict({k[:-1]:v.opts(colorbar=False) for k,v in all_plots.items() if "PRECT" in k[2] and k[3] == "skill"}),sort = False, kdims = ["model","dataset","variable"])
    temp = (temp*outline.opts(line_color = "black", line_width = 1)).layout(["model","dataset"]).cols(2)
    temp = temp.opts(hv.opts.QuadMesh(cmap = cmap))
    hv.save(temp,f"lons_vs_lats_{cmap}.html")
# temp

### compute top level performance


In [None]:
import pandas as pd
all_metrics = []

for model_name in tqdm.tqdm(["cam4","spcam"]):
    for dataset_name in tqdm.tqdm(["cam4","spcam"]):
        

        model_dir = f"../fine-tune/lightning_logs/base_{model_name}"
        dataset = f"{dataset_name}_fixed"

        model = TrainingModel.load_from_checkpoint(get_checkpoint_file(model_dir), map_location="cpu").eval()
        predictions = torch.load(model_dir+f"/predictions_{dataset}.pt")
        test_dataset, test_loader  = get_dataset_from_model(model, dataset =dataset )
        targets =  unflatten_tensor(test_dataset["y"])
        plots = compute_metrics(targets, predictions, model.hparams.output_index)
        
        plots = pd.DataFrame(plots)
        plots["model_name"] = model_name
        plots["dataset_name"] = dataset_name
        all_metrics.append(plots)

all_metrics = pd.concat(all_metrics, ignore_index=True)
# all_metrics.to_csv("benchmarks.csv")
        
            

In [None]:
import hvplot.pandas

In [None]:
all_metrics["variable_top"] = all_metrics.variable.apply(lambda a: a.split("_")[0])
all_metrics["level_number"] = all_metrics.variable.apply(lambda a: a.split("_")[-1])
all_metrics["level"] = all_metrics.level_number.apply(lambda a: levels["spcam"][int(a)] if a.isnumeric() else None)
all_metrics.columns = [c.split("_")[0] for c in all_metrics.columns]


In [None]:
temp = all_metrics.iloc[:,[0,1,2,4,5,6,8]]
temp = temp.loc[~temp.variable.str.startswith("PREC")]
temp.to_csv("benchmark_levels_vs_metrics.csv",index = False)

In [None]:
# def temp_func(model, dataset,variable):
#     hv.Curve(temp.query(f"model=='{model}' & dataset =='{dataset}' & variable =='{variable}'"),["level"],["skill"])*\
#     hv.Curve(temp.query(f"model=='{model}' & dataset =='{dataset}' & variable =='{variable}'"),["level"],["skill"])
    
# hv.DynamicMap(lambda  

In [None]:
import pandas as pd
import holoviews as hv
import hvplot.pandas
hv.extension("bokeh")
data = pd.read_csv("benchmark_levels_vs_metrics.csv")
# data

In [None]:
model_colors = {"cam4" : "orange",
                "spcam": "purple"}

def temp_func(model,dataset,variable):
    data_subset = data.query(f"model=='{model}' &  dataset=='{dataset}' &  variable=='{variable}'")
    line_dash = "solid" if model == dataset else "dashed"
    return  (hv.Curve(data_subset,  ["level"],["skill"], label = f"{model} on {dataset}" ).opts(color = model_colors[model], line_dash = line_dash, line_width = 1,show_grid = True) * \
             hv.Scatter(data_subset,["level"],["skill"]).opts(color = model_colors[model], line_dash = line_dash, size = 5))
plots = []
for v in ["PTEQ","PTTEND"]:
    for m in ["cam4","spcam"]:
        for d in ["cam4","spcam"]:
            plots.append(temp_func(m,d,v))

plots = (hv.Overlay(plots[:4]).opts(width = 400,show_legend= False,title = "PTEQ") + hv.Overlay(plots[4:]).opts(width = 530,title = "PTTEND", legend_position="right",legend_opts={"title":"model on dataset"}))
hv.save(plots,"level_vs_skill.html")
plots



In [None]:
data["color_field"] = data.model.apply(lambda a: "blue" if a == "cam4" else "orange")
data["line_field"] =  data.dataset.apply(lambda a: [0,1] if a == "cam4" else [1,1])
# temp

In [None]:
data["line_field"].unique()

In [None]:
data.hvplot.line(x  = "level", y = "skill", by = ["model","dataset"], line_width = 1, color = "color_field",  groupby = ["variable"], grid = True,)


In [None]:
all_metrics = all_metrics.applymap(lambda a: a.item() if torch.is_tensor(a) else a)
all_metrics.to_csv("benchmarks.csv",index = False)

In [None]:
print(all_metrics.to_markdown(index = False))

### Compare predictions of cam4 and spcam trained models on either cam4 inputs and spcam inputs


In [None]:
all_plots = OrderedDict()

for dataset_name in tqdm.tqdm(["cam4","spcam"]):
    
        model_dir = f"../fine-tune/lightning_logs/base_cam4"
        dataset = f"{dataset_name}_fixed"
        model = TrainingModel.load_from_checkpoint(get_checkpoint_file(model_dir), map_location="cpu").eval()
        
        targets = torch.load(model_dir+f"/predictions_{dataset}.pt")
        
        
        model_dir = f"../fine-tune/lightning_logs/base_spcam"
        dataset = f"{dataset_name}_fixed"
        model = TrainingModel.load_from_checkpoint(get_checkpoint_file(model_dir), map_location="cpu").eval()
        
        predictions = torch.load(model_dir+f"/predictions_{dataset}.pt")
        
        plots = make_plots(targets, predictions, model.hparams.output_index)
        
        model_name = "cam4_vs_spcam"
        
        for k,v in plots.items():
            new_key = (model_name, dataset_name) + k
            all_plots[new_key] = v
            
            
        plots = make_plots(predictions, targets, model.hparams.output_index)
        
        model_name = "spcam_vs_cam4"
        
        for k,v in plots.items():
            new_key = (model_name, dataset_name) + k
            all_plots[new_key] = v
            

In [None]:
hv.renderer('bokeh').theme = 'caliber'


temp = hv.HoloMap(OrderedDict({k:v for k,v in all_plots.items() if "PREC" not in k[2]}),sort = False, kdims = ["model","dataset","variable","metric"])
temp = temp.layout(["model","metric"]).cols(3)
hv.save(temp,"levels_vs_lats_cross.html")

In [None]:
temp = hv.HoloMap(OrderedDict({k:v for k,v in all_plots.items() if "PREC" in k[2]}),sort = False, kdims = ["model","dataset","variable","metric"])
temp = (temp*outline.opts(color = "black", line_width = 1).layout(["model","metric"]).cols(3)
hv.save(temp,"lons_vs_lats_cross.html")
temp

### Normalize by the Truth

In [None]:
all_plots = OrderedDict()

dataset_temp = {}

for dataset_name in tqdm.tqdm(["cam4","spcam"]):
    
        model_dir = f"../fine-tune/lightning_logs/base_cam4"
        dataset = f"{dataset_name}_paper" if "cam4" in dataset_name else f"{dataset_name}_fixed"
        model = TrainingModel.load_from_checkpoint(get_checkpoint_file(model_dir), map_location="cpu").eval()
        
        targets = torch.load(model_dir+f"/predictions_{dataset}.pt")
        
        if dataset_name in model_dir:
              test_dataset, test_loader  = get_dataset_from_model(model, dataset =dataset_name )
              true_predictions =  unflatten_tensor(test_dataset["y"])
              dataset_temp[dataset_name] = true_predictions
        
        model_dir = f"../fine-tune/lightning_logs/base_spcam"
        dataset = f"{dataset_name}_paper" if "cam4" in dataset_name else f"{dataset_name}_fixed"
        model = TrainingModel.load_from_checkpoint(get_checkpoint_file(model_dir), map_location="cpu").eval()
        
        predictions = torch.load(model_dir+f"/predictions_{dataset}.pt")

        if dataset_name in model_dir:
              test_dataset, test_loader  = get_dataset_from_model(model, dataset =dataset_name )
              true_predictions =  unflatten_tensor(test_dataset["y"])
       
        
        plots = make_plots(targets, predictions, model.hparams.output_index, true_predictions=true_predictions)
        
        model_name = "cam4_vs_spcam"
        
        for k,v in plots.items():
            new_key = (model_name, dataset_name) + k
            all_plots[new_key] = v
            
            
        # plots = make_plots(predictions, targets, model.hparams.output_index)
        
        # model_name = "spcam_vs_cam4_on_spcam"
        
        # for k,v in plots.items():
        #     new_key = (model_name, dataset_name) + k
        #     all_plots[new_key] = v
            

In [None]:
hv.renderer('bokeh').theme = 'caliber'

for cmap in ["bgyw","dimgray","bmy","fire"]:

    temp = hv.HoloMap(OrderedDict({k[:-1]:v.opts(colorbar=False,ylabel = "pressure",width = 420) for k,v in all_plots.items() if "PREC" not in k[2] and k[3]=="skill"}),sort = False, kdims = ["model","dataset","variable"])

    temp = temp.layout(["model","dataset","model"]).cols(4)
    temp = temp.opts(hv.opts.QuadMesh(cmap = cmap, colorbar=False))

    hv.save(temp1,f"levels_vs_lats_cross_norm_truth_{cmap}.html")
    # temp

In [None]:
hv.renderer('bokeh').theme = 'caliber'
for cmap in ["bgyw","dimgray","bmy","fire"]:

    temp = hv.HoloMap(OrderedDict({k[:-1]:v.opts(colorbar=False) for k,v in all_plots.items() if "PRECT" in k[2] and k[3]=="skill"}),sort = False, kdims = ["model","dataset","variable"])
    temp = (temp*outline.opts(color = "black", line_width = 1)).layout(["model","dataset"]).cols(2)
    temp = temp.opts(hv.opts.QuadMesh(cmap = cmap, colorbar=False))

    hv.save(temp,f"lons_vs_lats_cross_norm_truth_{cmap}.html")
    temp

In [None]:
### dataset size
import glob
for f in glob.glob("/ssddg1/gaia/fixed/*.pt"):
    print(f)
    temp = torch.load(f)
    if "x" in temp:
        print(temp["x"].shape)
    if "y" in temp:
        print(temp["y"].shape)



# torch.load("/ssddg1/gaia/fixed/")

In [None]:
### save comparison plots

out = []

for model_name in tqdm.tqdm(["cam4","spcam"]):
    for dataset_name in tqdm.tqdm(["cam4","spcam"]):
        model_dir = f"../fine-tune/lightning_logs/base_{model_name}"
        dataset = f"{dataset_name}_paper" if "cam4" in dataset_name else f"{dataset_name}_fixed"
        out.append(pd.read_json(f"{model_dir}/test_results_{dataset}.json"))
        out[-1]["model"] = model_name
        out[-1]["dataset"] = dataset_name
        
out = pd.concat(out,ignore_index=True)
out.T.iloc[::-1].to_csv("top_level_performance.csv")
print(out.T.iloc[::-1].to_markdown())

In [None]:
out[-1]

In [None]:
temp["x"].shape