In [None]:
%load_ext autoreload
%autoreload 2
import optax, jax, pickle
from jax import numpy as jnp
from flax import nnx
from models import ResNet
from fedflax import train
from data import get_gaze
from matplotlib import pyplot as plt
from collections import defaultdict
from functools import partial
print(jax.devices())

## Setup

In [None]:
# Optimizer
opt = lambda model: nnx.Optimizer(
    model,
    optax.adamw(learning_rate=1e-3),
    wrt=nnx.Param
)

# # Sharpness-aware minimization optimizer
# opt = lambda model: nnx.Optimizer(
#     model,
#     optax.contrib.sam(
#         optax.adamw(learning_rate=1e-3),
#         optax.chain(optax.contrib.normalize(), optax.adam(1e-2)),
#         sync_period=5
#     ),
#     wrt=nnx.Param
# )

# Loss function with a softmax layer included
def return_ell(omega):
    def ell(model, model_g, x_batch, z_batch, y_batch, train):
        prox = sum(jax.tree.map(lambda a, b: jnp.sum((a-b)**2), jax.tree.leaves(nnx.to_tree(model)), jax.tree.leaves(nnx.to_tree(model_g))))
        y_pred = model(x_batch, z_batch, train=train)
        loss = optax.softmax_cross_entropy(y_pred, y_batch).mean()
        return (omega/2*prox + loss), (prox, loss)
    return ell

# Set matplotlib rendering style
plt.style.use("seaborn-v0_8-pastel")
# plt.rcParams.update({
#     "text.usetex": True,
#     "font.family": "Helvetica"
# })

## Angle and feature skew

In [None]:
# Get updates of first communication round for various heterogeneity levels
updateses = []
betas = 1 - jnp.log(jnp.linspace(jnp.exp(1), jnp.exp(0.), 20))
for beta in betas:
    ds_train = get_gaze(skew="feature", beta=beta, partition="train")
    ds_val = get_gaze(skew="feature", beta=beta, partition="val")
    updates = train(ResNet, opt, ds_train, ds_val, return_ell(0.), local_epochs=10, max_patience=None)
    updates = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updates)    
    updateses.append(updates)
# Failsafe
pickle.dump(updateses, open("updateses_featureskew.pkl", "wb"))

# Load weights
updateses = pickle.load(open("updateses_featureskew.pkl", "rb"))
updateses = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updateses)
# Loop over each heterogeneity level 
angles = defaultdict(list)
for i, (beta, updates) in enumerate(zip(betas, updateses)):
    # Compute global update as mean of local updates
    update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updateses[i])
    # Flatten global and local updates
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), updates), axis=1)
    # Compute angle with global update for each client's update
    for client, update in enumerate(updates):
        angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update)))
        angles[client].append(angle.item())

# Plot as function of beta
fig, ax = plt.subplots(1, dpi=300);
for i in range(4):
    ax.plot(betas, angles[i], label=f"Angle with update of client {i}");
ax.legend(loc="lower left");
ax.set_xlabel(r"Feature heterogeneity level ($\beta$)");
ax.set_ylabel("Angle (degrees)");
ax.set_ylim(0, 25);
ax.set_xscale("log", base=jnp.e);
ax.set_xlim(right=1.001);
ax.set_xticks(jnp.logspace(-2,0,9));
ax.set_xticklabels([f"{x:.2f}" for x in jnp.logspace(-2,0,9)]);
fig.savefig("img/angles_featureskew.png", bbox_inches="tight");

## Ablation: Vary sample overlap
Due to floating point precision, updates will still diverge if overlap is 100%. This is not due to unequal initialization, unequal random seeds, or unequal data.

In [None]:
# Get updates of first communication round for various overlap levels
updateses = []
overlaps = jnp.log(jnp.linspace(jnp.exp(1),jnp.exp(0),20))
for overlap in overlaps:
    ds_train = get_gaze(skew="overlap", beta=overlap, partition="train")
    ds_val = get_gaze(skew="overlap", beta=overlap, partition="val")
    updates = train(ResNet, opt, ds_train, ds_val, return_ell(0.), local_epochs=10, max_patience=None)
    updates = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updates)    
    updateses.append(updates)
pickle.dump(updateses, open("updateses_overlap.pkl", "wb"))

# Load weights
updateses = pickle.load(open("updateses_overlap.pkl", "rb"))
updateses = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updateses)
# Loop over each heterogeneity level 
angles = defaultdict(list)
for i, (overlap, updates) in enumerate(zip(overlaps, updateses)):
    # Compute global update as mean of local updates
    update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updateses[i])
    # Flatten global and local updates
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), updates), axis=1)
    # Compute angle with global update for each client's update
    for client, update in enumerate(updates):
        angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update)))
        angles[client].append(angle.item())

# Plot as function of beta
plt.style.use("seaborn-v0_8-pastel")
fig, ax = plt.subplots(1, dpi=300);
for i in range(4):
    ax.plot(overlaps, angles[i], label=f"Angle with update of client {i}");
ax.legend(loc="lower left");
ax.set_xlabel("Sample overlap (%)");
ax.set_ylabel("Angle (degrees)");
ax.set_ylim(0, 25);
ax.set_xlim(-.001,1.001);
ax.set_xscale("function", functions=(jnp.exp, jnp.log));
ax.set_xticks(jnp.linspace(0, 1, 9));
ax.set_xticklabels([f"{x:.2f}" for x in jnp.linspace(0, 1, 9)]);
fig.savefig("img/angles_overlap.png", bbox_inches="tight");

## Ablation: Vary label skew

In [None]:
# Get updates of first communication round for various label skew levels
updateses = []
betas = 1 - jnp.log(jnp.linspace(jnp.exp(0),jnp.exp(1),20))
for beta in betas:
    ds_train = get_gaze(skew="label", beta=beta, partition="train")
    ds_val = get_gaze(skew="label", beta=beta, partition="val")
    updates = train(ResNet, opt, ds_train, ds_val, return_ell(0.), local_epochs=10, max_patience=None)
    updates = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updates)    
    updateses.append(updates)
pickle.dump(updateses, open("updateses_labelskew.pkl", "wb"))

# Load weights
updateses = pickle.load(open("updateses_labelskew.pkl", "rb"))
updateses = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updateses)
# Loop over each heterogeneity level 
angles = defaultdict(list)
for i, (beta, updates) in enumerate(zip(betas, updateses)):
    # Compute global update as mean of local updates
    update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updateses[i])
    # Flatten global and local updates
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), updates), axis=1)
    # Compute angle with global update for each client's update
    for client, update in enumerate(updates):
        angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update)))
        angles[client].append(angle.item())

# Plot as function of beta
plt.style.use("seaborn-v0_8-pastel")
fig, ax = plt.subplots(1, dpi=300);
for i in range(4):
    ax.plot(betas, angles[i], label=f"Angle with update of client {i}");
ax.legend(loc="lower left");
ax.set_xlabel(r"Label heterogeneity level ($\beta$)");
ax.set_ylabel("Angle (degrees)");
ax.set_ylim(0, 25);
ax.set_xscale("log");
ax.set_xlim(right=1.001);
ax.set_xticks(jnp.logspace(-2,0,9));
ax.set_xticklabels([f"{x:.2f}" for x in jnp.logspace(-2,0,9)]);
fig.savefig("img/angles_labelskew.png", bbox_inches="tight");

## Ablation: Vary local epochs

In [None]:
# Get updates of first communication round for various numbers of local epochs
updateses = []
epochses = jnp.linspace(1, 41, 20, dtype=int)
for epochs in epochses:
    ds_train = get_gaze(partition="train")
    ds_val = get_gaze(partition="val")
    updates = train(ResNet, opt, ds_train, ds_val, return_ell(0.), local_epochs=epochs, max_patience=None)
    updates = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updates)
    updateses.append(updates)
pickle.dump(updateses, open("updateses_epochs.pkl", "wb"))

# Load weights
updateses = pickle.load(open("updateses_epochs.pkl", "rb"))
updateses = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updateses)
# Loop over each heterogeneity level 
angles = defaultdict(list)
for i, (epochs, updates) in enumerate(zip(epochses, updateses)):
    # Compute global update as mean of local updates
    update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updateses[i])
    # Flatten global and local updates
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (4,-1)), updates), axis=1)
    # Compute angle with global update for each client's update
    for client, update in enumerate(updates):
        angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update)))
        angles[client].append(angle.item())

# Plot as function of beta
plt.style.use("seaborn-v0_8-pastel")
fig, ax = plt.subplots(1, dpi=300);
for i in range(4):
    ax.plot(epochses, angles[i], label=f"Angle with update of client {i}");
ax.legend(loc="lower left");
ax.set_xlabel("Number of local epochs");
ax.set_ylabel("Angle (degrees)");
ax.set_ylim(0, 25);
ax.set_xlim(1, 41);
ax.set_xticks(jnp.arange(1, 41, 5));
fig.savefig("img/angles_epochs.png", bbox_inches="tight");

## Ablation: Pre-train

In [None]:
# # Get updates of first communication round for various overlap levels
# updateses = []
# roundses = jnp.linspace(1, 20, 10, dtype=int)
# for rounds in roundses:
#     ds_train = create_imagenet(feature_beta=0., sample_overlap=1., n=n, label_beta=0.)
#     ds_val = create_imagenet(path="./data/Data/CLS-LOC/val", feature_beta=0., sample_overlap=1., n=n, label_beta=0.)
#     updates = train(ResNet, opt, ds_train, ds_val, return_ell(0.), local_epochs=10, max_patience=None) # TODO: max_patience is not the same as forcing a certain number of rounds
#     updates = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updates)
#     updateses.append(updates)
# pickle.dump(updateses, open("updateses_pretrain.pkl", "wb"))

## Ablation: Vary n clients

In [None]:
# Get updates of first communication round for various numbers of clients
updateses = []
nums = jnp.linspace(1, 15, 10, dtype=int)
for num in nums:
    ds_train = get_gaze(n_clients=num, partition="train") # TODO: which type of skew?
    ds_val = get_gaze(n_clients=num, partition="val")
    updates = train(ResNet, opt, ds_train, ds_val, return_ell(0.), local_epochs=10, max_patience=None)
    updates = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updates)
    updateses.append(updates)
pickle.dump(updateses, open("updateses_clients.pkl", "wb"))

# Load weights
updateses = pickle.load(open("updateses_clients.pkl", "rb"))
updateses = jax.tree.map(partial(jax.device_put, device=jax.devices("cpu")[0]), updateses)
# Loop over each heterogeneity level 
angles = []
for i, (num, updates) in enumerate(zip(nums, updateses)):
    # Compute global update as mean of local updates
    update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updateses[i])
    # Flatten global and local updates
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (num,-1)), updates), axis=1)
    # Compute angle with global update for each client's update
    for update in updates:
        angles[i] += jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update))).item()
    angles[i] /= num

# Plot as function of beta
plt.style.use("seaborn-v0_8-pastel")
fig, ax = plt.subplots(1, dpi=300);
ax.plot(nums, angles, label=f"Average angle of client updates");
ax.legend(loc="lower left");
ax.set_xlabel("Number of clients");
ax.set_ylabel("Angle (degrees)");
ax.set_ylim(0, 25);
ax.set_xlim(1, 15);
ax.set_xticks(jnp.arange(1, 15, 3));
fig.savefig("img/angles_clients.png", bbox_inches="tight");