# Use `ocetrac-unstruct` to identify and track marine heatwaves
This example using 40 years of Daily outputs at 5km native grid resolution takes ~ minutes on 128 cores.

In [1]:
import xarray as xr
import numpy as np
import dask
import ocetrac_unstruct

from tempfile import TemporaryDirectory
from getpass import getuser
from pathlib import Path
from dask_jobqueue import SLURMCluster
from dask.distributed import Client, LocalCluster
import subprocess
import re

import warnings
warnings.filterwarnings('ignore')

In [2]:
scratch_dir = Path('/scratch') / getuser()[0] / getuser() / 'mhws' 

## Start Dask Cluster

In [None]:
cluster_scratch = Path('/scratch') / getuser()[0] / getuser() / 'clients'
dask_tmp_dir = TemporaryDirectory(dir=cluster_scratch)
dask.config.set(temporary_directory=dask_tmp_dir.name)

In [4]:
# ## Local Cluster
# cluster = LocalCluster(n_workers=32, threads_per_worker=4)
# client = Client(cluster)

# remote_node = subprocess.run(['hostname'], capture_output=True, text=True).stdout.strip().split('.')[0]
# port = re.search(r':(\d+)/', client.dashboard_link).group(1)
# print(f"Forward with Port = {remote_node}:{port}")

# client.dashboard_link

In [None]:
scale = 256
node_memory = 512

if node_memory == 512:
    client_memory = '500GB'
    constraint_memory = '512'
elif node_memory == 1024:
    client_memory = '1000GB'
    constraint_memory = '1024'

## Distributed Cluster (without GPU)
clusterDistributed = SLURMCluster(name='dask-cluster',
                                    cores=32,
                                    memory='230GB',
                                    processes=64,  # 2 threads
                                    interface='ib0',
                                    queue='compute',
                                    account='bk1377',
                                    walltime='00:59:00',
                                    asynchronous=0,
                                    #job_extra_directives = [f'--constraint={constraint_memory}G --mem=0'],
                                    log_directory=f'/home/b/{getuser()}/.log_trash',
                                    local_directory=dask_tmp_dir.name,
                                    scheduler_options={'dashboard_address': ':8889'})

clusterDistributed.scale(scale)
clientDistributed = Client(clusterDistributed)
remote_node = subprocess.run(['hostname'], capture_output=True, text=True).stdout.strip().split('.')[0]
port = re.search(r':(\d+)/', clientDistributed.dashboard_link).group(1)
print(f"Forward Port = {remote_node}:{port}")
print(f"localhost:{port}/status")

## Load Pre-processed Data
(cf. `01_preprocess_unstruct.ipynb`)

In [6]:
chunk_size = {'time': 2, 'ncells': -1}
time_subset = slice(0,-1)
ds = xr.open_zarr(str(scratch_dir / '01_preprocess_unstruct.zarr'), chunks=chunk_size).drop_vars({'decimal_year', 'dayofyear'}).isel(time=time_subset)

In [None]:
binary_out = ds.features_notrend.drop_vars({'lat','lon'})
binary_out

In [8]:
mask = ds.mask.drop_vars({'lat','lon'}).compute().data

## Run Tracker

In [9]:
# Tracking Parameters
threshold_percentile = 0.95
min_size_quartile = 0.85      # percent   *NB: At 5km (vs ~25km for regridded data, we find many more very small objects!)
radius = 100.0                # km
resolution = 5.0              # km

In [10]:
tracker = ocetrac_unstruct.Tracker(binary_out, scratch_dir=str(scratch_dir / 'ocetrac_unstruct_scratch'), radius=radius, resolution=resolution, min_size_quartile=min_size_quartile, timedim='time', xdim='ncells', neighbours=ds.neighbours, land_mask=mask)

In [11]:
blobs = tracker.track()

In [None]:
blobs.attrs

## Save Blobs

In [None]:
blobs.to_netcdf(scratch_dir / '02_tracked_unstruct.nc', mode='w')

In [None]:
clientDistributed.close()