# Notebook for Z500, IVT, and MSLP SOM Training

By: Ty Janoski

Updated 1/13/2026

## Setup

### Imports

In [3]:
# Import Statements
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cmweather  # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scienceplots  # noqa: F401
import xarray as xr
from minisom import MiniSom
from sklearn.manifold import MDS
from sklearn.metrics import pairwise_distances

plt.style.use(["science", "nature", "grid"])
plt.rcParams["text.usetex"] = True


### Data Loading

In [13]:
# Read in Z500, IVT, and MSLP at flash-flood event times
path = "/mnt/drive2/SOM_intermediate_files/"
ivt_norm_weighted_ffe = xr.load_dataset(f"{path}era5_ivt_norm_weighted_ffe.nc")["ivt"]
ivt_norm_ffe = xr.load_dataset(f"{path}era5_ivt_norm_ffe.nc")["ivt"]
ivt_ffe = xr.load_dataset(f"{path}era5_ivt_ffe.nc")["ivt"]

z500_norm_weighted_ffe = xr.load_dataarray(f"{path}era5_Z500_norm_weighted_ffe.nc")
z500_norm_ffe = xr.load_dataarray(f"{path}era5_Z500_norm_ffe.nc")
z500_ffe = xr.load_dataarray(f"{path}era5_Z500_ffe.nc")

mslp_norm_weighted_ffe = xr.load_dataarray(f"{path}era5_mslp_norm_weighted_ffe.nc")
mslp_norm_ffe = xr.load_dataarray(f"{path}era5_mslp_norm_ffe.nc")
mslp_ffe = xr.load_dataarray(f"{path}era5_mslp_ffe.nc")

# Total precipitation at flash-flood event times
tp_ffe = xr.load_dataarray(f"{path}era5_tp_ffe.nc")

### Reshape Data

In [14]:
# Flatten the data for SOM training
z500_flat = z500_norm_weighted_ffe.stack(
    features=["lat", "lon"]
).values  # shape: (time, feature)
ivt_flat = ivt_norm_weighted_ffe.stack(
    features=["latitude", "longitude"]
).values  # shape: (time, features)
mslp_flat = mslp_norm_weighted_ffe.stack(features=["latitude", "longitude"]).values

# Concatenate the data
X = np.concatenate([z500_flat, ivt_flat, mslp_flat], axis=1)  # shape: (time, feature)


## SOM Training

We are going to train our SOM with random initialization and online training. We will also use two phases: a "coarse" phase with a larger sigma and learning rate, then a "fine" phase with a smaller learning rate and sigma.

### Set SOM parameters

In [15]:
# Set SOM shape
xdim, ydim = 3, 2

# Set number of iterations for each phase
n1, n2 = 2000, 8000

# Set starting sigmas
sig1, sig2 = np.sqrt(xdim**2 + ydim**2), 1.5

# Set starting learning rates
lr1, lr2 = 0.1, 0.01

# Random seed for reproducibility
random_seed = 42


### Train SOM

In [16]:
# Create SOM instance
som = MiniSom(
    xdim,
    ydim,
    input_len=X.shape[1],
    sigma=sig1,
    learning_rate=lr1,
    decay_function="linear_decay_to_zero",
    sigma_decay_function="linear_decay_to_one",
    neighborhood_function="gaussian",
    random_seed=random_seed,
)

# Initialize random weights
som.random_weights_init(X)

# Random training
som.train_random(X, n1, verbose=True)

# Phase 2
som._sigma = sig2 # type: ignore
som._learning_rate = lr2
som.train_random(X, n2, verbose=True)


 [ 2000 / 2000 ] 100% - 0:00:00 left 
 quantization error: 162.02108337413196
 [ 8000 / 8000 ] 100% - 0:00:00 left 
 quantization error: 159.55887400869557


### Grab important fields

In [17]:
# Total node number
n_nodes = xdim * ydim

# Get flattened weights
weights = som.get_weights().reshape(xdim * ydim, -1)

# u-matrix
u_matrix = som.distance_map().T

# bmus & hit_map
bmus = np.array([som.winner(x) for x in X])

hit_map = np.zeros((xdim, ydim))
for i, j in bmus:
    hit_map[i, j] += 1
hit_map = hit_map.T

# Sammon Coordinates
D = pairwise_distances(weights)
coords = MDS(
    n_components=2,
    metric="precomputed",
    random_state=42,
    n_init=4,
    init="random",  # pyright: ignore[reportCallIssue]
).fit_transform(D)

# Get lats/lons
lat = ivt_norm_ffe.latitude
lon = ivt_norm_ffe.longitude

# Dimensions of the spatial field
n_lat = lat.size
n_lon = lon.size
n_features = n_lat * n_lon

# Split weights into Z500 and IVT components
z500_weights = weights[:, :n_features]
ivt_weights = weights[:, n_features : n_features * 2]
mslp_weights = weights[:, n_features * 2 :]

# Reshape weights back to spatial dimensions
z500_nodes = z500_weights.reshape(xdim, ydim, n_lat, n_lon)
ivt_nodes = ivt_weights.reshape(xdim, ydim, n_lat, n_lon)
mslp_nodes = mslp_weights.reshape(xdim, ydim, n_lat, n_lon)

## Plots

### U-matrix and Sammon Map

In [18]:
fig, axes = plt.subplots(1, 2, layout="constrained", figsize=(6, 3), dpi=600)

# u-matrix
im0 = axes[0].imshow(u_matrix, cmap="viridis", origin="lower")
axes[0].set_title("U-Matrix (Mean Inter-Node Distance)", fontsize=7)
fig.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04, shrink=0.7)

# hit map
im1 = axes[1].imshow(hit_map, cmap="plasma", origin="lower")
axes[1].set_title("Hit Map (Samples per Node)", fontsize=7)
fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04, shrink=0.7)

# axis styling
for ax in axes:
    ax.set_xticks(np.arange(xdim))
    ax.set_yticks(np.arange(ydim))
    ax.set_xlabel("X-index", fontsize=6)
    ax.set_ylabel("Y-index", fontsize=6)

plt.savefig("figs/Z500-IVT-mslp-SOM/u_matrix_hit_map.png")
plt.close()

In [19]:
# Flatten u-matrix & hit map
U_flat = u_matrix.T.reshape(-1)      # back to (n_nodes,)
hits_flat = hit_map.T.reshape(-1)    # back to (n_nodes,)

# scale hits
hits_scaled = 30 + 250 * (hits_flat / hits_flat.max())

# plot
plt.figure(figsize=(7, 7))

# Scatter: U controls color, hits control bubble size
sc = plt.scatter(
    coords[:, 0], coords[:, 1],
    c=U_flat,
    s=hits_scaled,
    cmap="balance", edgecolor="k", linewidth=0.5,
    zorder=3
)

# Draw lattice connections (right & down neighbors only)
for i in range(xdim):
    for j in range(ydim):
        node = i * ydim + j

        # right neighbor
        if j + 1 < ydim:
            nbr = i * ydim + (j + 1)
            plt.plot(
                [coords[node, 0], coords[nbr, 0]],
                [coords[node, 1], coords[nbr, 1]],
                "k-", lw=0.6, alpha=0.4
            )

        # down neighbor
        if i + 1 < xdim:
            nbr = (i + 1) * ydim + j
            plt.plot(
                [coords[node, 0], coords[nbr, 0]],
                [coords[node, 1], coords[nbr, 1]],
                "k-", lw=0.6, alpha=0.4
            )

# Node labels (i,j)
for idx, (x, y) in enumerate(coords):
    ix, iy = divmod(idx, ydim)
    plt.text(
        x, y, f"({ix},{iy})",
        fontsize=8, ha="center", va="center", zorder=5
    )

plt.title("Sammon / MDS Distortion Grid\nU-Matrix (Color) \\& Node Frequency (Size)")
plt.axis("off")
plt.colorbar(sc, label="U-Matrix (Avg. Neighbor Distance)")
plt.savefig("figs/Z500-IVT-mslp-SOM/sammon_mds.png", bbox_inches="tight")
plt.close()

### Node Weights Map

In [27]:
# Shading levels for Z500
levels_Z = np.arange(-1.2, 1.21, 0.2)

# IVT contour levels
levels_ivt = np.arange(-1.8, 1.81, 0.2)

# MSLP contour levels
levels_mslp = np.arange(-1.4, 1.41, 0.2)

fig, axes = plt.subplots(
    ydim,
    xdim,
    figsize=(6, 3),
    subplot_kw={"projection": ccrs.PlateCarree()},
    constrained_layout=True,
    dpi=600,
)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        # Fields for this node
        Z_field = z500_nodes[i, j, :, :]
        ivt_field = ivt_nodes[i, j, :, :]
        mslp_field = mslp_nodes[i, j, :, :]

        # --- IVT shaded ---
        im = ax.contourf(
            lon,
            lat,
            ivt_field,
            cmap="balance",
            levels=levels_ivt,
            transform=ccrs.PlateCarree(),
        )

        # Z500 contours
        cn = ax.contour(
            lon,
            lat,
            Z_field,
            colors="black",
            linewidths=0.5,
            levels=levels_Z,
            transform=ccrs.PlateCarree(),
        )

        # MSLP contours
        cn1 = ax.contour(
            lon,
            lat,
            mslp_field,
            colors="green",
            levels=levels_mslp,
            linewidths=0.5,
            transform=ccrs.PlateCarree(),
        )

        # Map features
        ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
        ax.add_feature(cfeature.STATES.with_scale("50m"), linewidth=0.4)
        ax.set_title(f"Node ({i},{j})", fontsize=6)
        ax.set_xticks([])
        ax.set_yticks([])

        # Add inline contour labels
        ax.clabel(cn, inline=True, fontsize=5, fmt="%.1f")
        ax.clabel(cn1, inline=True, fontsize=5, fmt="%.1f")

# Shared colorbar for Z500 shading
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
cbar.set_label("Standardized IVT Anomaly (Shaded)", fontsize=6)

plt.suptitle(
    "Node Weight Patterns\nZ500 (contoured, black) + MSLP (contoured, green)"
    + " + IVT (shaded)",
    fontsize=8,
)
plt.savefig(
    "figs/Z500-IVT-mslp-SOM/combined_node_weights_ivt_shaded.png", bbox_inches="tight"
)
plt.close()


### Anomaly Composite Map

In [33]:
# Create empty arrays for standardized anomalies
z500_patterns = np.full((xdim, ydim, n_lat, n_lon), np.nan)
ivt_patterns = np.full((xdim, ydim, n_lat, n_lon), np.nan)
mslp_patterns = np.full((xdim, ydim, n_lat, n_lon), np.nan)

counts = np.zeros((xdim, ydim), dtype=int)

for i in range(xdim):
    for j in range(ydim):

        # indices in this node
        idx = np.where((bmus[:, 0] == i) & (bmus[:, 1] == j))[0]
        counts[i, j] = len(idx)

        if len(idx) > 0:
            # z500 composite
            z500_patterns[i, j] = (
                z500_norm_ffe.isel(time=idx).mean("time").values
            )

            # ivt composite
            ivt_patterns[i, j] = (
                ivt_norm_ffe.isel(valid_time=idx).mean("valid_time").values
            )

            # mslp composite
            mslp_patterns[i, j] = (
                mslp_norm_ffe.isel(valid_time=idx).mean("valid_time").values
            )

In [34]:
fig, axes = plt.subplots(
    ydim,
    xdim,
    figsize=(6, 3),
    subplot_kw={"projection": ccrs.PlateCarree()},
    constrained_layout=True,
    dpi=600,
)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        # pull the Z500 & ivt composite fields for this node
        Z_field = z500_patterns[i, j, :, :]
        ivt_field = ivt_patterns[i, j, :, :]
        mslp_field = mslp_patterns[i, j, :, :]

        # --- ivt shaded composite ---
        im = ax.contourf(
            lon,
            lat,
            ivt_field,
            cmap="balance",
            levels=levels_ivt,
            transform=ccrs.PlateCarree(),
            extend="both",
        )

        # --- z500 contour overlay ---
        ax.contour(
            lon,
            lat,
            Z_field,
            colors="black",
            linewidths=0.5,
            levels=levels_Z,
            transform=ccrs.PlateCarree(),
        )

        # --- mslp contour overlay ---
        ax.contour(
            lon,
            lat,
            mslp_field,
            colors="green",
            levels=levels_mslp,
            linewidths=0.5,
            transform=ccrs.PlateCarree(),
        )

        ax.add_feature(cfeature.COASTLINE, linewidth=0.6)
        ax.add_feature(cfeature.STATES.with_scale("50m"), linewidth=0.4)

        ax.set_title(f"({i},{j})  N={counts[i, j]}", fontsize=6)
        ax.set_xticks([])
        ax.set_yticks([])

# --- Shared colorbar for Z500 shading ---
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
cbar.set_label("Standardized Anomaly", fontsize=6)

plt.suptitle(
    "SOM Composite Anomalies: Z500 (black) + MSLP (green) + IVT (shaded)",
    fontsize=8,
    y=1.04,
)

plt.savefig(
    "figs/Z500-IVT-mslp-SOM/composite_anomalies_ivt_shaded.png",
    bbox_inches="tight",
)
plt.close()


### Composite Mean Map

In [35]:
# Create empty arrays for raw composites
z500_patterns_raw = np.full((xdim, ydim, n_lat, n_lon), np.nan)
ivt_patterns_raw = np.full((xdim, ydim, n_lat, n_lon), np.nan)
mslp_patterns_raw = np.full((xdim, ydim, n_lat, n_lon), np.nan)

for i in range(xdim):
    for j in range(ydim):

        # indices in this node
        idx = np.where((bmus[:, 0] == i) & (bmus[:, 1] == j))[0]

        if len(idx) > 0:
            # z500 raw composite
            z500_patterns_raw[i, j] = (
                z500_ffe.isel(time=idx).mean("time").values
            )

            # ivt raw composite
            ivt_patterns_raw[i, j] = (
                ivt_ffe.isel(valid_time=idx).mean("valid_time").values
            )

            # mslp raw composite
            mslp_patterns_raw[i, j] = (
                mslp_ffe.isel(valid_time=idx).mean("valid_time").values
            )

In [37]:
mslp_patterns_raw.max()

np.float64(102364.3125)

In [39]:
fig, axes = plt.subplots(
    ydim,
    xdim,
    figsize=(6, 2.7),
    subplot_kw={"projection": ccrs.PlateCarree()},
    constrained_layout=True,
    dpi=600,
)

# Levels for shading (Z500)
levels_Z = range(552, 595, 3)

# IVT levels
levels_ivt = np.arange(0, 701, 100)

# mslp levels
levels_mslp = range(1000, 1033, 4)

for i in range(xdim):
    for j in range(ydim):
        ax = axes[j, i]

        # pull the Z500 & ivt composite fields for this node
        Z_field = z500_patterns_raw[i, j, :, :]
        ivt_field = ivt_patterns_raw[i, j, :, :]
        mslp_field = mslp_patterns_raw[i, j, :, :]

        # --- IVT shaded composite ---
        im = ax.contourf(
            lon,
            lat,
            ivt_field,
            cmap="BuPu",
            levels=levels_ivt,
            transform=ccrs.PlateCarree(),
            extend="max",
        )

        # --- ivt contour overlay ---
        cn = ax.contour(
            lon,
            lat,
            Z_field / 98.1,
            colors="black",
            linewidths=0.5,
            levels=levels_Z,
            transform=ccrs.PlateCarree(),
        )

        cn1 = ax.contour(
            lon,
            lat,
            mslp_field / 100,
            colors="lime",
            levels=levels_mslp,
            linewidths=0.5,
            transform=ccrs.PlateCarree(),
        )

        ax.add_feature(cfeature.COASTLINE, linewidth=0.6)
        ax.add_feature(cfeature.STATES.with_scale("50m"), linewidth=0.4)

        ax.set_title(f"({i},{j})  N={counts[i, j]}", fontsize=6)
        ax.set_xticks([])
        ax.set_yticks([])

        # Add inline contour labels
        ax.clabel(cn, inline=True, fontsize=5, fmt="%.0f")
        ax.clabel(cn1, inline=True, fontsize=5, fmt="%.0f")

# --- Shared colorbar for Z500 shading ---
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, pad=0.02)
cbar.set_label("IVT (kg m$^{-1}$ s$^{-1}$)", fontsize=6)

plt.suptitle(
    "SOM Composite: IVT (shaded) + Z500 (black) + MSLP (green)",
    fontsize=8,
    y=1.04,
)

plt.savefig(
    "figs/Z500-IVT-mslp-SOM/composite_mean_IVT_shaded.png",
    bbox_inches="tight",
)
plt.close()
