# Flash Flood Only SOM Training (Z<sub>500</sub> and |IVT|)

By: Ty Janoski

Updated 1/24/2026

## Setup

### Imports

In [102]:
# Import Statements
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cmweather  # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scienceplots  # noqa: F401
import xarray as xr
from minisom import MiniSom
from sklearn.manifold import MDS
from sklearn.metrics import pairwise_distances

plt.style.use(["science", "nature", "grid"])
plt.rcParams["text.usetex"] = True


In [103]:
# --- Helper Functions ---
def get_node_indices(bmus, i, j):
    """Get sample indices belonging to node (i,j)."""
    return np.where((bmus[:, 0] == i) & (bmus[:, 1] == j))[0]


def compute_composites(data, bmus, xdim, ydim, time_dim="valid_time"):
    """Compute composite mean for each SOM node."""
    sample_shape = data.isel({time_dim: 0}).shape
    composites = np.full((xdim, ydim) + sample_shape, np.nan)
    counts = np.zeros((xdim, ydim), dtype=int)

    for i in range(xdim):
        for j in range(ydim):
            idx = get_node_indices(bmus, i, j)
            counts[i, j] = len(idx)
            if len(idx) > 0:
                composites[i, j] = data.isel({time_dim: idx}).mean(time_dim).values
    return composites, counts


def create_som_figure(xdim, ydim, figsize=(6, 3.7), dpi=600):
    """Create a standard figure for SOM node plots."""
    fig, axes = plt.subplots(
        ydim,
        xdim,
        figsize=figsize,
        subplot_kw={"projection": ccrs.PlateCarree()},
        constrained_layout=True,
        dpi=dpi,
    )
    return fig, axes


def add_map_features(ax):
    """Add standard map features to an axis."""
    ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
    ax.add_feature(cfeature.STATES.with_scale("50m"), linewidth=0.4)
    ax.set_xticks([])
    ax.set_yticks([])


def plot_node_events(
    data,
    bmus,
    xdim,
    ydim,
    lon,
    lat,
    levels,
    cmap,
    save_pattern,
    scale=1.0,
    cbar_label=None,
    time_dim="valid_time",
    contour=False,
    z500_data=None,
    z500_levels=None,
    z500_scale=1.0,
    z500_time_dim="time",
):
    """Plot individual events for each SOM node.

    Parameters
    ----------
    z500_data : xarray.DataArray, optional
        Z500 data to overlay as contours (same time dimension as data)
    z500_levels : array-like, optional
        Contour levels for Z500 (required if z500_data is provided)
    z500_scale : float, optional
        Scale factor for Z500 values (e.g., 1/98.1 to convert to dam)
    z500_time_dim : str, optional
        Name of time dimension in z500_data (default: "time")
    """
    cols = 5
    proj = ccrs.PlateCarree()

    for i in range(xdim):
        for j in range(ydim):
            idx = get_node_indices(bmus, i, j)
            n = len(idx)
            rows = int(np.ceil(n / cols))

            fig, axes = plt.subplots(
                rows,
                cols,
                figsize=(3 * cols, 2.5 * rows),
                subplot_kw={"projection": proj},
                layout="constrained",
                dpi=300,
            )

            for k, ax in enumerate(axes.flat):
                if k < n:
                    field = data.isel({time_dim: idx[k]})
                    time_val = field[time_dim].values

                    if contour:
                        im = ax.contour(
                            lon,
                            lat,
                            field.values * scale,
                            levels=levels,
                            colors="black",
                            transform=proj,
                            linewidths=0.6,
                        )
                        ax.clabel(im, im.levels, fontsize=5)
                    else:
                        im = ax.contourf(
                            lon,
                            lat,
                            field.values * scale,
                            levels=levels,
                            cmap=cmap,
                            transform=proj,
                            extend="max",
                        )

                        # Overlay Z500 contours if provided
                        if z500_data is not None and z500_levels is not None:
                            z500_field = z500_data.isel({z500_time_dim: idx[k]})
                            cn = ax.contour(
                                lon,
                                lat,
                                z500_field.values * z500_scale,
                                levels=z500_levels,
                                colors="black",
                                linewidths=0.5,
                                transform=proj,
                            )
                            ax.clabel(cn, inline=True, fontsize=5, fmt="%.0f")

                    ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
                    ax.add_feature(cfeature.BORDERS, linewidth=0.3)
                    ax.add_feature(cfeature.STATES, linewidth=0.2)
                    ax.set_title(str(pd.to_datetime(time_val))[:16])
                else:
                    ax.axis("off")

            if cbar_label and not contour:
                cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
                cbar.set_label(cbar_label, fontsize=6)

            fig.suptitle(f"Node ({i},{j})  N={n}", fontsize=8, y=1.02)
            plt.savefig(save_pattern.format(i=i, j=j))
            plt.close(fig)


### Data Loading

In [104]:
# Read in Z500 and IVT at flash-flood event times
path = "/mnt/drive2/SOM_intermediate_files/"
ivt_norm_weighted_ffe = xr.load_dataset(f"{path}era5_ivt_norm_weighted_ffe.nc")["ivt"]
ivt_norm_ffe = xr.load_dataset(f"{path}era5_ivt_norm_ffe.nc")["ivt"]
ivt_ffe = xr.load_dataset(f"{path}era5_ivt_ffe.nc")["ivt"]

z500_norm_weighted_ffe = xr.load_dataarray(f"{path}era5_Z500_norm_weighted_ffe.nc")
z500_norm_ffe = xr.load_dataarray(f"{path}era5_Z500_norm_ffe.nc")
z500_ffe = xr.load_dataarray(f"{path}era5_Z500_ffe.nc")

# Total precipitation and mean sea level pressure at flash-flood event times
tp_ffe = xr.load_dataarray(f"{path}era5_tp_ffe.nc")
mslp_ffe = xr.load_dataarray(f"{path}era5_mslp_ffe.nc")


### Reshape Data

In [105]:
# Flatten the data for SOM training
z500_flat = z500_norm_weighted_ffe.stack(
    features=["lat", "lon"]
).values  # shape: (time, feature)
ivt_flat = ivt_norm_weighted_ffe.stack(
    features=["latitude", "longitude"]
).values  # shape: (time, features)

# Concatenate the data
X = np.concatenate((z500_flat, ivt_flat), axis=1)  # shape: (time, feature)


## SOM Training

We are going to train our SOM with random initialization and online training. We will also use two phases: a "coarse" phase with a larger sigma and learning rate, then a "fine" phase with a smaller learning rate and sigma.

### Set SOM parameters

In [106]:
# Set SOM shape
xdim, ydim = 2, 2

# Set number of iterations for each phase
# Using "literature style": short rough phase, longer fine phase with sigma=1
n1, n2 = 1000, 1000

# Set starting sigmas
# Phase 1: large sigma for global ordering; Phase 2: sigma=1 for localized refinement
sig1, sig2 = np.sqrt(xdim**2 + ydim**2), 1.0

# Set starting learning rates
lr1, lr2 = 0.5, 0.1

# Random seed for reproducibility
random_seed = 42


### Train SOM

In [107]:
# Create SOM instance
som = MiniSom(
    xdim,
    ydim,
    input_len=X.shape[1],
    sigma=sig1,
    learning_rate=lr1,
    decay_function="linear_decay_to_zero",
    sigma_decay_function="linear_decay_to_one",
    neighborhood_function="gaussian",
    random_seed=random_seed,
)

# Initialize random weights
som.random_weights_init(X)

# Random training
som.train_random(X, n1, verbose=True)

# Phase 2
som._sigma = sig2 # type: ignore
som._learning_rate = lr2
som.train_random(X, n2, verbose=True)


 [ 1000 / 1000 ] 100% - 0:00:00 left 
 quantization error: 136.4534791604323
 [ 1000 / 1000 ] 100% - 0:00:00 left 
 quantization error: 134.61211112981474


### Grab important fields

In [108]:
# Total node number
n_nodes = xdim * ydim

# Get flattened weights
weights = som.get_weights().reshape(xdim * ydim, -1)

# u-matrix
u_matrix = som.distance_map().T

# bmus & hit_map
bmus = np.array([som.winner(x) for x in X])

hit_map = np.zeros((xdim, ydim))
for i, j in bmus:
    hit_map[i, j] += 1
hit_map = hit_map.T

# Sammon Coordinates
D = pairwise_distances(weights)
coords = MDS(
    n_components=2,
    metric="precomputed",
    random_state=42,
    n_init=4,
    init="random",  # pyright: ignore[reportCallIssue]
).fit_transform(D)

# Get lats/lons
lat = ivt_norm_ffe.latitude
lon = ivt_norm_ffe.longitude

# Dimensions of the spatial field
n_lat = lat.size
n_lon = lon.size
n_features = n_lat * n_lon

# Split weights into Z500 and IVT components
z500_weights = weights[:, :n_features]
ivt_weights = weights[:, n_features:]

# Reshape weights back to spatial dimensions
z500_nodes = z500_weights.reshape(xdim, ydim, n_lat, n_lon)
ivt_nodes = ivt_weights.reshape(xdim, ydim, n_lat, n_lon)


### Rotate SOM (Optional)

Rotating a trained SOM does **not** violate SOM principles—it preserves the learned topology (neighbors remain neighbors). This is purely a relabeling of node coordinates for visualization purposes.

The rotation below is **counterclockwise by 90°**. For a 2×2 grid, this maps:
- `(0,0) → (0,1)`
- `(0,1) → (1,1)`
- `(1,0) → (0,0)`
- `(1,1) → (1,0)`

In [109]:
# Set to True to apply clockwise rotation, False to skip
ROTATE_SOM = True

if ROTATE_SOM:
    # Rotate node weight patterns clockwise (k=-1 means 90° CW)
    # This must match the BMU transformation below
    # np.rot90 rotates on axes (0, 1), which are our (i, j) node indices
    z500_nodes = np.rot90(z500_nodes, k=-1, axes=(0, 1))
    ivt_nodes = np.rot90(ivt_nodes, k=-1, axes=(0, 1))

    # Transform BMU coordinates: (i, j) → (j, xdim - 1 - i)
    # This transformation is consistent with 90° CW rotation of weights
    bmus_new = np.column_stack([bmus[:, 1], xdim - 1 - bmus[:, 0]])
    bmus = bmus_new

    # Rotate u-matrix and hit_map for consistency
    # Note: these were transposed earlier, so we rotate on the transposed arrays
    u_matrix = np.rot90(u_matrix, k=-1)
    hit_map = np.rot90(hit_map, k=-1)

    # Recompute flattened weights in new order for Sammon map
    weights_rotated = np.zeros_like(weights)
    for i in range(xdim):
        for j in range(ydim):
            old_idx = i * ydim + j
            new_idx = i * ydim + j
            # Rebuild from rotated node arrays
            weights_rotated[new_idx, :n_features] = z500_nodes[i, j].flatten()
            weights_rotated[new_idx, n_features:] = ivt_nodes[i, j].flatten()
    weights = weights_rotated

    # Recompute Sammon/MDS coordinates with rotated weights
    D = pairwise_distances(weights)
    coords = MDS(
        n_components=2,
        metric="precomputed",
        random_state=42,
        n_init=4,
        init="random",  # pyright: ignore[reportCallIssue]
    ).fit_transform(D)

    print("SOM rotated 90° clockwise.")
    print("New BMU mapping:")
    print("  Old (0,0) → New (0,1)")
    print("  Old (0,1) → New (1,1)")
    print("  Old (1,0) → New (0,0)")
    print("  Old (1,1) → New (1,0)")
else:
    print("Rotation skipped (ROTATE_SOM = False)")

SOM rotated 90° clockwise.
New BMU mapping:
  Old (0,0) → New (0,1)
  Old (0,1) → New (1,1)
  Old (1,0) → New (0,0)
  Old (1,1) → New (1,0)


In [110]:
# Save BMUs with timestamps for cross-SOM analysis
bmu_df = pd.DataFrame(
    {
        "timestamp": pd.to_datetime(ivt_ffe.valid_time.values),
        "node_i": bmus[:, 0],
        "node_j": bmus[:, 1],
    }
)
bmu_df.to_csv("data/som_2x2_bmus.csv", index=False)
print(f"Saved {len(bmu_df)} BMU assignments to data/som_2x2_bmus.csv")


Saved 117 BMU assignments to data/som_2x2_bmus.csv


## Plots

### U-matrix and Sammon Map

In [111]:
fig, axes = plt.subplots(1, 2, layout="constrained", figsize=(6, 3), dpi=600)

# u-matrix
im0 = axes[0].imshow(u_matrix, cmap="viridis", origin="lower")
axes[0].set_title("U-Matrix (Mean Inter-Node Distance)", fontsize=7)
fig.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04, shrink=0.7)

# hit map
im1 = axes[1].imshow(hit_map, cmap="plasma", origin="lower")
axes[1].set_title("Hit Map (Samples per Node)", fontsize=7)
fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04, shrink=0.7)

# axis styling
for ax in axes:
    ax.set_xticks(np.arange(xdim))
    ax.set_yticks(np.arange(ydim))
    ax.set_xlabel("X-index", fontsize=6)
    ax.set_ylabel("Y-index", fontsize=6)

plt.savefig("figs/Z500-and-ivt-SOM/Z500_ivt_som_u_matrix_hit_map.png")
plt.close()


In [112]:
# Flatten u-matrix & hit map
U_flat = u_matrix.T.reshape(-1)  # back to (n_nodes,)
hits_flat = hit_map.T.reshape(-1)  # back to (n_nodes,)

# scale hits
hits_scaled = 30 + 250 * (hits_flat / hits_flat.max())

# plot
plt.figure(figsize=(7, 7))

# Scatter: U controls color, hits control bubble size
sc = plt.scatter(
    coords[:, 0],
    coords[:, 1],
    c=U_flat,
    s=hits_scaled,
    cmap="balance",
    edgecolor="k",
    linewidth=0.5,
    zorder=3,
)

# Draw lattice connections (right & down neighbors only)
for i in range(xdim):
    for j in range(ydim):
        node = i * ydim + j

        # right neighbor
        if j + 1 < ydim:
            nbr = i * ydim + (j + 1)
            plt.plot(
                [coords[node, 0], coords[nbr, 0]],
                [coords[node, 1], coords[nbr, 1]],
                "k-",
                lw=0.6,
                alpha=0.4,
            )

        # down neighbor
        if i + 1 < xdim:
            nbr = (i + 1) * ydim + j
            plt.plot(
                [coords[node, 0], coords[nbr, 0]],
                [coords[node, 1], coords[nbr, 1]],
                "k-",
                lw=0.6,
                alpha=0.4,
            )

# Node labels (i,j)
for idx, (x, y) in enumerate(coords):
    ix, iy = divmod(idx, ydim)
    plt.text(x, y, f"({ix},{iy})", fontsize=8, ha="center", va="center", zorder=5)

plt.title("Sammon / MDS Distortion Grid\nU-Matrix (Color) \\& Node Frequency (Size)")
plt.axis("off")
plt.colorbar(sc, label="U-Matrix (Avg. Neighbor Distance)")
plt.savefig("figs/Z500-and-ivt-SOM/Z500-ivt_som_sammon_mds.png", bbox_inches="tight")
plt.close()


### Node Weights Map

In [113]:
# Shading levels for standardized anomalies
levels_ivt = np.arange(-1.8, 1.81, 0.2)
levels_Z = np.arange(-1.4, 1.41, 0.2)

fig, axes = create_som_figure(xdim, ydim)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        # IVT shaded
        im = ax.contourf(
            lon,
            lat,
            ivt_nodes[i, j],
            cmap="balance",
            levels=levels_ivt,
            transform=ccrs.PlateCarree(),
        )

        # Z500 contours
        cn = ax.contour(
            lon,
            lat,
            z500_nodes[i, j],
            colors="black",
            linewidths=0.5,
            levels=levels_Z,
            transform=ccrs.PlateCarree(),
        )
        ax.clabel(cn, inline=True, fontsize=5, fmt="%.1f")

        add_map_features(ax)
        ax.set_title(f"Node ({i},{j})", fontsize=6)

cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
cbar.set_label("Standardized IVT Anomaly (Shaded)", fontsize=6)

plt.suptitle(
    "Flash Flood Only SOM: Node Weight Patterns\nZ500 (contoured) + IVT (shaded)",
    fontsize=8,
)
plt.savefig(
    "figs/Z500-and-ivt-SOM/combined_node_weights_ivt_shaded.png", bbox_inches="tight"
)
plt.close()


### Anomaly Composite Map

In [114]:
# Compute standardized anomaly composites
z500_patterns, counts = compute_composites(
    z500_norm_ffe, bmus, xdim, ydim, time_dim="time"
)
ivt_patterns, _ = compute_composites(
    ivt_norm_ffe, bmus, xdim, ydim, time_dim="valid_time"
)


In [115]:
# Levels for standardized anomaly composites
levels_ivt_anom = np.arange(-2.5, 2.6, 0.5)
levels_Z_anom = np.arange(-2.0, 2.1, 0.25)

fig, axes = create_som_figure(xdim, ydim)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        # IVT shaded
        im = ax.contourf(
            lon,
            lat,
            ivt_patterns[i, j],
            cmap="balance",
            levels=levels_ivt_anom,
            transform=ccrs.PlateCarree(),
            extend="both",
        )

        # Z500 contours
        ax.contour(
            lon,
            lat,
            z500_patterns[i, j],
            colors="black",
            linewidths=0.5,
            levels=levels_Z_anom,
            transform=ccrs.PlateCarree(),
        )

        add_map_features(ax)
        ax.set_title(f"({i},{j})  N={counts[i, j]}", fontsize=6)

cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
cbar.set_label("Standardized Anomaly", fontsize=6)

plt.suptitle(
    "Flash Flood Only SOM Composite Anomalies: Z500 (contoured) + IVT (shaded)",
    fontsize=8,
    y=1.04,
)
plt.savefig(
    "figs/Z500-and-ivt-SOM/Z500_and_ivt_SOM_composite_anomalies_ivt_shaded.png",
    bbox_inches="tight",
)
plt.close()


### Node Representativeness

To assess how well each SOM node represents its assigned events, we compute **spatial pattern correlations** between:
1. The node's learned weight vector (prototype)
2. The composite mean of all events assigned to that node

For this multivariate SOM, we compute correlations separately for Z500 and IVT, as well as a combined correlation using the concatenated feature vectors. High correlations (r > 0.8) indicate that the prototype faithfully captures the typical pattern of its assigned events.

**Note:** The weights were trained on latitude-weighted, standardized anomalies, while the composites here use unweighted standardized anomalies. Some differences are expected due to the weighting scheme.

In [118]:
# Compute pattern correlations between node weights and composites
from scipy.stats import pearsonr

# Compute composites of the weighted data for a fair comparison
# (weights were trained on latitude-weighted standardized anomalies)
z500_patterns_weighted, node_counts = compute_composites(
    z500_norm_weighted_ffe, bmus, xdim, ydim, time_dim="time"
)
ivt_patterns_weighted, _ = compute_composites(
    ivt_norm_weighted_ffe, bmus, xdim, ydim, time_dim="valid_time"
)

# Store results
representativeness = []

print("=" * 75)
print("NODE REPRESENTATIVENESS: Pattern Correlations (Weights vs. Composites)")
print("=" * 75)
print(
    f"\n{'Node':<8} {'N':>4} {'r(Z500)':>10} {'r(IVT)':>10} {'r(Combined)':>12} {'Interpretation':<20}"
)
print("-" * 75)

for i in range(xdim):
    for j in range(ydim):
        n_events = int(node_counts[i, j])

        # Flatten the spatial fields for correlation
        z500_weight = z500_nodes[i, j].flatten()
        ivt_weight = ivt_nodes[i, j].flatten()

        z500_composite = z500_patterns_weighted[i, j].flatten()
        ivt_composite = ivt_patterns_weighted[i, j].flatten()

        # Compute correlations (handle potential NaNs)
        if n_events > 0 and not np.any(np.isnan(z500_composite)):
            r_z500, p_z500 = pearsonr(z500_weight, z500_composite)
            r_ivt, p_ivt = pearsonr(ivt_weight, ivt_composite)

            # Combined correlation using concatenated vectors
            combined_weight = np.concatenate([z500_weight, ivt_weight])
            combined_composite = np.concatenate([z500_composite, ivt_composite])
            r_combined, p_combined = pearsonr(combined_weight, combined_composite)
        else:
            r_z500 = r_ivt = r_combined = np.nan
            p_z500 = p_ivt = p_combined = np.nan

        # Interpret the correlation
        if r_combined >= 0.9:
            interp = "Excellent"
        elif r_combined >= 0.8:
            interp = "Good"
        elif r_combined >= 0.7:
            interp = "Moderate"
        else:
            interp = "Poor"

        representativeness.append(
            {
                "node": f"({i},{j})",
                "n": n_events,
                "r_z500": r_z500,
                "r_ivt": r_ivt,
                "r_combined": r_combined,
            }
        )

        print(
            f"({i},{j}){'':<4} {n_events:>4} {r_z500:>10.3f} {r_ivt:>10.3f} {r_combined:>12.3f} {interp:<20}"
        )

print("-" * 75)

# Summary statistics
r_combined_values = [
    r["r_combined"] for r in representativeness if not np.isnan(r["r_combined"])
]
print(
    f"\nSummary: Mean r(Combined) = {np.mean(r_combined_values):.3f}, "
    f"Min = {np.min(r_combined_values):.3f}, Max = {np.max(r_combined_values):.3f}"
)

if np.min(r_combined_values) >= 0.8:
    print("→ All nodes show good-to-excellent representativeness.")
elif np.min(r_combined_values) >= 0.7:
    print(
        "→ Most nodes show adequate representativeness; some may benefit from more events."
    )
else:
    print(
        "→ Some nodes show poor representativeness; consider SOM size or training parameters."
    )


NODE REPRESENTATIVENESS: Pattern Correlations (Weights vs. Composites)

Node        N    r(Z500)     r(IVT)  r(Combined) Interpretation      
---------------------------------------------------------------------------
(0,0)       25      0.357      0.825        0.565 Poor                
(0,1)       33      0.684      0.695        0.633 Poor                
(1,0)       35      0.960      0.974        0.963 Excellent           
(1,1)       24      0.748      0.700        0.737 Moderate            
---------------------------------------------------------------------------

Summary: Mean r(Combined) = 0.725, Min = 0.565, Max = 0.963
→ Some nodes show poor representativeness; consider SOM size or training parameters.


In [117]:
# Visualize representativeness
fig, axes = plt.subplots(1, 2, figsize=(8, 3.5), dpi=600, constrained_layout=True)

# Left panel: Heatmap of combined correlations on SOM grid
ax = axes[0]
r_grid = np.zeros((xdim, ydim))
for i in range(xdim):
    for j in range(ydim):
        idx = i * ydim + j
        r_grid[i, j] = representativeness[idx]["r_combined"]

# Display with origin="lower" to match SOM convention
im = ax.imshow(r_grid.T, cmap="RdYlGn", vmin=0.5, vmax=1.0, origin="lower")
ax.set_xticks(np.arange(xdim))
ax.set_yticks(np.arange(ydim))
ax.set_xlabel("X-index", fontsize=7)
ax.set_ylabel("Y-index", fontsize=7)
ax.set_title("Combined Pattern Correlation\n(Weights vs. Composites)", fontsize=8)

# Add correlation values as text
for i in range(xdim):
    for j in range(ydim):
        r_val = r_grid[i, j]
        color = "white" if r_val < 0.75 else "black"
        ax.text(i, j, f"{r_val:.2f}", ha="center", va="center", fontsize=8, color=color)

cbar = fig.colorbar(im, ax=ax, shrink=0.8)
cbar.set_label("Pearson r", fontsize=7)

# Right panel: Grouped bar chart showing Z500, IVT, and Combined correlations
ax = axes[1]
node_labels = [r["node"] for r in representativeness]
x = np.arange(len(node_labels))
width = 0.25

r_z500 = [r["r_z500"] for r in representativeness]
r_ivt = [r["r_ivt"] for r in representativeness]
r_comb = [r["r_combined"] for r in representativeness]

bars1 = ax.bar(x - width, r_z500, width, label="Z500", color="steelblue", alpha=0.9)
bars2 = ax.bar(x, r_ivt, width, label="IVT", color="coral", alpha=0.9)
bars3 = ax.bar(x + width, r_comb, width, label="Combined", color="seagreen", alpha=0.9)

ax.axhline(0.8, color="gray", linestyle="--", linewidth=0.8, label="r = 0.8 threshold")
ax.set_xlabel("SOM Node", fontsize=7)
ax.set_ylabel("Pattern Correlation (r)", fontsize=7)
ax.set_title("Representativeness by Variable", fontsize=8)
ax.set_xticks(x)
ax.set_xticklabels(node_labels, fontsize=6)
ax.set_ylim(0.5, 1.0)
ax.legend(fontsize=6, loc="lower right")
ax.grid(True, linewidth=0.3, alpha=0.5, axis="y")

plt.savefig(
    "figs/Z500-and-ivt-SOM/Z500_and_ivt_som_representativeness.png", bbox_inches="tight"
)
plt.close()

print("Saved representativeness figure to figs/Z500-and-ivt-SOM/Z500_and_ivt_som_representativeness.png")

Saved representativeness figure to figs/Z500-and-ivt-SOM/Z500_and_ivt_som_representativeness.png


### Composite Mean Map

In [119]:
# Compute raw composites
z500_patterns_raw, _ = compute_composites(z500_ffe, bmus, xdim, ydim, time_dim="time")
ivt_patterns_raw, _ = compute_composites(
    ivt_ffe, bmus, xdim, ydim, time_dim="valid_time"
)


In [120]:
# Levels for raw composites
levels_Z_raw = range(552, 595, 3)
levels_ivt_raw = np.arange(0, 701, 100)

fig, axes = create_som_figure(xdim, ydim)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        # IVT shaded
        im = ax.contourf(
            lon,
            lat,
            ivt_patterns_raw[i, j],
            cmap="BuPu",
            levels=levels_ivt_raw,
            transform=ccrs.PlateCarree(),
            extend="max",
        )

        # Z500 contours
        cn = ax.contour(
            lon,
            lat,
            z500_patterns_raw[i, j] / 98.1,
            colors="black",
            linewidths=0.5,
            levels=levels_Z_raw,
            transform=ccrs.PlateCarree(),
        )
        ax.clabel(cn, inline=True, fontsize=5, fmt="%.0f")

        add_map_features(ax)
        ax.set_title(f"({i},{j})  N={counts[i, j]}", fontsize=6)

cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
cbar.set_label("IVT (kg m$^{-1}$ s$^{-1}$)", fontsize=6)

plt.suptitle(
    "Flash Flood Only SOM Composite: IVT (shaded) + Z500 (contoured)",
    fontsize=8,
    y=1.04,
)
plt.savefig(
    "figs/Z500-and-ivt-SOM/Z500_and_ivt_SOM_composite_mean_IVT_shaded.png",
    bbox_inches="tight",
)
plt.close()


### Maps of Individual Nodes

In [121]:
# Plot individual IVT events per node
plot_node_events(
    ivt_ffe,
    bmus,
    xdim,
    ydim,
    lon,
    lat,
    levels=np.arange(0, 1001, 100),
    cmap="BuPu",
    save_pattern="figs/Z500-and-ivt-SOM/indiv-nodes/node_{i}_{j}.png",
    cbar_label="IVT (kg m$^{-1}$ s$^{-1}$)",
    z500_data=z500_ffe,
    z500_levels=range(552, 595, 3),
    z500_scale=1/98.1,
    z500_time_dim="time",
)

# Plot individual precipitation events per node
plot_node_events(
    tp_ffe,
    bmus,
    xdim,
    ydim,
    lon,
    lat,
    levels=np.arange(0, 28, 3),
    cmap="HomeyerRainbow",
    save_pattern="figs/Z500-and-ivt-SOM/indiv-nodes/node_{i}_{j}_precip.png",
    scale=1000,
    cbar_label="Total Precipitation (mm)",
)

# Plot individual MSLP events per node
plot_node_events(
    mslp_ffe,
    bmus,
    xdim,
    ydim,
    lon,
    lat,
    levels=np.arange(976, 1041, 4),
    cmap=None,
    save_pattern="figs/ivt-SOM/indiv-nodes/node_{i}_{j}_mslp.png",
    scale=0.01,
    contour=True,
)


### Maps of Composite MSLP for each SOM node

In [122]:
# Compute MSLP composites
mslp_patterns, _ = compute_composites(mslp_ffe, bmus, xdim, ydim, time_dim="valid_time")


In [123]:
fig, axes = create_som_figure(xdim, ydim, figsize=(6, 3.7))

levels = np.arange(1008, 1023, 2)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        cn = ax.contour(
            lon,
            lat,
            mslp_patterns[i, j] / 100,
            levels=levels,
            transform=ccrs.PlateCarree(),
            cmap="managua_r",
        )
        add_map_features(ax)

        # Add inline labels
        ax.clabel(cn, cn.levels, fontsize=5, inline=True)
        ax.set_title(f"({i},{j})  N={counts[i, j]}", fontsize=6)

# cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
# cbar.set_label("MSLP (hPa)", fontsize=6)

plt.suptitle("SOM Composite MSLP", fontsize=8, y=1.04)
plt.savefig(
    "figs/Z500-and-ivt-SOM/z500_and_ivt_som_composite_mslp.png", bbox_inches="tight"
)
plt.close()


### Month Histograms

In [124]:
# Compute monthly event counts per node
months = pd.to_datetime(ivt_ffe.valid_time).month.to_numpy()
month_counts = {}

for i in range(xdim):
    for j in range(ydim):
        idx = get_node_indices(bmus, i, j)
        node_months = months[idx]
        month_counts[(i, j)] = np.bincount(node_months, minlength=13)[1:]


In [125]:
fig, axes = plt.subplots(ydim, xdim, figsize=(6, 3.7), constrained_layout=True, dpi=600)

# Warm-season labels
month_labels = ["May", "Jun", "Jul", "Aug", "Sep", "Oct"]

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        # Extract May–Oct counts (months 5–10 → indices 4:10)
        counts = month_counts[(i, j)][4:10]

        ax.bar(month_labels, counts, color="teal", alpha=0.9, width=0.8)

        # Title matches your SOM composite style
        ax.set_title(f"({i},{j})  N={counts.sum()}", fontsize=6)

        # Remove ticks entirely (categorical labels don’t need them)
        ax.tick_params(axis="x", bottom=False, labelsize=5)

        # Shared, fixed y-axis across all panels
        ax.set_ylim(0, 18)
        ax.set_yticks(np.arange(0, 17, 2))

        # Light grid for readability
        ax.grid(True, linewidth=0.3, alpha=0.5, axis="y")

# Overall title
plt.suptitle(
    "Warm-Season (May–Oct) Event Distribution per SOM Node", fontsize=8, y=1.04
)
plt.savefig(
    "figs/Z500-and-ivt-SOM/Z500_and_ivt_som_monthly_counts.png", bbox_inches="tight"
)
plt.close()


In [126]:
# Statistical test: Are the monthly distributions different across SOM nodes?
from scipy.stats import chi2_contingency, fisher_exact
from itertools import combinations

# Build contingency table: rows = nodes, columns = months (May-Oct)
contingency = np.array(
    [month_counts[(i, j)][4:10] for i in range(xdim) for j in range(ydim)]
)

node_labels = [f"({i},{j})" for i in range(xdim) for j in range(ydim)]
month_labels = ["May", "Jun", "Jul", "Aug", "Sep", "Oct"]

# Display the contingency table
print("Contingency Table (Node × Month):")
print("-" * 55)
print(f"{'Node':<10}", end="")
for m in month_labels:
    print(f"{m:>7}", end="")
print(f"{'Total':>8}")
print("-" * 55)
for k, label in enumerate(node_labels):
    print(f"{label:<10}", end="")
    for val in contingency[k]:
        print(f"{val:>7}", end="")
    print(f"{contingency[k].sum():>8}")
print("-" * 55)
print(f"{'Total':<10}", end="")
for val in contingency.sum(axis=0):
    print(f"{val:>7}", end="")
print(f"{contingency.sum():>8}")
print()

# Chi-square test for independence (overall)
chi2, p_value, dof, expected = chi2_contingency(contingency)

print("=" * 55)
print("OVERALL CHI-SQUARE TEST FOR INDEPENDENCE")
print("=" * 55)
print(f"H₀: Monthly distributions are the same across all nodes")
print(f"H₁: At least one node has a different monthly distribution")
print()
print(f"Chi-square statistic: {chi2:.3f}")
print(f"Degrees of freedom:   {dof}")
print(f"p-value:              {p_value:.4f}")
print()

# Check expected cell count assumption
min_expected = expected.min()
low_expected = (expected < 5).sum()
print(f"Minimum expected count: {min_expected:.2f}")
print(f"Cells with expected < 5: {low_expected}/{expected.size}")

if low_expected > 0.2 * expected.size:
    print("⚠ Warning: >20% of cells have expected counts < 5; interpret with caution")
print()

if p_value < 0.05:
    print("→ Result: REJECT H₀ at α=0.05. Monthly distributions differ across nodes.")
else:
    print(
        "→ Result: FAIL TO REJECT H₀. No significant difference in monthly distributions."
    )
print()

# Pairwise comparisons with Bonferroni correction
print("=" * 55)
print("PAIRWISE CHI-SQUARE TESTS (Bonferroni-corrected)")
print("=" * 55)

pairs = list(combinations(range(n_nodes), 2))
n_comparisons = len(pairs)
alpha_corrected = 0.05 / n_comparisons

print(f"Number of comparisons: {n_comparisons}")
print(f"Bonferroni-corrected α: {alpha_corrected:.4f}")
print()

pairwise_results = []
for idx1, idx2 in pairs:
    pair_table = contingency[[idx1, idx2], :]

    # Use chi-square if expected counts are reasonable, otherwise note limitation
    chi2_pair, p_pair, dof_pair, exp_pair = chi2_contingency(pair_table)

    sig = "***" if p_pair < alpha_corrected else ""
    pairwise_results.append(
        {
            "pair": f"{node_labels[idx1]} vs {node_labels[idx2]}",
            "chi2": chi2_pair,
            "p": p_pair,
            "sig": p_pair < alpha_corrected,
        }
    )

    print(
        f"{node_labels[idx1]} vs {node_labels[idx2]}: χ²={chi2_pair:.2f}, p={p_pair:.4f} {sig}"
    )

print()
significant_pairs = [r for r in pairwise_results if r["sig"]]
if significant_pairs:
    print(f"Significantly different pairs (p < {alpha_corrected:.4f}):")
    for r in significant_pairs:
        print(f"  • {r['pair']}")
else:
    print("No pairwise comparisons are significant after Bonferroni correction.")


Contingency Table (Node × Month):
-------------------------------------------------------
Node          May    Jun    Jul    Aug    Sep    Oct   Total
-------------------------------------------------------
(0,0)           1      5      6     10      3      0      25
(0,1)           3      7      6      9      5      3      33
(1,0)           3      5     16      5      4      2      35
(1,1)           0      2      6     10      4      2      24
-------------------------------------------------------
Total           7     19     34     34     16      7     117

OVERALL CHI-SQUARE TEST FOR INDEPENDENCE
H₀: Monthly distributions are the same across all nodes
H₁: At least one node has a different monthly distribution

Chi-square statistic: 16.944
Degrees of freedom:   15
p-value:              0.3222

Minimum expected count: 1.44
Cells with expected < 5: 14/24

→ Result: FAIL TO REJECT H₀. No significant difference in monthly distributions.

PAIRWISE CHI-SQUARE TESTS (Bonferroni-corrected

In [127]:
node_totals = np.array(
    [month_counts[(i, j)].sum() for i in range(xdim) for j in range(ydim)]
)

total_events = node_totals.sum()

P_node = node_totals / total_events

# Total events per month (all nodes combined)
all_month_counts = np.bincount(months, minlength=13)[1:]  # 1–12

# Only warm season (May–Oct)
month_idx = np.arange(4, 10)  # indices for May–Oct

n_nodes = xdim * ydim
heatmap = np.zeros((n_nodes, len(month_idx)))

# Flatten node indices consistently
node_labels = []

k = 0
for i in range(xdim):
    for j in range(ydim):
        counts = month_counts[(i, j)][month_idx]
        totals = all_month_counts[month_idx]

        # P(Node | Month)
        heatmap[k, :] = counts / totals
        node_labels.append(f"({i},{j})")

        k += 1

relative_heatmap = np.zeros_like(heatmap)

for k in range(n_nodes):
    relative_heatmap[k, :] = heatmap[k, :] / P_node[k]


In [128]:
fig, ax = plt.subplots(figsize=(6, 3.7), dpi=600)

im = ax.imshow(relative_heatmap, aspect="auto", cmap="RdBu_r", vmin=0, vmax=2)

# Axes labels
ax.set_xticks(np.arange(len(month_idx)))
ax.set_xticklabels(["May", "Jun", "Jul", "Aug", "Sep", "Oct"], fontsize=7)

ax.set_yticks(np.arange(n_nodes))
ax.set_yticklabels(node_labels, fontsize=6)

ax.set_xlabel("Month")
ax.set_ylabel("SOM Node")

cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Relative Likelihood", fontsize=7)

plt.title(
    "Monthly Relative Likelihood of SOM Nodes\n"
    "(Normalized by Seasonal Event Frequency)",
    fontsize=8,
)

plt.tight_layout()
plt.savefig(
    "figs/Z500-and-ivt-SOM/Z500_and_ivt_som_monthly_relative_heatmap.png",
    bbox_inches="tight",
)
plt.close()


## Node Statistics

### Maximum Hourly Rainfall by SOM Node

In [129]:
# Load hourly precipitation data from four NYC-area sites
precip_path = "precip_data_and_tc_association_code/"
sites = {
    "JFK": np.load(f"{precip_path}jfk7_14_25.npy", allow_pickle=True),
    "LGA": np.load(f"{precip_path}lga7_14_25.npy", allow_pickle=True),
    "Central Park": np.load(f"{precip_path}cp7_14_25.npy", allow_pickle=True),
    "EWR": np.load(f"{precip_path}zewr_14_25.npy", allow_pickle=True),
}

# Convert to DataFrames with datetime index for easy lookup
precip_dfs = {}
for name, arr in sites.items():
    df = pd.DataFrame(arr, columns=["precip", "time"])
    df["time"] = pd.to_datetime(df["time"])
    df = df.set_index("time").sort_index()
    df["precip"] = pd.to_numeric(df["precip"], errors="coerce")
    precip_dfs[name] = df

print(f"Loaded precipitation data for {len(precip_dfs)} sites")
for name, df in precip_dfs.items():
    print(f"  {name}: {df.index.min()} to {df.index.max()}, {len(df)} records")


Loaded precipitation data for 4 sites
  JFK: 1949-01-05 08:00:00 to 2025-07-14 23:00:00, 670792 records
  LGA: 1948-07-05 22:00:00 to 2025-07-14 22:00:00, 675193 records
  Central Park: 1948-05-03 00:00:00 to 2025-07-14 22:00:00, 676727 records
  EWR: 1948-05-02 23:00:00 to 2025-07-14 23:00:00, 676729 records


In [130]:
# Load BMU assignments (timestamps are in UTC)
bmu_df = pd.read_csv("data/som_2x2_bmus.csv")
bmu_df["timestamp"] = pd.to_datetime(bmu_df["timestamp"])

# Convert UTC timestamps to local time for matching with precip data
# Precip data is in local time (EST/EDT)
bmu_df["timestamp_local"] = (
    bmu_df["timestamp"].dt.tz_localize("UTC").dt.tz_convert("EST").dt.tz_localize(None)
)

# For each event, find the max hourly rainfall across all four sites
# Look at a 12-hour window centered on the event time to capture peak rainfall
window_hours = 6  # hours before and after

max_precip = []
for _, row in bmu_df.iterrows():
    event_time = row["timestamp_local"]
    start = event_time - pd.Timedelta(hours=window_hours)
    end = event_time + pd.Timedelta(hours=window_hours)

    # Get max precip at each site within the window
    site_maxes = []
    for name, df in precip_dfs.items():
        window_data = df.loc[start:end, "precip"]
        if len(window_data) > 0:
            site_maxes.append(window_data.max())

    # Take max across all sites
    max_precip.append(np.nanmax(site_maxes) if site_maxes else np.nan)

bmu_df["max_precip_in"] = max_precip

# Report coverage
valid = bmu_df["max_precip_in"].notna().sum()
print(f"Found precipitation data for {valid}/{len(bmu_df)} events")
print(
    f"Max hourly rainfall range: {bmu_df['max_precip_in'].min():.2f} - {bmu_df['max_precip_in'].max():.2f} inches"
)


Found precipitation data for 116/117 events
Max hourly rainfall range: 0.06 - 3.62 inches


  max_precip.append(np.nanmax(site_maxes) if site_maxes else np.nan)


In [131]:
# Check counts for each node pair (i,j)
node_precip_stats = {}
for i in range(xdim):
    for j in range(ydim):
        node_data = bmu_df[(bmu_df["node_i"] == i) & (bmu_df["node_j"] == j)]
        precip_values = node_data["max_precip_in"].dropna().values

        if len(precip_values) > 0:
            mean_precip = np.mean(precip_values)
            median_precip = np.median(precip_values)
            std_precip = np.std(precip_values)
        else:
            mean_precip = median_precip = std_precip = np.nan

        node_precip_stats[(i, j)] = {
            "count": len(precip_values),
            "mean": mean_precip,
            "median": median_precip,
            "std": std_precip,
        }
# Print summary table
print("\nNYC-Area Max Hourly Precipitation Statistics by SOM Node:")
print("Node (i,j) | Count | Mean (in) | Median (in) | Std Dev (in)")
for i in range(xdim):
    for j in range(ydim):
        stats = node_precip_stats[(i, j)]
        print(
            f"   ({i},{j})   |  {stats['count']:3d}  |  {stats['mean']:.2f}   |   {stats['median']:.2f}   |   {stats['std']:.2f}"
        )



NYC-Area Max Hourly Precipitation Statistics by SOM Node:
Node (i,j) | Count | Mean (in) | Median (in) | Std Dev (in)
   (0,0)   |   25  |  1.09   |   1.01   |   0.48
   (0,1)   |   32  |  0.88   |   0.80   |   0.56
   (1,0)   |   35  |  1.06   |   1.03   |   0.48
   (1,1)   |   24  |  0.88   |   0.78   |   0.71


In [132]:
# Create histogram of max hourly rainfall for each SOM node
fig, axes = plt.subplots(ydim, xdim, figsize=(6, 4), constrained_layout=True, dpi=600)

# Define consistent bins for all histograms (in inches)
bins = np.arange(0, 3.76, 0.25)

# Colors for each node (matching other plots)
colors = ["steelblue", "darkorange", "seagreen", "firebrick"]

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        # Filter events for this node
        node_data = bmu_df[(bmu_df["node_i"] == i) & (bmu_df["node_j"] == j)][
            "max_precip_in"
        ].dropna()

        # Plot histogram
        ax.hist(
            node_data,
            bins=bins,
            color="teal",
            alpha=0.9,
            edgecolor="white",
            linewidth=0.5,
        )

        # Add statistics
        n = len(node_data)
        median = node_data.median() if n > 0 else np.nan
        mean = node_data.mean() if n > 0 else np.nan

        # Add vertical line for median
        if n > 0:
            ax.axvline(
                median,
                color="red",
                linestyle="--",
                linewidth=1,
                label=f'Median: {median:.2f}"',
            )

        ax.set_title(f"({i},{j})  N={n}", fontsize=6)
        ax.set_xlim(0, 3.75)
        ax.set_ylim(0, 12)
        ax.set_yticks(np.arange(0, 13, 2))
        ax.tick_params(axis="both", labelsize=5)
        ax.grid(True, linewidth=0.3, alpha=0.5, axis="y")

        # Only add x-label on bottom row
        if j == ydim - 1:
            ax.set_xlabel("Max Hourly Precip (in)", fontsize=5)

        # Only add y-label on left column
        if i == 0:
            ax.set_ylabel("Count", fontsize=5)

        # Add legend with median
        if n > 0:
            ax.legend(fontsize=4, loc="upper right")

plt.suptitle(
    "Maximum Hourly Rainfall ($\\pm$6 hr window) by SOM Node",
    fontsize=8,
    y=1.02,
)
plt.savefig(
    "figs/Z500-and-ivt-SOM/Z500_and_ivt_som_max_precip_histograms.png",
    bbox_inches="tight",
)
plt.close()

print(
    "Saved histogram to figs/Z500-and-ivt-SOM/Z500_and_ivt_som_max_precip_histograms.png"
)


Saved histogram to figs/Z500-and-ivt-SOM/Z500_and_ivt_som_max_precip_histograms.png


### Tropical Cyclone Association by SOM Node

We cross-reference flash flood event times with IBTrACS to identify events where a tropical system was present within the analysis domain (30–54°N, 100–60°W). This helps determine whether certain SOM patterns are preferentially associated with tropical cyclone activity.

In [133]:
# Load IBTrACS data
ibtracs_path = "precip_data_and_tc_association_code/ibtracs.NA.list.v04r01.processed_6hrly.statslp3.csv"
ibtracs = pd.read_csv(ibtracs_path)
ibtracs["ISO_TIME"] = pd.to_datetime(ibtracs["ISO_TIME"])

# Define domain bounds (same as IVT/Z500 data)
lat_min, lat_max = 30.0, 54.0
lon_min, lon_max = -100.0, -60.0

# Filter IBTrACS to our domain
ibtracs_domain = ibtracs[
    (ibtracs["LAT"] >= lat_min)
    & (ibtracs["LAT"] <= lat_max)
    & (ibtracs["LON"] >= lon_min)
    & (ibtracs["LON"] <= lon_max)
].copy()

print(f"Total IBTrACS records: {len(ibtracs):,}")
print(f"Records within domain: {len(ibtracs_domain):,}")
print(f"Unique storms in domain: {ibtracs_domain['SID'].nunique()}")


Total IBTrACS records: 64,320
Records within domain: 12,035
Unique storms in domain: 1185


In [134]:
# Cross-reference flash flood events with tropical cyclones
# Look for TCs within the domain within ±12 hours of each event

time_window_hours = 6

tc_associations = []
for _, row in bmu_df.iterrows():
    event_time = row["timestamp"]
    node_i, node_j = row["node_i"], row["node_j"]

    # Find TCs within time window
    time_mask = (
        ibtracs_domain["ISO_TIME"] >= event_time - pd.Timedelta(hours=time_window_hours)
    ) & (
        ibtracs_domain["ISO_TIME"] <= event_time + pd.Timedelta(hours=time_window_hours)
    )
    matching_tcs = ibtracs_domain[time_mask]

    if len(matching_tcs) > 0:
        # Get unique storm IDs and names
        storm_ids = matching_tcs["SID"].unique()
        tc_associations.append(
            {
                "timestamp": event_time,
                "node_i": node_i,
                "node_j": node_j,
                "tc_present": True,
                "n_storms": len(storm_ids),
                "storm_ids": ", ".join(storm_ids),
                "storm_status": matching_tcs["STAT"].mode().iloc[0]
                if len(matching_tcs["STAT"].dropna()) > 0
                else "Unknown",
            }
        )
    else:
        tc_associations.append(
            {
                "timestamp": event_time,
                "node_i": node_i,
                "node_j": node_j,
                "tc_present": False,
                "n_storms": 0,
                "storm_ids": "",
                "storm_status": "",
            }
        )

tc_df = pd.DataFrame(tc_associations)

# Summary statistics
n_tc_events = tc_df["tc_present"].sum()
print(
    f"Flash flood events associated with TCs: {n_tc_events}/{len(tc_df)} ({100 * n_tc_events / len(tc_df):.1f}%)"
)
print(f"\nTC-associated events by SOM node:")
for i in range(xdim):
    for j in range(ydim):
        node_data = tc_df[(tc_df["node_i"] == i) & (tc_df["node_j"] == j)]
        tc_count = node_data["tc_present"].sum()
        total = len(node_data)
        pct = 100 * tc_count / total if total > 0 else 0
        print(f"  Node ({i},{j}): {tc_count}/{total} ({pct:.1f}%)")


Flash flood events associated with TCs: 22/117 (18.8%)

TC-associated events by SOM node:
  Node (0,0): 5/25 (20.0%)
  Node (0,1): 7/33 (21.2%)
  Node (1,0): 7/35 (20.0%)
  Node (1,1): 3/24 (12.5%)


In [135]:
# List TC-associated flash flood events
tc_events = tc_df[tc_df["tc_present"]].copy()
tc_events = tc_events.sort_values("timestamp")

print("TC-Associated Flash Flood Events:")
print("-" * 80)
for _, row in tc_events.iterrows():
    print(
        f"{row['timestamp'].strftime('%Y-%m-%d %H:%M')} | "
        f"Node ({row['node_i']},{row['node_j']}) | "
        f"Status: {row['storm_status']:>3} | "
        f"{row['storm_ids']}"
    )

TC-Associated Flash Flood Events:
--------------------------------------------------------------------------------
1996-07-13 13:00 | Node (1,0) | Status:  TS | 1996187N10326
1996-09-08 20:00 | Node (1,1) | Status:  EX | 1996237N14339
1997-07-16 00:00 | Node (0,1) | Status:  TS | 1997194N31286
1999-09-16 16:00 | Node (1,0) | Status:  HU | 1999251N15314
2001-06-17 15:00 | Node (0,0) | Status:  SS | 2001157N28265
2002-09-02 13:00 | Node (0,1) | Status:  TS | 2002245N29281
2004-09-08 10:00 | Node (0,0) | Status:  TD | 2004238N11325
2004-09-18 13:00 | Node (0,1) | Status:  EX | 2004247N10332
2004-09-28 21:00 | Node (0,1) | Status:  EX | 2004258N16300
2005-07-06 23:00 | Node (1,1) | Status:  TD | 2005185N18273
2005-10-14 21:00 | Node (0,1) | Status:  EX | 2005281N26303
2006-07-21 21:00 | Node (0,1) | Status:  EX | 2006200N32287
2007-06-04 13:00 | Node (1,0) | Status:  EX | 2007151N18273
2008-09-06 22:00 | Node (1,0) | Status:  TS | 2008241N19303
2014-07-03 00:00 | Node (1,0) | Status:  HU |

In [136]:
# Create bar chart showing TC vs non-TC events by SOM node
fig, axes = plt.subplots(1, 2, figsize=(8, 3.5), dpi=600, constrained_layout=True)

# Compute counts per node
tc_counts = np.zeros((xdim, ydim))
non_tc_counts = np.zeros((xdim, ydim))

for i in range(xdim):
    for j in range(ydim):
        node_data = tc_df[(tc_df["node_i"] == i) & (tc_df["node_j"] == j)]
        tc_counts[i, j] = node_data["tc_present"].sum()
        non_tc_counts[i, j] = (~node_data["tc_present"]).sum()

# Left panel: Stacked bar chart
ax = axes[0]
node_labels = [f"({i},{j})" for i in range(xdim) for j in range(ydim)]
x = np.arange(len(node_labels))
width = 0.6

tc_flat = tc_counts.flatten()
non_tc_flat = non_tc_counts.flatten()

bars1 = ax.bar(x, non_tc_flat, width, label="Non-TC", color="steelblue", alpha=0.9)
bars2 = ax.bar(x, tc_flat, width, bottom=non_tc_flat, label="TC-Associated", color="coral", alpha=0.9)

ax.set_xlabel("SOM Node", fontsize=7)
ax.set_ylabel("Number of Events", fontsize=7)
ax.set_title("Flash Flood Events by TC Association", fontsize=8)
ax.set_xticks(x)
ax.set_xticklabels(node_labels, fontsize=6)
ax.legend(fontsize=6, loc="upper right")
ax.grid(True, linewidth=0.3, alpha=0.5, axis="y")

# Right panel: TC percentage by node
ax = axes[1]
totals = tc_flat + non_tc_flat
tc_pct = 100 * tc_flat / totals

bars = ax.bar(x, tc_pct, width, color="coral", alpha=0.9, edgecolor="white")

# Add percentage labels on bars
for bar, pct, tc, total in zip(bars, tc_pct, tc_flat, totals):
    ax.text(
        bar.get_x() + bar.get_width() / 2,
        bar.get_height() + 1,
        f"{int(tc)}/{int(total)}",
        ha="center",
        va="bottom",
        fontsize=6,
    )

# Add overall average line
overall_pct = 100 * tc_flat.sum() / totals.sum()
ax.axhline(overall_pct, color="red", linestyle="--", linewidth=1, label=f"Overall: {overall_pct:.1f}\\%")

ax.set_xlabel("SOM Node", fontsize=7)
ax.set_ylabel("TC-Associated Events (\\%)", fontsize=7)
ax.set_title("Percentage of Events with TC Influence", fontsize=8)
ax.set_xticks(x)
ax.set_xticklabels(node_labels, fontsize=6)
ax.set_ylim(0, max(tc_pct) + 15)
ax.legend(fontsize=6, loc="upper right")
ax.grid(True, linewidth=0.3, alpha=0.5, axis="y")

plt.savefig("figs/Z500-and-ivt-SOM/Z500_and_ivt_som_tc_association.png", bbox_inches="tight")
plt.close()

print("Saved TC association figure to figs/Z500-and-ivt-SOM/Z500_and_ivt_som_tc_association.png")

Saved TC association figure to figs/Z500-and-ivt-SOM/Z500_and_ivt_som_tc_association.png


In [137]:
# Statistical test: Is TC association different across SOM nodes?
from scipy.stats import chi2_contingency, fisher_exact

# Build 2xN contingency table: rows = TC/non-TC, columns = nodes
contingency_tc = np.array([tc_flat, non_tc_flat])

print("=" * 60)
print("CHI-SQUARE TEST: TC Association vs SOM Node")
print("=" * 60)
print("\nContingency Table:")
print(f"{'':>12}", end="")
for i in range(xdim):
    for j in range(ydim):
        print(f"({i},{j}):>8", end="")
print(f"{'Total':>8}")

print(f"{'TC':>12}", end="")
for val in tc_flat:
    print(f"{int(val):>8}", end="")
print(f"{int(tc_flat.sum()):>8}")

print(f"{'Non-TC':>12}", end="")
for val in non_tc_flat:
    print(f"{int(val):>8}", end="")
print(f"{int(non_tc_flat.sum()):>8}")

print(f"{'Total':>12}", end="")
for val in totals:
    print(f"{int(val):>8}", end="")
print(f"{int(totals.sum()):>8}")

# Chi-square test
chi2, p_value, dof, expected = chi2_contingency(contingency_tc)

print(f"\nChi-square statistic: {chi2:.3f}")
print(f"Degrees of freedom:   {dof}")
print(f"p-value:              {p_value:.4f}")

# Check expected cell counts
min_expected = expected.min()
print(f"\nMinimum expected count: {min_expected:.2f}")

if min_expected < 5:
    print("⚠ Warning: Some expected counts < 5; interpret with caution")

if p_value < 0.05:
    print("\n→ Result: REJECT H₀ at α=0.05. TC association differs across SOM nodes.")
else:
    print(
        "\n→ Result: FAIL TO REJECT H₀. No significant difference in TC association across nodes."
    )

# Pairwise Fisher's exact tests
print("\n" + "=" * 60)
print("PAIRWISE FISHER'S EXACT TESTS (Bonferroni-corrected)")
print("=" * 60)

pairs = list(combinations(range(n_nodes), 2))
n_comparisons = len(pairs)
alpha_corrected = 0.05 / n_comparisons

print(f"Number of comparisons: {n_comparisons}")
print(f"Bonferroni-corrected α: {alpha_corrected:.4f}\n")

for idx1, idx2 in pairs:
    # 2x2 table for this pair
    table = np.array(
        [[tc_flat[idx1], tc_flat[idx2]], [non_tc_flat[idx1], non_tc_flat[idx2]]]
    )
    odds_ratio, p = fisher_exact(table)
    sig = "***" if p < alpha_corrected else ""
    print(
        f"{node_labels[idx1]} vs {node_labels[idx2]}: OR={odds_ratio:.2f}, p={p:.4f} {sig}"
    )


CHI-SQUARE TEST: TC Association vs SOM Node

Contingency Table:
            (0,0):>8(0,1):>8(1,0):>8(1,1):>8   Total
          TC       5       7       7       3      22
      Non-TC      20      26      28      21      95
       Total      25      33      35      24     117

Chi-square statistic: 0.806
Degrees of freedom:   3
p-value:              0.8480

Minimum expected count: 4.51

→ Result: FAIL TO REJECT H₀. No significant difference in TC association across nodes.

PAIRWISE FISHER'S EXACT TESTS (Bonferroni-corrected)
Number of comparisons: 6
Bonferroni-corrected α: 0.0083

(0,0) vs (0,1): OR=0.93, p=1.0000 
(0,0) vs (1,0): OR=1.00, p=1.0000 
(0,0) vs (1,1): OR=1.75, p=0.7019 
(0,1) vs (1,0): OR=1.08, p=1.0000 
(0,1) vs (1,1): OR=1.88, p=0.4939 
(1,0) vs (1,1): OR=1.75, p=0.5059 


In [138]:
# Map showing TC positions during flash flood events, colored by SOM node
fig, ax = plt.subplots(
    figsize=(8, 5),
    subplot_kw={"projection": ccrs.PlateCarree()},
    dpi=600,
)

# Define colors for each node
node_colors = {
    (0, 0): "tab:blue",
    (0, 1): "tab:orange",
    (1, 0): "tab:green",
    (1, 1): "tab:red",
}

# Plot domain box
ax.plot(
    [lon_min, lon_max, lon_max, lon_min, lon_min],
    [lat_min, lat_min, lat_max, lat_max, lat_min],
    "k--",
    linewidth=1,
    transform=ccrs.PlateCarree(),
    label="Analysis Domain",
)

# For each TC-associated event, plot the TC track segment
for _, row in tc_events.iterrows():
    event_time = row["timestamp"]
    node_i, node_j = int(row["node_i"]), int(row["node_j"])
    storm_ids = row["storm_ids"].split(", ")

    for sid in storm_ids:
        # Get track for this storm around event time
        storm_data = ibtracs[ibtracs["SID"] == sid].copy()
        storm_data = storm_data[
            (storm_data["ISO_TIME"] >= event_time - pd.Timedelta(hours=48))
            & (storm_data["ISO_TIME"] <= event_time + pd.Timedelta(hours=48))
        ]

        if len(storm_data) > 1:
            ax.plot(
                storm_data["LON"],
                storm_data["LAT"],
                color=node_colors[(node_i, node_j)],
                linewidth=1.5,
                alpha=0.7,
                transform=ccrs.PlateCarree(),
            )

        # Mark position at event time (closest 6-hourly fix)
        closest_idx = (storm_data["ISO_TIME"] - event_time).abs().idxmin()
        closest = storm_data.loc[closest_idx]
        ax.scatter(
            closest["LON"],
            closest["LAT"],
            color=node_colors[(node_i, node_j)],
            s=40,
            marker="o",
            edgecolor="black",
            linewidth=0.5,
            transform=ccrs.PlateCarree(),
            zorder=5,
        )

# Add NYC marker
ax.scatter(
    -74.0,
    40.7,
    color="black",
    s=100,
    marker="*",
    zorder=10,
    transform=ccrs.PlateCarree(),
)
ax.text(-73.5, 40.7, "NYC", fontsize=7, transform=ccrs.PlateCarree())

# Map features
ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
ax.add_feature(cfeature.STATES.with_scale("50m"), linewidth=0.3)
ax.add_feature(cfeature.BORDERS, linewidth=0.3)
ax.set_extent([lon_min - 5, lon_max + 5, lat_min - 5, lat_max + 5])

# Legend
legend_elements = [
    plt.Line2D(
        [0], [0], color=node_colors[(i, j)], linewidth=2, label=f"Node ({i},{j})"
    )
    for i in range(xdim)
    for j in range(ydim)
]
ax.legend(handles=legend_elements, loc="lower right", fontsize=6)

ax.set_title(
    "TC Tracks During Flash Flood Events\n(±48 hr track segments, markers at event time)",
    fontsize=8,
)

plt.savefig("figs/Z500-and-ivt-SOM/Z500_and_ivt_som_tc_tracks.png", bbox_inches="tight")
plt.close()

print("Saved TC tracks map to figs/Z500-and-ivt-SOM/Z500_and_ivt_som_tc_tracks.png")


Saved TC tracks map to figs/Z500-and-ivt-SOM/Z500_and_ivt_som_tc_tracks.png


In [139]:
# Compare precipitation intensity for TC vs non-TC events
# Merge TC info with precipitation data
bmu_df_with_tc = bmu_df.merge(
    tc_df[["timestamp", "tc_present", "storm_ids"]],
    on="timestamp",
    how="left"
)

print("Precipitation Statistics: TC vs Non-TC Events")
print("=" * 65)
print(f"{'Category':<20} {'N':>6} {'Mean (in)':>12} {'Median (in)':>12} {'Std (in)':>10}")
print("-" * 65)

# Overall comparison
tc_precip = bmu_df_with_tc[bmu_df_with_tc["tc_present"]]["max_precip_in"].dropna()
non_tc_precip = bmu_df_with_tc[~bmu_df_with_tc["tc_present"]]["max_precip_in"].dropna()

print(f"{'TC-Associated':<20} {len(tc_precip):>6} {tc_precip.mean():>12.2f} {tc_precip.median():>12.2f} {tc_precip.std():>10.2f}")
print(f"{'Non-TC':<20} {len(non_tc_precip):>6} {non_tc_precip.mean():>12.2f} {non_tc_precip.median():>12.2f} {non_tc_precip.std():>10.2f}")
print("-" * 65)

# Mann-Whitney U test for difference
from scipy.stats import mannwhitneyu

if len(tc_precip) > 0 and len(non_tc_precip) > 0:
    stat, p = mannwhitneyu(tc_precip, non_tc_precip, alternative="two-sided")
    print(f"\nMann-Whitney U test: U={stat:.1f}, p={p:.4f}")
    if p < 0.05:
        print("→ Significant difference in precipitation intensity between TC and non-TC events")
    else:
        print("→ No significant difference in precipitation intensity")

Precipitation Statistics: TC vs Non-TC Events
Category                  N    Mean (in)  Median (in)   Std (in)
-----------------------------------------------------------------
TC-Associated            22         1.17         1.02       0.77
Non-TC                   94         0.94         0.89       0.50
-----------------------------------------------------------------

Mann-Whitney U test: U=1197.5, p=0.2510
→ No significant difference in precipitation intensity


In [140]:
# Save TC association data for future reference
tc_df.to_csv("data/som_2x2_tc_associations.csv", index=False)
print(f"Saved TC association data to data/som_2x2_tc_associations.csv")

# Summary
print("\n" + "=" * 60)
print("SUMMARY: Tropical Cyclone Association with SOM Nodes")
print("=" * 60)
print(f"\nTotal flash flood events: {len(tc_df)}")
print(f"TC-associated events: {n_tc_events} ({100*n_tc_events/len(tc_df):.1f}%)")
print(f"\nTC association by node:")
for i in range(xdim):
    for j in range(ydim):
        node_data = tc_df[(tc_df["node_i"] == i) & (tc_df["node_j"] == j)]
        tc_count = node_data["tc_present"].sum()
        total = len(node_data)
        pct = 100 * tc_count / total if total > 0 else 0
        print(f"  Node ({i},{j}): {tc_count:2d}/{total:2d} ({pct:5.1f}%)")

Saved TC association data to data/som_2x2_tc_associations.csv

SUMMARY: Tropical Cyclone Association with SOM Nodes

Total flash flood events: 117
TC-associated events: 22 (18.8%)

TC association by node:
  Node (0,0):  5/25 ( 20.0%)
  Node (0,1):  7/33 ( 21.2%)
  Node (1,0):  7/35 ( 20.0%)
  Node (1,1):  3/24 ( 12.5%)
