In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import jax, optax
from jax import numpy as jnp
from flax import nnx
from models import LeNet
from fedflax import train
from data import get_gaze

## Setup

In [None]:
# Optimizer
opt_create = lambda model: nnx.Optimizer(
    model,
    optax.adamw(learning_rate=1e-3),
    wrt=nnx.Param
)

# Loss includes softmax layer
def ell(model, _, x_batch, z_batch, y_batch, train):
    ce = optax.softmax_cross_entropy(model(x_batch, z_batch, train=train), y_batch).mean()
    return ce

model = LeNet(nnx.Rngs(42))

## Train with dimension expansion

In [None]:
def interleave(img):
    img = img.repeat_interleave(2, dim=0).repeat_interleave(2, dim=1)
    img[::2] = .5
    img[:, ::2] = .5
    return img

ds_train = get_gaze(beta=.5, transform=interleave)
ds_val = get_gaze(partition="val", transform=interleave, beta=.5, batch_size=32)

updates, models = train(model, opt_create, ds_train, ds_val, ell, local_epochs=10, rounds=1)

## Train with W-Asymmetry

In [None]:
ds_train = get_gaze(beta=.5)
ds_val = get_gaze(partition="val")

updates, models = train(model, opt_create, ds_train, ds_val, ell, local_epochs=10, rounds=1, p_mask=.75)

## Check for decrease in client drift

In [None]:
update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updates)
update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
updates_flat = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), updates), axis=1)
for i, update in enumerate(updates_flat):
    angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update))).item()
    print(f"Angle with client {i} to global update: {angle:.2f} degrees")