# Notebook for IVT SOM Training

By: Ty Janoski

Updated 1/22/2026

## Setup

### Imports

In [30]:
# Import Statements
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cmweather  # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scienceplots  # noqa: F401
import xarray as xr
from minisom import MiniSom
from sklearn.manifold import MDS
from sklearn.metrics import pairwise_distances

plt.style.use(["science", "nature", "grid"])
plt.rcParams["text.usetex"] = True


In [31]:
# --- Helper Functions ---

def get_node_indices(bmus, i, j):
    """Get sample indices belonging to node (i,j)."""
    return np.where((bmus[:, 0] == i) & (bmus[:, 1] == j))[0]


def compute_composites(data, bmus, xdim, ydim, time_dim="valid_time"):
    """Compute composite mean for each SOM node."""
    sample_shape = data.isel({time_dim: 0}).shape
    composites = np.full((xdim, ydim) + sample_shape, np.nan)
    counts = np.zeros((xdim, ydim), dtype=int)

    for i in range(xdim):
        for j in range(ydim):
            idx = get_node_indices(bmus, i, j)
            counts[i, j] = len(idx)
            if len(idx) > 0:
                composites[i, j] = data.isel({time_dim: idx}).mean(time_dim).values
    return composites, counts


def create_som_figure(xdim, ydim, figsize=(6, 4), dpi=600):
    """Create a standard figure for SOM node plots."""
    fig, axes = plt.subplots(
        ydim, xdim, figsize=figsize,
        subplot_kw={"projection": ccrs.PlateCarree()},
        constrained_layout=True, dpi=dpi,
    )
    return fig, axes


def add_map_features(ax):
    """Add standard map features to an axis."""
    ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
    ax.add_feature(cfeature.STATES.with_scale("50m"), linewidth=0.4)
    ax.set_xticks([])
    ax.set_yticks([])


def plot_node_events(data, bmus, xdim, ydim, lon, lat, levels, cmap,
                     save_pattern, scale=1.0, cbar_label=None, 
                     time_dim="valid_time", contour=False):
    """Plot individual events for each SOM node."""
    cols = 5
    proj = ccrs.PlateCarree()

    for i in range(xdim):
        for j in range(ydim):
            idx = get_node_indices(bmus, i, j)
            n = len(idx)
            rows = int(np.ceil(n / cols))

            fig, axes = plt.subplots(
                rows, cols, figsize=(3 * cols, 2.5 * rows),
                subplot_kw={"projection": proj},
                layout="constrained", dpi=300,
            )

            for k, ax in enumerate(axes.flat):
                if k < n:
                    field = data.isel({time_dim: idx[k]})
                    time_val = field[time_dim].values
                    
                    if contour:
                        im = ax.contour(
                            lon, lat, field.values * scale,
                            levels=levels, colors="black",
                            transform=proj, linewidths=0.6,
                        )
                        ax.clabel(im, im.levels, fontsize=5)
                    else:
                        im = ax.contourf(
                            lon, lat, field.values * scale,
                            levels=levels, cmap=cmap,
                            transform=proj, extend="max",
                        )
                    ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
                    ax.add_feature(cfeature.BORDERS, linewidth=0.3)
                    ax.add_feature(cfeature.STATES, linewidth=0.2)
                    ax.set_title(str(pd.to_datetime(time_val))[:16])
                else:
                    ax.axis("off")

            if cbar_label and not contour:
                cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
                cbar.set_label(cbar_label, fontsize=6)

            fig.suptitle(f"Node ({i},{j})  N={n}", fontsize=8, y=1.02)
            plt.savefig(save_pattern.format(i=i, j=j))
            plt.close(fig)

### Data Loading

In [32]:
# Read in Z500 and IVT at flash-flood event times
path = "/mnt/drive2/SOM_intermediate_files/"
ivt_norm_weighted_ffe = xr.load_dataset(f"{path}era5_ivt_norm_weighted_ffe.nc")["ivt"]
ivt_norm_ffe = xr.load_dataset(f"{path}era5_ivt_norm_ffe.nc")["ivt"]
ivt_ffe = xr.load_dataset(f"{path}era5_ivt_ffe.nc")["ivt"]

z500_norm_weighted_ffe = xr.load_dataarray(f"{path}era5_Z500_norm_weighted_ffe.nc")
z500_norm_ffe = xr.load_dataarray(f"{path}era5_Z500_norm_ffe.nc")
z500_ffe = xr.load_dataarray(f"{path}era5_Z500_ffe.nc")

# Total precipitation and mean sea level pressure at flash-flood event times
tp_ffe = xr.load_dataarray(f"{path}era5_tp_ffe.nc")
mslp_ffe = xr.load_dataarray(f"{path}era5_mslp_ffe.nc")


### Reshape Data

In [33]:
# Flatten the data for SOM training
z500_flat = z500_norm_weighted_ffe.stack(
    features=["lat", "lon"]
).values  # shape: (time, feature)
ivt_flat = ivt_norm_weighted_ffe.stack(
    features=["latitude", "longitude"]
).values  # shape: (time, features)

# Concatenate the data
X = np.concatenate((z500_flat, ivt_flat), axis=1)  # shape: (time, feature)


## SOM Training

We are going to train our SOM with random initialization and online training. We will also use two phases: a "coarse" phase with a larger sigma and learning rate, then a "fine" phase with a smaller learning rate and sigma.

### Set SOM parameters

In [34]:
# Set SOM shape
xdim, ydim = 2, 2

# Set number of iterations for each phase
# Using "literature style": short rough phase, longer fine phase with sigma=1
n1, n2 = 500, 2000

# Set starting sigmas
# Phase 1: large sigma for global ordering; Phase 2: sigma=1 for localized refinement
sig1, sig2 = np.sqrt(xdim**2 + ydim**2), 1.0

# Set starting learning rates
lr1, lr2 = 0.5, 0.1

# Random seed for reproducibility
random_seed = 42

### Train SOM

In [35]:
# Create SOM instance
som = MiniSom(
    xdim,
    ydim,
    input_len=X.shape[1],
    sigma=sig1,
    learning_rate=lr1,
    decay_function="linear_decay_to_zero",
    sigma_decay_function="linear_decay_to_one",
    neighborhood_function="gaussian",
    random_seed=random_seed,
)

# Initialize random weights
som.random_weights_init(X)

# Random training
som.train_random(X, n1, verbose=True)

# Phase 2
som._sigma = sig2 # type: ignore
som._learning_rate = lr2
som.train_random(X, n2, verbose=True)


 [ 500 / 500 ] 100% - 0:00:00 left 
 quantization error: 137.66825300968708
 [ 2000 / 2000 ] 100% - 0:00:00 left 
 quantization error: 135.04204269972294


### Grab important fields

In [36]:
# Total node number
n_nodes = xdim * ydim

# Get flattened weights
weights = som.get_weights().reshape(xdim * ydim, -1)

# u-matrix
u_matrix = som.distance_map().T

# bmus & hit_map
bmus = np.array([som.winner(x) for x in X])

hit_map = np.zeros((xdim, ydim))
for i, j in bmus:
    hit_map[i, j] += 1
hit_map = hit_map.T

# Sammon Coordinates
D = pairwise_distances(weights)
coords = MDS(
    n_components=2,
    metric="precomputed",
    random_state=42,
    n_init=4,
    init="random",  # pyright: ignore[reportCallIssue]
).fit_transform(D)

# Get lats/lons
lat = ivt_norm_ffe.latitude
lon = ivt_norm_ffe.longitude

# Dimensions of the spatial field
n_lat = lat.size
n_lon = lon.size
n_features = n_lat * n_lon

# Split weights into Z500 and IVT components
z500_weights = weights[:, :n_features]
ivt_weights = weights[:, n_features:]

# Reshape weights back to spatial dimensions
z500_nodes = z500_weights.reshape(xdim, ydim, n_lat, n_lon)
ivt_nodes = ivt_weights.reshape(xdim, ydim, n_lat, n_lon)


In [55]:
# Save BMUs with timestamps for cross-SOM analysis
bmu_df = pd.DataFrame({
    "timestamp": pd.to_datetime(ivt_ffe.valid_time.values),
    "node_i": bmus[:, 0],
    "node_j": bmus[:, 1],
})
bmu_df.to_csv("data/som_2x2_bmus.csv", index=False)
print(f"Saved {len(bmu_df)} BMU assignments to data/som_2x2_bmus.csv")

Saved 117 BMU assignments to data/som_2x2_bmus.csv


## Plots

### U-matrix and Sammon Map

In [37]:
fig, axes = plt.subplots(1, 2, layout="constrained", figsize=(6, 3), dpi=600)

# u-matrix
im0 = axes[0].imshow(u_matrix, cmap="viridis", origin="lower")
axes[0].set_title("U-Matrix (Mean Inter-Node Distance)", fontsize=7)
fig.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04, shrink=0.7)

# hit map
im1 = axes[1].imshow(hit_map, cmap="plasma", origin="lower")
axes[1].set_title("Hit Map (Samples per Node)", fontsize=7)
fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04, shrink=0.7)

# axis styling
for ax in axes:
    ax.set_xticks(np.arange(xdim))
    ax.set_yticks(np.arange(ydim))
    ax.set_xlabel("X-index", fontsize=6)
    ax.set_ylabel("Y-index", fontsize=6)

plt.savefig("figs/Z500-and-ivt-SOM/Z500_ivt_som_u_matrix_hit_map.png")
plt.close()

In [38]:
# Flatten u-matrix & hit map
U_flat = u_matrix.T.reshape(-1)      # back to (n_nodes,)
hits_flat = hit_map.T.reshape(-1)    # back to (n_nodes,)

# scale hits
hits_scaled = 30 + 250 * (hits_flat / hits_flat.max())

# plot
plt.figure(figsize=(7, 7))

# Scatter: U controls color, hits control bubble size
sc = plt.scatter(
    coords[:, 0], coords[:, 1],
    c=U_flat,
    s=hits_scaled,
    cmap="balance", edgecolor="k", linewidth=0.5,
    zorder=3
)

# Draw lattice connections (right & down neighbors only)
for i in range(xdim):
    for j in range(ydim):
        node = i * ydim + j

        # right neighbor
        if j + 1 < ydim:
            nbr = i * ydim + (j + 1)
            plt.plot(
                [coords[node, 0], coords[nbr, 0]],
                [coords[node, 1], coords[nbr, 1]],
                "k-", lw=0.6, alpha=0.4
            )

        # down neighbor
        if i + 1 < xdim:
            nbr = (i + 1) * ydim + j
            plt.plot(
                [coords[node, 0], coords[nbr, 0]],
                [coords[node, 1], coords[nbr, 1]],
                "k-", lw=0.6, alpha=0.4
            )

# Node labels (i,j)
for idx, (x, y) in enumerate(coords):
    ix, iy = divmod(idx, ydim)
    plt.text(
        x, y, f"({ix},{iy})",
        fontsize=8, ha="center", va="center", zorder=5
    )

plt.title("Sammon / MDS Distortion Grid\nU-Matrix (Color) \\& Node Frequency (Size)")
plt.axis("off")
plt.colorbar(sc, label="U-Matrix (Avg. Neighbor Distance)")
plt.savefig("figs/Z500-and-ivt-SOM/Z500-ivt_som_sammon_mds.png", bbox_inches="tight")
plt.close()

### Node Weights Map

In [39]:
# Shading levels for standardized anomalies
levels_ivt = np.arange(-1.8, 1.81, 0.2)
levels_Z = np.arange(-1.4, 1.41, 0.2)

fig, axes = create_som_figure(xdim, ydim)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        # IVT shaded
        im = ax.contourf(
            lon, lat, ivt_nodes[i, j],
            cmap="balance", levels=levels_ivt,
            transform=ccrs.PlateCarree(),
        )

        # Z500 contours
        cn = ax.contour(
            lon, lat, z500_nodes[i, j],
            colors="black", linewidths=0.5,
            levels=levels_Z, transform=ccrs.PlateCarree(),
        )
        ax.clabel(cn, inline=True, fontsize=5, fmt="%.1f")
        
        add_map_features(ax)
        ax.set_title(f"Node ({i},{j})", fontsize=6)

cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
cbar.set_label("Standardized IVT Anomaly (Shaded)", fontsize=6)

plt.suptitle("Node Weight Patterns\nZ500 (contoured) + IVT (shaded)", fontsize=8)
plt.savefig("figs/Z500-and-ivt-SOM/combined_node_weights_ivt_shaded.png", bbox_inches="tight")
plt.close()

### Anomaly Composite Map

In [40]:
# Compute standardized anomaly composites
z500_patterns, counts = compute_composites(z500_norm_ffe, bmus, xdim, ydim, time_dim="time")
ivt_patterns, _ = compute_composites(ivt_norm_ffe, bmus, xdim, ydim, time_dim="valid_time")

In [41]:
# Levels for standardized anomaly composites
levels_ivt_anom = np.arange(-2.5, 2.6, 0.5)
levels_Z_anom = np.arange(-2.0, 2.1, 0.25)

fig, axes = create_som_figure(xdim, ydim)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        # IVT shaded
        im = ax.contourf(
            lon, lat, ivt_patterns[i, j],
            cmap="balance", levels=levels_ivt_anom,
            transform=ccrs.PlateCarree(), extend="both",
        )

        # Z500 contours
        ax.contour(
            lon, lat, z500_patterns[i, j],
            colors="black", linewidths=0.5,
            levels=levels_Z_anom, transform=ccrs.PlateCarree(),
        )

        add_map_features(ax)
        ax.set_title(f"({i},{j})  N={counts[i, j]}", fontsize=6)

cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
cbar.set_label("Standardized Anomaly", fontsize=6)

plt.suptitle("SOM Composite Anomalies: Z500 (contoured) + IVT (shaded)", fontsize=8, y=1.04)
plt.savefig("figs/Z500-and-ivt-SOM/Z500_and_ivt_SOM_composite_anomalies_ivt_shaded.png", bbox_inches="tight")
plt.close()

### Composite Mean Map

In [42]:
# Compute raw composites
z500_patterns_raw, _ = compute_composites(z500_ffe, bmus, xdim, ydim, time_dim="time")
ivt_patterns_raw, _ = compute_composites(ivt_ffe, bmus, xdim, ydim, time_dim="valid_time")

In [43]:
# Levels for raw composites
levels_Z_raw = range(552, 595, 3)
levels_ivt_raw = np.arange(0, 701, 100)

fig, axes = create_som_figure(xdim, ydim)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        # IVT shaded
        im = ax.contourf(
            lon, lat, ivt_patterns_raw[i, j],
            cmap="BuPu", levels=levels_ivt_raw,
            transform=ccrs.PlateCarree(), extend="max",
        )

        # Z500 contours
        cn = ax.contour(
            lon, lat, z500_patterns_raw[i, j] / 98.1,
            colors="black", linewidths=0.5,
            levels=levels_Z_raw, transform=ccrs.PlateCarree(),
        )
        ax.clabel(cn, inline=True, fontsize=5, fmt="%.0f")

        add_map_features(ax)
        ax.set_title(f"({i},{j})  N={counts[i, j]}", fontsize=6)

cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
cbar.set_label("IVT (kg m$^{-1}$ s$^{-1}$)", fontsize=6)

plt.suptitle("SOM Composite: IVT (shaded) + Z500 (contoured)", fontsize=8, y=1.04)
plt.savefig("figs/Z500-and-ivt-SOM/Z500_and_ivt_SOM_composite_mean_IVT_shaded.png", bbox_inches="tight")
plt.close()

### Maps of Individual Nodes

In [45]:
# Plot individual IVT events per node
plot_node_events(
    ivt_ffe, bmus, xdim, ydim, lon, lat,
    levels=np.arange(0, 1001, 100), cmap="BuPu",
    save_pattern="figs/Z500-and-ivt-SOM/indiv-nodes/node_{i}_{j}.png",
    cbar_label="IVT (kg m$^{-1}$ s$^{-1}$)",
)

# Plot individual precipitation events per node
plot_node_events(
    tp_ffe, bmus, xdim, ydim, lon, lat,
    levels=np.arange(0, 28, 3), cmap="HomeyerRainbow",
    save_pattern="figs/Z500-and-ivt-SOM/indiv-nodes/node_{i}_{j}_precip.png",
    scale=1000, cbar_label="Total Precipitation (mm)",
)

# Plot individual MSLP events per node
plot_node_events(
    mslp_ffe, bmus, xdim, ydim, lon, lat,
    levels=np.arange(976, 1041, 4), cmap=None,
    save_pattern="figs/ivt-SOM/indiv-nodes/node_{i}_{j}_mslp.png",
    scale=0.01, contour=True,
)

### Maps of Composite MSLP for each SOM node

In [46]:
# Compute MSLP composites
mslp_patterns, _ = compute_composites(mslp_ffe, bmus, xdim, ydim, time_dim="valid_time")

In [47]:
fig, axes = create_som_figure(xdim, ydim, figsize=(6, 2.7))

levels = np.arange(1006, 1025, 2)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]
        
        im = ax.contourf(
            lon, lat, mslp_patterns[i, j] / 100,
            cmap="HomeyerRainbow", levels=levels,
            extend="neither", transform=ccrs.PlateCarree(),
        )
        add_map_features(ax)
        ax.set_title(f"({i},{j})  N={counts[i, j]}", fontsize=6)

cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
cbar.set_label("MSLP (hPa)", fontsize=6)

plt.suptitle("SOM Composite MSLP", fontsize=8, y=1.04)
plt.savefig("figs/Z500-and-ivt-SOM/z500_and_ivt_som_composite_mslp.png", bbox_inches="tight")
plt.close()

### Maps of PMM precipitation for each SOM node

In [48]:
def process_dim(
    da: xr.DataArray,
    ens_dim: str = "valid_time",
    spatial_dims: tuple = ("lat", "lon"),
) -> xr.DataArray:
    """Compute PMM for a single time slice or 2D field"""

    # Ensemble mean field
    ens_mean = da.mean(dim=ens_dim).stack(i=spatial_dims)

    # Sort mean field
    sorted_mean = ens_mean.sortby(ens_mean)

    # Sort all values across ensemble members
    sorted_all = np.sort(da.stack(z=(ens_dim, *spatial_dims)).values)

    # Sample distribution to match spatial field size
    n = sorted_mean.size
    step = len(sorted_all) // n
    sorted_mean.values = sorted_all[::step][:n]

    return sorted_mean.unstack("i")


def pmm(
    da: xr.DataArray,
    ens_dim: str = "valid_time",
    spatial_dims: tuple = ("lat", "lon"),
) -> xr.DataArray:
    """Probability matched mean over ens_dim"""

    return process_dim(da, ens_dim=ens_dim, spatial_dims=spatial_dims)


In [49]:
# Compute PMM patterns for each node
patterns = np.full((xdim, ydim, n_lat, n_lon), np.nan)

for i in range(xdim):
    for j in range(ydim):
        idx = get_node_indices(bmus, i, j)
        if len(idx) > 0:
            node_da = tp_ffe.isel(valid_time=idx)
            patterns[i, j] = pmm(
                node_da, ens_dim="valid_time",
                spatial_dims=("latitude", "longitude"),
            ).values

In [50]:
fig, axes = create_som_figure(xdim, ydim, figsize=(6, 2.7))

levels = np.arange(0, 16, 2)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        im = ax.contourf(
            lon, lat, patterns[i, j] * 1000,
            cmap="viridis", levels=levels,
            extend="max", transform=ccrs.PlateCarree(),
        )
        add_map_features(ax)
        ax.set_title(f"({i},{j})  N={counts[i, j]}", fontsize=6)

cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
cbar.set_label("Total Prec. (mm)", fontsize=6)

plt.suptitle("SOM Composite PMM Prec.", fontsize=8, y=1.04)
plt.savefig("figs/Z500-and-ivt-SOM/Z500_and_ivt_som_composite_tp.png", bbox_inches="tight")
plt.close()

### Month Histograms

In [51]:
# Compute monthly event counts per node
months = pd.to_datetime(ivt_ffe.valid_time).month.to_numpy()
month_counts = {}

for i in range(xdim):
    for j in range(ydim):
        idx = get_node_indices(bmus, i, j)
        node_months = months[idx]
        month_counts[(i, j)] = np.bincount(node_months, minlength=13)[1:]

In [52]:
fig, axes = plt.subplots(ydim, xdim, figsize=(6, 2.7), constrained_layout=True, dpi=600)

# Warm-season labels
month_labels = ["May", "Jun", "Jul", "Aug", "Sep", "Oct"]

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        # Extract May–Oct counts (months 5–10 → indices 4:10)
        counts = month_counts[(i, j)][4:10]

        ax.bar(month_labels, counts, color="teal", alpha=0.9, width=0.8)

        # Title matches your SOM composite style
        ax.set_title(f"({i},{j})  N={counts.sum()}", fontsize=6)

        # Remove ticks entirely (categorical labels don’t need them)
        ax.tick_params(axis="x", bottom=False, labelsize=5)

        # Shared, fixed y-axis across all panels
        ax.set_ylim(0, 14)
        ax.set_yticks(np.arange(0, 15, 2))

        # Light grid for readability
        ax.grid(True, linewidth=0.3, alpha=0.5, axis="y")

# Overall title
plt.suptitle(
    "Warm-Season (May–Oct) Event Distribution per SOM Node", fontsize=8, y=1.04
)
plt.savefig(
    "figs/Z500-and-ivt-SOM/Z500_and_ivt_som_monthly_counts.png", bbox_inches="tight"
)
plt.close()


In [53]:
node_totals = np.array([
    month_counts[(i, j)].sum()
    for i in range(xdim)
    for j in range(ydim)
])

total_events = node_totals.sum()

P_node = node_totals / total_events

# Total events per month (all nodes combined)
all_month_counts = np.bincount(months, minlength=13)[1:]  # 1–12

# Only warm season (May–Oct)
month_idx = np.arange(4, 10)  # indices for May–Oct

n_nodes = xdim * ydim
heatmap = np.zeros((n_nodes, len(month_idx)))

# Flatten node indices consistently
node_labels = []

k = 0
for i in range(xdim):
    for j in range(ydim):

        counts = month_counts[(i, j)][month_idx]
        totals = all_month_counts[month_idx]

        # P(Node | Month)
        heatmap[k, :] = counts / totals
        node_labels.append(f"({i},{j})")

        k += 1

relative_heatmap = np.zeros_like(heatmap)

for k in range(n_nodes):
    relative_heatmap[k, :] = heatmap[k, :] / P_node[k]



In [54]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=600)

im = ax.imshow(relative_heatmap, aspect="auto", cmap="RdBu_r", vmin=0, vmax=2)

# Axes labels
ax.set_xticks(np.arange(len(month_idx)))
ax.set_xticklabels(["May", "Jun", "Jul", "Aug", "Sep", "Oct"], fontsize=7)

ax.set_yticks(np.arange(n_nodes))
ax.set_yticklabels(node_labels, fontsize=6)

ax.set_xlabel("Month")
ax.set_ylabel("SOM Node")

cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Relative Likelihood", fontsize=7)

plt.title(
    "Monthly Relative Likelihood of SOM Nodes\n"
    "(Normalized by Seasonal Event Frequency)",
    fontsize=8,
)

plt.tight_layout()
plt.savefig(
    "figs/Z500-and-ivt-SOM/Z500_and_ivt_som_monthly_relative_heatmap.png",
    bbox_inches="tight",
)
plt.close()
