In [None]:
%load_ext autoreload
%autoreload 2
from jax import config
config.update("jax_platforms", "cpu")
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import jax, optax, pickle
from jax import numpy as jnp
from flax import nnx
from fedflax import train
from models import LeNet
from data import fetch_data
from utils import opt_create, return_l2
from tqdm.auto import tqdm

## Bezier interpolation setup

In [None]:
# Bezier function
def bezier(t):
    def inner(theta1, theta2, w):
        return (1 - t) ** 2 * theta1 + 2 * (1 - t) * t * w + t ** 2 * theta2
    return inner

# Loss is calculated using the model resulting from a Bezier interpolation
def ell_tilde(w, x, y, key):
    t = jax.random.uniform(key, ()) # TODO: should actually differ per sample
    theta = jax.tree.map(bezier(t), theta1, theta2, w)
    model = nnx.merge(struct, theta, rest)
    return optax.softmax_cross_entropy(model(x), y).mean()

# Straightforward
@jax.jit
def train_step(w, x, y, opt_state, key):
    loss, grad = nnx.value_and_grad(ell_tilde)(w, x, y, key)
    updates, opt_state = opt.update(grad, opt_state, w)
    w = optax.apply_updates(w, updates)
    return loss, opt_state, w

## Train

In [None]:
for seed in range(10):
    # Train model
    model = LeNet(jax.random.key(seed))
    opt = opt_create(model, learning_rate=1e-3)
    ds_train = fetch_data(skew="feature", beta=1., n_clients=2)
    ds_val = fetch_data(skew="feature", beta=1., partition="val", batch_size=16, n_clients=2)
    ds_test = fetch_data(skew="feature", beta=1., partition="test", batch_size=16, n_clients=2)
    _, models = train(model, opt, ds_train, return_l2(0.), ds_val, local_epochs="early", n_clients=2, max_patience=4, rounds="early")

    # Get client models
    struct, theta, rest = nnx.split(models, (nnx.Param, nnx.BatchStat), ...)
    theta1 = jax.tree.map(lambda p: p[0], theta)
    theta2 = jax.tree.map(lambda p: p[1], theta)

    # Bezier parameter and optimizer
    key = jax.random.key(seed)
    def rand_like(x): global key; key, subkey = jax.random.split(key); return jax.random.normal(subkey, shape=x.shape)
    w = jax.tree.map(rand_like, theta1)
    opt = optax.adamw(learning_rate=1e-3)
    opt_state = opt.init(w)

    # Optimize Bezier curve
    for epoch in tqdm(range(20)):
        # Concatenate both clients' data
        for x_batch, y_batch in ds_train:
            loss, opt_state, w = train_step(w, x_batch, y_batch, opt_state, key)
            _, key = jax.random.split(key)
    
    pickle.dump((theta1, theta2, w), open(f"bezier_model_seed{seed}.pkl", "wb"))