# CMIP6 Neural Network Analysis

**Following steps are included in this script:**

1. Load netCDF files
2. Preprocess data
3. Initilize and train model
4. Test performance and plot results

In [None]:
# ========== Packages ==========
import xarray as xr
import pandas as pd
import numpy as np
import os
import dask
import glob

### Functions

#### Open files

In [None]:
# ========= Helper function to open the dataset ========
def open_dataset(filename):
    ds = xr.open_dataset(filename)
    return ds

# Define a helper function to open and merge datasets
def open_and_merge_datasets(folder, model, experiment_id, variables):
    filepaths = []
    for var in variables:
        path = f'../../data/CMIP6/{experiment_id}/{folder}/{var}'
        fp = glob.glob(os.path.join(path, f'CMIP.{model}.{experiment_id}.{var}_regridded.nc'))
        if fp:
            filepaths.append(fp[0])
        else:
            #print(f"No file found for variable '{var}' in model '{model}'.")
            print(fp)

    datasets = [xr.open_dataset(fp) for fp in filepaths]
    ds = xr.merge(datasets)
    return ds

#### Helper functions

In [None]:
# ======== Standardize ========
def standardize(ds_dict):
    '''
    Helper function to standardize datasets of a dictionary
    '''
    for name, ds in ds_dict.items():
        attrs = ds.attrs
        ds_stand = (ds - ds.mean('time')) / ds.std('time')

        # Preserve variable attributes from the original dataset
        for var in ds.variables:
            if var in ds_stand.variables:
                ds_stand[var].attrs = ds[var].attrs

        ds_stand.attrs = attrs
        ds_dict[name] = ds_stand
        
    return ds_dict

In [None]:
def select_period(ds_dict, start_year=None, end_year=None):
    '''
    Helper function to select periods.
    
    Parameters:
    ds_dict (dict): Dictionary with xarray datasets.
    start_year (int): The start year of the period.
    end_year (int): The end year of the period.
    """
    '''
    start_year = DatetimeNoLeap(start_year, 1, 16, 12, 0, 0, 0,has_year_zero=True) # 16th of January of start year
    end_year = DatetimeNoLeap(end_year, 12, 16, 12, 0, 0, 0, has_year_zero=True) # 16th of December of end year
    ds_dict = {k: v.sel(time=slice(start_year, end_year)) for k, v in ds_dict.items()}
    
    return ds_dict

In [None]:
def check_args_and_get_info(ds_dict, variable):
    # Check the validity of input arguments
    if not isinstance(ds_dict, dict):
        raise TypeError("ds_dict must be a dictionary of xarray datasets.")
    if not all(isinstance(ds, xr.Dataset) for ds in ds_dict.values()):
        raise TypeError("All values in ds_dict must be xarray datasets.")
    if not isinstance(variable, str):
        raise TypeError('variable must be a string.')
        
    # Dictionary to store plot titles for each statistic
    titles = {"mean": "Mean", "std": "Standard deviation of yearly means", "min": "Minimum", "max": "Maximum", "median": "Median", "time": "Time", "space": "Space"}
    
    long_name = {
        'Precipitation': 'Precipitation',
        'Total Runoff': 'Total Runoff',
        'Vapor Pressure Deficit': 'Vapor Pressure Deficit',
        'Evaporation Including Sublimation and Transpiration': 'Evapotranspiration',
        'Transpiration': 'Transpiration',
        'Leaf Area Index': 'Leaf Area Index',
        'Carbon Mass Flux out of Atmosphere Due to Gross Primary Production on Land [kgC m-2 s-1]': 'Gross Primary Production',
        'Total Liquid Soil Moisture Content of 1 m Column': '1 m Soil Moisture',
        'Total Liquid Soil Moisture Content of 2 m Column': '2 m Soil Moisture',
    }
    
    # Data information
    var_long_name = long_name[ds_dict[list(ds_dict.keys())[0]][variable].long_name]
    period = f"{ds_dict[list(ds_dict.keys())[0]].attrs['period']}"
    experiment_id =  ds_dict[list(ds_dict.keys())[0]].experiment_id
    unit = ds_dict[list(ds_dict.keys())[0]][variable].units
    statistic_dim = ds_dict[list(ds_dict.keys())[0]].statistic_dimension
    statistic = ds_dict[list(ds_dict.keys())[0]].attrs['statistic']

    return var_long_name, period, unit, statistic_dim, statistic, experiment_id, titles

In [None]:
def slice_to_regions(ds_dict, regions):

    ds_dict_region = {region: {} for region in regions.keys()}

    # For each dataset, slice to each region and save in new dict
    for ds_name, ds in ds_dict.items():
        for region, bounds in regions.items():
            ds_dict_region[region][ds_name] = ds.sel(lat=bounds['lat'], lon=bounds['lon'])
            
    return ds_dict_region

In [None]:
# Assume all_keys is a list of keys representing the models in both dictionaries
def create_consecutive_ts(ds_dict):
    ds_dict_merged = {}
    for key in ds_dict[list(ds_dict.keys())[0]].keys():
        ds1 = ds_dict[list(ds_dict.keys())[0]][key]
        ds2 = ds_dict[list(ds_dict.keys())[1]][key]
        ds_dict_merged[key] = xr.concat([ds1, ds2], dim='time')
    return ds_dict_merged

In [None]:
def flatten_data(ds, variables):
    flattened_data_dict = {}

    for variable in variables:
        # Flatten the data
        flat_array = ds[variable].values.flatten()

        # Add to dictionary
        flattened_data_dict[variable] = flat_array

    return pd.DataFrame(flattened_data_dict)

#### Comuting functions

In [None]:
def compute_ens_metric(ds_dict, metric='mean'):
    
    # Get info
    experiment = ds_dict[list(ds_dict.keys())[0]].experiment_id
    first_year = [t.year for t in ds_dict[list(ds_dict.keys())[0]].time.values][0]
    last_year = [t.year for t in ds_dict[list(ds_dict.keys())[0]].time.values][-1]
    
    # Combine all datasets into one larger dataset
    combined = xr.concat(ds_dict.values(), dim='ensemble')
    
    # Compute the ensemble metric
    ensemble_metric = getattr(combined, metric)(dim='ensemble', skipna=True) # use getattr to call method by string name

    # Preserve variable attributes from the original dataset
    for var in ds_dict[list(ds_dict.keys())[0]].variables:
        if var in ensemble_metric.variables:
            ensemble_metric[var].attrs = ds_dict[list(ds_dict.keys())[0]][var].attrs
    
    ensemble_metric.attrs = {"period" : [first_year, last_year],
                           "statistic" : metric, # use variable metric here
                           "statistic_dimension" : "time",
                           "experiment_id": experiment, 
                           "source_id" : f"Ensemble {metric}"} 
        
    ds_dict[f'Ensemble_{metric}'] = ensemble_metric
    
    return ds_dict

### Load data

In [None]:
# ========= Define period, models and path ==============
variables=['pr', 'vpd', 'evspsbl', 'tran',  'mrro', 'lmrso_2m', 'lai', 'gpp']
experiment_id = ['historical', 'ssp370']
source_id = ['TaiESM1', 'BCC-CSM2-MR',  'CanESM5', 'CNRM-CM6-1', 'CNRM-ESM2-1', 'IPSL-CM6A-LR', 'UKESM1-0-LL', 'MPI-ESM1-2-LR', 'CESM2-WACCM', 'NorESM2-MM'] #
folder='preprocessed'

# ========= Use Dask to parallelize computations ==========
dask.config.set(scheduler='processes')

ds_dict = {}

for period in experiment_id:
    # Create dictionary using a dictionary comprehension and Dask
    ds_dict[period] = dask.compute({model: open_and_merge_datasets(folder, model, period, variables) for model in source_id})[0]

In [None]:
# ============= Have a look into the data ==============
print(ds_dict.keys())
#ds_dict[list(ds_dict.keys())[0]]

### Preprocessing

#### Merge time series to a consecutive one

In [None]:
ds_dict = create_consecutive_ts(ds_dict)

#### Compute ensemble mean

In [None]:
ds_dict = compute_ens_metric(ds_dict) 

In [None]:
ds_dict_ensmean = {'Ensemble_mean': ds_dict['Ensemble_mean']}

#### Standardize data


In [None]:
ds_dict_ensmean_stand = standardize(ds_dict_ensmean)

In [None]:
ds_ensmean = ds_dict_ensmean_stand['Ensemble_mean']

#### Flatten and remove nan

In [None]:
# Flatten data for NN
flattened_data_df = flatten_data(ds_ensmean, variables)

# Drop NaN values
flattened_data_df = flattened_data_df.dropna()

In [None]:
flattened_data_df

### Set up NN

#### Split data 

In [None]:
target_var = 'tran'

# Flatten your data and prepare your features and targets
X = flattened_data_df.drop(columns=[target_var])
y = flattened_data_df[target_var]

In [None]:
# Assuming the data is ordered chronologically
train_size = int(len(flattened_data_df) * 0.6)

# Train features and targets
X_train = X.iloc[:train_size]
y_train = y.iloc[:train_size]

# Test features and targets
X_test = X.iloc[train_size:]
y_test = y.iloc[train_size:]

#### Create the model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
# Network with 3 hidden layer and sigmoid activation in the last layer
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(X_train.shape[1], 50) # 50 units
        self.fc2 = nn.Linear(50, 30)
        self.fc3 = nn.Linear(30, 10)
        self.fc4 = nn.Linear(10, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = torch.sigmoid(self.fc4(x))  # using sigmoid in the last layer
        return x

In [None]:
model = Net()

#### Define your loss function and optimizer

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01) # optimizer with L2 regularization

#### Train the model

In [None]:
X_train_tensor = torch.tensor(X_train.values, dtype=torch.float)
y_train_tensor = torch.tensor(y_train.values, dtype=torch.float).view(-1, 1) # reshape the test target data to match the output size of model

for epoch in range(150):  # loop over the dataset multiple times
    optimizer.zero_grad()  # zero the gradient buffers
    outputs = model(X_train_tensor)  # forward pass
    loss = criterion(outputs, y_train_tensor)  # compute loss
    loss.backward()  # backward pass
    optimizer.step()  # update weights

#### Evaluate your model on the test data

##### Get MSE and MAE and compare it to a baseline model

In [None]:
X_test_tensor = torch.tensor(X_test.values, dtype=torch.float)
y_test_tensor = torch.tensor(y_test.values, dtype=torch.float).view(-1, 1) # reshape the test target data to match the output size of model

model.eval()  # set the model to evaluation mode
with torch.no_grad():
    predictions = model(X_test_tensor)
mse = criterion(predictions, y_test_tensor)
print(f"The Mean Squared Error of our forecasts is {mse.item()}")
mae_criterion = torch.nn.L1Loss()
mae = mae_criterion(predictions, y_test_tensor)
print(f"The Mean Absolute Error of our forecasts is {mae.item()}")

# Calculate the mean of the training data
mean_train = y_train_tensor.mean()

# Create a tensor of the same shape as y_test_tensor, filled with the mean of the training data
mean_preds = torch.full_like(y_test_tensor, fill_value=mean_train)

# Calculate the MSE of the mean predictions
mse_baseline = criterion(mean_preds, y_test_tensor)
print(f"The Mean Squared Error of the baseline model is {mse_baseline.item()}")
# If your model's MSE is lower than the baseline's MSE, that suggests your model is learning something useful from the data.

#### Evaluate multivarite dependencies

##### Compute Permutation Importance
- get a general sense of feature importance
- might not provide clear results when features are correlated

In [None]:
from sklearn.inspection import permutation_importance

def permutation_importance(model, X, y, loss_fn):
    X = X.clone().detach().requires_grad_(True)
    y = y.clone().detach()
    
    output = model(X)
    original_loss = loss_fn(output, y).item()
    
    importances = []
    
    for i in range(X.shape[1]):
        X_perm = X.clone()
        X_perm[:, i] = torch.rand(X.shape[0])
        output_perm = model(X_perm)
        perm_loss = loss_fn(output_perm, y).item()
        
        importances.append(perm_loss - original_loss)
    
    return importances

In [None]:
# Assuming you have a loss function (like MSE) stored in variable 'criterion'
importances = permutation_importance(model, X_test_tensor, y_test_tensor, criterion)

In [None]:
# Print or plot importances as desired
import matplotlib.pyplot as plt

In [None]:
# Create a list of feature names
feature_names = X_test.columns

# Plot the feature importances
plt.figure(figsize=(10, 6))
plt.barh(feature_names, importances, align='center')

# Add labels and title
plt.xlabel('Permutation Importance')
plt.title('Feature Importances')
plt.show()

##### Partial Dependence Plots (PDPs)
- visualize the effect of certain features on the model output, given that all other features remain constant
- get a sense of how different values of a feature affect the output of your model

In [None]:
from captum.attr import IntegratedGradients

In [None]:
# Create an IntegratedGradients object
ig = IntegratedGradients(model)

# Compute the attribution scores
attributions = ig.attribute(X_train_tensor, target=0, n_steps=10)

# Convert tensor to numpy for plotting
attributions = attributions.detach().numpy()

# Sum up the attributions for each feature across all data points (you might also consider taking the mean)
attributions_sum = attributions.sum(axis=0)

# Plotting
plt.barh(X_train.columns, attributions_sum)
plt.show()

##### SHAP Values
- how much does each feature contributed, positively or negatively, to each individual prediction

In [None]:
import shap

# Get a batch of your training data to serve as a representative dataset
background_data = X_train_tensor[:100]

# Create an explainer object
from shap.explainers._deep import PyTorchDeep
explainer = PyTorchDeep(model, background_data)

# Calculate SHAP values for a sample
shap_values = explainer.shap_values(X_test_tensor[:10])

# Convert the test set to numpy array for the plot
X_test_array = X_test[:10].values if isinstance(X_test, pd.DataFrame) else X_test[:10]

In [None]:
# feature names
feature_names = X_test.columns

# plot
shap.summary_plot(shap_values, X_test_array, feature_names=feature_names)