In [1]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import jax, json, matplotlib as mpl, optax
from jax import numpy as jnp
from flax import nnx
from models import LeNet
from fedflax import train, aggregate, get_updates, cast
from data import fetch_data
from functools import reduce
from matplotlib import pyplot as plt
from utils import return_l2, angle_err, functional_drift, nnx_norm
from collections import defaultdict
from scipy import stats
n_clients = 4

## Setup

In [None]:
# Data 
ds_train = fetch_data(beta=1., skew="feature", n_clients=n_clients)
ds_val = fetch_data(partition="val", beta=1., batch_size=32, skew="feature", n_clients=n_clients)
ds_test = fetch_data(partition="test", beta=1., batch_size=16, skew="feature", n_clients=n_clients)

# Various settings to find conditions in which the asymmetry helps
hypers_sweep:tuple[dict,...] = (
    {},
    {"wasym":True, "kappa":1.},
    {"sigma":1e-4},
    {"normweights":True},
    {"dimexp":2},
    {"wasym":True, "kappa":.1},
    {"sigma":1e-3},
    {"dimexp":4},
    {"wasym":True, "kappa":5.},
    {"sigma":5e-2},
    {"dimexp":6}
)

# Perform a statistical test to assess whether regular round is worse than asym round
def t_test(grid, mu_null=0., two_tailed:bool=True):
    reg_gain = grid[:-1,:] - grid[1:,:] # i.e., error gain of step in direction of regular rounds
    asym_gain = grid[:, :-1] - grid[:, 1:] # i.e., error gain of step in direction of asym rounds
    diffs = reg_gain[:,:-1] - asym_gain[:-1,:] # i.e., difference between gain of reg round and asym round
    diffs = jnp.fill_diagonal(diffs, jnp.nan, inplace=False) # otherwise would skew mu

    mu = jnp.nanmean(diffs)
    std = jnp.nanstd(diffs, ddof=1)
    n = jnp.sum(~jnp.isnan(diffs))
    t = (mu - mu_null) / std * jnp.sqrt(n)
    p = (int(two_tailed)+1)*(1-stats.t.cdf(jnp.abs(t), df=n-1))

    return {"p-value": p, "df": n-1, "t-statistic": t}

# Hyperparameter selection

In [None]:
e = defaultdict(lambda: defaultdict(list))
for hypers in hypers_sweep:
    # In case of syre, add wd
    if "sigma" in hypers:
        ell = lambda m, mg, y, x: return_l2(0.)(m, mg, y, x) + 1e-4*nnx_norm(nnx.state(m, nnx.Param), n_clients=n_clients).mean()
    else: ell = return_l2(0.)
    # Provide bounds on results
    for seed in range(10):
        # Train
        model_init = LeNet(jax.random.key(seed), **hypers)
        opt = nnx.Optimizer(model_init, optax.adam(1e-3), wrt=nnx.Param)
        models, _ = train(model_init, opt, ds_train, ell, ds_val, local_epochs="early", rounds=1, val_fn=angle_err, max_patience=5, n_clients=n_clients)
        # Aggregate
        state_l = nnx.state(models, (nnx.Param, nnx.BatchStat))
        state_g = jax.tree.map(lambda p: p.mean(0), state_l)

        # Calculate L1 and L2 distance from global model
        e[str(hypers)]["l2"].append(nnx_norm(state_l, state_g, order=2., n_clients=n_clients).mean().item())
        e[str(hypers)]["l1"].append(nnx_norm(state_l, state_g, order=1., n_clients=n_clients).mean().item())
        vval_fn = nnx.jit(nnx.vmap(angle_err))
        e[str(hypers)]["angle_err"].append(reduce(lambda e, batch: e + vval_fn(models, *batch), ds_test, jnp.zeros(n_clients)).mean().item() / len(ds_test))

        # Save
        json.dump(e, open("break/asym_hypers.json", "w"))

## Visualize

In [None]:
# Rendering style
plt.style.use("seaborn-v0_8-pastel")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"""
        \usepackage{amsmath, amssymb}
        \usepackage{mathptmx} 
    """
})

def plot_as_grid(data:dict[int:dict[int:float]], filename:str):
    fig, ax = plt.subplots(dpi=500);
    # Determine grid size
    y_ticks = list(data.keys())
    x_ticks = max([set(subdict.keys()) for subdict in data.values()], key=len)
    x_ticks = sorted(x_ticks, key=int)
    grid = jnp.full((len(y_ticks), len(x_ticks)), jnp.nan, jnp.float32)
    # Populate grid
    for i, n_regular_rounds in enumerate(y_ticks):
        for j, n_asym_rounds in enumerate(x_ticks):
            # Set value, or NaN if missing
            value = data[n_regular_rounds].get(n_asym_rounds, jnp.nan)
            grid = grid.at[i, j].set(value)
            # Simultaneously, plot the value as text
            text = f"{value:.2f}" if not jnp.isnan(value) else ""
            ax.text(j, i, text, ha="center", va="center", color="w");

    # Plot grid
    cmap = plt.get_cmap("inferno").with_extremes(bad="lightgray");
    ax.imshow(grid, cmap=cmap, norm=mpl.colors.LogNorm(vmin=3.66, vmax=15.5), interpolation="nearest");
    ax.set_xticks(range(len(x_ticks)), labels=[str(x) for x in x_ticks]);
    ax.set_yticks(range(len(y_ticks)), labels=[str(y) for y in y_ticks]);
    ax.set_ylabel("No. of conventional communication rounds performed in advance");
    ax.set_xlabel("No. of communication rounds subsequently\nperformed with asymmetry applied");
    fig.savefig(filename.removesuffix(".png")+".png", bbox_inches="tight");
    return grid

## Check metrics
Available techniques are dimension expansion and W-Asymmetry, of which the latter is implemented here.

In [None]:
for hypers in hypers_sweep:
    # Track metrics
    hypers_as_str = str(hypers)
    test_errs_global = defaultdict(dict)
    test_errs_local = defaultdict(dict)
    drift = defaultdict(dict)
    # In case of syre, add wd
    if "sigma" in hypers:
        ell = lambda m, mg, y, x: return_l2(0.)(m, mg, y, x) + 1e-4*nnx_norm(nnx.state(m, nnx.Param), n_clients=n_clients).mean()
    else: ell = return_l2(0.)
    # Initialize regular model
    reg_model = LeNet(jax.random.key(42))
    # Number of rounds is initialized to number required for early stopping
    regular_roundses = jnp.array([(i+1)//2 for i in range(11)]).tolist()
    for j in range(len(regular_roundses)):
        # Federated training using regular model
        if regular_roundses[j]:
            opt = nnx.Optimizer(reg_model, optax.adam(1e-3), wrt=nnx.Param)
            reg_models, _ = train(reg_model, opt, ds_train, ell, ds_val, local_epochs="early", rounds=regular_roundses[j], val_fn=angle_err, max_patience=2, n_clients=n_clients)
            reg_model = aggregate(reg_model, get_updates(reg_model, reg_models))
        # Finalization of training using asymmetric model
        asym_roundses = jnp.array([(i+1)//2 for i in range(11)]).tolist()
        for i in range(len(asym_roundses)):
            if i==1:
                # Transfer parameters of the trained regular model to an asymmetric architecture
                _, fc1reg, reg_params, _ = nnx.split(reg_model, nnx.All(nnx.PathContains("fc1"), (nnx.Param, nnx.BatchStat)), (nnx.Param, nnx.BatchStat), ...)
                asym_struct, fc1asym, _, rest = nnx.split(LeNet(**hypers), nnx.All(nnx.PathContains("fc1"), (nnx.Param, nnx.BatchStat)), (nnx.Param, nnx.BatchStat), ...)
                fc1 = fc1asym if "dimexp" in hypers else fc1reg
                asym_model = nnx.merge(asym_struct, fc1, reg_params, rest)
            if i==0:
                asym_model = reg_model
                asym_models = cast(reg_model, n_clients) if j==0 else reg_models
            else:
                # Train asymmetric
                opt = nnx.Optimizer(asym_model, optax.adam(1e-3), wrt=nnx.Param)
                asym_models, _ = train(asym_model, opt, ds_train, ell, ds_val, local_epochs="early", rounds=asym_roundses[i], val_fn=angle_err, max_patience=2, n_clients=n_clients)
                # Aggregate final models (the model passed as first arguments is technically irrelevant)
                updates = get_updates(asym_model, asym_models)
                asym_model = aggregate(asym_model, updates)                
            # Accuracy on each client's data of aggregated model
            vval_fn = nnx.jit(nnx.vmap(angle_err, in_axes=(None,0,0,0)))
            err_test_global = reduce(lambda e, batch: e + vval_fn(asym_model, *batch), ds_test, 0.) / len(ds_test)
            test_errs_global[sum(regular_roundses[:j+1])][sum(asym_roundses[:i+1])] = err_test_global.mean().item()
            # Accuracy of local models
            vval_fn = nnx.jit(nnx.vmap(angle_err))
            err_test_local = reduce(lambda e, batch: e + vval_fn(asym_models, *batch), ds_test, 0.) / len(ds_test)
            test_errs_local[sum(regular_roundses[:j+1])][sum(asym_roundses[:i+1])] = err_test_local.mean().item()
            # Client drift measured in function space # TODO: doesn't make sense when training for multiple rounds
            drift[sum(regular_roundses[:j+1])][sum(asym_roundses[:i+1])] = functional_drift(asym_models, ds_test).mean().item()
            # Save intermediate results
            json.dump(test_errs_global, open(f"break/test_errs_global_{hypers_as_str}.json", "w"))
            json.dump(test_errs_local, open(f"break/test_errs_local_{hypers_as_str}.json", "w"))
            json.dump(drift, open(f"break/drift_{hypers_as_str}.json", "w"))
    # Plot results for this asymmetry setting
    grid = plot_as_grid(test_errs_global, f"break/test_errs_global_{hypers_as_str}.png")
    print(t_test(grid))
    plot_as_grid(test_errs_local, f"break/test_errs_local_{hypers_as_str}.png")
    plot_as_grid(drift, f"break/drift_{hypers_as_str}.png")