In [None]:
%load_ext autoreload
%autoreload 2
from fedflax import train
from models import ResNet
from data import fetch_data
from utils import angle_err, opt_create, return_l2, return_ce, top_5_err
import jax, pickle
from jax import numpy as jnp
from flax import nnx
from functools import partial
from matplotlib import pyplot as plt
from copy import deepcopy
n_clients = 3

In [None]:
key = jax.random.key(42)
vcall = nnx.jit(nnx.vmap(
    lambda model, x: model(x, train=False), 
    in_axes=(0, None)
))

# For different distances, compute entanglement scores
inducers = jnp.arange(20).tolist()
distances = []
logits = {"logits_init": [], "logits_l": [], "logits_g": [],
          "logits_l_rand": [], "logits_g_rand": [], "labels": []}
model_g = ResNet(key, dim_out=100)
for inducer in inducers:
    # Get datasets which induce this distance
    ds_train = fetch_data("label", beta=1., n_clients=n_clients, dataset=1, n_classes=100)
    ds_test = fetch_data("label", partition="test", beta=1., n_clients=1, dataset=1, n_classes=100)
    ds_val = fetch_data("label", partition="val", beta=1., n_clients=n_clients, dataset=1, n_classes=100)
    # Get federated models which are this far apart
    model_init = deepcopy(model_g)
    fl_models, _ = train(model_g, partial(opt_create, learning_rate=1e-3), 
                      ds_train, return_ce(0.), ds_val, local_epochs="early", 
                      max_patience=2, val_fn=top_5_err, rounds=1, n_clients=n_clients)
    # Check entanglement wrt global model
    struct, params_l, rest = nnx.split(fl_models, (nnx.Param, nnx.BatchStat), ...)
    params_g = jax.tree.map(lambda p: p.mean(0), params_l)
    model_g = nnx.merge(struct, params_g, rest)
    # Get L2 distance
    distance = jax.tree.map(lambda pg, pl: jnp.abs(pg - pl)**2, params_g, params_l)
    distance = jax.tree.reduce(lambda acc, d: acc+jnp.sum(d.reshape(n_clients,-1), -1), distance, 0.)
    distance = jnp.sqrt(distance.mean())
    distances.append(distance.item())
    # Get models which are (measured across all params) this far apart at random
    rand_model_g = ResNet(key, dim_out=100)
    rstate, rparams_g, rrest = nnx.split(rand_model_g, (nnx.Param, nnx.BatchStat), ...)
    rand = jax.tree.map(
        lambda p: jax.random.normal(key, (n_clients, *p.shape)), 
        rparams_g
    )
    norm = jnp.linalg.norm(
        jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (n_clients,-1)), jax.tree.leaves(rand)), axis=1),
        axis=-1
    )
    rparams_l = jax.tree.map(lambda p, r: p+(r.T/norm).T*distance, rparams_g, rand)
    rand_models = nnx.merge(rstate, rparams_l, rrest)

    # Now store logits to calculate entanglement
    labels = jnp.empty((0,100))
    logits_init = jnp.empty((0, 100))
    logits_l = jnp.empty((n_clients, 0, 100))
    logits_g = jnp.empty((0, 100))
    logits_l_rand = jnp.empty((n_clients, 0, 100))
    logits_g_rand = jnp.empty((0, 100))
    for batch in ds_test:
        labels = jnp.concatenate([labels, batch[0].squeeze(0)])
        logits_l = jnp.concatenate([logits_l, vcall(fl_models, batch[1].squeeze(0))], axis=1)
        logits_init = jnp.concatenate([logits_init, model_init(batch[1].squeeze(0), train=False)])
        logits_g = jnp.concatenate([logits_g, model_g(batch[1].squeeze(0), train=False)])
        logits_l_rand = jnp.concatenate([logits_l_rand, vcall(rand_models, batch[1].squeeze(0))], axis=1)
        logits_g_rand = jnp.concatenate([logits_g_rand, rand_model_g(batch[1].squeeze(0), train=False)])
    logits["logits_l"].append(logits_l)
    logits["logits_g"].append(logits_g)
    logits["logits_l_rand"].append(logits_l_rand)
    logits["logits_g_rand"].append(logits_g_rand)
    logits["logits_init"].append(logits_init)
    logits["labels"].append(labels)
    pickle.dump(logits, open("disentanglement/logits.pkl", "wb"))

## Visualize entanglement

In [None]:
fl_entanglement_scores = [jnp.abs(logits["logits_l"][i] - logits["logits_g"][i]).mean().item() for i in range(len(inducers))] # TODO: models are currently tested on all data, not just their own
rand_entanglement_scores = [jnp.abs(logits["logits_l_rand"][i] - logits["logits_g_rand"][i]).mean().item() for i in range(len(inducers))]

# Matplotlib rendering style
plt.style.use("seaborn-v0_8-pastel")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"""
        \usepackage{amsmath, amssymb}
        \usepackage{mathptmx}  % Safe fallback for Times + math
    """
})

# Plot entanglement vs distance
fig, ax = plt.subplots();
ax.plot(inducers, fl_entanglement_scores, label="FL models", c="C0");
ax.plot(inducers, rand_entanglement_scores, label="Random models", c="C1");
ax.set_xticks(inducers, labels=map(int, list(inducers)))
ax.set_xlabel("Communication round");
ax.set_ylabel("Functional entanglement");
ax.legend();
ax.set_xlim(min(inducers), max(inducers));
ax.grid(True, linestyle="--", linewidth=0.5);
ax2 = ax.twiny();
ax2.xaxis.set_ticks_position("bottom");
ax2.xaxis.set_label_position("bottom");
ax2.spines.bottom.set_position(("outward", 36));
ax2.set_xlim(min(inducers), max(inducers));
ax2.set_xticks(inducers, labels=map(lambda x: f"{x:.1f}", distances));
ax2.set_xlabel("Euclidean parameter divergence");
fig.savefig("disentanglement/disentanglement.png", dpi=300, bbox_inches="tight");

## Plot function over labels (ImageNet)

In [None]:
# Convert one-hot encoding to integer (identical for each round, given same test set)
labels = logits["labels"][-1].argmax(-1)
# Function that returns the best guess out of top 5
best_of_5 = lambda label, lgts: jnp.where(
    jnp.any(label==jnp.argsort(lgts, axis=-1)[...,-5:], axis=-1),
    label,
    jnp.argmax(lgts, axis=-1)
)
# Function that returns the amount by which the correct label is displaced in the sorted logits (normalized by 10 for visibility)
displacement = lambda label, lgts: jnp.argmax(jnp.argsort(lgts, axis=-1, descending=True)==label, axis=-1)/10
# Plot
fig, ax = plt.subplots();
for i in range(n_clients):
    # This indexation says: last round, client i, samples belonging to label, top logit, mean over samples
    scoreperlabel = [best_of_5(label, logits["logits_l"][-1][i,jnp.nonzero(labels==label)[0],:]).mean() for label in jnp.unique(labels)]
    # This indexation says: all rounds, client i, samples belonging to label, top logit, mean over samples, std over rounds
    stdperlabel = jnp.stack(logits["logits_l"])
    stdperlabel = [displacement(label, stdperlabel[:,i,jnp.nonzero(labels==label)[0],:]).mean((0,-1)) for label in jnp.unique(labels)]
    # Plot mean line with shaded std
    ax.hlines(scoreperlabel, jnp.arange(100), jnp.arange(1,101), colors="C"+str(i), label=f"Client {i}")
    ax.fill_between(jnp.unique(labels),
                    jnp.array(scoreperlabel)-jnp.array(stdperlabel),
                    jnp.array(scoreperlabel)+jnp.array(stdperlabel),
                    color="C"+str(i), alpha=0.3, step="post", linewidth=0)
    # Show domain
    ax.axvspan(i*100/n_clients, (i+1)*100/n_clients, color="C"+str(i), alpha=0.05, zorder=0)

# Same for aggregated model
scoreperlabel = [best_of_5(label, logits["logits_g"][-1][jnp.nonzero(labels==label)[0],:]).mean() for label in jnp.unique(labels)]
stdperlabel = jnp.stack(logits["logits_g"])
stdperlabel = [displacement(label, stdperlabel[:,jnp.nonzero(labels==label)[0],:]).mean((0,-1)) for label in jnp.unique(labels)]
ax.hlines(scoreperlabel, jnp.arange(100), jnp.arange(1,101), colors="black", label="Aggregated model")
ax.fill_between(jnp.unique(labels),
                jnp.array(scoreperlabel)-jnp.array(stdperlabel),
                jnp.array(scoreperlabel)+jnp.array(stdperlabel),
                color="black", alpha=0.3, step="post", linewidth=0)

# # ax.plot(labels, labels - outs_init, c="gray", label="Ideal model");
# # ax.axhline(0., c="black", linestyle="-.", label="Relative initialization");
fig.legend(framealpha=1);
ax.set_xlabel("Label");
ax.set_ylabel(r"$\mathbb{E}_{x\in\mathcal{D}_{\text{label}}}[\text{argmax}(f(\theta_i, x))]$")
ax.set_xlim(0, 99);
ax.set_ylim(0,99);
ax.set_yticks(jnp.arange(0, 100), labels=[0]+[None]*48+["..."]+[None]*49+[99])
ax.set_xticks(jnp.arange(0, 100), labels=[0]+[None]*48+["..."]+[None]*49+[99])
# ax.grid(True, linestyle="--", linewidth=0.5);
fig.savefig("disentanglement/function.png", dpi=300, bbox_inches="tight");

## Plot function over labels (MPIIGaze)

In [None]:
labels = jnp.rad2deg(jnp.arctan2(*logits["labels"][0].T))
sorter = jnp.argsort(labels)
labels = labels[sorter]
labels = jnp.convolve(labels, jnp.ones((100,))/100., mode="valid")
outs_l = jnp.rad2deg(jnp.arctan2(*logits["logits_l"][-1].T))[sorter]
outs_l = jnp.vectorize(lambda out: jnp.convolve(out, jnp.ones((100,))/100., mode="valid"), signature="(n)->(m)")(outs_l.T)
std_l = jnp.stack(list(map(lambda arr: jnp.rad2deg(jnp.arctan2(*arr.T)), logits["logits_l"]))).std(0)[sorter]
std_l = jnp.vectorize(lambda out: jnp.convolve(out, jnp.ones((100,))/100., mode="valid"), signature="(n)->(m)")(std_l.T)
outs_g = jnp.rad2deg(jnp.arctan2(*logits["logits_g"][-1][sorter].T))
outs_g = jnp.convolve(outs_g, jnp.ones((100,))/100., mode="valid")
std_g = jnp.stack(list(map(lambda arr: jnp.rad2deg(jnp.arctan2(*arr.T)), logits["logits_g"]))).std(0)[sorter]
std_g = jnp.convolve(std_g, jnp.ones((100,))/100., mode="valid")
# Plot
fig, ax = plt.subplots();
for i in range(n_clients):
    ax.plot(labels, outs_l[i], c="C"+str(i), label=f"Client {i}");
    ax.fill_between(
        labels, 
        outs_l[i]-std_l[i], 
        outs_l[i]+std_l[i], 
        color="C"+str(i), 
        alpha=.6,
        linewidth=0.
    );
    ax.axvspan(*[(-120, 0), (-180, -120), (0, 120)][i], alpha=.2, color="C"+str(i), zorder=0);
ax.plot(labels, outs_g, c="black", label="Aggregated model");
ax.fill_between(
    labels, 
    outs_g-std_g, 
    outs_g+std_g, 
    color="black", 
    alpha=.4,
    linewidth=0.
);

# ax.plot(labels, labels - outs_init, c="gray", label="Ideal model");
# ax.axhline(0., c="black", linestyle="-.", label="Relative initialization");
ax.legend();
ax.set_xlabel("Label (degrees)");
ax.set_ylabel(r"$\mathbb{E}_{x\in\mathcal{D}_{\text{label}}}[f(\theta_i, x)-f(\bar{\theta},x)]$")
ax.set_xlim(-180, 180);
ax.grid(True, linestyle="--", linewidth=0.5);
fig.savefig("disentanglement/function.png", dpi=300, bbox_inches="tight");