In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import jax, optax, json
from jax import numpy as jnp
from flax import nnx
from models import LeNet
from fedflax import train, aggregate, get_updates, cast
from data import get_data
from functools import reduce, partial
from matplotlib import pyplot as plt
from utils import return_l2, angle_err, opt_create
from collections import defaultdict
n_clients = 4

## Setup
Based on command-line params

In [None]:
# Data 
ds_train = get_data(beta=1., skew="feature", discrete=False, n_clients=n_clients)
ds_val = get_data(partition="val", beta=1., batch_size=32, skew="feature", discrete=False, n_clients=n_clients)
ds_test = get_data(partition="test", beta=1., batch_size=16, skew="feature", discrete=False, n_clients=n_clients)

# Various settings to find conditions in which the asymmetry helps
hypers_sweep:tuple[dict,...] = (
    # W-Asymmetry
    {"wasym":"densest", "kappa":1.},
    {"wasym":"densest", "kappa":1e-2},
    {"wasym":"random", "kappa":1.},
    {"wasym":"random", "kappa":1e-2},

    # SyRe
    {"sigma":1e-4},
    {"sigma":1e-3},
    {"sigma":1e-4, "weight_decay":2e-3},
    {"sigma":1e-3, "weight_decay":2e-3},

    # Bias ordering
    {"orderbias":True},

    # Kernel normalization
    {"normweights":True},
)

## Visualize

In [None]:
# Rendering style
plt.style.use("seaborn-v0_8-pastel")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"""
        \usepackage{amsmath, amssymb}
        \usepackage{mathptmx} 
    """
})

def plot_as_grid(data:dict[int:dict[int:float]], filename:str):
    fig, ax = plt.subplots()
    # Determine grid size
    y_ticks = list(data.keys())
    x_ticks = max([set(subdict.keys()) for subdict in data.values()], key=len)
    x_ticks = sorted(x_ticks.union({0}))  # in case data[0] is `y_ticks`, ensure 0 is included
    grid = jnp.full((len(y_ticks)+1, len(x_ticks)+1), jnp.nan)
    # Populate grid
    for i, n_asym_rounds in enumerate(y_ticks):
        for j, n_regular_rounds in enumerate(x_ticks):
            # Set value, or NaN if missing
            value = data[n_asym_rounds].get(n_regular_rounds, jnp.nan)
            grid = grid.at[i, j].set(value)
            # Simultaneously, plot the value as text
            text = f"{value:.2f}" if not jnp.isnan(value) else ""
            ax.text(j, i, text, ha="center", va="center", color="w")

    # Plot grid
    cmap = plt.get_cmap("inferno").with_extremes(bad="lightgray")
    ax.imshow(grid, cmap=cmap, interpolation="nearest")
    ax.set_xticks(x_ticks, labels=[str(x) for x in x_ticks])
    ax.set_yticks(y_ticks, labels=[str(y) for y in y_ticks])
    ax.set_xlabel("No. of conventional communication rounds performed in advance")
    ax.set_ylabel("No. of communication rounds subsequently\nperformed with asymmetry applied")
    fig.savefig(filename)
    return fig

## Check metrics
Available techniques are dimension expansion and W-Asymmetry, of which the latter is implemented here.

In [None]:
for hypers in hypers_sweep:
    # Track metrics
    hypers_as_str = str(hypers)
    wd = hypers.pop("weight_decay", 1e-4)
    test_errs_global = defaultdict(dict)
    test_errs_local = defaultdict(dict)
    # Number of rounds is initialized to number required for early stopping
    n_asym_rounds = "early"
    n_regular_rounds = "early"
    while pre:=(n_regular_rounds!=0):
        # Federated training using regular model
        opt_create = partial(opt_create, learning_rate=1e-3, weight_decay=wd)
        model_init = LeNet(jax.random.key(42))
        if pre:
            models, n_regular_rounds = train(model_init, opt_create, ds_train, return_l2(0.), ds_val, local_epochs="early", rounds=n_regular_rounds, val_fn=angle_err, max_patience=2, n=n_clients)
        else:
            models = cast(model_init, n_clients)
        # Transfer parameters of the trained regular model to a new asymmetric model
        if post:=(n_asym_rounds!=0):
            params_trained = jax.tree.map(lambda p: jnp.mean(p, axis=0), nnx.split(models, (nnx.Param, nnx.BatchStat), ...)[1])
            asym_model = LeNet(jax.random.key(42), **hypers) # key is irrelevant because params are overwritten
            asym_struct, _, rest = nnx.split(asym_model, (nnx.Param, nnx.BatchStat), ...)
            asym_model = nnx.merge(asym_struct, params_trained, rest)
            # Finalization of training using asymmetric model
            models, n_asym_rounds = train(asym_model, opt_create, ds_train, return_l2(0.), ds_val, local_epochs="early", rounds=n_asym_rounds, val_fn=angle_err, max_patience=2, n=n_clients)
        # Aggregate final models (the model passed as first arguments is technically irrelevant)
        updates = get_updates(asym_model, models)
        model_g = aggregate(asym_model, updates)
        # Accuracy on each client's data of aggregated model
        vval_fn = nnx.jit(nnx.vmap(angle_err, in_axes=(None,0,0,0)))
        err_test_global = reduce(lambda e, batch: e + vval_fn(model_g, *batch), ds_test, 0.) / len(ds_test)
        test_errs_global[n_regular_rounds][n_asym_rounds] = err_test_global.mean().item()
        # Accuracy of local models
        vval_fn = nnx.jit(nnx.vmap(angle_err))
        err_test_local = reduce(lambda e, batch: e + vval_fn(models, *batch), ds_test, 0.) / len(ds_test)
        test_errs_local[n_regular_rounds][n_asym_rounds] = err_test_local.mean().item()
        # Update number of communication rounds to be performed
        if not post:
            n_regular_rounds -= max(1, n_regular_rounds//3)
            n_asym_rounds = "early"
        else:
            n_asym_rounds -= max(1, n_asym_rounds//3)
    # Plot results for this asymmetry setting
    plot_as_grid(test_errs_global, f"break/test_errs_global_{hypers_as_str}.png")
    plot_as_grid(test_errs_local, f"break/test_errs_local_{hypers_as_str}.png")