In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import optax, jax, pickle
from jax import numpy as jnp
from flax import nnx
from models import LeNet
from fedflax import train
from data import get_gaze
from matplotlib import pyplot as plt
from collections import defaultdict
from functools import partial
print(jax.devices())

## Setup

In [None]:
# Optimizer
opt_create = lambda model: nnx.Optimizer(
    model,
    optax.adamw(learning_rate=1e-3),
    wrt=nnx.Param
)

# Loss function with a softmax layer included
def ell(model, _, x_batch, z_batch, y_batch, train):
    y_pred = model(x_batch, z_batch, train=train)
    loss = optax.softmax_cross_entropy(y_pred, y_batch).mean()
    return loss

# Set matplotlib rendering style
plt.style.use("seaborn-v0_8-pastel")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"""
        \usepackage{amsmath, amssymb}
        \usepackage{mathptmx}  % Safe fallback for Times + math
    """
})

## Angle and feature skew

In [None]:
# Get updates of first communication round for various heterogeneity levels
updateses = []
betas = 1 - jnp.log(jnp.linspace(jnp.e, 1, 20))
for beta in betas:
    ds_train = get_gaze(skew="feature", beta=beta, partition="train")
    ds_val = get_gaze(skew="feature", beta=beta, partition="val", batch_size=16)
    updates, _ = train(LeNet(nnx.Rngs(42)), opt_create, ds_train, ds_val, ell, local_epochs=10, rounds=1)
    updates = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updates)    
    updateses.append(updates)
# Failsafe
pickle.dump(updateses, open("angles/MPIIGaze/updateses_featureskew.pkl", "wb"))

# Load weights
updateses = pickle.load(open("angles/MPIIGaze/updateses_featureskew.pkl", "rb"))
updateses = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updateses)
# Loop over each heterogeneity level 
angles = defaultdict(list)
for i, (beta, updates) in enumerate(zip(betas, updateses)):
    # Compute global update as mean of local updates
    update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updateses[i])
    # Flatten global and local updates
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), updates), axis=1)
    # Compute angle with global update for each client's update
    for client, update in enumerate(updates):
        angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update)))
        angles[client].append(angle.item())

# Plot as function of beta
fig, ax = plt.subplots(1, dpi=300);
for i in range(4):
    ax.plot(betas, angles[i], label=f"Angle with update of client {i}");
ax.legend(loc="lower left");
ax.set_xlabel(r"Feature heterogeneity level ($\beta$)");
ax.set_ylabel("Angle (degrees)");
ax.set_ylim(0, 55);
ax.set_xlim(left=-.001, right=1.001);
ax.set_xscale("function", functions=(lambda x: jnp.exp(1 - x), lambda x: 1 - jnp.log(x)));
ax.set_xticks(jnp.linspace(0, 1, 9));
ax.set_xticklabels([f"{x:.2f}" for x in jnp.linspace(0, 1, 9)]);
ax.grid(True, linestyle="--", linewidth=0.5);
fig.savefig("angles/MPIIGaze/angles_featureskew.png", bbox_inches="tight");

## Ablation: Vary sample overlap

In [None]:
# Get updates of first communication round for various overlap levels
updateses = []
overlaps = jnp.log(jnp.linspace(jnp.e, 1, 20))
for overlap in overlaps:
    ds_train = get_gaze(skew="overlap", beta=overlap, partition="train")
    ds_val = get_gaze(skew="overlap", beta=overlap, partition="val", batch_size=16)
    updates, _ = train(LeNet(nnx.Rngs(42)), opt_create, ds_train, ds_val, ell, local_epochs=10, rounds=1)
    updates = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updates)    
    updateses.append(updates)
pickle.dump(updateses, open("angles/MPIIGaze/updateses_overlap.pkl", "wb"))

# Load weights
updateses = pickle.load(open("angles/MPIIGaze/updateses_overlap.pkl", "rb"))
updateses = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updateses)
# Loop over each heterogeneity level 
angles = defaultdict(list)
for i, (overlap, updates) in enumerate(zip(overlaps, updateses)):
    # Compute global update as mean of local updates
    update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updateses[i])
    # Flatten global and local updates
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), updates), axis=1)
    # Compute angle with global update for each client's update
    for client, update in enumerate(updates):
        angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update)))
        angles[client].append(angle.item())

# Plot as function of beta
plt.style.use("seaborn-v0_8-pastel")
fig, ax = plt.subplots(1, dpi=300);
for i in range(4):
    ax.plot(overlaps, angles[i], label=f"Angle with update of client {i}");
ax.legend(loc="lower left");
ax.set_xlabel("Sample overlap (\%)");
ax.set_ylabel("Angle (degrees)");
ax.set_ylim(0, 55);
ax.set_xlim(-.001,1.001);
ax.set_xscale("function", functions=(jnp.exp, jnp.log));
ax.set_xticks(jnp.linspace(0, 1, 9));
ax.set_xticklabels(jnp.linspace(0, 100, 9, dtype=jnp.int32));
ax.grid(True, linestyle="--", linewidth=0.5);
fig.savefig("angles/MPIIGaze/angles_overlap.png", bbox_inches="tight");

## Ablation: Vary label skew

In [None]:
# Get updates of first communication round for various label skew levels
updateses = []
betas = 1 - jnp.log(jnp.linspace(jnp.e, 1, 20))
for beta in betas:
    ds_train = get_gaze(skew="label", beta=beta, partition="train")
    ds_val = get_gaze(skew="label", beta=beta, partition="val", batch_size=16)
    updates, _ = train(LeNet(nnx.Rngs(42)), opt_create, ds_train, ds_val, ell, local_epochs=10, rounds=1)
    updates = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updates)    
    updateses.append(updates)
pickle.dump(updateses, open("angles/MPIIGaze/updateses_labelskew.pkl", "wb"))

# Load weights
updateses = pickle.load(open("angles/MPIIGaze/updateses_labelskew.pkl", "rb"))
updateses = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updateses)
# Loop over each heterogeneity level 
angles = defaultdict(list)
for i, (beta, updates) in enumerate(zip(betas, updateses)):
    # Compute global update as mean of local updates
    update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updateses[i])
    # Flatten global and local updates
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), updates), axis=1)
    # Compute angle with global update for each client's update
    for client, update in enumerate(updates):
        angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update)))
        angles[client].append(angle.item())

# Plot as function of beta
plt.style.use("seaborn-v0_8-pastel")
fig, ax = plt.subplots(1, dpi=300);
for i in range(4):
    ax.plot(betas, angles[i], label=f"Angle with update of client {i}");
ax.legend(loc="lower left");
ax.set_xlabel(r"Label heterogeneity level ($\beta$)");
ax.set_ylabel("Angle (degrees)");
ax.set_ylim(0, 55);
ax.set_xlim(left=-.001, right=1.001);
ax.set_xscale("function", functions=(lambda x: jnp.exp(1 - x), lambda x: 1 - jnp.log(x)));
ax.set_xticks(jnp.linspace(0, 1, 9));
ax.set_xticklabels([f"{x:.2f}" for x in jnp.linspace(0, 1, 9)]);
ax.grid(True, linestyle="--", linewidth=0.5);
fig.savefig("angles/MPIIGaze/angles_labelskew.png", bbox_inches="tight");

## Ablation: Vary local epochs

In [None]:
# Get updates of first communication round for various numbers of local epochs
updateses = []
epochses = jnp.linspace(1, 41, 20, dtype=jnp.int32)
for epochs in epochses:
    ds_train = get_gaze(partition="train", beta=.3) # Notice the fixed heterogeneity
    ds_val = get_gaze(partition="val", beta=.3, batch_size=16)
    updates, _ = train(LeNet(nnx.Rngs(42)), opt_create, ds_train, ds_val, ell, local_epochs=epochs, rounds=1)
    updates = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updates)
    updateses.append(updates)
pickle.dump(updateses, open("angles/MPIIGaze/updateses_epochs.pkl", "wb"))

# Load weights
updateses = pickle.load(open("angles/MPIIGaze/updateses_epochs.pkl", "rb"))
updateses = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updateses)
# Loop over each heterogeneity level 
maes = defaultdict(list)
for i, (epochs, updates) in enumerate(zip(epochses, updateses)):
    # Compute global update as mean of local updates
    update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updateses[i])
    # Flatten global and local updates
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), updates), axis=1)
    # Compute angle with global update for each client's update
    for client, update in enumerate(updates):
        mae = jnp.mean(jnp.abs(update_g - update))
        maes[client].append(mae.item())

# Plot as function of beta
plt.style.use("seaborn-v0_8-pastel")
fig, ax = plt.subplots(1, dpi=300);
for i in range(4):
    ax.plot(epochses, maes[i], label=f"Angle with update of client {i}");
ax.legend(loc="lower left");
ax.set_xlabel("Number of local epochs");
ax.set_ylabel("L1 distance");
ax.set_xlim(1, 41);
ax.set_ylim(0);
ax.grid(True, linestyle="--", linewidth=0.5);
fig.savefig("angles/MPIIGaze/l1_epochs.png", bbox_inches="tight");

## Ablation: Pre-train

In [None]:
# Get updates of first communication round for various overlap levels
updateses = []
roundses = jnp.linspace(1, 20, 10, dtype=jnp.int32) - jnp.linspace(0, 17, 10, dtype=jnp.int32)
model = LeNet(nnx.Rngs(42))
for rounds in roundses:
    ds_train = get_gaze(partition="train", beta=.3) # Notice the fixed heterogeneity
    ds_val = get_gaze(partition="val", beta=.3, batch_size=16)
    updates, model = train(model, opt_create, ds_train, ds_val, ell, local_epochs=10, rounds=rounds)
    model = nnx.from_tree(jax.tree.map(lambda p: p.mean(0), nnx.to_tree(model)))
    updates = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updates)
    updateses.append(updates)
pickle.dump(updateses, open("angles/MPIIGaze/updateses_pretrain.pkl", "wb"))

# Load weights
updateses = pickle.load(open("angles/MPIIGaze/updateses_pretrain.pkl", "rb"))
updateses = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updateses)
# Loop over each heterogeneity level 
angles = defaultdict(list)
for updates in updateses:
    # Compute global update as mean of local updates
    update_g = jax.tree.map(lambda update: jnp.mean(update, axis=0), updates)
    # Flatten global and local updates
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), updates), axis=1)
    # Compute angle with global update for each client's update
    for client, update in enumerate(updates):
        angle = jnp.mean(jnp.abs(update_g-update))#jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update)))
        angles[client].append(angle.item())

# Plot as function of beta
plt.style.use("seaborn-v0_8-pastel")
fig, ax = plt.subplots(1, dpi=300);
for i in range(4):
    ax.plot(jnp.linspace(1, 20, 10, dtype=jnp.int32), angles[i], label=f"Angle with update of client {i}");
ax.legend(loc="upper right");
ax.set_xlabel("Number of communication rounds");
ax.set_ylabel("L1 distance")#("Angle (degrees)");
ax.set_xlim(0, 20);
ax.set_ylim(0)
ax.set_xticks(jnp.arange(0, 20, 9));
ax.set_xticklabels(jnp.arange(1, 21, 9));
ax.grid(True, linestyle="--", linewidth=0.5);
fig.savefig("angles/MPIIGaze/l1_pretrained.png", bbox_inches="tight");

## Ablation: Vary n clients

In [None]:
# Get updates of first communication round for various numbers of clients
updateses = []
nums = jnp.linspace(1, 15, 10, dtype=jnp.int32).tolist()
for num in nums:
    ds_train = get_gaze(n_clients=num, partition="train", beta=.3) # Notice the fixed heterogeneity
    ds_val = get_gaze(n_clients=num, partition="val", beta=.3, batch_size=16)
    updates, _ = train(LeNet(nnx.Rngs(42)), opt_create, ds_train, ds_val, ell, n=num, local_epochs=10, rounds=1)
    updates = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updates)
    updateses.append(updates)
pickle.dump(updateses, open("angles/MPIIGaze/updateses_clients.pkl", "wb"))

# Load weights
updateses = pickle.load(open("angles/MPIIGaze/updateses_clients.pkl", "rb"))
updateses = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updateses)
# Loop over each heterogeneity level 
angles = [0]*len(nums)
angles_min = [None]*len(nums)
angles_max = [None]*len(nums)
for i, (num, updates) in enumerate(zip(nums, updateses)):
    # Compute global update as mean of local updates
    update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updateses[i])
    # Flatten global and local updates
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (num,-1)), updates), axis=1)
    # Compute angle with global update for each client's update
    for update in updates:
        angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update))).item()
        angles_min[i] = angles_min[i] if angles_min[i] is not None and angles_min[i]<angle else angle
        angles_max[i] = angles_max[i] if angles_max[i] is not None and angles_max[i]>angle else angle
        angles[i] += angle
    angles[i] /= num

# Plot as function of beta
plt.style.use("seaborn-v0_8-pastel")
fig, ax = plt.subplots(1, dpi=300);
ax.plot(nums, angles);
ax.fill_between(nums, angles_min, angles_max, alpha=.3);
ax.set_xlabel("Number of clients");
ax.set_ylabel("Angle (degrees)");
ax.set_ylim(0, 55);
ax.set_xlim(1, 15);
ax.grid(True, linestyle="--", linewidth=0.5);
fig.savefig("angles/MPIIGaze/angles_clients.png", bbox_inches="tight");