# Visualise the tracked MHWs

In [None]:
import xarray as xr
import numpy as np
import dask
from getpass import getuser
from pathlib import Path

import spot_the_blOb.helper as hpc

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.colors as mcolors
import cartopy.crs as ccrs
import cartopy.feature as cfeature

In [None]:
# Start Dask Cluster
client = hpc.StartLocalCluster(n_workers=32, n_threads=4)

In [None]:
# Import Tracked MHW DataSet

file_name = Path('/scratch') / getuser()[0] / getuser() / 'mhws' / 'MHWs_tracked.nc'
blobs_ds = xr.open_dataset(file_name, chunks={'time': 50, 'lat': -1, 'lon': -1})
blobs_ds

In [None]:
# Extract the MHW ID Field

blobs = blobs_ds.ID_field
blobs

## Plot some Blobs

In [None]:
blob_subset = blobs.sel(time=slice('2020-01-01', '2021-01-01')).resample(time='MS').first()
maxl = blob_subset.max().compute().item()
minl = blob_subset.min().compute().item()

In [None]:
cm = ListedColormap(np.random.random(size=(int(maxl-minl), 3)).tolist())

In [None]:
blobs_first_day = blobs.sel(time=slice('2020-01-01', '2020-12-31')).resample(time='MS').first()
blobs_first_day.plot(col='time', col_wrap=3, cmap=cm)

## Global MHW Frequency

In [None]:
mhw_frequency = xr.where((blobs == 0) | np.isnan(blobs), 0.0, 1.0).mean('time')
mhw_frequency.plot(cmap='hot_r')

## Find the longest MHWs

In [None]:
final_objects_tracked = blobs.attrs['final objects tracked']
labels = np.arange(final_objects_tracked)

occurrence_array = xr.apply_ufunc(
    lambda blobs_data, labels: np.isin(labels, blobs_data[..., np.newaxis]),  # Check presence...
    blobs,
    input_core_dims=[['lat', 'lon']],
    output_core_dims=[['label']],
    vectorize=True,
    dask='parallelized',
    output_dtypes=[bool],
    output_sizes={'label': final_objects_tracked},
    kwargs={'labels': labels}
)

In [None]:
label_occurrence = occurrence_array.sum(dim='time').compute()

In [None]:
longest_mhws = label_occurrence.argsort()[::-1]

In [None]:
for label in longest_mhws[:10].values:
    print(f"Label: {label}, Time: {label_occurrence.sel(label=label).item()} days")

## Plot a few long MHWs

In [None]:
# mhw_intensity = xr.where(blobs == longest_mhws[:9], 1, 0).sum(dim='time')

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(12, 12))
axes = axes.flatten()

for i, label in enumerate(longest_mhws[:9]):
    ax = axes[i]
    mhw_intensity = xr.where(blobs == longest_mhws[i], 1, 0).sum(dim='time')
    c = ax.pcolor(mhw_intensity, cmap='hot_r')  #.isel(label=i)
    cbar = fig.colorbar(c, ax=ax, orientation='vertical')
    cbar.set_label('Duration (days)')
    ax.set_title(f'Label: {label}')

# Adjust layout
plt.tight_layout()
plt.show()