In [None]:
import numpy as np

# ---------- helpers (from your formulas) ----------
def compute_Lgw(C_T, C_S, a, b, gamma):
    CT2, CS2 = C_T**2, C_S**2
    term1 = (CT2 @ a)[:, None]
    term2 = (CS2 @ b)[None, :]
    cross = C_T @ gamma @ C_S.T
    return term1 + term2 - 2.0 * cross

def compute_Lcladegw(C_T, C_S, a, b, gamma, Omega):
    CT2, CS2 = C_T**2, C_S**2
    term1 = (Omega * CT2) @ a           # [n]
    term2 = Omega @ gamma @ (CS2).T     # [n, m]
    cross = (Omega * C_T) @ gamma @ C_S.T
    return term1[:, None] + term2 - 2.0 * cross

# ---------- toy problem ----------
n = m = 4
# uniform marginals
a = np.ones(n) / n
b = np.ones(m) / m

# identity mapping: each cell i goes to spot i with mass a_i
gamma = np.diag(a)  # shape [4,4]

# two clades: {0,1} and {2,3}
clades = np.array([0,0,1,1])
K = 2
omega = np.eye(K)[clades]              # [n,K] one-hot
Omega_within = (omega @ omega.T).astype(float)  # [n,n], 1 if same clade else 0
Omega_cross = 1.0 - Omega_within  # [n,n], 1 if *different* clade else 0


# cell–cell tree distances (small within clade, large across)
C_T = np.array([
    [0,1,3,3],
    [1,0,3,3],
    [3,3,0,1],
    [3,3,1,0],
], dtype=float)

# spot–spot spatial distances (permute to introduce mismatch)
C_S = np.array([
    [0,3,1,3],
    [3,0,3,1],
    [1,3,0,3],
    [3,1,3,0],
], dtype=float)

# ---------- compute costs ----------
L_gw   = compute_Lgw(C_T, C_S, a, b, gamma)
L_om   = compute_Lcladegw(C_T, C_S, a, b, gamma, Omega_within)
L_om_cross = compute_Lcladegw(C_T, C_S, a, b, gamma, Omega_cross)

Delta  = L_om - L_gw
Delta_cross = L_om + L_om_cross - L_gw
np.set_printoptions(precision=3, suppress=True)
print("Omega (clade mask):\n", Omega_within, "\n")
print("L_gw (global):\n", L_gw, "\n")
print("L_clade (with Omega):\n", L_om + L_om_cross, "\n")
print("Difference L_clade - L_gw:\n", Delta, "\n")

# Optional: show block means for intuition
blk1 = np.s_[0:2]; blk2 = np.s_[2:4]
def block_mean(M, rr, cc): return M[rr, :][:, cc].mean()
print("Block means of L_gw:")
print("  clade1→spots0-1:", block_mean(L_gw, blk1, slice(0,2)))
print("  clade1→spots2-3:", block_mean(L_gw, blk1, slice(2,4)))
print("  clade2→spots0-1:", block_mean(L_gw, blk2, slice(0,2)))
print("  clade2→spots2-3:", block_mean(L_gw, blk2, slice(2,4)))
print("\nBlock means of L_clade:")
print("  clade1→spots0-1:", block_mean(L_om, blk1, slice(0,2)))
print("  clade1→spots2-3:", block_mean(L_om, blk1, slice(2,4)))
print("  clade2→spots0-1:", block_mean(L_om, blk2, slice(0,2)))
print("  clade2→spots2-3:", block_mean(L_om, blk2, slice(2,4)))


Omega (clade mask):
 [[1. 1. 0. 0.]
 [1. 1. 0. 0.]
 [0. 0. 1. 1.]
 [0. 0. 1. 1.]] 

L_gw (global):
 [[2.  3.5 3.5 4.5]
 [3.5 2.  4.5 3.5]
 [3.5 4.5 2.  3.5]
 [4.5 3.5 3.5 2. ]] 

L_clade (with Omega):
 [[1.   2.5  1.25 2.25]
 [2.5  1.   2.25 1.25]
 [1.25 2.25 1.   2.5 ]
 [2.25 1.25 2.5  1.  ]] 

Difference L_clade - L_gw:
 [[-1.   -1.   -2.25 -2.25]
 [-1.   -1.   -2.25 -2.25]
 [-2.25 -2.25 -1.   -1.  ]
 [-2.25 -2.25 -1.   -1.  ]] 

Block means of L_gw:
  clade1→spots0-1: 2.75
  clade1→spots2-3: 4.0
  clade2→spots0-1: 4.0
  clade2→spots2-3: 2.75

Block means of L_clade:
  clade1→spots0-1: 1.75
  clade1→spots2-3: 1.75
  clade2→spots0-1: 1.75
  clade2→spots2-3: 1.75
