# Preprocess global _daily_ SST to extract features to be tracked by `ocetrac-unstruct`
This example using 40 years of Daily outputs at 5km native grid resolution (15 million cells) takes ~10 minutes on 512 cores.

In [1]:
import xarray as xr
import numpy as np
import pandas as pd
import dask
import flox.xarray
import ocetrac_unstruct

import intake
from tempfile import TemporaryDirectory
from getpass import getuser
from pathlib import Path
from dask_jobqueue import SLURMCluster
from dask.distributed import Client, LocalCluster
import subprocess
import re

import warnings
warnings.filterwarnings('ignore')

In [2]:
scratch_dir = Path('/scratch') / getuser()[0] / getuser() / 'mhws' 

## Start Dask Cluster

In [3]:
cluster_scratch = Path('/scratch') / getuser()[0] / getuser() / 'clients'
dask_tmp_dir = TemporaryDirectory(dir=cluster_scratch)
dask.config.set(temporary_directory=dask_tmp_dir.name)

# ## Local Cluster
# cluster = LocalCluster(n_workers=32, threads_per_worker=2)  ## N.B.:  Reduce this is you have memory problems...
# client = Client(cluster)

# remote_node = subprocess.run(['hostname'], capture_output=True, text=True).stdout.strip().split('.')[0]
# port = re.search(r':(\d+)/', client.dashboard_link).group(1)
# print(f"Forward with Port = {remote_node}:{port}")

# client.dashboard_link

<dask.config.set at 0x155536e7a560>

In [4]:
scale = 512
node_memory = 1024


if node_memory == 512:
    client_memory = '500GB'
    constraint_memory = '512'
elif node_memory == 1024:
    client_memory = '1000GB'
    constraint_memory = '1024'

## Distributed Cluster (without GPU)
clusterDistributed = SLURMCluster(name='dask-cluster',
                                    cores=128,
                                    memory=client_memory,
                                    processes=128,  # Only 1 thread
                                    interface='ib0',
                                    queue='compute',
                                    account='bk1377',
                                    walltime=f'00:29:00',
                                    asynchronous=0,
                                    job_extra_directives = [f'--constraint={constraint_memory}G --mem=0'],
                                    log_directory='/home/b/b382615/.log_trash',
                                    local_directory=dask_tmp_dir.name,
                                    scheduler_options={'dashboard_address': ':8889'})

clusterDistributed.scale(scale)
clientDistributed = Client(clusterDistributed)
remote_node = subprocess.run(['hostname'], capture_output=True, text=True).stdout.strip().split('.')[0]
port = re.search(r':(\d+)/', clientDistributed.dashboard_link).group(1)
print(f"Forward Port = {remote_node}:{port}")
print(f"localhost:{port}/status")

Forward Port = l40046:8889
localhost:8889/status


## Import 40 years of Daily EERIE ICON data on the _Native_ 5km unstructured grid

In [5]:
cat = intake.open_catalog("https://raw.githubusercontent.com/eerie-project/intake_catalogues/main/eerie.yaml")
expid = 'eerie-control-1950'
version = 'v20231106'
model = 'icon-esm-er'
gridspec = 'native'

dat = cat['dkrz.disk.model-output'][model][expid][version]['ocean'][gridspec]

In [6]:
## Flox chunking predictor data array
#   Optimise for dayofyear std reduction later...
da_predictor = dat['2d_daily_mean'](chunks={}).to_dask().to.isel(depth=0)
da_predictor_rechunk = flox.xarray.rechunk_for_cohorts(da_predictor, dim='time', labels=da_predictor.time.dt.dayofyear, force_new_chunk_at=1, chunksize=8, ignore_old_chunks=True)

In [7]:
sst = dat['2d_daily_mean'](chunks={'time':da_predictor_rechunk.chunks[0]}).to_dask().to.isel(depth=0)
sst

grid2d = dat['2d_grid'](chunks={}).to_dask().rename({'cell':'ncells'})

## Calculate Decimal Year

In [8]:
# Calculate decimal year for _daily outputs_
def decimal_year(da):
    time = pd.to_datetime(da.time)
    start_of_year = pd.to_datetime(time.year.astype(str) + '-01-01')
    start_of_next_year = pd.to_datetime((time.year + 1).astype(str) + '-01-01')
    year_elapsed = (time - start_of_year).days
    year_duration = (start_of_next_year - start_of_year).days
    return time.year + year_elapsed / year_duration

# Add into the dataset
dyr = decimal_year(sst)
sst = sst.assign_coords(decimal_year=('time', dyr))

## Remove Trend & Compute Anomalies

In [9]:
## De-trending:  coefficients = SST/model ---> coefficients = SST * pmodel

# The 6 coefficient model is composed of the mean, trend, annual sine and cosine harmonics, & semi-annual sine and cosine harmonics
model = np.array([np.ones(len(dyr))] + [dyr - np.mean(dyr)] + [np.sin(2 * np.pi * dyr)] + 
                 [np.cos(2 * np.pi * dyr)] + [np.sin(4 * np.pi * dyr)] + 
                 [np.cos(4 * np.pi * dyr)], dtype=np.float32)

# Take the pseudo-inverse of model to solve least-squares problem
pmodel = np.linalg.pinv(model)

# Convert model and pmodel to xaray DataArray
model_da = xr.DataArray(model.T, dims=['time','coeff'], coords={'time':sst.time.values, 'coeff':np.arange(1,7,1)}).chunk({'time':sst.chunks[0]})
pmodel_da = xr.DataArray(pmodel.T, dims=['coeff','time'], coords={'coeff':np.arange(1,7,1), 'time':sst.time.values})  

# Resulting coefficients of the model
sst_mod = xr.DataArray(pmodel_da.dot(sst), dims=['coeff','ncells'], coords={'coeff':np.arange(1,7,1), 'ncells':sst.ncells.values})


In [10]:
## Construct mean, trend, and seasonal cycle

mean  = model_da.isel(coeff=0).dot(sst_mod.isel(coeff=0))
trend = model_da.isel(coeff=1).dot(sst_mod.isel(coeff=1))
seas  = model_da.isel(coeff=2).dot(sst_mod.isel(coeff=2))

In [11]:
## Compute anomalies by removing all the model coefficients 
ssta_notrend = (sst - model_da.dot(sst_mod))

## Standardises SSTa by dividing by the _30-day rolling_ standard deviation
This step places equal variance on SSTa at all spatial points

In [12]:
# Compute the daily standard deviation
#stdev_day = ssta_notrend.groupby(ssta_notrend.time.dt.dayofyear).std()  # Very slow & memory prohibitive...
stdev_day = flox.xarray.xarray_reduce(ssta_notrend, ssta_notrend.time.dt.dayofyear, dim='time', func='std', isbin=False, method='cohorts')

In [13]:
# Compute the rolling 30-day STD
stdev_day_wrap = stdev_day.pad(dayofyear=16, mode='wrap')
stdev_rolling = np.sqrt((stdev_day_wrap**2).rolling(dayofyear=30, center=True).mean()).isel(dayofyear=slice(16,366+16))  # This is still a memory high-water mark...

In [14]:
# Divide by standard deviation
ssta_stn_notrend = ssta_notrend.groupby(ssta_notrend.time.dt.dayofyear) / stdev_rolling

## Use a Threshold to ID Extreme Anomalies

In [15]:
threshold_percentile = 0.95

In [16]:
ssta_stn_notrend_rechunk = ssta_stn_notrend.chunk({'time':-1,'ncells':4000})

In [17]:
threshold = ssta_stn_notrend_rechunk.quantile(threshold_percentile, dim='time')
features_notrend = xr.where(ssta_stn_notrend_rechunk>=threshold, True, False)

In [18]:
# Consistently chunk for zarr
tchunk = 2

features_notrend = features_notrend.chunk({'time':tchunk, 'ncells':-1})
ssta_notrend = ssta_stn_notrend.chunk({'time':tchunk, 'ncells':-1})
stdev_rolling = stdev_rolling.chunk({'dayofyear':tchunk, 'ncells':-1})

# Add a land/ocean mask
mask = (grid2d.cell_sea_land_mask.rename({'clat':'lat', 'clon':'lon'}) > 0).chunk({'ncells':-1}) # True for land False for ocean

neighbours = grid2d.neighbor_cell_index.rename({'clat':'lat', 'clon':'lon'}).astype(int).chunk({'ncells':-1})

In [19]:
# xarray Dataset to save
ds_out = xr.Dataset(
    data_vars=dict(
        features_notrend=(['time','ncells'], features_notrend.data),
        ssta_notrend=(['time','ncells'], ssta_notrend.data),
        stdev=(['dayofyear','ncells'], stdev_rolling.data),
        mask=(['ncells'], mask.data),
        neighbours=(['nv', 'ncells'], neighbours.data),
    ),
    coords=dict(
        ncells=sst.ncells,
        time=sst.time,
        dayofyear=stdev_rolling.dayofyear,
        nv=neighbours.nv,
    ),
    attrs=dict(description="ICON erc1011 preprocessed for Ocetrac-unstruct",
              threshold=f"{int(threshold_percentile * 100)}th percentile",
              climatology='entire period'),
)

ds_out

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 113.57 MiB 113.57 MiB Shape (14886338,) (14886338,) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",14886338  1,

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 113.57 MiB 113.57 MiB Shape (14886338,) (14886338,) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",14886338  1,

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,192.42 GiB,28.39 MiB
Shape,"(13879, 14886338)","(2, 14886338)"
Dask graph,6940 chunks in 269 graph layers,6940 chunks in 269 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray
"Array Chunk Bytes 192.42 GiB 28.39 MiB Shape (13879, 14886338) (2, 14886338) Dask graph 6940 chunks in 269 graph layers Data type bool numpy.ndarray",14886338  13879,

Unnamed: 0,Array,Chunk
Bytes,192.42 GiB,28.39 MiB
Shape,"(13879, 14886338)","(2, 14886338)"
Dask graph,6940 chunks in 269 graph layers,6940 chunks in 269 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,769.67 GiB,113.57 MiB
Shape,"(13879, 14886338)","(2, 14886338)"
Dask graph,6940 chunks in 248 graph layers,6940 chunks in 248 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 769.67 GiB 113.57 MiB Shape (13879, 14886338) (2, 14886338) Dask graph 6940 chunks in 248 graph layers Data type float32 numpy.ndarray",14886338  13879,

Unnamed: 0,Array,Chunk
Bytes,769.67 GiB,113.57 MiB
Shape,"(13879, 14886338)","(2, 14886338)"
Dask graph,6940 chunks in 248 graph layers,6940 chunks in 248 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,20.30 GiB,113.57 MiB
Shape,"(366, 14886338)","(2, 14886338)"
Dask graph,183 chunks in 244 graph layers,183 chunks in 244 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 20.30 GiB 113.57 MiB Shape (366, 14886338) (2, 14886338) Dask graph 183 chunks in 244 graph layers Data type float32 numpy.ndarray",14886338  366,

Unnamed: 0,Array,Chunk
Bytes,20.30 GiB,113.57 MiB
Shape,"(366, 14886338)","(2, 14886338)"
Dask graph,183 chunks in 244 graph layers,183 chunks in 244 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,14.20 MiB,14.20 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray
"Array Chunk Bytes 14.20 MiB 14.20 MiB Shape (14886338,) (14886338,) Dask graph 1 chunks in 3 graph layers Data type bool numpy.ndarray",14886338  1,

Unnamed: 0,Array,Chunk
Bytes,14.20 MiB,14.20 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,340.72 MiB,340.72 MiB
Shape,"(3, 14886338)","(3, 14886338)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 340.72 MiB 340.72 MiB Shape (3, 14886338) (3, 14886338) Dask graph 1 chunks in 3 graph layers Data type int64 numpy.ndarray",14886338  3,

Unnamed: 0,Array,Chunk
Bytes,340.72 MiB,340.72 MiB
Shape,"(3, 14886338)","(3, 14886338)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray


## Save data to zarr for more efficient parallel I/O

In [20]:
#ds_out.to_netcdf(scratch_dir / '01_preprocess_dask.nc', mode='w')
encoding = {var: {'compressor': None} for var in ds_out.data_vars}
ds_out.to_zarr(scratch_dir / '01_preprocess_unstruct.zarr', mode='w', encoding=encoding)

<xarray.backends.zarr.ZarrStore at 0x1553ca90aa40>

In [21]:
clientDistributed.close()