In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import optax, jax, json
from jax import numpy as jnp
from flax import nnx
from fedflax import train
from models import LeNet
from data import get_data
from collections import defaultdict
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
from functools import reduce
from utils import return_l2, opt_create, angle_err

## Utils

In [None]:
# Sharpness-aware minimization optimizer
def sam(model): return nnx.Optimizer(
    model,
    optax.contrib.sam(
        optax.adamw(learning_rate=1e-3),
        optax.chain(optax.contrib.normalize(), optax.adam(1e-2)),
        sync_period=5
    ),
    wrt=nnx.Param
)
# Regular optimizer
def opt_reg(model): return nnx.Optimizer(
    model,
    optax.adamw(learning_rate=1e-3),
    wrt=nnx.Param
)

def partial_aggregate(models, alpha):
    params, struct = jax.tree.flatten(nnx.to_tree(models))
    avg_params = jax.tree.map(lambda p: jnp.mean(p, axis=0), params)
    new_params = jax.tree.map(lambda p, ap: (1-alpha)*p + alpha*ap, params, avg_params)
    return nnx.from_tree(jax.tree.unflatten(struct, new_params))

# Calculate the dominant eigenvalue (lambda_max) of an nnx model using the power iteration method
key = jax.random.key(42)
def lambda_max(model, x, y, max_iter=20, key=key):
    # Convenience
    theta, struct = jax.tree.flatten(model)
    shapes = [jnp.array(p.shape) for p in theta]
    sections = jnp.cumsum(jnp.array([jnp.prod(shape) for shape in shapes]))
    inflate = lambda th: nnx.from_tree(jax.tree.unflatten(struct,
        [layer.reshape(s) for layer,s in zip(jnp.split(th, sections), shapes)]))
    # Cross entropy
    ce_fn = lambda m: optax.losses.softmax_cross_entropy(m(x), y).mean()
    # Get first gradient
    flat_grad_fn = lambda theta: jnp.concatenate([jnp.ravel(g) for g in jax.tree.leaves(
        nnx.grad(ce_fn)(inflate(theta))
    )])
    # Random normalized vector
    theta = jnp.concatenate([jnp.ravel(layer) for layer in theta])
    v = jax.random.normal(key, theta.shape)
    v = v / (jnp.linalg.norm(v)+1e-9)
    # Power iteration 
    for _ in range(max_iter):
        hv = jax.jvp(
            flat_grad_fn,
            (theta,),
            (v,)
        )[1]
        v = hv / (jnp.linalg.norm(hv)+1e-9)
    # Return Rayleigh quotient
    return jnp.dot(v, flat_grad_fn(theta))

## Interpolate between client and global, and measure metrics

In [None]:
# Get models after one-shot for various heterogeneity levels
metrics = defaultdict(lambda: defaultdict(list))
for opt_callable, beta, skew in [(opt_reg, 1., "feature"), (opt_reg, 0., "overlap"), (sam, 1., "feature"), (opt_reg, .01, "feature"), (opt_reg, .5, "feature"), (opt_reg, .5, "overlap"), (opt_reg, .99, "overlap"), (sam, .01, "feature"), (sam, .5, "feature")]:
    # Train model
    ds_train = get_data(skew=skew, beta=beta)
    ds_val = get_data(skew=skew, beta=beta, partition="val", batch_size=16)
    ds_test = get_data(skew=skew, beta=beta, partition="test", batch_size=16)
    models, _ = train(LeNet(jax.random.key(42)), opt_callable, ds_train, return_l2(0.), ds_val, local_epochs="early", rounds=1, max_patience=5, val_fn=angle_err)
    # Interpolate
    for alpha in tqdm(jnp.linspace(0,1,30), leave=False):
        # Get the client models between global and local at alpha
        models_agg = partial_aggregate(models, alpha)
        # Average accuracy on local data
        err_fn = nnx.jit(nnx.vmap(angle_err))
        errs_l = reduce(lambda acc, b: acc + err_fn(models_agg,*b), ds_test, 0.) / len(ds_test)
        metrics[f"{skew}_{beta}_{opt_callable.__name__}"]["Avg. loc. err."].append(errs_l.tolist()) 
        # Average accuracy on global data
        err_fn = nnx.jit(nnx.vmap(angle_err, in_axes=(0,None,None,None)))
        errs_g = reduce(lambda acc, b: acc + err_fn(models_agg,b[0].reshape(-1,2),b[1].reshape(-1,36,60,1),b[2].reshape(-1,3)), ds_test, 0.) / len(ds_test)
        metrics[f"{skew}_{beta}_{opt_callable.__name__}"]["Avg. glob. err."].append(errs_g.tolist())
        # # Average top Hessian eigenvalue
        # lambda_maxs = nnx.vmap(lambda_max, in_axes=(0,None,None,None,None))(models_agg, x_test, jnp.eye(10)[y_test], 5, key)
        # metrics[beta][r"Avg. $\lambda_\text{max}$"].append(lambda_maxs.tolist())
    # Save intermediate results
    with open("agg/metrics.json", "w") as f:
        json.dump(metrics, f)

## Visualize

In [None]:
# Rendering style
plt.style.use("seaborn-v0_8-pastel")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"""
        \usepackage{amsmath, amssymb}
        \usepackage{mathptmx}  % Safe fallback for Times + math
    """
})

# Plot all settings
metrics = json.load(open("agg/metrics.json", "r"))
for name in metrics.keys(): 
    # There are supposed to be three metrics per setting
    fig, axs = plt.subplots(dpi=300);
    twin1 = axs.twinx();
    # twin2 = axs.twinx();
    # twin2.spines.right.set_position(("axes", 1.2));
    # Plot
    for c, (metric, ax) in enumerate(zip(metrics[name], [axs, twin1])):#, twin2])):
        data = jnp.array(metrics[name][metric])
        ax.plot(jnp.linspace(0,1,len(data)), data.mean(axis=-1), c=f"C{c}", label=metric);
        if metric != "Avg. loc. acc.":
            ax.fill_between(jnp.linspace(0,1,len(data)), data.min(axis=-1), data.max(axis=-1), color=f"C{c}", alpha=.2);
        ax.set_ylabel(metric);
        ax.tick_params(axis="y");
    axs.set_xlabel("Interpolation coefficient");
    axs.set_xlim(0,1);
    axs.grid(True, linestyle="--", linewidth=0.5, axis="both");
    fig.legend();
    fig.savefig(f"agg/interpolation_{name}.png");