In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import jax, optax, pickle, argparse
from jax import numpy as jnp
from flax import nnx
from models import LeNet, ResNet
from fedflax import train, aggregate
from data import get_gaze
from functools import reduce, partial
from matplotlib import pyplot as plt
from utils import return_l2, angle_err, opt_create

## Setup
Based on command-line params

In [None]:
# Data 
ds_train = get_gaze(beta=1., skew="feature", discrete=False)
ds_val = get_gaze(partition="val", beta=1., batch_size=32, skew="feature", discrete=False)
ds_test = get_gaze(partition="test", beta=1., batch_size=16, skew="feature", discrete=False)

# All various settings to find conditions in which the asymmetry helps
params_sweep:tuple[dict,...] = (
    # Regular
    {},
    {"model_class":ResNet},

    # W-Asymmetry
    {"wasym":"densest", "kappa":5.},
    {"wasym":"densest", "kappa":1.},
    {"wasym":"densest", "kappa":.1},
    {"wasym":"densest", "kappa":1e-2},
    {"wasym":"densest", "kappa":0.},

    {"wasym":"random", "kappa":5.},
    {"wasym":"random", "kappa":1.},
    {"wasym":"random", "kappa":.1},
    {"wasym":"random", "kappa":1e-2},
    {"wasym":"random", "kappa":0.},

    {"model_class":ResNet, "wasym":"densest", "kappa":5.},
    {"model_class":ResNet, "wasym":"densest", "kappa":1.},
    {"model_class":ResNet, "wasym":"densest", "kappa":.1},
    {"model_class":ResNet, "wasym":"densest", "kappa":1e-2},
    {"model_class":ResNet, "wasym":"densest", "kappa":0.},

    {"model_class":ResNet, "wasym":"random", "kappa":5.},
    {"model_class":ResNet, "wasym":"random", "kappa":1.},
    {"model_class":ResNet, "wasym":"random", "kappa":.1},
    {"model_class":ResNet, "wasym":"random", "kappa":1e-2},
    {"model_class":ResNet, "wasym":"random", "kappa":0.},

    # SyRe
    {"sigma":1e-4},
    {"sigma":1e-3},
    {"sigma":5e-3},
    {"sigma":1e-2},
    {"sigma":5e-2},

    {"model_class":ResNet, "sigma":1e-4},
    {"model_class":ResNet, "sigma":1e-3},
    {"model_class":ResNet, "sigma":5e-3},
    {"model_class":ResNet, "sigma":1e-2},
    {"model_class":ResNet, "sigma":5e-2},

    {"sigma":1e-4, "weight_decay":2e-3},
    {"sigma":1e-3, "weight_decay":2e-3},
    {"sigma":5e-3, "weight_decay":2e-3},
    {"sigma":1e-2, "weight_decay":2e-3},
    {"sigma":5e-2, "weight_decay":2e-3},
    
    {"model_class":ResNet, "sigma":1e-4, "weight_decay":2e-3},
    {"model_class":ResNet, "sigma":1e-3, "weight_decay":2e-3},
    {"model_class":ResNet, "sigma":5e-3, "weight_decay":2e-3},
    {"model_class":ResNet, "sigma":1e-2, "weight_decay":2e-3},
    {"model_class":ResNet, "sigma":5e-2, "weight_decay":2e-3},
)

## Check metrics
Available techniques are dimension expansion and W-Asymmetry, of which the latter is implemented here.

In [None]:
def angle(updates):
    update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updates)
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates_flat = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), updates), axis=1)
    for update in updates_flat:
        angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update))).item()
    return angle

metrics = {"angle":[], "test_err":[], "val_err":[]}
for params in params_sweep:
    # Train
    opt_create = partial(opt_create, learning_rate=1e-3, weight_decay=params.pop("weight_decay", 1e-4))
    model = params.pop("model_class", LeNet)
    model = model(jax.random.key(42), **params, dim_out=2)
    updates, _ = train(model, opt_create, ds_train, ds_val, return_l2(0.), local_epochs=20, rounds=1, val_fn=angle_err)
    # Aggregate
    model_g = aggregate(model, updates)
    # Average accuracy on local data
    vval_fn = nnx.jit(nnx.vmap(angle_err, in_axes=(None,0,0,0)))
    err_test = reduce(lambda e, batch: e + vval_fn(model_g,*batch), ds_test, 0.) / len(ds_test)
    err_val = reduce(lambda e, batch: e + vval_fn(model_g,*batch), ds_val, 0.) / len(ds_val)
    metrics["test_err"].append(err_test)
    metrics["val_err"].append(err_val)
    # Angle
    metrics["angle"].append(angle(jax.tree.leaves(updates)))
pickle.dump(metrics, open(f"break/break_metrics_{str(params)}.pkl", "wb"))

## Visualize

In [None]:
# Rendering style
plt.style.use("seaborn-v0_8-pastel")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"""
        \usepackage{amsmath, amssymb}
        \usepackage{mathptmx} 
    """
})

metrics = pickle.load(open(f"break/break_metrics_{discrete=}_{asym=}_{model_class.__name__}.pkl", "rb"))

# Scatter angles
fig, ax1 = plt.subplots(dpi=300);
ax1.scatter(asym_extents, metrics["angle"], color="C0", label="Client Drift", marker="D", s=20)
ax1.plot(asym_extents, metrics["angle"], color="C0", linewidth=1)
ax1.set_ylabel("Client Drift (Â°)")

# New axis for error boxplots
err = jnp.array(metrics["err"])
ax2 = ax1.twinx();
ax2.scatter(asym_extents, err.mean(axis=1), color="C1", label="Error", marker="s", s=20)
ax2.plot(asym_extents, err.mean(axis=1), color="C1", linewidth=1)
ax2.fill_between(asym_extents, err.min(1), err.max(1), color="C1", alpha=0.5, step="mid", linewidth=0)
ax2.set_ylabel("Error")

# Final touches
ax1.set_xlim(min(asym_extents), max(asym_extents))
ax1.set_xticks(asym_extents[::3])
ax1.set_xticklabels(ax1.get_xticks().round(4))
ax1.set_xlabel("Application of asymmetry parameter")
fig.legend()
fig.savefig(f"break/break_{discrete=}_{asym=}_model={model_class.__name__}.png", bbox_inches="tight")