## Minimal FedMosaic-style toy (2 clients, hetero dims)
- Two clients with different hidden dims (4 vs 6) and rank r=2.
- Freeze LoRA A/B; train only PQ and gating $\beta$.
- Relevance via EMA + noise + subsampling of last-layer grads from a frozen small model.
- Relevance-weighted aggregation builds per-client global PQ.

In [4]:
import numpy as np
np.set_printoptions(precision=3, suppress=True)
rng = np.random.default_rng(0)

def softmax_rows(x, tau=1.0):
    z = x / tau
    z = z - z.max(axis=1, keepdims=True)
    e = np.exp(z)
    return e / e.sum(axis=1, keepdims=True)

def cosine(a, b):
    a_flat, b_flat = a.ravel(), b.ravel()
    denom = (np.linalg.norm(a_flat) * np.linalg.norm(b_flat) + 1e-9)
    return float(np.dot(a_flat, b_flat) / denom)

### Toy data and heterogeneous models

In [6]:
clients = 2
hidden_dims = [4, 6]  # hetero input/hidden sizes per client
output_dim = 2        # keep output dim shared
rank = 2              # low-rank size for LoRA/PQ

def make_data(mean, n=64):
    d = len(mean)
    x = rng.normal(mean, 0.3, size=(n, d))
    y = np.tile(mean[:output_dim], (n, 1))  # simple target depends on first two dims
    return x, y

data = [
    make_data(np.array([1.0, 0.0, 0.5, -0.3])),
    make_data(np.array([0.0, 1.0, -0.2, 0.4, 0.3, -0.1]))
]

# Frozen base weights and LoRA A/B (remain fixed)
Wp = [rng.normal(0, 0.5, size=(output_dim, h)) for h in hidden_dims]
A = [rng.normal(0, 0.3, size=(rank, h)) for h in hidden_dims]
B = [rng.normal(0, 0.3, size=(output_dim, rank)) for _ in hidden_dims]

# Trainable PQ and gating beta (one layer per client)
P_local = [np.zeros((rank, rank)) for _ in hidden_dims]
Q_local = [np.zeros(rank) for _ in hidden_dims]
beta = [0.5 for _ in hidden_dims]

# Server-side aggregated PQ for each client (start at zeros)
P_global = [np.zeros_like(P_local[i]) for i in range(clients)]
Q_global = [np.zeros_like(Q_local[i]) for i in range(clients)]

### PQ-LoRA forward and local update
- $h_L = B (P A h_I + Q)$ and $h_G$ uses aggregated $P, Q$.
- $h_O = W_p h_I + (1-\beta) h_L + \beta h_G$.
- Update only $P, Q, \beta$ with MSE loss.

In [7]:
def pq_forward(x, Wp_i, A_i, B_i, P_i, Q_i):
    # x: [batch, d_in]
    hidden_proj = A_i @ x.T                       # [r, batch]
    low_rank = P_i @ hidden_proj + Q_i[:, None]   # [r, batch]
    h = (B_i @ low_rank).T                        # [batch, d_out]
    base = (Wp_i @ x.T).T                         # [batch, d_out]
    return base, h

def local_step(idx, x, y, lr=0.3):
    Wp_i, A_i, B_i = Wp[idx], A[idx], B[idx]
    P_i, Q_i = P_local[idx], Q_local[idx]
    P_g, Q_g = P_global[idx], Q_global[idx]
    beta_i = beta[idx]

    base, h_L = pq_forward(x, Wp_i, A_i, B_i, P_i, Q_i)
    _, h_G = pq_forward(x, Wp_i, A_i, B_i, P_g, Q_g)
    h_O = base + (1 - beta_i) * h_L + beta_i * h_G

    err = h_O - y
    bsz = len(x)
    grad_out = (2.0 / bsz) * err                   # dL/dh_O

    # Grad beta: derivative of mix between h_G and h_L
    grad_beta = np.sum(grad_out * (h_G - h_L))

    # Grad P/Q via chain rule through low-rank path
    err_proj = grad_out @ B_i                      # [batch, r]
    hidden_proj = (A_i @ x.T).T                    # [batch, r]
    grad_P = err_proj.T @ hidden_proj              # [r, r]
    grad_Q = err_proj.sum(axis=0)                  # [r]

    # Parameter updates (A/B frozen)
    P_local[idx] = P_i - lr * grad_P
    Q_local[idx] = Q_i - lr * grad_Q
    beta[idx] = beta_i - lr * grad_beta

def batch_from_client(idx, batch_size=8):
    x, y = data[idx]
    sel = rng.choice(len(x), size=batch_size, replace=False)
    return x[sel], y[sel]

### Relevance via EMA + noise + subsampling
Use a tiny frozen model $W_s$ (2Ã—2) on the first two features to compute last-layer gradients.

In [8]:
W_s = rng.normal(0, 1.0, size=(2, 2))  # frozen small model
g_ema = [np.zeros_like(W_s) for _ in range(clients)]

def last_layer_grad(x, y):
    x2 = x[:, :2]  # use first two dims
    y2 = y[:, :2]
    pred = (W_s @ x2.T).T  # [batch, 2]
    err = pred - y2
    bsz = len(x2)
    grad = (2.0 / bsz) * (err.T @ x2)  # [2,2]
    return grad

def sanitize_grad(g, noise=0.05, subsample=0.6):
    noisy = g + rng.normal(0, noise, size=g.shape)
    flat = noisy.ravel()
    k = max(1, int(len(flat) * subsample))
    mask_idx = rng.choice(len(flat), size=k, replace=False)
    mask = np.zeros_like(flat)
    mask[mask_idx] = 1
    return (flat * mask).reshape(g.shape)

### Run a few federated rounds

In [9]:
rounds = 3
local_steps = 4
alpha = 0.6  # EMA decay
tau = 0.5    # softmax temperature for relevance

for r in range(rounds):
    # Client-side local training
    for i in range(clients):
        for _ in range(local_steps):
            x_b, y_b = batch_from_client(i)
            local_step(i, x_b, y_b, lr=0.1)

        # Gradient for relevance from frozen small model
        x_g, y_g = batch_from_client(i)
        g_now = last_layer_grad(x_g, y_g)
        g_ema[i] = (1 - alpha) * g_ema[i] + alpha * g_now

    # Server-side relevance + aggregation
    sanitized = [sanitize_grad(g) for g in g_ema]
    S = np.zeros((clients, clients))
    for i in range(clients):
        for j in range(clients):
            S[i, j] = cosine(sanitized[i], sanitized[j])

    weights = softmax_rows(S, tau=tau)

    # Build customized global PQ per client
    P_new, Q_new = [], []
    for i in range(clients):
        P_i = sum(weights[i, j] * P_local[j] for j in range(clients))
        Q_i = sum(weights[i, j] * Q_local[j] for j in range(clients))
        P_new.append(P_i)
        Q_new.append(Q_i)

    P_global, Q_global = P_new, Q_new

    print(f"Round {r+1}")
    print("Relevance S:\n", np.round(S, 3))
    print("Weights per client (rows):\n", np.round(weights, 3))
    print("Beta values:", [round(b, 3) for b in beta])
    print("P norms (local):", [round(np.linalg.norm(P_local[i]), 3) for i in range(clients)])
    print("P norms (global):", [round(np.linalg.norm(P_global[i]), 3) for i in range(clients)])
    print('-' * 60)

# Quick sanity: forward one batch with final globals
x0, y0 = batch_from_client(0)
base0, hL0 = pq_forward(x0, Wp[0], A[0], B[0], P_local[0], Q_local[0])
_, hG0 = pq_forward(x0, Wp[0], A[0], B[0], P_global[0], Q_global[0])
hO0 = base0 + (1 - beta[0]) * hL0 + beta[0] * hG0
print("Sample output client 0 (first row):", np.round(hO0[0], 3))
print("Target (first row):", np.round(y0[0], 3))

Round 1
Relevance S:
 [[1.    0.919]
 [0.919 1.   ]]
Weights per client (rows):
 [[0.54 0.46]
 [0.46 0.54]]
Beta values: [np.float64(0.423), np.float64(0.414)]
P norms (local): [np.float64(0.185), np.float64(0.211)]
P norms (global): [np.float64(0.164), np.float64(0.167)]
------------------------------------------------------------
Round 2
Relevance S:
 [[1. 0.]
 [0. 1.]]
Weights per client (rows):
 [[0.881 0.119]
 [0.119 0.881]]
Beta values: [np.float64(0.235), np.float64(0.197)]
P norms (local): [np.float64(0.332), np.float64(0.381)]
P norms (global): [np.float64(0.314), np.float64(0.353)]
------------------------------------------------------------
Round 3
Relevance S:
 [[1.    0.823]
 [0.823 1.   ]]
Weights per client (rows):
 [[0.587 0.413]
 [0.413 0.587]]
Beta values: [np.float64(0.178), np.float64(0.092)]
P norms (local): [np.float64(0.425), np.float64(0.54)]
P norms (global): [np.float64(0.393), np.float64(0.417)]
------------------------------------------------------------
Sam

The printouts show the relevance matrix, softmax weights, learned $\beta$, and the norms of local/global $P$ after each round. Adjust `rounds`, `local_steps`, or noise/subsample ratios to explore behavior.