In [None]:
%load_ext autoreload
%autoreload 2
from fedflax import train
from models import LeNet
from data import fetch_data
from utils import angle_err, opt_create, return_l2
import jax
from jax import numpy as jnp
from flax import nnx
from tqdm.auto import tqdm
from functools import reduce, partial
from matplotlib import pyplot as plt

In [None]:
# Measurement of functional interference between one client's update vs all updates
entanglement_fn = nnx.jit(nnx.vmap(
    lambda model_g, models, x, z: jnp.abs(model_g(x, z) - models(x, z)), 
    in_axes=(None, 0, 0, 0)
))
key = jax.random.key(42)

# For different distances, compute entanglement scores
inducers = [0., 0.00782, 0.01564, 0.024] + (jnp.linspace(jnp.sqrt(0.032), 1, 7)**2).tolist()
distances = []
fl_entanglement_scores = []
rand_entanglement_scores = []
for inducer in inducers:
    # Get datasets which induce this distance
    ds_train = fetch_data("feature", beta=inducer)
    ds_test = fetch_data("feature", partition="test", beta=inducer)
    ds_val = fetch_data("feature", partition="val", beta=inducer)
    # Get federated models which are this far apart
    model_init = LeNet(key)
    fl_models, _ = train(model_init, partial(opt_create, learning_rate=1e-3), 
                      ds_train, return_l2(0.), ds_val, local_epochs="early", 
                      max_patience=5, val_fn=angle_err, rounds=1)
    # Check entanglement wrt global model
    struct, params_l, rest = nnx.split(fl_models, (nnx.Param, nnx.BatchStat), ...)
    params_g = jax.tree.map(lambda p: p.mean(0), params_l)
    model_g = nnx.merge(struct, params_g, rest)
    fl_entanglement_scores.append(reduce(
        lambda acc, batch: acc + entanglement_fn(model_g, fl_models, *batch[1:]).mean().item(), 
        ds_test,
        0.
    ) / len(ds_test))
    # Get L2 distance
    distance = jax.tree.map(lambda pg, pl: jnp.abs(pg - pl)**2, params_g, params_l)
    distance = jax.tree.reduce(lambda acc, d: acc+jnp.sum(d.reshape(4,-1), -1), distance, 0.)
    distance = jnp.sqrt(distance.mean())
    distances.append(distance.item())
    # Get models which are (measured across all params) this far apart at random
    rand_model_g = LeNet(key)
    rstate, rparams_g, rrest = nnx.split(rand_model_g, (nnx.State, nnx.Param), ...)
    rand = jax.tree.map(
        lambda p: jax.random.normal(key, (4, *p.shape)), 
        rparams_g
    )
    norm = jnp.linalg.norm(
        jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), jax.tree.leaves(rand)), axis=1),
        axis=-1
    )
    rparams_l = jax.tree.map(lambda p, r: p+(r.T/norm).T*distance, rparams_g, rand)
    rand_models = nnx.merge(rstate, rparams_l, rrest)
    # Now check if this disentanglement is better than random
    rand_entanglement_scores.append(reduce(
        lambda acc, batch: acc + entanglement_fn(rand_model_g, rand_models, *batch[1:]).mean().item(), 
        ds_test, 
        0.
    ) / len(ds_test))

## Visualize

In [None]:
# Matplotlib rendering style
plt.style.use("seaborn-v0_8-pastel")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"""
        \usepackage{amsmath, amssymb}
        \usepackage{mathptmx}  % Safe fallback for Times + math
    """
})

# Plot entanglement vs distance
distances, inducers, fl_entanglement_scores, rand_entanglement_scores = zip(*sorted(zip(distances, inducers, fl_entanglement_scores, rand_entanglement_scores)))
fig, ax = plt.subplots();
ax.plot(distances, fl_entanglement_scores, label="FL models", c="C0");
ax.plot(distances, rand_entanglement_scores, label="Random models", c="C1");
ax.set_xlabel("Euclidean Parameter Divergence");
ax.set_ylabel("Functional Entanglement");
ax.legend();
ax.set_xlim(min(distances), max(distances));
ax.grid(True, linestyle="--", linewidth=0.5);
ax2 = ax.twiny();
ax2.xaxis.set_ticks_position("bottom");
ax2.xaxis.set_label_position("bottom");
ax2.spines.bottom.set_position(("outward", 36));
ax2.set_xlim(min(distances), max(distances));
ax2.set_xticks(distances, labels=inducers);
ax2.set_xlabel("Inducer of divergence");
fig.savefig("disentanglement.png", dpi=300, bbox_inches="tight");