In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import jax, optax
from jax import numpy as jnp
from flax import nnx
from models import LeNet, ResNet
from fedflax import train, aggregate
from data import get_gaze
from functools import reduce

## Setup

In [None]:
# Optimizer
opt_create = lambda model: nnx.Optimizer(
    model,
    optax.adamw(learning_rate=1e-3),
    wrt=nnx.Param
)

# Loss includes softmax layer
def ell(model, _, x_batch, z_batch, y_batch, train):
    ce = optax.softmax_cross_entropy(model(x_batch, z_batch, train=train), y_batch).mean()
    return ce

# Data
ds_train = get_gaze(beta=.5, skew="feature")
ds_val = get_gaze(partition="val", beta=.5, batch_size=32, skew="feature")
ds_test = get_gaze(partition="test", beta=.5, batch_size=16, skew="feature")

## Check metrics
Available techniques are dimension expansion and W-Asymmetry, of which the latter is implemented here.

In [None]:
def angle(updates):
    update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updates)
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates_flat = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), updates), axis=1)
    for update in updates_flat:
        angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update))).item()
    return angle

metrics = {"Angle (Â°)":[], "Glob. err.":[], "Avg. loc. err.":[]}
pfixes = jnp.log(jnp.linspace(jnp.exp(0.5), jnp.exp(1.), 10))
for pfix in [0.5, 0.525, 0.55, 0.575, 0.6, 0.625, 0.65, 0.675, 0.7, 0.725, 0.75, 0.775, 0.8, 0.825, 0.85, 0.875, 0.9, 0.925, 0.95, 0.975, 1.]:
    # Train
    model = LeNet(nnx.Rngs(42), pfix=pfix, mask_key=jax.random.key(43))
    updates, models = train(model, opt_create, ds_train, ds_val, ell, local_epochs=20, rounds=1)
    # Aggregate
    struct, params, rest = nnx.split(models, nnx.Param, ...)
    params = jax.tree.map(lambda p: p.mean(axis=0), params)
    rest.mask_key = rest.mask_key[0] # TODO: temporary fix
    model_g = nnx.merge(struct, params, rest)
    # Average accuracy on local data
    acc_fn = nnx.jit(nnx.vmap(lambda m,x,z,y: (m(x,z,train=False).argmax(-1)==y.argmax(-1)).mean(), in_axes=(None,0,0,0)))
    err_l = 1 - reduce(lambda acc, batch: acc + acc_fn(model_g,*batch), ds_test, 0.) / len(ds_test)
    metrics["Avg. loc. err."].append(err_l)
    # Average accuracy on global data
    acc_fn = nnx.jit(lambda m,x,z,y: (m(x.reshape(-1,36,60,1),z.reshape(-1,3),train=False).argmax(-1)==y.reshape(-1,16).argmax(-1)).mean())
    accs_g = 1 - reduce(lambda acc, b: acc + acc_fn(model_g,*b), ds_test, 0.) / len(ds_test)
    metrics["Glob. err."].append(accs_g)

  0%|          | 0/20 [00:00<?, ?it/s]