In [None]:
%load_ext autoreload
%autoreload 2
import optax, jax
from jax import numpy as jnp
from flax import nnx
from fedflax import train
from models import ResNet
from data import get_gaze
from collections import defaultdict
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
from functools import reduce

## Utils

In [None]:
# Sharpness-aware minimization optimizer
sam = lambda model: nnx.Optimizer(
    model,
    optax.contrib.sam(
        optax.adamw(learning_rate=1e-3),
        optax.chain(optax.contrib.normalize(), optax.adam(1e-2)),
        sync_period=5
    ),
    wrt=nnx.Param
)

# Optimizer
opt_reg = lambda model: nnx.Optimizer(
    model,
    optax.adamw(learning_rate=1e-3),
    wrt=nnx.Param
)

# Loss includes softmax layer
def ell(model, _, x_batch, z_batch, y_batch, train):
    ce = optax.softmax_cross_entropy(model(x_batch, z_batch, train=train), y_batch).mean()
    return ce, (0., 0.)

def partial_aggregate(models, alpha):
    params, struct = jax.tree.flatten(nnx.to_tree(models))
    avg_params = jax.tree.map(lambda p: jnp.mean(p, axis=0), params)
    new_params = jax.tree.map(lambda p, ap: (1-alpha)*p + alpha*ap, params, avg_params)
    return nnx.from_tree(jax.tree.unflatten(struct, new_params))

# Calculate the dominant eigenvalue (lambda_max) of an nnx model using the power iteration method
key = jax.random.key(42)
def lambda_max(model, x, y, max_iter=20, key=key):
    # Convenience
    theta, struct = jax.tree.flatten(model)
    shapes = [jnp.array(p.shape) for p in theta]
    sections = jnp.cumsum(jnp.array([jnp.prod(shape) for shape in shapes]))
    inflate = lambda th: nnx.from_tree(jax.tree.unflatten(struct,
        [layer.reshape(s) for layer,s in zip(jnp.split(th, sections), shapes)]))
    # Cross entropy
    ce_fn = lambda m: optax.losses.softmax_cross_entropy(m(x), y).mean()
    # Get first gradient
    flat_grad_fn = lambda theta: jnp.concatenate([jnp.ravel(g) for g in jax.tree.leaves(
        nnx.grad(ce_fn)(inflate(theta))
    )])
    # Random normalized vector
    theta = jnp.concatenate([jnp.ravel(layer) for layer in theta])
    v = jax.random.normal(key, theta.shape)
    v = v / (jnp.linalg.norm(v)+1e-9)
    # Power iteration 
    for _ in range(max_iter):
        hv = jax.jvp(
            flat_grad_fn,
            (theta,),
            (v,)
        )[1]
        v = hv / (jnp.linalg.norm(hv)+1e-9)
    # Return Rayleigh quotient
    return jnp.dot(v, flat_grad_fn(theta))

## Interpolate between client and global, and measure metrics

In [None]:
# Get models after one-shot for various heterogeneity levels
metrics = defaultdict(lambda: defaultdict(list))
for opt_create, beta, skew in [(opt_reg, 0., "feature"), (opt_reg, .5, "feature"), (opt_reg, 1., "feature"), (opt_reg, 0., "overlap"), (opt_reg, .5, "overlap"), (opt_reg, 1., "overlap"), (sam, 0., "feature"), (sam, .5, "feature"), (sam, 1., "feature")]:
    # Train model
    ds_train = get_gaze(skew="feature", beta=beta)
    ds_val = get_gaze(skew="feature", beta=beta, partition="val", batch_size=16)
    ds_test = get_gaze(skew="feature", beta=beta, partition="test", batch_size=16)
    _, models = train(ResNet, opt_create, ds_train, ds_val, ell, local_epochs=50, max_rounds=0)
    # Interpolate
    for alpha in tqdm(jnp.linspace(0,1,30), leave=False):
        # Get the client models between global and local at alpha
        models = partial_aggregate(models, alpha)
        # Average accuracy on local data
        acc_fn = nnx.jit(nnx.vmap(lambda m,x,z,y: (m(x,z,train=False).argmax(-1)==y.argmax(-1))))
        err_g = 1 - reduce(lambda acc, b: acc + acc_fn(models,*b), ds_test, 0.) / len(ds_test)
        metrics[f"{skew}_{beta}_{opt_create.__name__}"]["Avg. loc. err."].append(err_g)
        # Average accuracy on global data
        acc_fn = nnx.jit(nnx.vmap(lambda m,x,z,y: (m(x.reshape(-1,36,60,1),z.reshape(-1,3),train=False).argmax(-1)==y.reshape(-1,16).argmax(-1)), in_axes=(0,None,None,None)))
        accs_l = 1 - reduce(lambda acc, b: acc + acc_fn(models,*b), ds_test, 0.) / len(ds_test)
        metrics[f"{skew}_{beta}_{opt_create.__name__}"]["Avg. glob. err."].append(accs_l)
        # # Average top Hessian eigenvalue
        # lambda_maxs = nnx.vmap(lambda_max, in_axes=(0,None,None,None,None))(models, x_test, jnp.eye(10)[y_test], 5, key)
        # metrics[beta][r"Avg. $\lambda_\text{max}$"].append(lambda_maxs)

## Visualize

In [None]:
plt.style.use("seaborn-v0_8-pastel")
for name in metrics.keys(): 
    fig, axs = plt.subplots(dpi=300)
    twin1 = axs.twinx()
    twin2 = axs.twinx()
    twin2.spines.right.set_position(("axes", 1.2))
    for c, (metric, ax) in enumerate(zip(metrics[name], [axs, twin1, twin2])):
        data = metrics[name][metric]
        ax.plot(jnp.linspace(0,1,len(data)), data.mean(axis=-1), c=f"C{c}", label=metric)
        if metric != "Avg. loc. acc.":
            ax.fill_between(jnp.linspace(0,1,len(data)), data.min(axis=-1), data.max(axis=-1), color=f"C{c}", alpha=.2)
        ax.set_ylabel(metric, color=f"C{c}")
        ax.tick_params(axis="y", colors=f"C{c}")
    axs.set_xlabel("Interpolation coefficient")
    axs.set_xlim(0,1)
    # fig.legend()
    fig.savefig(f"agg/interpolation_{name}.png")