In [None]:
%load_ext autoreload
%autoreload 2
from jax import config
config.update("jax_platforms", "cpu")
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import jax, optax, json
from jax import numpy as jnp
from flax import nnx
from fedflax import train
from models import LeNet
from data import fetch_data
from utils import opt_create, return_l2, angle_err
from tqdm.auto import tqdm
from functools import reduce
from collections import defaultdict
from matplotlib import pyplot as plt

## Bezier interpolation setup

In [None]:
# Bezier function
def bezier(t):
    def inner(theta1, theta2, w):
        # Broadcast
        t_ = jnp.reshape(t, (-1,)+(1,)*w.ndim)
        return (1 - t_) ** 2 * theta1 + 2 * (1 - t_) * t_ * w + t_ ** 2 * theta2
    return inner

# Loss is calculated using the model resulting from a Bezier interpolation
ell = lambda m, y, *xs: jnp.square(m(*xs, train=True) - y).mean()
def ell_tilde(theta1, theta2, struct, rest):
    def inner(w, y, x, z, key):
        # Each sample gets a different t
        t = jax.random.uniform(key, y.shape[0])
        theta = jax.tree.map(bezier(t), theta1, theta2, w)
        model = nnx.merge(struct, theta, rest)
        # Use vmap to compute the loss for each t/model/sample separately
        return nnx.vmap(ell)(
            model, y[:, None], x[:, None], z[:, None]
        ).mean()
    return inner

# Straightforward
def train_step(opt, theta1, theta2, struct, rest):
    ell_tilde_fixed = ell_tilde(theta1, theta2, struct, rest)
    @jax.jit
    def inner(w, y, x, z, opt_state, key):
        loss, grad = nnx.value_and_grad(ell_tilde_fixed)(w, y, x, z, key)
        updates, opt_state = opt.update(grad, opt_state, w)
        w = optax.apply_updates(w, updates)
        return loss, opt_state, w
    return inner

## Train

In [None]:
metrics = defaultdict(dict)
for seed in range(10):
    # Train model
    model = LeNet(jax.random.key(seed))
    opt = opt_create(model, learning_rate=1e-3)
    ds_train = fetch_data(skew="feature", beta=1., n_clients=2)
    ds_val = fetch_data(skew="feature", beta=1., partition="val", batch_size=16, n_clients=2)
    models, _ = train(model, opt, ds_train, return_l2(0.), ds_val, local_epochs="early", n_clients=2, max_patience=4, rounds="early", val_fn=angle_err)

    # Get client models
    struct, theta, rest = nnx.split(models, (nnx.Param, nnx.BatchStat), ...)
    theta1 = jax.tree.map(lambda p: p[0], theta)
    theta2 = jax.tree.map(lambda p: p[1], theta)

    # Bezier parameter and optimizer
    key = jax.random.key(seed)
    def rand_like(x): global key; key, subkey = jax.random.split(key); return jax.random.normal(subkey, shape=x.shape)
    w = jax.tree.map(rand_like, theta1)
    opt = optax.adamw(learning_rate=1e-3)
    opt_state = opt.init(w)

    # Optimize Bezier curve
    train_step_fixed = train_step(opt, theta1, theta2, struct, rest)
    for epoch in (bar:=tqdm(range(1000))):
        for batch in ds_train:
            # Concatenate both clients' data
            batch = jax.tree.map(lambda x: jnp.reshape(x, (-1, *x.shape[2:])), batch)
            loss, opt_state, w = train_step_fixed(w, *batch, opt_state, key)
            _, key = jax.random.split(key)
            bar.set_description(f"Bezier optimization (seed {seed}). Loss {loss.item():.4f}")
    
    # Evaluate along the curve
    ds_test = fetch_data(skew="feature", beta=1., partition="test", batch_size=16, n_clients=2)
    for t in jnp.linspace(0., 1., 50):
        theta = jax.tree.map(bezier(t), theta1, theta2, w)
        theta = jax.tree.map(lambda p: p.squeeze(0), theta)
        model = nnx.merge(struct, theta, rest)
        vval_fn = nnx.vmap(angle_err, in_axes=(None,0,0,0))
        loss = reduce(lambda acc, batch: acc+vval_fn(model, *batch), ds_test, 0.).mean().item() / len(ds_test)
        metrics[seed][t.item()] = loss

    json.dump(metrics, open("bezier_interpolation.json", "w"))

## Visualize

In [None]:
# Rendering style
plt.style.use("seaborn-v0_8-pastel")
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.sans-serif": ["Helvetica"],
    "text.latex.preamble": r"""
        \usepackage{amsmath, amssymb}
        \usepackage{mathptmx} 
    """
})

metrics = jnp.asarray([list(s.values()) for s in metrics.values()])
std = metrics.std(0)
mean = metrics.mean(0)
plt.plot(jnp.linspace(0., 1., 50), mean);
plt.fill_between(jnp.linspace(0., 1., 50), mean - std, mean + std, alpha=0.3);
plt.xlabel("Bezier parameter $t$");
plt.ylabel("Angle error (degrees)");
plt.xlim(0., 1.);
plt.ylim(0., 90.);
plt.savefig("bezier_interpolation.png", dpi=300, bbox_inches="tight");