In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import jax, optax
from jax import numpy as jnp
from flax import nnx
from models import LeNet, WAsymLeNet, DimExpLeNet
from fedflax import train
from data import get_gaze

## Setup

In [None]:
# Optimizer
opt_create = lambda model: nnx.Optimizer(
    model,
    optax.adamw(learning_rate=1e-3),
    wrt=nnx.Param
)

# Loss includes softmax layer
def ell(model, _, x_batch, z_batch, y_batch, train):
    ce = optax.softmax_cross_entropy(model(x_batch, z_batch, train=train), y_batch).mean()
    return ce

model = LeNet(nnx.Rngs(42))

ds_train = get_gaze(beta=.5, skew="feature")
ds_val = get_gaze(partition="val", beta=.5, batch_size=32, skew="feature")

## Train with dimension expansion

In [None]:
model = DimExpLeNet(nnx.Rngs(42))

updates, models = train(model, opt_create, ds_train, ds_val, ell, local_epochs=10, rounds=1)

## Train with W-Asymmetry

In [None]:
model = WAsymLeNet(jax.random.key(42), pfix=.975)

updates, models = train(model, opt_create, ds_train, ds_val, ell, local_epochs=20, rounds=1)

In [None]:
# OVERLAP SKEW
#Wasym, 20 epochs, p0.75 (excl conv), beta0.5: 0.3691, ~2.7 degrees
#Lenet, 20 epochs, (p1.0), beta0.5: 0.4453125, ~34 degrees
#Wasym, 20 epochs, p0.85, beta0.5: 0.388671875, ~4 degrees
#WASYM, 20 epochs, p0.95, beta0.5: 0.37, ~5.5 degrees

# FEATURE SKEW
#Wasym, 20 epochs, p0.75, beta0.5: 0.2695, ~2.4 degrees
#Wasym, 20 epochs, p0.75 (excl conv), beta0.5: 0.38671875, ~2.5 degrees
#Wasym, 20 epochs, p0.85 (excl conv), beta0.5: 0.4296875, ~4.6 degrees
#Wasym, 20 epochs, p0.975 (excl conv), beta0.5: 0.517578125, ~9.4 degrees
#Lenet, 20 epochs, (p1.0), beta0.5: 0.4921875, ~41 degrees

## Check for decrease in client drift

In [None]:
update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updates)
update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
updates_flat = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), updates), axis=1)
for i, update in enumerate(updates_flat):
    angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update))).item()
    print(f"Angle with client {i} to global update: {angle:.2f} degrees")