In [None]:
import holoviews as hv
hv.extension("bokeh")
import hvplot.pandas
import pandas as pd
from pathlib import Path
import sys
sys.path.append("/proj/gaia-climate/team/kirill/gaia-surrogate")
from gaia.training import load_hparams_file, get_dataset_from_model, get_checkpoint_file, get_levels
from gaia.data import unflatten_tensor, flatten_tensor
from gaia.config import levels
from gaia.plot import lats, lons
from gaia.models import TrainingModel
import tqdm.auto as tqdm
import torch

## model evaluated on each dataset

In [None]:
def get_metrics(y,yhat,reduce_dims = [0,3]):
    mse = (y-yhat).square().mean(dim = reduce_dims)
    var = y.var(reduce_dims, unbiased = False)
    skill = (1 - mse/var).clip(min = 0)
    return dict(rmse = mse.sqrt(), std = var.sqrt(), skill = skill) 


    
# mse, var, skill = get_2d_metrics(targets, predictions)

In [None]:
def plot_levels_vs_lats(x,z_name):
    if "skill" in z_name:
        cmap = "Greens"
    else:
        cmap = "Oranges"
        
    width = 350
    height = 300
        
    return hv.QuadMesh((lats, levels["spcam"], x),["lats","levels"],[z_name]).opts(invert_yaxis = True, colorbar = True, tools = ["hover"], cmap = cmap, width = width, height = height)

def plot_lats_vs_lons(x, z_name):
    if "skill" in z_name:
        cmap = "Greens"
    else:
        cmap = "Oranges"

        
    width = 400
    height = 300
        
    return hv.QuadMesh((lons, lats, x),["lons","lats"],[z_name]).opts(invert_yaxis = True, colorbar = True, tools = ["hover"], cmap = cmap, width = width, height = height)

def plot_lats_vs_metric(x, z_name):
        
    width = 400
    height = 300
        
    return hv.Curve((lats, x),["lats"],[z_name]).opts( tools = ["hover"],  width = width, height = height)


In [None]:
from collections import OrderedDict

In [None]:
def make_plots(targets, predictions, output_index):
    ### make lats vs level plots for 2d vars
    
    metric_dict = get_metrics(targets, predictions, reduce_dims = [0,3])
    
    plots = OrderedDict()
    
    for k,v in output_index.items():
        s,e = v
        if e-s > 1:
            for metric_name, metric_value in metric_dict.items():
                plot_title = f"{metric_name}_{k}"
                temp  = plot_levels_vs_lats(metric_value[s:e],f"{k}_std_units" if metric_name != "skill" else "skill")
                plots[(k,metric_name)] = temp#.opts(title = plot_title)
                
                
    metric_dict = get_metrics(targets, predictions, reduce_dims = [0])

    
    
    for k,v in output_index.items():
        s,e = v
        if e-s == 1:
            for metric_name, metric_value in metric_dict.items():
                plot_title = f"{metric_name}_{k}"

                temp = plot_lats_vs_lons(metric_value[s:e].squeeze(), f"{k}_std_units" if metric_name != "skill" else "skill")
                plots[(k,metric_name)] = temp#.opts(title = plot_title)
    

    
    return plots
        
        
    
    

### Evaluate models trained on [cam4, spcam] on [cam4 ,spcam]

In [None]:
all_plots = OrderedDict()

for model_name in tqdm.tqdm(["cam4","spcam"]):
    for dataset_name in tqdm.tqdm(["cam4","spcam"]):
        

        model_dir = f"../fine-tune/lightning_logs/base_{model_name}"
        dataset = f"{dataset_name}_fixed"

        model = TrainingModel.load_from_checkpoint(get_checkpoint_file(model_dir), map_location="cpu").eval()
        predictions = torch.load(model_dir+f"/predictions_{dataset}.pt")
        test_dataset, test_loader  = get_dataset_from_model(model, dataset =dataset )
        targets =  unflatten_tensor(test_dataset["y"])
        plots = make_plots(targets, predictions, model.hparams.output_index)
        
        for k,v in plots.items():
            new_key = (model_name, dataset_name) + k
            all_plots[new_key] = v
            
        
        

In [None]:
import holoviews as hv
hv.extension("bokeh")

In [None]:
from bokeh.themes import built_in_themes
print(built_in_themes.keys())
hv.renderer('bokeh').theme = built_in_themes['dark_minimal']

In [None]:
hv.renderer('bokeh').theme = 'caliber'


temp = hv.HoloMap(OrderedDict({k:v for k,v in all_plots.items() if "PREC" not in k[2]}),sort = False, kdims = ["model","dataset","variable","metric"])
temp = temp.layout(["model","metric"]).cols(3)
hv.save(temp,"levels_vs_lats.html")

In [None]:
temp = hv.HoloMap(OrderedDict({k:v for k,v in all_plots.items() if "PREC" in k[2]}),sort = False, kdims = ["model","dataset","variable","metric"])
temp = temp.layout(["model","metric"]).cols(3)
hv.save(temp,"lons_vs_lats.html")
temp