In [1]:
import os
import sys
import pickle

import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import xarray as xr
import xesmf as xe
import xcdat as xc
import xsearch as xs
import xskillscore as xscore

from glob import glob 
from global_land_mask import globe
from typing import List, Tuple, Dict, Union, Optional, Any, Callable, Iterable, Sequence, cast
from scipy.stats import linregress

# Ignore xarray warnings (bad practice)
import warnings
warnings.simplefilter("ignore") 

In [2]:
os.chdir('/home/espinosa10/tropical_pacific_clouds')

In [17]:
def ingest_and_process(
    output_grid: np.ndarray,
    var: str = "tos", 
    cmipTable: str ="Omon", 
    era: str = "CMIP6",
    testing: bool = False,
    calc_anoms: bool = False
) -> dict:
    """
    Ingests all the data from the CMIP piControl experiment. 
    Optionally remove the seasonal cycle and detrend the data, regrid to a commond 2.5 x 2.5 degree grid and return a new dataset

    Returns:
        xarray datasets with dimensions (model, time, lat, lon)
    """

    # Specify the ensemble member
    if era == "CMIP6": 
        member = "r1i1p1f1"
    else:
        member = "r1i1p1"

    # Find all the paths to the data
    dpaths = xs.findPaths(
        experiment="ssp370",
        variable=var,
        frequency="mon",
        cmipTable=cmipTable,
        mip_era=era, 
        activity="ScenarioMIP",
        member=member
    )
    models =  xs.getGroupValues(dpaths, 'model')
    print("Models: ", len(models), models)
    dpaths = list(dpaths.keys())
    
    # Create an empty Dataset
    ds = []
    valid_models = []

    for i, (model_path, model) in enumerate(zip(dpaths, models)):
        print("Starting model: ", model)
        print("Model path: ", model_path)

        try:
            # Load data
            da = xc.open_mfdataset(glob(model_path + "/*.nc")) #, parallel=True, chunks="auto")
            if (var == "ta"):
                da = da.sel(plev=70000)
            if (var in ["hur", "ua", "va", "zg"]): 
                # da = da.sel(plev=100000)
                da = da.sel(plev=85000)

            # Regrid Data
            output = da.regridder.horizontal(var, output_grid, tool='xesmf', method='bilinear')

            # Deseasonalize (no need to detrend piControl data, there should be no drift)
            if calc_anoms:
                output = output.temporal.departures(var, "month")

            ntime, _, _ = output[var].shape

            # nmonths = 165*12 # historical
            nmonths = (2100-2015)*12 # piControl
            if ntime < nmonths:
                continue
            
            output = output[var][:nmonths]
            time = np.arange(np.datetime64("2015-01"), np.datetime64("2015-01") + np.timedelta64(nmonths, 'M'), dtype="datetime64[M]")
            output["time"] = time
            coords = output.coords.keys()
            print(coords)
            if "height" in coords:
                output = output.drop('height')
            ds.append(output)
            valid_models.append(model)

            # Only load one model if testing
            if testing:
                _, axes = plt.subplots(ncols=2, figsize=(16, 4))
                da[var].isel(time=0).plot(ax=axes[0])
                axes[0].set_title('Input data')
                output.isel(time=0).plot(ax=axes[1])
                axes[1].set_title('Output data')
                plt.tight_layout()

                if i == 1:
                    break
        
        except Exception as e:
            print("Model failed: ", model, e)

    ds = xr.concat(ds, dim='model', coords='minimal', compat='override')
    ds.coords['model'] = list(valid_models)

    return ds


def collect_data():
    """
    Iterate through all the variables and eras and save the data
    """

    # Define Constants
    eras = ["CMIP6"]
    # variables = ["rsutcs", "rsut", "tos"] # Shortwave cloud forcing
    # variables = ["hur", "tas", "psl", "ta"] # EIS 
    # variables = ["rlut", "rlutcs"] # Longwave cloud forcing
    variables = [
        # "hur"
        # "rlds", # surface downwelling longwave flux, all sky
        # "rlus", # surface upwelling longwave flux, all sky
        # "rldscs", # surface downwelling longwave flux, clear sky
        # "rsds", # surface downwelling shortwave flux, all sky
        # "rsdscs", # surface downwelling shortwave flux, clear sky
        # "rsus", # surface upwelling shortwave flux, all sky (probably don't need)
        # "rsuscs", # surface upwelling shortwave flux, clear sky (probably don't need)
        # "hfls", # surface upward latent heat flux
        # "hfss", # surface upward shortwave heat flux
        # "sfcWind", # surface wind speed
        # "uas", # zonal 10 meter wind speed 
        # "vas", # meridional 10 meter wind speed
        # "psl", # Mean sea-level pressure
        # "ua", # zonal wind speed (lowest level bc not all models have 10m wind speed)
        # "va" # meridional wind speed (lowest level bc not all models have 10m wind speed)
        # "zg"
        # "tauu", # zonal wind stress
        # "tauv", # meridional wind stress
        # "tas", # surface air temperature
        # "pr", # precipitation flux (km m^-2 s^-1)
        "siconc",
        # "sic",
    ]
    calc_anoms = False

    # Create output grid
    # lat = np.arange(-88.75, 90, 2.5)
    # lon = np.arange(1.25, 360, 2.5)
    lat = np.arange(-89.5, 90.5, 1)
    lon = np.arange(.5, 360.5, 1)
    output_grid = xc.create_grid(lat, lon)

    for era in eras: 
        print("Starting era: ", era)
        for var in variables:
            if var == "sfcWind" and era == "CMIP6": continue

            print("Starting variable: ", var)

            if var == "tos": 
                table = "Omon"
            elif var in ["siconc", "sic"]:
                if var == "siconc": 
                    era = "CMIP6"
                    table = "SImon"
                if var == "sic": 
                    era  = "CMIP5"
                    table = "OImon"
            else: 
                table = "Amon"

            ds = ingest_and_process(
                var=var,
                cmipTable=table,
                era=era,
                testing=TESTING,
                output_grid=output_grid,
                calc_anoms=False
            )

            print(ds)
            if (var == "hur") or (var == "ta"):
                var = f"{var}-surface"
                
            if calc_anoms:
                ds.to_netcdf(f"{var}_mon_{era}_ssp370.nc")
                ds.close()
            else: 
                ds.to_netcdf(f"{var}_mon_{era}_ssp370.nc")
                ds.close()


TESTING = False
collect_data()

Starting era:  CMIP6
Starting variable:  siconc
Models:  28 ['BCC-CSM2-MR', 'CESM2-WACCM', 'FGOALS-f3-L', 'CAS-ESM2-0', 'FGOALS-g3', 'GFDL-ESM4', 'CanESM5', 'MPI-ESM1-2-HR', 'CAMS-CSM1-0', 'INM-CM5-0', 'INM-CM4-8', 'MPI-ESM-1-2-HAM', 'TaiESM1', 'EC-Earth3-Veg-LR', 'EC-Earth3', 'EC-Earth3-AerChem', 'CMCC-ESM2', 'CMCC-CM2-SR5', 'ACCESS-ESM1-5', 'ACCESS-CM2', 'IPSL-CM6A-LR', 'IPSL-CM5A2-INCA', 'NorESM2-LM', 'NorESM2-MM', 'MPI-ESM1-2-LR', 'EC-Earth3-Veg', 'CanESM5-1', 'E3SM-2-0']
Starting model:  BCC-CSM2-MR
Model path:  /p/css03/esgf_publish/CMIP6/ScenarioMIP/BCC/BCC-CSM2-MR/ssp370/r1i1p1f1/SImon/siconc/gn/v20200219/
KeysView(Coordinates:
  * time     (time) datetime64[ns] 2015-01-01 2015-02-01 ... 2099-12-01
    type     |S7 ...
  * lat      (lat) float64 -89.5 -88.5 -87.5 -86.5 -85.5 ... 86.5 87.5 88.5 89.5
  * lon      (lon) float64 0.5 1.5 2.5 3.5 4.5 ... 355.5 356.5 357.5 358.5 359.5)
Starting model:  CESM2-WACCM
Model path:  /p/css03/esgf_publish/CMIP6/ScenarioMIP/NCAR/CESM2-WACCM/s