## Data Sources and Import Statements

Data have been downloaded from the Earth System Grid Federation at https://esgf-node.ipsl.upmc.fr/projects/esgf-ipsl/.

Each file has been concatenated to contain ssp119 and ssp126 scenarios and r1-5 ensemble members from 2015 to 2100. Each has also been regridded to 2.5° resolution.

In [27]:
# IMPORT STATEMENTS

# General useful libraries
import math
import os
import re
# Loading in data (netcdf files)
import h5py
# Handling data
import numpy as np
import netCDF4 as nc
# Installing xarray and its dependencies
import xarray as xr
import scipy 
import dask
import bottleneck
# Plotting figures
import matplotlib.pyplot as plt #Main plotting package
import cartopy.crs as ccrs
from cartopy.util import add_cyclic_point
import cartopy.mpl.ticker as cticker

# Machine Learning package
import tensorflow as tf
tf.compat.v1.disable_v2_behavior() 
print(tf.__version__)

# Interpreting neural networks 
import shap

2.18.0


## Loading Data/Checking Shape & Dimension

In [28]:
# Replace this line with the folder in which your clone of the repo is located
os.chdir("/Users/Caroline/Desktop/school/MamalakisResearch")

base_path = os.getcwd()

data_path = base_path + '/data/'

# Move to the data folder
os.chdir("data")

filenames = [
    "CNRM_ESM2-1_ssp119_ssp126_201501_210012_r1-5_2pt5degree.nc",
    "MIROC6_ssp119_ssp126_201501_210012_r1-5_2pt5degree.nc",
    "MPI-ESM1-2-LR_ssp119_ssp126_201501_210012_r1-5_2pt5degree.nc",
    "MRI-ESM2-0_ssp119_ssp126_201501_210012_r1-5_2pt5degree.nc",
    "UKESM1-0-LL_ssp119_ssp126_201501_210012_r1-5_2pt5degree.nc",
]

# We will call upon this later when loading files
files = [os.path.join(data_path, f) for f in filenames]

ds = nc.Dataset(files[0])
print(list(ds.variables.keys()))
ds.close()

['lat', 'lon', 'time', 'data_ssp119', 'data_ssp126']


In [29]:
ds = nc.Dataset(files[0])

print("data_ssp119 shape:", ds["data_ssp119"].shape)
print("data_ssp126 shape:", ds["data_ssp126"].shape)

print("data_ssp119 dims:", ds["data_ssp119"].dimensions)
print("data_ssp126 dims:", ds["data_ssp126"].dimensions)

ds.close()

data_ssp119 shape: (5, 7, 1032, 144, 73)
data_ssp126 shape: (5, 7, 1032, 144, 73)
data_ssp119 dims: ('ensemble', 'variable', 'time', 'lon', 'lat')
data_ssp126 dims: ('ensemble', 'variable', 'time', 'lon', 'lat')


In [30]:
ds = nc.Dataset(files[1])
x = np.array(ds["data_ssp119"][0, :, 0, :, :])  # (7, lon, lat)
for j in range(x.shape[0]):
    print(j, np.nanmin(x[j]), np.nanmax(x[j]), np.nanmean(x[j]))
ds.close()

0 234.47755571890667 309.8531804281488 270.3023092874373
1 236.21079545127867 322.26223186661866 274.1879111830409
2 232.61641495096973 301.92593328821147 266.7963481147974
3 2.125162459204466e-12 0.00026221260865620655 1.8725814006290823e-05
4 98075.16364846689 103840.45692411123 100618.62371195715
5 0.0015542857193698485 12.37257435270109 2.7059935880144472
6 3.1505518403209596 45.331031811593704 28.4120984137752


## Variable Index Map and Units

In [31]:
# Variable order is explicitly given by metadata:
# "tas, tasmax, tasmin, pr, psl, sfcWind, mrsos"
var_to_index = {
    "tas": 0,
    "tasmax": 1,
    "tasmin": 2,
    "pr": 3,
    "psl": 4,
    "sfcWind": 5,
    "mrsos": 6,
}

# Units are also given:
var_units = {
    "tas": "K",
    "tasmax": "K",
    "tasmin": "K",
    "pr": "kg/m2s",
    "psl": "Pa",
    "sfcWind": "m/s",
    "mrsos": "kg/m2",
}

## Time Handling (Monthly Index → Years)

In [32]:
def get_model_name(path: str) -> str:
    # Everything before "_ssp..."
    return os.path.basename(path).split("_ssp")[0]


def months_to_year_month(time_months: np.ndarray, start_year=2015, start_month=1):
    """
    File says time units are 'months' and it spans 2015-2100.
    This creates year + month arrays assuming the first index corresponds to Jan 2015.

    If my time axis is "month count since 2015-01", this is correct.
    If not, it still gives consistent indexing as long as the file starts at 2015-01.
    """
    # time_months is usually 0..1031 or 1..1032 depending on how the file was written but I handle either by shifting to 0-based.
    t = np.array(time_months, dtype=int)
    if t.min() == 1:
        t = t - 1

    # compute year/month
    year = start_year + (start_month - 1 + t) // 12
    month = (start_month - 1 + t) % 12 + 1
    return year, month


def time_mask_for_year_range(ds: nc.Dataset, start_year: int, end_year: int):
    """
    Create a mask over the monthly time axis using year bounds.
    """
    t = ds["time"][:]
    year, month = months_to_year_month(t, start_year=2015, start_month=1)
    return (year >= start_year) & (year <= end_year)

In [33]:
def get_model_name(path: str) -> str:
    # Everything before "_ssp..."
    return os.path.basename(path).split("_ssp")[0]


def months_to_year_month(time_months: np.ndarray, start_year=2015, start_month=1):
    """
    File says time units are 'months' and it spans 2015-2100.
    This creates year + month arrays assuming the first index corresponds to Jan 2015.
    """
    # Cast to numpy array in case it's passed as an xarray DataArray
    t = np.array(time_months, dtype=int)
    
    # Handle 1-based vs 0-based indexing by shifting to 0-based
    if t.min() == 1:
        t = t - 1

    # Compute year/month using floor division and modulo
    year = start_year + (start_month - 1 + t) // 12
    month = (start_month - 1 + t) % 12 + 1
    return year, month


def time_mask_for_year_range(ds: xr.Dataset, start_year: int, end_year: int):
    """
    Create a mask over the monthly time axis using year bounds.
    """
    # Using .values to pull the underlying numpy array from the xarray object
    t = ds["time"].values
    
    # Conversion logic
    year, month = months_to_year_month(t, start_year=2015, start_month=1)
    
    # Returning the boolean mask as a numpy array 
    return (year >= start_year) & (year <= end_year)

## Unit Conversions and Standardization

In [34]:
def convert_units(varname: str, x: np.ndarray) -> tuple[np.ndarray, str]:
    """
    Convert raw units into more interpretable and plottable units.
    - tas/tasmax/tasmin: K to C
    - pr: kg/m2s to mm/day  (1 kg/m2 = 1 mm water; multiply by 86400)
    - psl: Pa to hPa
    - sfcWind: keep m/s
    - mrsos: keep kg/m2 
    """
    if varname in {"tas", "tasmax", "tasmin"}:
        return x - 273.15, "°C"
    if varname == "pr":
        return x * 86400.0, "mm/day"
    if varname == "psl":
        return x / 100.0, "hPa"
    if varname == "sfcWind":
        return x, "m/s"
    if varname == "mrsos":
        return x, "kg/m²"
    return x, "unknown"

## Statistics Functions (Mean, Std, Median, Percentiles)

In [35]:
def compute_stat_over_time(x: np.ndarray, stat: str) -> np.ndarray:
    """
    x: (ens, time, lat, lon) after loading and swapping
    Returns: (ens, lat, lon) after aggregating over time
    Supported stats:
      - mean (default)
      - std
      - median
      - percentile_XX  (ex. percentile_95)
    """
    s = stat.lower().strip()

    if s == "mean":
        return np.nanmean(x, axis=1)
    if s == "std":
        return np.nanstd(x, axis=1)
    if s == "median":
        return np.nanmedian(x, axis=1)

    m = re.match(r"percentile[_\s-]?(\d+)", s)
    if m:
        p = float(m.group(1))
        return np.nanpercentile(x, p, axis=1)

    raise ValueError(f"Unknown stat '{stat}'. Use mean/std/median/percentile_XX.")

## Loading and Aggregating Data for One Model

In [36]:
# def manipulate_dataset(
#     varname: str,
#     scenario: str,
#     model_name: str,
#     period: tuple[int, int] | None = (2015, 2100),
#     region: tuple[int, int, int, int] | None = None
# ): 

#     """
#     Opens and manipulates given dataset for a chosen scenario, time period, and region where the inputs are:
#       varname: one of tas, tasmax, tasmin, pr, psl, sfcWind, mrsos
#       scenario: "ssp119" or "ssp126"
#       model_name: one of 5 models available in data folder
#       period: None (default), (start_year, end_year)
#       region: entire globe (default), lon/lat range
#     """

#     # Changing to the data folder - for creation of this folder and for data downloading, see data_download script on GitHub
#     current_dir = os.getcwd()
#     # Add a check
#     #if current-dir does not end in "/data:""
#         # os.chdir(current_dir + "/data")

#     # Scenario check
#     scenario = scenario.lower()
#     if scenario not in {"ssp119", "ssp126"}:
#         raise ValueError("scenario must be 'ssp119' or 'ssp126'")

#     key = f"data_{scenario}"
#     vidx = var_to_index[varname]

#     # Opening dataset 
#     ds = xr.open_dataset(model_name, engine="netcdf4")


    
#     # Variable manipulation (Selecting from var_to_index)
#     data_array = ds[key].isel(variable=var_to_index[key])

#     # Subset time
#     if period is not None:
#         ya, yb = period
#         data_array = data_array.sel(time=slice(str(ya), str(yb)))
#     else:
#         ya, yb = int(data_array.time[0]), int(data_array.time[-1])
        
#     # Subset regions
#     if region is not None:
#         lat1, lat2, lon1, lon2 = region
#         data_array = data_array.sel(lat=slice(lat1, lat2), lon=slice(lon1, lon2))
    
#     # Cosine logic 
#     cosl = np.cos(np.pi * data_array.lat / 180)
    
#     # Spatial averaging
#     ts = data_array.weighted(cosl).mean(dim=("lat", "lon"))
#     all_ts.append(ts.compute())

#     return ds, ya, yb, lat1, lat2, lon1, lon2

In [37]:
def manipulate_dataset(
    varname: str,
    scenario: str,
    model_name: str,
    period: tuple[int, int] | None = (2015, 2100),
    region: tuple[int, int, int, int] | None = None
): 

    """
    Opens and manipulates given dataset for a chosen scenario, time period, and region.
    """

    # Scenario check
    scenario = scenario.lower()
    if scenario not in {"ssp119", "ssp126"}:
        raise ValueError("scenario must be 'ssp119' or 'ssp126'")

    key = f"data_{scenario}"
    vidx = var_to_index[varname]

    # Opening dataset 
    ds = xr.open_dataset(model_name, engine="netcdf4")

    # Unpacking period tuple and setting up variables
    start_year, end_year = period

    # Subsetting time
    if period is not None:
        mask = time_mask_for_year_range(ds, start_year, end_year) 
        if mask.sum() == 0:
            ds.close()
            raise ValueError(f"No months found between {start_year}-{end_year} in {model_name}")

    # Raw shape: (ensemble, var, time, lon, lat)
    raw = ds[key].isel(variable=vidx).sel(time=mask)

    # Transpose to (ensemble, time, lat, lon) from (ensemble, time, lon, lat) 
    data_array = raw.transpose("ensemble", "time", "lat", "lon")

    # Converting units
    data_array, unit_label = convert_units(varname, data_array)

    # Subsetting regions
    if region is not None:
        lat1, lat2, lon1, lon2 = region
        data_array = data_array.sel(lat=slice(lat1, lat2), lon=slice(lon1, lon2))
    else:
        # Define these for the return statement if no region is provided
        lat1, lat2, lon1, lon2 = ds.lat.min(), ds.lat.max(), ds.lon.min(), ds.lon.max()
    
    # Applying cosine logic, thus weighting the average since grid cells get smaller toward the poles
    cosl = np.cos(np.pi * data_array.lat / 180)
    
    # Spatial averaging
    ts = data_array.weighted(cosl).mean(dim=("lat", "lon"))
    processed_data = ts.compute()

    ds.close()

    # "Repack" tuple
    region = (float(lat1), float(lat2), float(lon1), float(lon2))

    return processed_data, start_year, end_year, region

## Time Series Plot

Produces a plot of a time series from 2015 to 2100 for the selected variable, scenario, and model. Will show all 5 trajectories within a model. 

Inputs: 
- Variable
- Scenario
- Base period (2015-2100 default)
- Region (entire globe default, lon/lat range)
- Statistic
- Number models returned (all 5 default, single model name) 
- File name(s)


In [38]:
def time_series_plot(
    varname: str,
    scenario: str,
    period: tuple[int, int] | None = (2015, 2100),
    region: tuple[int, int, int, int] | None = None,
    stat: str = "mean",
    multimodel: bool = True,
    model_name: list[str] | None = None,
):
    """
    Produces a time series plot where the inputs are:
      varname: one of tas, tasmax, tasmin, pr, psl, sfcWind, mrsos
      scenario: "ssp119" or "ssp126"
      period: None (default), (start_year, end_year)
      region: entire globe (default), lon/lat range
      stat: mean (default), std, median, percentile_XX
      multimodel: True (average across all 5 models) (default)
      model_name: if not None, ignore multimodel and plot only that/those model(s)
    """

    # Variable check
    if varname not in var_to_index:
        raise ValueError(f"varname must be one of {list(var_to_index.keys())}")
    
    # Scenario check
    scenario = scenario.lower()
    if scenario not in {"ssp119", "ssp126"}:
        raise ValueError("Scenario must be 'ssp119' or 'ssp126'.")

    # Model name check
    # If model name(s) given, filters to only available files
    if model_name:
        file_list = []
        for i in range(len(model_name)):
            if model_name[i] in filenames:
                file_list.append
            else: 
                raise ValueError(f"model {model_name[i]} not found. Available: {files}.")
  
    # If multimodel (and no model names given)
    if multimodel:
        pass

    # If a model name or names given (and not multimodel)
    else:
        # Pull variables from manipulate_dataset function
        for i in range(len(model_name)): 

            ds, start_year, end_year, region = manipulate_dataset(varname, scenario, model_name[i], period, region)
            
            # For years ranging from ya to yb, plot ith element of the list of model names with color pink (C6) and alpha 0.3
            plt.plot(np.arange(start_year, end_year), file_list[i], 'C6', alpha=0.3)
            # And plot the chosen stat of all models together
            ##computed_stat = compute_stat_over_time(model_name[i], stat)
            #plt.plot(np.arange(start_year, end_year), computed_stat, 'C6',alpha=0.3)
            
        plt.xlabel('x')
        plt.ylabel('y')
        plt.axis('tight')
        plt.grid(color='0.8')
        plt.legend()
        plt.show()

        ds.close()


In [39]:
if __name__ == "__main__":

   # Example 1: Time-series of tas in ssp119 scenario
    time_series_plot(
        varname="tas",
        scenario="ssp119", 
        period=(2020, 2039), 
        region=(-80, 80, 100, 200),
        stat="mean",
        multimodel=False,
        model_name = ["CNRM_ESM2-1_ssp119_ssp126_201501_210012_r1-5_2pt5degree.nc"]
    )


IndexError: list index out of range