In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import optax, jax, json
from jax import numpy as jnp
from flax import nnx
from fedflax import train
from models import LeNet
from data import fetch_data
from collections import defaultdict
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
from functools import reduce
from utils import return_l2, opt_create, angle_err, nnx_norm

## Utils

In [None]:
# Sharpness-aware minimization optimizer
def sam(model): return nnx.Optimizer(
    model,
    optax.contrib.sam(
        optax.adamw(learning_rate=1e-3),
        optax.chain(optax.contrib.normalize(), optax.adam(1e-2)),
        sync_period=5
    ),
    wrt=nnx.Param
)
# Regular optimizer
def opt_reg(model): return nnx.Optimizer(
    model,
    optax.adamw(learning_rate=1e-3),
    wrt=nnx.Param
)

def partial_aggregate(models, alpha):
    params, struct = jax.tree.flatten(nnx.to_tree(models))
    avg_params = jax.tree.map(lambda p: jnp.mean(p, axis=0), params)
    new_params = jax.tree.map(lambda p, ap: (1-alpha)*p + alpha*ap, params, avg_params)
    return nnx.from_tree(jax.tree.unflatten(struct, new_params))

# Calculate the dominant eigenvalue (lambda_max) of an nnx model using the power iteration method
key = jax.random.key(42)
@nnx.vmap(in_axes=(0,None,None,None,None,None))
def lambda_max(model, x, z, y, max_iter=20, key=key):
    # Convenience (jvp must act on a parameter-only pytree)
    struct, theta, rest = nnx.split(model, (nnx.Param, nnx.BatchStat), ...)
    reconstruct = lambda th: nnx.merge(struct, th, rest)

    # Gradient of loss on this data
    grad_fn = nnx.grad(lambda th: jnp.square(reconstruct(th)(x, z, train=True) - y).mean())

    # Random normalized vector
    def random_like(arr): nonlocal key; _, key = jax.random.split(key); return jax.random.normal(key, arr.shape)
    v = jax.tree.map(lambda p: random_like(p), theta)
    norm = nnx_norm(v)
    v = v_prev = jax.tree.map(lambda v_: v_/norm, v)

    # Perform iteration avoiding python bools
    def true_fun(val):
        # Power iteration step
        v, _, i = val
        hv = jax.jvp(
            grad_fn,
            (theta,),
            (v,)
        )[1]
        norm = nnx_norm(hv)
        v_new = jax.tree.map(lambda v_: v_/norm, hv)
        return v_new, v, i+1
    def cond_fun(val):
        # Check convergence
        v, v_prev, i = val
        return jnp.logical_and(i<max_iter, jnp.logical_or(i==0, nnx_norm(v, v_prev)>=1e-3))
    # Iterate
    i0 = jnp.array(0)
    v, *_ = jax.lax.while_loop(cond_fun, true_fun, (v, v_prev, i0))
    # Return Rayleigh quotient (dot product between v and grad)
    return jax.tree.reduce(
        lambda acc, prod: acc+prod, 
        jax.tree.map(lambda v_, g_: jnp.sum(v_*g_), 
                     v, grad_fn(theta)
                    )
        )

## Interpolate between client and global, and measure metrics

In [None]:
# Get models after one-shot for various heterogeneity levels
metrics = defaultdict(lambda: defaultdict(list))
for opt_callable, beta, skew in [(opt_reg, 1., "feature"), (opt_reg, 0., "overlap"), (sam, 1., "feature"), (opt_reg, .01, "feature"), (opt_reg, .5, "feature"), (opt_reg, .5, "overlap"), (opt_reg, .99, "overlap"), (sam, .01, "feature"), (sam, .5, "feature")]:
    # Train model
    ds_train = fetch_data(skew=skew, beta=beta)
    ds_val = fetch_data(skew=skew, beta=beta, partition="val", batch_size=16)
    ds_test = fetch_data(skew=skew, beta=beta, partition="test", batch_size=16)
    models, _ = train(LeNet(jax.random.key(42)), opt_callable, ds_train, return_l2(0.), ds_val, local_epochs="early", rounds=1, max_patience=5, val_fn=angle_err)
    # Interpolate
    for alpha in tqdm(jnp.linspace(0,1,30), leave=False):
        # Get the client models between global and local at alpha
        models_agg = partial_aggregate(models, alpha)
        # Average accuracy on local data
        err_fn = nnx.jit(nnx.vmap(angle_err))
        errs_l = reduce(lambda acc, b: acc + err_fn(models_agg,*b), ds_test, 0.) / len(ds_test)
        metrics[f"{skew}_{beta}_{opt_callable.__name__}"]["Avg. loc. err."].append(errs_l.tolist()) 
        # Average accuracy on global data
        err_fn = nnx.jit(nnx.vmap(angle_err, in_axes=(0,None,None,None)))
        errs_g = reduce(lambda acc, b: acc + err_fn(models_agg,b[0].reshape(-1,2),b[1].reshape(-1,36,60,1),b[2].reshape(-1,3)), ds_test, 0.) / len(ds_test)
        metrics[f"{skew}_{beta}_{opt_callable.__name__}"]["Avg. glob. err."].append(errs_g.tolist())
        # Average top Hessian eigenvalue
        y_test, x_test, z_test = reduce(lambda acc, b: (
            jnp.concatenate([acc[0], b[0].reshape(-1,2)]), 
            jnp.concatenate([acc[1], b[1].reshape(-1,36,60,1)]), 
            jnp.concatenate([acc[2], b[2].reshape(-1,3)])
        ), ds_test, (jnp.empty((0,2)), jnp.empty((0,36,60,1)), jnp.empty((0,3))))
        lambda_maxs = lambda_max(models_agg, x_test, z_test, y_test)
        metrics[f"{skew}_{beta}_{opt_callable.__name__}"][r"Avg. $\lambda_\text{max}$"].append(lambda_maxs.tolist())
    # Save intermediate results
    with open("agg/metrics.json", "w") as f:
        json.dump(metrics, f)

## Visualize

In [None]:
# Rendering style
plt.style.use("seaborn-v0_8-pastel")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"""
        \usepackage{amsmath, amssymb}
        \usepackage{mathptmx}  % Safe fallback for Times + math
    """
})

# Plot all settings
metrics = json.load(open("agg/metrics.json", "r"))
for name in metrics.keys(): 
    # There are supposed to be three metrics per setting
    fig, mainax = plt.subplots(dpi=300);
    # Plot
    loc = jnp.array(metrics[name]["Avg. loc. err."]);
    glob = jnp.array(metrics[name]["Avg. glob. err."]);
    mainax.plot(jnp.linspace(0,1,len(loc)), loc.mean(axis=-1), c=f"C0", label="Error on local data");
    mainax.plot(jnp.linspace(0,1,len(glob)), glob.mean(axis=-1), c=f"C1", label="Error on global data");
    mainax.fill_between(jnp.linspace(0,1,len(loc)), loc.min(axis=-1), loc.max(axis=-1), color=f"C0", alpha=.2, linewidth=0.);
    mainax.fill_between(jnp.linspace(0,1,len(glob)), glob.min(axis=-1), glob.max(axis=-1), color=f"C1", alpha=.2, linewidth=0.);
    mainax.set_ylabel("Angular error (degrees)");
    twin = mainax.twinx();
    twin.set_ylabel("Top Hessian eigenvalue");
    hess = jnp.array(metrics[name][r"Avg. $\lambda_\text{max}$"]);
    twin.plot(jnp.linspace(0,1,len(hess)), hess.mean(axis=-1), c=f"C2", label=r"$\lambda_\mathrm{max}$");
    twin.fill_between(jnp.linspace(0,1,len(hess)), hess.min(axis=-1), hess.max(axis=-1), color=f"C2", alpha=.2, linewidth=0.);
    mainax.set_xlabel("Interpolation coefficient");
    mainax.set_xlim(0,1);
    mainax.grid(True, linestyle="--", linewidth=0.5, axis="both");
    fig.legend();
    fig.savefig(f"agg/interpolation_{name}.png");