In [None]:
%load_ext autoreload
%autoreload 2
import optax, jax, pickle
from jax import numpy as jnp
from flax import nnx
from models import ResNet
from fedflax import train
from data import get_gaze
from matplotlib import pyplot as plt
from collections import defaultdict
n = 4 # number of clients
print(jax.devices())

## Angle and beta

In [None]:
# Optimizer
opt = lambda model: nnx.Optimizer(
    model,
    optax.adamw(learning_rate=1e-3),
    wrt=nnx.Param
)

# # Sharpness-aware minimization optimizer
# opt = lambda model: nnx.Optimizer(
#     model,
#     optax.contrib.sam(
#         optax.adamw(learning_rate=1e-3),
#         optax.chain(optax.contrib.normalize(), optax.adam(1e-2)),
#         sync_period=5
#     ),
#     wrt=nnx.Param
# )

# Loss function
def return_ell(omega):
    def ell(model, model_g, z_batch, x_batch, y_batch, train):
        prox = sum(jax.tree.map(lambda a, b: jnp.sum((a-b)**2), jax.tree.leaves(nnx.to_tree(model)), jax.tree.leaves(nnx.to_tree(model_g))))
        loss = jnp.sqrt(jnp.square(model(z_batch, x_batch, train=train) - y_batch).mean())
        return (omega/2*prox + loss), (prox, loss)
    return ell

# Get updates of first communication round for various heterogeneity levels
updateses = []
betas = jnp.log(jnp.linspace(jnp.exp(1),jnp.exp(0.01),20))
for beta in betas:
    ds_train = get_gaze(skew="feature", n_clients=n, beta=beta, partition="train")
    ds_val = get_gaze(skew="feature", n_clients=n, beta=beta, partition="val")
    updates = train(ResNet, opt, ds_train, ds_val, return_ell(0.), local_epochs=10, max_patience=None, n=n)
    updateses.append(updates)

# Failsafe
pickle.dump(updateses, open("updateses_featureskew.pkl", "wb"))

## Ablation: Vary sample overlap

In [None]:
# Get updates of first communication round for various overlap levels
updateses = []
overlaps = jnp.log(jnp.linspace(jnp.exp(0),jnp.exp(1),20))
for overlap in overlaps:
    ds_train = get_gaze(skew="overlap", n_clients=n, beta=overlap, partition="train")
    ds_val = get_gaze(skew="overlap", n_clients=n, beta=overlap, partition="val")
    updates = train(ResNet, opt, ds_train, ds_val, return_ell(0.), local_epochs=10, max_patience=None)
    updateses.append(updates)
pickle.dump(updateses, open("updateses_overlap.pkl", "wb"))

## Ablation: Vary label skew

In [None]:
# Get updates of first communication round for various overlap levels
updateses = []
betas = jnp.log(jnp.linspace(jnp.exp(0),jnp.exp(1),20))
for beta in betas:
    ds_train = get_gaze(skew="label", n_clients=n, beta=beta, partition="train")
    ds_val = get_gaze(skew="label", n_clients=n, beta=beta, partition="val")
    updates = train(ResNet, opt, ds_train, ds_val, return_ell(0.), local_epochs=10, max_patience=None)
    updateses.append(updates)
pickle.dump(updateses, open("updateses_labelskew.pkl", "wb"))

## Ablation: Vary local epochs

In [None]:
# Get updates of first communication round for various overlap levels
updateses = []
epochses = jnp.linspace(1, 41, 20, dtype=int)
for epochs in epochses:
    ds_train = get_gaze(n_clients=n, partition="train")
    ds_val = get_gaze(n_clients=n, partition="val")
    updates = train(ResNet, opt, ds_train, ds_val, return_ell(0.), local_epochs=epochs, max_patience=None)
    updateses.append(updates)
pickle.dump(updateses, open("updateses_epochs.pkl", "wb"))

## Ablation: Pre-train

In [None]:
# # Get updates of first communication round for various overlap levels
# updateses = []
# roundses = jnp.linspace(1, 20, 10, dtype=int)
# for rounds in roundses:
#     ds_train = create_imagenet(feature_beta=0., sample_overlap=1., n=n, label_beta=0.)
#     ds_val = create_imagenet(path="./data/Data/CLS-LOC/val", feature_beta=0., sample_overlap=1., n=n, label_beta=0.)
#     updates = train(ResNet, opt, ds_train, ds_val, return_ell(0.), local_epochs=10, max_patience=None) # TODO: max_patience is not the same as forcing a certain number of rounds
#     updateses.append(updates)
# pickle.dump(updateses, open("updateses_pretrain.pkl", "wb"))

## Ablation: Vary n clients

In [None]:
# # Get updates of first communication round for various overlap levels
# updateses = []
# nums = jnp.linspace(1, 20, 10, dtype=int)
# for num in nums:
#     ds_train = get_gaze(n_clients=n, partition="train")
#     ds_val = get_gaze(n_clients=n, partition="val")
#     updates = train(ResNet, opt, ds_train, ds_val, return_ell(0.), local_epochs=10, max_patience=None)
#     updateses.append(updates)
# pickle.dump(updateses, open("updateses_clients.pkl", "wb"))

## Visualizations

In [None]:
# Loop over each heterogeneity level 
angles = defaultdict(list)
for i, (beta, updates) in enumerate(zip(betas, updateses)): # overlaps
    # Compute global update as mean of local updates
    update_g = jax.tree.map(lambda updates: jnp.mean(updates, axis=0), updateses[i])
    # Flatten global and local updates
    update_g = jnp.concatenate([jnp.ravel(x) for x in update_g])
    updates = jnp.concatenate(jax.tree.map(lambda x: jnp.reshape(x, (n,-1)), updates), axis=1)
    # Compute angle with global update for each client's update
    for client, update in enumerate(updates):
        angle = jnp.degrees(jnp.arccos(optax.losses.cosine_similarity(update_g, update)))
        angles[client].append(angle.item())

# Plot as function of beta
plt.style.use("seaborn-v0_8-pastel")
fig, ax = plt.subplots(1, dpi=300);
for i in range(len(angles)):
    ax.plot(betas, angles[i], label=f"Angle with update of client {i}"); # overlaps
ax.legend(loc="lower left");
ax.set_xlabel(r"Heterogeneity level ($\beta$)");# ("Sample overlap (%)")
ax.set_ylabel("Angle (degrees)");
ax.set_ylim(0);
ax.set_xlim(-.005,1.001);
ax.set_xscale("function", functions=(jnp.exp, jnp.log));
ax.set_xticks(jnp.linspace(0.01, 1, 9));
ax.set_xticklabels([f"{x:.2f}" for x in jnp.linspace(0.01, 1, 9)]);
fig.savefig("img/angles_beta.png", bbox_inches="tight");