# Identify & Track Marine Heatwaves using `spot_the_blOb`

## Processing Steps:
1. Fill holes in the binary data, using `dask_image.ndmorph` -- up to `R_fill` cells in radius.
2. Filter out small objects -- area less than the bottom `area_filter_quartile` of the size distribution of objects.
3. Identify objects in the binary data, using `dask_image.ndmeasure`.
4. Connect objects across time, applying the following criteria for splitting, merging, and persistence:
    - Connected Blobs must overlap by at least fraction `overlap_threshold` of the smaller blob.
    - Merged Blobs retain their original ID, but partition the child blob based on the parent of the _nearest-neighbour_ cell. 
5. Cluster and reduce the final object ID graph using `scipy.sparse.csgraph.connected_components`.
6. Map the tracked objects into ID-time space for convenient analysis.

N.B.: Exploits parallelised `dask` operations with optimised chunking using `flox` for memory efficiency and speed \
N.N.B.: This example using 40 years of _daily_ outputs at 0.25° resolution on 32 cores takes 
- Standard (i.e. Scannell et al., which involves no merge/split criteria or tracking):  ~2 minutes
- Full Split/Merge Thresholding & Merge Tracking:  ~1 hour

In [1]:
import xarray as xr
import dask
from getpass import getuser
from pathlib import Path

import spot_the_blOb as blob
import spot_the_blOb.helper as hpc

In [2]:
# Start Dask Cluster
client = hpc.StartLocalCluster(n_workers=32, n_threads=2)

Memory per Worker: 15.74 GB
Hostname is  l40095
Forward Port = l40095:8787
Dashboard Link: localhost:8787/status


In [3]:
# Load Pre-processed Data (cf. `01_preprocess_extremes.ipynb`)

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'extreme_events_binary.zarr'
chunk_size = {'time': 25, 'lat': -1, 'lon': -1}
ds = xr.open_zarr(str(file_name), chunks=chunk_size)

In [4]:
# Extract Binary Features and Modify Mask

extreme_bin = ds.extreme_events#.isel(time=slice(0, 2000))
mask = ds.mask.where((ds.lat<85) & (ds.lat>-90), other=False)

In [5]:
# Tracking Parameters

drop_area_quartile = 0.5
filling_radius = 8
allow_merging = True     # Allow blobs to split/merge. Keeps track of merge events & unique IDs.
overlap_threshold = 0.5  # Overlap threshold for merging blobs. If overlap < threshold, blobs keep independent IDs.
nn_partitioning = True   # Use new NN method to partition merged children blobs. If False, reverts to old method of Di Sun et al. 2023...

In [None]:
# Spot the Blobs

tracker = blob.Spotter(extreme_bin, mask, R_fill=filling_radius, area_filter_quartile=drop_area_quartile, 
                       allow_merging=allow_merging, overlap_threshold=overlap_threshold, nn_partitioning=nn_partitioning)
blobs = tracker.run(return_merges=False)

blobs

Finished filling holes.
Finished filtering small blobs.
Finished blob identification.
Finished calculating blob properties.
Finished finding overlapping blobs.


  result = blockwise(


Processing splitting and merging in chunk 0 of 556
Processing splitting and merging in chunk 25 of 556
Missing newly created child_ids {151981} because parents have split/morphed in the meantime...
Processing splitting and merging in chunk 50 of 556
Missing newly created child_ids {153469} because parents have split/morphed in the meantime...
Processing splitting and merging in chunk 75 of 556
Missing newly created child_ids {154794} because parents have split/morphed in the meantime...


In [None]:
# Save Tracked Blobs

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'MHWs_tracked.nc'
blobs.to_netcdf(file_name, mode='w')

In [None]:
data_bin_filled = tracker.fill_holes()
print('Finished filling holes.')

# Remove Small Objects
data_bin_filtered, area_threshold, blob_areas, N_blobs_unfiltered = tracker.filter_small_blobs(data_bin_filled)
print('Finished filtering small blobs.')

In [None]:
data_bin = data_bin_filtered

In [None]:
blob_id_field, _ = tracker.identify_blobs(data_bin, time_connectivity=False)
print('Finished blob identification.')

# Calculate Properties of each Blob
blob_props = tracker.calculate_blob_properties(blob_id_field, properties=['area', 'centroid'])
print('Finished calculating blob properties.')

# Compile List of Overlapping Blob ID Pairs Across Time
overlap_blobs_list = tracker.find_overlapping_blobs(blob_id_field)  # List of overlapping blob pairs
print('Finished finding overlapping blobs.')


# Apply Splitting & Merging Logic to `overlap_blobs`
#   N.B. This is the longest step due to loop-wise dependencies... but many sub-steps are highly threaded so we're okay-ish in the end
split_merged_blob_id_field_unique, merged_blobs_props, split_merged_blobs_list, merge_events = tracker.split_and_merge_blobs(blob_id_field, blob_props, overlap_blobs_list)
print('Finished splitting and merging blobs.')



In [None]:
blob_id_field_unique = split_merged_blob_id_field_unique
blobs_props = merged_blobs_props
overlap_blobs_list = split_merged_blobs_list

In [None]:
import xarray as xr
import numpy as np
from dask_image.ndmeasure import label
from skimage.measure import regionprops_table
from dask_image.ndmorph import binary_closing as binary_closing_dask
from dask_image.ndmorph import binary_opening as binary_opening_dask
from scipy.ndimage import binary_closing, binary_opening
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from dask import persist
from dask.base import is_dask_collection
from numba import jit, prange
import warnings

In [None]:
IDs = np.unique(overlap_blobs_list) # 1D sorted unique
        
# Create a mapping from ID to indices
ID_to_index = {ID: index for index, ID in enumerate(IDs)}

# Convert overlap pairs to indices
overlap_pairs_indices = np.array([(ID_to_index[pair[0]], ID_to_index[pair[1]]) for pair in overlap_blobs_list])

# Create a sparse matrix representation of the graph
n = len(IDs)
row_indices, col_indices = overlap_pairs_indices.T
data = np.ones(len(overlap_pairs_indices))
graph = csr_matrix((data, (row_indices, col_indices)), shape=(n, n))

# Solve the graph to determine connected components
num_components, component_IDs = connected_components(csgraph=graph, directed=False, return_labels=True)

# Group IDs by their component index
ID_clusters = [[] for _ in range(num_components)]
for ID, component_ID in zip(IDs, component_IDs):
    ID_clusters[component_ID].append(ID)

In [None]:
min_int32 = np.iinfo(np.int32).min
max_old_ID = blob_id_field_unique.max().compute().data
ID_to_cluster_index_array = np.full(max_old_ID + 1, min_int32, dtype=np.int32)

# Fill the lookup array with cluster indices
for index, cluster in enumerate(ID_clusters):
    for ID in cluster:
        ID_to_cluster_index_array[ID] = np.int32(index+1) # Because these are the connected IDs, there are many fewer!
                                                            #  Add 1 so that ID = 0 is still invalid/no object

# N.B.: **Need to pass da into apply_ufunc, otherwise it doesn't manage the memory correctly with large shared-mem numpy arrays**
ID_to_cluster_index_da = xr.DataArray(ID_to_cluster_index_array, dims='ID', coords={'ID': np.arange(max_old_ID + 1)})

def map_IDs_to_indices(block, ID_to_cluster_index_array):
    mask = block > 0
    new_block = np.zeros_like(block, dtype=np.int32)
    new_block[mask] = ID_to_cluster_index_array[block[mask]]
    return new_block

split_merged_relabeled_blob_id_field = xr.apply_ufunc(
    map_IDs_to_indices,
    blob_id_field_unique, 
    ID_to_cluster_index_da,
    input_core_dims=[[tracker.ydim, tracker.xdim],['ID']],
    output_core_dims=[[tracker.ydim, tracker.xdim]],
    vectorize=True,
    dask="parallelized",
    output_dtypes=[np.int32]
)

In [None]:
split_merged_relabeled_blob_id_field = split_merged_relabeled_blob_id_field.persist()

In [None]:
max_new_ID = num_components + 1  # New IDs range from 0 to max_new_ID...
new_ids = np.arange(1, max_new_ID+1, dtype=np.int32)

# New blobs_props DataSet Structure
blobs_props_extended = xr.Dataset(coords={
    'ID': new_ids,
    tracker.timedim: blob_id_field_unique[tracker.timedim]
})

In [None]:
valid_new_ids = (split_merged_relabeled_blob_id_field > 0)  

In [None]:
original_ids_field = blob_id_field_unique.where(valid_new_ids).stack(z=(tracker.ydim, tracker.xdim), create_index=False)
new_ids_field = split_merged_relabeled_blob_id_field.where(valid_new_ids).stack(z=(tracker.ydim, tracker.xdim), create_index=False)

In [None]:
# Create lookup dictionary for new IDs
new_id_to_idx = {id_val: idx for idx, id_val in enumerate(new_ids)}

def process_timestep(orig_ids, new_ids_t):
    """Process a single timestep to create ID mapping."""
    result = np.zeros(len(new_id_to_idx), dtype=np.int32)
    
    valid_mask = new_ids_t > 0
    
    # Get valid points for this timestep
    if not valid_mask.any():
        return result
        
    # Extract valid IDs
    orig_valid = orig_ids[valid_mask]
    new_valid = new_ids_t[valid_mask]
    
    if len(orig_valid) == 0:
        return result
        
    # Find unique pairs efficiently
    unique_pairs = np.unique(np.column_stack((orig_valid, new_valid)), axis=0)
    
    # Create mapping
    for orig_id, new_id in unique_pairs:
        if new_id in new_id_to_idx:
            result[new_id_to_idx[new_id]] = orig_id
            
    return result


# Use apply_ufunc to parallelize the computation
result = xr.apply_ufunc(
    process_timestep,
    original_ids_field,
    new_ids_field,
    input_core_dims=[['z'], ['z']],
    output_core_dims=[['ID']],
    vectorize=True,
    dask='parallelized',
    output_dtypes=[np.int32],
    dask_gufunc_kwargs={'output_sizes': {'ID': len(new_ids)}}
)

In [None]:
global_id_mapping = xr.apply_ufunc(
                    process_timestep,
                    original_ids_field,
                    new_ids_field,
                    input_core_dims=[['z'], ['z']],
                    output_core_dims=[['ID']],
                    vectorize=True,
                    dask='parallelized',
                    output_dtypes=[np.int32],
                    dask_gufunc_kwargs={'output_sizes': {'ID': len(new_ids)}}
            ).assign_coords(ID=new_ids).compute()
        

blobs_props_extended['global_ID'] = global_id_mapping

In [None]:


result = xr.where(new_ids_field.isel(z=first_match_idx) == new_ids, 
                    original_ids_field.isel(z=first_match_idx), 0)

global_id_mapping = (result
    .assign_coords(new_id=new_ids)
    .rename({'new_id': 'ID'})
    .astype(np.int32)
    .compute())

blobs_props_extended['global_ID'] = global_id_mapping

In [None]:
dummy = blobs_props.isel(ID=0) * np.nan
blobs_props = xr.concat([dummy.assign_coords(ID=0), blobs_props], dim='ID')


for var_name in blobs_props.data_vars:
    
    temp = (blobs_props[var_name]
                        .sel(ID=global_id_mapping.rename({'ID':'new_id'}))
                        .drop_vars('ID').rename({'new_id':'ID'}))
    
    if var_name == 'ID':
        temp = temp.astype(np.int32)
    else:
        temp = temp.astype(np.float32)
        
    blobs_props_extended[var_name] = temp

In [None]:
old_parent_IDs = xr.where(merge_events.parent_IDs>0, merge_events.parent_IDs, 0)
new_IDs_parents = ID_to_cluster_index_da.sel(ID=old_parent_IDs)

# Replace the coordinate merge_ID in new_IDs_parents with merge_time.  merge_events.merge_time gives merge_time for each merge_ID
new_IDs_parents_t = new_IDs_parents.assign_coords({'merge_time': merge_events.merge_time}).drop_vars('ID').swap_dims({'merge_ID': 'merge_time'})  # this now has coordinate merge_time and ID

# Map new_IDs_parents_t into a new data array with dimensions time, ID, and sibling_ID
merge_ledger = xr.full_like(global_id_mapping, fill_value=-1).expand_dims({'sibling_ID': new_IDs_parents_t.parent_idx.shape[0]}).copy() # dimesions are time, ID, sibling_ID

for time_val in new_IDs_parents_t.merge_time.values:
    IDs = new_IDs_parents_t.sel({'merge_time': time_val})
    if IDs.ndim == 1:
        IDs = IDs.values
        for ID in IDs:
            if ID > 0:
                merge_ledger.loc[{tracker.timedim: time_val, 'ID': ID}] = IDs
    else:  # There were multiple mergers at this time...
        for merge_num, _ in enumerate(IDs.merge_time):
            IDs_sub = IDs.isel(merge_time=merge_num).values
            for ID in IDs_sub:
                if ID > 0:
                    merge_ledger.loc[{tracker.timedim: time_val, 'ID': ID}] = IDs_sub

merge_ledger = merge_ledger.rename('merge_ledger').transpose(tracker.timedim, 'ID', 'sibling_ID').chunk({tracker.timedim: split_merged_relabeled_blob_id_field.data.chunksize[0]})
