# Topology heatmaps

Visualize per-output surviving input connections as 8x8 heatmaps.
Each pixel shows how many parallel edges (0-3) remain for that output.


In [None]:
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

from topology_io import load_topology_npz


In [None]:
# Explicit topology selection (switch by commenting).

# topo_path = Path("scikit_digit/topology/digits_8x8_dense_io_x1.npz")
# topo_path = Path("scikit_digit/topology/digits_8x8_dense_io_x1_pruned_vglt0p6_epoch20_run20260110-235531.npz")
# topo_path = Path("scikit_digit/topology/digits_8x8_dense_io_x1_pruned_vglt0p75_epoch20_run20260110-235531.npz")
# topo_path = Path("scikit_digit/topology/digits_8x8_dense_io_x1_autoprune_epoch38_run20260111-000650.npz")


In [None]:
topo = load_topology_npz(topo_path)
if topo.Nin != 64:
    raise ValueError(f"Expected 64 inputs for 8x8, got {topo.Nin}")

input_index = {int(n): i for i, n in enumerate(topo.input_nodes.tolist())}
out_index = {int(n): i for i, n in enumerate(topo.out_nodes.tolist())}

counts = np.zeros((topo.K, topo.Nin), dtype=int)
for d, s in zip(topo.edges_D.tolist(), topo.edges_S.tolist()):
    counts[out_index[int(s)], input_index[int(d)]] += 1

counts_8x8 = counts.reshape(topo.K, 8, 8)
counts_8x8.shape


In [None]:
fig, axes = plt.subplots(2, 5, figsize=(12, 5), constrained_layout=True)
axes = axes.ravel()
vmin, vmax = 0, 3

for k in range(topo.K):
    ax = axes[k]
    im = ax.imshow(counts_8x8[k], vmin=vmin, vmax=vmax, cmap="viridis")
    ax.set_title(f"Output {k}")
    ax.set_xticks([])
    ax.set_yticks([])

for ax in axes[topo.K:]:
    ax.axis("off")

fig.colorbar(im, ax=axes.tolist(), shrink=0.8, label="connections per pixel")
plt.show()
