In [1]:
import holoviews as hv
hv.extension("bokeh")
import hvplot.pandas
import pandas as pd
from pathlib import Path
import sys
sys.path.append("/proj/gaia-climate/team/kirill/gaia-surrogate")
from gaia.training import load_hparams_file
import tqdm.auto as tqdm
from gaia.config import Config, levels
from gaia.training import main


2024-05-10 23:50:33,222 - gaia.gaia.config - INFO - no dataset provided ... you must be loading it from an existing model


In [2]:
def get_metrics(model_dir, dataset = "spcam"):

    try:
        metrics = pd.read_json(next(Path(model_dir).glob(f"*{dataset}*"))).T.squeeze()
    except Exception:
        return pd.Series()
    hparams = load_hparams_file(model_dir)
    metrics["subsample"] = int(hparams["dataset_params"]["train"].get("subsample",1))
    metrics["batch_size"] = int(hparams["dataset_params"]["train"].get("batch_size",1))

    metrics["is_finetuned"]  = int(hparams.get("is_finetuned",False))
    metrics["base_model"] = "random"
    metrics["samples"] = 24 * 96 * 144 * 54 // metrics["subsample"]
    metrics["lr"] = hparams["lr"]
    
    if "base_cam4" in str(model_dir):
        # return pd.Series()
        metrics["is_finetuned"] = True
        metrics["samples"] = 10

    # if "base_spacm" in str(model_dir):
    #     metrics["base_model"] = "cam4" 
    #     # metrics["base_model"] = "cam4"
    #     # metrics["is_finetuned"] = True
    #     # metrics["subsample"] = 

    
    if metrics["is_finetuned"]:
        metrics["base_model"] = "cam4"        
        
    metrics["model_dir"] = model_dir
    metrics["seed"] = hparams.get("seed")
    return metrics

data = pd.concat([get_metrics(f).to_frame().T for f in tqdm.tqdm(Path("../fine-tune/lightning_logs").glob("*"))]).dropna()
data = data.iloc[:,1:]
data.columns = [c.replace("test_skill_ave_trunc_","") for c in data.columns]





0it [00:00, ?it/s]

In [3]:
avg = data.query("base_model=='cam4'").groupby(["samples"])["PRECT"].mean().to_frame()

var = "PRECT"
for s,v in avg[var].items():
    temp = data.query(f"samples=={s} & base_model=='cam4'")
    t = (temp[var] - v).abs()
    avg.loc[s,"model_dir"] = temp.iloc[t.argmin()]["model_dir"]
    avg.loc[s,var+"_closest"] = temp.iloc[t.argmin()][var]
avg.head()

  avg.loc[s,"model_dir"] = temp.iloc[t.argmin()]["model_dir"]


Unnamed: 0_level_0,PRECT,model_dir,PRECT_closest
samples,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
10.0,0.667111,../fine-tune/lightning_logs/base_cam4,0.667111
68.0,0.76637,../fine-tune/lightning_logs/version_223,0.777028
136.0,0.808655,../fine-tune/lightning_logs/version_269,0.808867
273.0,0.826144,../fine-tune/lightning_logs/version_267,0.826813
546.0,0.841926,../fine-tune/lightning_logs/version_347,0.84166


### predict

In [None]:
gpu = 0
samples = [136,1093,8748,279936,17915904]
# samples = [17915904]
for s,model_dir in avg["model_dir"].items():
    if s not in samples:
        continue
    print(s)
    print(model_dir)
    config = Config(
            {
                "mode": "predict",
                "dataset_params": {
                    "dataset": "spcam_fixed",
                    # "train": {"subsample": 1, "batch_size": max([64, (24 * 96 * 144) // 1])},
                    # "val": {"subsample": 1}
                },  # "subsample" : 16, "batch_size": 8 * 96 * 144},
                "trainer_params": {"gpus": [gpu], "max_epochs": 100},
                "model_params": {
                    "ckpt": str(model_dir)
                },
            }
        )

    model_dir = main(**config.config)
    # break

### Plot PRECT distribition

In [None]:
import torch
from gaia.plot import lats, lons
lats = torch.tensor(lats)
mask = lats.abs()<=20
samples = [136,1093,8748,279936,17915904]
datas = {}
var_index = 0
time_index = 1
datas["spcam_dataset"] = torch.load("/ssddg1/gaia/fixed/spcamclbm-nx-16-20m-timestep_4_test.pt")["y"][:,time_index,var_index,mask,:].ravel().numpy()
datas["ft_0"] = torch.load("../fine-tune/lightning_logs/base_cam4/predictions_spcam_fixed.pt")[:,0,mask,:].ravel().numpy()
for s in samples:
    data_file = str(avg.loc[s,"model_dir"] / "predictions_spcam_fixed.pt")
    datas[f"ft_{s:.1E}"] = torch.load(data_file)[:,0,mask,:].ravel().numpy()

In [None]:
datas = pd.DataFrame(datas)
# datas

In [None]:
from holoviews.operation.stats import univariate_kde
import hvplot.pandas

#### log mode

In [None]:
len(datas)

In [None]:
datas.columns.name = "models"
datas.name = "PRECT"
datas_sub = datas.sample(1000000).stack().to_frame("PRECT").reset_index()
datas_sub.iloc[:,-1]= np.log10(datas_sub.iloc[:,-1].values)
datas_sub

In [None]:
def get_kde(x,label):
    dist = hv.Distribution(x,kdims=["PRECT"],label = label)
    kde = univariate_kde(dist,n_samples = 1000,bandwidth=.1, filled = True)
    return kde

# dmap = hv.DynamicMap(lambda a: get_kde(datas[a].numpy()), kdims = ["samples"]).redim.values(samples = samples)
# dmap.overlay("samples")*get_kde(datas["ground_truth"]).opts(label="ground truth")

plots = hv.Overlay([get_kde(v,k).opts(logy=True, width = 800,height = 500, fill_alpha = 0, line_width = 3, line_dash = "dashed" if "spcam" in k else "solid") for k,v in tqdm.tqdm(datas_sub.groupby("models"))])
plots = plots.opts(hv.opts.Area(line_color = hv.Cycle(cycle = "Set1"), fill_color = hv.Cycle(cycle = "Set1")))
plots = plots.redim.range(PRECT=(-10.,-6.),PRECT_density = (10**-2.1,10**.1)).opts(legend_position = "right", xticks = [(i,f"{10**i:1.0e}") for i in range(-10,-5)])
plots

In [None]:
hv.save(plots,"prect_loglog_models.html")

#### not log mode

In [None]:
datas.columns.name = "models"
datas.name = "PRECT"
datas_sub = datas.sample(2000000).stack().to_frame("PRECT").reset_index()
# datas_sub.iloc[:,-1]= np.log10(datas_sub.iloc[:,-1].values)
datas_sub

In [None]:
def get_kde(x,label):
    dist = hv.Distribution(x,kdims=["PRECT"],label = label)
    kde = univariate_kde(dist,n_samples = 1000, filled = False)
    return kde

# dmap = hv.DynamicMap(lambda a: get_kde(datas[a].numpy()), kdims = ["samples"]).redim.values(samples = samples)
# dmap.overlay("samples")*get_kde(datas["ground_truth"]).opts(label="ground truth")

plots = hv.Overlay([get_kde(v,k).opts(logy=True, width = 800,height = 500,  line_width = 3, line_dash = "dashed" if "spcam" in k else "solid") for k,v in tqdm.tqdm(datas_sub.groupby("models",sort = False))])
# plots = plots.opts(hv.opts.Area(line_color = hv.Cycle(cycle = "Set1"), fill_color = hv.Cycle(cycle = "Set1")))
plots

In [None]:
plots = plots.opts(hv.opts.Curve(line_color = hv.Cycle(cycle = "Set1"), line_width = 2)).opts(legend_position = "bottom_left", show_grid = True)
# plots = plots.redim.range(PRECT_density = (10**0.1,10**6),PRECT=(2e-7,2e-5))
hv.save(plots,"prect_log_models.html")
plots

### Plot average precipitation over lats and time

In [None]:
import torch
from gaia.plot import lats, lons
import numpy as np
lats = torch.tensor(lats)
# mask = lats.abs()<=20
mask = torch.ones_like(lats,dtype = bool)
lons = np.array(lons)

In [None]:
samples = [136,1093,8748,279936,17915904]

datas = {}
var_index = 0
time_index = 1
full_dataset = torch.load("/ssddg1/gaia/fixed/spcamclbm-nx-16-20m-timestep_4_test.pt")
datas["spcam_dataset"] = full_dataset["y"][:,time_index,var_index,mask,:].mean(-2).numpy()
datas["ft_0"] = torch.load("../fine-tune/lightning_logs/base_cam4/predictions_spcam_fixed.pt")[:,0,mask,:].mean(-2).numpy()

In [None]:
# v = full_dataset["y"][:,time_index,var_index,:,:].mean(0).numpy()
# hv.Image((lons_shifted,lats,v[:,new_index]),["lons","lats"],["prec"]) + hv.Image((lons,lats,v),["lons","lats"],["prec"])

In [None]:
for s in tqdm.tqdm(samples):
    data_file = str(avg.loc[s,"model_dir"] / "predictions_spcam_fixed.pt")
    datas[f"ft_{s:.1E}"] = torch.load(data_file)[:,0,mask,:].mean(-2).numpy()

In [None]:
lons_shifted = lons.copy()
N = len(lons_shifted)
lons_shifted[N//2:] = lons_shifted[N//2:] - 360
new_index = list(range(N//2,N))+list(range(N//2))
lons_shifted = lons_shifted[new_index]
lons_shifted

In [None]:
for k,temp in list(datas.items()):
    start = pd.Timestamp(year = 2023,month=1,day=1,hour = 0,unit = "H")
    index = pd.date_range(start = start,freq = "12H",periods = len(temp),unit = "ns")
    index.name = "time"
    temp = pd.DataFrame(temp[:,new_index],columns = lons_shifted,index = index)
    temp.columns.name = "lons"
    datas[k] = temp
# temp

In [None]:
import holoviews as hv
hv.extension("bokeh")
import hvplot.pandas

In [None]:
from bokeh.models import BasicTickFormatter

In [None]:
def plot_one(model_name):
    temp = datas[model_name]
    out =  hv.Image((temp.columns.values,temp.index.values, temp.values),kdims = ["lons","time"],vdims = ["PRECT"]).opts(colorbar = True, invert_yaxis=True, logz=False, colorbar_opts = {"formatter":BasicTickFormatter(precision=1)}, colorbar_position = "bottom", height = 800,width = 200,cmap = "bgyw", yticks=30)
    if model_name !="spcam_dataset":
        return out.opts(yaxis=None)
    return out.opts(width = 260)
out = hv.DynamicMap(plot_one,kdims = ["model"]).redim.values(model = list(datas.keys())).layout("model").cols(len(datas))
hv.save(out,"prect_vs_time_lons_vs_model.html")
out

### PRECT vs FLNT

In [None]:
import torch
from gaia.plot import lats, lons
lats = torch.tensor(lats)
# mask = lats.abs()<=20
mask = torch.ones_like(lats,dtype = bool)

samples = [136,1093,8748,279936,17915904]
datas = {}
var_index = 0
time_index = 1
temp = torch.load("/ssddg1/gaia/fixed/spcamclbm-nx-16-20m-timestep_4_test.pt")
datas["spcam_dataset"] = temp["y"][:,time_index,var_index,mask,:].ravel().numpy()
FLNT = temp["x"][:,0,157,mask,:].ravel().numpy()
datas["ft_0"] = torch.load("../fine-tune/lightning_logs/base_cam4/predictions_spcam_fixed.pt")[:,0,mask,:].ravel().numpy()
for s in tqdm.tqdm(samples):
    data_file = str(avg.loc[s,"model_dir"] / "predictions_spcam_fixed.pt")
    datas[f"ft_{s:.1E}"] = torch.load(data_file)[:,0,mask,:].ravel().numpy()

In [None]:
import hvplot.pandas
import numpy as np

In [None]:
datas_df = pd.DataFrame(datas)

In [None]:
source_range = (FLNT.min(),FLNT.max())
target_range = (1e-7,2.5e-6)

def scale(x,sr,tr):
    return (x - sr[0])/(sr[1] - sr[0]) * (tr[1] -  tr[0]) + tr[0]
# hv.Distribution(scale(FLNT, source_range, target_range))
                                                            

In [None]:
datas_df["FLNT"] = scale(FLNT, source_range, target_range)
# datas_df["FLNT"] = FLNT

In [None]:
datas_df.sample(1000).iloc[:,-1].hvplot.kde()

In [None]:
# datas.columns.name = "models"
# datas.name = "PRECT"
datas_sub = datas_df.sample(100000)#.stack().to_frame("PRECT").reset_index()
# datas_sub.iloc[:,:-1]= np.log10(datas_sub.iloc[:,:-1].values)
datas_sub

In [None]:
from colorcet import bgyw

In [None]:
# bgyw

In [None]:
sorted(datas_sub.columns[:-1])

In [None]:
from holoviews.operation.stats import bivariate_kde
from bokeh.models import BasicTickFormatter
def plot_one(model):
    dist = hv.Bivariate((datas_sub.loc[:,model], datas_sub.iloc[:,-1]),kdims = ["PRECT","FLNT"])
    dist = bivariate_kde(dist, filled=True, levels = 25,bandwidth = .05,n_samples=100).opts(width = 350, height = 300, colorbar_opts = {"formatter":BasicTickFormatter(precision=1)}, xformatter="%.1e",  yformatter="%.1e", xlim=(0,1e-7), line_width = .01,  cmap = "bgyw_r",alpha = 1., bgcolor = None, colorbar = True)
    # dist = dist.redim.range(PRECT=(-10.,-6.),FLNT = (-10.,-6)).opts( xticks = [(i,f"{10**i:1.0e}") for i in range(-10,-5)], yticks = [(i,f"{scale(i,target_range,source_range):.1f}") for i in range(-10,-5)])

    return dist


out = hv.DynamicMap(plot_one, kdims = ["model"]).redim.values(model = datas_sub.columns[:-1].tolist()).layout("model").cols(4)
hv.save(out,"FLNT_vs_PRECT_vs_model.html")
out

In [None]:
# pip install dask[dataframe] datashader dask-expr -U

In [None]:
import holoviews.operation.datashader as hd


In [None]:
from holoviews.operation.stats import bivariate_kde
def plot_one(model):
    dist = hv.Bivariate((datas_sub.loc[:,model], datas_sub.iloc[:,-1]),kdims = ["PRECT","FLNT"])
    dist = bivariate_kde(dist, filled=True, levels = 25,bandwidth = .05,n_samples=100).opts(width = 600, height = 500, colorbar_opts = {"formatter":BasicTickFormatter(precision=1)}, xformatter="%.1e",  yformatter="%.1e", xlim=(0,1e-7), line_width = .01,  cmap = "bgyw_r",alpha = 1., bgcolor = None, colorbar = True)
    dist = dist.redim.range(PRECT=(-10.,-6.),FLNT = (-10.,-6)).opts( xticks = [(i,f"{10**i:1.0e}") for i in range(-10,-5)], yticks = [(i,f"{scale(i,target_range,source_range):.1f}") for i in range(-10,-5)])

    return dist


out = hv.DynamicMap(plot_one, kdims = ["model"]).redim.values(model = datas_sub.columns[:-1].tolist()[:1]).layout("model").cols(3)
hv.save(out,"FLNT_vs_PRECT_vs_model.html")
out

In [None]:
points_dmap = hv.DynamicMap(lambda a: hv.Points((datas_df.loc[:,a].values,datas_df.iloc[:,-1].values), ["precipitation","lw radiation"]),kdims=["model"])\
.redim.values(model = datas_df.columns[:-1].tolist())

points_dmap_datashape = hd.datashade(points_dmap, cmap = "fire",cnorm="eq_hist",x_range = (0,2e-6)).opts(xformatter="%.1e").layout("model").cols(4)
hv.save(points_dmap_datashape,"dist_flnt_precip.html")
points_dmap_datashape

### look at specific location

In [4]:
import torch
from gaia.plot import lats, lons
lat_c = 0
lon_c = 175

lats = torch.tensor(lats)
value,index = (torch.tensor(lats) - lat_c).abs().sort()
lat_mask = torch.zeros_like(lats).to(bool)
lat_mask[index[:2]] = True
print(lats[lat_mask])

lons = torch.tensor(lons)
value,index = (torch.tensor(lons) - lon_c).abs().sort()
lon_mask = torch.zeros_like(lons).to(bool)
lon_mask[index[:1]] = True
print(lons[lon_mask])




tensor([-0.9474,  0.9474])
tensor([175.])


  value,index = (torch.tensor(lats) - lat_c).abs().sort()
  value,index = (torch.tensor(lons) - lon_c).abs().sort()


In [None]:
temp["y"][:,time_index,var_index,lat_mask,lon_mask].mean(dim=-1)

In [5]:
samples = [136,1093,8748,279936,17915904]
datas = {}
var_index = 0
time_index = 1
temp = torch.load("/ssddg1/gaia/fixed/spcamclbm-nx-16-20m-timestep_4_test.pt")
datas["spcam_dataset"] = temp["y"][:,time_index,var_index,lat_mask,lon_mask].mean(dim=-1).numpy()
# FLNT = temp["x"][:,0,157,mask,:].ravel().numpy()
datas["ft_0"] = torch.load("../fine-tune/lightning_logs/base_cam4/predictions_spcam_fixed.pt")[:,0,lat_mask,lon_mask].mean(dim=-1).numpy()
for s in tqdm.tqdm(samples):
    data_file = str(avg.loc[s,"model_dir"] / "predictions_spcam_fixed.pt")
    datas[f"ft_{s:.1E}"] = torch.load(data_file)[:,0,lat_mask,lon_mask].mean(dim=-1).numpy()

  0%|          | 0/5 [00:00<?, ?it/s]

In [9]:
datas_df = pd.DataFrame(datas)
start = pd.Timestamp(year = 2023,month=1,day=1,hour = 0,unit = "H")
index = pd.date_range(start = start,freq = "12H",periods = len(datas_df),unit = "ns")
index.name = "time"
datas_df.index = index
datas_df = datas_df.rolling("D").mean()

In [None]:
# datas_df

In [None]:
plots = hv.DynamicMap(lambda a: hv.Curve((datas_df.index.values,datas_df.loc[:,a].values),["time"],["precipitation"],label = a), kdims = ["model"])
plots = plots.redim.values(model = datas_df.columns[1:].tolist())
plots = plots.layout("model")
plots = hv.Area((datas_df.index.values,datas_df.iloc[:,0].values),["time"],["precipitation"],label =datas_df.columns[0] )*plots
plots = plots.opts(hv.opts.Curve(width = 600, line_width = 1, xticks = 20,line_color ="blue",yformatter="%.1e",title=""))
plots = plots.opts(hv.opts.Area(line_color = None,fill_color = "orange"))
plots = plots.redim.range(precipitation=(0,1.5e-7)).cols(2)
hv.save(plots,"N0_E175_prec_vs_models.html")
plots
# plot = datas_df.hvplot.line(line_width = 2,y=datas_df.columns[-1],legend = True)*\
# datas_df.hvplot.line(line_width =1 ,alpha =.9,y=datas_df.columns[1:],subplots = True,legend = True)
# plot = plot.opts(hv.opts.Curve(width = 1000, xticks = 20,line_color = hv.Cycle(cycle = "Set1"))).cols(1)
# plot


### make a movie

In [6]:
import xarray


In [25]:
lats

tensor([-90.0000, -88.1053, -86.2105, -84.3158, -82.4211, -80.5263, -78.6316,
        -76.7368, -74.8421, -72.9474, -71.0526, -69.1579, -67.2632, -65.3684,
        -63.4737, -61.5789, -59.6842, -57.7895, -55.8947, -54.0000, -52.1053,
        -50.2105, -48.3158, -46.4211, -44.5263, -42.6316, -40.7368, -38.8421,
        -36.9474, -35.0526, -33.1579, -31.2632, -29.3684, -27.4737, -25.5789,
        -23.6842, -21.7895, -19.8947, -18.0000, -16.1053, -14.2105, -12.3158,
        -10.4211,  -8.5263,  -6.6316,  -4.7368,  -2.8421,  -0.9474,   0.9474,
          2.8421,   4.7368,   6.6316,   8.5263,  10.4211,  12.3158,  14.2105,
         16.1053,  18.0000,  19.8947,  21.7895,  23.6842,  25.5789,  27.4737,
         29.3684,  31.2632,  33.1579,  35.0526,  36.9474,  38.8421,  40.7368,
         42.6316,  44.5263,  46.4211,  48.3158,  50.2105,  52.1053,  54.0000,
         55.8947,  57.7895,  59.6842,  61.5789,  63.4737,  65.3684,  67.2632,
         69.1579,  71.0526,  72.9474,  74.8421,  76.7368,  78.63

In [21]:
from gaia.plot import get_land_outline

In [24]:
lons_shifted = lons.numpy().copy()
N = len(lons_shifted)
lons_shifted[N//2:] = lons_shifted[N//2:] - 360
new_index = list(range(N//2,N))+list(range(N//2))
lons_shifted = lons_shifted[new_index]
lons_shifted

array([-180. , -177.5, -175. , -172.5, -170. , -167.5, -165. , -162.5,
       -160. , -157.5, -155. , -152.5, -150. , -147.5, -145. , -142.5,
       -140. , -137.5, -135. , -132.5, -130. , -127.5, -125. , -122.5,
       -120. , -117.5, -115. , -112.5, -110. , -107.5, -105. , -102.5,
       -100. ,  -97.5,  -95. ,  -92.5,  -90. ,  -87.5,  -85. ,  -82.5,
        -80. ,  -77.5,  -75. ,  -72.5,  -70. ,  -67.5,  -65. ,  -62.5,
        -60. ,  -57.5,  -55. ,  -52.5,  -50. ,  -47.5,  -45. ,  -42.5,
        -40. ,  -37.5,  -35. ,  -32.5,  -30. ,  -27.5,  -25. ,  -22.5,
        -20. ,  -17.5,  -15. ,  -12.5,  -10. ,   -7.5,   -5. ,   -2.5,
          0. ,    2.5,    5. ,    7.5,   10. ,   12.5,   15. ,   17.5,
         20. ,   22.5,   25. ,   27.5,   30. ,   32.5,   35. ,   37.5,
         40. ,   42.5,   45. ,   47.5,   50. ,   52.5,   55. ,   57.5,
         60. ,   62.5,   65. ,   67.5,   70. ,   72.5,   75. ,   77.5,
         80. ,   82.5,   85. ,   87.5,   90. ,   92.5,   95. ,   97.5,
      

In [26]:
arr = xarray.DataArray(temp["y"][:,1,0,:,new_index].numpy(),
                       dims = ["time","lats","lons"],
                       coords=[index,lats.numpy(),lons_shifted],
                       attrs=dict(description = "precipitation"))
arr.name = "precipitation"

In [11]:
import hvplot.xarray

In [12]:
# !pip install geoviews

In [31]:
land = get_land_outline()

In [48]:
anim = arr.hvplot.image(colormap = "bmy",clim = (0,1e-6))
anim*land.opts(line_color = "white")