# Notebook for SOM Training

By: Ty Janoski

Updated 1/11/2026

## Setup

### Imports

In [4]:
# Import Statements
import glob
import os

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cmweather  # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scienceplots  # noqa: F401
import xarray as xr
from minisom import MiniSom
from sklearn.manifold import MDS
from sklearn.metrics import pairwise_distances

plt.style.use(["science", "nature", "grid"])
plt.rcParams["text.usetex"] = True


### Data Loading

In [5]:
# Read in Z500 at flash-flood event times
path = "/mnt/drive2/SOM_intermediate_files/"

# Z500
Z500_daily = xr.load_dataarray(f"{path}era5_Z500_daily.nc")
Z500_norm_daily = xr.load_dataarray(f"{path}era5_Z500_norm_daily.nc")
Z500_norm_weighted_daily = xr.load_dataarray(f"{path}era5_Z500_norm_weighted_daily.nc")

# IVT
IVT_daily = xr.load_dataset(f"{path}era5_ivt_daily.nc")["ivt"]
IVT_norm_daily = xr.load_dataset(f"{path}era5_ivt_norm_daily.nc")["ivt"]
IVT_norm_weighted_daily = xr.load_dataset(f"{path}era5_ivt_norm_weighted_daily.nc")[
    "ivt"
]


### Reshape Data

In [6]:
# Flatten the data for SOM training
Z500_flat = Z500_norm_weighted_daily.stack(
    features=["lat", "lon"]
).values  # shape: (time, space)
IVT_flat = IVT_norm_weighted_daily.stack(
    features=["latitude", "longitude"]
).values  # shape: (time, space)

X = np.concatenate([Z500_flat, IVT_flat], axis=1)  # shape: (time, space*2)


In [7]:
# Read flash flood events and filter to unique episodes
df = pd.read_csv("data/storm_data_search_results.csv")
df = df[df["EVENT_ID"].astype(str).str.isdigit()].drop_duplicates(
    subset=["EPISODE_ID"], keep="first"
)

# Parse begin datetime: combine date and time, convert to UTC
df["BEGIN_DATETIME"] = (
    pd.to_datetime(
        df["BEGIN_DATE"]
        + " "
        + df["BEGIN_TIME"].fillna(0).astype(int).astype(str).str.zfill(4),
        format="%m/%d/%Y %H%M",
        errors="coerce",
    )
    .dt.tz_localize("US/Eastern", ambiguous="NaT", nonexistent="NaT")
    .dt.tz_convert("UTC")
)

# Extract unique event days (timezone-naive for xarray compatibility)
event_days = sorted(df["BEGIN_DATETIME"].dt.floor("D").dt.tz_localize(None).unique())


## SOM Training

We are going to train our SOM with random initialization and online training. We will also use two phases: a "coarse" phase with a larger sigma and learning rate, then a "fine" phase with a smaller learning rate and sigma.

### Set SOM parameters

In [8]:
# Set SOM shape
xdim, ydim = 5, 4

# Set number of iterations for each phase
n1, n2 = 20000, 30000

# Set starting sigmas
sig1, sig2 = 0.6 * np.sqrt(xdim**2 + ydim**2), 2.0

# Set starting learning rates
lr1, lr2 = 0.3, 0.1

# Random seed for reproducibility
random_seed = 42


### Train SOM

In [9]:
# Create SOM instance
som = MiniSom(
    xdim,
    ydim,
    input_len=X.shape[1],
    sigma=sig1,
    learning_rate=lr1,
    decay_function="linear_decay_to_zero",
    sigma_decay_function="linear_decay_to_one",
    neighborhood_function="gaussian",
    random_seed=random_seed,
)

# Initialize random weights
som.random_weights_init(X)

# Random training
som.train_random(X, n1, verbose=True)
print(som.topographic_error(X))

# Phase 2
som._sigma = sig2  # type: ignore
som._learning_rate = lr2
som.train_random(X, n2, verbose=True)
print(som.topographic_error(X))


 [ 20000 / 20000 ] 100% - 0:00:00 left 
 quantization error: 115.35613821769022
0.0001874062968515742
 [ 30000 / 30000 ] 100% - 0:00:00 left 
 quantization error: 114.37786864359009
0.0020614692653673165


### Grab important fields

In [10]:
# Total node number
n_nodes = xdim * ydim

# Get flattened weights
weights = som.get_weights().reshape(xdim * ydim, -1)

# u-matrix
u_matrix = som.distance_map().T

# bmus & hit_map
bmus = np.array([som.winner(x) for x in X])

hit_map = np.zeros((xdim, ydim))
for i, j in bmus:
    hit_map[i, j] += 1
hit_map = hit_map.T

# Sammon Coordinates
D = pairwise_distances(weights)
coords = MDS(
    n_components=2, dissimilarity="precomputed", random_state=42, n_init=4
).fit_transform(D)

# Get lats/lons
lat = Z500_norm_weighted_daily.lat
lon = Z500_norm_weighted_daily.lon

# Dimensions of the spatial field
n_lat = lat.size
n_lon = lon.size
n_features = n_lat * n_lon

# Split weights into Z500 and IVT components
z500_weights = weights[:, :n_features]
ivt_weights = weights[:, n_features:]

# Reshape weights back to spatial dimensions
z500_nodes = z500_weights.reshape(xdim, ydim, n_lat, n_lon)
ivt_nodes = ivt_weights.reshape(xdim, ydim, n_lat, n_lon)




## Plots

### U-matrix and Sammon Map

In [12]:
fig, axes = plt.subplots(1, 2, layout="constrained", figsize=(6, 3), dpi=600)

# u-matrix
im0 = axes[0].imshow(u_matrix, cmap="viridis", origin="lower")
axes[0].set_title("U-Matrix (Mean Inter-Node Distance)", fontsize=7)
fig.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04, shrink=0.7)

# hit map
im1 = axes[1].imshow(hit_map, cmap="plasma", origin="lower")
axes[1].set_title("Hit Map (Samples per Node)", fontsize=7)
fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04, shrink=0.7)

# axis styling
for ax in axes:
    ax.set_xticks(np.arange(xdim))
    ax.set_yticks(np.arange(ydim))
    ax.set_xlabel("X-index", fontsize=6)
    ax.set_ylabel("Y-index", fontsize=6)

plt.savefig("figs/Z500-IVT-big-SOM//Z500_som_u_matrix_hit_map.png")
plt.close()

In [13]:
# Flatten u-matrix & hit map
U_flat = u_matrix.T.reshape(-1)  # back to (n_nodes,)
hits_flat = hit_map.T.reshape(-1)  # back to (n_nodes,)

# scale hits
hits_scaled = 30 + 250 * (hits_flat / hits_flat.max())

# plot
plt.figure(figsize=(7, 7))

# Scatter: U controls color, hits control bubble size
sc = plt.scatter(
    coords[:, 0],
    coords[:, 1],
    c=U_flat,
    s=hits_scaled,
    cmap="balance",
    edgecolor="k",
    linewidth=0.5,
    zorder=3,
)

# Draw lattice connections (right & down neighbors only)
for i in range(xdim):
    for j in range(ydim):
        node = i * ydim + j

        # right neighbor
        if j + 1 < ydim:
            nbr = i * ydim + (j + 1)
            plt.plot(
                [coords[node, 0], coords[nbr, 0]],
                [coords[node, 1], coords[nbr, 1]],
                "k-",
                lw=0.6,
                alpha=0.4,
            )

        # down neighbor
        if i + 1 < xdim:
            nbr = (i + 1) * ydim + j
            plt.plot(
                [coords[node, 0], coords[nbr, 0]],
                [coords[node, 1], coords[nbr, 1]],
                "k-",
                lw=0.6,
                alpha=0.4,
            )

# Node labels (i,j)
for idx, (x, y) in enumerate(coords):
    ix, iy = divmod(idx, ydim)
    plt.text(x, y, f"({ix},{iy})", fontsize=8, ha="center", va="center", zorder=5)

plt.title("Sammon / MDS Distortion Grid\nU-Matrix (Color) \\& Node Frequency (Size)")
plt.axis("off")
plt.colorbar(sc, label="U-Matrix (Avg. Neighbor Distance)")
plt.savefig("figs/Z500-IVT-big-SOM/Z500_som_sammon_mds.png", bbox_inches="tight")
plt.close()


### Node Weights Map

In [14]:
fig, axes = plt.subplots(
    ydim, xdim,
    figsize=(6, 3.7),
    subplot_kw={'projection': ccrs.PlateCarree()},
    constrained_layout=True,
    dpi=600
)

# Shading levels for Z500
levels_Z = np.arange(-1.4, 1.41, 0.2)

# IVT contour levels
levels_ivt = np.arange(-1.8, 1.81, 0.2)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        # Fields for this node
        Z_field = z500_nodes[i, j, :, :]
        ivt_field = ivt_nodes[i, j, :, :]

        # --- Z500 shaded ---
        im = ax.contourf(
            lon,
            lat,
            Z_field,
            cmap="balance",
            levels=levels_Z,
            transform=ccrs.PlateCarree(),
        )

        # --- IVT contours (black depending on preference) ---
        cn = ax.contour(
            lon,
            lat,
            ivt_field,
            colors="black",
            linewidths=0.5,
            levels=levels_ivt,
            transform=ccrs.PlateCarree(),
        )

        ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
        ax.add_feature(cfeature.STATES.with_scale("50m"), linewidth=0.4)
        ax.set_title(f"Node ({i},{j})", fontsize=6)
        ax.set_xticks([])
        ax.set_yticks([])

# One shared colorbar
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
cbar.set_ticks(levels_Z)
cbar.set_label("Standardized 500-hPa Anomaly", fontsize=6)

plt.suptitle("Node Weight Patterns", fontsize=8)
plt.savefig("figs/Z500-IVT-big-SOM/node_weights.png", bbox_inches="tight")
plt.close()

### Anomaly Composite Map

In [15]:
som_days = pd.to_datetime(Z500_norm_daily.time.values).tz_localize(None)

event_mask = np.isin(som_days.normalize(), pd.to_datetime(event_days))
event_indices = np.where(event_mask)[0]

# Create empty arrays for standardized anomalies
z500_patterns = np.full((xdim, ydim, n_lat, n_lon), np.nan)
ivt_patterns = np.full((xdim, ydim, n_lat, n_lon), np.nan)

counts = np.zeros((xdim, ydim), dtype=int)
totals = np.zeros((xdim, ydim), dtype=int)

for i in range(xdim):
    for j in range(ydim):
        # All days assigned to this node
        idx_node = np.where((bmus[:, 0] == i) & (bmus[:, 1] == j))[0]
        totals[i, j] = len(idx_node)

        # Flash-flood days within this node
        idx_event = np.intersect1d(idx_node, event_indices)
        counts[i, j] = len(idx_event)

        # Composite over *all* days in the node
        if len(idx_node) > 0:
            z500_patterns[i, j] = (
                Z500_norm_daily.isel(time=idx_node).mean("time").values
            )
            ivt_patterns[i, j] = (
                IVT_norm_daily.isel(valid_time=idx_node).mean("valid_time").values
            )


risk = np.zeros((xdim, ydim))
risk[totals > 0] = counts[totals > 0] / totals[totals > 0]


In [16]:
fig, axes = plt.subplots(
    ydim,
    xdim,
    figsize=(6, 3.7),
    subplot_kw={"projection": ccrs.PlateCarree()},
    constrained_layout=True,
    dpi=600,
)

# Levels for shading (Z500)
levels_Z = np.arange(-2.0, 2.1, 0.25)

# Fewer contour levels for ivt (to avoid clutter)
levels_ivt = np.arange(-2.5, 2.6, 0.5)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        # pull the Z500 & ivt composite fields for this node
        Z_field = z500_patterns[i, j, :, :]
        ivt_field = ivt_patterns[i, j, :, :]

        # --- ivt shaded composite ---
        im = ax.contourf(
            lon,
            lat,
            ivt_field,
            cmap="balance",
            levels=levels_ivt,
            transform=ccrs.PlateCarree(),
            extend="both",
        )

        # --- z500 contour overlay ---
        ax.contour(
            lon,
            lat,
            Z_field,
            colors="black",
            linewidths=0.5,
            levels=levels_Z,
            transform=ccrs.PlateCarree(),
        )

        ax.add_feature(cfeature.COASTLINE, linewidth=0.6)
        ax.add_feature(cfeature.STATES.with_scale("50m"), linewidth=0.4)
        ax.set_title(
            f"({i},{j})  FFE={counts[i, j]}/{totals[i, j]}  ({100 * risk[i, j]:.1f}\\%)",
            fontsize=5,
        )

        ax.set_xticks([])
        ax.set_yticks([])

# one colorbar
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
cbar.set_label("Standardized Anomaly", fontsize=6)

plt.suptitle(
    "SOM Composite Anomalies: Z500 (contoured) + IVT (shaded)", fontsize=8, y=1.04
)
plt.savefig(
    "figs/Z500-IVT-big-SOM/composite_anomalies.png", bbox_inches="tight"
)
plt.close()


### Composite Mean Map

In [17]:
z500_patterns_raw = np.full((xdim, ydim, n_lat, n_lon), np.nan)
ivt_patterns_raw = np.full((xdim, ydim, n_lat, n_lon), np.nan)
counts = np.zeros((xdim, ydim), dtype=int)
totals = np.zeros((xdim, ydim), dtype=int)

for i in range(xdim):
    for j in range(ydim):
        # All days assigned to this node
        idx_node = np.where((bmus[:, 0] == i) & (bmus[:, 1] == j))[0]
        totals[i, j] = len(idx_node)

        # Flash-flood days within this node
        idx_event = np.intersect1d(idx_node, event_indices)
        counts[i, j] = len(idx_event)

        # Composite over *all* days in the node
        if len(idx_node) > 0:
            z500_patterns_raw[i, j] = Z500_daily.isel(time=idx_node).mean("time").values
            ivt_patterns_raw[i, j] = (
                IVT_daily.isel(valid_time=idx_node).mean("valid_time").values
            )

risk = np.zeros((xdim, ydim))
risk[totals > 0] = counts[totals > 0] / totals[totals > 0]


In [18]:
fig, axes = plt.subplots(
    ydim,
    xdim,
    figsize=(6, 3.7),
    subplot_kw={"projection": ccrs.PlateCarree()},
    constrained_layout=True,
    dpi=600,
)

# Levels for shading (Z500)
levels_Z = range(552, 595, 3)

# IVT levels
levels_ivt = np.arange(0, 701, 100)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        # pull the Z500 & ivt composite fields for this node
        Z_field = z500_patterns_raw[i, j, :, :]
        ivt_field = ivt_patterns_raw[i, j, :, :]

        # --- IVT shaded composite ---
        im = ax.contourf(
            lon,
            lat,
            ivt_field,
            cmap="BuPu",
            levels=levels_ivt,
            transform=ccrs.PlateCarree(),
            extend="max"
        )

        # --- ivt contour overlay ---
        cn = ax.contour(
            lon,
            lat,
            Z_field / 98.1,
            colors="black",
            linewidths=0.5,
            levels=levels_Z,
            transform=ccrs.PlateCarree(),
        )

        ax.add_feature(cfeature.COASTLINE, linewidth=0.6)
        ax.add_feature(cfeature.STATES.with_scale("50m"), linewidth=0.4)
        ax.set_title(
            f"({i},{j})  FFE={counts[i, j]}/{totals[i, j]}  ({100 * risk[i, j]:.1f}\\%)",
            fontsize=5,
        )

        ax.set_xticks([])
        ax.set_yticks([])

        # Add inline labels
        ax.clabel(cn, cn.levels, fontsize=5)

# one colorbar
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
cbar.set_label("IVT (kg m$^{-1}$ s$^{-1}$)", fontsize=6)

plt.suptitle("SOM Composite: IVT (shaded) + Z500 (contoured)", fontsize=8, y=1.04)
plt.savefig("figs/Z500-IVT-big-SOM/composite_mean_IVT_shaded.png", bbox_inches="tight")
plt.close()


### Maps of Individual Nodes

In [None]:
# Set number of columns
cols = 4
proj = ccrs.PlateCarree()

# Clear out existing _FFE.png files in the indiv-nodes directory
ffe_files = glob.glob("figs/Z500-IVT-big-SOM/indiv-nodes/*_FFE.png")
for file in ffe_files:
    os.remove(file)
print(f"Removed {len(ffe_files)} existing _FFE.png files")

# Iterate through each node
for i in range(xdim):
    for j in range(ydim):
        # All days assigned to this node
        idx_node = np.where((bmus[:, 0] == i) & (bmus[:, 1] == j))[0]

        # Restrict to flash-flood days only
        idx = np.intersect1d(idx_node, event_indices)
        n = len(idx)

        # Skip nodes with no flash-flood days
        if n == 0:
            continue

        # Set number of rows
        rows = int(np.ceil(n / cols))

        # Create a figure with subplots
        fig, axes = plt.subplots(
            rows,
            cols,
            figsize=(3 * cols, 2.5 * rows),
            subplot_kw={"projection": proj},
            layout="constrained",
        )

        # Ensure axes is always iterable
        axes = np.atleast_1d(axes).flatten()

        for k, ax in enumerate(axes):
            if k < n:
                t = idx[k]
                z500_data = Z500_daily.isel(time=t)
                ivt_data = IVT_daily.isel(valid_time=t)

                cn = ax.contour(
                    lon,
                    lat,
                    z500_data / 98.1,
                    colors="black",
                    linewidths=0.5,
                    levels=range(546, 595, 6),
                    transform=ccrs.PlateCarree(),
                )

                im = ax.contourf(
                    lon,
                    lat,
                    ivt_data,
                    cmap="BuPu",
                    levels=levels_ivt,
                    transform=ccrs.PlateCarree(),
                    extend="max",
                )

                ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
                ax.add_feature(cfeature.BORDERS, linewidth=0.3)
                ax.add_feature(cfeature.STATES, linewidth=0.2)

                ax.set_title(
                    pd.to_datetime(z500_data.time.values).strftime("%Y-%m-%d"),
                    fontsize=7,
                )
            else:
                ax.axis("off")

        fig.suptitle(
            f"Node ({i},{j})  Flash-Flood Days = {n}",
            fontsize=8,
            y=1.02,
        )

        plt.savefig(
            f"figs/Z500-IVT-big-SOM/indiv-nodes/node_{i}_{j}_FFE.png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.close(fig)


Removed 20 existing _FFE.png files


### Residual Analysis for Flash Flood Days

Calculate residuals (observed - node centroid) for flash flood events to understand how individual events deviate from the typical pattern captured by each SOM node.

In [29]:
# Calculate residuals, quantization error, and spatial correlation for flash flood days
# Residual = observed (normalized) - node centroid weight
# QE = Euclidean distance between observed vector and BMU weight (MiniSom definition)
# Spatial correlation = Pearson correlation between observed and node weight patterns

# Storage for mean residuals per node (only flash flood days)
z500_residuals = np.full((xdim, ydim, n_lat, n_lon), np.nan)
ivt_residuals = np.full((xdim, ydim, n_lat, n_lon), np.nan)
ff_counts = np.zeros((xdim, ydim), dtype=int)
mean_qe_per_node = np.full((xdim, ydim), np.nan)
mean_corr_per_node = np.full((xdim, ydim), np.nan)  # Combined correlation
mean_corr_z500_per_node = np.full((xdim, ydim), np.nan)
mean_corr_ivt_per_node = np.full((xdim, ydim), np.nan)

# Storage for per-event metrics (for ranking later)
event_metrics_list = []  # Will store (date, node_i, node_j, qe, corr, corr_z500, corr_ivt)

# Get SOM weights in original shape for distance calculation
som_weights = som.get_weights()  # shape: (xdim, ydim, n_features)

for i in range(xdim):
    for j in range(ydim):
        # All days assigned to this node
        idx_node = np.where((bmus[:, 0] == i) & (bmus[:, 1] == j))[0]

        # Flash-flood days within this node
        idx_ff = np.intersect1d(idx_node, event_indices)
        ff_counts[i, j] = len(idx_ff)

        if len(idx_ff) > 0:
            # Get observed normalized values for flash flood days
            z500_obs = Z500_norm_daily.isel(time=idx_ff).values  # (n_ff, lat, lon)
            ivt_obs = IVT_norm_daily.isel(valid_time=idx_ff).values  # (n_ff, lat, lon)

            # Get node centroid (weights are already normalized)
            z500_centroid = z500_nodes[i, j, :, :]  # (lat, lon)
            ivt_centroid = ivt_nodes[i, j, :, :]  # (lat, lon)

            # Calculate residuals: observed - centroid
            z500_resid = z500_obs - z500_centroid[np.newaxis, :, :]
            ivt_resid = ivt_obs - ivt_centroid[np.newaxis, :, :]

            # Store mean residual for this node
            z500_residuals[i, j] = np.mean(z500_resid, axis=0)
            ivt_residuals[i, j] = np.mean(ivt_resid, axis=0)

            # Calculate QE and spatial correlation for each flash flood event
            qe_values = []
            corr_values = []
            corr_z500_values = []
            corr_ivt_values = []

            for k, t in enumerate(idx_ff):
                # QE for single sample = distance to BMU weight
                qe = np.linalg.norm(X[t] - som_weights[i, j])
                qe_values.append(qe)

                # Spatial correlation (Pearson) for combined vector
                corr_combined = np.corrcoef(X[t], som_weights[i, j].flatten())[0, 1]
                corr_values.append(corr_combined)

                # Separate correlations for Z500 and IVT
                z500_obs_flat = z500_obs[k].flatten()
                ivt_obs_flat = ivt_obs[k].flatten()
                z500_node_flat = z500_centroid.flatten()
                ivt_node_flat = ivt_centroid.flatten()

                corr_z500 = np.corrcoef(z500_obs_flat, z500_node_flat)[0, 1]
                corr_ivt = np.corrcoef(ivt_obs_flat, ivt_node_flat)[0, 1]
                corr_z500_values.append(corr_z500)
                corr_ivt_values.append(corr_ivt)

                # Store for ranking
                event_date = pd.to_datetime(Z500_norm_daily.time.values[t])
                event_metrics_list.append(
                    (event_date, i, j, qe, corr_combined, corr_z500, corr_ivt)
                )

            mean_qe_per_node[i, j] = np.mean(qe_values)
            mean_corr_per_node[i, j] = np.mean(corr_values)
            mean_corr_z500_per_node[i, j] = np.mean(corr_z500_values)
            mean_corr_ivt_per_node[i, j] = np.mean(corr_ivt_values)

# Convert to DataFrame for easier manipulation
event_metrics_df = pd.DataFrame(
    event_metrics_list,
    columns=["date", "node_i", "node_j", "qe", "corr", "corr_z500", "corr_ivt"],
)

# Sort by correlation (ascending = worst pattern match first)
event_metrics_df_by_corr = event_metrics_df.sort_values("corr", ascending=True)

# Also keep QE-sorted version
event_qe_df = event_metrics_df.sort_values("qe", ascending=False)

print("Residuals, QE, and spatial correlation calculated for flash flood days:")
print(f"Flash flood counts per node:\n{ff_counts.T}")
print(f"\nMean QE per node:\n{np.round(mean_qe_per_node.T, 1)}")
print(f"\nMean spatial correlation (combined) per node:\n{np.round(mean_corr_per_node.T, 3)}")
print(f"\nMean Z500 correlation per node:\n{np.round(mean_corr_z500_per_node.T, 3)}")
print(f"\nMean IVT correlation per node:\n{np.round(mean_corr_ivt_per_node.T, 3)}")

Residuals, QE, and spatial correlation calculated for flash flood days:
Flash flood counts per node:
[[ 5  4  2  3  9]
 [ 2  1  3  1  5]
 [ 3  4  4  3  6]
 [ 8  8  6 10 28]]

Mean QE per node:
[[122.1 101.6 101.9  95.4 121.3]
 [115.   94.9  89.4 111.  104.3]
 [101.3  95.5  95.3  92.1 101.4]
 [120.1 106.9 111.4 112.2 126.7]]

Mean spatial correlation (combined) per node:
[[0.476 0.551 0.617 0.383 0.688]
 [0.484 0.269 0.093 0.284 0.544]
 [0.423 0.297 0.197 0.252 0.53 ]
 [0.528 0.504 0.495 0.423 0.655]]

Mean Z500 correlation per node:
[[ 0.54   0.592  0.831  0.376  0.736]
 [ 0.348  0.156  0.156  0.399  0.51 ]
 [ 0.243  0.097 -0.064  0.276  0.673]
 [ 0.603  0.588  0.604  0.44   0.729]]

Mean IVT correlation per node:
[[ 0.325  0.399  0.007  0.019  0.556]
 [ 0.399  0.233 -0.033  0.184  0.281]
 [ 0.341  0.171  0.435  0.17   0.38 ]
 [ 0.402  0.37   0.409  0.468  0.534]]


In [31]:
# Visualize mean residuals for flash flood days (with mean QE and correlation in titles)
fig, axes = plt.subplots(
    ydim,
    xdim,
    figsize=(6, 3.7),
    subplot_kw={"projection": ccrs.PlateCarree()},
    constrained_layout=True,
    dpi=600,
)

# Symmetric levels for residuals (centered on zero)
levels_z500_resid = np.arange(-2.0, 2.05, 0.2)
levels_ivt_resid = np.arange(-2.0, 2.05, 0.2)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        z500_resid_field = z500_residuals[i, j, :, :]
        ivt_resid_field = ivt_residuals[i, j, :, :]

        # Skip nodes with no flash flood days
        if np.isnan(z500_resid_field).all():
            ax.set_title(f"({i},{j})  n=0", fontsize=5)
            ax.add_feature(cfeature.COASTLINE, linewidth=0.6)
            ax.add_feature(cfeature.STATES.with_scale("50m"), linewidth=0.4)
            ax.set_xticks([])
            ax.set_yticks([])
            continue

        # IVT residuals shaded
        im = ax.contourf(
            lon,
            lat,
            ivt_resid_field,
            cmap="balance",
            levels=levels_ivt_resid,
            transform=ccrs.PlateCarree(),
            extend="both",
        )

        # Z500 residuals contoured
        ax.contour(
            lon,
            lat,
            z500_resid_field,
            colors="black",
            linewidths=0.5,
            levels=levels_z500_resid,
            transform=ccrs.PlateCarree(),
        )

        ax.add_feature(cfeature.COASTLINE, linewidth=0.6)
        ax.add_feature(cfeature.STATES.with_scale("50m"), linewidth=0.4)
        ax.set_title(
            f"({i},{j}) n={ff_counts[i, j]} r={mean_corr_per_node[i, j]:.2f}",
            fontsize=5,
        )
        ax.set_xticks([])
        ax.set_yticks([])

# Colorbar
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
cbar.set_label("Standardized Residual", fontsize=6)

plt.suptitle(
    "Mean Residuals (FFE days): Z500 (contoured) + IVT (shaded)", fontsize=8, y=1.04
)
plt.savefig("figs/Z500-IVT-big-SOM/residuals_ffe.png", bbox_inches="tight")
plt.close()

In [32]:
# Rank flash flood events by spatial correlation (worst pattern match = lowest correlation)
print("Flash Flood Events Ranked by Spatial Correlation (Worst Pattern Match First)")
print("=" * 85)
print(f"{'Rank':<6}{'Date':<14}{'Node':<10}{'Corr':<8}{'Z500_r':<8}{'IVT_r':<8}{'QE':<10}")
print("-" * 85)

for rank, (_, row) in enumerate(event_metrics_df_by_corr.iterrows(), start=1):
    date_str = row["date"].strftime("%Y-%m-%d")
    node_str = f"({int(row['node_i'])},{int(row['node_j'])})"
    print(
        f"{rank:<6}{date_str:<14}{node_str:<10}{row['corr']:.3f}   "
        f"{row['corr_z500']:.3f}   {row['corr_ivt']:.3f}   {row['qe']:.1f}"
    )

print("-" * 85)
print(f"\nTotal flash flood events: {len(event_metrics_df)}")
print(f"Mean correlation (all FFE): {event_metrics_df['corr'].mean():.3f}")
print(f"Mean Z500 correlation: {event_metrics_df['corr_z500'].mean():.3f}")
print(f"Mean IVT correlation: {event_metrics_df['corr_ivt'].mean():.3f}")
print(f"Mean QE (all FFE): {event_metrics_df['qe'].mean():.2f}")

Flash Flood Events Ranked by Spatial Correlation (Worst Pattern Match First)
Rank  Date          Node      Corr    Z500_r  IVT_r   QE        
-------------------------------------------------------------------------------------
1     2012-06-22    (3,2)     -0.136   0.014   -0.155   96.6
2     2004-07-02    (2,1)     0.027   -0.023   0.086   92.4
3     2006-06-02    (2,2)     0.100   -0.447   0.530   88.7
4     2013-05-09    (3,3)     0.108   0.205   0.209   102.6
5     2000-09-03    (2,1)     0.125   0.232   -0.157   83.9
6     2000-07-03    (2,1)     0.127   0.260   -0.027   91.8
7     2017-08-02    (4,2)     0.128   0.687   -0.168   106.5
8     2002-06-26    (1,2)     0.163   -0.021   0.102   94.8
9     2006-07-21    (2,2)     0.175   -0.222   0.473   85.1
10    2004-09-18    (2,2)     0.182   0.075   0.350   113.5
11    2021-08-22    (0,1)     0.199   0.143   0.074   138.4
12    2018-06-28    (1,2)     0.229   0.271   0.049   77.4
13    2024-08-06    (2,3)     0.236   0.389   0.130

In [23]:
# Plot top N highest-QE flash flood events to visually inspect for cutoff lows
n_top = 15
top_events = event_qe_df.head(n_top)

cols = 5
rows = int(np.ceil(n_top / cols))

fig, axes = plt.subplots(
    rows,
    cols,
    figsize=(3 * cols, 2.5 * rows),
    subplot_kw={"projection": ccrs.PlateCarree()},
    layout="constrained",
    dpi=300,
)
axes = axes.flatten()

# Levels for Z500 contours (in dam) - extended range for cutoff lows
levels_Z = range(534, 606, 3)
levels_ivt = np.arange(0, 701, 100)

for k, (_, row) in enumerate(top_events.iterrows()):
    ax = axes[k]
    
    # Find the time index for this date
    event_date = row["date"]
    t = np.where(pd.to_datetime(Z500_daily.time.values) == event_date)[0][0]
    
    z500_data = Z500_daily.isel(time=t)
    ivt_data = IVT_daily.isel(valid_time=t)
    
    # IVT shaded
    im = ax.contourf(
        lon,
        lat,
        ivt_data,
        cmap="BuPu",
        levels=levels_ivt,
        transform=ccrs.PlateCarree(),
        extend="max",
    )
    
    # Z500 contoured
    cn = ax.contour(
        lon,
        lat,
        z500_data / 98.1,
        colors="black",
        linewidths=0.6,
        levels=levels_Z,
        transform=ccrs.PlateCarree(),
    )
    
    ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
    ax.add_feature(cfeature.STATES.with_scale("50m"), linewidth=0.3)
    
    ax.set_title(
        f"{event_date.strftime('%Y-%m-%d')}\nNode ({int(row['node_i'])},{int(row['node_j'])})  QE={row['qe']:.1f}",
        fontsize=6,
    )

# Turn off any unused axes
for k in range(n_top, len(axes)):
    axes[k].axis("off")

plt.suptitle(f"Top {n_top} Highest-QE Flash Flood Events", fontsize=10, y=1.02)
plt.savefig("figs/Z500-IVT-big-SOM/top_qe_events.png", bbox_inches="tight")
plt.close()

In [24]:
# Plot bottom N lowest-QE flash flood events (best fit) for comparison
n_bottom = 15
bottom_events = event_qe_df.tail(n_bottom).iloc[::-1]  # Reverse so lowest QE is first

cols = 5
rows = int(np.ceil(n_bottom / cols))

fig, axes = plt.subplots(
    rows,
    cols,
    figsize=(3 * cols, 2.5 * rows),
    subplot_kw={"projection": ccrs.PlateCarree()},
    layout="constrained",
    dpi=300,
)
axes = axes.flatten()

# Levels for Z500 contours (in dam) - extended range
levels_Z = range(534, 606, 3)
levels_ivt = np.arange(0, 701, 100)

for k, (_, row) in enumerate(bottom_events.iterrows()):
    ax = axes[k]
    
    # Find the time index for this date
    event_date = row["date"]
    t = np.where(pd.to_datetime(Z500_daily.time.values) == event_date)[0][0]
    
    z500_data = Z500_daily.isel(time=t)
    ivt_data = IVT_daily.isel(valid_time=t)
    
    # IVT shaded
    im = ax.contourf(
        lon,
        lat,
        ivt_data,
        cmap="BuPu",
        levels=levels_ivt,
        transform=ccrs.PlateCarree(),
        extend="max",
    )
    
    # Z500 contoured
    cn = ax.contour(
        lon,
        lat,
        z500_data / 98.1,
        colors="black",
        linewidths=0.6,
        levels=levels_Z,
        transform=ccrs.PlateCarree(),
    )
    
    ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
    ax.add_feature(cfeature.STATES.with_scale("50m"), linewidth=0.3)
    
    ax.set_title(
        f"{event_date.strftime('%Y-%m-%d')}\nNode ({int(row['node_i'])},{int(row['node_j'])})  QE={row['qe']:.1f}",
        fontsize=6,
    )

# Turn off any unused axes
for k in range(n_bottom, len(axes)):
    axes[k].axis("off")

plt.suptitle(f"Bottom {n_bottom} Lowest-QE Flash Flood Events (Best Fit)", fontsize=10, y=1.02)
plt.savefig("figs/Z500-IVT-big-SOM/bottom_qe_events.png", bbox_inches="tight")
plt.close()

### Cutoff Low Detection

Simple automated detection of cutoff lows in flash flood events. A cutoff low is identified when there is a single local minimum in the Z500 field that is surrounded by higher values (i.e., enclosed by a closed contour using 6-dam spacing).

In [28]:
from scipy import ndimage


def has_cutoff_low(z500_field, contour_spacing=6):
    """
    Detect if a Z500 field contains a cutoff low.

    A cutoff low is identified as a local minimum where the contour
    at (minimum + contour_spacing) forms a closed loop that does not
    touch the domain boundary.

    Parameters
    ----------
    z500_field : array-like
        2D array of Z500 values in dam (decameters)
    contour_spacing : float
        Contour interval in dam. The closed contour is defined at
        (local_minimum + contour_spacing).

    Returns
    -------
    bool
        True if a cutoff low is detected, False otherwise
    """
    z500 = np.asarray(z500_field)

    # Find local minima using minimum filter
    min_filtered = ndimage.minimum_filter(z500, size=5)
    local_minima = z500 == min_filtered

    # Get coordinates of local minima
    min_coords = np.argwhere(local_minima)

    for coord in min_coords:
        i, j = coord
        min_val = z500[i, j]

        # Skip minima too close to the edge
        if i < 2 or i >= z500.shape[0] - 2 or j < 2 or j >= z500.shape[1] - 2:
            continue

        # Define the closed contour level: minimum + contour_spacing
        contour_level = min_val + contour_spacing

        # Find the region enclosed by this contour (values < contour_level)
        below_contour = z500 < contour_level

        # Label connected regions
        labeled, num_features = ndimage.label(below_contour)

        if num_features == 0:
            continue

        # Get the label of the region containing this minimum
        region_label = labeled[i, j]

        if region_label == 0:
            continue

        # Check if this region touches any boundary
        region_mask = labeled == region_label

        touches_boundary = (
            np.any(region_mask[0, :])  # top edge
            or np.any(region_mask[-1, :])  # bottom edge
            or np.any(region_mask[:, 0])  # left edge
            or np.any(region_mask[:, -1])  # right edge
        )

        # If the region doesn't touch any boundary, it's a closed cutoff
        if not touches_boundary:
            return True

    return False


# Test on all flash flood events and track by node
cutoff_by_node = {
    (i, j): {"total": 0, "cutoff": 0} for i in range(xdim) for j in range(ydim)
}
cutoff_events = []  # Store (date, node_i, node_j, has_cutoff)

for idx in event_indices:
    # Get Z500 in dam
    z500_dam = Z500_daily.isel(time=idx).values / 98.1

    # Get BMU for this day
    node_i, node_j = bmus[idx]

    # Check for cutoff low
    has_cutoff = has_cutoff_low(z500_dam, contour_spacing=6)

    # Update counts
    cutoff_by_node[(node_i, node_j)]["total"] += 1
    if has_cutoff:
        cutoff_by_node[(node_i, node_j)]["cutoff"] += 1

    # Store event info
    event_date = pd.to_datetime(Z500_daily.time.values[idx])
    cutoff_events.append((event_date, node_i, node_j, has_cutoff))

# Create arrays for visualization
cutoff_fraction = np.full((xdim, ydim), np.nan)
cutoff_count = np.zeros((xdim, ydim), dtype=int)
total_ff_count = np.zeros((xdim, ydim), dtype=int)

for i in range(xdim):
    for j in range(ydim):
        total = cutoff_by_node[(i, j)]["total"]
        cutoff = cutoff_by_node[(i, j)]["cutoff"]
        total_ff_count[i, j] = total
        cutoff_count[i, j] = cutoff
        if total > 0:
            cutoff_fraction[i, j] = cutoff / total

# Summary
total_cutoffs = sum(1 for e in cutoff_events if e[3])
print(f"Cutoff low detection results:")
print(f"Total flash flood events: {len(cutoff_events)}")
print(
    f"Events with cutoff lows: {total_cutoffs} ({100 * total_cutoffs / len(cutoff_events):.1f}%)"
)
print(f"\nCutoff fraction by node:")
print(f"{np.round(cutoff_fraction.T * 100, 1)}")


Cutoff low detection results:
Total flash flood events: 115
Events with cutoff lows: 5 (4.3%)

Cutoff fraction by node:
[[60.   0.   0.   0.   0. ]
 [ 0.   0.   0.   0.   0. ]
 [ 0.   0.   0.   0.   0. ]
 [ 0.   0.   0.   0.   7.1]]


In [None]:
# Visualize cutoff low fraction by SOM node
fig, ax = plt.subplots(figsize=(5, 4), dpi=150)

# Plot cutoff fraction as heatmap
im = ax.imshow(
    cutoff_fraction.T * 100,
    cmap="YlOrRd",
    origin="lower",
    vmin=0,
    vmax=100,
)

# Add text annotations with fraction and counts
for i in range(xdim):
    for j in range(ydim):
        total = total_ff_count[i, j]
        cutoff = cutoff_count[i, j]
        frac = cutoff_fraction[i, j]
        
        if total > 0:
            text = f"{100 * frac:.0f}\\%\n({cutoff}/{total})"
        else:
            text = "n=0"
        
        # Choose text color based on background
        text_color = "white" if frac > 0.5 else "black"
        ax.text(i, j, text, ha="center", va="center", fontsize=7, color=text_color)

ax.set_xticks(np.arange(xdim))
ax.set_yticks(np.arange(ydim))
ax.set_xlabel("X-index", fontsize=8)
ax.set_ylabel("Y-index", fontsize=8)
ax.set_title("Cutoff Low Frequency in Flash Flood Events by SOM Node", fontsize=9)

cbar = fig.colorbar(im, ax=ax, shrink=0.8)
cbar.set_label("Cutoff Low Frequency (\\%)", fontsize=8)

plt.tight_layout()
plt.savefig("figs/Z500-IVT-big-SOM/cutoff_low_frequency.png", bbox_inches="tight")
plt.close()