In [None]:
import holoviews as hv
hv.extension("bokeh")
import hvplot.pandas
import pandas as pd
from pathlib import Path
import sys
sys.path.append("/proj/gaia-climate/team/kirill/gaia-surrogate")
from gaia.training import load_hparams_file, get_dataset_from_model, get_checkpoint_file, get_levels
from gaia.data import unflatten_tensor, flatten_tensor
from gaia.config import levels
from gaia.plot import lats, lons, get_land_outline
from gaia.models import TrainingModel
import tqdm.auto as tqdm
import torch

In [None]:
lons = torch.tensor([l if l<=180 else l-360 for l in lons])
lon_vals,lon_idx =  lons.sort() 
lons = lon_vals.tolist()
outline = get_land_outline()

## model evaluated on each dataset

In [None]:
def get_metrics(y,yhat,reduce_dims = [0,3], y2  = None):
    mse = (y-yhat).square().mean(dim = reduce_dims)

    if y2 is None:
        var = y.var(reduce_dims, unbiased = False)
    else:
        var = y2.var(reduce_dims, unbiased = False)

    skill = (1 - mse/var).clip(min = 0)
    return dict(rmse = mse.sqrt(), std = var.sqrt(), skill = skill) 


    
# mse, var, skill = get_2d_metrics(targets, predictions)

In [None]:
def plot_levels_vs_lats(x,z_name):
    if "skill" in z_name:
        cmap = "Greens"
    else:
        cmap = "Oranges"
        
    width = 350
    height = 300
        
    return hv.QuadMesh((lats, levels["spcam"], x),["lats","levels"],[z_name]).opts(invert_yaxis = True, colorbar = True, tools = ["hover"], cmap = cmap, width = width, height = height)

def plot_lats_vs_lons(x, z_name):
    if "skill" in z_name:
        cmap = "Greens"
    else:
        cmap = "Oranges"

        
    width = 400
    height = 300

    x = x[:,lon_idx]
        
    return hv.QuadMesh((lons, lats, x),["lons","lats"],[z_name]).opts(invert_yaxis = False, colorbar = True, tools = ["hover"], cmap = cmap, width = width, height = height)


def plot_lats_vs_metric(x, z_name):
        
    width = 400
    height = 300
        
    return hv.Curve((lats, x),["lats"],[z_name]).opts( tools = ["hover"],  width = width, height = height)


In [None]:
from collections import OrderedDict

In [None]:
def make_plots(targets, predictions, output_index, true_predictions = None):
    ### make lats vs level plots for 2d vars
    
    metric_dict = get_metrics(targets, predictions, reduce_dims = [0,3], y2 = true_predictions)
    
    plots = OrderedDict()
    
    for k,v in output_index.items():
        s,e = v
        if e-s > 1:
            for metric_name, metric_value in metric_dict.items():
                plot_title = f"{metric_name}_{k}"
                temp  = plot_levels_vs_lats(metric_value[s:e],f"{k}_std_units" if metric_name != "skill" else "skill")
                plots[(k,metric_name)] = temp#.opts(title = plot_title)
                
                
    metric_dict = get_metrics(targets, predictions, reduce_dims = [0],y2 = true_predictions)

    
    
    for k,v in output_index.items():
        s,e = v
        if e-s == 1:
            for metric_name, metric_value in metric_dict.items():
                plot_title = f"{metric_name}_{k}"

                temp = plot_lats_vs_lons(metric_value[s:e].squeeze(), f"{k}_std_units" if metric_name != "skill" else "skill")
                plots[(k,metric_name)] = temp#.opts(title = plot_title)
    

    
    return plots
        
        
    
    

### Evaluate models trained on [cam4, spcam] on [cam4 ,spcam]

In [None]:
all_plots = OrderedDict()

for model_name in tqdm.tqdm(["cam4","spcam"]):
    for dataset_name in tqdm.tqdm(["cam4","spcam"]):
        

        model_dir = f"../fine-tune/lightning_logs/base_{model_name}"
        dataset = f"{dataset_name}_fixed"

        model = TrainingModel.load_from_checkpoint(get_checkpoint_file(model_dir), map_location="cpu").eval()
        predictions = torch.load(model_dir+f"/predictions_{dataset}.pt")
        test_dataset, test_loader  = get_dataset_from_model(model, dataset =dataset )
        targets =  unflatten_tensor(test_dataset["y"])
        plots = make_plots(targets, predictions, model.hparams.output_index)
        
        for k,v in plots.items():
            new_key = (model_name, dataset_name) + k
            all_plots[new_key] = v
            
        
        

In [None]:
import holoviews as hv
hv.extension("bokeh")

In [None]:
from bokeh.themes import built_in_themes
print(built_in_themes.keys())
hv.renderer('bokeh').theme = built_in_themes['dark_minimal']

In [None]:
hv.renderer('bokeh').theme = 'caliber'


temp = hv.HoloMap(OrderedDict({k:v for k,v in all_plots.items() if "PREC" not in k[2]}),sort = False, kdims = ["model","dataset","variable","metric"])
temp = temp.layout(["model","metric"]).cols(3)
hv.save(temp,"levels_vs_lats.html")

In [None]:
hv.renderer('bokeh').theme = 'caliber'

temp = hv.HoloMap(OrderedDict({k:v for k,v in all_plots.items() if "PREC" in k[2]}),sort = False, kdims = ["model","dataset","variable","metric"])
temp = (temp*outline.opts(color = "black", line_width = 1)).layout(["model","metric"]).cols(3)
hv.save(temp,"lons_vs_lats.html")
temp

In [None]:
hv.save(temp,"lons_vs_lats.html")


### Compare predictions of cam4 and spcam trained models on either cam4 inputs and spcam inputs


In [None]:
all_plots = OrderedDict()

for dataset_name in tqdm.tqdm(["cam4","spcam"]):
    
        model_dir = f"../fine-tune/lightning_logs/base_cam4"
        dataset = f"{dataset_name}_fixed"
        model = TrainingModel.load_from_checkpoint(get_checkpoint_file(model_dir), map_location="cpu").eval()
        
        targets = torch.load(model_dir+f"/predictions_{dataset}.pt")
        
        
        model_dir = f"../fine-tune/lightning_logs/base_spcam"
        dataset = f"{dataset_name}_fixed"
        model = TrainingModel.load_from_checkpoint(get_checkpoint_file(model_dir), map_location="cpu").eval()
        
        predictions = torch.load(model_dir+f"/predictions_{dataset}.pt")
        
        plots = make_plots(targets, predictions, model.hparams.output_index)
        
        model_name = "cam4_vs_spcam"
        
        for k,v in plots.items():
            new_key = (model_name, dataset_name) + k
            all_plots[new_key] = v
            
            
        plots = make_plots(predictions, targets, model.hparams.output_index)
        
        model_name = "spcam_vs_cam4"
        
        for k,v in plots.items():
            new_key = (model_name, dataset_name) + k
            all_plots[new_key] = v
            

In [None]:
hv.renderer('bokeh').theme = 'caliber'


temp = hv.HoloMap(OrderedDict({k:v for k,v in all_plots.items() if "PREC" not in k[2]}),sort = False, kdims = ["model","dataset","variable","metric"])
temp = temp.layout(["model","metric"]).cols(3)
hv.save(temp,"levels_vs_lats_cross.html")

In [None]:
temp = hv.HoloMap(OrderedDict({k:v for k,v in all_plots.items() if "PREC" in k[2]}),sort = False, kdims = ["model","dataset","variable","metric"])
temp = (temp*outline.opts(color = "black", line_width = 1).layout(["model","metric"]).cols(3)
hv.save(temp,"lons_vs_lats_cross.html")
temp

### Normalize by the Truth

In [None]:
all_plots = OrderedDict()

for dataset_name in tqdm.tqdm(["cam4","spcam"]):
    
        model_dir = f"../fine-tune/lightning_logs/base_cam4"
        dataset = f"{dataset_name}_fixed"
        model = TrainingModel.load_from_checkpoint(get_checkpoint_file(model_dir), map_location="cpu").eval()
        
        targets = torch.load(model_dir+f"/predictions_{dataset}.pt")
        
        if dataset_name in model_dir:
              test_dataset, test_loader  = get_dataset_from_model(model, dataset =dataset_name )
              true_predictions =  unflatten_tensor(test_dataset["y"])
        
        model_dir = f"../fine-tune/lightning_logs/base_spcam"
        dataset = f"{dataset_name}_fixed"
        model = TrainingModel.load_from_checkpoint(get_checkpoint_file(model_dir), map_location="cpu").eval()
        
        predictions = torch.load(model_dir+f"/predictions_{dataset}.pt")

        if dataset_name in model_dir:
              test_dataset, test_loader  = get_dataset_from_model(model, dataset =dataset_name )
              true_predictions =  unflatten_tensor(test_dataset["y"])
       
        
        plots = make_plots(targets, predictions, model.hparams.output_index, true_predictions=true_predictions)
        
        model_name = "cam4_vs_spcam"
        
        for k,v in plots.items():
            new_key = (model_name, dataset_name) + k
            all_plots[new_key] = v
            
            
        # plots = make_plots(predictions, targets, model.hparams.output_index)
        
        # model_name = "spcam_vs_cam4_on_spcam"
        
        # for k,v in plots.items():
        #     new_key = (model_name, dataset_name) + k
        #     all_plots[new_key] = v
            

In [None]:
hv.renderer('bokeh').theme = 'caliber'


temp = hv.HoloMap(OrderedDict({k:v for k,v in all_plots.items() if "PREC" not in k[2]}),sort = False, kdims = ["model","dataset","variable","metric"])
temp = temp.layout(["model","metric"]).cols(3)
hv.save(temp,"levels_vs_lats_cross_norm_truth.html")
temp

In [None]:
hv.renderer('bokeh').theme = 'caliber'

temp = hv.HoloMap(OrderedDict({k:v for k,v in all_plots.items() if "PREC" in k[2]}),sort = False, kdims = ["model","dataset","variable","metric"])
temp = (temp*outline.opts(color = "black", line_width = 1)).layout(["model","metric"]).cols(3)
hv.save(temp,"lons_vs_lats_cross_norm_truth.html")
temp

In [None]:
### dataset size
import glob
for f in glob.glob("/ssddg1/gaia/fixed/*.pt"):
    print(f)
    temp = torch.load(f)
    if "x" in temp:
        print(temp["x"].shape)
    if "y" in temp:
        print(temp["y"].shape)



# torch.load("/ssddg1/gaia/fixed/")

In [None]:
temp["x"].shape