# Global Daily Event Analysis: Marine Heatwave ID & Tracking using `MarEx`

### `MarEx` Processing Pipeline for Unstructured Datasets:

1. **Morphological Pre-Processing**
    - Performs binary morphological closing using highly-threaded binary dilation matrix operations to fill small spatial holes up to `R_fill` elements in radius 
    - Executes binary opening to remove isolated small features of order `R_fill`
    - Fills gaps in time to maintain event continuity for interruptions up to `T_fill` time steps
    - Filters out smallest objects below the `area_filter_quartile` percentile threshold

2. **Blob Identification**
    - Labels spatially connected components using a highly efficient Unstructured Union-Find (Disjoint Set Union) Clustering Algorithm
    - Computes blob properties (area, centroid, boundaries)

3. **Temporal Tracking**
    - Identifies blob overlaps between consecutive time frames
    - Connects objects across time, applying the following criteria for splitting, merging, & persistence:
        - Connected objects must overlap by at least fraction `overlap_threshold` of the smaller area
        - Merged objects retain their original ID, but partition the child area based on the parent of the _nearest-neighbour_ cell (or centroid distance)

4. **Graph Reduction & Finalisation**
    - Constructs the complete temporal graph of object evolution through time
    - Resolves object connectivity graph using `scipy.sparse.csgraph.connected_components`
    - Creates globally unique IDs for each tracked extreme event
    - Maps objects into efficient ID-time space for convenient analysis
    - Computes comprehensive statistics about the lifecycle of each event

The pipeline leverages **dask** for distributed parallel computation, enabling efficient processing of large datasets. \
A 40-year global daily analysis at 5km resolution on the _unstructured grid_ (15 million cells) using 240 cores takes ~40 minutes.

#### N.B.: The following `dask` config may be necessary on particular systems:
```python
dask.config.set({
    'distributed.comm.timeouts.connect': '120s',  # Increase from default
    'distributed.comm.timeouts.tcp': '240s',      # Double the connection timeout
    'distributed.comm.retry.count': 10,           # More retries before giving up
})
```

In [None]:
import xarray as xr
import dask
from getpass import getuser
from pathlib import Path

import marEx
import marEx.helper as hpc

In [None]:
# Lustre Scratch Directory
scratch_dir = Path('/scratch') / getuser()[0] / getuser()

In [None]:
# # Start Distributed Dask Cluster
# client_cluster = hpc.start_distributed_cluster(n_workers=2048, workers_per_node=128, runtime=59, node_memory=256,
#                                  scratch_dir = scratch_dir / 'clients')  # Specify temporary scratch directory for dask to use
client = hpc.start_local_cluster(n_workers=50, threads_per_worker=1,
                                 scratch_dir = scratch_dir / 'clients')  # Specify temporary scratch directory for dask to use

In [None]:
# Choose optimal chunk size & load data
#   N.B.: This is crucial for dask (not only for performance, but also to make the problem tractable)
#         The operations are eventually global-in-space, and so requires the spatial dimension to be contiguous/unchunked
#         We can adjust the chunk size in time depending on available system memory; however,
#         note that the performance of the parallel iterative merging algorithm increases with larger chunks in time.

chunk_size = {'time': 4, 'ncells': -1}

In [None]:
# Load Pre-processed Data (cf. `01_preprocess_extremes.ipynb`)

file_name = scratch_dir / 'mhws' / 'extremes_binary_unstruct.zarr'
ds = xr.open_zarr(str(file_name), chunks=chunk_size).isel(time=slice(0,256))

In [None]:
# Run ID, Tracking, & Merging

tracker = marEx.tracker(ds.extreme_events, 
                       ds.mask,                                 
                       area_filter_quartile = 0.8,          # Remove the smallest 80% of the identified coherent extreme areas. N.B.: With increasing resolution, the filter quartile should be increased.
                       R_fill = 32,                         # Fill small holes with radius < 32 elements, i.e. ~100 km, 
                       T_fill = 2,                          # Allow gaps of 2 days and still continue the event tracking with the same ID
                       allow_merging = True,                # Allow extreme events to split/merge. Keeps track of merge events & unique IDs.
                       overlap_threshold = 0.5,             # Overlap threshold for merging events. If overlap < threshold, events keep independent IDs.
                       nn_partitioning = True,              # Use new NN method to partition merged children areas. If False, reverts to old method of Di Sun et al. 2023.
                       temp_dir = str(scratch_dir/'mhws'/'TEMP/'), # Temporary Scratch Directory needed for Dask
                       checkpoint = 'save',                 # Make checkpoint of binary pre-processed data
                       verbosity = 1,                       # Choose Verbosity Level (0=None, 1=Basic, 2=Advanced/Timing)
                       # -- Unstructured Grid Options -- 
                       unstructured_grid = True,            # Use Unstructured Grid
                       xdim = 'ncells',                     # Need to tell MarEx the new Unstructured dimension
                       neighbours = ds.neighbours,          # Connectivity array for the Unstructured Grid Cells
                       cell_areas = ds.cell_areas)          # Cell areas for each Unstructured Grid Cell


## Split the Coherent Area Pre-processing & ID/Tracking/Merging Steps:
- Coherent Area Pre-Processing Requires _Many Workers_
- ID/Tracking Requires _Lots of Memory per Worker_

In [None]:
# Run Spatial Pre-processing on the Distributed Cluster (above)
data_bin_preprocessed, object_stats = tracker.run_preprocess(checkpoint='save')
# client_cluster.close()

In [None]:
# Start a Small Local Dask Cluster
#  N.B.: Need ~8 GB per worker (for 5km data // 15 million points)
client = hpc.start_local_cluster(n_workers=50, threads_per_worker=1,
                                 scratch_dir = scratch_dir / 'clients')  # Specify temporary scratch directory for dask to use

In [None]:
extreme_events_ds, merges_ds = tracker.run(return_merges=True, checkpoint='load')  # This first loads the processed data, then tracks the events
extreme_events_ds

In [None]:
merges_ds

In [None]:
# Save IDed/Tracked/Merged Events to `zarr` for more efficient parallel I/O

file_name = scratch_dir / 'mhws' / 'extreme_events_merged_unstruct.zarr'
extreme_events_ds.to_zarr(file_name, mode='w')

### Use Centroid-based Partitioning Method for Comparison

In [None]:
# Run ID, Tracking, & Merging

tracker = marEx.tracker(ds.extreme_events, 
                       ds.mask,                                 
                       area_filter_quartile = 0.8,          # Remove the smallest 80% of the identified coherent extreme areas. N.B.: With increasing resolution, the filter quartile should be increased.
                       R_fill = 32,                         # Fill small holes with radius < 32 elements, i.e. ~100 km, 
                       T_fill = 2,                          # Allow gaps of 2 days and still continue the event tracking with the same ID
                       allow_merging = True,                # Allow extreme events to split/merge. Keeps track of merge events & unique IDs.
                       overlap_threshold = 0.5,             # Overlap threshold for merging events. If overlap < threshold, events keep independent IDs.
                       nn_partitioning = False,             # Use old Centroid-based partitioning method (Di Sun et al. 2023).
                       temp_dir = str(scratch_dir/'mhws'/'TEMP/'), # Temporary Scratch Directory for Dask
                       checkpoint = 'save',                 # Make checkpoint of binary pre-processed data
                       verbosity = 1,                       # Choose Verbosity Level (0=None, 1=Basic, 2=Advanced/Timing)
                       # -- Unstructured Grid Options -- 
                       unstructured_grid = True,            # Use Unstructured Grid
                       xdim = 'ncells',                     # Need to tell MarEx the new Unstructured dimension
                       neighbours = ds.neighbours,          # Connectivity array for the Unstructured Grid Cells
                       cell_areas = ds.cell_areas)          # Cell areas for each Unstructured Grid Cell

extreme_events_ds, merges_ds = tracker.run(return_merges=True)
extreme_events_ds

In [None]:
# Save IDed/Tracked/Merged Events to `zarr` for more efficient parallel I/O

file_name = scratch_dir / 'mhws' / 'extreme_events_merged_centroid_unstruct.zarr'
extreme_events_ds.to_zarr(file_name, mode='w')