In [1]:
import xarray as xr
import numpy as np
import pandas as pd

In [2]:
import os
import sys

path_aibedo = '/Users/shazarika/ProjectSpace/currentProjects/AiBEDO/codebase/11_07_22/aibedo/'
sys.path.append(path_aibedo)

import torch
from typing import *


from aibedo.models import BaseModel
# from aibedo.utilities.wandb_api import reload_checkpoint_from_wandb, get_run_ids_for_hyperparams
import scipy.stats

# from aibedo.utilities.config_utils import get_config_from_hydra_compose_overrides
# from aibedo.utilities.utils import rsetattr, get_logger, get_local_ckpt_path, rhasattr, rgetattr

In [3]:
# from aibedo.interface import reload_model_from_config_and_ckpt
# def load_model(config_path,config_name,ckpt_path,ckpt_name):
#     overrides = [f'datamodule.data_dir={DATA_DIR}', f"++model.use_auxiliary_vars=False"]
    
#     print(overrides)
#     print(config_path,config_name,ckpt_path,ckpt_name)

#     ### Load Hydra config file
#     GlobalHydra.instance().clear() 
#     hydra.initialize(config_path=config_path, version_base=None)
#     config = hydra.compose(config_name=config_name, overrides=overrides)
#     config['ckpt_dir'] = ckpt_path
#     config['callbacks']['model_checkpoint']['dirpath'] = config['ckpt_dir']
    

#     ## Modify config dict
#     if config.model.get('input_transform'):
#         OmegaConf.update(config, f'model.input_transform._target_',
#                          str(rgetattr(config, f'model.input_transform._target_')).replace('aibedo_salva', 'aibedo'))
#     for k in ['model', 'datamodule', 'model.mixer', 'model.input_transform']:
#         if config.get(k):
#             OmegaConf.update(config, f'{k}._target_',
#                              str(rgetattr(config, f'{k}._target_')).replace('aibedo_salva', 'aibedo'))
    
#     ## Load model
#     loadmodel = reload_model_from_config_and_ckpt(config, ckpt_path+ckpt_name, load_datamodule=True)

#     return loadmodel[0], config

def concat_variables_into_channel_dim(data: xr.Dataset, variables: List[str]) -> np.ndarray:
    """Concatenate xarray variables into numpy channel dimension (last)."""
    assert len(data[variables[0]].shape) == 2, "Each input data variable must have two dimensions"
    data_ml = np.concatenate(
        [np.expand_dims(data[var].values, axis=-1) for var in variables],
        axis=-1  # last axis
    )
    return data_ml.astype(np.float32)

def get_month_of_output_data(output_xarray: xr.Dataset) -> np.ndarray:
    """ Get month of the snapshot (0-11)  """
    n_gridcells = len(output_xarray['ncells'])
    # .item() is required here as only one timestep is used, the subtraction with -1 because we want 0-indexed months
    month_of_snapshot = np.array(output_xarray['time.month'], dtype=np.float32) - 1
    # now repeat the month for each grid cell/pixel
    dataset_month = np.repeat(month_of_snapshot, n_gridcells)
    return dataset_month.reshape([month_of_snapshot.shape[0], n_gridcells, 1])  # Add a dummy channel/feature dimension

def get_pytorch_model_data(input_xarray: xr.Dataset, output_xarray: xr.Dataset, input_vars: List[str]) -> torch.Tensor:
    """Get the tensor input data for the ML model."""
    # Concatenate all variables into the channel/feature dimension (last) of the input tensor
    data_input = concat_variables_into_channel_dim(input_xarray, input_vars)
    # Get the month of the snapshot (0-11), which is needed to denormalize the model predictions into their original scale
    data_month = get_month_of_output_data(output_xarray)
    # For convenience, we concatenate the month information to the input data, but it is *not* used by the model!
    data_input = np.concatenate([data_input, data_month], axis=-1)
    # Convert to torch tensor and move to CPU/GPU
    data_input = torch.from_numpy(data_input).float().to(device)
    return data_input

def predict_with_aibedo_model(aibedo_model: BaseModel, input_tensor: torch.Tensor) -> Dict[str, torch.Tensor]:
    """
    Predict with the AiBEDO model.
    Returns:
        A dictionary of output-variable -> prediction-tensor key->value pairs for each variable {var}.
        Keys with name {var} (e.g. 'pr') are in denormalized scale. Keys with name {var}_pre or {var}_nonorm are raw predictions of the ML model.
        To only get the raw predictions, please use aibedo_model.raw_predict(input_tensor)
    """
    aibedo_model.eval()
    with torch.no_grad():  # No need to track the gradients during inference
        prediction = aibedo_model.predict(input_tensor, return_normalized_outputs=True)  # if true, also return {var}_nonorm (or {var}_pre)
    return prediction

In [4]:
def prediction_to_dataset(inDS,preddict,in_vars):
    ds = {var:(['time','ncells'],inDS[var].data) for i,var in enumerate(in_vars[:1])} 
    for var in preddict:
        ds[var] = (['time','ncells'],preddict[var])
    ds_prediction = xr.Dataset(data_vars = ds,
                    coords = {"time":(['time'],inDS.time.values),
                            "lat":(['ncells'],inDS.lat.values),
                            "lon":(["ncells"],inDS.lon.values),},)
    return ds_prediction

In [5]:
def clean_output_dataset(inDS):
    ds = {var:(['time','ncells'],np.zeros_like(inDS[var].data)) for i,var in enumerate(['tas_nonorm', 'pr_nonorm', 'ps_nonorm'])} 
    
    ds_final = xr.Dataset(data_vars = ds,
                    coords = {"time":(['time'],inDS.time.values),
                            "lat":(['ncells'],inDS.lat.values),
                            "lon":(["ncells"],inDS.lon.values),},)
    return ds_final

In [6]:
def run_perturbation(model, ds_input, ds_output, perturbations, invariables, lons = [0,40], lats = [0,30]):
    lat0,lat1 = lats
    lon0,lon1 = lons
    ### Perturb radiation fields
    data_all = []
    for var in invariables:
        if var in perturbations:
            where = np.where((ds_input.lat > lat0) & (ds_input.lat < lat1) & 
                             (ds_input.lon > lon0) & (ds_input.lon < lon1))
            ds_input['{0}'.format(var)][:,where[0]] += perturbations[var]
    
    input_ml = get_pytorch_model_data(ds_input, ds_output, input_vars=model.main_input_vars)
    predictions_ml = predict_with_aibedo_model(model, input_ml)    
    
    ds_prediction = prediction_to_dataset(ds_input,predictions_ml,
                               [var for var in ds_input if 'nonorm' in var])

    return ds_prediction

def get_perturbed_data(in_data, perturbations, invariables, lons, lats):
    lat0,lat1 = lats
    lon0,lon1 = lons
    ### Perturb radiation fields
    print("invar:", invariables)
    print("perturb:", perturbations)
    
    for var in invariables:
        print("#", var)
        if var in perturbations:
            print("-", var)
            where = np.where((in_data.lat > lat0) & (in_data.lat < lat1) & 
                             (in_data.lon > lon0) & (in_data.lon < lon1))
            in_data['{0}'.format(var)][:,where[0]] += perturbations[var]
            
    return in_data

def run_aibedomodel(model, ds_in, ds_out): 
    input_ml = get_pytorch_model_data(ds_in, ds_out, input_vars=model.main_input_vars)
    predictions_ml = predict_with_aibedo_model(model, input_ml)    
    
    ds_prediction = prediction_to_dataset(ds_in,predictions_ml,
                               [var for var in ds_in if 'nonorm' in var])

    return ds_prediction

def reg_avg(ds,var,lats = [0,30],lons = [-150,-110]):
    lat0,lat1 = lats
    lon0,lon1 = lons

    avg = ds[var].where((ds.lat > lat0) & (ds.lat < lat1) & 
                                 (ds.lon > lon0) & (ds.lon < lon1)).mean(('ncells'))
    return avg

In [7]:
DATA_DIR = '/Users/shazarika/ProjectSpace/currentProjects/AiBEDO/codebase/aibedo_viz/haruki_notebook_10_27_22/LE_CESM2_data/'
# the data used for prediction must be here, as well as the cmip6 mean/std statistics
# Input data filename (isosph is an order 6 icosahedron, isosph5 of order 5, etc.)
filename_input = "isosph5.CESM2-LE.historical.r11i1p1f1.Input.Exp8.nc"
# Output data filename is inferred from the input filename, do not edit!
# E.g.: "compress.isosph.CESM2.historical.r1i1p1f1.Output.nc"
filename_output = filename_input.replace("Input.Exp8.nc", "Output.nc")

ds_input = xr.open_dataset(f"{DATA_DIR}/{filename_input}")  # Input data
ds_output = xr.open_dataset(f"{DATA_DIR}/{filename_output}") # Ground truth data
# Get the appropriate device (GPU or CPU) to use
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
overrides = [f'datamodule.data_dir={DATA_DIR}', f"++model.use_auxiliary_vars=False"]

In [8]:
model1 = torch.load('/Users/shazarika/ProjectSpace/currentProjects/AiBEDO/codebase/11_07_22/aibedoviz/fullmodel/MLP_aibedo.pt')
model1.eval()

AIBEDO_MLP(
  (val_metrics): ModuleDict(
    (val/mse): MeanSquaredError()
    (val/tas_nonorm/rmse): MeanSquaredError()
    (val/tas/rmse): MeanSquaredError()
    (val/ps_nonorm/rmse): MeanSquaredError()
    (val/ps/rmse): MeanSquaredError()
    (val/pr_nonorm/rmse): MeanSquaredError()
    (val/pr/rmse): MeanSquaredError()
  )
  (mlp): MLP(
    (hidden_layers): ModuleList(
      (0): MLP_Block(
        (layer): Sequential(
          (0): Linear(in_features=71694, out_features=1024, bias=True)
          (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (2): GELU()
        )
      )
      (1): MLP_Block(
        (layer): Sequential(
          (0): Linear(in_features=1024, out_features=1024, bias=True)
          (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (2): GELU()
        )
      )
      (2): MLP_Block(
        (layer): Sequential(
          (0): Linear(in_features=1024, out_features=1024, bias=True)
          (1): LayerNorm((1024,), eps=1

In [9]:
## Region perturbation

cres = -10
crelSurf = 0
lats = [-30,0]
lons = [-110,-70]

# Perturbation values
perturbations = {'cres_nonorm':cres, 'cresSurf_nonorm':cres}
selslice = slice("1900-01-01","1900-01-01")

# region lat/lon definitions
regions = {'SEP':{'lats':[-30,0],'lons':[-110,-70]},
           'NEP':{'lats':[0,30],'lons':[-150,-110]},
           'SEA':{'lats':[-30,0],'lons':[-25,15]},
          }
invariables = ['crelSurf_nonorm', 'crel_nonorm', 'cresSurf_nonorm', 'cres_nonorm', 'netTOAcs_nonorm', 'lsMask', 'netSurfcs_nonorm']

# Get AiBEDO prediction of reference baseline
# d_dspert_clim = run_perturbation(model1, ds_input.sel(time=selslice),
#                                  ds_output.sel(time=selslice), 
#                                  {'cres_nonorm':0, 'cresSurf_nonorm':0, 'crelSurf_nonorm':0},
#                                           invariables, 
#                                  lons = lons, lats = lats)

d_dspert_clim = run_aibedomodel(model1, ds_input.sel(time=selslice),ds_output.sel(time=selslice))



In [10]:
test_in_ds = ds_input.sel(time=slice('1900-01-01','1900-12-01'))

In [11]:
test_array = np.mean(test_in_ds['crel_nonorm'],axis=0).data[:1000]

In [12]:
perturbations

{'cres_nonorm': -10, 'cresSurf_nonorm': -10}

In [13]:
lats = regions['SEP']['lats']
lons = regions['SEP']['lons']
perturbed_SEP = get_perturbed_data(ds_input.sel(time=slice('1900-01-01','1900-12-01')), perturbations, invariables, lons = lons, lats = lats)

invar: ['crelSurf_nonorm', 'crel_nonorm', 'cresSurf_nonorm', 'cres_nonorm', 'netTOAcs_nonorm', 'lsMask', 'netSurfcs_nonorm']
perturb: {'cres_nonorm': -10, 'cresSurf_nonorm': -10}
# crelSurf_nonorm
# crel_nonorm
# cresSurf_nonorm
- cresSurf_nonorm
# cres_nonorm
- cres_nonorm
# netTOAcs_nonorm
# lsMask
# netSurfcs_nonorm


In [14]:
original_data = ds_input.sel(time=selslice)['netTOAcs_nonorm'][0].data

In [15]:
modified_data = perturbed_SEP['netTOAcs_nonorm'][0].data

In [16]:
diff = modified_data - original_data

In [17]:
np.sum(diff)

0.0

In [18]:
perturbed_SEP

### PPCA out-of-distribution test

In [19]:
%pip install sklearn

Note: you may need to restart the kernel to use updated packages.


In [20]:
from sklearn.decomposition import PCA

In [21]:
invariables

['crelSurf_nonorm',
 'crel_nonorm',
 'cresSurf_nonorm',
 'cres_nonorm',
 'netTOAcs_nonorm',
 'lsMask',
 'netSurfcs_nonorm']

In [22]:
pca_dict = {}
for v in [ 'crelSurf_nonorm', 'crel_nonorm', 'cresSurf_nonorm', 'cres_nonorm', 'netTOAcs_nonorm', 'netSurfcs_nonorm']:
    mv_data = ds_input[v][:,:642].data
    pca = PCA()
    pca.fit(mv_data)
    pca_dict[v] = pca
    

In [23]:
vname = 'cresSurf_nonorm'

mv_data = ds_input[vname][:,:642].data
perturb_data = perturbed_SEP[vname][:,:642].data

pca = pca_dict[vname]
transformed_data = pca.transform(mv_data)
transformed_data_perturbed = pca.transform(perturb_data)

sample_score = pca.score_samples(perturb_data)
sample_labels = ["ll:" + "{:.2f}".format(s) for s in sample_score]



In [24]:
import matplotlib.pyplot as plt
import plotly.express as px

import plotly.graph_objects as go

In [25]:


fig = go.Figure()
fig.add_trace(go.Histogram2dContour(
        x = transformed_data[:,0],
        y = transformed_data[:,1],
        colorscale = 'Blues',
        xaxis = 'x',
        yaxis = 'y',
    ))
fig.add_trace(go.Scatter(
        x = transformed_data[:,0],
        y = transformed_data[:,1],
        xaxis = 'x',
        yaxis = 'y',
        mode = 'markers',
        marker = dict(
            color = 'rgba(0,0,0,0.1)',
            size = 2
        ),
        name="original"
    ))
fig.add_trace(go.Scatter(
        x = transformed_data_perturbed[:,0],
        y = transformed_data_perturbed[:,1],
        xaxis = 'x',
        yaxis = 'y',
        mode = 'markers',
        marker = dict(
            color = 'rgba(255,0,0,0.8)',
            size = 6
        ),
        text=sample_labels,
        name="perturbed"
    ))


fig.update_layout(
    height = 600,
    width = 600,
    hovermode = 'closest',
    showlegend = False
)




fig.show()

In [26]:
from dash import Dash, dash_table, dcc, html
from dash.dependencies import Input, Output, State

In [27]:
data=[
        {
            "year": i,
            "montreal": i * 10,
            "toronto": i * 100,
            "ottawa": i * -1,
            "vancouver": i * -10,
            "temp": i * -100,
            "humidity": i * 5,
        }
        for i in range(10)
    ],

In [28]:
data[0][1]

{'year': 1,
 'montreal': 10,
 'toronto': 100,
 'ottawa': -1,
 'vancouver': -10,
 'temp': -100,
 'humidity': 5}

In [29]:
ds_output

In [30]:
d_dspert_clim

In [31]:
tp_region_defs = {
              'Sahel':{'lat':[10,20],'lon':[-15,35],'variable':['pr']},
              'Atlantic Subpolar Gyre':{'lat':[45,60],'lon':[-50,-20],'variable':['tas']},
              'Eurasia Boreal':{'lat':[60,80],'lon':[65,170],'variable':['tas']},
              'America Boreal':{'lat':[60,75],'lon':[-160,-60],'variable':['tas']},
              'Amazon':{'lat':[-10,10],'lon':[-65,-45],'variable':['pr']},
              'Coral Sea':{'lat':[-25,-10],'lon':[145,165],'variable':['tas']},
              'Barents Sea Ice':{'lat':[70,90],'lon':[10,60],'variable':['tas']},
            }


In [64]:
def get_regional_data(local_ds, lons = [0,40], lats = [0,30]):
    lat0,lat1 = lats
    lon0,lon1 = lons
        
    where = np.where((local_ds.lat > lat0) & (local_ds.lat < lat1) & 
                        (local_ds.lon > lon0) & (local_ds.lon < lon1))
    
    local_mean_dict = {
        'pr_nonorm': np.mean(local_ds['pr_nonorm'][:,where[0]].data),
        'ps_nonorm': np.mean(local_ds['ps_nonorm'][:,where[0]].data),
        'tas_nonorm': np.mean(local_ds['tas_nonorm'][:,where[0]].data),
    }
    
    return local_mean_dict

In [67]:
test_data = get_regional_data(ds_output, tp_region_defs['Sahel']['lon'], tp_region_defs['Sahel']['lat'])

In [69]:
test_data['pr_nonorm']

2.5521811252247252e-05

In [47]:
d_dspert_clim.keys()

KeysView(<xarray.Dataset>
Dimensions:           (time: 1, ncells: 10242)
Coordinates:
  * time              (time) object 1900-01-01 00:00:00
    lat               (ncells) float64 58.28 -58.28 58.28 ... -1.184 1.184 1.184
    lon               (ncells) float64 -90.0 -90.0 90.0 ... -21.59 -25.66 -21.59
Dimensions without coordinates: ncells
Data variables:
    netSurfcs_nonorm  (time, ncells) float32 -37.05 1.771 7.538 ... 10.5 13.03
    tas               (time, ncells) float32 253.1 278.9 251.8 ... 299.8 299.9
    ps                (time, ncells) float32 1.014e+05 9.912e+04 ... 1.011e+05
    pr                (time, ncells) float32 1.026e-05 2.837e-05 ... 1.969e-05
    tas_nonorm        (time, ncells) float32 0.7442 0.1472 ... 0.03967 0.03982
    ps_nonorm         (time, ncells) float32 -0.01413 0.6817 ... 0.02988 0.04701
    pr_nonorm         (time, ncells) float32 0.1793 -0.0916 ... -0.09402 -0.1244)

In [75]:
tt_dict = {}

In [71]:
tt_dict['gg'] = 0

In [72]:
tt_dict

{'gg': 0}

In [76]:
for v in ['dd','ll', 'oo']:
    tt_dict[v] = v

In [77]:
tt_dict

{'dd': 'dd', 'll': 'll', 'oo': 'oo'}