# Identify & Track Marine Heatwaves on _Unstructured Grid_ using `spot_the_blOb`

## Processing Steps:
1. Fill spatial holes in the binary data, using `dask_image.ndmorph` -- up to `R_fill` cells in radius.
2. Fill gaps in time -- permitting up to `T_fill` missing time slices, while keeping the same blob ID.
3. Filter out small objects -- area less than the bottom `area_filter_quartile` of the size distribution of objects.
4. Identify objects in the binary data, using `dask_image.ndmeasure`.
5. Connect objects across time, applying the following criteria for splitting, merging, and persistence:
    - Connected Blobs must overlap by at least fraction `overlap_threshold` of the smaller blob.
    - Merged Blobs retain their original ID, but partition the child blob based on the parent of the _nearest-neighbour_ cell. 
6. Cluster and reduce the final object ID graph using `scipy.sparse.csgraph.connected_components`.
7. Map the tracked objects into ID-time space for convenient analysis.

N.B.: Exploits parallelised `dask` operations with optimised chunking using `flox` for memory efficiency and speed \
N.N.B.: This example using 40 years of _daily_ outputs at 5km resolution on an Unstructured Grid (15 million cells) using 32 cores takes 
- Full Split/Merge Thresholding & Merge Tracking:  ~40 minutes

In [1]:
import xarray as xr
import dask
from getpass import getuser
from pathlib import Path

import spot_the_blOb as blob
import spot_the_blOb.helper as hpc

In [2]:
# Start Dask Cluster
#  N.B.: Need ~ 8 GB per worker (for 5km data // 15 million points)
client = hpc.StartLocalCluster(n_workers=50, n_threads=1)

Memory per Worker: 10.07 GB
Hostname is  l40235
Forward Port = l40235:8787
Dashboard Link: localhost:8787/status


In [3]:
# Load Pre-processed Data (cf. `01_preprocess_extremes.ipynb`)

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'extreme_events_binary_unstruct.zarr'
chunk_size = {'time': 4, 'ncells': -1}
ds = xr.open_zarr(str(file_name), chunks={}).isel(time=slice(0, 128)).chunk(chunk_size)

In [4]:
# Tracking Parameters

drop_area_quartile = 0.8  # Remove the smallest 80% of the identified blobs
hole_filling_radius = 32  # Fill small holes with radius < 32 elements, i.e. ~100 km
time_gap_fill = 2         # Allow gaps of 4 days and still continue the blob tracking with the same ID
allow_merging = True      # Allow blobs to split/merge. Keeps track of merge events & unique IDs.
overlap_threshold = 0.5   # Overlap threshold for merging blobs. If overlap < threshold, blobs keep independent IDs.
nn_partitioning = True    # Use new NN method to partition merged children blobs. If False, reverts to old method of Di Sun et al. 2023.

In [5]:
# SpOt & Track the Blobs & Merger Events

tracker = blob.Spotter(ds.extreme_events, ds.mask, R_fill=hole_filling_radius, T_fill = time_gap_fill, area_filter_quartile=drop_area_quartile, 
                       allow_merging=allow_merging, overlap_threshold=overlap_threshold, nn_partitioning=nn_partitioning, 
                       xdim='ncells',                 # Need to tell spot_the_blOb the new Unstructured dimension
                       unstructured_grid=True,        # Use Unstructured Grid
                       neighbours=ds.neighbours,      # Connectivity array for the Unstructured Grid Cells
                       cell_areas=ds.cell_areas,      # Cell areas for each Unstructured Grid Cell
                       verbosity=1)                   # Choose Verbosity Level (0=None, 1=Basic, 2=Advanced/Timing)

# blobs = tracker.run(return_merges=False)

# blobs

Finished Constructing the Sparse Dilation Matrix.


In [6]:
del ds

In [7]:
import xarray as xr
import numpy as np
from dask.distributed import wait
from dask_image.ndmeasure import label
from skimage.measure import regionprops_table
from dask_image.ndmorph import binary_closing as binary_closing_dask
from dask_image.ndmorph import binary_opening as binary_opening_dask
from scipy.ndimage import binary_closing, binary_opening
from scipy.sparse import coo_matrix, csr_matrix, eye
from scipy.sparse.csgraph import connected_components
from dask import persist
from dask import delayed
from dask import compute as dask_compute
import dask.array as dsa
from dask.base import is_dask_collection
from numba import jit, njit, prange
import jax.numpy as jnp
from collections import defaultdict
import warnings
import logging
import gc

In [8]:
# Compute Area of Initial Binary Data
raw_area = tracker.compute_area(tracker.data_bin)  # This is e.g. the initial Hobday area

# Fill Small Holes & Gaps between Objects
data_bin_filled = tracker.fill_holes(tracker.data_bin)
# Delete the original binary data to free up memory
del tracker.data_bin
if tracker.verbosity > 0:    print('Finished Filling Spatial Holes')

# Fill Small Time-Gaps between Objects
data_bin_filled = tracker.fill_time_gaps(data_bin_filled).persist()
if tracker.verbosity > 0:    print('Finished Filling Spatio-temporal Holes.')

# Remove Small Objects
data_bin_filtered, area_threshold, blob_areas, N_blobs_prefiltered = tracker.filter_small_blobs(data_bin_filled)
del data_bin_filled
if tracker.verbosity > 0:    print('Finished Filtering Small Blobs.')

# Clean Up & Persist Preprocessing (This helps avoid block-wise task fusion run_spec issues with dask)
data_bin_filtered = data_bin_filtered.persist()
wait(data_bin_filtered)
        
# Compute Area of Morphologically-Processed & Filtered Data
processed_area = tracker.compute_area(data_bin_filtered)

Finished Filling Spatial Holes
Finished Filling Spatio-temporal Holes.
Finished Filtering Small Blobs.


In [9]:
data_bin = data_bin_filtered
del data_bin_filtered

In [10]:
# Cluster & ID Binary Data at each Time Step
blob_id_field, _ = tracker.identify_blobs(data_bin, time_connectivity=False)
blob_id_field = blob_id_field.persist()
del data_bin
if tracker.verbosity > 0:    print('Finished Blob Identification.')

if tracker.unstructured_grid:
    # Make the blob_id_field unique across time
    cumsum_ids = (blob_id_field.max(dim=tracker.xdim)).cumsum(tracker.timedim).shift({tracker.timedim: 1}, fill_value=0)
    blob_id_field = xr.where(blob_id_field > 0, blob_id_field + cumsum_ids, 0)
    blob_id_field = blob_id_field.persist()
    if tracker.verbosity > 0:    print('Finished Making Blobs Globally Unique.')
    del cumsum_ids

# Calculate Properties of each Blob
blob_props = tracker.calculate_blob_properties(blob_id_field, properties=['area', 'centroid'])
blob_props = blob_props.persist()
wait(blob_props)
if tracker.verbosity > 0:    print('Finished Calculating Blob Properties.')

Finished Blob Identification.
Finished Making Blobs Globally Unique.
Finished Calculating Blob Properties.


In [11]:


##################################
### Optimised Helper Functions ###
##################################


@jit(nopython=True, parallel=True, fastmath=True)
def wrapped_euclidian_parallel(mask_values, parent_centroids_values, Nx):
    """
    Optimised function for computing wrapped Euclidean distances.
    
    Parameters:
    -----------
    mask_values : np.ndarray
        2D boolean array where True indicates points to calculate distances for
    parent_centroids_values : np.ndarray
        Array of shape (n_parents, 2) containing (y, x) coordinates of parent centroids
    Nx : int
        Size of the x-dimension for wrapping
        
    Returns:
    --------
    distances : np.ndarray
        Array of shape (n_true_points, n_parents) with minimum distances
    """
    n_parents = len(parent_centroids_values)
    half_Nx = Nx / 2
    
    y_indices, x_indices = np.nonzero(mask_values)
    n_true = len(y_indices)
    
    distances = np.empty((n_true, n_parents), dtype=np.float64)
    
    # Precompute for faster access
    parent_y = parent_centroids_values[:, 0]
    parent_x = parent_centroids_values[:, 1]
    
    # Parallel loop over true positions
    for idx in prange(n_true):
        y, x = y_indices[idx], x_indices[idx]
        
        # Pre-compute y differences for all parents
        dy = y - parent_y
        
        # Pre-compute x differences for all parents
        dx = x - parent_x
        
        # Wrapping correction
        dx = np.where(dx > half_Nx, dx - Nx, dx)
        dx = np.where(dx < -half_Nx, dx + Nx, dx)
        
        distances[idx] = np.sqrt(dy * dy + dx * dx)
    
    return distances



@jit(nopython=True, fastmath=True)
def create_grid_index_arrays(points_y, points_x, grid_size, ny, nx):
    """
    Creates a grid-based spatial index using numpy arrays.
    """
    n_grids_y = (ny + grid_size - 1) // grid_size
    n_grids_x = (nx + grid_size - 1) // grid_size
    max_points_per_cell = len(points_y)
    
    grid_points = np.full((n_grids_y, n_grids_x, max_points_per_cell), -1, dtype=np.int32)
    grid_counts = np.zeros((n_grids_y, n_grids_x), dtype=np.int32)
    
    for idx in range(len(points_y)):
        grid_y = min(points_y[idx] // grid_size, n_grids_y - 1)
        grid_x = min(points_x[idx] // grid_size, n_grids_x - 1)
        count = grid_counts[grid_y, grid_x]
        if count < max_points_per_cell:
            grid_points[grid_y, grid_x, count] = idx
            grid_counts[grid_y, grid_x] += 1
    
    return grid_points, grid_counts

@jit(nopython=True, fastmath=True)
def calculate_wrapped_distance(y1, x1, y2, x2, nx, half_nx):
    """
    Calculate distance with periodic boundary conditions in x dimension.
    """
    dy = y1 - y2
    dx = x1 - x2
    
    if dx > half_nx:
        dx -= nx
    elif dx < -half_nx:
        dx += nx
        
    return np.sqrt(dy * dy + dx * dx)

@jit(nopython=True, parallel=True, fastmath=True)
def partition_nn_grid(child_mask, parent_masks, child_ids, parent_centroids, Nx, max_distance=20):
    """
    Assigns labels based on nearest parent blob points.
    This is quite computationally-intensive, so we utilise many optimisations here...
    """
    
    ny, nx = child_mask.shape
    half_Nx = Nx / 2
    n_parents = len(parent_masks)
    grid_size = max(2, max_distance // 4)
    
    y_indices, x_indices = np.nonzero(child_mask)
    n_child_points = len(y_indices)
    
    min_distances = np.full(n_child_points, np.inf)
    parent_assignments = np.zeros(n_child_points, dtype=np.int32)
    found_close = np.zeros(n_child_points, dtype=np.bool_)
    
    for parent_idx in range(n_parents):
        py, px = np.nonzero(parent_masks[parent_idx])
        
        if len(py) == 0:  # Skip empty parents
            continue
            
        # Create grid index for this parent
        n_grids_y = (ny + grid_size - 1) // grid_size
        n_grids_x = (nx + grid_size - 1) // grid_size
        grid_points, grid_counts = create_grid_index_arrays(py, px, grid_size, ny, nx)
        
        # Process child points in parallel
        for child_idx in prange(n_child_points):
            if found_close[child_idx]:  # Skip if we already found an exact match
                continue
                
            child_y, child_x = y_indices[child_idx], x_indices[child_idx]
            grid_y = min(child_y // grid_size, n_grids_y - 1)
            grid_x = min(child_x // grid_size, n_grids_x - 1)
            
            min_dist_to_parent = np.inf
            
            # Check nearby grid cells
            for dy in range(-1, 2):
                grid_y_check = (grid_y + dy) % n_grids_y
                
                for dx in range(-1, 2):
                    grid_x_check = (grid_x + dx) % n_grids_x
                    
                    # Process points in this grid cell
                    n_points = grid_counts[grid_y_check, grid_x_check]
                    
                    for p_idx in range(n_points):
                        point_idx = grid_points[grid_y_check, grid_x_check, p_idx]
                        if point_idx == -1:
                            break
                        
                        dist = calculate_wrapped_distance(
                            child_y, child_x,
                            py[point_idx], px[point_idx],
                            Nx, half_Nx
                        )
                        
                        if dist > max_distance:
                            continue
                        
                        if dist < min_dist_to_parent:
                            min_dist_to_parent = dist
                            
                        if dist < 1e-6:  # Found exact same point (within numerical precision)
                            min_dist_to_parent = dist
                            found_close[child_idx] = True
                            break
                    
                    if found_close[child_idx]:
                        break
                
                if found_close[child_idx]:
                    break
            
            # Update assignment if this parent is closer
            if min_dist_to_parent < min_distances[child_idx]:
                min_distances[child_idx] = min_dist_to_parent
                parent_assignments[child_idx] = parent_idx
    
    # Handle any unassigned points using centroids
    unassigned = min_distances == np.inf
    if np.any(unassigned):
        for child_idx in np.nonzero(unassigned)[0]:
            child_y, child_x = y_indices[child_idx], x_indices[child_idx]
            min_dist = np.inf
            best_parent = 0
            
            for parent_idx in range(n_parents):
                # Calculate distance to centroid with periodic boundary conditions
                dist = calculate_wrapped_distance(
                    child_y, child_x,
                    parent_centroids[parent_idx, 0],
                    parent_centroids[parent_idx, 1],
                    Nx, half_Nx
                )
                
                if dist < min_dist:
                    min_dist = dist
                    best_parent = parent_idx
                    
            parent_assignments[child_idx] = best_parent
    
    # Convert from parent indices to child_ids
    new_labels = child_ids[parent_assignments]
    
    return new_labels


@jit(nopython=True, fastmath=True)
def partition_nn_unstructured(child_mask, parent_masks, child_ids, parent_centroids, neighbours_int, lat, lon, max_distance=20):
    """
    Optimised version of nearest parent label assignment for unstructured grids.
    Uses numpy arrays throughout to ensure Numba compatibility.
    
    Parameters
    ----------
    child_mask : np.ndarray
        1D boolean array where True indicates points in the child blob
    parent_masks : np.ndarray
        2D boolean array of shape (n_parents, n_points) where True indicates points in each parent blob
    child_ids : np.ndarray
        1D array containing the IDs to assign to each partition of the child blob
    parent_centroids : np.ndarray
        Array of shape (n_parents, 2) containing (lat, lon) coordinates of parent centroids in degrees
    neighbours_int : np.ndarray
        2D array of shape (3, n_points) containing indices of neighboring cells for each point
    lat / lon : np.ndarray
        Latitude/Longitude in degrees
    max_distance : int, optional
        Maximum number of edge hops to search for parent points
    
    Returns
    -------
    new_labels : np.ndarray
        1D array containing the assigned child_ids for each True point in child_mask
    """
    
    # Force contiguous arrays in memory for optimal vectorised performance (from indexing)
    child_mask = np.ascontiguousarray(child_mask)
    parent_masks = np.ascontiguousarray(parent_masks)
    
    n_points = len(child_mask)
    n_parents = len(parent_masks)
    
    # Pre-allocate arrays
    distances = np.full(n_points, np.inf, dtype=np.float32)
    parent_assignments = np.full(n_points, -1, dtype=np.int32)
    visited = np.zeros((n_parents, n_points), dtype=np.bool_)
    
    # Initialise with direct overlaps
    for parent_idx in range(n_parents):
        overlap_mask = parent_masks[parent_idx] & child_mask
        if np.any(overlap_mask):
            visited[parent_idx, overlap_mask] = True
            unclaimed_overlap = distances[overlap_mask] == np.inf
            if np.any(unclaimed_overlap):
                overlap_points = np.where(overlap_mask)[0]
                valid_points = overlap_points[unclaimed_overlap]
                distances[valid_points] = 0
                parent_assignments[valid_points] = parent_idx
    
    # Pre-compute trig values
    lat_rad = np.deg2rad(lat)
    lon_rad = np.deg2rad(lon)
    cos_lat = np.cos(lat_rad)
    
    # Graph traversal for remaining points
    current_distance = 0
    any_unassigned = np.any(child_mask & (parent_assignments == -1))
    
    while current_distance < max_distance and any_unassigned:
        current_distance += 1
        updates_made = False
        
        for parent_idx in range(n_parents):
            # Get current frontier points
            frontier_mask = visited[parent_idx]
            if not np.any(frontier_mask):
                continue
            
            # Process neighbors
            for i in range(3):  # For each neighbor direction
                neighbors = neighbours_int[i, frontier_mask]
                valid_neighbors = neighbors >= 0
                if not np.any(valid_neighbors):
                    continue
                    
                valid_points = neighbors[valid_neighbors]
                unvisited = ~visited[parent_idx, valid_points]
                new_points = valid_points[unvisited]
                
                if len(new_points) > 0:
                    visited[parent_idx, new_points] = True
                    update_mask = distances[new_points] > current_distance
                    if np.any(update_mask):
                        points_to_update = new_points[update_mask]
                        distances[points_to_update] = current_distance
                        parent_assignments[points_to_update] = parent_idx
                        updates_made = True
        
        if not updates_made:
            break
            
        any_unassigned = np.any(child_mask & (parent_assignments == -1))
    
    # Handle remaining unassigned points using great circle distances
    unassigned_mask = child_mask & (parent_assignments == -1)
    if np.any(unassigned_mask):
        parent_lat_rad = np.deg2rad(parent_centroids[:, 0])
        parent_lon_rad = np.deg2rad(parent_centroids[:, 1])
        cos_parent_lat = np.cos(parent_lat_rad)
        
        unassigned_points = np.where(unassigned_mask)[0]
        for point in unassigned_points:
            # Vectorised haversine calculation
            dlat = parent_lat_rad - lat_rad[point]
            dlon = parent_lon_rad - lon_rad[point]
            a = np.sin(dlat/2)**2 + cos_lat[point] * cos_parent_lat * np.sin(dlon/2)**2
            dist = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a))
            parent_assignments[point] = np.argmin(dist)
    
    # Return only the assignments for points in child_mask
    child_points = np.where(child_mask)[0]
    return child_ids[parent_assignments[child_points]]



@jit(nopython=True, fastmath=True)
def partition_nn_unstructured_optimised(child_mask, parent_frontiers, parent_centroids, neighbours_int, lat, lon, max_distance=20):
    """
    Memory-optimised version of nearest parent label assignment for unstructured grids.
    Uses numpy arrays throughout to ensure Numba compatibility.
    
    Parameters
    ----------
    child_mask : np.ndarray
        1D boolean array where True indicates points in the child blob
    parent_frontiers : np.ndarray
        1D uint8 array with parent indices (255 for unvisited points)
    parent_centroids : np.ndarray
        Array of shape (n_parents, 2) containing (lat, lon) coordinates
    neighbours_int : np.ndarray
        2D array of shape (3, n_points) containing indices of neighboring cells
    lat / lon : np.ndarray
        1D arrays of latitude/longitude in degrees
    max_distance : int
        Maximum number of edge hops to search for parent points
    
    Returns
    -------
    result : np.ndarray
        1D array containing the assigned parent indices for points in child_mask
    """
    
    # Create working copies to ensure memory cleanup
    parent_frontiers_working = parent_frontiers.copy()
    child_mask_working = child_mask.copy()
    
    n_parents = np.max(parent_frontiers_working[parent_frontiers_working < 255]) + 1
    
    # Graph traversal
    current_distance = 0
    any_unassigned = np.any(child_mask_working & (parent_frontiers_working == 255))
    
    while current_distance < max_distance and any_unassigned:
        current_distance += 1
        updates_made = False
        
        for parent_idx in range(n_parents):
            # Skip if no frontier points for this parent
            if not np.any(parent_frontiers_working == parent_idx):
                continue
            
            # Process neighbours for current parent's frontier
            for i in range(3):
                neighbors = neighbours_int[i, parent_frontiers_working == parent_idx]
                valid_neighbors = neighbors >= 0
                
                if not np.any(valid_neighbors):
                    continue
                
                valid_points = neighbors[valid_neighbors]
                unvisited = parent_frontiers_working[valid_points] == 255
                
                if not np.any(unvisited):
                    continue
                
                # Update new frontier points
                new_points = valid_points[unvisited]
                parent_frontiers_working[new_points] = parent_idx
                
                if np.any(child_mask_working[new_points]):
                    updates_made = True
        
        if not updates_made:
            break
            
        any_unassigned = np.any(child_mask_working & (parent_frontiers_working == 255))
    
    # Handle remaining unassigned points using great circle distances
    unassigned_mask = child_mask_working & (parent_frontiers_working == 255)
    if np.any(unassigned_mask):
        # Pre-compute parent coordinates in radians
        parent_lat_rad = np.deg2rad(parent_centroids[:, 0])
        parent_lon_rad = np.deg2rad(parent_centroids[:, 1])
        cos_parent_lat = np.cos(parent_lat_rad)
        
        # Process each unassigned point
        unassigned_points = np.where(unassigned_mask)[0]
        for point in unassigned_points:
            dlat = parent_lat_rad - np.deg2rad(lat[point])
            dlon = parent_lon_rad - np.deg2rad(lon[point])
            
            a = np.sin(dlat/2)**2 + np.cos(np.deg2rad(lat[point])) * cos_parent_lat * np.sin(dlon/2)**2
            dist = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a))
            
            parent_frontiers_working[point] = np.argmin(dist)
    
    # Extract result for child points only
    result = parent_frontiers_working[child_mask_working].copy()
    
    # Explicitly clear working arrays
    parent_frontiers_working = None
    child_mask_working = None
    
    return result



@jit(nopython=True, parallel=True, fastmath=True)
def partition_centroid_unstructured(child_mask, parent_centroids, child_ids, lat, lon):
    """
    Assigns labels to child cells based on closest parent centroid using great circle distances.
    
    Parameters:
    -----------
    child_mask : np.ndarray
        1D boolean array indicating which cells belong to the child blob
    parent_centroids : np.ndarray
        Array of shape (n_parents, 2) containing (lat, lon) coordinates of parent centroids in degrees
    child_ids : np.ndarray
        Array of IDs to assign to each partition of the child blob
    lat / lon : np.ndarray
        Latitude/Longitude in degrees
        
    Returns:
    --------
    new_labels : np.ndarray
        1D array containing assigned child_ids for cells in child_mask
    """
    n_cells = len(child_mask)
    n_parents = len(parent_centroids)
    
    lat_rad = np.deg2rad(lat)
    lon_rad = np.deg2rad(lon)
    parent_coords_rad = np.deg2rad(parent_centroids)
    
    new_labels = np.zeros(n_cells, dtype=child_ids.dtype)
    
    # Process each child cell in parallel
    for i in prange(n_cells):
        if not child_mask[i]:
            continue
            
        min_dist = np.inf
        closest_parent = 0
        
        # Calculate great circle distance to each parent centroid
        for j in range(n_parents):
            dlat = parent_coords_rad[j, 0] - lat_rad[i]
            dlon = parent_coords_rad[j, 1] - lon_rad[i]
            
            # Use haversine formula for great circle distance
            a = np.sin(dlat/2)**2 + np.cos(lat_rad[i]) * np.cos(parent_coords_rad[j, 0]) * np.sin(dlon/2)**2
            dist = 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a))
            
            if dist < min_dist:
                min_dist = dist
                closest_parent = j
        
        new_labels[i] = child_ids[closest_parent]
    
    return new_labels





## Helper Function for Super Fast Sparse Bool Multiply (*without the scipy+Dask Memory Leak*)
@njit(fastmath=True, parallel=True)
def sparse_bool_power(vec, sp_data, indices, indptr, exponent):
    vec = vec.T
    num_rows = indptr.size - 1
    num_cols = vec.shape[1]
    result = vec.copy()

    for _ in range(exponent):
        temp_result = np.zeros((num_rows, num_cols), dtype=np.bool_)

        for i in prange(num_rows):
            for j in range(indptr[i], indptr[i + 1]):
                if sp_data[j]:
                    for k in range(num_cols):
                        if result[indices[j], k]:
                            temp_result[i, k] = True

        result = temp_result

    return result.T

In [17]:
# Save blob_id_field as zarr
blob_id_field.to_zarr('/scratch/b/b382615/mhws/TEMP/blob_id_field.zarr', mode='w')

2025-02-26 13:39:29,237 - distributed.worker - ERROR - Compute Failed
Key:       ('store-map-60013b7b198538c59b20948822264310', 5, 0)
State:     executing
Task:  <Task ('store-map-60013b7b198538c59b20948822264310', 5, 0) store_chunk(...)>
Exception: 'AttributeError("\'tuple\' object has no attribute \'size\'")'
Traceback: '  File "/home/b/b382615/opt/anaconda3/lib/python3.10/site-packages/dask/array/core.py", line 4629, in store_chunk\n    return load_store_chunk(x, out, index, lock, return_stored, False)\n  File "/home/b/b382615/opt/anaconda3/lib/python3.10/site-packages/dask/array/core.py", line 4609, in load_store_chunk\n    if x is not None and x.size != 0:\n'

2025-02-26 13:39:29,237 - distributed.worker - ERROR - Compute Failed
Key:       ('store-map-60013b7b198538c59b20948822264310', 7, 0)
State:     executing
Task:  <Task ('store-map-60013b7b198538c59b20948822264310', 7, 0) store_chunk(...)>
Exception: 'AttributeError("\'tuple\' object has no attribute \'size\'")'
Traceback: ' 

AttributeError: 'tuple' object has no attribute 'size'

In [None]:
# Load from zarr
blob_id_field = xr.open_zarr('/scratch/b/b382615/mhws/TEMP/blob_id_field.zarr')

In [None]:
# blob_id_field_unique = blob_id_field.copy()

In [None]:
MAX_MERGES = 20 # per timestep
MAX_PARENTS = 10 # per merge
MAX_CHILDREN = MAX_PARENTS
        
def process_chunk(chunk_data_m1_full, chunk_data_p1_full, merging_blobs, next_id_start, lat, lon, area, neighbours_int):
    """Process a single chunk of merging blobs."""
    
    ## Fix Broadcasted dimensions of inputs
    chunk_data_m1 = chunk_data_m1_full.squeeze()[0].astype(np.int32).copy()
    chunk_data    = chunk_data_m1_full.squeeze()[1].astype(np.int32).copy()
    del chunk_data_m1_full # Immediately release t+1 data !
    chunk_data_p1 = chunk_data_p1_full.squeeze().astype(np.int32).copy()
    del chunk_data_p1_full
    
    lat = lat.squeeze().astype(np.float32)
    lon = lon.squeeze().astype(np.float32)
    area = area.squeeze().astype(np.float32)
    next_id_start = next_id_start.squeeze()
    
    # Handle neighbours_int with correct dimensions (nv, ncells)
    neighbours_int = neighbours_int.squeeze()
    if neighbours_int.shape[1] != lat.shape[0]:
        neighbours_int = neighbours_int.T
    
    # Handle multiple merging blobs:
    merging_blobs = merging_blobs.squeeze()
    if merging_blobs.ndim == 1:
        # Add additional (last) dimension for max_merges
        merging_blobs = merging_blobs[:, None]
    
    # Pre-Convert lat/lon to Cartesian
    x = (np.cos(np.radians(lat)) * np.cos(np.radians(lon))).astype(np.float32)
    y = (np.cos(np.radians(lat)) * np.sin(np.radians(lon))).astype(np.float32)
    z = np.sin(np.radians(lat)).astype(np.float32)
    
    # Pre-allocate arrays for merge events
    n_time = chunk_data_p1.shape[0]
    n_points = chunk_data_p1.shape[1]
    
    merge_child_ids = np.full((n_time, MAX_MERGES, MAX_PARENTS), -1, dtype=np.int32)
    merge_parent_ids = np.full((n_time, MAX_MERGES, MAX_PARENTS), -1, dtype=np.int32)
    merge_areas = np.full((n_time, MAX_MERGES, MAX_PARENTS), -1, dtype=np.float32)
    merge_counts = np.zeros(n_time, dtype=np.int16)  # Track number of merges per timestep

    updates_array = np.full((n_time, n_points), 255, dtype=np.uint8)
    updates_ids   = np.full((n_time, 255), -1, dtype=np.int32)
    has_merge = np.zeros(n_time, dtype=np.bool_)
    
    # Process each timestep
    merging_blobs_list = [list(merging_blobs[i][merging_blobs[i]>0]) for i in range(merging_blobs.shape[0])]
    final_merging_blobs = np.full((n_time, MAX_MERGES), -1, dtype=np.int32)
    final_merge_count = 0
    
    for t in range(n_time):
        
        next_new_id = next_id_start[t]  # Use the offset for this timestep
        
        # Get current time slice data
        if t == 0:
            data_m1 = chunk_data_m1
            data_t = chunk_data
            del chunk_data_m1, chunk_data  # Clean up original references
        else:
            data_m1 = data_t
            data_t = data_p1
        data_p1 = chunk_data_p1[t]
        
        # Process each merging blob at this timestep
        while merging_blobs_list[t]:
            child_id = merging_blobs_list[t].pop(0)
            
            # Get child mask and find overlapping parents
            child_mask = (data_t == child_id)
            
            # Find parent blobs that overlap with this child
            potential_parents = np.unique(data_m1[child_mask])
            parent_iterator = 0
            parent_masks_uint = np.full(n_points, 255, dtype=np.uint8)
            parent_centroids = np.full((MAX_PARENTS, 2), -1.e10, dtype=np.float32)
            parent_ids = np.full(MAX_PARENTS, -1, dtype=np.int32)
            parent_areas = np.zeros(MAX_PARENTS, dtype=np.float32)
            overlap_areas = np.zeros(MAX_PARENTS, dtype=np.float32)
            n_parents = 0
            
            # Find all unique parent IDs that overlap with the child
            for parent_id in potential_parents[potential_parents > 0]:
                if n_parents >= MAX_PARENTS:
                    raise RuntimeError(f"Reached maximum number of parents ({MAX_PARENTS}) for child {child_id} at timestep {t}")
                    
                parent_mask = (data_m1 == parent_id)
                if np.any(parent_mask & child_mask):
                    
                    # Check if overlap area is large enough
                    area_0 = area[parent_mask].sum()  # Parent area
                    area_1 = area[child_mask].sum()
                    min_area = np.minimum(area_0, area_1)
                    overlap_area = area[parent_mask & child_mask].sum()
                    
                    if overlap_area / min_area < tracker.overlap_threshold:
                        continue
                    
                    parent_masks_uint[parent_mask] = parent_iterator
                    parent_ids[n_parents] = parent_id
                    overlap_areas[n_parents] = overlap_area
                    
                    # Calculate centroid for this parent
                    mask_area = area[parent_mask]
                    weighted_coords = np.array([
                        np.sum(mask_area * x[parent_mask]),
                        np.sum(mask_area * y[parent_mask]),
                        np.sum(mask_area * z[parent_mask])
                    ], dtype=np.float32)
                    
                    norm = np.sqrt(np.sum(weighted_coords * weighted_coords))
                                
                    # Convert back to lat/lon
                    parent_centroids[n_parents, 0] = np.degrees(np.arcsin(weighted_coords[2]/norm))
                    parent_centroids[n_parents, 1] = np.degrees(np.arctan2(weighted_coords[1], weighted_coords[0]))
                    
                    # Fix longitude range to [-180, 180]
                    if parent_centroids[n_parents, 1] > 180:
                        parent_centroids[n_parents, 1] -= 360
                    elif parent_centroids[n_parents, 1] < -180:
                        parent_centroids[n_parents, 1] += 360
                    
                    parent_areas[n_parents] = area_0
                    parent_iterator += 1
                    n_parents += 1
            
            if n_parents < 2:  # Need at least 2 parents for merging
                continue
            
            # Create new IDs for each partition
            new_child_ids = np.arange(next_new_id, next_new_id + (n_parents - 1), dtype=np.int32)
            child_ids = np.concatenate((np.array([child_id]), new_child_ids))
            
            # Record merge event
            curr_merge_idx = merge_counts[t]
            if curr_merge_idx > MAX_MERGES:
                raise RuntimeError(f"Reached maximum number of merges ({MAX_MERGES}) at timestep {t}")
            
            merge_child_ids[t, curr_merge_idx, :n_parents] = child_ids[:n_parents]
            merge_parent_ids[t, curr_merge_idx, :n_parents] = parent_ids[:n_parents]
            merge_areas[t, curr_merge_idx, :n_parents] = overlap_areas[:n_parents]
            merge_counts[t] += 1
            has_merge[t] = True
            
            # Get new labels based on partitioning method
            if tracker.nn_partitioning:
                # Estimate max_area from number of cells
                max_area = parent_areas.max() / tracker.mean_cell_area
                max_distance = int(np.sqrt(max_area) * 2.0)
                
                new_labels_uint = partition_nn_unstructured_optimised(
                    child_mask.copy(),
                    parent_masks_uint.copy(),
                    parent_centroids,
                    neighbours_int.copy(),
                    lat,
                    lon,
                    max_distance=max(max_distance, 20)*2
                )
                # Returned 'new_labels_uint' is just the index of the child_ids
                new_labels = child_ids[new_labels_uint]
                
                # Force Number JIT Cleanup
                new_labels_uint = None
                
            else:
                new_labels = partition_centroid_unstructured(
                    child_mask,
                    parent_centroids,
                    child_ids,
                    lat,
                    lon
                )
            
            # Update slice data for subsequent merging in process_chunk
            data_t[child_mask] = new_labels
            
            # Record Updates
            spatial_indices_all = np.where(child_mask)[0]
            child_mask = None
            gc.collect()
            
            for new_id in child_ids[1:]:
                update_idx = np.where(updates_ids[t] == -1)[0][0]  # Find next non-negative index in updates_ids
                updates_ids[t, update_idx] = new_id
                updates_array[t, spatial_indices_all[new_labels == new_id]] = update_idx
            
            next_new_id += n_parents - 1
            
            
            # Find all child blobs in the next timestep that overlap with our newly labeled regions
            new_merging_list = []
            for new_id in child_ids:
                parent_mask = (data_t == new_id)
                if np.any(parent_mask):
                    area_0 = area[parent_mask].sum()
                    potential_children = np.unique(data_p1[parent_mask])
                    
                    for potential_child in potential_children[potential_children > 0]:
                        potential_child_mask = (data_p1 == potential_child)
                        area_1 = area[potential_child_mask].sum()
                        min_area = min(area_0, area_1)
                        overlap_area = area[parent_mask & potential_child_mask].sum()
                        
                        if overlap_area / min_area > tracker.overlap_threshold:
                            new_merging_list.append(potential_child)
            
            # Add to processing queue if not already processed
            if t < n_time - 1:
                for new_blob_id in new_merging_list:
                    if new_blob_id not in merging_blobs_list[t+1]:
                        merging_blobs_list[t+1].append(new_blob_id)
            else:
                for new_blob_id in new_merging_list:
                    if final_merge_count > MAX_MERGES:
                        raise RuntimeError(f"Reached maximum number of merges ({MAX_MERGES}) at timestep {t}")
                    
                    # if new_blob_id is not in final_merging_blobs[t]
                    if not np.any(final_merging_blobs[t][:final_merge_count] == new_blob_id):
                        final_merging_blobs[t][final_merge_count] = new_blob_id
                        final_merge_count += 1

    # Explicitly clean up large arrays before returning
    del x, y, z, merging_blobs_list
    
    result = (merge_child_ids, merge_parent_ids, merge_areas, merge_counts, 
                has_merge, updates_array, updates_ids, 
                final_merging_blobs)
    
    return result


def update_blob_field_inplace(blob_id_field, id_lookup, updates_array, updates_ids, has_merge):
    """Update the blob field with chunk results using xarray operations."""
    
    def update_timeslice(data, updates, update_ids, lookup_values):
        """Process a single timeslice."""
        # Extract valid update IDs
        valid_ids = update_ids[update_ids > -1]
        if len(valid_ids) == 0:
            return data
            
        # Create result array starting with original values
        result = data.copy()
        
        # Apply each update
        for idx, update_id in enumerate(valid_ids):
            mask = updates == idx
            if mask.any():
                result = np.where(mask, lookup_values[update_id], result)
                
        return result
    
                
    # Convert lookup dict to array for vectorized access
    max_id = max(id_lookup.keys()) + 1 if id_lookup else 1
    lookup_array = np.full(max_id, -1, dtype=np.int32)
    for temp_id, new_id in id_lookup.items():
        lookup_array[temp_id] = new_id
    
    result = xr.apply_ufunc(
        update_timeslice,
        blob_id_field,
        updates_array,
        updates_ids,
        kwargs={'lookup_values': lookup_array},
        input_core_dims=[[tracker.xdim],
                        [tracker.xdim],
                        ['update_idx']],
        output_core_dims=[[tracker.xdim]],
        vectorize=True, 
        dask='parallelized',
        output_dtypes=[np.int32]
    )
    
    # Clear references before persisting
    del lookup_array
    gc.collect()
    
    result = result.persist(optimize_graph=True)
    
    return result


def merge_blobs_parallel_iteration(blob_id_field_unique, merging_blobs, global_id_counter):
    """Perform a single iteration of the parallel merging process."""
    
    n_time = len(blob_id_field_unique[tracker.timedim])
    
    child_ids_iter = np.full((n_time, MAX_MERGES, MAX_CHILDREN), -1, dtype=np.int32)
    parent_ids_iter = np.full((n_time, MAX_MERGES, MAX_PARENTS), -1, dtype=np.int32)
    merge_areas_iter = np.full((n_time, MAX_MERGES, MAX_PARENTS), -1, dtype=np.float32)
    merge_counts_iter = np.zeros(n_time, dtype=np.int32)
    
    neighbours_int = tracker.neighbours_int.chunk({tracker.xdim: -1, 'nv':-1})
    
    if tracker.verbosity > 0:    print(f"Processing Parallel Iteration {iteration + 1} with {len(merging_blobs)} Merging Blobs...")
    
    # Pre-compute the child_time_idx for merging_blobs
    time_index_map = tracker.compute_id_time_dict(blob_id_field_unique, list(merging_blobs), global_id_counter)
    if tracker.verbosity > 1:    print('  Finished Mapping Children to Time Indices.')
    
    # Create the uniform merging blobs array
    max_merges = max(len([b for b in merging_blobs if time_index_map.get(b, -1) == t]) for t in range(n_time))
    uniform_merging_blobs_array = np.zeros((n_time, max_merges), dtype=np.int64)
    for t in range(n_time):
        blobs_at_t = [b for b in merging_blobs if time_index_map.get(b, -1) == t]
        if blobs_at_t:  # Only fill if there are blobs at this time
            uniform_merging_blobs_array[t, :len(blobs_at_t)] = np.array(blobs_at_t, dtype=np.int64)

    merging_blobs_da = xr.DataArray(
        uniform_merging_blobs_array,
        dims=[tracker.timedim, 'merges'],
        coords={tracker.timedim: blob_id_field_unique[tracker.timedim]})
    
    next_id_offsets = np.arange(n_time) * max_merges * tracker.timechunks + global_id_counter    
    next_id_offsets_da = xr.DataArray(next_id_offsets,
                                    dims=[tracker.timedim],
                                    coords={tracker.timedim: blob_id_field_unique[tracker.timedim]})
    
    blob_id_field_unique_p1 = blob_id_field_unique.shift({tracker.timedim: -1}, fill_value=0)
    blob_id_field_unique_m1 = blob_id_field_unique.shift({tracker.timedim: 1}, fill_value=0)
    
    # Align chunks
    blob_id_field_unique_m1 = blob_id_field_unique_m1.chunk({tracker.timedim: tracker.timechunks})
    blob_id_field_unique_p1 = blob_id_field_unique_p1.chunk({tracker.timedim: tracker.timechunks})
    merging_blobs_da = merging_blobs_da.chunk({tracker.timedim: tracker.timechunks})
    next_id_offsets_da = next_id_offsets_da.chunk({tracker.timedim: tracker.timechunks})
    
    results = xr.apply_ufunc(process_chunk,
                            blob_id_field_unique_m1,
                            blob_id_field_unique_p1,
                            merging_blobs_da,
                            next_id_offsets_da,
                            blob_id_field_unique_p1.lat.astype(np.float32),
                            blob_id_field_unique_p1.lon.astype(np.float32),
                            tracker.cell_area.astype(np.float32),
                            neighbours_int,
                            input_core_dims=[[tracker.xdim], [tracker.xdim], ['merges'], [], [tracker.xdim], [tracker.xdim], [tracker.xdim], ['nv', tracker.xdim]],
                            output_core_dims=[['merge', 'parent'], ['merge', 'parent'], 
                                            ['merge', 'parent'], [],
                                            [], [tracker.xdim], ['update_idx'], ['merge']],
                            output_dtypes=[np.int32, np.int32, np.float32, np.int16, 
                                        np.bool_, np.uint8, np.int32, np.int32],
                            dask_gufunc_kwargs={'output_sizes': {
                                                                'merge': MAX_MERGES,
                                                                'parent': MAX_PARENTS,
                                                                'update_idx': 255
                                                            }},
                            vectorize=False,
                            dask='parallelized')

    # Clean up inputs after use
    del neighbours_int, uniform_merging_blobs_array, next_id_offsets
    del blob_id_field_unique_m1, blob_id_field_unique_p1, merging_blobs_da, next_id_offsets_da
    gc.collect()

    # Unpack & persist results
    (merge_child_ids, merge_parent_ids, merge_areas, merge_counts,
        has_merge, updates_array, updates_ids, final_merging_blobs) = results
    
    merge_child_ids, merge_parent_ids, merge_areas, merge_counts, \
        has_merge, updates_array, updates_ids, final_merging_blobs = persist(
            merge_child_ids, merge_parent_ids, merge_areas, merge_counts,
            has_merge, updates_array, updates_ids, final_merging_blobs
        )
    
    has_merge = has_merge.compute()
    time_indices = np.where(has_merge)[0]
    
    gc.collect()
    if tracker.verbosity > 1:    print('  Finished Batch Processing Step.')
    
    
    ### Global Consolidatation of Data ###
    
    # 1:  Collect all temporary IDs and create global mapping
    all_temp_ids = np.unique(merge_child_ids.where(merge_child_ids >= global_id_counter, other=0).compute().values)
    all_temp_ids = all_temp_ids[all_temp_ids>0] # Remove the 0...
    if not all_temp_ids.size:  # If no temporary IDs exist
        id_lookup = {}
    else:            
        id_lookup = {temp_id: np.int32(new_id) for temp_id, new_id in zip(
            all_temp_ids,
            range(global_id_counter, global_id_counter + len(all_temp_ids))
        )}
        global_id_counter += len(all_temp_ids)
    
    if tracker.verbosity > 1:    print('  Finished Consolidation Step 1: Temporary ID Mapping')
    
    # 2:  Update Field with new IDs
    blob_id_field_unique = update_blob_field_inplace(blob_id_field_unique, id_lookup, updates_array, updates_ids, has_merge)
    blob_id_field_unique = blob_id_field_unique.chunk({tracker.timedim: tracker.timechunks}) # Rechunk to avoid accumulating chunks...
    
    # Clean up large arrays but keep id_lookup until we've processed merges
    del updates_array, updates_ids
    gc.collect()
    
    if tracker.verbosity > 1:    print('  Finished Consolidation Step 2: Data Field Update.')
    
    # 3:  Update Merge Events
    new_merging_blobs = set()
    merge_counts = merge_counts.compute()
    for t in time_indices:
        count = merge_counts.isel({tracker.timedim: t}).item()
        if count > 0:
            merge_counts_iter[t] = count
            
            # Extract valid IDs and areas
            for merge_idx in range(count):
                child_ids = merge_child_ids.isel({tracker.timedim: t, 'merge': merge_idx}).compute().values
                child_ids = child_ids[child_ids >= 0]
                
                parent_ids = merge_parent_ids.isel({tracker.timedim: t, 'merge': merge_idx}).compute().values
                areas = merge_areas.isel({tracker.timedim: t, 'merge': merge_idx}).compute().values
                valid_mask = parent_ids >= 0
                parent_ids = parent_ids[valid_mask]
                areas = areas[valid_mask]
                
                # Map IDs and add to merge events
                mapped_child_ids = [id_lookup.get(id_.item(), id_.item()) for id_ in child_ids]
                mapped_parent_ids = [id_lookup.get(id_.item(), id_.item()) for id_ in parent_ids]
                
                # Store in pre-allocated arrays
                child_ids_iter[t, merge_idx, :len(mapped_child_ids)] = mapped_child_ids
                parent_ids_iter[t, merge_idx, :len(mapped_parent_ids)] = mapped_parent_ids
                merge_areas_iter[t, merge_idx, :len(areas)] = areas
    
    
    final_merging_blobs = final_merging_blobs.compute().values
    final_merging_blobs = final_merging_blobs[final_merging_blobs > 0]
    # Make sure we preserve original IDs if they're not in the lookup
    mapped_final_blobs = []
    for id_ in final_merging_blobs:
        mapped_id = id_lookup.get(id_, id_)
        if mapped_id > 0:  # Skip any invalid IDs
            mapped_final_blobs.append(mapped_id)
            
    # Add to the set of new merging blobs
    new_merging_blobs.update(mapped_final_blobs)
    
    if tracker.verbosity > 1:
        print('  Finished Consolidation Step 3: Merge List Dictionary Consolidation.')
        print(f'  Found {len(mapped_final_blobs)} potential new merging blobs from final_merging_blobs')
    
    
    # Clean up memory for large dask arrays
    del merge_child_ids, merge_parent_ids, merge_areas, merge_counts, has_merge, final_merging_blobs
    gc.collect()
                
    
    return (blob_id_field_unique,  
            (child_ids_iter, parent_ids_iter, merge_areas_iter, merge_counts_iter),
            new_merging_blobs, global_id_counter)

In [None]:
overlap_blobs_list = tracker.find_overlapping_blobs(blob_id_field_unique, blob_props)  # List blob pairs that overlap by at least overlap_threshold percent
if tracker.verbosity > 0:    print('Finished Finding Overlapping Blobs.')

# Find initial merging blobs
unique_children, children_counts = np.unique(overlap_blobs_list[:, 1], return_counts=True)
merging_blobs = set(unique_children[children_counts > 1].astype(np.int32))
del overlap_blobs_list
gc.collect()

## Process chunks iteratively until no new merging blobs remain
iteration = 0
max_iterations = 20  # i.e. 80 days (maximum event duration...)
processed_chunks = set()
global_id_counter = blob_props.ID.max().item() + 1

# Initialise global merge event tracking
global_child_ids = []
global_parent_ids = []
global_merge_areas = []
global_merge_tidx = []

In [None]:
##### WHILE !

blob_id_field_new, merge_data_iter, new_merging_blobs, global_id_counter = merge_blobs_parallel_iteration(blob_id_field_unique, merging_blobs, global_id_counter)

# Carefully manage memory - explicitly delete the old field before assigning the new one
del blob_id_field_unique
blob_id_field_unique = blob_id_field_new.persist(optimize_graph=True)
del blob_id_field_new
gc.collect()

child_ids_iter, parent_ids_iter, merge_areas_iter, merge_counts_iter = merge_data_iter

# Merge all_merge_events_chunk into all_merge_events
for t in range(len(merge_counts_iter)):
    count = merge_counts_iter[t]
    if count > 0:
        for merge_idx in range(count):
            # Get valid children
            children = child_ids_iter[t, merge_idx]
            children = children[children >= 0]
            
            # Get valid parents and areas
            parents = parent_ids_iter[t, merge_idx]
            areas = merge_areas_iter[t, merge_idx]
            valid_mask = parents >= 0
            parents = parents[valid_mask]
            areas = areas[valid_mask]
            
            if len(children) > 0 and len(parents) > 0:
                global_child_ids.append(children)
                global_parent_ids.append(parents)
                global_merge_areas.append(areas)
                global_merge_tidx.append(t)


# Clean up per-iteration arrays
del child_ids_iter, parent_ids_iter, merge_areas_iter, merge_counts_iter, merge_data_iter
gc.collect()

# Prepare for next iteration
merging_blobs = new_merging_blobs - processed_chunks
processed_chunks.update(new_merging_blobs)

# Debug info to help diagnose issues
if tracker.verbosity > 0:
    print(f"  Iteration {iteration}: found {len(new_merging_blobs)} new blobs, {len(merging_blobs)} to process next")
    
del new_merging_blobs

# Periodic full garbage collection
gc.collect()
import time
time.sleep(0.5)  # Allow time for memory to be returned to the OS
    
iteration += 1

In [None]:
***** Update every chunk during vectorised update no matter what. Then delete variables within the function... !
??? Find which data I need to release from history of the algorithm... del everything !!

In [None]:
hasattr(blob_id_field_new, '__dask_graph__')

In [None]:
import sys
graph = blob_id_field_new.__dask_graph__()
print(f"- Number of keys: {len(graph)}")
print(f"- Graph size: {sys.getsizeof(graph)} bytes")

In [None]:
blob_id_field_unique = blob_id_field_new.persist()

In [None]:
blob_id_field_unique['lat'] = blob_id_field_unique.lat.compute()
blob_id_field_unique['lon'] = blob_id_field_unique.lon.compute()
blob_id_field_unique[tracker.timedim] = blob_id_field_unique[tracker.timedim].compute()
blob_id_field_unique = blob_id_field_unique.persist()

In [None]:
blob_id_field_unique.__dask_graph__()

In [None]:
blob_areas = blob_areas.compute()

In [None]:
del blob_id_field, data_bin_filtered, cumsum_ids, ds, blob_props, blob_areas

In [None]:
del blob_id_field_new

In [None]:
gc.collect()

In [None]:
### ^^^^^^^^^^ DOES ALL THIS FIX IT ALL ????? 
## cf. the local variables that might still be present... They're holding back the graph ???

In [None]:
del child_ids_iter, parent_ids_iter, merge_areas_iter, merge_counts_iter

In [None]:
gc.collect()

In [None]:
blob_id_field_unique.mean().compute().values

In [None]:
# VVVV THis also works well -- BUT, it uses too much memory and takes too long...

In [None]:
def rebuild_clean_dask_array(blob_id_field_unique, tracker):
    from dask.distributed import get_client
    client = get_client()
    
    blob_id_field_unique['lat'] = blob_id_field_unique.lat.compute()
    blob_id_field_unique['lon'] = blob_id_field_unique.lon.compute()
    blob_id_field_unique[tracker.timedim] = blob_id_field_unique[tracker.timedim].compute()
    blob_id_field_unique = blob_id_field_unique.persist()
    
    times = blob_id_field_unique[tracker.timedim].compute().values
    chunk_size = tracker.timechunks
    futures = []

    # Compute chunks in parallel using Dask
    for i in range(0, len(times), chunk_size):
        future = client.compute(
            blob_id_field_unique.isel({tracker.timedim: slice(i, i + chunk_size)})
        )
        futures.append(future)
    
    # Get results
    chunks = client.gather(futures)
    
    # Clean up old array
    del blob_id_field_unique
    client.run(gc.collect)
    
    # Create fresh dask array
    new_array = xr.concat(chunks, dim=tracker.timedim)
    new_array = new_array.chunk({tracker.timedim: tracker.timechunks})
    
    del chunks
    del futures
    gc.collect()
    
    return new_array

In [None]:
blob_id_field_new = blob_id_field_new.persist()
del blob_id_field_unique
gc.collect()

In [None]:
blob_id_field_unique = rebuild_clean_dask_array(blob_id_field_new, tracker)

In [None]:
blob_id_field_unique['lat'] = blob_id_field_unique.lat.compute()
blob_id_field_unique['lon'] = blob_id_field_unique.lon.compute()
blob_id_field_unique['time'] = blob_id_field_unique.time.compute()
blob_id_field_unique = blob_id_field_unique.persist()

In [None]:
client.cancel(blob_id_field_new)
gc.collect()

In [None]:
##### VVVV   THIS WORKS

In [None]:
def rebuild_clean_dask_array(blob_id_field_unique, tracker):
    from dask.distributed import get_client
    client = get_client()
    
    times = blob_id_field_unique[tracker.timedim]
    chunk_size = tracker.timechunks
    futures = []

    # Compute chunks in parallel using Dask
    for i in range(0, len(times), chunk_size):
        future = client.compute(
            blob_id_field_unique.isel({tracker.timedim: slice(i, i + chunk_size)})
        )
        futures.append(future)
    
    # Get results
    chunks = client.gather(futures)
    
    # Clean up old array
    del blob_id_field_unique
    client.run(gc.collect)
    
    # Create fresh dask array
    new_array = xr.concat(chunks, dim=tracker.timedim)
    new_array = new_array.chunk({tracker.timedim: tracker.timechunks})
    
    del chunks
    del futures
    gc.collect()
    
    return new_array

In [None]:
blob_id_field_new = blob_id_field_unique.persist()

In [None]:
blob_id_field_new['lat'] = blob_id_field_new.lat.compute()
blob_id_field_new['lon'] = blob_id_field_new.lon.compute()
blob_id_field_new['time'] = blob_id_field_new.time.compute()
blob_id_field_new = blob_id_field_new.persist()

In [None]:
del blob_id_field_unique
gc.collect()

In [None]:
blob_id_field_unique = rebuild_clean_dask_array(blob_id_field_new, tracker)

In [None]:
blob_id_field_unique['lat'] = blob_id_field_unique.lat.compute()
blob_id_field_unique['lon'] = blob_id_field_unique.lon.compute()
blob_id_field_unique['time'] = blob_id_field_unique.time.compute()
blob_id_field_unique = blob_id_field_unique.persist()

In [None]:
client.cancel(blob_id_field_new)
gc.collect()

In [None]:
times = blob_id_field_unique[tracker.timedim]
chunk_size = tracker.timechunks
chunks = []

for i in range(0, len(times), chunk_size):
    chunk = blob_id_field_unique.isel({tracker.timedim: slice(i, i + chunk_size)}).compute()
    chunks.append(chunk)
del blob_id_field_unique


# Create fresh dask array from computed chunks
blob_id_field_unique = xr.concat(chunks, dim=tracker.timedim)
blob_id_field_unique = blob_id_field_unique.chunk({tracker.timedim: tracker.timechunks})

# Clear any references to old dask tasks
del blob_id_field_new
del chunks
gc.collect()

In [None]:
blobs = blobs.compute() 

In [None]:
# Save Tracked Blobs to `zarr` for more efficient parallel I/O

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'MHWs_tracked_unstruct.zarr'
blobs.to_zarr(file_name, mode='w')