# Identify & Track Marine Heatwaves using `spot_the_blOb`

## Processing Steps:
1. Fill holes in the binary data, using `dask_image.ndmorph` -- up to `R_fill` cells in radius.
2. Filter out small objects -- area less than the `area_filter_quartile` of the distribution of objects.
3. Identify objects in the binary data, using `dask_image.ndmeasure`.
4. Manually connect objects across time, applying Sun et al. 2023 criteria:
    - Connected Blobs must overlap by at least `overlap_threshold=50%` of the smaller blob.
    - Merged Blobs retain their original ID, but split the blob based on parent centroid locality.
5. Cluster and reduce the final object ID graph using `scipy.sparse.csgraph.connected_components`.

N.B.: Exploits parallelised `Dask` operations with optimised chunking using `flox` for memory efficiency and speed \
N.N.B.: This example using 40 years of Daily outputs at 0.25° resolution takes ~6 minutes on 128 total cores.

In [1]:
import xarray as xr
import dask
from getpass import getuser
from pathlib import Path

import spot_the_blOb as blob
import spot_the_blOb.helper as hpc

In [2]:
# Start Dask Cluster
client = hpc.StartLocalCluster(n_workers=32, n_threads=2)

Memory per Worker: 15.74 GB
Hostname is  l40128
Forward Port = l40128:8787
Dashboard Link: localhost:8787/status


In [3]:
# Load Pre-processed Data (cf. `01_preprocess_extremes.ipynb`)

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'extreme_events_binary.zarr'
chunk_size = {'time': 25, 'lat': -1, 'lon': -1}
ds = xr.open_zarr(str(file_name), chunks=chunk_size)

In [4]:
# Extract Binary Features and Modify Mask

extreme_bin = ds.extreme_events
mask = ds.mask.where((ds.lat<85) & (ds.lat>-90), other=False)

In [5]:
# Tracking Parameters

drop_area_quartile = 0.5
filling_radius = 8
allow_merging = True

In [6]:
# Spot the Blobs

tracker = blob.Spotter(extreme_bin, mask, R_fill=filling_radius, area_filter_quartile=drop_area_quartile, allow_merging=allow_merging)
#blobs = tracker.run()

#blobs

In [7]:
data_bin_filled = tracker.fill_holes()

In [None]:
data_bin_filtered, area_threshold, blob_areas, N_blobs_unfiltered = tracker.filter_small_blobs(data_bin_filled)

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


In [None]:
data_bin = data_bin_filtered

In [None]:
blob_id_field, _ = tracker.identify_blobs(data_bin, time_connectivity=False)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Calculate Properties of each Blob
blob_props = tracker.calculate_blob_properties(blob_id_field, properties=['area', 'centroid'])


In [None]:

# Compile List of Overlapping Blob ID Pairs Across Time
overlap_blobs_list = tracker.find_overlapping_blobs(blob_id_field)  # List of overlapping blob pairs


In [None]:
# split_merged_blob_id_field_unique, merged_blobs_props, split_merged_blobs_list, merged_blobs_ledger = tracker.split_and_merge_blobs(blob_id_field, blob_props, overlap_blobs_list)


In [None]:
blob_id_field_unique = blob_id_field.copy()

In [None]:
# Vectorised computation of overlap fractions
areas_0 = blob_props['area'].sel(ID=overlap_blobs_list[:, 0]).values
areas_1 = blob_props['area'].sel(ID=overlap_blobs_list[:, 1]).values
min_areas = np.minimum(areas_0, areas_1)
overlap_fractions = overlap_blobs_list[:, 2].astype(float) / min_areas

# Filter out the overlaps that are too small
overlap_blobs_list = overlap_blobs_list[overlap_fractions >= tracker.overlap_threshold]



##### Consider Merging Blobs

# Initialise merge tracking structures
merge_ledger = []                      # List of IDs of the 2 Merging Parents
next_new_id = blob_props.ID.max().item() + 1  # Start new IDs after highest existing ID

# Find all the Children (t+1 / RHS) elements that appear multiple times --> Indicates there are 2+ Parent Blobs...
unique_children, children_counts = np.unique(overlap_blobs_list[:, 1], return_counts=True)
merging_blobs = unique_children[children_counts > 1]

In [None]:
### Pre-compute the child_time_idx for each child_blob

search_ids = xr.DataArray(
        merging_blobs,
        dims=['search_id'],
        coords={'search_id': merging_blobs}
    )

# Reduce boolean array in spatial dimensions for all IDs at once
mask_4d = blob_id_field_unique == search_ids
presence_by_time = mask_4d.any(dim=[tracker.ydim, tracker.xdim])

# Find time index
time_indices = presence_by_time.argmax(dim=tracker.timedim).compute()

# Convert to dictionary for fast lookup
time_index_map = {
    int(id_val): int(idx.values) 
    for id_val, idx in time_indices.items()
}



#... then can persist the present time slice, and reuse it....

In [None]:
import time  # Add at the top with other imports

for child_id in merging_blobs:
    start_total = time.perf_counter()
    print(f'Merging Child ID: {child_id}')
    
    # Find all pairs involving this Child Blob
    t0 = time.perf_counter()
    child_mask = overlap_blobs_list[:, 1] == child_id
    child_where = np.where(overlap_blobs_list[:, 1] == child_id)[0]  # Needed for assignment
    merge_group = overlap_blobs_list[child_mask]
    parent_ids = merge_group[:, 0]
    num_parents = len(parent_ids)
    print(f'Parent IDs: {parent_ids}')
    print(f'Time for finding pairs: {time.perf_counter() - t0:.4f}s')
    
    # Make a new ID for the other Half of the Child Blob & Record in the Merge Ledger
    t0 = time.perf_counter()
    new_blob_id = np.arange(next_new_id, next_new_id + (num_parents - 1), dtype=np.int32)
    next_new_id += num_parents - 1
    merge_ledger.append(parent_ids)
    overlap_blobs_list[child_where[1:], 1] = new_blob_id
    child_ids = np.concatenate((np.array([child_id]), new_blob_id))
    print(f'Time for ID assignment: {time.perf_counter() - t0:.4f}s')
    
    # Detailed timing for distance calculations and relabeling
    print("\nDetailed timing for distance calculations:")
    
    t_start = time.perf_counter()
    parent_centroids = blob_props.sel(ID=parent_ids).centroid.isel(component=[1,0]).values.T
    t1 = time.perf_counter()
    print(f'  Getting parent centroids: {t1 - t_start:.4f}s')
    
    child_time_idx = (blob_id_field_unique == child_id).any(dim=[tracker.ydim, tracker.xdim]).argmax().compute().item()
    t2 = time.perf_counter()
    print(f'  Finding child time index (with compute): {t2 - t1:.4f}s')
    
    child_mask_2d = blob_id_field_unique.isel({tracker.timedim: child_time_idx}) == child_id
    t3 = time.perf_counter()
    print(f'  Creating 2D child mask: {t3 - t2:.4f}s')
    
    child_coords = np.stack(np.where(child_mask_2d), axis=1)
    t4 = time.perf_counter()
    print(f'  Getting child coordinates: {t4 - t3:.4f}s')
    
    # Break down the distance calculation
    print(f'  Child coords shape: {child_coords.shape}, Parent centroids shape: {parent_centroids.shape}')
    expanded_coords = child_coords[:, None]  # Broadcasting preparation
    t5 = time.perf_counter()
    print(f'  Broadcasting preparation: {t5 - t4:.4f}s')
    
    coord_diff = expanded_coords - parent_centroids
    t6 = time.perf_counter()
    print(f'  Coordinate differencing: {t6 - t5:.4f}s')
    
    distances = np.linalg.norm(coord_diff, axis=2)
    t7 = time.perf_counter()
    print(f'  Computing distances: {t7 - t6:.4f}s')
    
    new_labels = child_ids[np.argmin(distances, axis=1)]
    t8 = time.perf_counter()
    print(f'  Assigning new labels: {t8 - t7:.4f}s')
    print(f'Total time for distance calculations: {t8 - t_start:.4f}s\n')
    
    # Update blob field values
    t0 = time.perf_counter()
    tslice_child = blob_id_field_unique.isel({tracker.timedim: child_time_idx})
    temp = np.zeros_like(tslice_child)
    temp[child_mask_2d] = new_labels
    blob_id_field_unique[{tracker.timedim: child_time_idx}] = tslice_child.where(~child_mask_2d, temp)
    print(f'Time for field update: {time.perf_counter() - t0:.4f}s')
    
    # Update blob properties
    t0 = time.perf_counter()
    new_child_props = tracker.calculate_blob_properties(blob_id_field_unique.isel({tracker.timedim: child_time_idx}), properties=['area', 'centroid'])
    blob_props.loc[dict(ID=child_id)] = new_child_props.sel(ID=child_id)
    blob_props = xr.concat([blob_props, new_child_props.sel(ID=new_blob_id)], dim='ID')
    print(f'Time for properties update: {time.perf_counter() - t0:.4f}s')
    
    # Re-assess overlaps
    t0 = time.perf_counter()
    new_overlaps = tracker.check_overlap_slice(blob_id_field_unique.isel({tracker.timedim: child_time_idx}).values, 
                                             blob_id_field_unique.isel({tracker.timedim: child_time_idx+1}).values)
    new_child_overlaps_list = new_overlaps[(new_overlaps[:, 0] == child_id) | np.isin(new_overlaps[:, 0], new_blob_id)]
    
    areas_0 = blob_props['area'].sel(ID=new_child_overlaps_list[:, 0]).values
    areas_1 = blob_props['area'].sel(ID=new_child_overlaps_list[:, 1]).values
    min_areas = np.minimum(areas_0, areas_1)
    overlap_fractions = new_child_overlaps_list[:, 2].astype(float) / min_areas
    new_child_overlaps_list = new_child_overlaps_list[overlap_fractions >= tracker.overlap_threshold]
    
    child_mask_LHS = overlap_blobs_list[:, 0] == child_id
    overlap_blobs_list = np.concatenate([overlap_blobs_list[~child_mask_LHS], new_child_overlaps_list])
    print(f'Time for overlap reassessment: {time.perf_counter() - t0:.4f}s')
    
    print(f'Total time for this child_id: {time.perf_counter() - start_total:.4f}s\n')

In [None]:
# Cluster Blobs List to Determine Globally Unique IDs & Update Blob ID Field
split_merged_blobs_ds = tracker.cluster_rename_blobs_and_props(split_merged_blob_id_field_unique, merged_blobs_props, split_merged_blobs_list)

In [None]:
# Add Merge Ledger to split_merged_blobs_ds
split_merged_blobs_ds.attrs['merge_ledger'] = merged_blobs_ledger

# Count Number of Blobs (This may have increased due to splitting)
N_blobs = split_merged_blobs_ds.ID_field.max().compute().data

In [None]:
blobs.attrs

In [None]:
# Save Tracked Blobs

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'MHWs_tracked.nc'
blobs.to_netcdf(file_name, mode='w')