Here I we examine the pattern of warming associated with the largest 30 year trend in SST# for the first 150 years of a PiControl simulation for each CMIP6 model

In [2]:
import os
import sys

import numpy as np
import xarray as xr
import xcdat as xc
import xskillscore as xscore

from glob import glob
from typing import Tuple, Dict, Union, Optional, Any, Callable, Iterable, Sequence, cast

# Ignore xarray warnings (bad practice)
import warnings
warnings.simplefilter("ignore") 


In [3]:
def load_datasets(era: str = "cmip6") -> Tuple[xr.Dataset, Union[xr.Dataset, Dict], Union[xr.Dataset, Dict], Union[xr.Dataset, Dict]]:
    models = []
    for model_path in glob(f"data/sharp/{era}/SWCF_*.nc"):
            model_name = model_path.split("/")[-1].split(".")[0][5:]
            models.append(model_name)

    ################## Load SST Sharp Trends and SST Sharp #######################
    tos_sharp = {}
    for model in models:
        try: 
            tos_sharp[model] = xr.open_dataset(f"data/sharp/{era}/SSTsharp_{model}.nc") 
            models.append(model)
        except Exception as e:
            print(model, e)

    ds_tos_sharp = xr.concat(list(tos_sharp.values()), dim='model')
    # Set the coordinates for the 'model' dimension
    ds_tos_sharp = ds_tos_sharp.assign_coords(model=models)

    return  ds_tos_sharp


In [4]:
ds_tos_sharp_cmip6 = load_datasets(era="cmip6")
ds_tos_sharp_cmip6

KeyboardInterrupt: 

In [None]:
def calculate_rolling_trend(data: xr.Dataset, model: str = "", window: int = 12*30, step: int = 12) -> xr.Dataset:
    """Calculate the rolling gradient of a dataset

    Args:
        data (xr.Dataset): _description_
        window (int, optional): _description_. Defaults to 120*30.
        step (int, optional): _description_. Defaults to 12.

    Returns:
        np.array: _description_
    """

    # Calculate SSTd = SSTsharp (warm) - SSTflat (cold)
    SSTd = data["sharp"] - data["flat"]
    SSTsharp = data["sharp"] 
    SSTflat = data["flat"]
    ntime = SSTd.shape[0]

    SSTd_trend = np.zeros(int(np.ceil(np.ceil((ntime - window)/12))))
    SSTsharp_trend = np.zeros(int(np.ceil(np.ceil((ntime - window)/12))))
    SSTflat_trend = np.zeros(int(np.ceil(np.ceil((ntime - window)/12))))
    SSTd_trend.fill(np.nan)
    SSTsharp_trend.fill(np.nan)
    SSTflat_trend.fill(np.nan)

    # Calculate rolling trend of SSTd
    time_idx = xr.DataArray(np.arange(window), dims=("time"))
    for j, i in enumerate(range(0, ntime - window, step)):
        slice_SSTd = SSTd.isel(time=slice(i, int(i+window)))
        slice_SSTsharp = SSTsharp.isel(time=slice(i, int(i+window)))
        slice_SSTflat = SSTflat.isel(time=slice(i, int(i+window)))
        SSTd_trend[j] = xscore.linslope(time_idx, slice_SSTd, dim='time', skipna=True).values
        SSTsharp_trend[j] = xscore.linslope(time_idx, slice_SSTsharp, dim='time', skipna=True).values
        SSTflat_trend[j] = xscore.linslope(time_idx, slice_SSTflat, dim='time', skipna=True).values
    
    # Convert SSTd trend and SSTd to xr.DataArray
    SSTd_trend = xr.DataArray(SSTd_trend, dims=("time"), coords={"time": np.arange(SSTd_trend.shape[0])})
    SSTsharp_trend = xr.DataArray(SSTsharp_trend, dims=("time"), coords={"time": np.arange(SSTd_trend.shape[0])})
    SSTflat_trend = xr.DataArray(SSTflat_trend, dims=("time"), coords={"time": np.arange(SSTd_trend.shape[0])})
    # Convert SSTd to xr.DataArray
    SSTd = xr.DataArray(SSTd, dims=("time"), coords={"time": np.arange(SSTd.shape[0])})
    SSTsharp= xr.DataArray(SSTsharp, dims=("time"), coords={"time": np.arange(SSTd.shape[0])})
    SSTflat = xr.DataArray(SSTflat, dims=("time"), coords={"time": np.arange(SSTd.shape[0])})
    # Combine SSTd and SSTd_trend into xr.Dataset
    ds = xr.Dataset({"SSTd": SSTd, "SSTd_trend": SSTd_trend, "SSTsharp": SSTsharp, "SSTsharp_trend": SSTsharp_trend, "SSTflat": SSTflat, "SSTflat_trend": SSTflat_trend})

    # Sanity plot of SSTd and SSTd_trend
    # sanity_rolling_trend(window=window, sharp=data["sharp"], flat=data["flat"], trend=SSTd_trend*window, raw=SSTd, model=model) 

    return ds

# Load SST Sharp Trend
def get_rolling_trend(ds_tos_sharp):
    tos_trends_sharp = {}

    for model in ds_tos_sharp.model.values:
        try: 
            tos_trends_sharp[model] = calculate_rolling_trend(ds_tos_sharp.sel(model=model))
        except Exception as e:
            print(model, e)

    ds_tos_trends_sharp = xr.concat(list(tos_trends_sharp.values()), dim='model')
    ds_tos_trends_sharp = ds_tos_trends_sharp.assign_coords(model=list(tos_trends_sharp.keys()))
    return ds_tos_trends_sharp

In [None]:
ds_tos_trends_sharp_cmip6 = get_rolling_trend(ds_tos_sharp_cmip6)
ds_tos_trends_sharp_cmip6