# Identify & Track Marine Heatwaves using `spot_the_blOb`

## Processing Steps:
1. Fill holes in the binary data, using `dask_image.ndmorph` -- up to `R_fill` cells in radius.
2. Filter out small objects -- area less than the `area_filter_quartile` of the distribution of objects.
3. Identify objects in the binary data, using `dask_image.ndmeasure`.
4. Manually connect objects across time, applying Sun et al. 2023 criteria:
    - Connected Blobs must overlap by at least `overlap_threshold=50%` of the smaller blob.
    - Merged Blobs retain their original ID, but split the blob based on parent centroid locality.
5. Cluster and reduce the final object ID graph using `scipy.sparse.csgraph.connected_components`.

N.B.: Exploits parallelised `Dask` operations with optimised chunking using `flox` for memory efficiency and speed \
N.N.B.: This example using 40 years of Daily outputs at 0.25° resolution takes ~6 minutes on 128 total cores.

In [1]:
import xarray as xr
import dask
from getpass import getuser
from pathlib import Path

import spot_the_blOb as blob
import spot_the_blOb.helper as hpc

In [2]:
# Start Dask Cluster
client = hpc.StartLocalCluster(n_workers=32, n_threads=2)

Memory per Worker: 15.74 GB
Hostname is  l40287
Forward Port = l40287:8787
Dashboard Link: localhost:8787/status


In [3]:
# Load Pre-processed Data (cf. `01_preprocess_extremes.ipynb`)

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'extreme_events_binary.zarr'
chunk_size = {'time': 25, 'lat': -1, 'lon': -1}
ds = xr.open_zarr(str(file_name), chunks=chunk_size)

In [4]:
# Extract Binary Features and Modify Mask

extreme_bin = ds.extreme_events
mask = ds.mask.where((ds.lat<85) & (ds.lat>-90), other=False)

In [5]:
# Tracking Parameters

drop_area_quartile = 0.5
filling_radius = 8
allow_merging = True

In [6]:
# Spot the Blobs

tracker = blob.Spotter(extreme_bin, mask, R_fill=filling_radius, area_filter_quartile=drop_area_quartile, allow_merging=allow_merging)
#blobs = tracker.run()

#blobs

In [7]:
data_bin_filled = tracker.fill_holes()

In [8]:
data_bin_filtered, area_threshold, blob_areas, N_blobs_unfiltered = tracker.filter_small_blobs(data_bin_filled)

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


In [9]:
data_bin = data_bin_filtered

In [10]:
blob_id_field, _ = tracker.identify_blobs(data_bin, time_connectivity=False)

In [11]:
import numpy as np
import matplotlib.pyplot as plt

In [12]:
# Calculate Properties of each Blob
blob_props = tracker.calculate_blob_properties(blob_id_field, properties=['area', 'centroid'])


In [None]:

# Compile List of Overlapping Blob ID Pairs Across Time
overlap_blobs_list = tracker.find_overlapping_blobs(blob_id_field)  # List of overlapping blob pairs


In [None]:
# split_merged_blob_id_field_unique, merged_blobs_props, split_merged_blobs_list, merged_blobs_ledger = tracker.split_and_merge_blobs(blob_id_field, blob_props, overlap_blobs_list)


In [None]:
blob_id_field_unique = blob_id_field.copy()

In [None]:

# Vectorised computation of overlap fractions
areas_0 = blob_props['area'].sel(ID=overlap_blobs_list[:, 0]).values
areas_1 = blob_props['area'].sel(ID=overlap_blobs_list[:, 1]).values
min_areas = np.minimum(areas_0, areas_1)
overlap_fractions = overlap_blobs_list[:, 2].astype(float) / min_areas

# Filter out the overlaps that are too small
overlap_blobs_list = overlap_blobs_list[overlap_fractions >= tracker.overlap_threshold]



##### Consider Merging Blobs

# Initialise merge tracking structures
merge_ledger = []                      # List of IDs of the 2 Merging Parents
next_new_id = blob_props.ID.max().item() + 1  # Start new IDs after highest existing ID

# Find all the Children (t+1 / RHS) elements that appear multiple times --> Indicates there are 2+ Parent Blobs...
unique_children, children_counts = np.unique(overlap_blobs_list[:, 1], return_counts=True)
merging_blobs = unique_children[children_counts > 1]

# Pre-compute the child_time_idx & 2d_mask_id for each child_blob
time_index_map = tracker.compute_id_time_dict_mask(blob_id_field_unique, merging_blobs, next_new_id)
Nx = blob_id_field_unique[tracker.xdim].size

# Group blobs by time-chunk
# -- Pre-condition: Blob IDs should be monotonically increasing in time...
chunk_boundaries = np.cumsum([0] + list(blob_id_field_unique.chunks[0] ))
blobs_by_chunk = {}
for blob_id in merging_blobs:
    # Find which chunk this time index belongs to
    chunk_idx = np.searchsorted(chunk_boundaries, time_index_map[blob_id], side='right') - 1
    blobs_by_chunk.setdefault(chunk_idx, []).append(blob_id)



In [None]:
from numba import jit, prange

In [None]:
@jit(nopython=True, parallel=True, fastmath=True)
def wrapped_euclidian_parallel(mask_values, parent_centroids_values, Nx):
    """
    Optimised function for computing wrapped Euclidean distances.
    
    Parameters:
    -----------
    mask_values : np.ndarray
        2D boolean array where True indicates points to calculate distances for
    parent_centroids_values : np.ndarray
        Array of shape (n_parents, 2) containing (y, x) coordinates of parent centroids
    Nx : int
        Size of the x-dimension for wrapping
        
    Returns:
    --------
    distances : np.ndarray
        Array of shape (n_true_points, n_parents) with minimum distances
    """
    n_parents = len(parent_centroids_values)
    half_Nx = Nx / 2
    
    y_indices, x_indices = np.nonzero(mask_values)
    n_true = len(y_indices)
    
    distances = np.empty((n_true, n_parents), dtype=np.float64)
    
    # Precompute for faster access
    parent_y = parent_centroids_values[:, 0]
    parent_x = parent_centroids_values[:, 1]
    
    # Parallel loop over true positions
    for idx in prange(n_true):
        y, x = y_indices[idx], x_indices[idx]
        
        # Pre-compute y differences for all parents
        dy = y - parent_y
        
        # Pre-compute x differences for all parents
        dx = x - parent_x
        
        # Wrapping correction
        dx = np.where(dx > half_Nx, dx - Nx, dx)
        dx = np.where(dx < -half_Nx, dx + Nx, dx)
        
        distances[idx] = np.sqrt(dy * dy + dx * dx)
    
    return distances

In [None]:
import time

chunk_start_time = time.time()
for chunk_idx, chunk_blobs in blobs_by_chunk.items(): # Loop over each time-chunk
    print(f"\nProcessing chunk {chunk_idx}...")
    chunk_iter_start = time.time()
    
    chunk_start = sum(blob_id_field_unique.chunks[0][:chunk_idx])
    chunk_end = chunk_start + blob_id_field_unique.chunks[0][chunk_idx] + 1  #  We also want access to the blob_id_time_p1...  But need to remember to remove the last time later
    
    # Load the entire chunk into memory
    chunk_data = blob_id_field_unique.isel({tracker.timedim: slice(chunk_start, chunk_end)}).compute()
    
    # Process each blob in this chunk
    for child_id in chunk_blobs:
        #child_start_time = time.time()
        
        child_time_idx = time_index_map[child_id]
        relative_time_idx = child_time_idx - chunk_start
        
        blob_id_time = chunk_data.isel({tracker.timedim: relative_time_idx})
        blob_id_time_p1 = chunk_data.isel({tracker.timedim: relative_time_idx+1})
        
        child_mask_2d  = (blob_id_time == child_id).values  #child_mask_2d_all.sel(child_id=child_id).values
        
        # Find all pairs involving this Child Blob
        child_mask = overlap_blobs_list[:, 1] == child_id
        child_where = np.where(overlap_blobs_list[:, 1] == child_id)[0]  # Needed for assignment
        merge_group = overlap_blobs_list[child_mask]
        
        # Get all Parents (LHS) Blobs that overlap with this Child Blob -- N.B. This is now generalised for N-parent merging !
        parent_ids = merge_group[:, 0]
        num_parents = len(parent_ids)
        
        # Make a new ID for the other Half of the Child Blob & Record in the Merge Ledger
        new_blob_id = np.arange(next_new_id, next_new_id + (num_parents - 1), dtype=np.int32)
        next_new_id += num_parents - 1
        merge_ledger.append(parent_ids)
        
        # Replace the 2nd+ Child in the Overlap Blobs List with the new Child ID
        overlap_blobs_list[child_where[1:], 1] = new_blob_id    #overlap_blobs_list[child_mask, 1][1:] = new_blob_id
        child_ids = np.concatenate((np.array([child_id]), new_blob_id))    #np.array([child_id, new_blob_id])
        
        ## Relabel the Original Child Blob ID Field to account for the New ID:
        # --> For every (Original) Child Cell in the ID Field, Measure the Distance to the Centroids of the Parents
        # --> Assign the ID for each Cell corresponding to the closest Parent
        
        parent_centroids = blob_props.sel(ID=parent_ids).centroid.values.T  # (y, x), [:,0] are the y's
        distances = wrapped_euclidian_parallel(child_mask_2d, parent_centroids, Nx)  # **Deals with wrapping**
        
        # Assign the new ID to each cell based on the closest parent
        new_labels = child_ids[np.argmin(distances, axis=1)]
        
        # Update values in child_time_idx and assign the updated slice back to the original DataArray
        temp = np.zeros_like(blob_id_time)
        temp[child_mask_2d] = new_labels
        blob_id_time = blob_id_time.where(~child_mask_2d, temp)
        # blob_id_field_unique[{tracker.timedim: child_time_idx}] = blob_id_time
        ## ** Update directly into the chunk
        chunk_data[{tracker.timedim: relative_time_idx}] = blob_id_time
        
        ## Update the Properties of the N Children Blobs
        new_child_props = tracker.calculate_blob_properties(blob_id_time, properties=['area', 'centroid'])
        
        # Update the properties for the original child ID
        blob_props.loc[dict(ID=child_id)] = new_child_props.sel(ID=child_id)
        
        # Add the properties for the N-1 other new child ID
        blob_props = xr.concat([blob_props, new_child_props.sel(ID=new_blob_id)], dim='ID')
    
        ## Finally, we need to re-assess all of the Parent IDs (LHS) equal to the (original) child_id
        
        # Look at the overlap IDs between the original child_id and the next time-step, and also the new_blob_id and the next time-step
        new_overlaps = tracker.check_overlap_slice(blob_id_time.values, blob_id_time_p1.values)
        new_child_overlaps_list = new_overlaps[(new_overlaps[:, 0] == child_id) | np.isin(new_overlaps[:, 0], new_blob_id)]
        
        # _Before_ replacing the overlap_blobs_list, we need to re-assess the overlap fractions of just the new_child_overlaps_list
        areas_0 = blob_props['area'].sel(ID=new_child_overlaps_list[:, 0]).values
        areas_1 = blob_props['area'].sel(ID=new_child_overlaps_list[:, 1]).values
        min_areas = np.minimum(areas_0, areas_1)
        overlap_fractions = new_child_overlaps_list[:, 2].astype(float) / min_areas
        new_child_overlaps_list = new_child_overlaps_list[overlap_fractions >= tracker.overlap_threshold]
        
        # Replace the lines in the overlap_blobs_list where (original) child_id is on the LHS, with these new pairs in new_child_overlaps_list
        child_mask_LHS = overlap_blobs_list[:, 0] == child_id
        overlap_blobs_list = np.concatenate([overlap_blobs_list[~child_mask_LHS], new_child_overlaps_list])

        # Print timing for this child_id
        child_elapsed = time.time() - child_start_time
        #print(f"Completed processing child_id {child_id} in {child_elapsed:.2f} seconds")
    
    # Update the full dask DataArray with this processed chunk
    blob_id_field_unique[{
        tracker.timedim: slice(chunk_start, chunk_end-1)  # cf. above definition of chunk_end for why we need -1
    }] = chunk_data[:-1]
    
    # Print timing for this chunk
    chunk_elapsed = time.time() - chunk_iter_start
    print(f"\nCompleted chunk {chunk_idx} in {chunk_elapsed:.2f} seconds")
    print(f"Processed {len(chunk_blobs)} blobs in this chunk")
    print(f"Average time per blob: {chunk_elapsed/len(chunk_blobs):.2f} seconds")

# Print total execution time
total_elapsed = time.time() - chunk_start_time
print(f"\nTotal execution time: {total_elapsed:.2f} seconds")

In [None]:
parent_centroids

In [None]:
parent_ids

In [None]:
child_ids

In [None]:
plt.pcolor(blob_id_field_unique.isel({tracker.timedim: child_time_idx-1}))

In [None]:
plt.pcolor(child_mask_2d)

In [None]:
parent_centroids

In [None]:
distances

In [None]:
(new_labels == 149898).sum()

In [None]:
child_id

In [None]:
parent_centroids

In [None]:
plt.pcolor(blob_id_field_unique.isel({tracker.timedim: child_time_idx}), vmax=10)
plt.colorbar()

In [None]:
plt.pcolor(child_mask_2d)

In [None]:
plt.pcolor(blob_id_field_unique.isel({tracker.timedim: child_time_idx-1})==998)

In [None]:
binary_mask = blob_id_field_unique.isel({tracker.timedim: child_time_idx-1})==998

In [None]:
binary_mask

In [None]:
tracker.calculate_centroid(binary_mask.values, (123, 123123))

In [None]:
parent_ids

In [None]:
parent_centroids

In [None]:
new_child_props.sel(ID=new_blob_id)

In [None]:
blob_props.isel(ID=1011)

In [None]:
# Cluster Blobs List to Determine Globally Unique IDs & Update Blob ID Field
split_merged_blobs_ds = tracker.cluster_rename_blobs_and_props(split_merged_blob_id_field_unique, merged_blobs_props, split_merged_blobs_list)

In [None]:
# Add Merge Ledger to split_merged_blobs_ds
split_merged_blobs_ds.attrs['merge_ledger'] = merged_blobs_ledger

# Count Number of Blobs (This may have increased due to splitting)
N_blobs = split_merged_blobs_ds.ID_field.max().compute().data

In [None]:
blobs.attrs

In [None]:
# Save Tracked Blobs

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'MHWs_tracked.nc'
blobs.to_netcdf(file_name, mode='w')