In [30]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import xesmf as xe
import xcdat as xc
import xsearch as xs
import xskillscore as xscore

from glob import glob 
from typing import List, Tuple, Dict, Union, Optional, Any, Callable, Iterable, Sequence, cast

# Ignore xarray warnings (bad practice)
import warnings
warnings.simplefilter("ignore") 

In [36]:
def ingest_and_process(
    output_grid: np.ndarray,
    var: str = "tos", 
    cmipTable: str ="Omon", 
    era: str = "CMIP6",
    testing: bool = False,
    calc_anoms: bool = False
) -> dict:
    """
    Ingests all the data from the CMIP piControl experiment. 
    Optionally remove the seasonal cycle and detrend the data, regrid to a commond 2.5 x 2.5 degree grid and return a new dataset

    Returns:
        xarray datasets with dimensions (model, time, lat, lon)
    """
    
    # Find all the paths to the data
    dpaths = xs.findPaths(
        experiment="piControl",
        variable=var,
        frequency="mon",
        cmipTable=cmipTable,
        mip_era=era, 
        activity="CMIP",
    )
    models =  xs.getGroupValues(dpaths, 'model')
    dpaths = list(dpaths.keys())
    
    # Create an empty Dataset
    ds = []
    valid_models = []

    for i, (model_path, model) in enumerate(zip(dpaths, models)):
        try:
            print("Starting model: ", model)
            # Load data
            da = xc.open_mfdataset(glob(model_path + "/*.nc")) #, parallel=True, chunks="auto")

            # Regrid Data
            output = da.regridder.horizontal(var, output_grid, tool='xesmf', method='bilinear')

            # Deseasonalize (no need to detrend piControl data, there should be no drift)
            if calc_anoms:
                output = output.temporal.departures(var, "month")

            ntime, _, _ = output[var].shape

            nmonths = 150*12
            if ntime < nmonths:
                continue

            output = output[var][:nmonths]
            time = np.arange(np.datetime64("1850-01"), np.datetime64("1850-01") + np.timedelta64(nmonths, 'M'), dtype="datetime64[M]")
            output["time"] = time
            ds.append(output)
            valid_models.append(model)

            # Only load one model if testing
            if testing:
                _, axes = plt.subplots(ncols=2, figsize=(16, 4))
                da[var].isel(time=0).plot(ax=axes[0])
                axes[0].set_title('Input data')
                output.isel(time=0).plot(ax=axes[1])
                axes[1].set_title('Output data')
                plt.tight_layout()

                if i == 1:
                    break
        
        except Exception as e:
            print("Model failed: ", model, e)
            continue

    ds = xr.concat(ds, dim='model')
    ds.coords['model'] = list(valid_models)

    return ds


def collect_data():
    """
    Iterate through all the variables and eras and save the data
    """

    # Define Constants
    eras = ["CMIP5"] #, "CMIP5"]
    variables = ["tos"] # ,"rsutcs", "rsut"] # tos == sea surface temperature
    calc_anoms = False

    # Create output grid
    lat = np.arange(-88.75, 90, 2.5)
    lon = np.arange(1.25, 360, 2.5)
    output_grid = xc.create_grid(lat, lon)

    for era in eras: 
        print("Starting era: ", era)
        for var in variables:
            print("Starting variable: ", var)

            if var == "tos": 
                table = "Omon"
            else: 
                table = "Amon"

            ds = ingest_and_process(
                var=var,
                cmipTable=table,
                era=era,
                testing=False,
                output_grid=output_grid,
                calc_anoms=False
            )

            print(ds)
            if calc_anoms:
                ds.to_netcdf(f"/data/{var}_mon_1850-2100_anoms_{era}_piControl.nc")
            else: 
                ds.to_netcdf(f"/data/{var}_mon_1850-2100_{era}_piControl.nc")
                

        break

collect_data()

Starting era:  CMIP5
Starting variable:  tos
Starting model:  CNRM-CM5
Starting model:  bcc-csm1-1
Starting model:  bcc-csm1-1-m
Starting model:  GFDL-ESM2G
Starting model:  GFDL-ESM2M
Starting model:  GFDL-CM3
Starting model:  EC-EARTH
Starting model:  CESM1-WACCM
Starting model:  CESM1-CAM5
Starting model:  CESM1-FASTCHEM
Starting model:  CESM1-BGC
Starting model:  IPSL-CM5B-LR
Starting model:  IPSL-CM5A-MR
Starting model:  IPSL-CM5A-LR
Starting model:  HadGEM2-ES
Starting model:  HadGEM2-CC
Starting model:  ACCESS1-3
Starting model:  ACCESS1-0
Starting model:  HadGEM2-AO
Starting model:  GISS-E2-H
Starting model:  GISS-E2-R
Starting model:  GISS-E2-R-CC
Starting model:  GISS-E2-H-CC
Starting model:  CSIRO-Mk3-6-0
Starting model:  CanESM2
Starting model:  inmcm4
Starting model:  CMCC-CMS
Starting model:  CMCC-CM
Starting model:  CMCC-CESM
Starting model:  MIROC5
Starting model:  NorESM1-ME
Starting model:  NorESM1-M
Starting model:  MPI-ESM-MR
Starting model:  MPI-ESM-P
Starting mode

In [None]:
def calculate_swcre(rsutcs: xr.Dataset, rsut: xr.Dataset, save: bool = False, save_name: str = ""):
    # Select common models
    common_models = list(set(rsutcs["model"].values) & set(rsut["model"].values))
    rsutcs = rsutcs.sel(model=common_models)
    rsut = rsut.sel(model=common_models)
    # Calculate SWCRE
    swcre = rsutcs["rsutcs"] - rsut["rsut"]
    swcre = swcre.rename("swcre").to_dataset()
    # Set time bounds
    swcre["time_bnds"] = rsut["time_bnds"]
    # Calculate SWCRE anomalies
    swcre_anoms = swcre.temporal.departures("swcre", "month")
    swcre_anoms = swcre_anoms.rename({"swcre": "swcre_anoms"})

    if save: 
        swcre.to_netcdf(f"data/swcre{save_name}_monthly_1850-2000.nc")
        swcre_anoms.to_netcdf(f"data/swcre{save_name}_anoms_monthly_1850-2000.nc")

    return swcre, swcre_anoms

# Load data
rsut_cmip5 = xc.open_dataset("data/rsut_mon_1850-2100_CMIP5_piControl.nc")
rsut_cmip6 = xc.open_dataset("data/rsut_mon_1850-2100_CMIP6_piControl.nc")
rsutcs_cmip6 = xc.open_dataset("data/rsutcs_mon_1850-2100_CMIP6_piControl.nc")
rsutcs_cmip5 = xc.open_dataset("data/rsutcs_mon_1850-2100_CMIP5_piControl.nc")

# Calculate SWCRE, SWCRE anomalies, and save
swcre_cmi6, swcre_anoms_cmip6 = calculate_swcre(rsutcs_cmip6, rsut_cmip6, save=True, save_name="_cmip6")
swcre_cmip5, swcre_anoms_cmip5 = calculate_swcre(rsutcs_cmip5, rsut_cmip5, save=True, save_name="_cmip5")


In [None]:
def calculate_swcf(tos: xr.Dataset, swcre: xr.Dataset, save: bool = False, save_name: str = "") -> xr.Dataset:
    # Calculate the correlation coefficient
    swcf = xscore.linslope(tos, swcre, dim='time')
    
    if save: 
        swcf.to_netcdf(f"data/swcf_{save_name}_monthly_1850-2000.nc")
    return swcf

# Load tos
tos_cmip6 = xc.open_dataset("data/tos_mon_1850-2100_CMIP6_piControl.nc")
tos_cmip5 = xc.open_dataset("data/tos_mon_1850-2100_CMIP5_piControl.nc")
# Calculate tos anomalies
tos_anoms_cmip6 = tos_cmip6.temporal.departures("tos", "month")
tos_anoms_cmip5 = tos_cmip5.temporal.departures("tos", "month")
# Calculate SWCF
swcf_cmip6 = calculate_swcf(tos_anoms_cmip6["tos"], swcre_anoms_cmip6["swcre_anoms"], save=True, save_name="cmip6")
swcf_cmip5 = calculate_swcf(tos_anoms_cmip5["tos"], swcre_anoms_cmip5["swcre_anoms"], save=True, save_name="cmip5")


In [None]:
def calculate_moving_trend():
    pass

def calculate_west_east_gradient():
    pass