In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import jax, optax, pickle
from jax import numpy as jnp
from flax import nnx
from models import LeNet, ResNet
from fedflax import train, aggregate
from data import get_gaze
from functools import reduce
from matplotlib import pyplot as plt

## Setup

In [None]:
# Optimizer
opt_create = lambda model: nnx.Optimizer(
    model,
    optax.adamw(learning_rate=1e-3),
    wrt=nnx.Param
)

# Loss includes softmax layer
def ell(model, _, x_batch, z_batch, y_batch, train):
    ce = optax.softmax_cross_entropy(model(x_batch, z_batch, train=train), y_batch).mean()
    return ce

# Data
ds_train = get_gaze(beta=.5, skew="feature")
ds_val = get_gaze(partition="val", beta=.5, batch_size=32, skew="feature")
ds_test = get_gaze(partition="test", beta=.5, batch_size=16, skew="feature")

## Check metrics
Available techniques are dimension expansion and W-Asymmetry, of which the latter is implemented here.

In [None]:
def angle(updates):
    update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updates)
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates_flat = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), updates), axis=1)
    for update in updates_flat:
        angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update))).item()
    return angle

metrics = {"angle":[], "err":[]}
pfixes = [0.5, 0.525, 0.55, 0.575, 0.6, 0.625, 0.65, 0.675, 0.7, 0.725, 0.75, 0.775, 0.8, 0.825, 0.85, 0.875, 0.9, 0.925, 0.95, 0.975, 1.]
for pfix in pfixes:
    # Train
    model = ResNet(nnx.Rngs(42), pfix=pfix, mask_key=jax.random.key(43))
    updates, models = train(model, opt_create, ds_train, ds_val, ell, local_epochs=20, rounds=1)
    # Aggregate
    struct, params, rest = nnx.split(models, (nnx.Param, nnx.BatchStat), ...)
    params = jax.tree.map(lambda p: p.mean(axis=0), params)
    rest = jax.tree.map(lambda r: r[-1], rest)
    model_g = nnx.merge(struct, params, rest)
    # Average accuracy on local data
    acc_fn = nnx.jit(nnx.vmap(lambda m,x,z,y: (m(x,z,train=False).argmax(-1)==y.argmax(-1)).mean(), in_axes=(None,0,0,0)))
    err = 1 - reduce(lambda acc, batch: acc + acc_fn(model_g,*batch), ds_test, 0.) / len(ds_test)
    metrics["err"].append(err)
    # Angle
    metrics["angle"].append(angle(updates))
pickle.dump(metrics, open("break_metrics.pkl", "wb"))

## Visualize

In [None]:
# Rendering style
plt.style.use("seaborn-v0_8-pastel")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"""
        \usepackage{amsmath, amssymb}
        \usepackage{mathptmx} 
    """
})

metrics = pickle.load(open("break_metrics.pkl", "rb"))
# There are three metrics with two scales
fig, ax1 = plt.subplots(dpi=300);
ax2 = ax1.twinx();
# Plot
err = jnp.array(metrics["err"])
ax1.plot(pfixes, err.mean(1), label="Error Rate", c="C1")
ax1.fill_between(pfixes, err.max(axis=1), err.min(axis=1), alpha=0.3, color="C1")
ax2.plot(pfixes, metrics["angle"], color="C0", label="Angle (°)")
# Labels
ax1.set_xlim(.5, 1.001)
ax1.set_xlabel("\% of non-frozen parameters (x100)")
ax1.set_ylabel("Error Rate")
ax2.set_ylabel("Angle (°)")
fig.legend()
fig.savefig("break.png", bbox_inches="tight")