# Identify & Track Marine Heatwaves on _Unstructured Grid_ using `MarEx`

## Processing Steps:
1. Fill spatial holes in the binary data, using highly-threaded binary dilation matrix operations -- up to `R_fill` elements in radius.
2. Fill gaps in time -- permitting up to `T_fill` missing time slices, while keeping the same event ID.
3. Filter out small objects -- areas less than the bottom `area_filter_quartile` of the size distribution of objects.
4. Identify objects in the binary data, using a highly efficient Unstructured Union-Find (Disjoint Set Union) Clustering Algorithm.
5. Connect objects across time, applying the following criteria for splitting, merging, and persistence:
    - Connected Blobs must overlap by at least fraction `overlap_threshold` of the smaller area.
    - Merged Blobs retain their original ID, but partition the child area based on the parent of the _nearest-neighbour_ cell. 
6. Cluster and reduce the final object ID graph using `scipy.sparse.csgraph.connected_components`.
7. Map the tracked objects into ID-time space for convenient analysis.

N.B.: Exploits parallelised `dask` operations with optimised chunking using `flox` for memory efficiency and speed \
N.N.B.: This example using 40 years of _daily_ outputs at 5km resolution on an _Unstructured Grid_ (15 million cells) using 240 cores takes ~40 minutes

#### N.B.: The following `dask` config may be necessary on particular systems:
```python
dask.config.set({
    'distributed.comm.timeouts.connect': '120s',  # Increase from default
    'distributed.comm.timeouts.tcp': '240s',      # Double the connection timeout
    'distributed.comm.retry.count': 10,           # More retries before giving up
})
```

In [None]:
import xarray as xr
import dask
from getpass import getuser
from pathlib import Path

import spot_the_blOb as blob
import spot_the_blOb.helper as hpc

In [None]:
# Start Dask Cluster
#  N.B.: Need ~ 8 GB per worker (for 5km data // 15 million points)
client = hpc.StartLocalCluster(n_workers=80, n_threads=1)

In [None]:
# Load Pre-processed Data (cf. `01_preprocess_extremes.ipynb`)

scratch_dir = Path('/scratch') / getuser()[0] / getuser() / 'mhws'

file_name = scratch_dir / 'extremes_binary_unstruct.zarr'
chunk_size = {'time': 4, 'ncells': -1}  # Adjust chunksize depending on system memory (too small make parallel iterative algorithm very slow)
ds = xr.open_zarr(str(file_name), chunks={}).isel(time=slice(0, 1825)).chunk(chunk_size)

In [None]:
# Run ID, Tracking, & Merging

tracker = blob.Spotter(ds.extreme_events, 
                       ds.mask,                                 
                       area_filter_quartile = 0.8,          # Remove the smallest 80% of the identified coherent extreme areas. N.B.: With increasing resolution, the filter quartile should be increased.
                       R_fill = 32,                         # Fill small holes with radius < 32 elements, i.e. ~100 km, 
                       T_fill = 2,                          # Allow gaps of 2 days and still continue the event tracking with the same ID
                       allow_merging = True,                # Allow extreme events to split/merge. Keeps track of merge events & unique IDs.
                       overlap_threshold = 0.5,             # Overlap threshold for merging events. If overlap < threshold, events keep independent IDs.
                       nn_partitioning = True,              # Use new NN method to partition merged children areas. If False, reverts to old method of Di Sun et al. 2023.
                       temp_dir = str(scratch_dir/'TEMP/'), # Temporary Scratch Directory for Dask
                       checkpoint = 'load',                 # Load binary pre-processed data
                       verbosity = 1,                       # Choose Verbosity Level (0=None, 1=Basic, 2=Advanced/Timing)
                       # Unstructured Grid Options -- 
                       unstructured_grid = True,            # Use Unstructured Grid
                       xdim = 'ncells',                     # Need to tell MarEx the new Unstructured dimension
                       neighbours = ds.neighbours,          # Connectivity array for the Unstructured Grid Cells
                       cell_areas = ds.cell_areas)          # Cell areas for each Unstructured Grid Cell

extreme_events_ds = tracker.run(return_merges=True)
extreme_events_ds

## Split the Processing & Tracking Steps:
- Processing Requires Many Workers
- Tracking Requires Lots of Memory per Worker

In [None]:
# Processing: Use a Distributed Dask Cluster
#client_cluster = hpc.StartDistributedCluster(n_workers=120, workers_per_node=40, node_memory=256, runtime=19)
data_bin_preprocessed, blob_stats = tracker.run_preprocess(checkpoint='save')
#client_cluster.close()

### N.B.: 480 time steps takes ~17 minutes with 40 workers (512Gb)

In [None]:
# Tracking: Use Local Cluster
extreme_events_ds = tracker.run(return_merges=False)  # This first loads the processed data, then tracks the blobs

### N.B.: 480 time steps takes ~40 minutes with 40 workers (512Gb)

In [None]:
# Save IDed/Tracked/Merged Events to `zarr` for more efficient parallel I/O
file_name = scratch_dir / 'extreme_events_merged_unstruct.zarr'
extreme_events_ds.to_zarr(file_name, mode='w')