# Identify & Track Marine Heatwaves on _Unstructured Grid_ using `spot_the_blOb`

## Processing Steps:
1. Fill spatial holes in the binary data, using `dask_image.ndmorph` -- up to `R_fill` cells in radius.
2. Fill gaps in time -- permitting up to `T_fill` missing time slices, while keeping the same blob ID.
3. Filter out small objects -- area less than the bottom `area_filter_quartile` of the size distribution of objects.
4. Identify objects in the binary data, using `dask_image.ndmeasure`.
5. Connect objects across time, applying the following criteria for splitting, merging, and persistence:
    - Connected Blobs must overlap by at least fraction `overlap_threshold` of the smaller blob.
    - Merged Blobs retain their original ID, but partition the child blob based on the parent of the _nearest-neighbour_ cell. 
6. Cluster and reduce the final object ID graph using `scipy.sparse.csgraph.connected_components`.
7. Map the tracked objects into ID-time space for convenient analysis.

N.B.: Exploits parallelised `dask` operations with optimised chunking using `flox` for memory efficiency and speed \
N.N.B.: This example using 40 years of _daily_ outputs at 5km resolution on an Unstructured Grid (15 million cells) using 32 cores takes 
- Full Split/Merge Thresholding & Merge Tracking:  ~40 minutes

In [1]:
import xarray as xr
import dask
from getpass import getuser
from pathlib import Path

import spot_the_blOb as blob
import spot_the_blOb.helper as hpc

In [2]:
# Start Dask Cluster
client = hpc.StartLocalCluster(n_workers=32, n_threads=2)

Memory per Worker: 7.86 GB
Hostname is  l10746
Forward Port = l10746:8787
Dashboard Link: localhost:8787/status


In [3]:
# Load Pre-processed Data (cf. `01_preprocess_extremes.ipynb`)

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'extreme_events_binary_unstruct.zarr'
chunk_size = {'time': 4, 'ncells': -1}
ds = xr.open_zarr(str(file_name), chunks={}).isel(time=slice(0,32)).chunk(chunk_size)

In [4]:
# Tracking Parameters

drop_area_quartile = 0.8  # Remove the smallest 80% of the identified blobs
hole_filling_radius = 32   # Fill small holes with radius < 8 elements
time_gap_fill = 2         # Allow gaps of 2 days and still continue the blob tracking with the same ID
allow_merging = True      # Allow blobs to split/merge. Keeps track of merge events & unique IDs.
overlap_threshold = 0.5   # Overlap threshold for merging blobs. If overlap < threshold, blobs keep independent IDs.
nn_partitioning = True    # Use new NN method to partition merged children blobs. If False, reverts to old method of Di Sun et al. 2023...

In [5]:
# SpOt & Track the Blobs & Merger Events

tracker = blob.Spotter(ds.extreme_events, ds.mask, R_fill=hole_filling_radius, T_fill = time_gap_fill, area_filter_quartile=drop_area_quartile, 
                       allow_merging=allow_merging, overlap_threshold=overlap_threshold, nn_partitioning=nn_partitioning, 
                       xdim='ncells',               # Need to tell spot_the_blOb the new Unstructured dimension
                       unstructured_grid=True,      # Use Unstructured Grid
                       neighbours=ds.neighbours,    # Connectivity array for the Unstructured Grid
                       cell_areas=ds.cell_areas)      # Cell areas for each Unstructured Grid cell
blobs = tracker.run(return_merges=False)

# blobs

Constructing the Sparse Dilation Matrix...


This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


Finished filling spatio-temporal holes.
Finished filtering small blobs.
Finished blob identification.
Finished calculating blob properties.


Debugging information
---------------------
old task state: processing
old run_spec: <Task ('original-invert-8fd951b7d4123bc599e52d80b0b2ed21', 0) _execute_subgraph(...)>
new run_spec: <Task ('original-invert-8fd951b7d4123bc599e52d80b0b2ed21', 0) _execute_subgraph(...)>
old token: '84b41ef87bd6591e511c33ab99c2f6ae'
new token: '1af5f3b91fe030248cb91418d6e55101'
old dependencies: set()
new dependencies: set()

Debugging information
---------------------
old task state: processing
old run_spec: <Task ('original-invert-8fd951b7d4123bc599e52d80b0b2ed21', 0) _execute_subgraph(...)>
new run_spec: <Task ('original-invert-8fd951b7d4123bc599e52d80b0b2ed21', 0) _execute_subgraph(...)>
old token: '84b41ef87bd6591e511c33ab99c2f6ae'
new token: 'ec828eb88a15ca2e5fa28fa46ee60ecb'
old dependencies: set()
new dependencies: set()

Debugging information
---------------------
old task state: processing
old run_spec: <Task ('original-invert-8fd951b7d4123bc599e52d80b0b2ed21', 0) _execute_subgraph(...)>
new r

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


Finished finding overlapping blobs.
Processing iteration 1 with 6 merging blobs...


This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


Processing iteration 2 with 9 merging blobs...


This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


Processing iteration 3 with 6 merging blobs...


This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/

Finished splitting and merging blobs.


This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


Finished clustering and renaming blobs.
Finished tracking blobs.


NotImplementedError: 'item' is not yet a valid method on dask arrays

In [None]:
import pyicon as pyic
import numpy as np

In [None]:
data_bin_filled = tracker.fill_holes(tracker.data_bin).persist()
data_bin_gap_filled = tracker.fill_time_gaps(data_bin_filled).persist()
data_bin_filtered, area_threshold, blob_areas, N_blobs_unfiltered = tracker.filter_small_blobs(data_bin_gap_filled)
data_bin_filtered = data_bin_filtered.persist()

In [None]:
#### Track_blObs....

blob_id_field, _ = tracker.identify_blobs(data_bin_filtered, time_connectivity=False)
blob_id_field = blob_id_field.persist()

cumsum_ids = (blob_id_field.max(dim=tracker.xdim)).cumsum(tracker.timedim).shift({tracker.timedim: 1}, fill_value=0)
blob_id_field = xr.where(blob_id_field > 0, blob_id_field + cumsum_ids, 0)

blob_props = tracker.calculate_blob_properties(blob_id_field, properties=['area', 'centroid'])

In [9]:
import xarray as xr
import numpy as np
from dask_image.ndmeasure import label
from skimage.measure import regionprops_table
from dask_image.ndmorph import binary_closing as binary_closing_dask
from dask_image.ndmorph import binary_opening as binary_opening_dask
from scipy.ndimage import binary_closing, binary_opening
from scipy.sparse import coo_matrix, csr_matrix, eye
from scipy.sparse.csgraph import connected_components
from dask import persist
from dask import delayed
from dask import compute as dask_compute
import dask.array as dsa
from dask.base import is_dask_collection
from numba import jit, njit, int64, int32, prange
import jax.numpy as jnp
import warnings

In [None]:
split_merged_blob_id_field_unique, merged_blobs_props, split_merged_blobs_list, merge_events = tracker.split_and_merge_blobs_parallel(blob_id_field, blob_props)

In [35]:
blob_id_field_unique = split_merged_blob_id_field_unique
blobs_props = merged_blobs_props
overlap_blobs_list = split_merged_blobs_list

In [36]:
# Get unique IDs from the overlap pairs
IDs = np.unique(overlap_blobs_list) # 1D sorted unique
        
# Create a mapping from ID to indices
ID_to_index = {ID: index for index, ID in enumerate(IDs)}

# Convert overlap pairs to indices
overlap_pairs_indices = np.array([(ID_to_index[pair[0]], ID_to_index[pair[1]]) for pair in overlap_blobs_list])

# Create a sparse matrix representation of the graph
n = len(IDs)
row_indices, col_indices = overlap_pairs_indices.T
data = np.ones(len(overlap_pairs_indices), dtype=np.bool_)
graph = csr_matrix((data, (row_indices, col_indices)), shape=(n, n), dtype=np.bool_)

# Clear temporary arrays
del row_indices
del col_indices
del data

# Solve the graph to determine connected components
num_components, component_IDs = connected_components(csgraph=graph, directed=False, return_labels=True)

del graph

# Group IDs by their component index
ID_clusters = [[] for _ in range(num_components)]
for ID, component_ID in zip(IDs, component_IDs):
    ID_clusters[component_ID].append(ID)


In [37]:

## ID_clusters now is a list of lists of equivalent blob IDs that have been tracked across time
#  We now need to replace all IDs in blob_id_field_unique that match the equivalent_IDs with the list index:  This is the new/final ID field.

# Create a dictionary to map IDs to the new cluster indices
min_int32 = np.iinfo(np.int32).min
max_old_ID = blob_id_field_unique.max().compute().data
ID_to_cluster_index_array = np.full(max_old_ID + 1, min_int32, dtype=np.int32)

# Fill the lookup array with cluster indices
for index, cluster in enumerate(ID_clusters):
    for ID in cluster:
        ID_to_cluster_index_array[ID] = np.int32(index+1) # Because these are the connected IDs, there are many fewer!
                                                            #  Add 1 so that ID = 0 is still invalid/no object

# N.B.: **Need to pass da into apply_ufunc, otherwise it doesn't manage the memory correctly with large shared-mem numpy arrays**
ID_to_cluster_index_da = xr.DataArray(ID_to_cluster_index_array, dims='ID', coords={'ID': np.arange(max_old_ID + 1)})

def map_IDs_to_indices(block, ID_to_cluster_index_array):
    mask = block > 0
    new_block = np.zeros_like(block, dtype=np.int32)
    new_block[mask] = ID_to_cluster_index_array[block[mask]]
    return new_block

input_dims = [tracker.xdim] if tracker.unstructured_grid else [tracker.ydim, tracker.xdim]
split_merged_relabeled_blob_id_field = xr.apply_ufunc(
    map_IDs_to_indices,
    blob_id_field_unique, 
    ID_to_cluster_index_da,
    input_core_dims=[input_dims,['ID']],
    output_core_dims=[input_dims],
    vectorize=True,
    dask="parallelized",
    output_dtypes=[np.int32]
).persist()


In [38]:
### Relabel the blobs_props to match the new IDs (and add time dimension!)

max_new_ID = num_components + 1  # New IDs range from 0 to max_new_ID...
new_ids = np.arange(1, max_new_ID+1, dtype=np.int32)

# New blobs_props DataSet Structure
blobs_props_extended = xr.Dataset(coords={
    'ID': new_ids,
    tracker.timedim: blob_id_field_unique[tracker.timedim]
})

## Create a mapping from new IDs to the original IDs _at the corresponding time_
valid_new_ids = (split_merged_relabeled_blob_id_field > 0)      
original_ids_field = blob_id_field_unique.where(valid_new_ids)
new_ids_field = split_merged_relabeled_blob_id_field.where(valid_new_ids)

if not tracker.unstructured_grid:
    original_ids_field = original_ids_field.stack(z=(tracker.ydim, tracker.xdim), create_index=False)
    new_ids_field = new_ids_field.stack(z=(tracker.ydim, tracker.xdim), create_index=False)

new_id_to_idx = {id_val: idx for idx, id_val in enumerate(new_ids)}

def process_timestep(orig_ids, new_ids_t):
    """Process a single timestep to create ID mapping."""
    result = np.zeros(len(new_id_to_idx), dtype=np.int32)
    
    valid_mask = new_ids_t > 0
    
    # Get valid points for this timestep
    if not valid_mask.any():
        return result
        
    orig_valid = orig_ids[valid_mask]
    new_valid = new_ids_t[valid_mask]
    
    if len(orig_valid) == 0:
        return result
        
    unique_pairs = np.unique(np.column_stack((orig_valid, new_valid)), axis=0)
    
    # Create mapping
    for orig_id, new_id in unique_pairs:
        if new_id in new_id_to_idx:
            result[new_id_to_idx[new_id]] = orig_id
            
    return result

input_dim = ['ncells'] if tracker.unstructured_grid else ['z']
global_id_mapping = xr.apply_ufunc(
            process_timestep,
            original_ids_field,
            new_ids_field,
            input_core_dims=[input_dim, input_dim],
            output_core_dims=[['ID']],
            vectorize=True,
            dask='parallelized',
            output_dtypes=[np.int32],
            dask_gufunc_kwargs={'output_sizes': {'ID': len(new_ids)}}
    ).assign_coords(ID=new_ids).compute()


blobs_props_extended['global_ID'] = global_id_mapping
# N.B.: Now, e.g. global_id_mapping.sel(ID=10) --> Given the new ID (10), returns corresponding original_id at every time


In [None]:
global_id_mapping

In [None]:
blobs_props.sel(ID=0).compute()

In [None]:
## Transfer and transform all variables from original blobs_props:


for var_name in blobs_props.data_vars:
    
    temp = (blobs_props[var_name]
                        .sel(ID=global_id_mapping.rename({'ID':'new_id'}))
                        .drop_vars('ID').rename({'new_id':'ID'}))
    
    if var_name == 'ID':
        temp = temp.astype(np.int32)
    else:
        temp = temp.astype(np.float32)
        
    blobs_props_extended[var_name] = temp



In [None]:

## Map the merge_events using the old IDs to be from dimensions (merge_ID, parent_idx) 
#     --> new merge_ledger with dimensions (time, ID, sibling_ID)
# i.e. for each merge_ID --> merge_parent_IDs   gives the old IDs  --> map to new ID using ID_to_cluster_index_da
#                   --> merge_time

old_parent_IDs = xr.where(merge_events.parent_IDs>0, merge_events.parent_IDs, 0)
new_IDs_parents = ID_to_cluster_index_da.sel(ID=old_parent_IDs)

# Replace the coordinate merge_ID in new_IDs_parents with merge_time.  merge_events.merge_time gives merge_time for each merge_ID
new_IDs_parents_t = new_IDs_parents.assign_coords({'merge_time': merge_events.merge_time}).drop_vars('ID').swap_dims({'merge_ID': 'merge_time'}).persist()  # this now has coordinate merge_time and ID

# Map new_IDs_parents_t into a new data array with dimensions time, ID, and sibling_ID
merge_ledger = xr.full_like(global_id_mapping, fill_value=-1).chunk({tracker.timedim: split_merged_relabeled_blob_id_field.data.chunksize[0]}).expand_dims({'sibling_ID': new_IDs_parents_t.parent_idx.shape[0]}).copy() # dimesions are time, ID, sibling_ID

# Wrapper for processing/mapping mergers in parallel
def process_time_group(time_block, IDs_data, IDs_coords):
    """Process all mergers for a single block of timesteps."""
    result = xr.full_like(time_block, -1)
    
    # Get unique times in this block
    unique_times = np.unique(time_block[tracker.timedim])
    
    for time_val in unique_times:
        # Get IDs for this time
        time_mask = IDs_coords['merge_time'] == time_val
        if not np.any(time_mask):
            continue
            
        IDs_at_time = IDs_data[time_mask]
        
        # Single merger case
        if IDs_at_time.ndim == 1:
            valid_mask = IDs_at_time > 0
            if np.any(valid_mask):
                # Create expanded array for each sibling_ID dimension
                expanded_IDs = np.broadcast_to(IDs_at_time, (len(time_block.sibling_ID), len(IDs_at_time)))
                result.loc[{tracker.timedim: time_val, 'ID': IDs_at_time[valid_mask]}] = expanded_IDs[:, valid_mask]
        # Multiple mergers case
        else:
            for merger_IDs in IDs_at_time:
                valid_mask = merger_IDs > 0
                if np.any(valid_mask):
                    expanded_IDs = np.broadcast_to(merger_IDs, (len(time_block.sibling_ID), len(merger_IDs)))
                    result.loc[{tracker.timedim: time_val, 'ID': merger_IDs[valid_mask]}] = expanded_IDs[:, valid_mask]
                    
    return result

merge_ledger = xr.map_blocks(
    process_time_group,
    merge_ledger,
    args=(new_IDs_parents_t.values, new_IDs_parents_t.coords),
    template=merge_ledger
)

# Final formatting
merge_ledger = merge_ledger.rename('merge_ledger').transpose(tracker.timedim, 'ID', 'sibling_ID').persist()


In [None]:


## Finish up:
# Add start and end time indices for each ID
valid_presence = blobs_props_extended['global_ID'] > 0  # Where we have valid data

blobs_props_extended['presence'] = valid_presence
blobs_props_extended['time_start'] = valid_presence.time[valid_presence.argmax(dim=tracker.timedim)]
blobs_props_extended['time_end'] = valid_presence.time[(valid_presence.sizes[tracker.timedim] - 1) - (valid_presence[::-1]).argmax(dim=tracker.timedim)]
        
# Combine blobs_props_extended with split_merged_relabeled_blob_id_field
split_merged_relabeled_blobs_ds = xr.merge([split_merged_relabeled_blob_id_field.rename('ID_field'), 
                                            blobs_props_extended,
                                            merge_ledger])


In [45]:
out = split_merged_relabeled_blobs_ds.isel(ID=slice(0, -1))

In [None]:
out.ID_field.max().compute().data

In [None]:
# Save Tracked Blobs to `zarr` for more efficient parallel I/O

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'MHWs_tracked_unstruct.zarr'
blobs.to_zarr(file_name, mode='w')