# Identify & Track Marine Heatwaves using `spot_the_blOb`

## Processing Steps:
1. Fill holes in the binary data, using `dask_image.ndmorph` -- up to `R_fill` cells in radius.
2. Filter out small objects -- area less than the `area_filter_quartile` of the distribution of objects.
3. Identify objects in the binary data, using `dask_image.ndmeasure`.
4. Manually connect objects across time, applying Sun et al. 2023 criteria:
    - Connected Blobs must overlap by at least `overlap_threshold=50%` of the smaller blob.
    - Merged Blobs retain their original ID, but split the blob based on parent centroid locality.
5. Cluster and reduce the final object ID graph using `scipy.sparse.csgraph.connected_components`.

N.B.: Exploits parallelised `Dask` operations with optimised chunking using `flox` for memory efficiency and speed \
N.N.B.: This example using 40 years of Daily outputs at 0.25° resolution takes ~6 minutes on 128 total cores.

In [1]:
import xarray as xr
import dask
from getpass import getuser
from pathlib import Path

import spot_the_blOb as blob
import spot_the_blOb.helper as hpc

In [2]:
# Start Dask Cluster
client = hpc.StartLocalCluster(n_workers=32, n_threads=2)

Memory per Worker: 7.86 GB
Hostname is  l20134
Forward Port = l20134:8787
Dashboard Link: localhost:8787/status


In [3]:
# Load Pre-processed Data (cf. `01_preprocess_extremes.ipynb`)

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'extreme_events_binary.zarr'
chunk_size = {'time': 25, 'lat': -1, 'lon': -1}
ds = xr.open_zarr(str(file_name), chunks=chunk_size)

In [4]:
# Extract Binary Features and Modify Mask

extreme_bin = ds.extreme_events.isel(time=slice(0, 1000))
mask = ds.mask.where((ds.lat<85) & (ds.lat>-90), other=False)

In [5]:
# Tracking Parameters

drop_area_quartile = 0.5
filling_radius = 8
allow_merging = True
centroid_partitioning = False


In [6]:
# Spot the Blobs

tracker = blob.Spotter(extreme_bin, mask, R_fill=filling_radius, area_filter_quartile=drop_area_quartile, 
                       allow_merging=allow_merging, centroid_partitioning=centroid_partitioning)
# blobs = tracker.run()

# blobs

In [7]:
data_bin_filled = tracker.fill_holes()

# Remove Small Objects
data_bin_filtered, area_threshold, blob_areas, N_blobs_unfiltered = tracker.filter_small_blobs(data_bin_filled)

In [8]:
import xarray as xr
import numpy as np
from dask_image.ndmeasure import label
from skimage.measure import regionprops_table
from dask_image.ndmorph import binary_closing, binary_opening
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from dask import persist
from dask.base import is_dask_collection
from numba import jit, prange

In [9]:


@jit(nopython=True, parallel=True, fastmath=True)
def wrapped_euclidian_parallel(mask_values, parent_centroids_values, Nx):
    """
    Optimised function for computing wrapped Euclidean distances.
    
    Parameters:
    -----------
    mask_values : np.ndarray
        2D boolean array where True indicates points to calculate distances for
    parent_centroids_values : np.ndarray
        Array of shape (n_parents, 2) containing (y, x) coordinates of parent centroids
    Nx : int
        Size of the x-dimension for wrapping
        
    Returns:
    --------
    distances : np.ndarray
        Array of shape (n_true_points, n_parents) with minimum distances
    """
    n_parents = len(parent_centroids_values)
    half_Nx = Nx / 2
    
    y_indices, x_indices = np.nonzero(mask_values)
    n_true = len(y_indices)
    
    distances = np.empty((n_true, n_parents), dtype=np.float64)
    
    # Precompute for faster access
    parent_y = parent_centroids_values[:, 0]
    parent_x = parent_centroids_values[:, 1]
    
    # Parallel loop over true positions
    for idx in prange(n_true):
        y, x = y_indices[idx], x_indices[idx]
        
        # Pre-compute y differences for all parents
        dy = y - parent_y
        
        # Pre-compute x differences for all parents
        dx = x - parent_x
        
        # Wrapping correction
        dx = np.where(dx > half_Nx, dx - Nx, dx)
        dx = np.where(dx < -half_Nx, dx + Nx, dx)
        
        distances[idx] = np.sqrt(dy * dy + dx * dx)
    
    return distances



@jit(nopython=True, fastmath=True)
def create_grid_index_arrays(points_y, points_x, grid_size, ny, nx):
    """
    Creates a grid-based spatial index using numpy arrays.
    """
    n_grids_y = (ny + grid_size - 1) // grid_size
    n_grids_x = (nx + grid_size - 1) // grid_size
    max_points_per_cell = len(points_y)
    
    grid_points = np.full((n_grids_y, n_grids_x, max_points_per_cell), -1, dtype=np.int32)
    grid_counts = np.zeros((n_grids_y, n_grids_x), dtype=np.int32)
    
    for idx in range(len(points_y)):
        grid_y = min(points_y[idx] // grid_size, n_grids_y - 1)
        grid_x = min(points_x[idx] // grid_size, n_grids_x - 1)
        count = grid_counts[grid_y, grid_x]
        if count < max_points_per_cell:
            grid_points[grid_y, grid_x, count] = idx
            grid_counts[grid_y, grid_x] += 1
    
    return grid_points, grid_counts

@jit(nopython=True, fastmath=True)
def calculate_wrapped_distance(y1, x1, y2, x2, nx, half_nx):
    """
    Calculate distance with periodic boundary conditions in x dimension.
    """
    dy = y1 - y2
    dx = x1 - x2
    
    if dx > half_nx:
        dx -= nx
    elif dx < -half_nx:
        dx += nx
        
    return np.sqrt(dy * dy + dx * dx)

@jit(nopython=True, parallel=True, fastmath=True)
def get_nearest_parent_labels(child_mask, parent_masks, child_ids, parent_centroids, Nx, max_distance=20):
    """
    Assigns labels based on nearest parent blob points.
    This is quite computationally-intensive, so we utilise many optimisations here...
    """
    ny, nx = child_mask.shape
    half_Nx = Nx / 2
    n_parents = len(parent_masks)
    grid_size = max(2, max_distance // 4)
    
    y_indices, x_indices = np.nonzero(child_mask)
    n_child_points = len(y_indices)
    
    min_distances = np.full(n_child_points, np.inf)
    parent_assignments = np.zeros(n_child_points, dtype=np.int32)
    found_close = np.zeros(n_child_points, dtype=np.bool_)
    
    for parent_idx in range(n_parents):
        py, px = np.nonzero(parent_masks[parent_idx])
        
        if len(py) == 0:  # Skip empty parents
            continue
            
        # Create grid index for this parent
        n_grids_y = (ny + grid_size - 1) // grid_size
        n_grids_x = (nx + grid_size - 1) // grid_size
        grid_points, grid_counts = create_grid_index_arrays(py, px, grid_size, ny, nx)
        
        # Process child points in parallel
        for child_idx in prange(n_child_points):
            if found_close[child_idx]:  # Skip if we already found a very close match
                continue
                
            child_y, child_x = y_indices[child_idx], x_indices[child_idx]
            grid_y = min(child_y // grid_size, n_grids_y - 1)
            grid_x = min(child_x // grid_size, n_grids_x - 1)
            
            min_dist_to_parent = np.inf
            
            # Check nearby grid cells
            for dy in range(-1, 2):
                grid_y_check = (grid_y + dy) % n_grids_y
                
                for dx in range(-1, 2):
                    grid_x_check = (grid_x + dx) % n_grids_x
                    
                    # Process points in this grid cell
                    n_points = grid_counts[grid_y_check, grid_x_check]
                    
                    for p_idx in range(n_points):
                        point_idx = grid_points[grid_y_check, grid_x_check, p_idx]
                        if point_idx == -1:
                            break
                        
                        dist = calculate_wrapped_distance(
                            child_y, child_x,
                            py[point_idx], px[point_idx],
                            Nx, half_Nx
                        )
                        
                        if dist > max_distance:
                            continue
                        
                        if dist < min_dist_to_parent:
                            min_dist_to_parent = dist
                            
                        if dist < 2:  # Very close match found
                            min_dist_to_parent = dist
                            found_close[child_idx] = True
                            break
                    
                    if found_close[child_idx]:
                        break
                
                if found_close[child_idx]:
                    break
            
            # Update assignment if this parent is closer
            if min_dist_to_parent < min_distances[child_idx]:
                min_distances[child_idx] = min_dist_to_parent
                parent_assignments[child_idx] = parent_idx
                
                if min_dist_to_parent < 2:
                    found_close[child_idx] = True
    
    # Handle any unassigned points using centroids
    unassigned = min_distances == np.inf
    if np.any(unassigned):
        for child_idx in np.nonzero(unassigned)[0]:
            child_y, child_x = y_indices[child_idx], x_indices[child_idx]
            min_dist = np.inf
            best_parent = 0
            
            for parent_idx in range(n_parents):
                # Calculate distance to centroid with periodic boundary conditions
                dist = calculate_wrapped_distance(
                    child_y, child_x,
                    parent_centroids[parent_idx, 0],
                    parent_centroids[parent_idx, 1],
                    Nx, half_Nx
                )
                
                if dist < min_dist:
                    min_dist = dist
                    best_parent = parent_idx
                    
            parent_assignments[child_idx] = best_parent
    
    # Convert from parent indices to child_ids
    new_labels = child_ids[parent_assignments]
    
    return new_labels



In [10]:
data_bin = data_bin_filtered

In [11]:
blob_id_field, _ = tracker.identify_blobs(data_bin, time_connectivity=False)
        
# Calculate Properties of each Blob
blob_props = tracker.calculate_blob_properties(blob_id_field, properties=['area', 'centroid'])

# Compile List of Overlapping Blob ID Pairs Across Time
overlap_blobs_list = tracker.find_overlapping_blobs(blob_id_field)  # List of overlapping blob pairs

In [12]:
blob_id_field_unique = blob_id_field

In [None]:
areas_0 = blob_props['area'].sel(ID=overlap_blobs_list[:, 0]).values
areas_1 = blob_props['area'].sel(ID=overlap_blobs_list[:, 1]).values
min_areas = np.minimum(areas_0, areas_1)
overlap_fractions = overlap_blobs_list[:, 2].astype(float) / min_areas

## Filter out the overlaps that are too small
overlap_blobs_list = overlap_blobs_list[overlap_fractions >= tracker.overlap_threshold]



#################################
##### Consider Merging Blobs ####
#################################

## Initialize merge tracking lists to build DataArray later
merge_times = []      # When the merge occurred
merge_child_ids = []  # Resulting child ID
merge_parent_ids = [] # List of parent IDs that merged
merge_areas = []      # Areas of overlap
next_new_id = blob_props.ID.max().item() + 1  # Start new IDs after highest existing ID

# Find all the Children (t+1 / RHS) elements that appear multiple times --> Indicates there are 2+ Parent Blobs...
unique_children, children_counts = np.unique(overlap_blobs_list[:, 1], return_counts=True)
merging_blobs = unique_children[children_counts > 1]

# Pre-compute the child_time_idx & 2d_mask_id for each child_blob
time_index_map = tracker.compute_id_time_dict(blob_id_field_unique, merging_blobs, next_new_id)
Nx = blob_id_field_unique[tracker.xdim].size

# Group blobs by time-chunk
# -- Pre-condition: Blob IDs should be monotonically increasing in time...
chunk_boundaries = np.cumsum([0] + list(blob_id_field_unique.chunks[0] ))
blobs_by_chunk = {}
for blob_id in merging_blobs:
    # Find which chunk this time index belongs to
    chunk_idx = np.searchsorted(chunk_boundaries, time_index_map[blob_id], side='right') - 1
    blobs_by_chunk.setdefault(chunk_idx, []).append(blob_id)


future_chunk_merges = []
for chunk_idx, chunk_blobs in blobs_by_chunk.items(): # Loop over each time-chunk
    # We do this to avoid repetetively re-computing and injecting tiny changes into the full dask-backed DataArray blob_id_field_unique
    
    ## Extract and Load an entire chunk into memory
    
    chunk_start = sum(blob_id_field_unique.chunks[0][:chunk_idx])
    chunk_end = chunk_start + blob_id_field_unique.chunks[0][chunk_idx] + 1  #  We also want access to the blob_id_time_p1...  But need to remember to remove the last time later
    
    chunk_data = blob_id_field_unique.isel({tracker.timedim: slice(chunk_start, chunk_end)}).compute()
    
    # Create a working queue of blobs to process
    blobs_to_process = chunk_blobs.copy()
    # Combine only the future_chunk_merges that don't already appear in blobs_to_process
    blobs_to_process = blobs_to_process + [blob_id for blob_id in future_chunk_merges if blob_id not in blobs_to_process]  # First, assess the new blobs from the end of the previous chunk...
    future_chunk_merges = []
    
    #for child_id in chunk_blobs: # Process each blob in this chunk
    while blobs_to_process:  # Process until queue is empty
        child_id = blobs_to_process.pop(0)  # Get next blob to process
        
        child_time_idx = time_index_map[child_id]
        relative_time_idx = child_time_idx - chunk_start
        
        blob_id_time = chunk_data.isel({tracker.timedim: relative_time_idx})
        blob_id_time_p1 = chunk_data.isel({tracker.timedim: relative_time_idx+1})
        if relative_time_idx-1 >= 0:
            blob_id_time_m1 = chunk_data.isel({tracker.timedim: relative_time_idx-1})
        elif chunk_idx > 0:
            blob_id_time_m1 = blob_id_field_unique.isel({tracker.timedim: chunk_start-1})
        else:
            blob_id_time_m1 = xr.full_like(blob_id_time, 0)
        
        child_mask_2d  = (blob_id_time == child_id).values
        
        # Find all pairs involving this Child Blob
        child_mask = overlap_blobs_list[:, 1] == child_id
        child_where = np.where(overlap_blobs_list[:, 1] == child_id)[0]  # Needed for assignment
        merge_group = overlap_blobs_list[child_mask]
        
        # Get all Parents (LHS) Blobs that overlap with this Child Blob -- N.B. This is now generalised for N-parent merging !
        parent_ids = merge_group[:, 0]
        num_parents = len(parent_ids)
        
        # Make a new ID for the other Half of the Child Blob & Record in the Merge Ledger
        new_blob_id = np.arange(next_new_id, next_new_id + (num_parents - 1), dtype=np.int32)
        next_new_id += num_parents - 1
        
        # Replace the 2nd+ Child in the Overlap Blobs List with the new Child ID
        overlap_blobs_list[child_where[1:], 1] = new_blob_id    #overlap_blobs_list[child_mask, 1][1:] = new_blob_id
        child_ids = np.concatenate((np.array([child_id]), new_blob_id))    #np.array([child_id, new_blob_id])
        
        # Record merge event data
        merge_times.append(chunk_data.isel({tracker.timedim: relative_time_idx}).time.values)
        merge_child_ids.append(child_ids)
        merge_parent_ids.append(parent_ids)
        merge_areas.append(overlap_blobs_list[child_mask, 2])
        
        ### Relabel the Original Child Blob ID Field to account for the New ID:
        parent_centroids = blob_props.sel(ID=parent_ids).centroid.values.T  # (y, x), [:,0] are the y's
        
        if tracker.centroid_partitioning:
            # --> For every (Original) Child Cell in the ID Field, Measure the Distance to the Centroids of the Parents
            distances = wrapped_euclidian_parallel(child_mask_2d, parent_centroids, Nx)  # **Deals with wrapping**
            
            # Assign the new ID to each cell based on the closest parent
            new_labels = child_ids[np.argmin(distances, axis=1)]
        
        else:
            # --> For every (Original) Child Cell in the ID Field, Find the closest parent cell
            parent_masks = np.zeros((len(parent_ids), blob_id_time.shape[0], blob_id_time.shape[1]), dtype=bool)
            for idx, parent_id in enumerate(parent_ids):
                parent_masks[idx] = (blob_id_time_m1 == parent_id).values
            
            # Calculate typical blob size to set max_distance
            max_area = np.max(blob_props.sel(ID=parent_ids).area.values)
            max_distance = int(np.sqrt(max_area) * 2.0)  # Use 2x the max blob radius
            max_distance = max(max_distance, 20)  # Set minimum threshold...
            
            new_labels = get_nearest_parent_labels(
                child_mask_2d,
                parent_masks, 
                child_ids,
                parent_centroids,
                Nx,
                max_distance=max_distance
            )
        
        
        ## Update values in child_time_idx and assign the updated slice back to the original DataArray
        temp = np.zeros_like(blob_id_time)
        temp[child_mask_2d] = new_labels
        blob_id_time = blob_id_time.where(~child_mask_2d, temp)
        ## ** Update directly into the chunk
        chunk_data[{tracker.timedim: relative_time_idx}] = blob_id_time
        
        ## Add new entries to time_index_map for each of new_blob_id corresponding to the current time index
        time_index_map.update({new_id: child_time_idx for new_id in new_blob_id})
        
        ## Update the Properties of the N Children Blobs
        new_child_props = tracker.calculate_blob_properties(blob_id_time, properties=['area', 'centroid'])
        
        # Update the blob_props DataArray:  (but first, check if the original Children still exists)
        if child_id in new_child_props.ID:  # Update the entry
            blob_props.loc[dict(ID=child_id)] = new_child_props.sel(ID=child_id)
        else:  # Delte child_id:  The blob has split/morphed such that it doesn't get a partition of this child...
            blob_props = blob_props.drop_sel(ID=child_id)  # N.B.: This means that the IDs are no longer continuous...
            print(f"Deleted child_id {child_id} because parents have split/morphed in the meantime...")
        # Add the properties for the N-1 other new child ID
        new_blob_ids_still = new_child_props.ID.where(new_child_props.ID.isin(new_blob_id), drop=True).ID
        blob_props = xr.concat([blob_props, new_child_props.sel(ID=new_blob_ids_still)], dim='ID')
        missing_ids = set(new_blob_id) - set(new_blob_ids_still.values)
        if len(missing_ids) > 0:
            print(f"Missing newly created child_ids {missing_ids} because parents have split/morphed in the meantime...")

        
        ## Finally, Re-assess all of the Parent IDs (LHS) equal to the (original) child_id
        
        # Look at the overlap IDs between the original child_id and the next time-step, and also the new_blob_id and the next time-step
        new_overlaps = tracker.check_overlap_slice(blob_id_time.values, blob_id_time_p1.values)
        new_child_overlaps_list = new_overlaps[(new_overlaps[:, 0] == child_id) | np.isin(new_overlaps[:, 0], new_blob_id)]
        
        # _Before_ replacing the overlap_blobs_list, we need to re-assess the overlap fractions of just the new_child_overlaps_list
        areas_0 = blob_props['area'].sel(ID=new_child_overlaps_list[:, 0]).values
        areas_1 = blob_props['area'].sel(ID=new_child_overlaps_list[:, 1]).values
        min_areas = np.minimum(areas_0, areas_1)
        overlap_fractions = new_child_overlaps_list[:, 2].astype(float) / min_areas
        new_child_overlaps_list = new_child_overlaps_list[overlap_fractions >= tracker.overlap_threshold]
        
        # Replace the lines in the overlap_blobs_list where (original) child_id is on the LHS, with these new pairs in new_child_overlaps_list
        child_mask_LHS = overlap_blobs_list[:, 0] == child_id
        overlap_blobs_list = np.concatenate([overlap_blobs_list[~child_mask_LHS], new_child_overlaps_list])
        
        
        ## Finally, _FINALLY_, we need to ensure that of the new children blobs we made, they only overlap with their respective parent...
        new_unique_children, new_children_counts = np.unique(new_child_overlaps_list[:, 1], return_counts=True)
        new_merging_blobs = new_unique_children[new_children_counts > 1]
        if new_merging_blobs.size > 0:
            
            if relative_time_idx + 1 < chunk_data.sizes[tracker.timedim]-1:  # If there is a next time-step in this chunk
                for new_child_id in new_merging_blobs:
                    if new_child_id not in blobs_to_process: # We aren't already going to assess this blob
                        blobs_to_process.insert(0, new_child_id)
            
            else: # This is out of our current jurisdiction: Defer this reassessment to the beginning of the next chunk
                future_chunk_merges.extend(new_merging_blobs)
        
    
    # Update the full dask DataArray with this processed chunk
    blob_id_field_unique[{
        tracker.timedim: slice(chunk_start, chunk_end-1)  # cf. above definition of chunk_end for why we need -1
    }] = chunk_data[:(chunk_end-1-chunk_start)]
    
    
    ### Process the Merge Events
    max_parents = max(len(ids) for ids in merge_parent_ids)
    max_children = max(len(ids) for ids in merge_child_ids)
    
    # Convert lists to padded numpy arrays
    parent_ids_array = np.full((len(merge_parent_ids), max_parents), -1, dtype=np.int32)
    child_ids_array = np.full((len(merge_child_ids), max_children), -1, dtype=np.int32)
    overlap_areas_array = np.full((len(merge_areas), max_parents), -1, dtype=np.int32)

    for i, parents in enumerate(merge_parent_ids):
        parent_ids_array[i, :len(parents)] = parents
    
    for i, children in enumerate(merge_child_ids):
        child_ids_array[i, :len(children)] = children
    
    for i, areas in enumerate(merge_areas):
        overlap_areas_array[i, :len(areas)] = areas
    
    # Create merge events dataset
    merge_events = xr.Dataset(
        {
            'merge_parent_IDs': (('merge_ID', 'parent_idx'), parent_ids_array),
            'merge_child_IDs': (('merge_ID', 'child_idx'), child_ids_array),
            'merge_overlap_areas': (('merge_ID', 'parent_idx'), overlap_areas_array),
            'merge_time': ('merge_ID', merge_times),
            'merge_n_parents': ('merge_ID', np.array([len(p) for p in merge_parent_ids])),
            'merge_n_children': ('merge_ID', np.array([len(c) for c in merge_child_ids]))
        },
        attrs={
            'fill_value': -1
        }
    )
    

  unique_ids_by_time = xr.apply_ufunc(
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repe

In [None]:
# Save Tracked Blobs

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'MHWs_tracked.nc'
blobs.to_netcdf(file_name, mode='w')