# Identify & Track Marine Heatwaves on _Unstructured Grid_ using `spot_the_blOb`

## Processing Steps:
1. Fill spatial holes in the binary data, using `dask_image.ndmorph` -- up to `R_fill` cells in radius.
2. Fill gaps in time -- permitting up to `T_fill` missing time slices, while keeping the same blob ID.
3. Filter out small objects -- area less than the bottom `area_filter_quartile` of the size distribution of objects.
4. Identify objects in the binary data, using `dask_image.ndmeasure`.
5. Connect objects across time, applying the following criteria for splitting, merging, and persistence:
    - Connected Blobs must overlap by at least fraction `overlap_threshold` of the smaller blob.
    - Merged Blobs retain their original ID, but partition the child blob based on the parent of the _nearest-neighbour_ cell. 
6. Cluster and reduce the final object ID graph using `scipy.sparse.csgraph.connected_components`.
7. Map the tracked objects into ID-time space for convenient analysis.

N.B.: Exploits parallelised `dask` operations with optimised chunking using `flox` for memory efficiency and speed \
N.N.B.: This example using 40 years of _daily_ outputs at 5km resolution on an Unstructured Grid (15 million cells) using 32 cores takes 
- Full Split/Merge Thresholding & Merge Tracking:  ~40 minutes

In [1]:
import xarray as xr
import dask
from getpass import getuser
from pathlib import Path

import spot_the_blOb as blob
import spot_the_blOb.helper as hpc

In [2]:
# Start Dask Cluster
client = hpc.StartLocalCluster(n_workers=32, n_threads=1)

Memory per Worker: 7.86 GB
Hostname is  l50219
Forward Port = l50219:8787
Dashboard Link: localhost:8787/status


In [3]:
# Load Pre-processed Data (cf. `01_preprocess_extremes.ipynb`)

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'extreme_events_binary_unstruct.zarr'
chunk_size = {'time': 4, 'ncells': -1}
ds = xr.open_zarr(str(file_name), chunks={}).isel(time=slice(0,32)).chunk(chunk_size)

In [4]:
# Tracking Parameters

drop_area_quartile = 0.8  # Remove the smallest 80% of the identified blobs
hole_filling_radius = 32   # Fill small holes with radius < 8 elements
time_gap_fill = 2         # Allow gaps of 2 days and still continue the blob tracking with the same ID
allow_merging = True      # Allow blobs to split/merge. Keeps track of merge events & unique IDs.
overlap_threshold = 0.5   # Overlap threshold for merging blobs. If overlap < threshold, blobs keep independent IDs.
nn_partitioning = True    # Use new NN method to partition merged children blobs. If False, reverts to old method of Di Sun et al. 2023...

In [5]:
# SpOt & Track the Blobs & Merger Events

tracker = blob.Spotter(ds.extreme_events, ds.mask, R_fill=hole_filling_radius, T_fill = time_gap_fill, area_filter_quartile=drop_area_quartile, 
                       allow_merging=allow_merging, overlap_threshold=overlap_threshold, nn_partitioning=nn_partitioning, 
                       xdim='ncells',               # Need to tell spot_the_blOb the new Unstructured dimension
                       unstructured_grid=True,      # Use Unstructured Grid
                       neighbours=ds.neighbours,    # Connectivity array for the Unstructured Grid
                       cell_areas=ds.cell_areas)      # Cell areas for each Unstructured Grid cell
# blobs = tracker.run(return_merges=False)

# blobs

Constructing the Sparse Dilation Matrix...


In [6]:
import pyicon as pyic
import numpy as np

----Start loading pyicon.
----Start loading pyicon.
----Pyicon was loaded successfully.
----Pyicon was loaded successfully.


In [7]:
data_bin_filled = tracker.fill_holes(tracker.data_bin).persist()
data_bin_filled

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


Unnamed: 0,Array,Chunk
Bytes,454.29 MiB,56.79 MiB
Shape,"(32, 14886338)","(4, 14886338)"
Dask graph,8 chunks in 1 graph layer,8 chunks in 1 graph layer
Data type,bool numpy.ndarray,bool numpy.ndarray
"Array Chunk Bytes 454.29 MiB 56.79 MiB Shape (32, 14886338) (4, 14886338) Dask graph 8 chunks in 1 graph layer Data type bool numpy.ndarray",14886338  32,

Unnamed: 0,Array,Chunk
Bytes,454.29 MiB,56.79 MiB
Shape,"(32, 14886338)","(4, 14886338)"
Dask graph,8 chunks in 1 graph layer,8 chunks in 1 graph layer
Data type,bool numpy.ndarray,bool numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 113.57 MiB 113.57 MiB Shape (14886338,) (14886338,) Dask graph 1 chunks in 1 graph layer Data type float64 numpy.ndarray",14886338  1,

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 113.57 MiB 113.57 MiB Shape (14886338,) (14886338,) Dask graph 1 chunks in 1 graph layer Data type float64 numpy.ndarray",14886338  1,

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [8]:
data_bin_gap_filled = tracker.fill_time_gaps(data_bin_filled).persist()
data_bin_gap_filled

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


Unnamed: 0,Array,Chunk
Bytes,454.29 MiB,56.79 MiB
Shape,"(32, 14886338)","(4, 14886338)"
Dask graph,8 chunks in 1 graph layer,8 chunks in 1 graph layer
Data type,bool numpy.ndarray,bool numpy.ndarray
"Array Chunk Bytes 454.29 MiB 56.79 MiB Shape (32, 14886338) (4, 14886338) Dask graph 8 chunks in 1 graph layer Data type bool numpy.ndarray",14886338  32,

Unnamed: 0,Array,Chunk
Bytes,454.29 MiB,56.79 MiB
Shape,"(32, 14886338)","(4, 14886338)"
Dask graph,8 chunks in 1 graph layer,8 chunks in 1 graph layer
Data type,bool numpy.ndarray,bool numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 113.57 MiB 113.57 MiB Shape (14886338,) (14886338,) Dask graph 1 chunks in 1 graph layer Data type float64 numpy.ndarray",14886338  1,

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 113.57 MiB 113.57 MiB Shape (14886338,) (14886338,) Dask graph 1 chunks in 1 graph layer Data type float64 numpy.ndarray",14886338  1,

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [9]:
data_bin_filtered, area_threshold, blob_areas, N_blobs_unfiltered = tracker.filter_small_blobs(data_bin_gap_filled)

Debugging information
---------------------
old task state: processing
old run_spec: <Task ('original-invert-9fdc7540e2bc9f3c9012d0566315b6b3', 0) _execute_subgraph(...)>
new run_spec: <Task ('original-invert-9fdc7540e2bc9f3c9012d0566315b6b3', 0) _execute_subgraph(...)>
old token: '5ce67f316619954df71be2ecd5aaa91a'
new token: 'c778744811edc5d94a2a030e27fa8743'
old dependencies: set()
new dependencies: set()

Debugging information
---------------------
old task state: memory
old run_spec: <Task ('original-invert-9fdc7540e2bc9f3c9012d0566315b6b3', 0) _execute_subgraph(...)>
new run_spec: <Task ('original-invert-9fdc7540e2bc9f3c9012d0566315b6b3', 0) _execute_subgraph(...)>
old token: '5ce67f316619954df71be2ecd5aaa91a'
new token: 'fc9b883e0e29cbddd118c53d16fc95cd'
old dependencies: set()
new dependencies: set()

Debugging information
---------------------
old task state: processing
old run_spec: <Task ('original-invert-9fdc7540e2bc9f3c9012d0566315b6b3', 0) _execute_subgraph(...)>
new run_s

  cluster_sizes, unique_cluster_IDs = xr.apply_ufunc(count_cluster_sizes,


In [10]:
data_bin_filtered = data_bin_filtered.persist()
data_bin_filtered

Unnamed: 0,Array,Chunk
Bytes,454.29 MiB,56.79 MiB
Shape,"(32, 14886338)","(4, 14886338)"
Dask graph,8 chunks in 1 graph layer,8 chunks in 1 graph layer
Data type,bool numpy.ndarray,bool numpy.ndarray
"Array Chunk Bytes 454.29 MiB 56.79 MiB Shape (32, 14886338) (4, 14886338) Dask graph 8 chunks in 1 graph layer Data type bool numpy.ndarray",14886338  32,

Unnamed: 0,Array,Chunk
Bytes,454.29 MiB,56.79 MiB
Shape,"(32, 14886338)","(4, 14886338)"
Dask graph,8 chunks in 1 graph layer,8 chunks in 1 graph layer
Data type,bool numpy.ndarray,bool numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 113.57 MiB 113.57 MiB Shape (14886338,) (14886338,) Dask graph 1 chunks in 1 graph layer Data type float64 numpy.ndarray",14886338  1,

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 113.57 MiB 113.57 MiB Shape (14886338,) (14886338,) Dask graph 1 chunks in 1 graph layer Data type float64 numpy.ndarray",14886338  1,

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [27]:
#### Track_blObs....
###############


blob_id_field, _ = tracker.identify_blobs(data_bin_filtered, time_connectivity=False)

In [28]:
blob_id_field = blob_id_field.persist()
blob_id_field

Unnamed: 0,Array,Chunk
Bytes,1.77 GiB,227.15 MiB
Shape,"(32, 14886338)","(4, 14886338)"
Dask graph,8 chunks in 1 graph layer,8 chunks in 1 graph layer
Data type,int32 numpy.ndarray,int32 numpy.ndarray
"Array Chunk Bytes 1.77 GiB 227.15 MiB Shape (32, 14886338) (4, 14886338) Dask graph 8 chunks in 1 graph layer Data type int32 numpy.ndarray",14886338  32,

Unnamed: 0,Array,Chunk
Bytes,1.77 GiB,227.15 MiB
Shape,"(32, 14886338)","(4, 14886338)"
Dask graph,8 chunks in 1 graph layer,8 chunks in 1 graph layer
Data type,int32 numpy.ndarray,int32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 113.57 MiB 113.57 MiB Shape (14886338,) (14886338,) Dask graph 1 chunks in 1 graph layer Data type float64 numpy.ndarray",14886338  1,

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 113.57 MiB 113.57 MiB Shape (14886338,) (14886338,) Dask graph 1 chunks in 1 graph layer Data type float64 numpy.ndarray",14886338  1,

Unnamed: 0,Array,Chunk
Bytes,113.57 MiB,113.57 MiB
Shape,"(14886338,)","(14886338,)"
Dask graph,1 chunks in 1 graph layer,1 chunks in 1 graph layer
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [29]:
cumsum_ids = (blob_id_field.max(dim=tracker.xdim)).cumsum(tracker.timedim).shift({tracker.timedim: 1}, fill_value=0)
blob_id_field = xr.where(blob_id_field > 0, 
                                        blob_id_field + cumsum_ids, 0)

In [30]:
blob_props = tracker.calculate_blob_properties(blob_id_field, properties=['area', 'centroid'])

  props_array = xr.apply_ufunc(blob_properties_chunk, blob_id_field,


In [31]:
overlap_blobs_list = tracker.find_overlapping_blobs(blob_id_field)

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


In [None]:
# split_merged_blob_id_field_unique, merged_blobs_props, split_merged_blobs_list, merge_events = tracker.split_and_merge_blobs(blob_id_field, blob_props, overlap_blobs_list)

In [32]:
import xarray as xr
import numpy as np
from dask_image.ndmeasure import label
from skimage.measure import regionprops_table
from dask_image.ndmorph import binary_closing as binary_closing_dask
from dask_image.ndmorph import binary_opening as binary_opening_dask
from scipy.ndimage import binary_closing, binary_opening
from scipy.sparse import coo_matrix, csr_matrix, eye
from scipy.sparse.csgraph import connected_components
from dask import persist
import dask.array as dsa
from dask.base import is_dask_collection
from numba import jit, njit, int64, int32, prange
import jax.numpy as jnp
import warnings

In [33]:
blob_id_field_unique = blob_id_field

In [34]:
###################################################################
##### Enforce all Blob Pairs overlap by at least 50% (in Area) ####
###################################################################

## Vectorised computation of overlap fractions
areas_0 = blob_props['area'].sel(ID=overlap_blobs_list[:, 0]).values
areas_1 = blob_props['area'].sel(ID=overlap_blobs_list[:, 1]).values
min_areas = np.minimum(areas_0, areas_1)
overlap_fractions = overlap_blobs_list[:, 2].astype(float) / min_areas

## Filter out the overlaps that are too small
overlap_blobs_list = overlap_blobs_list[overlap_fractions >= tracker.overlap_threshold]



#################################
##### Consider Merging Blobs ####
#################################

## Initialize merge tracking lists to build DataArray later
merge_times = []      # When the merge occurred
merge_child_ids = []  # Resulting child ID
merge_parent_ids = [] # List of parent IDs that merged
merge_areas = []      # Areas of overlap
next_new_id = blob_props.ID.max().item() + 1  # Start new IDs after highest existing ID

# Find all the Children (t+1 / RHS) elements that appear multiple times --> Indicates there are 2+ Parent Blobs...
unique_children, children_counts = np.unique(overlap_blobs_list[:, 1], return_counts=True)
merging_blobs = unique_children[children_counts > 1]

# Pre-compute the child_time_idx & 2d_mask_id for each child_blob
time_index_map = tracker.compute_id_time_dict(blob_id_field_unique, merging_blobs, next_new_id)
Nx = blob_id_field_unique[tracker.xdim].size

# Group blobs by time-chunk
# -- Pre-condition: Blob IDs should be monotonically increasing in time...
chunk_boundaries = np.cumsum([0] + list(blob_id_field_unique.chunks[0] ))
blobs_by_chunk = {}
# Ensure that blobs_by_chunk has entry for every key
for chunk_idx in range(len(blob_id_field_unique.chunks[0])):
    blobs_by_chunk.setdefault(chunk_idx, [])

blob_id_field_unique = blob_id_field_unique.persist()

for blob_id in merging_blobs:
    # Find which chunk this time index belongs to
    chunk_idx = np.searchsorted(chunk_boundaries, time_index_map[blob_id], side='right') - 1
    blobs_by_chunk.setdefault(chunk_idx, []).append(blob_id)


In [35]:


##################################
### Optimised Helper Functions ###
##################################


@jit(nopython=True, parallel=True, fastmath=True)
def wrapped_euclidian_parallel(mask_values, parent_centroids_values, Nx):
    """
    Optimised function for computing wrapped Euclidean distances.
    
    Parameters:
    -----------
    mask_values : np.ndarray
        2D boolean array where True indicates points to calculate distances for
    parent_centroids_values : np.ndarray
        Array of shape (n_parents, 2) containing (y, x) coordinates of parent centroids
    Nx : int
        Size of the x-dimension for wrapping
        
    Returns:
    --------
    distances : np.ndarray
        Array of shape (n_true_points, n_parents) with minimum distances
    """
    n_parents = len(parent_centroids_values)
    half_Nx = Nx / 2
    
    y_indices, x_indices = np.nonzero(mask_values)
    n_true = len(y_indices)
    
    distances = np.empty((n_true, n_parents), dtype=np.float64)
    
    # Precompute for faster access
    parent_y = parent_centroids_values[:, 0]
    parent_x = parent_centroids_values[:, 1]
    
    # Parallel loop over true positions
    for idx in prange(n_true):
        y, x = y_indices[idx], x_indices[idx]
        
        # Pre-compute y differences for all parents
        dy = y - parent_y
        
        # Pre-compute x differences for all parents
        dx = x - parent_x
        
        # Wrapping correction
        dx = np.where(dx > half_Nx, dx - Nx, dx)
        dx = np.where(dx < -half_Nx, dx + Nx, dx)
        
        distances[idx] = np.sqrt(dy * dy + dx * dx)
    
    return distances



@jit(nopython=True, fastmath=True)
def create_grid_index_arrays(points_y, points_x, grid_size, ny, nx):
    """
    Creates a grid-based spatial index using numpy arrays.
    """
    n_grids_y = (ny + grid_size - 1) // grid_size
    n_grids_x = (nx + grid_size - 1) // grid_size
    max_points_per_cell = len(points_y)
    
    grid_points = np.full((n_grids_y, n_grids_x, max_points_per_cell), -1, dtype=np.int32)
    grid_counts = np.zeros((n_grids_y, n_grids_x), dtype=np.int32)
    
    for idx in range(len(points_y)):
        grid_y = min(points_y[idx] // grid_size, n_grids_y - 1)
        grid_x = min(points_x[idx] // grid_size, n_grids_x - 1)
        count = grid_counts[grid_y, grid_x]
        if count < max_points_per_cell:
            grid_points[grid_y, grid_x, count] = idx
            grid_counts[grid_y, grid_x] += 1
    
    return grid_points, grid_counts

@jit(nopython=True, fastmath=True)
def calculate_wrapped_distance(y1, x1, y2, x2, nx, half_nx):
    """
    Calculate distance with periodic boundary conditions in x dimension.
    """
    dy = y1 - y2
    dx = x1 - x2
    
    if dx > half_nx:
        dx -= nx
    elif dx < -half_nx:
        dx += nx
        
    return np.sqrt(dy * dy + dx * dx)

@jit(nopython=True, parallel=True, fastmath=True)
def get_nearest_parent_labels(child_mask, parent_masks, child_ids, parent_centroids, Nx, max_distance=20):
    """
    Assigns labels based on nearest parent blob points.
    This is quite computationally-intensive, so we utilise many optimisations here...
    """
    
    ny, nx = child_mask.shape
    half_Nx = Nx / 2
    n_parents = len(parent_masks)
    grid_size = max(2, max_distance // 4)
    
    y_indices, x_indices = np.nonzero(child_mask)
    n_child_points = len(y_indices)
    
    min_distances = np.full(n_child_points, np.inf)
    parent_assignments = np.zeros(n_child_points, dtype=np.int32)
    found_close = np.zeros(n_child_points, dtype=np.bool_)
    
    for parent_idx in range(n_parents):
        py, px = np.nonzero(parent_masks[parent_idx])
        
        if len(py) == 0:  # Skip empty parents
            continue
            
        # Create grid index for this parent
        n_grids_y = (ny + grid_size - 1) // grid_size
        n_grids_x = (nx + grid_size - 1) // grid_size
        grid_points, grid_counts = create_grid_index_arrays(py, px, grid_size, ny, nx)
        
        # Process child points in parallel
        for child_idx in prange(n_child_points):
            if found_close[child_idx]:  # Skip if we already found an exact match
                continue
                
            child_y, child_x = y_indices[child_idx], x_indices[child_idx]
            grid_y = min(child_y // grid_size, n_grids_y - 1)
            grid_x = min(child_x // grid_size, n_grids_x - 1)
            
            min_dist_to_parent = np.inf
            
            # Check nearby grid cells
            for dy in range(-1, 2):
                grid_y_check = (grid_y + dy) % n_grids_y
                
                for dx in range(-1, 2):
                    grid_x_check = (grid_x + dx) % n_grids_x
                    
                    # Process points in this grid cell
                    n_points = grid_counts[grid_y_check, grid_x_check]
                    
                    for p_idx in range(n_points):
                        point_idx = grid_points[grid_y_check, grid_x_check, p_idx]
                        if point_idx == -1:
                            break
                        
                        dist = calculate_wrapped_distance(
                            child_y, child_x,
                            py[point_idx], px[point_idx],
                            Nx, half_Nx
                        )
                        
                        if dist > max_distance:
                            continue
                        
                        if dist < min_dist_to_parent:
                            min_dist_to_parent = dist
                            
                        if dist < 1e-6:  # Found exact same point (within numerical precision)
                            min_dist_to_parent = dist
                            found_close[child_idx] = True
                            break
                    
                    if found_close[child_idx]:
                        break
                
                if found_close[child_idx]:
                    break
            
            # Update assignment if this parent is closer
            if min_dist_to_parent < min_distances[child_idx]:
                min_distances[child_idx] = min_dist_to_parent
                parent_assignments[child_idx] = parent_idx
    
    # Handle any unassigned points using centroids
    unassigned = min_distances == np.inf
    if np.any(unassigned):
        for child_idx in np.nonzero(unassigned)[0]:
            child_y, child_x = y_indices[child_idx], x_indices[child_idx]
            min_dist = np.inf
            best_parent = 0
            
            for parent_idx in range(n_parents):
                # Calculate distance to centroid with periodic boundary conditions
                dist = calculate_wrapped_distance(
                    child_y, child_x,
                    parent_centroids[parent_idx, 0],
                    parent_centroids[parent_idx, 1],
                    Nx, half_Nx
                )
                
                if dist < min_dist:
                    min_dist = dist
                    best_parent = parent_idx
                    
            parent_assignments[child_idx] = best_parent
    
    # Convert from parent indices to child_ids
    new_labels = child_ids[parent_assignments]
    
    return new_labels

@jit(nopython=True, fastmath=True)
def get_nearest_parent_labels_unstructured(child_mask, parent_masks, child_ids, parent_centroids, neighbours_int, lat, lon, max_distance=20):
    """
    Optimized version of nearest parent label assignment for unstructured grids.
    Uses numpy arrays throughout to ensure Numba compatibility.
    
    Parameters
    ----------
    child_mask : np.ndarray
        1D boolean array where True indicates points in the child blob
    parent_masks : np.ndarray
        2D boolean array of shape (n_parents, n_points) where True indicates points in each parent blob
    child_ids : np.ndarray
        1D array containing the IDs to assign to each partition of the child blob
    parent_centroids : np.ndarray
        Array of shape (n_parents, 2) containing (lat, lon) coordinates of parent centroids in degrees
    neighbours_int : np.ndarray
        2D array of shape (3, n_points) containing indices of neighboring cells for each point
    lat / lon : np.ndarray
        Latitude/Longitude in degrees
    max_distance : int, optional
        Maximum number of edge hops to search for parent points
    
    Returns
    -------
    new_labels : np.ndarray
        1D array containing the assigned child_ids for each True point in child_mask
    """
    n_points = len(child_mask)
    n_parents = len(parent_masks)
    
    # Pre-allocate arrays
    distances = np.full(n_points, np.inf, dtype=np.float32)
    parent_assignments = np.full(n_points, -1, dtype=np.int32)
    visited = np.zeros((n_parents, n_points), dtype=np.bool_)
    
    # Initialize with direct overlaps
    for parent_idx in range(n_parents):
        overlap_mask = parent_masks[parent_idx] & child_mask
        if np.any(overlap_mask):
            visited[parent_idx, overlap_mask] = True
            unclaimed_overlap = distances[overlap_mask] == np.inf
            if np.any(unclaimed_overlap):
                overlap_points = np.where(overlap_mask)[0]
                valid_points = overlap_points[unclaimed_overlap]
                distances[valid_points] = 0
                parent_assignments[valid_points] = parent_idx
    
    # Pre-compute trig values
    lat_rad = np.deg2rad(lat)
    lon_rad = np.deg2rad(lon)
    cos_lat = np.cos(lat_rad)
    
    # Graph traversal for remaining points
    current_distance = 0
    any_unassigned = np.any(child_mask & (parent_assignments == -1))
    
    while current_distance < max_distance and any_unassigned:
        current_distance += 1
        updates_made = False
        
        for parent_idx in range(n_parents):
            # Get current frontier points
            frontier_mask = visited[parent_idx]
            if not np.any(frontier_mask):
                continue
            
            # Process neighbors
            for i in range(3):  # For each neighbor direction
                neighbors = neighbours_int[i, frontier_mask]
                valid_neighbors = neighbors >= 0
                if not np.any(valid_neighbors):
                    continue
                    
                valid_points = neighbors[valid_neighbors]
                unvisited = ~visited[parent_idx, valid_points]
                new_points = valid_points[unvisited]
                
                if len(new_points) > 0:
                    visited[parent_idx, new_points] = True
                    update_mask = distances[new_points] > current_distance
                    if np.any(update_mask):
                        points_to_update = new_points[update_mask]
                        distances[points_to_update] = current_distance
                        parent_assignments[points_to_update] = parent_idx
                        updates_made = True
        
        if not updates_made:
            break
            
        any_unassigned = np.any(child_mask & (parent_assignments == -1))
    
    # Handle remaining unassigned points using great circle distances
    unassigned_mask = child_mask & (parent_assignments == -1)
    if np.any(unassigned_mask):
        parent_lat_rad = np.deg2rad(parent_centroids[:, 0])
        parent_lon_rad = np.deg2rad(parent_centroids[:, 1])
        cos_parent_lat = np.cos(parent_lat_rad)
        
        unassigned_points = np.where(unassigned_mask)[0]
        for point in unassigned_points:
            # Vectorized haversine calculation
            dlat = parent_lat_rad - lat_rad[point]
            dlon = parent_lon_rad - lon_rad[point]
            a = np.sin(dlat/2)**2 + cos_lat[point] * cos_parent_lat * np.sin(dlon/2)**2
            dist = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a))
            parent_assignments[point] = np.argmin(dist)
    
    # Return only the assignments for points in child_mask
    child_points = np.where(child_mask)[0]
    return child_ids[parent_assignments[child_points]]


@jit(nopython=True, parallel=True, fastmath=True)
def unstructured_centroid_partition(child_mask, parent_centroids, child_ids, lat, lon):
    """
    Assigns labels to child cells based on closest parent centroid using great circle distances.
    
    Parameters:
    -----------
    child_mask : np.ndarray
        1D boolean array indicating which cells belong to the child blob
    parent_centroids : np.ndarray
        Array of shape (n_parents, 2) containing (lat, lon) coordinates of parent centroids in degrees
    child_ids : np.ndarray
        Array of IDs to assign to each partition of the child blob
    lat / lon : np.ndarray
        Latitude/Longitude in degrees
        
    Returns:
    --------
    new_labels : np.ndarray
        1D array containing assigned child_ids for cells in child_mask
    """
    n_cells = len(child_mask)
    n_parents = len(parent_centroids)
    
    lat_rad = np.deg2rad(lat)
    lon_rad = np.deg2rad(lon)
    parent_coords_rad = np.deg2rad(parent_centroids)
    
    new_labels = np.zeros(n_cells, dtype=child_ids.dtype)
    
    # Process each child cell in parallel
    for i in prange(n_cells):
        if not child_mask[i]:
            continue
            
        min_dist = np.inf
        closest_parent = 0
        
        # Calculate great circle distance to each parent centroid
        for j in range(n_parents):
            dlat = parent_coords_rad[j, 0] - lat_rad[i]
            dlon = parent_coords_rad[j, 1] - lon_rad[i]
            
            # Use haversine formula for great circle distance
            a = np.sin(dlat/2)**2 + np.cos(lat_rad[i]) * np.cos(parent_coords_rad[j, 0]) * np.sin(dlon/2)**2
            dist = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a))
            
            if dist < min_dist:
                min_dist = dist
                closest_parent = j
        
        new_labels[i] = child_ids[closest_parent]
    
    return new_labels





## Helper Function for Super Fast Sparse Bool Multiply (*without the scipy+Dask Memory Leak*)
@njit(fastmath=True, parallel=True)
def sparse_bool_power(vec, sp_data, indices, indptr, exponent):
    vec = vec.T
    num_rows = indptr.size - 1
    num_cols = vec.shape[1]
    result = vec.copy()

    for _ in range(exponent):
        temp_result = np.zeros((num_rows, num_cols), dtype=np.bool_)

        for i in prange(num_rows):
            for j in range(indptr[i], indptr[i + 1]):
                if sp_data[j]:
                    for k in range(num_cols):
                        if result[indices[j], k]:
                            temp_result[i, k] = True

        result = temp_result

    return result.T

In [36]:
blobs_by_chunk

{0: [],
 1: [],
 2: [105.0],
 3: [170.0],
 4: [202.0],
 5: [233.0],
 6: [272.0],
 7: [313.0]}

In [37]:
for i in range(5):
    blobs_by_chunk[i] = []
    
blobs_by_chunk[6] = []
blobs_by_chunk[7] = []
blobs_by_chunk

{0: [], 1: [], 2: [], 3: [], 4: [], 5: [233.0], 6: [], 7: []}

In [None]:

future_chunk_merges = []
updated_chunks = []
for chunk_idx, chunk_blobs in blobs_by_chunk.items(): # Loop over each time-chunk
    # We do this to avoid repetetively re-computing and injecting tiny changes into the full dask-backed DataArray blob_id_field_unique
    
    ## Extract and Load an entire chunk into memory
    
    chunk_start = sum(blob_id_field_unique.chunks[0][:chunk_idx])
    chunk_end = chunk_start + blob_id_field_unique.chunks[0][chunk_idx] + 1  #  We also want access to the blob_id_time_p1...  But need to remember to remove the last time later
    
    chunk_data = blob_id_field_unique.isel({tracker.timedim: slice(chunk_start, chunk_end)}).compute()
    
    # Create a working queue of blobs to process
    blobs_to_process = chunk_blobs.copy()
    # Combine only the future_chunk_merges that don't already appear in blobs_to_process
    blobs_to_process = blobs_to_process + [blob_id for blob_id in future_chunk_merges if blob_id not in blobs_to_process]  # First, assess the new blobs from the end of the previous chunk...
    future_chunk_merges = []
    
    #for child_id in chunk_blobs: # Process each blob in this chunk
    while blobs_to_process:  # Process until queue is empty
        child_id = blobs_to_process.pop(0)  # Get next blob to process
        
        child_time_idx = time_index_map[child_id]
        relative_time_idx = child_time_idx - chunk_start
        
        print('Processing child_id:', child_id, 'at time index:', child_time_idx)
        
        blob_id_time = chunk_data.isel({tracker.timedim: relative_time_idx})
        try:
            blob_id_time_p1 = chunk_data.isel({tracker.timedim: relative_time_idx+1})
        except: # If this is the last chunk...
            blob_id_time_p1 = xr.full_like(blob_id_time, 0)
        if relative_time_idx-1 >= 0:
            blob_id_time_m1 = chunk_data.isel({tracker.timedim: relative_time_idx-1})
        elif updated_chunks:  # Get the last time slice from the previous chunk (stored in updated_chunks)
            _, _, last_chunk_data = updated_chunks[-1]
            blob_id_time_m1 = last_chunk_data[-1]
        else:
            blob_id_time_m1 = xr.full_like(blob_id_time, 0)
        
        child_mask_2d  = (blob_id_time == child_id).values
        
        # Find all pairs involving this Child Blob
        child_mask = overlap_blobs_list[:, 1] == child_id
        child_where = np.where(overlap_blobs_list[:, 1] == child_id)[0]  # Needed for assignment
        merge_group = overlap_blobs_list[child_mask]
        
        # Get all Parents (LHS) Blobs that overlap with this Child Blob -- N.B. This is now generalised for N-parent merging !
        parent_ids = merge_group[:, 0]
        num_parents = len(parent_ids)
        
        # Make a new ID for the other Half of the Child Blob & Record in the Merge Ledger
        new_blob_id = np.arange(next_new_id, next_new_id + (num_parents - 1), dtype=np.int32)
        next_new_id += num_parents - 1
        
        # Replace the 2nd+ Child in the Overlap Blobs List with the new Child ID
        overlap_blobs_list[child_where[1:], 1] = new_blob_id    #overlap_blobs_list[child_mask, 1][1:] = new_blob_id
        child_ids = np.concatenate((np.array([child_id]), new_blob_id))    #np.array([child_id, new_blob_id])
        
        # Record merge event data
        merge_times.append(chunk_data.isel({tracker.timedim: relative_time_idx}).time.values)
        merge_child_ids.append(child_ids)
        merge_parent_ids.append(parent_ids)
        merge_areas.append(overlap_blobs_list[child_mask, 2])
        
        ### Relabel the Original Child Blob ID Field to account for the New ID:
        parent_centroids = blob_props.sel(ID=parent_ids).centroid.values.T  # (y, x), [:,0] are the y's
        
        if tracker.nn_partitioning:
            # --> For every (Original) Child Cell in the ID Field, Find the closest (t-1) Parent _Cell_
            if tracker.unstructured_grid:
                parent_masks = np.zeros((len(parent_ids), blob_id_time.shape[0]), dtype=bool)
            else:
                parent_masks = np.zeros((len(parent_ids), blob_id_time.shape[0], blob_id_time.shape[1]), dtype=bool)
                
            for idx, parent_id in enumerate(parent_ids):
                parent_masks[idx] = (blob_id_time_m1 == parent_id).values
            
            # Calculate typical blob size to set max_distance
            max_area = np.max(blob_props.sel(ID=parent_ids).area.values) / tracker.mean_cell_area
            max_distance = int(np.sqrt(max_area) * 2.0)  # Use 2x the max blob radius
            
            if tracker.unstructured_grid:
                new_labels = get_nearest_parent_labels_unstructured(
                    child_mask_2d,
                    parent_masks,
                    child_ids,
                    parent_centroids,
                    tracker.neighbours_int.values,
                    tracker.data_bin.lat.values,  # Need to pass these as NumPy arrays for JIT compatibility
                    tracker.data_bin.lon.values,
                    max_distance=max(max_distance, 20)*2  # Set minimum threshold, in cells
                )
            else:
                new_labels = get_nearest_parent_labels(
                    child_mask_2d,
                    parent_masks, 
                    child_ids,
                    parent_centroids,
                    Nx,
                    max_distance=max(max_distance, 20)  # Set minimum threshold, in cells
                )
                
        else: 
            # --> For every (Original) Child Cell in the ID Field, Find the closest (t-1) Parent _Centroid_
            if tracker.unstructured_grid:
                new_labels = unstructured_centroid_partition(
                    child_mask_2d,
                    parent_centroids,
                    child_ids,
                    tracker.data_bin.lat.values,
                    tracker.data_bin.lon.values
                )                      
            else:
                distances = wrapped_euclidian_parallel(child_mask_2d, parent_centroids, Nx)  # **Deals with wrapping**

                # Assign the new ID to each cell based on the closest parent
                new_labels = child_ids[np.argmin(distances, axis=1)]
        
        
        ## Update values in child_time_idx and assign the updated slice back to the original DataArray
        temp = np.zeros_like(blob_id_time)
        temp[child_mask_2d] = new_labels
        blob_id_time = blob_id_time.where(~child_mask_2d, temp)
        ## ** Update directly into the chunk
        chunk_data[{tracker.timedim: relative_time_idx}] = blob_id_time
        
        
        ## Add new entries to time_index_map for each of new_blob_id corresponding to the current time index
        time_index_map.update({new_id: child_time_idx for new_id in new_blob_id})
        
        ## Update the Properties of the N Children Blobs
        new_child_props = tracker.calculate_blob_properties(blob_id_time, properties=['area', 'centroid'])
        
        # Update the blob_props DataArray:  (but first, check if the original Children still exists)
        if child_id in new_child_props.ID:  # Update the entry
            blob_props.loc[dict(ID=child_id)] = new_child_props.sel(ID=child_id)
        else:  # Delete child_id:  The blob has split/morphed such that it doesn't get a partition of this child...
            blob_props = blob_props.drop_sel(ID=child_id)  # N.B.: This means that the IDs are no longer continuous...
            print(f"Deleted child_id {child_id} because parents have split/morphed in the meantime...")
        # Add the properties for the N-1 other new child ID
        new_blob_ids_still = new_child_props.ID.where(new_child_props.ID.isin(new_blob_id), drop=True).ID
        blob_props = xr.concat([blob_props, new_child_props.sel(ID=new_blob_ids_still)], dim='ID')
        missing_ids = set(new_blob_id) - set(new_blob_ids_still.values)
        if len(missing_ids) > 0:
            print(f"Missing newly created child_ids {missing_ids} because parents have split/morphed in the meantime...")

        
        ## Finally, Re-assess all of the Parent IDs (LHS) equal to the (original) child_id
        
        # Look at the overlap IDs between the original child_id and the next time-step, and also the new_blob_id and the next time-step
        new_overlaps = tracker.check_overlap_slice(blob_id_time.values, blob_id_time_p1.values)
        new_child_overlaps_list = new_overlaps[(new_overlaps[:, 0] == child_id) | np.isin(new_overlaps[:, 0], new_blob_id)]
        
        # _Before_ replacing the overlap_blobs_list, we need to re-assess the overlap fractions of just the new_child_overlaps_list
        areas_0 = blob_props['area'].sel(ID=new_child_overlaps_list[:, 0]).values
        areas_1 = blob_props['area'].sel(ID=new_child_overlaps_list[:, 1]).values
        min_areas = np.minimum(areas_0, areas_1)
        overlap_fractions = new_child_overlaps_list[:, 2].astype(float) / min_areas
        new_child_overlaps_list = new_child_overlaps_list[overlap_fractions >= tracker.overlap_threshold]
        
        # Replace the lines in the overlap_blobs_list where (original) child_id is on the LHS, with these new pairs in new_child_overlaps_list
        child_mask_LHS = overlap_blobs_list[:, 0] == child_id
        overlap_blobs_list = np.concatenate([overlap_blobs_list[~child_mask_LHS], new_child_overlaps_list])
        
        
        ## Finally, _FINALLY_, we need to ensure that of the new children blobs we made, they only overlap with their respective parent...
        new_unique_children, new_children_counts = np.unique(new_child_overlaps_list[:, 1], return_counts=True)
        new_merging_blobs = new_unique_children[new_children_counts > 1]
        if new_merging_blobs.size > 0:
            
            if relative_time_idx + 1 < chunk_data.sizes[tracker.timedim]-1:  # If there is a next time-step in this chunk
                for new_child_id in new_merging_blobs:
                    if new_child_id not in blobs_to_process: # We aren't already going to assess this blob
                        blobs_to_process.insert(0, new_child_id)
            
            else: # This is out of our current jurisdiction: Defer this reassessment to the beginning of the next chunk
                future_chunk_merges.extend(new_merging_blobs)
        
    
    # Store the processed chunk
    updated_chunks.append((chunk_start, chunk_end-1, chunk_data[:(chunk_end-1-chunk_start)]))
    
    if chunk_idx % 10 == 0:
        print(f"Processing splitting and merging in chunk {chunk_idx} of {len(blobs_by_chunk)}")
        
        # Periodically update the main array to prevent memory buildup
        if len(updated_chunks) > 1:  # Keep the last chunk for potential blob_id_time_m1 reference
            for start, end, chunk_data in updated_chunks[:-1]:
                blob_id_field_unique[{tracker.timedim: slice(start, end)}] = chunk_data
            updated_chunks = updated_chunks[-1:]  # Keep only the last chunk
            blob_id_field_unique = blob_id_field_unique.persist() # Persist to collapse the dask graph !

# Final chunk updates
for start, end, chunk_data in updated_chunks:
    blob_id_field_unique[{tracker.timedim: slice(start, end)}] = chunk_data
blob_id_field_unique = blob_id_field_unique.persist()



Processing splitting and merging in chunk 0 of 8
Processing child_id: 233.0 at time index: 21


  weighted_x /= norm
  weighted_y /= norm
  weighted_z /= norm


Processing child_id: 242.0 at time index: 22


  weighted_x /= norm
  weighted_y /= norm
  weighted_z /= norm


Processing child_id: 249.0 at time index: 23


  weighted_x /= norm
  weighted_y /= norm
  weighted_z /= norm


Processing child_id: 256.0 at time index: 24


  weighted_x /= norm
  weighted_y /= norm
  weighted_z /= norm


Processing child_id: 263.0 at time index: 25


  weighted_x /= norm
  weighted_y /= norm
  weighted_z /= norm


Processing child_id: 272.0 at time index: 26


  weighted_x /= norm
  weighted_y /= norm
  weighted_z /= norm


Processing child_id: 281.0 at time index: 27


In [None]:
blob_id_field_unique.isel(time=20).pyic.plot()

In [None]:
blob_id_field_unique.isel(time=21).pyic.plot()

In [None]:
blob_id_field_unique.isel(time=22).pyic.plot()

In [None]:
blob_id_field_unique.isel(time=23).pyic.plot()

In [None]:
blob_id_field_unique.isel(time=24).pyic.plot()

In [None]:
blob_id_field_unique.isel(time=25).pyic.plot()

In [None]:
blob_id_field_unique.isel(time=26).pyic.plot()

In [None]:
# Save Tracked Blobs to `zarr` for more efficient parallel I/O

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'MHWs_tracked_unstruct.zarr'
blobs.to_zarr(file_name, mode='w')