# Demonstration of accessing PyFLEXTRKR MCS tracking outputs

This notebook is based on Zhe Feng's [Notebook](https://github.com/FlexTRKR/PyFLEXTRKR/blob/main/Notebooks/demo_mcs_track_stats_healpix.ipynb) and modified for UM model and tracking data. It demonstrates how to access and perform simple analysis on the MCS tracking outputs from PyFLEXTRKR. It also gives example code of the basic methods of loading the data using `xarray`, for both the tracks data and the HEALPix pixel-level data, as well as linking the two together. Some simple analysis is done.

This example largely follows that developed by Mark Muetzelfeldt, with updates to the pixel-level data in HEALPix format catalog.

* Authors:
- Torsten Auerswald (t.auerswald@reading.ac.uk)
- Zhe Feng (zhe.feng@pnnl.gov)
- Mark Muetzelfeldt (mark.muetzelfeldt@reading.ac.uk)

### Track statistics files

The MCS tracks statistics dataset are NetCDF4 files. These are laid out as 1D, 2D and 3D arrays of data, with the main coordinate being `tracks`, and variable length data stored in a fixed length array (using a maximum duration of 650, equivalent to an MCS that last for 650 hr). A consequence of this is that all the data over the duration of the length of a given MCS will be NANs or similar, and so the compression ratio of files on disk is very high. Hence, even though a file for a given year is only approximately 500MB on disk, its size in memory will be far larger. Fully loading one year using `xr.load_dataset` uses approximately 16G of memory. It is therefore sensible to use `xr.open_dataset` or `xr.open_mfdataset`, which access the dataset's metadata but do not load its data until they are needed.

Note, the compression level of a field can be seen from `xarray.Dataset`: `dstracks.area.encoding`. The tracks dataset is compressed using compression level 4.

### Pixel-level files

The pixel-level files are HEALPix format Zarr store files. The OLR (`rlut`) and precipitation (`pr`) data are part of the model catalog, while the MCS mask data (`mcs_mask_hp(x)_`) is remapped from lat/lon grid used in PyFLEXTRKR during tracking. The MCS mask data is being added to the catalog.

### Using `xarray` to access data

`xarray` is a convenient way of loading NetCDF or Zarr files in Python. Fields can generally be manipulated using `xarray` methods or by loading the values a `numpy` arrays and manipulating those. This notebook requires having the following Python packages correctly installed (using e.g. `conda`).

### Installing pyflextrkr ###

This notebook uses a pyflextrkr function to smooth MCS trajectories. Since pyflextrkr is not included in the standard hackathon environment, please install it manually, if you want to use that function. On the Jupyter notebook website open a new tab and choose the terminal. In the terminal activate your hackathon environment. After activating the environment install the package with:

`conda install pyflextrkr`

If you already created a notebook kernel from your hackathon environment, pyflextrkr should be available in your notebooks now. If you have not installed the kernel from your hackathon environment, install the kernel with:

`python -m ipykernel install --user --name=name-of-environment`


In [None]:
# from pathlib import Path
import cftime
import cartopy.crs as ccrs
import cartopy.geodesic
import cartopy.feature as cf
import dask
from IPython.display import clear_output
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation
import numpy as np
import pandas as pd
import xarray as xr
import easygems.healpix as egh
import intake
from pyflextrkr.smooth_trajectory import smooth_trajectory

import requests
from io import BytesIO

mpl.rcParams['figure.dpi'] = 72  # Set figure resolution to reduce file size
# Use for .mp4 video:
plt.rcParams["animation.html"] = "html5"
# Use for javascript animation:
# plt.rcParams["animation.html"] = "jshtml"

### Load the catalog
We have one data catalog for the global hackathon, listing our datasets.
But as we have multiple hosting sites, which have some datasets available locally and can access other datasets remotely, the best way to access data may be dependent on the location **where analysis code is executed**.
To solve this issue, we have one sub-catalog per hackathon node (the site where analysis code is executed), and an additional `online` catalog, which is available from the public internet. Here's how you can see our currently available sub-catalogs:

In [None]:
list(intake.open_catalog("https://digital-earths-global-hackathon.github.io/catalog/catalog.yaml"))

In [None]:
# Load the NERSC catalog
current_location = "UK"
cat = intake.open_catalog("https://digital-earths-global-hackathon.github.io/catalog/catalog.yaml")[current_location]

In [None]:
list(cat)

### Pick a Data Set
Use `.describe()` on a dateset to see the other parameter options (we use `pandas` just for concise output formatting)

In [None]:
pd.DataFrame(cat["um_glm_n2560_RAL3p3"].describe()["user_parameters"])

### Load HEALPix Data into a DataSet
Two simulations have been tracked. You can choose the N2560 with `zoom` level 10 [(~6km)](https://easy.gems.dkrz.de/Processing/healpix/index.html#healpix-spatial-resolution) or the N1280 with zoom level 9. The datasets containhourly OLR & precipitation, which was used in the MCS tracking.

In [None]:
#N2560
#runid='um_glm_n2560_RAL3p3'
runid='um_glm_n1280_CoMA9_TBv1p2'

dshp = cat[runid](zoom=9).to_dask() 
dshp


In [None]:
#Get timestep of the data in hours
dt=dshp.time.values[1].astype('datetime64[h]')-dshp.time.values[0].astype('datetime64[h]')
dt

### Load MCS Tracking Data

In [None]:
# MCS data

if runid=='um_glm_n2560_RAL3p3':
  maskdir = "https://hackathon-o.s3-ext.jc.rl.ac.uk/sim-data/analysis/PyFLEXTRKR/um_glm_n2560_RAL3p3/mcstracking/mcs_mask_hp10_20200201.0000_20210301.0000.zarr"
  statsdir = "https://hackathon-o.s3-ext.jc.rl.ac.uk/sim-data/analysis/PyFLEXTRKR/um_glm_n2560_RAL3p3/stats/"
elif runid=='um_glm_n1280_CoMA9_TBv1p2':
  maskdir = "https://hackathon-o.s3-ext.jc.rl.ac.uk/sim-data/analysis/PyFLEXTRKR/um_glm_n1280_CoMA9_TBv1p2_catalog_par/mcstracking/mcs_mask_hp9_20200201.0000_20210301.0000.zarr"
  statsdir = "https://hackathon-o.s3-ext.jc.rl.ac.uk/sim-data/analysis/PyFLEXTRKR/um_glm_n1280_CoMA9_TBv1p2_catalog_par/stats/"
else:
  print(f"Error: Dataset {runid} not available.")

# MCS track statistics file
stats_file = f"{statsdir}mcs_tracks_final_20200201.0000_20210301.0000.nc"
stats_file

In [None]:
response = requests.get(stats_file, stream=True)

# chunks_tracks = {'tracks': 1000, 'times': 400}
chunks_tracks = None

# Open the dataset
dstracks = xr.open_dataset(
    BytesIO(response.content), 
    chunks=chunks_tracks,
    mask_and_scale=True)

# # Each seperate file for each year defines its own index for tracks. Re-index with a global index.
# dstracks["tracks"] = np.arange(0, dstracks.dims["tracks"], 1, dtype=int)

In [None]:
dstracks.chunks

In [None]:
# # Check if the dataset uses dask
# print(f"Dataset uses dask: {dstracks.chunks is not None}")

# # Check if variables are dask arrays
# for var_name in list(dstracks.data_vars)[:]:  # First 3 variables as example
#     print(f"Variable {var_name} is dask array: {isinstance(dstracks[var_name].data, dask.array.Array)}")

In [None]:
# The times have a small offset from the exact times -- e.g. 34500 ns off. Correct this.
# This mostly applies to the satellite data
def round_times_to_nearest_second(dstracks, fields):
    def remove_time_incaccuracy(t):
        # To make this an array operation, you have to use the ns version of datetime64, like so:
        return (np.round(t.astype(int) / 1e9) * 1e9).astype("datetime64[ns]")

    for field in fields:
        dstracks[field].load()
        tmask = ~np.isnan(dstracks[field].values)
        dstracks[field].values[tmask] = remove_time_incaccuracy(
            dstracks[field].values[tmask]
        )

In [None]:
round_times_to_nearest_second(dstracks, ['base_time', 'start_basetime', 'end_basetime'])

In [None]:
dstracks

In [None]:
# Note, 'NaT' means nan (track does not exist at those times)
dstracks.base_time

In [None]:
# Values can be accessed as a `numpy` array:
dstracks.ccs_area.values

In [None]:
# Compression level info:
dstracks.ccs_area.encoding

### Selecting tracks

In [None]:
# A single track can be selected from its track number:
track = dstracks.sel(tracks=20)
track

In [None]:
# You might want to select tracks based on e.g. the time at which they are active:
datetime = pd.Timestamp('2020-02-01 12:00').to_numpy()
# `isel` selects on index. The expression on the RHS collapses the 2D field of `base_time` into a 
# 1D boolean field that is true if *any* base_time for a given track matches `datetime`.
dstracks_at_time = dstracks.isel(
     tracks=(dstracks.base_time.values == datetime).any(axis=1)
)
dstracks_at_time

In [None]:
# Or access tracks based on their location:
# N.B. force a load of meanlat.
dstracks.meanlat.load()
# This suppresses a warning about chunk sizes.
with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    dstracks_tropical = dstracks.isel(
        tracks=((dstracks.meanlat.values > -20) & (dstracks.meanlat.values < 20)).any(axis=1)
    )
dstracks_tropical

In [None]:
# Each track can then be looped over using e.g.:
for track_id in dstracks_at_time.tracks.values[:10]:  # Just get first 10.
    track = dstracks_at_time.sel(tracks=track_id)
    print(f"Track ID: {track.tracks.values.item()}, duration: {track.track_duration.values.item()}")

### Individual track properties

In [None]:
# Select a track.
track = dstracks.sel(tracks=50)
track

In [None]:
# Access some of the track's properties:
# Scalars can be accessed using:
duration = track.track_duration.values.item()  # `.values.item()` unpacks the value into an int (in this case).
print(f'track_duration: {duration*dt}')
# For times, it is useful to convert from a np.datetime64[ns] to a pandas Timestamp object or native datetime object:
start_basetime = pd.Timestamp(track.start_basetime.values.item()).to_pydatetime()
print(f'start_basetime: {start_basetime}')
end_basetime = pd.Timestamp(track.end_basetime.values.item()).to_pydatetime()
print(f'end_basetime: {end_basetime}')

# Lifetime values can be accessed:
# Note, e.g. area values are nan after the duration of the track:
print(f'area (full): {track.area.values}')
# Slice based on duration:
print(f'area (sliced): {track.area.values[:duration]}')

In [None]:
# Plot multiple variables for this track

fig, ax1 = plt.subplots(figsize=(8, 4))
track.ccs_area.plot.line(ax=ax1, label='CCS Area')

ax2 = ax1.twinx()
track.corecold_mintb.plot(ax=ax2, color='green', label='Min Tb')
ax2.spines['right'].set_position(('outward', 60))  # Move the third y-axis to the right
ax2.invert_yaxis()

ax3 = ax1.twinx()
track.pf_maxrainrate.isel(nmaxpf=0).plot(ax=ax3, color='orange', label='Max Rain Rate')
ax3.set_title('')

ax4 = ax1.twinx()
track.total_rain.plot(ax=ax4, color='red', label='Total Rain')
ax4.spines['right'].set_position(('outward', 120))  # Move the third y-axis to the right

fig.legend()

In [None]:
# mergers data are 2D fields (with -9999 indicating no values for merger number N):
track.merge_cloudnumber.values[:duration, :]

In [None]:
# Similarly for PF data (with nan indicating no value):
track.pf_rainrate.values[:duration, :]

In [None]:
# A simple plot of the track's position can be made using:
plt.scatter(track.meanlon.values[0], track.meanlat.values[0], marker='o', color='k')  # Start point.
plt.scatter(track.meanlon.values[duration - 1], track.meanlat.values[duration - 1], marker='x', color='darkorange')  # End point.
# Plot the track's path:
plt.plot(track.meanlon.values, track.meanlat.values, label='Original Track Path')
# Smooth the trajectory using the `smooth_trajectory` function:
time_resolution = dstracks.attrs['time_resolution_hour']
max_speed_kmh = 100
lon_s, lat_s = smooth_trajectory(track.meanlon.values, track.meanlat.values, max_speed_kmh=max_speed_kmh, time_step_h=time_resolution)
plt.plot(lon_s, lat_s, color='red', label='Smoothed Track Path')
plt.legend()

In [None]:
def add_ccs_area_swath(ax, track, n_points=20):
    """Adds an area swath of the cold cloud system area, treating each CCS as a circle."""
    try:
        # N.B. these are optional dependencies.
        import shapely.geometry
        import shapely.ops
    except ImportError:
        print('shapely not installed')
        return
    duration = track.track_duration.values.item()
    time_indices = range(duration)
    
    # geoms will contain all the circles.
    geoms = []
    for i in time_indices:
        lon = track.meanlon.values[i]
        lat = track.meanlat.values[i]
        radius = np.sqrt(track.ccs_area.values[i] / np.pi) * 1e3
        circle_points = cartopy.geodesic.Geodesic().circle(
            lon=lon, lat=lat, radius=radius, n_samples=n_points, endpoint=False
        )
        geom = shapely.geometry.Polygon(circle_points)
        geoms.append(geom)
    # Combine all the circles into a CCS swath.
    full_geom = shapely.ops.unary_union(geoms)
    ax.add_geometries(
        (full_geom,),
        crs=cartopy.crs.PlateCarree(),
        facecolor="none",
        edgecolor="royalblue",
        linewidth=2,
    )

In [None]:
extent = track.meanlon.min().item() - 10, track.meanlon.max().item() + 10, track.meanlat.min().item() - 10, track.meanlat.max().item() + 10

In [None]:
# Nicer figure using cartopy projections and showing a circle based on the CCS area.
fig, ax = plt.subplots(subplot_kw={'projection': ccrs.PlateCarree()})
fig.set_size_inches((8, 4))
ax.coastlines()
ax.add_feature(cf.BORDERS, linewidth=0.4)
ax.scatter(track.meanlon.values[0], track.meanlat.values[0], marker='o', color='k')
ax.scatter(track.meanlon.values[duration - 1], track.meanlat.values[duration - 1], marker='x', color='darkorange')
ax.plot(track.meanlon.values, track.meanlat.values)
ax.plot(lon_s, lat_s, color='red')
add_ccs_area_swath(ax, track)
buffer = 10
extent = track.meanlon.min().item() - buffer, track.meanlon.max().item() + buffer, track.meanlat.min().item() - buffer, track.meanlat.max().item() + buffer
ax.set_extent(extent)

### Group properties of tracks

Group properties for lots of tracks can be easily calculated by accessing the fields on an `xarray.Dataset` that contains many tracks.

In [None]:
# mean track duration:
dstracks.track_duration.values.mean()

In [None]:
# Same thing using xarray:
dstracks.track_duration.mean().values.item()

In [None]:
# Tropical duration:
dstracks_tropical.track_duration.values.mean()

In [None]:
# Mean area for each track:
# `np.nanX` functions are useful for naturally dealing with the missing data.
mean_areas = np.nanmean(dstracks.ccs_area.values, axis=1)

In [None]:
plt.hist(mean_areas, bins=np.linspace(0, 1e6, 101))

In [None]:
# Same thing using xarray:
xr_mean_areas = dstracks.ccs_area.mean(dim='times', skipna=True).values

In [None]:
(mean_areas == xr_mean_areas).all()

## Accessing pixel-level data

Pixel-level data is stored separately from the MCS tracks data. These files contain the mcs_mask field, which is area covered by the MCS and can be linked to a given track (see below).

In [None]:
# Read MCS mask data:
#dsmask = xr.open_dataset(
dsmask = xr.open_zarr(
    maskdir,
    chunks={},
    # mask_and_scale=False,
)
dsmask

In [None]:
pd.Timestamp(dshp.time.values[0]).year

In [None]:
pd.Timestamp(dsmask.time.values[0]).year

In [None]:
# Convert time components from datasets to deal with different calendars
hp_times = [(pd.Timestamp(t).year, pd.Timestamp(t).month, 
               pd.Timestamp(t).day, pd.Timestamp(t).hour) for t in dshp.time.values]
mask_times = [(pd.Timestamp(t).year, pd.Timestamp(t).month, 
               pd.Timestamp(t).day, pd.Timestamp(t).hour) 
              for t in dsmask.time.values]

# Find matching indices
matching_indices = [i for i, t in enumerate(hp_times) if t in set(mask_times)]
# Select using indices
dshp_subset = dshp.isel(time=matching_indices)
# Verify number of timestamps
print(f"Original HEALPix time count: {dshp.time.size}")
print(f"Matching times found: {len(matching_indices)}")
print(f"dsmask time count: {dsmask.time.size}")

In [None]:
# Create a new dsmask with the time coordinate from dshp_subset
dsmask_ = dsmask.copy()

# Replace the time coordinate with the matching times from dshp_subset
# This must maintain the same order as in the original dsmask
dsmask_ = dsmask_.assign_coords(time=dshp_subset.time.values)

# Combine the datasets
ds = xr.merge([dshp_subset, dsmask_], combine_attrs='drop_conflicts')

# Rechunk the dataset to match the mask
ds = ds.chunk(dict(dsmask_.chunks))
ds

In [None]:
# Quick plot of global OLR
im = egh.healpix_show(ds.isel(time=0).rlut, vmin=80, vmax=300, cmap='Greys')
plt.colorbar(im, label=f"{ds.rlut.attrs['long_name']} ({ds.rlut.attrs['units']})", orientation='horizontal')
time_str = ds.time.isel(time=0).dt.strftime("%Y-%m-%d %H:%M").values.item()
plt.title(time_str)

In [None]:
# Quick plot of global precipitation
levels_pr = [0.5,1,2,4,8,16,32,64]
cmap = mpl.colormaps.get_cmap('jet')
norm = mpl.colors.BoundaryNorm(boundaries=levels_pr, ncolors=cmap.N)
# Precipitation
pcp_convert=3600 
_pr = ds.isel(time=0).pr * pcp_convert # Convert unit from kg/m^2/s to mm/hr, if you use a different model check the units of your precip output 
_pr = _pr.where(_pr >= np.min(levels_pr), np.nan)

im = egh.healpix_show(_pr, norm=norm, cmap=cmap)
plt.colorbar(im, label=f"{ds.pr.attrs['long_name']} (mm/h)", orientation='horizontal')
time_str = ds.time.isel(time=0).dt.strftime("%Y-%m-%d %H:%M").values.item()
plt.title(time_str)

## Linking tracks to pixel-level data

Every track has a corresponding track number in the pixel-level mask data. This can be used to link each track to its pixel-level data, using the timestamp and the track number. The timestamp can be used to determine which time step data to load, and the track number references the equivalent field in the HEALPix dataset (as shown in the figure above).

In [None]:
itrack = 6


# Select a tropical track from Feb 2020, that lasted more than buffer hours and less than 40.
dates = pd.DatetimeIndex(dstracks_tropical.start_basetime.values)  # These containers make it easy to select on year, month...
track = dstracks_tropical.isel(tracks=(
    (dates.year == 2020) & 
    (dates.month == 2) & 
    (dstracks_tropical.track_duration > 20) & 
    (dstracks_tropical.track_duration < 40)
)).isel(tracks=itrack)  # just select one track that meets criteria.
duration = track.track_duration.values.item()
duration

In [None]:
# Create a figure with 1 row and 2 columns
fig = plt.figure(figsize=(16, 8))

# Create left subplot (regular axis)
ax1 = fig.add_subplot(1, 2, 1)
track.ccs_area.plot(ax=ax1, label='CCS Area')
ax1b = ax1.twinx()
track.pf_area.isel(nmaxpf=0).plot(ax=ax1b, color='green', label='PF Area')
# ax1b.spines['right'].set_position(('outward', 60))  # Move the third y-axis to the right
# ax1b.invert_yaxis()
ax1b.set_title('')
ax1.legend()

# Create right subplot (with Cartopy projection)
# Use PlateCarree projection (equirectangular projection)
ax2 = fig.add_subplot(1, 2, 2, projection=ccrs.PlateCarree())

ax2.coastlines()
ax2.add_feature(cf.BORDERS, linewidth=0.4)

ax2.scatter(track.meanlon.values[0], track.meanlat.values[0], marker='o', color='k')
ax2.scatter(track.meanlon.values[duration - 1], track.meanlat.values[duration - 1], marker='x', color='darkorange')
ax2.plot(track.meanlon.values, track.meanlat.values)
add_ccs_area_swath(ax2, track)
buffer = 8
extent = track.meanlon.min().item()-buffer, track.meanlon.max().item()+buffer, track.meanlat.min().item() - buffer, track.meanlat.max().item() + buffer
ax2.set_extent(extent)

fig.tight_layout()
# fig.tight_layout(rect=[0, 0.03, 1, 0.95])  # Make room for suptitle

In [None]:
def convert_to_cftime(datetime, calendar):
    """
    Convert a pandas.Timestamp object to a cftime object based on the calendar type.

    Args:
        datetime: pandas.Timestamp
            Timestamp object to convert.
        calendar: str
            Calendar type.

    Returns:
        cftime object.
    """
    if calendar == 'noleap':
        return cftime.DatetimeNoLeap(datetime.year, datetime.month, datetime.day, datetime.hour, datetime.minute)
    elif calendar == 'gregorian':
        return cftime.DatetimeGregorian(datetime.year, datetime.month, datetime.day, datetime.hour, datetime.minute)
    elif calendar == 'proleptic_gregorian':
        return cftime.DatetimeProlepticGregorian(datetime.year, datetime.month, datetime.day, datetime.hour, datetime.minute)
    elif calendar == 'standard':
        return cftime.DatetimeGregorian(datetime.year, datetime.month, datetime.day, datetime.hour, datetime.minute)
    elif calendar == '360_day':
        return cftime.Datetime360Day(datetime.year, datetime.month, datetime.day, datetime.hour, datetime.minute)
    else:
        raise ValueError(f"Unsupported calendar type: {calendar}")

In [None]:
# These can be used to work out which pixel-level data to locate the times.
calendar = ds['time'].dt.calendar

base_times = track.base_time.values[:duration]
track_dates = [convert_to_cftime(pd.Timestamp(d).to_pydatetime(), calendar) for d in base_times]
print(len(track_dates))
track_dates

In [None]:
# Track number in the mask file is +1 offset to the track index in the stats file.
track_number = track.tracks + 1

In [None]:
# Increase the animation embed limit [MB]
mpl.rcParams['animation.embed_limit'] = 30

In [None]:
%%capture
# Prev line ensures figure not shown until animation.fig, ax = plt.subplots(subplot_kw={'projection': ccrs.PlateCarree()})
# Set up a figure to use for the animation below.

margin_degree = 10
minlon = track.meanlon.values[:duration].min()
maxlon = track.meanlon.values[:duration].max()
minlat = track.meanlat.values[:duration].min()
maxlat = track.meanlat.values[:duration].max()
aspect = (maxlon - minlon + 2 * margin_degree) / (maxlat - minlat + 2 * margin_degree)
fig, ax = plt.subplots(subplot_kw={'projection': ccrs.PlateCarree()})
fig.set_size_inches((16, 16 / aspect))
fig.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.95)

In [None]:
# Create an animation of the track, MCS, OLR, and precip.
def plot_track_link_pixel(i):
    try:
        print(f'{i + 1}/{duration}')
        # print(f'{i}/{duration}')
        ax.clear()

        ax.set_extent((minlon - margin_degree, maxlon + margin_degree, minlat - margin_degree, maxlat + margin_degree))
        ax.coastlines(zorder=3)
        ax.add_feature(cf.BORDERS, linewidth=0.4, zorder=3)

        date = track_dates[i]
        ax.set_title(date)
        # Subset the HEALPix dataset to the current date & time
        _ds = ds.sel(time=track_dates[i], method='nearest')

        # Colormaps
        cmap_pr = mpl.colormaps.get_cmap('jet')
        cmap_rlut = mpl.colormaps.get_cmap('Greys')
        levels_pr = [0.5,1,2,4,8,16,32,64]
        norm = mpl.colors.BoundaryNorm(boundaries=levels_pr, ncolors=cmap_pr.N)
        # Precipitation
        _pr = _ds.pr * pcp_convert  # Convert unit to mm/hr
        _pr = _pr.where(_pr >= np.min(levels_pr), np.nan)
        # Mask
        _mask = _ds.mcs_mask.load()
        _mask = _mask.where(_mask == track_number, 1, 0)
        Zm_mask = np.ma.masked_where(_mask == 0, _mask)
        # Colorfill OLR and precipitation
        im_rlut = egh.healpix_show(_ds.rlut, ax=ax, vmin=80, vmax=260, cmap=cmap_rlut, alpha=0.7, zorder=1)
        im_pr = egh.healpix_show(_pr, ax=ax, cmap=cmap_pr, norm=norm, alpha=0.9, zorder=1)
        im_mask = egh.healpix_show(Zm_mask, ax=ax, cmap='Reds', vmin=1, vmax=1.01, alpha=0.5, zorder=2)
        # Contour MCS mask (boundary)
        # im_mask = egh.healpix_contour(_mask, ax=ax, levels=[0.5], colors=['r'], linewidths=1, alpha=1)
        buffer=8
        extent = track.meanlon.min().item()-buffer, track.meanlon.max().item()+buffer, track.meanlat.min().item() - buffer, track.meanlat.max().item() + buffer
        ax.set_extent(extent)
        
        # Display track path.
        ax.scatter(track.meanlon.values[0], track.meanlat.values[0], marker='^', c='maroon', zorder=3)
        ax.scatter(track.meanlon.values[duration - 1], track.meanlat.values[duration - 1], marker='x', c='maroon', zorder=3)
        ax.scatter(track.meanlon.values[i], track.meanlat.values[i], marker='o', c='firebrick', zorder=3)
        ax.plot(track.meanlon.values, track.meanlat.values, 'r-', zorder=2)
        # clear_output(wait=True)

        return ax
    except Exception as e:
        print(f"Error in frame {i}: {str(e)}")

In [None]:
# Create animation with specific settings
anim = matplotlib.animation.FuncAnimation(
    fig, 
    plot_track_link_pixel, 
    frames=duration,  # Number of frames
    # frames=2,
    interval=500,
    blit=False,  # Set to False for complex plots with multiple elements
    cache_frame_data=False  # Disable caching to avoid memory issues
)

anim

* Track: red line with circle showing its centroid.
* MCS mask: red shading.
* OLR: grey filled.
* precipitation: color filled contours.