In [None]:
%load_ext autoreload
%autoreload 2
from fedflax import train
from models import ResNet
from data import fetch_data
from utils import return_ce, top_5_err, nnx_norm
import jax, json, optax
from jax import numpy as jnp
from flax import nnx
from matplotlib import pyplot as plt
from npy_append_array import NpyAppendArray
n_clients = 3

# Matplotlib rendering style
plt.style.use("seaborn-v0_8-pastel")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"""
        \usepackage{amsmath, amssymb}
        \usepackage{mathptmx}
    """
})

In [None]:
key = jax.random.key(42)
vcall = nnx.jit(nnx.vmap(
    lambda model, x: model(x, train=False), 
    in_axes=(0, None)
))

# Sharpness-aware minimization optimizer
def sam(model): return nnx.Optimizer(
    model,
    optax.contrib.sam(
        optax.sgd(learning_rate=lr, momentum=.9),
        optax.chain(optax.contrib.normalize(), optax.sgd(.1)),
        sync_period=2
    ),
    wrt=nnx.Param
)

# Induce different distances via elongated FL training
inducers = jnp.arange(20).tolist()
# Store logits to compute entanglement scores (shape: n_inducers*n_samples, n_clients or none, n_partitions, n_classes)
jnp.save("disentanglement/labels.npy", jnp.empty((0,n_clients,100)))
jnp.save("disentanglement/logits_l.npy", jnp.empty((0,n_clients,n_clients,100)))
jnp.save("disentanglement/logits_g.npy", jnp.empty((0,n_clients,100)))
jnp.save("disentanglement/logits_l_rand.npy", jnp.empty((0,n_clients,n_clients,100)))
jnp.save("disentanglement/logits_g_rand.npy", jnp.empty((0,n_clients,100)))
distances = []
# Initialize global model
model_g = ResNet(key, dim_out=100)
lr = optax.warmup_exponential_decay_schedule(1e-4, 1., 200, 100, .9, end_value=1e-5)
opt = nnx.Optimizer( # sam(model_g)
    model_g,
    optax.adam(learning_rate=lr),
    wrt=nnx.Param
)
for inducer in inducers:
    # Get datasets which induce this distance
    ds_train = fetch_data("label", beta=1., n_clients=n_clients, dataset=1, n_classes=100)
    ds_test = fetch_data("label", partition="test", beta=1., batch_size=64, n_clients=n_clients, dataset=1, n_classes=100)
    ds_val = fetch_data("label", partition="val", beta=1., batch_size=64, n_clients=n_clients, dataset=1, n_classes=100)
    # Get federated models which are this far apart
    fl_models, _ = train(model_g, opt, 
                      ds_train, return_ce(0.), ds_val, local_epochs="early", 
                      max_patience=3, val_fn=top_5_err, rounds=1, n_clients=n_clients)
    # Check entanglement wrt global model
    struct, params_l, batchstat_l, rest = nnx.split(fl_models, nnx.Param, nnx.BatchStat, ...)
    params_g = jax.tree.map(lambda p: p.mean(0), params_l)
    batchstat_g = jax.tree.map(lambda b: b.mean(0), batchstat_l)
    model_g = nnx.merge(struct, params_g, batchstat_g, rest)
    # Get L2 distance
    distance = nnx_norm(params_l, params_g, order=2., n_clients=n_clients).mean()
    # Get models which are (measured across all params) this far apart at random
    # TODO: questions 1) should the random global model be further randomized; 2) which batchstats should the random local models receive?
    rand_model_g = ResNet(key, dim_out=100)
    rstate, rparams_g, *_ = nnx.split(rand_model_g, nnx.Param, nnx.BatchStat, ...)
    cpkey = jax.random.key(42)
    def random_like(p): global cpkey; _, cpkey = jax.random.split(cpkey); return jax.random.normal(cpkey, (n_clients,*p.shape))
    rand = jax.tree.map(random_like, rparams_g)
    norm = nnx_norm(rand, n_clients=n_clients, order=2.)
    rparams_l = jax.tree.map(lambda p, r: p+(r.T/norm).T*distance, rparams_g, rand)
    rand_models = nnx.merge(rstate, rparams_l, batchstat_l, rest)

    # Now store logits to calculate entanglement
    distances.append(distance.item())
    json.dump(distances, open("disentanglement/distances.json", "w"))
    for vy, vx in ds_test:
        # Reshape so that we get the logits of each client over all partitions
        flaty = vy.reshape(-1, *vy.shape[2:])
        flatx = vx.reshape(-1, *vx.shape[2:])
        out_shape = (n_clients, n_clients, -1, 100)
        # Append out-of-memory
        with (NpyAppendArray("disentanglement/labels.npy") as labels,
              NpyAppendArray("disentanglement/logits_l.npy") as logits_l,
              NpyAppendArray("disentanglement/logits_g.npy") as logits_g,
              NpyAppendArray("disentanglement/logits_l_rand.npy") as logits_l_rand,
              NpyAppendArray("disentanglement/logits_g_rand.npy") as logits_g_rand):
            labels.append(vy.swapaxes(0,-2).__array__())
            logits_l.append(vcall(fl_models, flatx).reshape(out_shape).swapaxes(0,-2).__array__())
            logits_g.append(model_g(flatx, train=False).reshape(out_shape[1:]).swapaxes(0,-2).__array__())
            logits_l_rand.append(vcall(rand_models, flatx).reshape(out_shape).swapaxes(0,-2).__array__())
            logits_g_rand.append(rand_model_g(flatx, train=False).reshape(out_shape[1:]).swapaxes(0,-2).__array__())

## Visualize entanglement
Entanglement score: $\sum_{i=0}^{|\mathcal{C}|} \mathbb{E}_{x\sim\mathcal{D}_i} [\text{dist}( f(\theta_i, x), f(\bar{\theta}, x) )]$

In [None]:
# Load in memory
logits_l = jnp.load("disentanglement/logits_l.npy").reshape(len(inducers), -1, n_clients, n_clients, 100)
logits_g = jnp.load("disentanglement/logits_g.npy").reshape(len(inducers), -1, n_clients, 100)
logits_l_rand = jnp.load("disentanglement/logits_l_rand.npy").reshape(len(inducers), -1, n_clients, n_clients, 100)
logits_g_rand = jnp.load("disentanglement/logits_g_rand.npy").reshape(len(inducers), -1, n_clients, 100)
distances = json.load(open("disentanglement/distances.json", "r"))
# Calculate entanglement per round
entfn = lambda ll, lg: jnp.abs((ll-ll.min())/jnp.sum(ll-ll.min(), -1, keepdims=True) - (lg-lg.min())/jnp.sum(lg-lg.min(), -1, keepdims=True)).sum(-1).mean().item()/2
fl_entanglement_scores = [
        sum(entfn(logits_l[i,:,c,c,:], logits_g[i,:,c,:]) for c in range(n_clients))/n_clients
    for i in range(len(inducers))]
rand_entanglement_scores = [
        sum(entfn(logits_l_rand[i,:,c,c,:], logits_g_rand[i,:,c,:]) for c in range(n_clients))/n_clients
    for i in range(len(inducers))]

# Plot entanglement vs distance
fig, ax = plt.subplots();

ax.plot(inducers[:8], fl_entanglement_scores[:8], label="FL models", c="C0");
ax.plot(inducers[:8], rand_entanglement_scores[:8], label="Random models", c="C1");
ax.set_xticks(inducers[:8], labels=map(lambda x: f"{x:.1f}", distances[:8]), rotation=45);
ax.set_xlabel("L2 client drift");
ax.set_xlim(min(inducers[:8]), max(inducers[:8]));
ax.grid(True, linestyle="--", linewidth=0.5, axis="x");
ax.set_ylim(0.);

ax2 = ax.twiny();
ax2.spines.top.set_visible(False);
ax2.xaxis.set_ticks_position("bottom");
ax2.xaxis.set_label_position("bottom");
ax2.spines.bottom.set_position(("outward", 50));
ax2.set_xlim(min(inducers[:8]), max(inducers[:8]));
ax2.set_xticks(inducers[:8], labels=map(int, list(inducers[:8])));
ax2.set_xlabel("Communication round");

fig.legend();
fig.supylabel("Functional entanglement score");
fig.savefig("disentanglement/disentanglement_fedavg.png", dpi=300, bbox_inches="tight");

## Plot function over labels (ImageNet)

In [None]:
# Load logits and merge partitions
logits_l = jnp.load("disentanglement/logits_l.npy").swapaxes(1,2).reshape(len(inducers), -1, n_clients, 100)
logits_g = jnp.load("disentanglement/logits_g.npy").reshape(len(inducers), -1, 100)
labels = jnp.load("disentanglement/labels.npy").argmax(-1).reshape(len(inducers), -1)
ulabels = jnp.unique(labels).sort()

# Function that returns the best guess out of top 5
best_of_5 = lambda label, lgts: label if label in jnp.argsort(lgts, axis=-1)[-5:] else jnp.argmax(lgts, axis=-1)
# Function that returns the amount by which the correct label is displaced in the sorted logits (which are averaged over samples)
displacement = lambda label, lgts: jnp.argmin(jnp.abs(jnp.argsort(lgts, axis=-1, descending=True).mean(-2)-label), axis=-1)
# Plot
fig, ax = plt.subplots();
for i in range(n_clients):
    # This indexation says: last round, samples belonging to label, client i, best prediction, mean over samples
    scoreperlabel = [best_of_5(label, logits_l[-1, labels[-1]==label, i, :].mean(0)) for label in ulabels]
    # This indexation says: all rounds, samples belonging to label, client i, mean displacement over rounds
    disperlabel = [displacement(label, logits_l[:,labels[-1]==label,i,:]).mean(0) for label in ulabels] + [0]
    # Plot mean line with shaded variation (normalized by 2 for polarity and 5 for visibility)
    ax.hlines(scoreperlabel, ulabels-1, ulabels, colors="C"+str(i), label=f"Client {i}")
    ax.fill_between(jnp.arange(0,100),
                    jnp.array(scoreperlabel+[0])-jnp.array(disperlabel)/2/5,
                    jnp.array(scoreperlabel+[0])+jnp.array(disperlabel)/2/5,
                    color="C"+str(i), alpha=.5, step="post", linewidth=0)
    # Show domain
    ax.axvspan(i*len(ulabels)/n_clients, (i+1)*len(ulabels)/n_clients, color="C"+str(i), alpha=.18, zorder=0)

# Same for aggregated model
scoreperlabel = [best_of_5(label, logits_g[-1,labels[-1]==label,:].mean(0)) for label in ulabels]
disperlabel = [displacement(label, logits_g[:,labels[-1]==label,:]).mean(0) for label in ulabels] + [0]
ax.hlines(scoreperlabel, ulabels-1, ulabels, colors="black", label="Aggregated model")
ax.fill_between(jnp.arange(0,100),
                jnp.array(scoreperlabel+[0])-jnp.array(disperlabel)/2/5,
                jnp.array(scoreperlabel+[0])+jnp.array(disperlabel)/2/5,
                color="black", alpha=.3, step="post", linewidth=0)

# # ax.plot(labels, labels - outs_init, c="gray", label="Ideal model");
# # ax.axhline(0., c="black", linestyle="-.", label="Relative initialization");
fig.legend(framealpha=1);
ax.set_xlabel("Label");
ax.set_ylabel(r"$\text{argmax}\mathbb{E}_{x\in\mathcal{D}_{\text{label}}}[f(\theta_i, x)]$")
ax.set_xlim(0, 99);
ax.set_ylim(0, 99);
ax.set_xticks(jnp.arange(0,101,10).at[-1].set(99), labels=[0]+[None]*4+["..."]+[None]*4+[99], minor=False)
ax.set_yticks(jnp.arange(0,101,10).at[-1].set(99), labels=[0]+[None]*4+["..."]+[None]*4+[99], minor=False, rotation=90)
ax.set_yticks(jnp.arange(0, 100), minor=True)
ax.set_xticks(jnp.arange(0, 100), minor=True)
ax.grid(True, linestyle="--", linewidth=0.5, which="major");
fig.savefig("disentanglement/function_fedavg.png", dpi=300, bbox_inches="tight");