# Identify & Track Marine Heatwaves using `spot_the_blOb`

## Processing Steps:
1. Fill holes in the binary data, using `dask_image.ndmorph` -- up to `R_fill` cells in radius.
2. Filter out small objects -- area less than the `area_filter_quartile` of the distribution of objects.
3. Identify objects in the binary data, using `dask_image.ndmeasure`.
4. Manually connect objects across time, applying Sun et al. 2023 criteria:
    - Connected Blobs must overlap by at least `overlap_threshold=50%` of the smaller blob.
    - Merged Blobs retain their original ID, but split the blob based on parent centroid locality.
5. Cluster and reduce the final object ID graph using `scipy.sparse.csgraph.connected_components`.

N.B.: Exploits parallelised `Dask` operations with optimised chunking using `flox` for memory efficiency and speed \
N.N.B.: This example using 40 years of Daily outputs at 0.25° resolution takes ~6 minutes on 128 total cores.

In [1]:
# Takes ~14 minutes per 1000 time steps on 32 workers (2 threads)

In [2]:
import xarray as xr
import dask
from getpass import getuser
from pathlib import Path

import spot_the_blOb as blob
import spot_the_blOb.helper as hpc

In [3]:
# Start Dask Cluster
client = hpc.StartLocalCluster(n_workers=32, n_threads=2)

Memory per Worker: 15.74 GB
Hostname is  l40200
Forward Port = l40200:8787
Dashboard Link: localhost:8787/status


In [4]:
# Load Pre-processed Data (cf. `01_preprocess_extremes.ipynb`)

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'extreme_events_binary.zarr'
chunk_size = {'time': 25, 'lat': -1, 'lon': -1}
ds = xr.open_zarr(str(file_name), chunks=chunk_size)

In [5]:
# Extract Binary Features and Modify Mask

extreme_bin = ds.extreme_events.isel(time=slice(0, 500))
mask = ds.mask.where((ds.lat<85) & (ds.lat>-90), other=False)

In [6]:
# Tracking Parameters

drop_area_quartile = 0.5
filling_radius = 8
allow_merging = True
nn_partitioning = True

In [None]:
# Spot the Blobs

tracker = blob.Spotter(extreme_bin, mask, R_fill=filling_radius, area_filter_quartile=drop_area_quartile, 
                       allow_merging=allow_merging, nn_partitioning=nn_partitioning)
blobs = tracker.run()
blobs

Finished filling holes.
Finished filtering small blobs.
Finished blob identification.
Finished calculating blob properties.
Finished finding overlapping blobs.
Processed chunk 0 of 20


In [None]:
# Save Tracked Blobs

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'MHWs_tracked.nc'
blobs.to_netcdf(file_name, mode='w')

In [None]:
data_bin_filled = tracker.fill_holes()
print('Finished filling holes.')

# Remove Small Objects
data_bin_filtered, area_threshold, blob_areas, N_blobs_unfiltered = tracker.filter_small_blobs(data_bin_filled)
print('Finished filtering small blobs.')

In [None]:
data_bin = data_bin_filtered

In [None]:
blob_id_field, _ = tracker.identify_blobs(data_bin, time_connectivity=False)
print('Finished blob identification.')

# Calculate Properties of each Blob
blob_props = tracker.calculate_blob_properties(blob_id_field, properties=['area', 'centroid'])
print('Finished calculating blob properties.')

# Compile List of Overlapping Blob ID Pairs Across Time
overlap_blobs_list = tracker.find_overlapping_blobs(blob_id_field)  # List of overlapping blob pairs
print('Finished finding overlapping blobs.')

In [None]:
split_merged_blob_id_field_unique, merged_blobs_props, split_merged_blobs_list, merge_events = tracker.split_and_merge_blobs(blob_id_field, blob_props, overlap_blobs_list)

In [None]:
import xarray as xr
import numpy as np
from dask_image.ndmeasure import label
from skimage.measure import regionprops_table
from dask_image.ndmorph import binary_closing as binary_closing_dask
from dask_image.ndmorph import binary_opening as binary_opening_dask
from scipy.ndimage import binary_closing, binary_opening
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from dask import persist
from dask.base import is_dask_collection
from numba import jit, prange
import warnings

In [None]:
blob_id_field_unique = split_merged_blob_id_field_unique
blobs_props = merged_blobs_props
overlap_blobs_list = split_merged_blobs_list

In [None]:
IDs = np.unique(overlap_blobs_list) # 1D sorted unique

# Create a mapping from ID to indices
ID_to_index = {ID: index for index, ID in enumerate(IDs)}

# Convert overlap pairs to indices
overlap_pairs_indices = np.array([(ID_to_index[pair[0]], ID_to_index[pair[1]]) for pair in overlap_blobs_list])

# Create a sparse matrix representation of the graph
n = len(IDs)
row_indices, col_indices = overlap_pairs_indices.T
data = np.ones(len(overlap_pairs_indices))
graph = csr_matrix((data, (row_indices, col_indices)), shape=(n, n))

# Solve the graph to determine connected components
num_components, component_IDs = connected_components(csgraph=graph, directed=False, return_labels=True)

# Group IDs by their component index
ID_clusters = [[] for _ in range(num_components)]
for ID, component_ID in zip(IDs, component_IDs):
    ID_clusters[component_ID].append(ID)

In [None]:
min_int32 = np.iinfo(np.int32).min
max_old_ID = blob_id_field_unique.max().compute().data
ID_to_cluster_index_array = np.full(max_old_ID + 1, min_int32, dtype=np.int32)

# Fill the lookup array with cluster indices
for index, cluster in enumerate(ID_clusters):
    for ID in cluster:
        ID_to_cluster_index_array[ID] = np.int32(index+1) # Because these are the connected IDs, there are many fewer!
                                                            #  Add 1 so that ID = 0 is still invalid/no object

# N.B.: **Need to pass da into apply_ufunc, otherwise it doesn't manage the memory correctly with large shared-mem numpy arrays**
ID_to_cluster_index_da = xr.DataArray(ID_to_cluster_index_array, dims='ID', coords={'ID': np.arange(max_old_ID + 1)})

def map_IDs_to_indices(block, ID_to_cluster_index_array):
    mask = block > 0
    new_block = np.zeros_like(block, dtype=np.int32)
    new_block[mask] = ID_to_cluster_index_array[block[mask]]
    return new_block

split_merged_relabeled_blob_id_field = xr.apply_ufunc(
    map_IDs_to_indices,
    blob_id_field_unique, 
    ID_to_cluster_index_da,
    input_core_dims=[[tracker.ydim, tracker.xdim],['ID']],
    output_core_dims=[[tracker.ydim, tracker.xdim]],
    vectorize=True,
    dask="parallelized",
    output_dtypes=[np.int32]
)

In [None]:
split_merged_relabeled_blob_id_field = split_merged_relabeled_blob_id_field.persist()

In [None]:
### Relabel the blobs_props to match the new IDs (and add time dimension!)

max_new_ID = num_components + 1  # New IDs range from 0 to max_new_ID...
new_ids = np.arange(1, max_new_ID+1, dtype=np.int32)

# New blobs_props DataSet Structure
blobs_props_extended = xr.Dataset(coords={
    'ID': new_ids,
    tracker.timedim: blob_id_field_unique[tracker.timedim]
})

## Create a mapping from new IDs to the original IDs _at the corresponding time_
valid_new_ids = (split_merged_relabeled_blob_id_field > 0)      
original_ids = blob_id_field_unique.where(valid_new_ids).stack(z=(tracker.ydim, tracker.xdim), create_index=False)
new_ids_field = split_merged_relabeled_blob_id_field.where(valid_new_ids).stack(z=(tracker.ydim, tracker.xdim), create_index=False)

id_mapping = xr.Dataset({
    'original_id': original_ids,
    'new_id': new_ids_field
})

In [None]:
# ############ THIS MAKES MANY MANY TASKSSSSSSSS



# transformed_arrays = []
# for new_id in new_ids:
    
#     mask = id_mapping.new_id == new_id
#     mask_time = mask.any('z')
    
#     original_ids = id_mapping.original_id.where(mask, 0).max(dim='z').where(mask_time, 0)
    
#     transformed_arrays.append(original_ids)

# global_id_mapping = xr.concat(transformed_arrays, dim='new_id').assign_coords(new_id=new_ids).rename({'new_id': 'ID'}).astype(np.int32).compute()
# blobs_props_extended['global_ID'] = global_id_mapping
# # N.B.: Now, e.g. global_id_mapping.sel(ID=10) --> Given the new ID (10), returns corresponding original_id at every time



def map_ids_chunk(original_ids, new_ids_field, target_new_id):
    """Process a single chunk/new_id combination."""
    mask = new_ids_field == target_new_id
    if not np.any(mask):
        return np.zeros(1, dtype=np.int32)
    return np.array([np.max(original_ids[mask])], dtype=np.int32)

# Replace your loop with this:
result = xr.apply_ufunc(
    map_ids_chunk,
    id_mapping.original_id,
    id_mapping.new_id,
    xr.DataArray(new_ids, dims='new_id'),
    input_core_dims=[['z'], ['z'], []],
    output_core_dims=[[]],
    vectorize=True,
    dask='parallelized',
    output_dtypes=[np.int32],
    dask_gufunc_kwargs={'allow_rechunk': True}
)

# Convert to the expected format
global_id_mapping = (result
    .assign_coords(new_id=new_ids)
    .rename({'new_id': 'ID'})
    .astype(np.int32)
    .compute())

blobs_props_extended['global_ID'] = global_id_mapping



In [None]:

## Transfer and transform all variables from original blobs_props:
        
# Add a value of ID = 0 to this coordinate ID
dummy = blobs_props.isel(ID=0) * np.nan
blobs_props = xr.concat([dummy.assign_coords(ID=0), blobs_props], dim='ID')


for var_name in blobs_props.data_vars:
    
    temp = (blobs_props[var_name]
                        .sel(ID=global_id_mapping.rename({'ID':'new_id'}))
                        .drop_vars('ID').rename({'new_id':'ID'}))
    
    if var_name == 'ID':
        temp = temp.astype(np.int32)
    else:
        temp = temp.astype(np.float32)
        
    blobs_props_extended[var_name] = temp


# Add start and end time indices for each ID
valid_presence = blobs_props_extended['global_ID'] > 0  # Where we have valid data

blobs_props_extended['presence'] = valid_presence
blobs_props_extended['time_start'] = valid_presence.time[valid_presence.argmax(dim=tracker.timedim)]
blobs_props_extended['time_end'] = valid_presence.time[(valid_presence.sizes[tracker.timedim] - 1) - (valid_presence[::-1]).argmax(dim=tracker.timedim)]

# Combine blobs_props_extended with split_merged_relabeled_blob_id_field
split_merged_relabeled_blobs_ds = xr.merge([split_merged_relabeled_blob_id_field.rename('ID_field'), 
                                            blobs_props_extended])

In [None]:
split_merged_relabeled_blobs_ds = split_merged_relabeled_blobs_ds.persist()