# Identify & Track Marine Heatwaves on _Unstructured Grid_ using `spot_the_blOb`

## Processing Steps:
1. Fill spatial holes in the binary data, using `dask_image.ndmorph` -- up to `R_fill` cells in radius.
2. Fill gaps in time -- permitting up to `T_fill` missing time slices, while keeping the same blob ID.
3. Filter out small objects -- area less than the bottom `area_filter_quartile` of the size distribution of objects.
4. Identify objects in the binary data, using `dask_image.ndmeasure`.
5. Connect objects across time, applying the following criteria for splitting, merging, and persistence:
    - Connected Blobs must overlap by at least fraction `overlap_threshold` of the smaller blob.
    - Merged Blobs retain their original ID, but partition the child blob based on the parent of the _nearest-neighbour_ cell. 
6. Cluster and reduce the final object ID graph using `scipy.sparse.csgraph.connected_components`.
7. Map the tracked objects into ID-time space for convenient analysis.

N.B.: Exploits parallelised `dask` operations with optimised chunking using `flox` for memory efficiency and speed \
N.N.B.: This example using 40 years of _daily_ outputs at 5km resolution on an Unstructured Grid (15 million cells) using 32 cores takes 
- Full Split/Merge Thresholding & Merge Tracking:  ~40 minutes

In [1]:
import xarray as xr
import dask
from getpass import getuser
from pathlib import Path

import spot_the_blOb as blob
import spot_the_blOb.helper as hpc

In [None]:
# Start Dask Cluster
client = hpc.StartLocalCluster(n_workers=32, n_threads=2)

In [3]:
# Load Pre-processed Data (cf. `01_preprocess_extremes.ipynb`)

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'extreme_events_binary_unstruct.zarr'
chunk_size = {'time': 4, 'ncells': -1}
ds = xr.open_zarr(str(file_name), chunks={}).isel(time=slice(0,32)).chunk(chunk_size)

In [4]:
# Tracking Parameters

drop_area_quartile = 0.8  # Remove the smallest 80% of the identified blobs
hole_filling_radius = 32  # Fill small holes with radius < 32 elements, i.e. ~100 km
time_gap_fill = 2         # Allow gaps of 2 days and still continue the blob tracking with the same ID
allow_merging = True      # Allow blobs to split/merge. Keeps track of merge events & unique IDs.
overlap_threshold = 0.5   # Overlap threshold for merging blobs. If overlap < threshold, blobs keep independent IDs.
nn_partitioning = True    # Use new NN method to partition merged children blobs. If False, reverts to old method of Di Sun et al. 2023.

In [None]:
# SpOt & Track the Blobs & Merger Events

tracker = blob.Spotter(ds.extreme_events, ds.mask, R_fill=hole_filling_radius, T_fill = time_gap_fill, area_filter_quartile=drop_area_quartile, 
                       allow_merging=allow_merging, overlap_threshold=overlap_threshold, nn_partitioning=nn_partitioning, 
                       xdim='ncells',                 # Need to tell spot_the_blOb the new Unstructured dimension
                       unstructured_grid=True,        # Use Unstructured Grid
                       neighbours=ds.neighbours,      # Connectivity array for the Unstructured Grid Cells
                       cell_areas=ds.cell_areas,      # Cell areas for each Unstructured Grid Cell
                       debug=0,                       # Choose Debugging Level (max=2)
                       verbosity=3)                   # Choose Verbosity Level (0=None, 1=Basic, 2=Timing)

# blobs = tracker.run(return_merges=False)

# blobs

In [6]:
import xarray as xr
import numpy as np
from dask.distributed import wait
from dask_image.ndmeasure import label
from skimage.measure import regionprops_table
from dask_image.ndmorph import binary_closing as binary_closing_dask
from dask_image.ndmorph import binary_opening as binary_opening_dask
from scipy.ndimage import binary_closing, binary_opening
from scipy.sparse import coo_matrix, csr_matrix, eye
from scipy.sparse.csgraph import connected_components
from dask import persist
from dask import delayed
from dask import compute as dask_compute
import dask.array as dsa
from dask.base import is_dask_collection
from numba import jit, njit, prange
import jax.numpy as jnp
import warnings
import logging

In [None]:
# Compute Area of Initial Binary Data
raw_area = tracker.compute_area(tracker.data_bin)  # This is e.g. the initial Hobday area

# Fill Small Holes & Gaps between Objects
data_bin_filled = tracker.fill_holes(tracker.data_bin).persist()
wait(data_bin_filled)
if tracker.verbosity > 0:    print('Finished Filling Spatial Holes')

# Fill Small Time-Gaps between Objects
data_bin_gap_filled = tracker.fill_time_gaps(data_bin_filled).persist()
wait(data_bin_gap_filled)
if tracker.verbosity > 0:    print('Finished Filling Spatio-temporal Holes.')

# Remove Small Objects
data_bin_filtered, area_threshold, blob_areas, N_blobs_prefiltered = tracker.filter_small_blobs(data_bin_gap_filled)
if tracker.verbosity > 0:    print('Finished Filtering Small Blobs.')

# Clean Up & Persist Preprocessing (This helps avoid block-wise task fusion run_spec issues with dask)
data_bin_filtered = data_bin_filtered.persist()
wait(data_bin_filtered)
del data_bin_filled
del data_bin_gap_filled

# Compute Area of Morphologically-Processed & Filtered Data
processed_area = tracker.compute_area(data_bin_filtered)

In [None]:
data_bin = data_bin_filtered

blob_id_field, _ = tracker.identify_blobs(data_bin, time_connectivity=False)
blob_id_field = blob_id_field.persist()
if tracker.verbosity > 0:    print('Finished Blob Identification.')


if tracker.unstructured_grid:
    # Make the blob_id_field unique across time
    cumsum_ids = (blob_id_field.max(dim=tracker.xdim)).cumsum(tracker.timedim).shift({tracker.timedim: 1}, fill_value=0)
    blob_id_field = xr.where(blob_id_field > 0, blob_id_field + cumsum_ids, 0)
    blob_id_field = blob_id_field.persist()
    wait(blob_id_field)
    if tracker.verbosity > 0:    print('Finished Making Blobs Globally Unique.')

# Calculate Properties of each Blob
blob_props = tracker.calculate_blob_properties(blob_id_field, properties=['area', 'centroid'])
blob_props = blob_props.persist()
wait(blob_props)
if tracker.verbosity > 0:    print('Finished Calculating Blob Properties.')

blob_id_field_unique = blob_id_field

In [9]:
def update_blob_field(blob_id_field_unique, id_lookup, updates):
    """Update the blob field with chunk results using xarray operations.
    
    Parameters
    ----------
    blob_id_field_unique : xarray.DataArray
        The full blob field to update
    id_lookup : Dictionary
        Dictionary mapping temporary IDs to new IDs
    updates : xarray.DataArray
        DataArray of Dictionaries containing updates: 'spatial_indices' for each 'new_label'
    
    Returns
    -------
    xarray.DataArray
        Updated blob field
    """
    
    def apply_updates(data, updates):
        """Apply updates to a single chunk of data."""
        result = data.copy()
        for update in updates:
            
            spatial_indices = update['spatial_indices']
            new_label = id_lookup[update['new_label']]
            result[spatial_indices] = new_label
        
        return result

    result = xr.apply_ufunc(apply_updates,
                            blob_id_field_unique,
                            updates,
                            input_core_dims=[[tracker.xdim], []],
                            output_core_dims=[[tracker.xdim]],
                            dask='parallelized',
                            output_dtypes=[blob_id_field_unique.dtype],
                            vectorize=True).persist()
    
    return result

In [27]:


##################################
### Optimised Helper Functions ###
##################################


@jit(nopython=True, parallel=True, fastmath=True)
def wrapped_euclidian_parallel(mask_values, parent_centroids_values, Nx):
    """
    Optimised function for computing wrapped Euclidean distances.
    
    Parameters:
    -----------
    mask_values : np.ndarray
        2D boolean array where True indicates points to calculate distances for
    parent_centroids_values : np.ndarray
        Array of shape (n_parents, 2) containing (y, x) coordinates of parent centroids
    Nx : int
        Size of the x-dimension for wrapping
        
    Returns:
    --------
    distances : np.ndarray
        Array of shape (n_true_points, n_parents) with minimum distances
    """
    n_parents = len(parent_centroids_values)
    half_Nx = Nx / 2
    
    y_indices, x_indices = np.nonzero(mask_values)
    n_true = len(y_indices)
    
    distances = np.empty((n_true, n_parents), dtype=np.float64)
    
    # Precompute for faster access
    parent_y = parent_centroids_values[:, 0]
    parent_x = parent_centroids_values[:, 1]
    
    # Parallel loop over true positions
    for idx in prange(n_true):
        y, x = y_indices[idx], x_indices[idx]
        
        # Pre-compute y differences for all parents
        dy = y - parent_y
        
        # Pre-compute x differences for all parents
        dx = x - parent_x
        
        # Wrapping correction
        dx = np.where(dx > half_Nx, dx - Nx, dx)
        dx = np.where(dx < -half_Nx, dx + Nx, dx)
        
        distances[idx] = np.sqrt(dy * dy + dx * dx)
    
    return distances



@jit(nopython=True, fastmath=True)
def create_grid_index_arrays(points_y, points_x, grid_size, ny, nx):
    """
    Creates a grid-based spatial index using numpy arrays.
    """
    n_grids_y = (ny + grid_size - 1) // grid_size
    n_grids_x = (nx + grid_size - 1) // grid_size
    max_points_per_cell = len(points_y)
    
    grid_points = np.full((n_grids_y, n_grids_x, max_points_per_cell), -1, dtype=np.int32)
    grid_counts = np.zeros((n_grids_y, n_grids_x), dtype=np.int32)
    
    for idx in range(len(points_y)):
        grid_y = min(points_y[idx] // grid_size, n_grids_y - 1)
        grid_x = min(points_x[idx] // grid_size, n_grids_x - 1)
        count = grid_counts[grid_y, grid_x]
        if count < max_points_per_cell:
            grid_points[grid_y, grid_x, count] = idx
            grid_counts[grid_y, grid_x] += 1
    
    return grid_points, grid_counts

@jit(nopython=True, fastmath=True)
def calculate_wrapped_distance(y1, x1, y2, x2, nx, half_nx):
    """
    Calculate distance with periodic boundary conditions in x dimension.
    """
    dy = y1 - y2
    dx = x1 - x2
    
    if dx > half_nx:
        dx -= nx
    elif dx < -half_nx:
        dx += nx
        
    return np.sqrt(dy * dy + dx * dx)

@jit(nopython=True, parallel=True, fastmath=True)
def get_nearest_parent_labels(child_mask, parent_masks, child_ids, parent_centroids, Nx, max_distance=20):
    """
    Assigns labels based on nearest parent blob points.
    This is quite computationally-intensive, so we utilise many optimisations here...
    """
    
    ny, nx = child_mask.shape
    half_Nx = Nx / 2
    n_parents = len(parent_masks)
    grid_size = max(2, max_distance // 4)
    
    y_indices, x_indices = np.nonzero(child_mask)
    n_child_points = len(y_indices)
    
    min_distances = np.full(n_child_points, np.inf)
    parent_assignments = np.zeros(n_child_points, dtype=np.int32)
    found_close = np.zeros(n_child_points, dtype=np.bool_)
    
    for parent_idx in range(n_parents):
        py, px = np.nonzero(parent_masks[parent_idx])
        
        if len(py) == 0:  # Skip empty parents
            continue
            
        # Create grid index for this parent
        n_grids_y = (ny + grid_size - 1) // grid_size
        n_grids_x = (nx + grid_size - 1) // grid_size
        grid_points, grid_counts = create_grid_index_arrays(py, px, grid_size, ny, nx)
        
        # Process child points in parallel
        for child_idx in prange(n_child_points):
            if found_close[child_idx]:  # Skip if we already found an exact match
                continue
                
            child_y, child_x = y_indices[child_idx], x_indices[child_idx]
            grid_y = min(child_y // grid_size, n_grids_y - 1)
            grid_x = min(child_x // grid_size, n_grids_x - 1)
            
            min_dist_to_parent = np.inf
            
            # Check nearby grid cells
            for dy in range(-1, 2):
                grid_y_check = (grid_y + dy) % n_grids_y
                
                for dx in range(-1, 2):
                    grid_x_check = (grid_x + dx) % n_grids_x
                    
                    # Process points in this grid cell
                    n_points = grid_counts[grid_y_check, grid_x_check]
                    
                    for p_idx in range(n_points):
                        point_idx = grid_points[grid_y_check, grid_x_check, p_idx]
                        if point_idx == -1:
                            break
                        
                        dist = calculate_wrapped_distance(
                            child_y, child_x,
                            py[point_idx], px[point_idx],
                            Nx, half_Nx
                        )
                        
                        if dist > max_distance:
                            continue
                        
                        if dist < min_dist_to_parent:
                            min_dist_to_parent = dist
                            
                        if dist < 1e-6:  # Found exact same point (within numerical precision)
                            min_dist_to_parent = dist
                            found_close[child_idx] = True
                            break
                    
                    if found_close[child_idx]:
                        break
                
                if found_close[child_idx]:
                    break
            
            # Update assignment if this parent is closer
            if min_dist_to_parent < min_distances[child_idx]:
                min_distances[child_idx] = min_dist_to_parent
                parent_assignments[child_idx] = parent_idx
    
    # Handle any unassigned points using centroids
    unassigned = min_distances == np.inf
    if np.any(unassigned):
        for child_idx in np.nonzero(unassigned)[0]:
            child_y, child_x = y_indices[child_idx], x_indices[child_idx]
            min_dist = np.inf
            best_parent = 0
            
            for parent_idx in range(n_parents):
                # Calculate distance to centroid with periodic boundary conditions
                dist = calculate_wrapped_distance(
                    child_y, child_x,
                    parent_centroids[parent_idx, 0],
                    parent_centroids[parent_idx, 1],
                    Nx, half_Nx
                )
                
                if dist < min_dist:
                    min_dist = dist
                    best_parent = parent_idx
                    
            parent_assignments[child_idx] = best_parent
    
    # Convert from parent indices to child_ids
    new_labels = child_ids[parent_assignments]
    
    return new_labels


@jit(nopython=True, fastmath=True)
def get_nearest_parent_labels_unstructured(child_mask, parent_masks, child_ids, parent_centroids, neighbours_int, lat, lon, max_distance=20):
    """
    Optimised version of nearest parent label assignment for unstructured grids.
    Uses numpy arrays throughout to ensure Numba compatibility.
    
    Parameters
    ----------
    child_mask : np.ndarray
        1D boolean array where True indicates points in the child blob
    parent_masks : np.ndarray
        2D boolean array of shape (n_parents, n_points) where True indicates points in each parent blob
    child_ids : np.ndarray
        1D array containing the IDs to assign to each partition of the child blob
    parent_centroids : np.ndarray
        Array of shape (n_parents, 2) containing (lat, lon) coordinates of parent centroids in degrees
    neighbours_int : np.ndarray
        2D array of shape (3, n_points) containing indices of neighboring cells for each point
    lat / lon : np.ndarray
        Latitude/Longitude in degrees
    max_distance : int, optional
        Maximum number of edge hops to search for parent points
    
    Returns
    -------
    new_labels : np.ndarray
        1D array containing the assigned child_ids for each True point in child_mask
    """
    
    # Force contiguous arrays in memory for optimal vectorised performance (from indexing)
    child_mask = np.ascontiguousarray(child_mask)
    parent_masks = np.ascontiguousarray(parent_masks)
    
    n_points = len(child_mask)
    n_parents = len(parent_masks)
    
    # Pre-allocate arrays
    distances = np.full(n_points, np.inf, dtype=np.float32)
    parent_assignments = np.full(n_points, -1, dtype=np.int32)
    visited = np.zeros((n_parents, n_points), dtype=np.bool_)
    
    # Initialise with direct overlaps
    for parent_idx in range(n_parents):
        overlap_mask = parent_masks[parent_idx] & child_mask
        if np.any(overlap_mask):
            visited[parent_idx, overlap_mask] = True
            unclaimed_overlap = distances[overlap_mask] == np.inf
            if np.any(unclaimed_overlap):
                overlap_points = np.where(overlap_mask)[0]
                valid_points = overlap_points[unclaimed_overlap]
                distances[valid_points] = 0
                parent_assignments[valid_points] = parent_idx
    
    # Pre-compute trig values
    lat_rad = np.deg2rad(lat)
    lon_rad = np.deg2rad(lon)
    cos_lat = np.cos(lat_rad)
    
    # Graph traversal for remaining points
    current_distance = 0
    any_unassigned = np.any(child_mask & (parent_assignments == -1))
    
    while current_distance < max_distance and any_unassigned:
        current_distance += 1
        updates_made = False
        
        for parent_idx in range(n_parents):
            # Get current frontier points
            frontier_mask = visited[parent_idx]
            if not np.any(frontier_mask):
                continue
            
            # Process neighbors
            for i in range(3):  # For each neighbor direction
                neighbors = neighbours_int[i, frontier_mask]
                valid_neighbors = neighbors >= 0
                if not np.any(valid_neighbors):
                    continue
                    
                valid_points = neighbors[valid_neighbors]
                unvisited = ~visited[parent_idx, valid_points]
                new_points = valid_points[unvisited]
                
                if len(new_points) > 0:
                    visited[parent_idx, new_points] = True
                    update_mask = distances[new_points] > current_distance
                    if np.any(update_mask):
                        points_to_update = new_points[update_mask]
                        distances[points_to_update] = current_distance
                        parent_assignments[points_to_update] = parent_idx
                        updates_made = True
        
        if not updates_made:
            break
            
        any_unassigned = np.any(child_mask & (parent_assignments == -1))
    
    # Handle remaining unassigned points using great circle distances
    unassigned_mask = child_mask & (parent_assignments == -1)
    if np.any(unassigned_mask):
        parent_lat_rad = np.deg2rad(parent_centroids[:, 0])
        parent_lon_rad = np.deg2rad(parent_centroids[:, 1])
        cos_parent_lat = np.cos(parent_lat_rad)
        
        unassigned_points = np.where(unassigned_mask)[0]
        for point in unassigned_points:
            # Vectorised haversine calculation
            dlat = parent_lat_rad - lat_rad[point]
            dlon = parent_lon_rad - lon_rad[point]
            a = np.sin(dlat/2)**2 + cos_lat[point] * cos_parent_lat * np.sin(dlon/2)**2
            dist = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a))
            parent_assignments[point] = np.argmin(dist)
    
    # Return only the assignments for points in child_mask
    child_points = np.where(child_mask)[0]
    return child_ids[parent_assignments[child_points]]


@jit(nopython=True, parallel=True, fastmath=True)
def unstructured_centroid_partition(child_mask, parent_centroids, child_ids, lat, lon):
    """
    Assigns labels to child cells based on closest parent centroid using great circle distances.
    
    Parameters:
    -----------
    child_mask : np.ndarray
        1D boolean array indicating which cells belong to the child blob
    parent_centroids : np.ndarray
        Array of shape (n_parents, 2) containing (lat, lon) coordinates of parent centroids in degrees
    child_ids : np.ndarray
        Array of IDs to assign to each partition of the child blob
    lat / lon : np.ndarray
        Latitude/Longitude in degrees
        
    Returns:
    --------
    new_labels : np.ndarray
        1D array containing assigned child_ids for cells in child_mask
    """
    n_cells = len(child_mask)
    n_parents = len(parent_centroids)
    
    lat_rad = np.deg2rad(lat)
    lon_rad = np.deg2rad(lon)
    parent_coords_rad = np.deg2rad(parent_centroids)
    
    new_labels = np.zeros(n_cells, dtype=child_ids.dtype)
    
    # Process each child cell in parallel
    for i in prange(n_cells):
        if not child_mask[i]:
            continue
            
        min_dist = np.inf
        closest_parent = 0
        
        # Calculate great circle distance to each parent centroid
        for j in range(n_parents):
            dlat = parent_coords_rad[j, 0] - lat_rad[i]
            dlon = parent_coords_rad[j, 1] - lon_rad[i]
            
            # Use haversine formula for great circle distance
            a = np.sin(dlat/2)**2 + np.cos(lat_rad[i]) * np.cos(parent_coords_rad[j, 0]) * np.sin(dlon/2)**2
            dist = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a))
            
            if dist < min_dist:
                min_dist = dist
                closest_parent = j
        
        new_labels[i] = child_ids[closest_parent]
    
    return new_labels





## Helper Function for Super Fast Sparse Bool Multiply (*without the scipy+Dask Memory Leak*)
@njit(fastmath=True, parallel=True)
def sparse_bool_power(vec, sp_data, indices, indptr, exponent):
    vec = vec.T
    num_rows = indptr.size - 1
    num_cols = vec.shape[1]
    result = vec.copy()

    for _ in range(exponent):
        temp_result = np.zeros((num_rows, num_cols), dtype=np.bool_)

        for i in prange(num_rows):
            for j in range(indptr[i], indptr[i + 1]):
                if sp_data[j]:
                    for k in range(num_cols):
                        if result[indices[j], k]:
                            temp_result[i, k] = True

        result = temp_result

    return result.T

In [147]:
def process_chunk(chunk_data, chunk_data_m1, chunk_data_p1, merging_blobs, next_id_start, lat, lon, area, neighbours_int):
    """Process a single chunk of merging blobs.
    
    Parameters
    ----------
    chunk_data : numpy.ndarray
        Array of shape (n_time, ncells) for unstructured or (n_time, ny, nx) for structured
    chunk_data_m1 & chunk_data_p1 : numpy.ndarray
        Same as chunk_data but shifted by 1 in time
    merging_blobs : numpy.ndarray
        Array of shape (n_time, max_merges) containing merging blob IDs (0=none)
    next_id_start : numpy.ndarray
        Array of shape (n_time, max_merges) containing ID offsets
    
    Returns
    -------
    dict
        Dictionary containing updates for each timestep
    """
    n_time = chunk_data.shape[0]
    updates_by_time = []
    merge_events_by_time = []
    new_merging_blobs_just_end = []
    id_mappings_by_time = []
    
    # Pre-Convert lat/lon to Cartesian
    x = np.cos(np.radians(lat)) * np.cos(np.radians(lon))
    y = np.cos(np.radians(lat)) * np.sin(np.radians(lon))
    z = np.sin(np.radians(lat))

    # Process each timestep
    for t in range(n_time):
        # Initialise tracking variables
        updates = []
        merge_events = {
            'child_ids': [],
            'parent_ids': [],
            'areas': []
        }
        new_merging_blobs = set()
        id_mapping = {}
        next_new_id = next_id_start[t]  # Use the offset for this timestep
        
        # Get current time slice data
        if t == 0:
            data_m1 = chunk_data_m1[t]
            data_t = chunk_data[t]
        else:
            data_m1 = data_t
            data_t = data_p1
        data_p1 = chunk_data_p1[t]
        
        # Get non-zero merging blobs for this timestep
        merging_blobs_t = merging_blobs[t]
        merging_blobs_t = merging_blobs_t[merging_blobs_t > 0]
        
        if len(merging_blobs_t) == 0:
            # Store empty results for this timestep
            updates_by_time.append(updates)
            merge_events_by_time.append(merge_events)
            id_mappings_by_time.append(id_mapping)
            new_merging_blobs_just_end.append(set())
            continue
        
        # Process each merging blob at this timestep
        blobs_to_process = list(merging_blobs_t)
        
        
        while blobs_to_process:
            child_id = blobs_to_process.pop(0)
            
            # Get child mask and find overlapping parents
            child_mask_2d = (data_t == child_id)
            
            # Find parent blobs that overlap with this child
            parent_masks = []
            parent_centroids = []
            parent_ids = []
            parent_areas = []
            overlap_areas = []
            
            # Find all unique parent IDs that overlap with the child
            potential_parents = np.unique(data_m1[child_mask_2d])
            for parent_id in potential_parents[potential_parents > 0]:
                parent_mask = (data_m1 == parent_id)
                if np.any(parent_mask & child_mask_2d):
                    
                    # Check if overlap area is large enough
                    area_0 = area[parent_mask].sum()  # Parent area
                    area_1 = area[child_mask_2d].sum()
                    min_area = np.minimum(area_0, area_1)
                    overlap_area = area[parent_mask & child_mask_2d].sum()
                    overlap_fraction = overlap_area / min_area
                    
                    if overlap_fraction < tracker.overlap_threshold:
                        continue
                    
                    overlap_areas.append(overlap_area)
                    parent_masks.append(parent_mask)
                    parent_ids.append(parent_id)
                    
                    # Calculate centroid for this parent
                    mask_area = area[parent_mask]
                    weighted_x = np.sum(mask_area * x[parent_mask])
                    weighted_y = np.sum(mask_area * y[parent_mask])
                    weighted_z = np.sum(mask_area * z[parent_mask])
                    
                    norm = np.sqrt(weighted_x**2 + weighted_y**2 + weighted_z**2)
                    
                    # Convert back to lat/lon
                    centroid_lat = np.degrees(np.arcsin(weighted_z/norm))
                    centroid_lon = np.degrees(np.arctan2(weighted_y, weighted_x))
                    
                    # Fix longitude range to [-180, 180]
                    if centroid_lon > 180:
                        centroid_lon -= 360
                    elif centroid_lon < -180:
                        centroid_lon += 360
                    
                    parent_centroids.append([centroid_lat, centroid_lon])
                    parent_areas.append(area_0) 
            
            if len(parent_ids) < 2:  # Need at least 2 parents for merging
                continue
                
            parent_masks = np.array(parent_masks)
            parent_centroids = np.array(parent_centroids, dtype=np.float32)
            parent_ids = np.array(parent_ids)
            parent_areas = np.array(parent_areas)
            overlap_areas = np.array(overlap_areas)
            
            # Create new IDs for each partition
            new_child_ids = np.arange(next_new_id, next_new_id + (len(parent_ids) - 1), dtype=np.int32)
            child_ids = np.concatenate((np.array([child_id], dtype=np.int32), new_child_ids))
            
            # Update ID tracking
            for new_id in child_ids[1:]:
                id_mapping[new_id] = None
            next_new_id += len(parent_ids) - 1
            
            # Get new labels based on partitioning method
            if tracker.nn_partitioning:
                # Estimate max_area from number of cells
                max_area = parent_areas.max() / tracker.mean_cell_area
                max_distance = int(np.sqrt(max_area) * 2.0)
                
                new_labels = get_nearest_parent_labels_unstructured(
                    child_mask_2d,
                    parent_masks,
                    child_ids,
                    parent_centroids,
                    neighbours_int,
                    lat,
                    lon,
                    max_distance=max(max_distance, 20)*2
                )
            else:
                new_labels = unstructured_centroid_partition(
                    child_mask_2d,
                    parent_centroids,
                    child_ids,
                    lat,
                    lon
                )
            
            # Update slice data
            data_t[child_mask_2d] = new_labels
            spatial_indices_all = np.where(child_mask_2d)[0]
            
            for new_id in child_ids[1:]:
                # Get spatial indices where we need to update
                new_id_mask = (new_labels == new_id)
                spatial_indices = spatial_indices_all[new_id_mask]
                
                # Store the updates
                updates.append({
                    'spatial_indices': spatial_indices,
                    'new_label': new_id
                })
            
            # Record merge event
            merge_events['child_ids'].append(child_ids)
            merge_events['parent_ids'].append(parent_ids)
            merge_events['areas'].append(overlap_areas)
            
            # Find all child blobs in the next timestep that overlap with our newly labeled regions
            new_merging = []
            for new_id in child_ids:
                parent_mask = (data_t == new_id)                        
                potential_children = np.unique(data_p1[parent_mask])
                area_0 = area[parent_mask].sum()
                
                for potential_child in potential_children:
                    # Check if overlap area is large enough
                    potential_child_mask = (data_p1==potential_child)
                    area_1 = area[potential_child_mask].sum()
                    min_area = np.minimum(area_0, area_1)
                    overlap_area = area[parent_mask & potential_child_mask].sum()
                    overlap_fraction = overlap_area / min_area
                    
                    if overlap_fraction > tracker.overlap_threshold:
                        new_merging.append(potential_child)                        
            
            # Add to new merging blobs set
            new_merging_blobs.update(new_merging)
            
            # Add to processing queue if not already processed
            for new_blob_id in new_merging:
                if new_blob_id not in blobs_to_process and new_blob_id not in merging_blobs_t:
                    blobs_to_process.append(new_blob_id)
        

        # Store results for this timestep
        updates_by_time.append(updates)
        merge_events_by_time.append(merge_events)
        id_mappings_by_time.append(id_mapping)
        if t < n_time - 1:
            new_merging_blobs_just_end.append(set())
        else:
            new_merging_blobs_just_end.append(new_merging_blobs)
    
    # Outputs need to be an array with dimension n_time:
    
    results_dict = {
        'merge_events': merge_events_by_time,
        'id_mappings': id_mappings_by_time,
        'next_chunk_merge': new_merging_blobs_just_end
    }

    # Create a list to hold dictionaries for each time step
    time_step_dicts = []

    for t in range(n_time):
        time_step_dict = {key: value[t] for key, value in results_dict.items()}
        time_step_dicts.append(time_step_dict)

    time_step_array = np.array(time_step_dicts, dtype=object)
    updates_by_time_array = np.array(updates_by_time, dtype=object)
    
    return time_step_array, updates_by_time_array

In [148]:
# Compile List of Overlapping Blob ID Pairs Across Time
overlap_blobs_list = tracker.find_overlapping_blobs(blob_id_field_unique, blob_props)  # List blob pairs that overlap by at least overlap_threshold percent
if tracker.verbosity > 0:    print('Finished Finding Overlapping Blobs.')

# Find initial merging blobs
unique_children, children_counts = np.unique(overlap_blobs_list[:, 1], return_counts=True)
merging_blobs = set(unique_children[children_counts > 1])


## Process chunks iteratively until no new merging blobs remain
iteration = 0
max_iterations = 10
processed_chunks = set()
global_id_counter = blob_props.ID.max().item() + 1

# Initialise global merge event tracking
all_merge_events = {
    'times': [],
    'child_ids': [],
    'parent_ids': [],
    'areas': []
}

n_time = len(blob_id_field_unique[tracker.timedim])
time_indices = xr.DataArray(np.arange(n_time),dims=[tracker.timedim],coords={tracker.timedim: blob_id_field_unique[tracker.timedim]})



Finished Finding Overlapping Blobs.


In [149]:
# WHILE merging_blobs:::::

if tracker.verbosity > 0:    print(f"Processing Parallel Iteration {iteration + 1} with {len(merging_blobs)} Merging Blobs...")

# Pre-compute the child_time_idx for merging_blobs
time_index_map = tracker.compute_id_time_dict(blob_id_field_unique, list(merging_blobs), global_id_counter)

# Create the uniform merging blobs array
max_merges = max(len([b for b in merging_blobs if time_index_map.get(b, -1) == t]) for t in range(n_time))

uniform_merging_blobs_array = np.array([
    np.pad([b for b in merging_blobs if time_index_map.get(b, -1) == t], (0, max_merges - len([b for b in merging_blobs if time_index_map.get(b, -1) == t])), 'constant')
    for t in range(n_time)
])
merging_blobs_da = xr.DataArray(
    uniform_merging_blobs_array,
    dims=[tracker.timedim, 'merges'],
    coords={tracker.timedim: blob_id_field_unique[tracker.timedim]})

next_id_offsets = np.arange(n_time) * max_merges
next_id_offsets_da = xr.DataArray(next_id_offsets,
                                dims=[tracker.timedim],
                                coords={tracker.timedim: blob_id_field_unique[tracker.timedim]})

blob_id_field_unique_p1 = blob_id_field_unique.shift({tracker.timedim: -1}, fill_value=0)
blob_id_field_unique_m1 = blob_id_field_unique.shift({tracker.timedim: 1}, fill_value=0)


Processing Parallel Iteration 1 with 6 Merging Blobs...




In [150]:
chunk_size = 4  # Or whatever size is appropriate
blob_id_field_unique = blob_id_field_unique.chunk({tracker.timedim: chunk_size})
blob_id_field_unique_m1 = blob_id_field_unique_m1.chunk({tracker.timedim: chunk_size})
blob_id_field_unique_p1 = blob_id_field_unique_p1.chunk({tracker.timedim: chunk_size})
merging_blobs_da = merging_blobs_da.chunk({tracker.timedim: chunk_size})
next_id_offsets_da = next_id_offsets_da.chunk({tracker.timedim: chunk_size})

In [151]:
results, updates = xr.apply_ufunc(process_chunk,
                                 blob_id_field_unique,
                                 blob_id_field_unique_m1,
                                 blob_id_field_unique_p1,
                                 merging_blobs_da,
                                 next_id_offsets_da,
                                 blob_id_field_unique.lat,
                                 blob_id_field_unique.lon,
                                 tracker.cell_area,
                                 tracker.neighbours_int,
                                 input_core_dims=[[tracker.xdim], [tracker.xdim], [tracker.xdim], [], [], [tracker.xdim], [tracker.xdim], [tracker.xdim], ['nv', tracker.xdim]],
                                 output_core_dims=[[], []],
                                 output_dtypes=[object, object],
                                 vectorize=False,
                                 dask='parallelized')

In [152]:
results, updates = persist(results, updates)

In [170]:
resultsc = results.compute()

ValueError: axes don't match array

In [169]:
results.isel(time=1).isel(merges=0)['id_mappings']

KeyError: 'id_mappings'

In [167]:
temp_id_arrays = [
    np.array(list(res['id_mappings'].keys()), dtype=np.int64) 
    for t in range(results.sizes[results.dims[0]])     # time dimension
    for m in range(results.sizes[results.dims[1]])     # merges dimension
    for res in [results.isel({results.dims[0]: t, results.dims[1]: m})]
]

KeyError: 'id_mappings'

In [166]:
temp_id_arrays

[]

In [15]:
# blobs = blobs.compute() 

In [16]:
# Save Tracked Blobs to `zarr` for more efficient parallel I/O

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'MHWs_tracked_unstruct.zarr'
blobs.to_zarr(file_name, mode='w')