In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import optax, jax, json
from jax import numpy as jnp
from flax import nnx
from fedflax import train
from models import LeNet
from data import fetch_data
from collections import defaultdict
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
from functools import reduce
from utils import return_l2, angle_err, nnx_norm

## Utils

In [None]:
# Sharpness-aware minimization optimizer
def sam(model): return nnx.Optimizer(
    model,
    optax.contrib.sam(
        optax.sgd(learning_rate=1e-3, momentum=.9),
        optax.chain(optax.contrib.normalize(), optax.sgd(.1)),
        sync_period=2
    ),
    wrt=nnx.Param
)
# Regular optimizer
def opt_reg(model): return nnx.Optimizer(
    model,
    optax.adamw(learning_rate=1e-3),
    wrt=nnx.Param
)

def partial_aggregate(models, alpha):
    params, struct = jax.tree.flatten(nnx.to_tree(models))
    avg_params = jax.tree.map(lambda p: jnp.mean(p, axis=0), params)
    new_params = jax.tree.map(lambda p, ap: (1-alpha)*p + alpha*ap, params, avg_params)
    return nnx.from_tree(jax.tree.unflatten(struct, new_params))

# Calculate the dominant eigenvalue (lambda_max) of an nnx model using the power iteration method
@nnx.vmap(in_axes=(0,None,None,None,None,None))
def lambda_max(model, x, z, y, key, max_iter=20):
    # Convenience (jvp must act on a parameter-only pytree)
    struct, theta, rest = nnx.split(model, (nnx.Param, nnx.BatchStat), ...)
    reconstruct = lambda th: nnx.merge(struct, th, rest)

    # Gradient of loss on this data
    grad_fn = nnx.grad(lambda th: jnp.square(reconstruct(th)(x, z, train=True) - y).mean())

    # Random normalized vector
    def random_like(arr): nonlocal key; _, key = jax.random.split(key); return jax.random.normal(key, arr.shape)
    rand = jax.tree.map(lambda p: random_like(p), theta)

    # Perform iteration avoiding python bools
    def true_fun(val):
        # Power iteration step
        hv_prev, *_, i = val
        norm = nnx_norm(hv_prev)
        v = jax.tree.map(lambda hv_: hv_ / norm, hv_prev)
        hv = jax.jvp(
            grad_fn,
            (theta,),
            (v,)
        )[1]
        return hv, hv_prev, v, i+1
    def cond_fun(val):
        # Check convergence
        hv, hv_prev, _, i = val
        return jnp.logical_or(i<2, jnp.logical_and(i<max_iter, nnx_norm(hv, hv_prev)>=1e-3))
    # Iterate
    i0 = jnp.array(0)
    hv, _, v, _ = jax.lax.while_loop(cond_fun, true_fun, (rand, rand, rand, i0))
    # Return Rayleigh quotient (manual dot product between hv and v)
    return jax.tree.reduce(
        lambda acc, prod: acc+prod, 
        jax.tree.map(lambda hv_, v_: jnp.sum(hv_*v_), hv, v)
        )

## Interpolate between client and global, and measure metrics

In [None]:
# Get models after one-shot for various heterogeneity levels
metrics = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) # setting -> metric -> alpha -> seed -> value per client
for opt_fn, beta, skew in [(opt_reg, 1., "feature"), (opt_reg, 0., "overlap"), (sam, 1., "feature")]:
    for seed in range(10):
        key = jax.random.key(seed)
        # Train model
        ds_train = fetch_data(skew=skew, beta=beta)
        ds_val = fetch_data(skew=skew, beta=beta, partition="val", batch_size=16)
        ds_test = fetch_data(skew=skew, beta=beta, partition="test", batch_size=16)
        model = LeNet(key)
        models, _ = train(LeNet(key), opt_fn(model), ds_train, return_l2(0.), ds_val, local_epochs="early", rounds=1, max_patience=5, val_fn=angle_err)
        # Interpolate
        for alpha in tqdm(jnp.linspace(0,1,30).tolist(), leave=False):
            # Get the client models between global and local at alpha
            models_agg = partial_aggregate(models, alpha)
            # Accuracy on local data
            err_fn = nnx.jit(nnx.vmap(angle_err))
            errs_l = reduce(lambda acc, b: acc + err_fn(models_agg,*b), ds_test, 0.) / len(ds_test)
            metrics[f"{skew}_{beta}_{opt_fn.__name__}"]["Avg. loc. err."][alpha].append(errs_l.tolist()) 
            # Accuracy on global data
            err_fn = nnx.jit(nnx.vmap(angle_err, in_axes=(0,None,None,None)))
            errs_g = reduce(lambda acc, batch: acc + err_fn(models_agg,*jax.tree.map(lambda x: x.reshape(-1, *x.shape[2:]), batch)), ds_test, 0.) / len(ds_test)
            metrics[f"{skew}_{beta}_{opt_fn.__name__}"]["Avg. glob. err."][alpha].append(errs_g.tolist())
            # Top Hessian eigenvalue
            y_test, x_test, z_test = reduce(lambda acc, batch: 
                                            jax.tree.map(lambda a, x: jnp.concatenate([a, x.reshape(-1, *x.shape[2:])]), acc, batch), 
                                            ds_test, 
                                            (jnp.empty((0,2)), jnp.empty((0,36,60,1)), jnp.empty((0,3))))
            lambda_maxs = lambda_max(models_agg, x_test, z_test, y_test, key)
            metrics[f"{skew}_{beta}_{opt_fn.__name__}"][r"Avg. $\lambda_\text{max}$"][alpha].append(lambda_maxs.tolist())
        # Save intermediate results
        with open("agg/metrics.json", "w") as f:
            json.dump(metrics, f)

## Visualize

In [None]:
# Rendering style
plt.style.use("seaborn-v0_8-pastel")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"""
        \usepackage{amsmath, amssymb}
        \usepackage{mathptmx}  % Safe fallback for Times + math
    """
})

# Plot all settings
with open("agg/metrics.json", "r") as f:
    metrics = json.load(f)
for name in metrics.keys(): 
    # There are supposed to be three metrics per setting
    fig, mainax = plt.subplots(dpi=300);
    # Plot
    glob = jnp.array(list(metrics[name]["Avg. glob. err."].values()));
    glob_mean = glob.mean(axis=(-2,-1))
    glob_std = glob.mean(-1).std(-1)
    mainax.plot(jnp.linspace(0,1,len(glob)), glob_mean, c=f"C1", label="Error on global data");
    mainax.fill_between(jnp.linspace(0,1,len(glob)), glob_mean - glob_std, glob_mean + glob_std, color=f"C1", alpha=.2, linewidth=0.);
    
    loc = jnp.array(list(metrics[name]["Avg. loc. err."].values()));
    loc_mean = loc.mean(axis=(-2,-1))
    loc_std = loc.mean(-1).std(-1)
    mainax.plot(jnp.linspace(0,1,len(loc)), loc_mean, c=f"C0", label="Error on local data");
    mainax.fill_between(jnp.linspace(0,1,len(loc)), loc_mean - loc_std, loc_mean + loc_std, color=f"C0", alpha=.2, linewidth=0.);
    
    twin = mainax.twinx();
    hess = jnp.array(list(metrics[name][r"Avg. $\lambda_\text{max}$"].values()));
    hess_mean = hess.mean(axis=(-2,-1))
    hess_std = hess.mean(-1).std(-1)
    twin.plot(jnp.linspace(0,1,len(hess)), hess_mean, c=f"C2", label=r"$\lambda_\mathrm{max}$");
    twin.fill_between(jnp.linspace(0,1,len(hess)), hess_mean - hess_std, hess_mean + hess_std, color=f"C2", alpha=.2, linewidth=0.);
    
    mainax.set_xlabel("Interpolation coefficient");
    mainax.set_ylabel("Angular error (degrees)");
    twin.set_ylabel("Top Hessian eigenvalue");
    mainax.set_xlim(0,1);
    mainax.grid(True, linestyle="--", linewidth=0.5, axis="both");
    fig.legend();
    fig.savefig(f"agg/interpolation_{name}.png");