In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import jax, optax, json
from jax import numpy as jnp
from flax import nnx
from models import LeNet
from fedflax import train, aggregate, get_updates, cast
from data import get_data
from functools import reduce, partial
from matplotlib import pyplot as plt
from utils import return_l2, angle_err, opt_create, functional_drift
from collections import defaultdict
n_clients = 4

## Setup

In [None]:
# Data 
ds_train = get_data(beta=1., skew="feature", n_clients=n_clients)
ds_val = get_data(partition="val", beta=1., batch_size=32, skew="feature", n_clients=n_clients)
ds_test = get_data(partition="test", beta=1., batch_size=16, skew="feature", n_clients=n_clients)

# Various settings to find conditions in which the asymmetry helps
hypers_sweep:tuple[dict,...] = (
    {"wasym":"densest", "kappa":1.},
    {"sigma":1e-4},
    {"orderbias":True},
    {"normweights":True},
    {"wasym":"random", "kappa":1.},
    {"sigma":1e-4, "weight_decay":2e-3},
    {"wasym":"densest", "kappa":1e-2},
    {"sigma":1e-3},
    {"wasym":"random", "kappa":1e-2},
    {"sigma":1e-3, "weight_decay":2e-3}
)

## Visualize

In [None]:
# Rendering style
plt.style.use("seaborn-v0_8-pastel")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"""
        \usepackage{amsmath, amssymb}
        \usepackage{mathptmx} 
    """
})

def plot_as_grid(data:dict[int:dict[int:float]], filename:str):
    fig, ax = plt.subplots(dpi=500)
    # Determine grid size
    y_ticks = list(data.keys())
    x_ticks = max([set(subdict.keys()) for subdict in data.values()], key=len)
    x_ticks = sorted(x_ticks.union({0})) # in case longest is nreg0
    grid = jnp.full((len(y_ticks)+1, len(x_ticks)+1), jnp.nan)
    # Populate grid
    for i, n_regular_rounds in enumerate(y_ticks):
        for j, n_asym_rounds in enumerate(x_ticks):
            # Set value, or NaN if missing
            value = data[n_regular_rounds].get(n_asym_rounds, jnp.nan)
            grid = grid.at[i, j].set(value)
            # Simultaneously, plot the value as text
            text = f"{value:.2f}" if not jnp.isnan(value) else ""
            ax.text(j, i, text, ha="center", va="center", color="w")

    # Plot grid
    cmap = plt.get_cmap("inferno").with_extremes(bad="lightgray")
    ax.imshow(grid, cmap=cmap, interpolation="nearest")
    ax.set_xticks(range(len(x_ticks)), labels=[str(x) for x in x_ticks])
    ax.set_yticks(range(len(y_ticks)), labels=[str(y) for y in y_ticks])
    ax.set_ylabel("No. of conventional communication rounds performed in advance")
    ax.set_xlabel("No. of communication rounds subsequently\nperformed with asymmetry applied")
    fig.tight_layout()
    fig.savefig(filename)
    return fig

## Check metrics
Available techniques are dimension expansion and W-Asymmetry, of which the latter is implemented here.

In [None]:
for hypers in hypers_sweep:
    # Track metrics
    hypers_as_str = str(hypers)
    wd = hypers.pop("weight_decay", 1e-4)
    test_errs_global = defaultdict(dict)
    test_errs_local = defaultdict(dict)
    drift = defaultdict(dict)
    # Number of rounds is initialized to number required for early stopping
    n_regular_rounds = "early"
    while n_regular_rounds!=-1:
        # Federated training using regular model
        opt_create = partial(opt_create, learning_rate=1e-3, weight_decay=wd)
        model_g = LeNet(jax.random.key(42))
        if n_regular_rounds!=0:
            models, n_regular_rounds = train(model_g, opt_create, ds_train, return_l2(0.), ds_val, local_epochs="early", rounds=n_regular_rounds, val_fn=angle_err, max_patience=2, n=n_clients)
        else:
            models = cast(model_g, n_clients)
        # Transfer parameters of the trained regular model to a new asymmetric model
        params_trained = jax.tree.map(lambda p: jnp.mean(p, axis=0), nnx.split(models, (nnx.Param, nnx.BatchStat), ...)[1])
        asym_struct, _, rest = nnx.split(LeNet(jax.random.key(42), **hypers), (nnx.Param, nnx.BatchStat), ...) # key is irrelevant because params are overwritten
        model_g = nnx.merge(asym_struct, params_trained, rest)
        # Finalization of training using asymmetric model
        n_asym_rounds = 1
        while n_asym_rounds<=25: # TODO: not exactly early stopping
            models, n_asym_rounds = train(model_g, opt_create, ds_train, return_l2(0.), ds_val, local_epochs="early", rounds=1, val_fn=angle_err, max_patience=2, n=n_clients)
            # Aggregate final models (the model passed as first arguments is technically irrelevant)
            updates = get_updates(model_g, models)
            model_g = aggregate(model_g, updates)
            # Accuracy on each client's data of aggregated model
            vval_fn = nnx.jit(nnx.vmap(angle_err, in_axes=(None,0,0,0)))
            err_test_global = reduce(lambda e, batch: e + vval_fn(model_g, *batch), ds_test, 0.) / len(ds_test)
            test_errs_global[n_regular_rounds][n_asym_rounds] = err_test_global.mean().item()
            # Accuracy of local models
            vval_fn = nnx.jit(nnx.vmap(angle_err))
            err_test_local = reduce(lambda e, batch: e + vval_fn(models, *batch), ds_test, 0.) / len(ds_test)
            test_errs_local[n_regular_rounds][n_asym_rounds] = err_test_local.mean().item()
            # Client drift measured in function space
            drift[n_regular_rounds][n_asym_rounds] = functional_drift(models, ds_test).mean().item()
            # Update number of communication rounds to be performed
            n_asym_rounds += 1
            # Save intermediate results
            json.dump(test_errs_global, open(f"break/test_errs_global_{hypers_as_str}.json", "w"))
            json.dump(test_errs_local, open(f"break/test_errs_local_{hypers_as_str}.json", "w"))
            json.dump(drift, open(f"break/drift_{hypers_as_str}.json", "w"))
        # Update number of communication rounds to be performed
        n_regular_rounds -= max(1, n_regular_rounds//3)
    # Plot results for this asymmetry setting
    plot_as_grid(test_errs_global, f"break/test_errs_global_{hypers_as_str}.png")
    plot_as_grid(test_errs_local, f"break/test_errs_local_{hypers_as_str}.png")
    plot_as_grid(drift, f"break/drift_{hypers_as_str}.png")