In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import jax, optax
from jax import numpy as jnp
from flax import nnx
from models import LeNet
from fedflax import train, get_updates
from data import get_data
from matplotlib import pyplot as plt
from functools import partial, reduce
from utils import opt_create, return_l2, angle_err

## Set up metrics

In [None]:
# Set matplotlib rendering style
plt.style.use("seaborn-v0_8-pastel")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"""
        \usepackage{amsmath, amssymb}
        \usepackage{mathptmx}  % Safe fallback for Times + math
    """
})

def functional_drift(models, ds_test):
    vcall = nnx.vmap(lambda model, *batch: model(*batch, train=False))
    logits = reduce(lambda acc, batch: acc + [vcall(models, *batch[1:])], ds_test, [])
    logits = jnp.concatenate(logits, axis=1)
    logits_mean = logits.mean(0)
    drift = jnp.abs(logits - logits_mean).mean((1,2))
    return drift

## Drift as a result of feature skew

In [None]:
# Get updates' drift at first communication round for various heterogeneity levels
maes = []
angles = []
funcdrift = []
betas = 1 - jnp.log(jnp.linspace(jnp.e, 1, 20))
for beta in betas.tolist():
    # Data at specified heterogeneity level
    ds_train = get_data(skew="feature", beta=beta, partition="train")
    ds_val = get_data(skew="feature", beta=beta, partition="val")
    ds_test = get_data(skew="feature", beta=beta, partition="test")
    # Train models for one round
    model_init = LeNet(jax.random.key(42))
    models, *_ = train(model_init, partial(opt_create, learning_rate=1e-3), ds_train, return_l2(0.), ds_val, local_epochs="early", rounds=1, max_patience=2, val_fn=angle_err)
    # Get local updates and flatten and aggregate
    updates = get_updates(model_init, models)
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), jax.tree.leaves(updates)), axis=1)
    update_g = updates.mean(0)
    # L1 distance to global update for each client's update
    mae = jnp.mean(jnp.abs(update_g - updates), axis=-1)
    maes.append(mae)
    # Angle between global and local update
    angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, updates)))
    angles.append(angle)
    # Functional drift
    funcdrift.append(functional_drift(models, ds_test))

# Plot as function of beta
fig, ax = plt.subplots(1, dpi=300);
# L1
ax.plot(betas, jnp.array(maes).mean(1), label=f"L1 distance with global update, averaged over clients");
ax.fill_between(betas, jnp.array(maes).max(1), jnp.array(maes).min(1), alpha=0.3, linewidth=0.);
ax.set_ylabel("L1 distance");
# Angle
ax2 = ax.twinx();
ax2.plot(betas, jnp.array(angles).mean(1), label=f"Angle with global update, averaged over clients");
ax2.fill_between(betas, jnp.array(angles).max(1), jnp.array(angles).min(1), alpha=0.3, linewidth=0.);
ax2.set_ylabel("Angle (°)");
# Function space drift
ax3 = ax.twinx();
ax3.spines.right.set_position(("outward", 60));
ax3.plot(betas, jnp.array(funcdrift).mean(1), label=f"Functional drift, averaged over clients");
ax3.fill_between(betas, jnp.array(funcdrift).max(1), jnp.array(funcdrift).min(1), alpha=0.3, linewidth=0.);
ax3.set_ylabel("Functional drift");
# Details
ax.legend(loc="lower left");
ax.set_xlabel(r"Feature heterogeneity level ($\beta$)");
ax.set_ylim(0);
ax.set_xlim(left=-.001, right=1.001);
ax.set_xscale("function", functions=(lambda x: jnp.exp(1 - x), lambda x: 1 - jnp.log(x)));
ax.set_xticks(jnp.linspace(0, 1, 9));
ax.set_xticklabels([f"{x:.2f}" for x in jnp.linspace(0, 1, 9)]);
ax.grid(True, linestyle="--", linewidth=0.5);
fig.savefig("drift/MPIIGaze/featureskew.png", bbox_inches="tight");

## Ablation: Vary sample overlap

In [None]:
# Get updates of first communication round for various overlap levels
maes = []
angles = []
funcdrift = []
overlaps = jnp.log(jnp.linspace(jnp.e, 1, 20)).tolist()
for overlap in overlaps:
    # Data at specified heterogeneity level
    ds_train = get_data(skew="overlap", beta=overlap, partition="train")
    ds_val = get_data(skew="overlap", beta=overlap, partition="val")
    ds_test = get_data(skew="overlap", beta=overlap, partition="test")
    # Train models for one round
    model_init = LeNet(jax.random.key(42))
    models, *_ = train(model_init, partial(opt_create, learning_rate=1e-3), ds_train, return_l2(0.), ds_val, local_epochs="early", rounds=1, max_patience=2, val_fn=angle_err)
    # Get local updates and flatten and aggregate
    updates = get_updates(model_init, models)
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), jax.tree.leaves(updates)), axis=1)
    update_g = updates.mean(0)
    # L1 distance to global update for each client's update
    mae = jnp.mean(jnp.abs(update_g - updates), axis=-1)
    maes.append(mae)
    # Angle between global and local update
    angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, updates)))
    angles.append(angle)
    # Functional drift
    funcdrift.append(functional_drift(models, ds_test))

# Plot as function of beta
fig, ax = plt.subplots(1, dpi=300);
# L1
ax.plot(betas, jnp.array(maes).mean(1), label=f"L1 distance with global update, averaged over clients");
ax.fill_between(betas, jnp.array(maes).max(1), jnp.array(maes).min(1), alpha=0.3, linewidth=0.);
ax.set_ylabel("L1 distance");
# Angle
ax2 = ax.twinx();
ax2.plot(betas, jnp.array(angles).mean(1), label=f"Angle with global update, averaged over clients");
ax2.fill_between(betas, jnp.array(angles).max(1), jnp.array(angles).min(1), alpha=0.3, linewidth=0.);
ax2.set_ylabel("Angle (°)");
# Function space drift
ax3 = ax.twinx();
ax3.spines.right.set_position(("outward", 60));
ax3.plot(betas, jnp.array(funcdrift).mean(1), label=f"Functional drift, averaged over clients");
ax3.fill_between(betas, jnp.array(funcdrift).max(1), jnp.array(funcdrift).min(1), alpha=0.3, linewidth=0.);
ax3.set_ylabel("Functional drift");
# Details
ax.legend(loc="lower left");
ax.set_xlabel("Sample overlap (\%)");
ax.set_ylabel("L1 distance");
ax.set_ylim(0);
ax.set_xlim(-.001,1.001);
ax.set_xscale("function", functions=(jnp.exp, jnp.log));
ax.set_xticks(jnp.linspace(0, 1, 9));
ax.set_xticklabels(jnp.linspace(0, 100, 9, dtype=jnp.int32));
ax.grid(True, linestyle="--", linewidth=0.5);
fig.savefig("drift/MPIIGaze/overlap.png", bbox_inches="tight");

## Ablation: Vary label skew

In [None]:
# Get updates of first communication round for various label skew levels
maes = []
angles = []
funcdrift = []
betas = 1 - jnp.log(jnp.linspace(jnp.e, 1, 20))
for beta in betas.tolist():
    # Data at specified heterogeneity level
    ds_train = get_data(skew="label", beta=beta, partition="train")
    ds_val = get_data(skew="label", beta=beta, partition="val")
    ds_test = get_data(skew="label", beta=beta, partition="test")
    # Train models for one round
    model_init = LeNet(jax.random.key(42))
    models, *_ = train(model_init, partial(opt_create, learning_rate=1e-3), ds_train, return_l2(0.), ds_val, local_epochs="early", rounds=1, max_patience=2, val_fn=angle_err)
    # Get local updates and flatten and aggregate
    updates = get_updates(model_init, models)
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), jax.tree.leaves(updates)), axis=1)
    update_g = updates.mean(0)
    # L1 distance to global update for each client's update
    mae = jnp.mean(jnp.abs(update_g - updates), axis=-1)
    maes.append(mae)
    # Angle between global and local update
    angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, updates)))
    angles.append(angle)
    # Functional drift
    funcdrift.append(functional_drift(models, ds_test))

# Plot as function of beta
fig, ax = plt.subplots(1, dpi=300);
# L1
ax.plot(betas, jnp.array(maes).mean(1), label=f"L1 distance with global update, averaged over clients");
ax.fill_between(betas, jnp.array(maes).max(1), jnp.array(maes).min(1), alpha=0.3, linewidth=0.);
ax.set_ylabel("L1 distance");
# Angle
ax2 = ax.twinx();
ax2.plot(betas, jnp.array(angles).mean(1), label=f"Angle with global update, averaged over clients");
ax2.fill_between(betas, jnp.array(angles).max(1), jnp.array(angles).min(1), alpha=0.3, linewidth=0.);
ax2.set_ylabel("Angle (°)");
# Function space drift
ax3 = ax.twinx();
ax3.spines.right.set_position(("outward", 60));
ax3.plot(betas, jnp.array(funcdrift).mean(1), label=f"Functional drift, averaged over clients");
ax3.fill_between(betas, jnp.array(funcdrift).max(1), jnp.array(funcdrift).min(1), alpha=0.3, linewidth=0.);
ax3.set_ylabel("Functional drift");
# Details
ax.legend(loc="lower left");
ax.set_xlabel(r"Label heterogeneity level ($\beta$)");
ax.set_ylabel("L1 distance");
ax.set_ylim(0);
ax.set_xlim(left=-.001, right=1.001);
ax.set_xscale("function", functions=(lambda x: jnp.exp(1 - x), lambda x: 1 - jnp.log(x)));
ax.set_xticks(jnp.linspace(0, 1, 9));
ax.set_xticklabels([f"{x:.2f}" for x in jnp.linspace(0, 1, 9)]);
ax.grid(True, linestyle="--", linewidth=0.5);
fig.savefig("drift/MPIIGaze/labelskew.png", bbox_inches="tight");

## Ablation: Vary local epochs

In [None]:
# Get updates of first communication round for various numbers of local epochs
maes = []
angles = []
funcdrift = []
epochses = jnp.linspace(1, 41, 20, dtype=jnp.int32).tolist()
for epochs in epochses:
    # Data
    ds_train = get_data(partition="train", beta=.3) # Notice the fixed heterogeneity
    # Train
    model_init = LeNet(jax.random.key(42))
    models, *_ = train(model_init, partial(opt_create, learning_rate=1e-3), ds_train, ell=return_l2(0.), local_epochs=epochs, rounds=1, max_patience=2)
    # Get local updates and flatten and aggregate
    updates = get_updates(model_init, models)
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), jax.tree.leaves(updates)), axis=1)
    update_g = updates.mean(0)
    # L1 distance to global update for each client's update
    mae = jnp.mean(jnp.abs(update_g - updates), axis=-1)
    maes.append(mae)
    # Angle between global and local update
    angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, updates)))
    angles.append(angle)
    # Functional drift
    funcdrift.append(functional_drift(models, ds_test))

# Plot as function of beta
fig, ax = plt.subplots(1, dpi=300);
# L1
ax.plot(betas, jnp.array(maes).mean(1), label=f"L1 distance with global update, averaged over clients");
ax.fill_between(betas, jnp.array(maes).max(1), jnp.array(maes).min(1), alpha=0.3, linewidth=0.);
ax.set_ylabel("L1 distance");
# Angle
ax2 = ax.twinx();
ax2.plot(betas, jnp.array(angles).mean(1), label=f"Angle with global update, averaged over clients");
ax2.fill_between(betas, jnp.array(angles).max(1), jnp.array(angles).min(1), alpha=0.3, linewidth=0.);
ax2.set_ylabel("Angle (°)");
# Function space drift
ax3 = ax.twinx();
ax3.spines.right.set_position(("outward", 60));
ax3.plot(betas, jnp.array(funcdrift).mean(1), label=f"Functional drift, averaged over clients");
ax3.fill_between(betas, jnp.array(funcdrift).max(1), jnp.array(funcdrift).min(1), alpha=0.3, linewidth=0.);
ax3.set_ylabel("Functional drift");
# Details
ax.legend(loc="lower left");
ax.set_xlabel("Number of local epochs");
ax.set_ylabel("L1 distance");
ax.set_xlim(1, 41);
ax.set_ylim(0);
ax.grid(True, linestyle="--", linewidth=0.5);
fig.savefig("drift/MPIIGaze/num_local_epochs.png", bbox_inches="tight");

## Ablation: Pre-train

In [None]:
# Get updates after several communication rounds
maes = []
angles = []
funcdrift = []
roundses = jnp.linspace(1, 20, 10, dtype=jnp.int32) - jnp.linspace(0, 17, 10, dtype=jnp.int32)
model = LeNet(jax.random.key(42))
for rounds in roundses.tolist():
    # Data
    ds_train = get_data(partition="train", beta=.3) # Notice the fixed heterogeneity
    ds_val = get_data(partition="val", beta=.3)
    # Train
    models, *_ = train(model, partial(opt_create, learning_rate=1e-3), ds_train, return_l2(0.), ds_val, local_epochs="early", rounds=rounds, max_patience=2, val_fn=angle_err)
    # Get local updates and flatten and aggregate
    updates = get_updates(model, models)
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), jax.tree.leaves(updates)), axis=1)
    update_g = updates.mean(0)
    # L1 distance to global update for each client's update
    mae = jnp.mean(jnp.abs(update_g - updates), axis=-1)
    maes.append(mae)
    # Angle between global and local update
    angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, updates)))
    angles.append(angle)
    # Functional drift
    funcdrift.append(functional_drift(models, ds_test))
    # Update model to last global model
    model = nnx.from_tree(jax.tree.map(lambda p: p.mean(0), nnx.to_tree(models))) # TODO: does not make any sense, should be the penultimate global model

# Plot as function of beta
fig, ax = plt.subplots(1, dpi=300);
# L1
ax.plot(betas, jnp.array(maes).mean(1), label=f"L1 distance with global update, averaged over clients");
ax.fill_between(betas, jnp.array(maes).max(1), jnp.array(maes).min(1), alpha=0.3, linewidth=0.);
ax.set_ylabel("L1 distance");
# Angle
ax2 = ax.twinx();
ax2.plot(betas, jnp.array(angles).mean(1), label=f"Angle with global update, averaged over clients");
ax2.fill_between(betas, jnp.array(angles).max(1), jnp.array(angles).min(1), alpha=0.3, linewidth=0.);
ax2.set_ylabel("Angle (°)");
# Function space drift
ax3 = ax.twinx();
ax3.spines.right.set_position(("outward", 60));
ax3.plot(betas, jnp.array(funcdrift).mean(1), label=f"Functional drift, averaged over clients");
ax3.fill_between(betas, jnp.array(funcdrift).max(1), jnp.array(funcdrift).min(1), alpha=0.3, linewidth=0.);
ax3.set_ylabel("Functional drift");
# Details
ax.legend(loc="upper right");
ax.set_xlabel("Number of communication rounds");
ax.set_ylabel("L1 distance");
ax.set_xlim(0, 20);
ax.set_ylim(0)
ax.set_xticks(jnp.arange(0, 20, 9));
ax.set_xticklabels(jnp.arange(1, 21, 9));
ax.grid(True, linestyle="--", linewidth=0.5);
fig.savefig("drift/MPIIGaze/pretrain.png", bbox_inches="tight");

## Ablation: Vary n clients

In [None]:
# Measure average drift from 'true' update