In [None]:
%load_ext autoreload
%autoreload 2
from fedflax import train
from models import LeNet
from data import fetch_data
from utils import angle_err, opt_create, return_l2
import jax
from jax import numpy as jnp
from flax import nnx
from functools import reduce, partial
from matplotlib import pyplot as plt
from copy import deepcopy
n_clients = 3

In [None]:
# Measurement of functional interference between one client's update vs all updates
entanglement_fn = nnx.jit(nnx.vmap(
    lambda model_g, models, x, z: jnp.abs(model_g(x, z) - models(x, z)), 
    in_axes=(None, 0, 0, 0)
))
key = jax.random.key(42)

# For different distances, compute entanglement scores
inducers = jnp.arange(20).tolist()
distances = []
fl_entanglement_scores = []
rand_entanglement_scores = []
model_g = LeNet(key)
for inducer in inducers:
    # Get datasets which induce this distance
    ds_train = fetch_data("label", beta=1., n_clients=n_clients)
    ds_test = fetch_data("label", partition="test", beta=1., n_clients=n_clients)
    ds_val = fetch_data("label", partition="val", beta=1., n_clients=n_clients)
    # Get federated models which are this far apart
    model_init = deepcopy(model_g)
    fl_models, _ = train(model_g, partial(opt_create, learning_rate=1e-3), 
                      ds_train, return_l2(0.), ds_val, local_epochs="early", 
                      max_patience=5, val_fn=angle_err, rounds=1, n_clients=n_clients)
    # Check entanglement wrt global model
    struct, params_l, rest = nnx.split(fl_models, (nnx.Param, nnx.BatchStat), ...)
    params_g = jax.tree.map(lambda p: p.mean(0), params_l)
    model_g = nnx.merge(struct, params_g, rest)
    fl_entanglement_scores.append(reduce(
        lambda acc, batch: acc + entanglement_fn(model_g, fl_models, *batch[1:]).mean().item(), 
        ds_test,
        0.
    ) / len(ds_test))
    # Get L2 distance
    distance = jax.tree.map(lambda pg, pl: jnp.abs(pg - pl)**2, params_g, params_l)
    distance = jax.tree.reduce(lambda acc, d: acc+jnp.sum(d.reshape(n_clients,-1), -1), distance, 0.)
    distance = jnp.sqrt(distance.mean())
    distances.append(distance.item())
    # Get models which are (measured across all params) this far apart at random
    rand_model_g = LeNet(key)
    rstate, rparams_g, rrest = nnx.split(rand_model_g, (nnx.State, nnx.Param), ...)
    rand = jax.tree.map(
        lambda p: jax.random.normal(key, (n_clients, *p.shape)), 
        rparams_g
    )
    norm = jnp.linalg.norm(
        jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (n_clients,-1)), jax.tree.leaves(rand)), axis=1),
        axis=-1
    )
    rparams_l = jax.tree.map(lambda p, r: p+(r.T/norm).T*distance, rparams_g, rand)
    rand_models = nnx.merge(rstate, rparams_l, rrest)
    # Now check if this disentanglement is better than random
    rand_entanglement_scores.append(reduce(
        lambda acc, batch: acc + entanglement_fn(rand_model_g, rand_models, *batch[1:]).mean().item(), 
        ds_test, 
        0.
    ) / len(ds_test))

## Visualize

In [None]:
# Matplotlib rendering style
plt.style.use("seaborn-v0_8-pastel")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"""
        \usepackage{amsmath, amssymb}
        \usepackage{mathptmx}  % Safe fallback for Times + math
    """
})

# Plot entanglement vs distance
fig, ax = plt.subplots();
ax.plot(inducers, fl_entanglement_scores, label="FL models", c="C0");
ax.plot(inducers, rand_entanglement_scores, label="Random models", c="C1");
ax.set_xlabel("Communication round");
ax.set_ylabel("Functional entanglement");
ax.legend();
ax.set_xlim(min(inducers), max(inducers));
ax.grid(True, linestyle="--", linewidth=0.5);
ax2 = ax.twiny();
ax2.xaxis.set_ticks_position("bottom");
ax2.xaxis.set_label_position("bottom");
ax2.spines.bottom.set_position(("outward", 36));
ax2.set_xlim(min(inducers), max(inducers));
ax2.set_xticks(inducers, labels=map(lambda x: f"{x:.1f}", distances));
ax2.set_xlabel("Euclidean parameter divergence");
fig.savefig("disentanglement.png", dpi=300, bbox_inches="tight");

## Plot function over labels

In [None]:
ds_test = fetch_data("label", partition="test", beta=0., n_clients=1)
vcall = nnx.jit(nnx.vmap(
    lambda model, x, z: model(x, z, train=False), 
    in_axes=(0, None, None)
))
labels = jnp.empty((0,2))
logits = jnp.empty((n_clients,0,2))
logits_init = jnp.empty((0,2))
for batch in ds_test:
    labels = jnp.concatenate([labels, batch[0].squeeze(0)])
    logits = jnp.concatenate([logits, vcall(fl_models, batch[1].squeeze(0), batch[2].squeeze(0))], axis=1)
    logits_init = jnp.concatenate([logits_init, model_init(batch[1].squeeze(0), batch[2].squeeze(0), train=False)])
# Convert to angles, sort, and smoothen
labels = jnp.rad2deg(jnp.arctan2(*labels.T))
sorter = jnp.argsort(labels)
labels = labels[sorter]
labels = jnp.convolve(labels, jnp.ones((100,))/100., mode="valid")
outs_normalized = jnp.rad2deg(jnp.arctan2(*logits.T)).T - jnp.rad2deg(jnp.arctan2(*logits_init.T))
outs_normalized = outs_normalized[:, sorter]
outs_normalized = jnp.vectorize(lambda out: jnp.convolve(out, jnp.ones((100,))/100., mode="valid"), signature="(n)->(m)")(outs_normalized)
# Plot
fig, ax = plt.subplots();
for i in range(n_clients):
    ax.plot(labels, outs_normalized[i,:], c="C"+str(i), label=f"Client {i}");
    ax.axvspan(360/n_clients*i-180, 360/n_clients*(i+1)-180, alpha=0.2, color="C"+str(i), zorder=0);
ax.axhline(0., c="black", linestyle="-.", linewidth=0.8, label="Relative initialization");
ax.legend();
ax.set_xlabel("Label (degrees)");
ax.set_ylabel(r"$\mathbb{E}_{(x,y)\in\mathcal{D}}[f_i(x)-f_{init}(x) : y=label]$")
ax.set_xlim(-180, 180);
ax.grid(True, linestyle="--", linewidth=0.5);
fig.savefig("disentanglement_over_domain.png", dpi=300, bbox_inches="tight");