# Explore regridded data

This is a notebook for visualizing regridded CMIP6 data!

In [3]:
import warnings
from pathlib import Path
import cftime
import dask
import numpy as np
import xarray as xr
import hvplot.xarray # noqa
import panel.widgets as pnw
import panel as pn
from tqdm import tqdm

# ignore serializationWarnings from xarray for datasets with multiple FillValues
import warnings
warnings.filterwarnings("ignore", category=xr.SerializationWarning)

Set up the data as a dask array:

In [208]:
def flatten(l):
    return [item for sublist in l for item in sublist]


def fix_hour(ts):
    """Fix the hour in a cftime object if they are 0 instead of 12 (standard is 12 for most all frequencies)"""
    s = pd.Series(ts)
    if np.any(s.dt.hour != 12):
        new_ts = pd.to_datetime(
            [f"{t.year}-{t.month}-{t.day}T12:00:00" for t in s]
        )
    else:
        new_ts = ts

    return new_ts


def filter_paths_by_timestamp(fps, min_ts, max_ts):
    """iterate over filepaths, open them up and drop any that have timestamps outside of [min_ts, max_ts]"""
    keep_fps = []
    
    # ignore serializationWarnings from xarray for datasets with multiple FillValues
    # warnings.filterwarnings("ignore", category=xr.SerializationWarning)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=xr.SerializationWarning)
    
        for fp in fps:
            with xr.open_dataset(fp, decode_cf=True) as ds:
                ts = ds.time.values
                
                # sometimes decode_cf not working?
                bad_types = [cftime._cftime.Datetime360Day, cftime._cftime.DatetimeNoLeap]
                if type(ts[0]) in bad_types:
                    try:
                        # also if the time is out of bounds, we don't even need to compare it
                        ts = pd.to_datetime([cf.isoformat() for cf in ts])
                    except pd.errors.OutOfBoundsDatetime:
                        continue
                
                # code here to make sure that ts has hour 12 instead of hour == 0 (not the case for all models)
                ts = fix_hour(ts)
                
                if (ts.min() >= min_ts) and (ts.max() <= max_ts):
                    keep_fps.append(fp)
                        
    return keep_fps

In [235]:
final_regrid_dir = Path("/beegfs/CMIP6/arctic-cmip6/regrid/")
var_ids = ["tas", "pr"]
models = [p.name for p in list(final_regrid_dir.glob("*"))]
scenarios = ["ssp245", "ssp585"]
min_ts = np.datetime64('2015-01-16T12:00:00.000000000')
max_ts = np.datetime64('2100-12-31T12:00:00.000000000')


model_datasets = []
for model in tqdm(models):
    scenario_datasets = []
    for scenario in scenarios:
        fps = flatten([list(final_regrid_dir.joinpath(model, scenario, "Amon", var_id).glob("*.nc")) for var_id in var_ids])
        fps = filter_paths_by_timestamp(fps, min_ts, max_ts)
        if len(fps) == 0:
            continue
            
        scenario_ds = xr.open_mfdataset(fps).expand_dims(
            dim={"model": 1, "scenario": 1}
        ).assign_coords({"model": [model], "scenario": [scenario]})
        
        # these datasets are still annoyingly heterogenous, hence all the code!
        # drop the height coordinate if it exists
        if "height" in scenario_ds.coords:
            scenario_ds = scenario_ds.drop("height")
        
        # now we just want to make the time dimension yearmonth instead, for simplicity
        # (some files have days as 14 and 15 of month instead of standard 15 / 16)
        # sometimes decode_cf in xr.open_dataset not working (missing attr maybe?)
        cf_types = [cftime._cftime.Datetime360Day, cftime._cftime.DatetimeNoLeap]
        if type(scenario_ds.time.values[0]) in cf_types:
            ym = pd.to_datetime([f"{cf.year}-{cf.month}" for cf in scenario_ds.time.values])
        else:
            s = pd.Series(scenario_ds.time.values)
            years = s.dt.year
            months = s.dt.month
            ym = pd.to_datetime([f"{year}-{month}" for year, month in zip(years, months)])
        # scenario_ds = scenario_ds.assign(time=new_times)
        
            
        # make sure hour is 12 for all timestamps
        # scenario_ds = scenario_ds.assign(time=fix_hour(new_times))
        
        scenario_ds = scenario_ds.rename(time="year_month").assign(year_month=ym)
            
        scenario_datasets.append(scenario_ds)
        
    model_ds = xr.combine_nested(scenario_datasets, concat_dim=["scenario"], combine_attrs="drop_conflicts")
    model_datasets.append(model_ds)
    
with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    viz_ds = xr.combine_nested(model_datasets, concat_dim=["model"], combine_attrs="drop_conflicts")

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:18<00:00,  1.50s/it]


Run the viz:

In [249]:
#ignoring runtime warning taking mean min/max of all-nan slice (e.g. most indices for interval=1)
#warnings.filterwarnings("ignore")

var_select = pnw.Select(name="Variable", options={"tas": "tas", "pr": "pr"})

models_rbg = pnw.Select(
    name="Model", options={model: model for model in models}
)
scenarios_rbg = pn.widgets.RadioButtonGroup(
    name="Scenario", options={scenario: scenario for scenario in scenarios}, button_type="default"
)
year_month = pnw.Player(name="year_month", start=0, end=1031, loop_policy='loop', interval=1)

var_da = viz_ds.interactive()[var_select].sel(model=models_rbg, scenario=scenarios_rbg).isel(year_month=year_month)
var_da.hvplot(cmap='coolwarm', kind="image").opts(width=800, height=500)