# Preprocess global _daily_ SST to extract features to be tracked by `ocetrac-dask`
This example using 40 years of Daily outputs at 0.25° resolution takes ~5 minutes on 128 cores.

In [1]:
import xarray as xr
import numpy as np
import pandas as pd
import dask
import flox.xarray
import ocetrac_dask
import matplotlib.pyplot as plt

import intake
from tempfile import TemporaryDirectory
from getpass import getuser
from pathlib import Path
from dask.distributed import Client, LocalCluster
import subprocess
import re

import warnings
warnings.filterwarnings('ignore')

In [2]:
scratch_dir = Path('/scratch') / getuser()[0] / getuser() / 'mhws' 

## Start Dask Cluster

In [3]:
cluster_scratch = Path('/scratch') / getuser()[0] / getuser() / 'clients'
dask_tmp_dir = TemporaryDirectory(dir=cluster_scratch)
dask.config.set(temporary_directory=dask_tmp_dir.name)

## Local Cluster
cluster = LocalCluster(n_workers=32, threads_per_worker=4)  ## N.B.:  Reduce this is you have memory problems...
client = Client(cluster)

remote_node = subprocess.run(['hostname'], capture_output=True, text=True).stdout.strip().split('.')[0]
port = re.search(r':(\d+)/', client.dashboard_link).group(1)
print(f"Forward with Port = {remote_node}:{port}")

client.dashboard_link

Forward with Port = l40034:8787


'http://127.0.0.1:8787/status'

## Import 40 years of Daily EERIE ICON data

In [4]:
cat = intake.open_catalog("https://raw.githubusercontent.com/eerie-project/intake_catalogues/main/eerie.yaml")
expid = 'eerie-control-1950'
version = 'v20231106'
model = 'icon-esm-er'
gridspec = 'gr025'

dat = cat['dkrz.disk.model-output'][model][expid][version]['ocean'][gridspec]

In [5]:
## Flox chunking predictor data array
#   Optimise for dayofyear std reduction later...
da_predictor = dat['2d_daily_mean'](chunks={}).to_dask().to.isel(depth=0).drop('depth')
da_predictor_rechunk = flox.xarray.rechunk_for_cohorts(da_predictor, dim='time', labels=da_predictor.time.dt.dayofyear, force_new_chunk_at=1, chunksize=100, ignore_old_chunks=True)

In [6]:
sst = dat['2d_daily_mean'](chunks={'time':da_predictor_rechunk.chunks[0]}).to_dask().to.isel(depth=0).drop('depth')
sst

Unnamed: 0,Array,Chunk
Bytes,53.68 GiB,99.98 MiB
Shape,"(13879, 721, 1440)","(100, 182, 1440)"
Dask graph,608 chunks in 3 graph layers,608 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 53.68 GiB 99.98 MiB Shape (13879, 721, 1440) (100, 182, 1440) Dask graph 608 chunks in 3 graph layers Data type float32 numpy.ndarray",1440  721  13879,

Unnamed: 0,Array,Chunk
Bytes,53.68 GiB,99.98 MiB
Shape,"(13879, 721, 1440)","(100, 182, 1440)"
Dask graph,608 chunks in 3 graph layers,608 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


## Reformat Data

In [7]:
# Exchange grid from lon=[0,360) to lon=[-180,180)
sst['lon'] = np.mod(sst['lon'] + 180, 360) - 180
sst = sst.sortby(sst['lon'])

In [8]:
# Calculate decimal year for _daily outputs_
def decimal_year(da):
    time = pd.to_datetime(da.time)
    start_of_year = pd.to_datetime(time.year.astype(str) + '-01-01')
    start_of_next_year = pd.to_datetime((time.year + 1).astype(str) + '-01-01')
    year_elapsed = (time - start_of_year).days
    year_duration = (start_of_next_year - start_of_year).days
    return time.year + year_elapsed / year_duration

# Add into the dataset
dyr = decimal_year(sst)
sst = sst.assign_coords(decimal_year=('time', dyr))

In [9]:
sst

Unnamed: 0,Array,Chunk
Bytes,53.68 GiB,99.98 MiB
Shape,"(13879, 721, 1440)","(100, 182, 1440)"
Dask graph,608 chunks in 4 graph layers,608 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 53.68 GiB 99.98 MiB Shape (13879, 721, 1440) (100, 182, 1440) Dask graph 608 chunks in 4 graph layers Data type float32 numpy.ndarray",1440  721  13879,

Unnamed: 0,Array,Chunk
Bytes,53.68 GiB,99.98 MiB
Shape,"(13879, 721, 1440)","(100, 182, 1440)"
Dask graph,608 chunks in 4 graph layers,608 chunks in 4 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


## Remove trend, compute anomalies

In [10]:
threshold_percentile = 0.95

In [11]:
## De-trending:  coefficients = SST/model ---> coefficients = SST * pmodel

# The 6 coefficient model is composed of the mean, trend, annual sine and cosine harmonics, & semi-annual sine and cosine harmonics
model = np.array([np.ones(len(dyr))] + [dyr - np.mean(dyr)] + [np.sin(2 * np.pi * dyr)] + 
                 [np.cos(2 * np.pi * dyr)] + [np.sin(4 * np.pi * dyr)] + 
                 [np.cos(4 * np.pi * dyr)])

# Take the pseudo-inverse of model to solve least-squares problem
pmodel = np.linalg.pinv(model)

# Convert model and pmodel to xaray DataArray
model_da = xr.DataArray(model.T, dims=['time','coeff'], coords={'time':sst.time.values, 'coeff':np.arange(1,7,1)}).chunk({'time':sst.chunks[0]})
pmodel_da = xr.DataArray(pmodel.T, dims=['coeff','time'], coords={'coeff':np.arange(1,7,1), 'time':sst.time.values})  

# Resulting coefficients of the model
sst_mod = xr.DataArray(pmodel_da.dot(sst), dims=['coeff','lat','lon'], coords={'coeff':np.arange(1,7,1), 'lat':sst.lat.values, 'lon':sst.lon.values})


In [12]:
## Construct mean, trend, and seasonal cycle

mean  = model_da.isel(coeff=0).dot(sst_mod.isel(coeff=0))
trend = model_da.isel(coeff=1).dot(sst_mod.isel(coeff=1))
seas  = model_da.isel(coeff=2).dot(sst_mod.isel(coeff=2))

In [13]:
## Compute anomalies by removing all the model coefficients 
ssta_notrend = (sst - model_da.dot(sst_mod))

## Standardises SSTa by dividing by the _30-day rolling_ standard deviation
This step places equal variance on SSTa at all spatial points

In [14]:
# Compute the daily standard deviation
#stdev_day = ssta_notrend.groupby(ssta_notrend.time.dt.dayofyear).std()  # Very slow & memory prohibitive...
stdev_day = flox.xarray.xarray_reduce(ssta_notrend, ssta_notrend.time.dt.dayofyear, dim='time', func='std', isbin=False, method='cohorts')

In [15]:
# Compute the rolling 30-day STD
stdev_day_wrap = stdev_day.pad(dayofyear=16, mode='wrap')
stdev_rolling = np.sqrt((stdev_day_wrap**2).rolling(dayofyear=30, center=True).mean()).isel(dayofyear=slice(16,366+16))  # This is still a memory high-water mark...

In [16]:
# Divide by standard deviation
ssta_stn_notrend = ssta_notrend.groupby(ssta_notrend.time.dt.dayofyear) / stdev_rolling

## Use a threshold to find extreme anomalies

In [17]:
ssta_stn_notrend_rechunk = ssta_stn_notrend.chunk({'time':-1,'lat':182,'lon':45})

In [18]:
threshold = ssta_stn_notrend_rechunk.quantile(threshold_percentile, dim='time')
features_notrend = ssta_stn_notrend_rechunk.where(ssta_stn_notrend_rechunk>=threshold, other=np.nan)

In [19]:
# Consistently chunk for zarr
tchunk = 25

features_notrend = features_notrend.chunk({'time':tchunk, 'lat':-1, 'lon':-1})
ssta_notrend = ssta_stn_notrend.chunk({'time':tchunk, 'lat':-1, 'lon':-1})
stdev_rolling = stdev_rolling.chunk({'dayofyear':-1, 'lat':-1, 'lon':-1})

# Add a land/ocean mask
mask = np.isfinite(sst.isel(time=0)).chunk({'lat':-1, 'lon':-1})

In [20]:
# xarray Dataset to save
ds_out = xr.Dataset(
    data_vars=dict(
        features_notrend=(['time','lat','lon'], features_notrend.data),
        ssta_notrend=(['time','lat','lon'], ssta_notrend.data),
        stdev=(['dayofyear','lat','lon'], stdev_rolling.data),
        mask=(['lat','lon'], mask.data),
    ),
    coords=dict(
        lon=sst.lon,
        lat=sst.lat,
        time=sst.time,
        dayofyear=stdev_rolling.dayofyear,
    ),
    attrs=dict(description="ICON erc1011 preprocessed for Ocetrac-dask",
              threshold=f"{int(threshold_percentile * 100)}th percentile",
              climatology='entire period'),
)

ds_out

Unnamed: 0,Array,Chunk
Bytes,107.36 GiB,198.03 MiB
Shape,"(13879, 721, 1440)","(25, 721, 1440)"
Dask graph,556 chunks in 40 graph layers,556 chunks in 40 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 107.36 GiB 198.03 MiB Shape (13879, 721, 1440) (25, 721, 1440) Dask graph 556 chunks in 40 graph layers Data type float64 numpy.ndarray",1440  721  13879,

Unnamed: 0,Array,Chunk
Bytes,107.36 GiB,198.03 MiB
Shape,"(13879, 721, 1440)","(25, 721, 1440)"
Dask graph,556 chunks in 40 graph layers,556 chunks in 40 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,107.36 GiB,198.03 MiB
Shape,"(13879, 721, 1440)","(25, 721, 1440)"
Dask graph,556 chunks in 25 graph layers,556 chunks in 25 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 107.36 GiB 198.03 MiB Shape (13879, 721, 1440) (25, 721, 1440) Dask graph 556 chunks in 25 graph layers Data type float64 numpy.ndarray",1440  721  13879,

Unnamed: 0,Array,Chunk
Bytes,107.36 GiB,198.03 MiB
Shape,"(13879, 721, 1440)","(25, 721, 1440)"
Dask graph,556 chunks in 25 graph layers,556 chunks in 25 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,2.83 GiB,2.83 GiB
Shape,"(366, 721, 1440)","(366, 721, 1440)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 2.83 GiB 2.83 GiB Shape (366, 721, 1440) (366, 721, 1440) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",1440  721  366,

Unnamed: 0,Array,Chunk
Bytes,2.83 GiB,2.83 GiB
Shape,"(366, 721, 1440)","(366, 721, 1440)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,0.99 MiB,0.99 MiB
Shape,"(721, 1440)","(721, 1440)"
Dask graph,1 chunks in 7 graph layers,1 chunks in 7 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray
"Array Chunk Bytes 0.99 MiB 0.99 MiB Shape (721, 1440) (721, 1440) Dask graph 1 chunks in 7 graph layers Data type bool numpy.ndarray",1440  721,

Unnamed: 0,Array,Chunk
Bytes,0.99 MiB,0.99 MiB
Shape,"(721, 1440)","(721, 1440)"
Dask graph,1 chunks in 7 graph layers,1 chunks in 7 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray


## Save data to zarr for more efficient parallel I/O

In [None]:
#ds_out.to_netcdf(scratch_dir / '01_preprocess_dask.nc', mode='w')
encoding = {var: {'compressor': None} for var in ds_out.data_vars}
ds_out.to_zarr(scratch_dir / '01_preprocess_dask.zarr', mode='w', encoding=encoding)