In [None]:
%load_ext autoreload
%autoreload 2
from fedflax import train
from models import LeNet
from data import fetch_data
from utils import angle_err, opt_create, return_l2
import jax
from jax import numpy as jnp
from flax import nnx
from tqdm.auto import tqdm
from functools import reduce, partial

In [None]:
# Measurement of functional interference between one client's update vs all updates
entanglement_fn = nnx.jit(nnx.vmap(
    lambda model_g, models, x, z: jnp.abs(model_g(x, z) - models(x, z)), 
    in_axes=(None, 0, 0, 0)
))
key = jax.random.key(42)

# For different distances, compute entanglement scores
inducers = [1, 6, 15, 25, 37] # TODO: does local epochs as an inducer test what you desire?
distances = []
fl_entanglement_scores = []
rand_entanglement_scores = []
for inducer in inducers:
    # Get datasets which induce this distance
    ds_train = fetch_data("feature", beta=1.)
    ds_test = fetch_data("feature", partition="test", beta=1.)
    ds_val = fetch_data("feature", partition="val", beta=1.)
    # Get federated models which are this far apart
    model_init = LeNet(key)
    fl_models, _ = train(model_init, partial(opt_create, learning_rate=1e-3), 
                      ds_train, return_l2(0.), ds_val, local_epochs=inducer, 
                      max_patience=5, val_fn=angle_err, rounds=1)
    # Check entanglement wrt global model
    struct, params_l, rest = nnx.split(fl_models, (nnx.Param, nnx.BatchStat), ...)
    params_g = jax.tree.map(lambda p: p.mean(0), params_l)
    model_g = nnx.merge(struct, params_g, rest)
    fl_entanglement_scores.append(reduce(
        lambda acc, batch: acc + entanglement_fn(model_g, fl_models, *batch[1:]).mean(), 
        ds_test,
        0.
    ) / len(ds_test))
    # Get MAE 
    distance = jax.tree.map(lambda pg, pl: jnp.abs(pg - pl), params_g, params_l)
    distance = jax.tree.reduce(lambda acc, d: acc+jnp.sum(d), distance, 0.)
    distance /= jax.tree.reduce(lambda acc, p: acc+p.size, params_l, 0.)
    distances.append(distance)
    # Get models which are (measured across all params) this far apart at random
    rand_model_g = LeNet(key)
    rstate, rparams_g, rrest = nnx.split(rand_model_g, (nnx.State, nnx.Param), ...)
    rand = jax.tree.map(
        lambda p: jax.random.normal(key, (4, *p.shape)), 
        rparams_g
    )
    norm = jax.tree.reduce(
        lambda acc, p: acc+jnp.linalg.norm(p.reshape(4,-1), axis=-1), rand, 0.
    ) / len(rand)
    rparams_l = jax.tree.map(lambda p, r: p+(r.T/norm).T*distance, rparams_g, rand)
    rand_models = nnx.merge(rstate, rparams_l, rrest)
    # Now check if this disentanglement is better than random
    rand_entanglement_scores.append(reduce(
        lambda acc, batch: acc + entanglement_fn(rand_model_g, rand_models, *batch[1:]).mean(), 
        ds_test,
        0.
    ) / len(ds_test))