# CMIP6 Statistics and Plots

**Following steps are included in this script:**

1. Load netCDF files
2. Compute statistics
3. Plot statistics

In [None]:
# ========== Packages ==========
import xarray as xr
import pandas as pd
import numpy as np
import os
import seaborn as sns
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import dask
from matplotlib import rcParams
import math
import multiprocessing as mp
from cftime import DatetimeNoLeap
import glob


%matplotlib inline

rcParams["mathtext.default"] = 'regular'

### Functions

In [None]:
def compute_statistic_single(ds, statistic, dimension, yearly_mean=True):
    if dimension == "time":
        stat_ds = getattr(ds, statistic)("time", keep_attrs=True, skipna=True)
        stat_ds.attrs['period'] = [str(ds.time.dt.year[0].values), str(ds.time.dt.year[-1].values)]
        
    if dimension == "space":
        # Assign the period attribute before grouping by year
        ds.attrs['period'] = [str(ds.time.dt.year[0].values), str(ds.time.dt.year[-1].values)]
        
        if yearly_mean:
            ds = ds.groupby('time.year').mean('time', keep_attrs=True, skipna=True)
            ds.attrs['mean'] = 'yearly mean'
            
        
        #get the weights, apply on data, and compute statistic
        weights = np.cos(np.deg2rad(ds.lat))
        weights.name = "weights"
        ds_weighted = ds.weighted(weights)
        stat_ds = getattr(ds_weighted, statistic)(("lon", "lat"), keep_attrs=True, skipna=True)
    
    stat_ds.attrs['statistic'] = statistic
    stat_ds.attrs['statistic_dimension'] = dimension

    return stat_ds

In [None]:
def compute_statistic(ds_dict, statistic, dimension, start_year=None, end_year=None, yearly_mean=True):
    """
    Computes the specified statistic for each dataset in the dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        statistic (str): The statistic to compute, which can be one of 'mean', 'std', 'min', 'var', or 'median'.
        dimension (str): The dimension to compute over, which can be 'time' or 'space'.
        start_year (str, optional): The start year of the period to compute the statistic over.
        end_year (str, optional): The end year of the period to compute the statistic over.

    Returns:
        dict: A dictionary with computed statistic for each dataset.
    """
    # Check the validity of input arguments
    if not isinstance(ds_dict, dict):
        raise TypeError("ds_dict must be a dictionary of xarray datasets.")
    if not all(isinstance(ds, xr.Dataset) for ds in ds_dict.values()):
        raise TypeError("All values in ds_dict must be xarray datasets.")
    if statistic not in ["mean", "std", "min", "max", "var", "median"]:
        raise ValueError(f"Invalid statistic '{statistic}' specified.")
    if dimension not in ["time", "space"]:
        raise ValueError(f"Invalid dimension '{dimension}' specified.")

    if start_year is not None and end_year is not None:
        # Convert integer years to DatetimeNoLeap format
        start_year = DatetimeNoLeap(start_year, 1, 16) # 16th of January of start year
        end_year = DatetimeNoLeap(end_year, 12, 16) # 16th of December of end year
        ds_dict = {k: v.sel(time=slice(start_year, end_year)) for k, v in ds_dict.items()}
        
    # Use multiprocessing to compute the statistic for each dataset in parallel
    with mp.Pool() as pool:
        results = pool.starmap(compute_statistic_single, [(ds, statistic, dimension, yearly_mean) for ds in ds_dict.values()])

    return dict(zip(ds_dict.keys(), results))

### 1. Load netCDF files

In [None]:
# ========= Define period, models and path ==============
variable=['evspsbl', 'gpp', 'huss', 'lai', 'mrro', 'pr', 'tran', 'lmrso_1m', 'lmrso_2m']
experiment_id = 'historical'
source_id = ['TaiESM1', 'AWI-ESM-1-1-LR', 'BCC-CSM2-MR', 'BCC-ESM1', 'CanESM5', 'CNRM-CM6-1','CNRM-CM6-1-HR', 'CNRM-ESM2-1', 'IPSL-CM6A-LR',  'UKESM1-0-LL', 'MPI-ESM1-2-LR', 'CESM2', 'CESM2-FV2', 'CESM2-WACCM', 'CESM2-WACCM-FV2',  'NorESM2-MM'] #
folder='preprocessed'


# ========= Use Dask to parallelize computations ==========
dask.config.set(scheduler='processes')

# ========= Create a helper function to open the dataset ========
def open_dataset(filename):
    ds = xr.open_dataset(filename)
    return ds

# Define a helper function to open and merge datasets
def open_and_merge_datasets(folder, model, experiment_id, variables):
    filepaths = []
    for var in variables:
        path = f'../../data/CMIP6/{experiment_id}/{folder}/{var}'
        fp = glob.glob(os.path.join(path, f'CMIP.{model}.{experiment_id}.{var}_regridded.nc'))
        if fp:
            filepaths.append(fp[0])
        else:
            #print(f"No file found for variable '{var}' in model '{model}'.")
            print(fp)

    datasets = [xr.open_dataset(fp) for fp in filepaths]
    ds = xr.merge(datasets)
    return ds

# Create dictionary using a dictionary comprehension and Dask
ds_dict = dask.compute({model: open_and_merge_datasets(folder, model, experiment_id, variable) for model in source_id})[0]

In [None]:
# ============= Have a look into the data ==============
print(ds_dict.keys())
ds_dict[list(ds_dict.keys())[0]]

In [None]:
# Convert integer years to DatetimeNoLeap format
start_year=1930
end_year=2014

start_year = DatetimeNoLeap(start_year, 1, 16) # 16th of January of start year
end_year = DatetimeNoLeap(end_year, 12, 16) # 16th of December of end year
ds_dict = {k: v.sel(time=slice(start_year, end_year)) for k, v in ds_dict.items()}

### 2. Compute statistics and plot data

In [None]:
# ========= Compute statistic for plot ===============
ds_stat = compute_statistic(ds_dict, 'mean', 'time', start_year=1985, end_year=2014) #, yearly_mean=True)

### 3. Perfrom hierarchical clustering

In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.impute import SimpleImputer
from scipy.cluster.hierarchy import dendrogram, linkage, fcluster

In [None]:
variable=['evspsbl', 'gpp', 'huss', 'lai', 'pr', 'tran', 'lmrso_1m', 'lmrso_2m']
#variable=['evspsbl', 'pr', 'huss']
#variable=['lai', 'mrro', 'gpp', 'tran', 'lmrso_1m', 'lmrso_2m']

In [None]:
# Initialize an empty dictionary to hold the flattened data arrays
data = {}

# Loop over all variables
for var in variable: 
    # Initialize an empty list to hold the flattened data arrays for each model
    data_arrays = []
    
    # Loop over all models
    for model in ds_dict.keys():
        # Extract the data for this variable and model
        var_data = ds_dict[model][var].values.flatten()
        # Exclude NaN values
        var_data = var_data[~np.isnan(var_data)]
        
        # Check if the array is empty
        if len(var_data) == 0:
            # Assign NaN as the mean value
            mean_data = np.nan
        else:
            # Compute the mean of the data
            mean_data = var_data.mean()
        
        # Add it to the list
        data_arrays.append(mean_data)
    
    # Add the variable and its corresponding data arrays to the dictionary
    data[var] = data_arrays

# Create a DataFrame from the data dictionary
df = pd.DataFrame(data)

# Set the row labels using the model names
df.index = list(ds_dict.keys())

# Convert the DataFrame to a numpy array
data_matrix = df.to_numpy()

# Standardize the DataFrame
scaler = StandardScaler()
df_standardized = pd.DataFrame(scaler.fit_transform(df), columns=df.columns, index=df.index)

from scipy.cluster.hierarchy import linkage

# Perform hierarchical/agglomerative clustering
# The 'ward' method has generally good performance
linked = linkage(df_standardized, method='ward')

In [None]:
# Determine the cluster assignments using a desired number of clusters
num_clusters = 3
cluster_assignments = fcluster(linked, num_clusters, criterion='maxclust')

# Calculate the cluster weights
cluster_weights = np.bincount(cluster_assignments) / len(cluster_assignments)

print("Cluster Weights:", cluster_weights)

In [None]:
from sklearn.cluster import KMeans
import warnings

# Suppress the warning messages
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")

# Initialize an empty list to store the within-cluster sum of squares (WCSS) values
wcss = []

# Define the range of numbers of clusters to try
num_clusters_range = range(1, 10)

# Calculate the WCSS for each number of clusters
for num_clusters in num_clusters_range:
    # Perform K-means clustering
    kmeans = KMeans(n_clusters=num_clusters)
    kmeans.fit(linked)
    
    # Append the WCSS to the list
    wcss.append(kmeans.inertia_)

# Plot the WCSS values
plt.plot(num_clusters_range, wcss)
plt.xlabel('Number of Clusters')
plt.ylabel('WCSS')
plt.title('Elbow Method')
plt.show()

In [None]:
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import matplotlib.pyplot as plt

# Initialize an empty list to store the silhouette scores
silhouette_scores = []

# Define the range of numbers of clusters to try
num_clusters_range = range(2, 10)

# Calculate the silhouette score for each number of clusters
for num_clusters in num_clusters_range:
    # Perform K-means clustering
    kmeans = KMeans(n_clusters=num_clusters)
    labels = kmeans.fit_predict(df_standardized)
    
    # Calculate the silhouette score
    silhouette_scores.append(silhouette_score(df_standardized, labels))

# Plot the silhouette scores
plt.plot(num_clusters_range, silhouette_scores)
plt.xlabel('Number of Clusters')
plt.ylabel('Silhouette Score')
plt.title('Silhouette Analysis')
plt.show()

In [None]:
df_standardized

In [None]:
from scipy.cluster.hierarchy import linkage

# Perform hierarchical/agglomerative clustering
# The 'ward' method has generally good performance
linked = linkage(df_standardized, method='ward')

4. Plot the clustering results

In [None]:
# Define labels for the plots

labels= ['TaiESM1 - (TaiAM1 // CLM4.0)',
 'AWI-ESM-1-1-LR - (ECHAM6.3.04p1 // JSBACH 3.20 + dyn veg)',
 'BCC-CSM2-MR - (BCC_AGCM3_MR // BCC_AVIM2)',
 'BCC-ESM1 - (BCC_AGCM3_LR // BCC_AVIM2)',
 'CanESM5 - (CanAM5 // CLASS3.6/CTEM1.2)',
 'CNRM-CM6-1 - (Arpege 6.3 // Surfex 8.0c)',
 'CNRM-CM6-1-HR - (Arpege 6.3 // Surfex 8.0c)',
 'CNRM-ESM2-1 - (Arpege 6.3 // Surfex 8.0c)',
 'IPSL-CM6A-LR - (LMDZ // ORCHIDEE)',
 'UKESM1-0-LL - (MetUM-HadGEM3-GA7.1 // JULES-ES-1.0)',
 'MPI-ESM1-2-LR - (ECHAM6.3 // JSBACH4.20)',
 'CESM2 - (CAM6 // CLM5)',
 'CESM2-FV2 - (CAM6 // CLM5)',
 'CESM2-WACCM - (WACCM6 // CLM5)',
 'CESM2-WACCM-FV2 - (WACCM6 // CLM5)',
 'NorESM2-MM - (CAM_OSLO5 // CLM5)']

labels_ls= ['TaiESM1 - (CLM4.0)',
 'AWI-ESM-1-1-LR - (JSBACH 3.20 + dyn veg)',
 'BCC-CSM2-MR - (BCC_AVIM2)',
 'BCC-ESM1 - (BCC_AVIM2)',
 'CanESM5 - (CLASS3.6/CTEM1.2)',
 'CNRM-CM6-1 - (Surfex 8.0c)',
 'CNRM-CM6-1-HR - (Surfex 8.0c)',
 'CNRM-ESM2-1 - (Surfex 8.0c)',
 'IPSL-CM6A-LR - (ORCHIDEE)',
 'UKESM1-0-LL - (JULES-ES-1.0)',
 'MPI-ESM1-2-LR - (JSBACH4.20)',
 'CESM2 - (CLM5)',
 'CESM2-FV2 - (CLM5)',
 'CESM2-WACCM - (CLM5)',
 'CESM2-WACCM-FV2 - (CLM5)',
 'NorESM2-MM - (CLM5)']

labels_atm= ['TaiESM1 - (TaiAM1)',
 'AWI-ESM-1-1-LR - (ECHAM6.3.04p1)',
 'BCC-CSM2-MR - (BCC_AGCM3_MR)',
 'BCC-ESM1 - (BCC_AGCM3_LR)',
 'CanESM5 - (CanAM5)',
 'CNRM-CM6-1 - (Arpege 6.3)',
 'CNRM-CM6-1-HR - (Arpege 6.3)',
 'CNRM-ESM2-1 - (Arpege 6.3)',
 'IPSL-CM6A-LR - (LMDZ)',
 'UKESM1-0-LL - (MetUM-HadGEM3-GA7.1)',
 'MPI-ESM1-2-LR - (ECHAM6.3)',
 'CESM2 - (CAM6)',
 'CESM2-FV2 - (CAM6)',
 'CESM2-WACCM - (WACCM6)',
 'CESM2-WACCM-FV2 - (WACCM6)',
 'NorESM2-MM - (CAM_OSLO5)']

Radial tree plot

In [None]:
import radialtree as rt

Z2 = dendrogram(linked,labels=labels,no_plot=True)

# plot a circular dendrogram
rt.plot(Z2)
# Add a title
plt.title('Dendrogram of Historical CMIP6 Data (1985-2014)', y=1.75)

Dendrogram

In [None]:
from scipy.cluster.hierarchy import dendrogram
import matplotlib.pyplot as plt

# Set the desired colormap
sns.set_palette('hsv')

# Plot the dendrogram
plt.figure(figsize=(10, 7), frameon=False)
dendrogram(linked, labels=labels,
           orientation='left')#,
           #distance_sort='descending',
           #show_leaf_counts=True)

# Remove the frame
plt.box(False)

# Rotate the x-axis tick labels vertically
#plt.xticks(rotation='vertical')

# Add a title
plt.title('Dendrogram of Historical CMIP6 Data (1985-2014)')

plt.show()

Clustermap

In [None]:
clustermap=sns.clustermap(df_standardized, cmap='viridis')
clustermap.ax_heatmap.set_title("Cluster Map of Historical CMIP6 Data (1985-2014)", y=1.25)