In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import jax, numpy as np, pickle
from flax import nnx
from jax import numpy as jnp
from vistool import plot_trajectory, compute_surface, get_reconstruct
from sklearn.decomposition import IncrementalPCA
from fedflax import train
from tqdm.auto import tqdm
from data import fetch_data
from models import LeNet
from utils import angle_err, opt_create, return_l2
from functools import partial
from itertools import product
N_CLIENTS = 4 # number of clients

## Train

In [None]:
# Iterate over feature skew settings
for beta, skew in (settings:=product([0., .5, 1.], ["feature", "overlap"])):
    # Create data
    ds_train = fetch_data(skew=skew, beta=beta, n_clients=N_CLIENTS)
    ds_val = fetch_data(skew=skew, beta=beta, partition="val", batch_size=16, n_clients=N_CLIENTS)
    # Optimize
    train(LeNet(jax.random.key(42)), partial(opt_create, learning_rate=1e-3), 
          ds_train, return_l2(0.), ds_val, "early", filename=f"grids/MPIIGaze/params_{skew}{beta}.npy", 
          n_clients=N_CLIENTS, rounds="early", max_patience=5, val_fn=angle_err)

## Visualizations

In [None]:
# Create PCA model
b = 50
pca = IncrementalPCA(2, whiten=True, batch_size=b)
# Fit PCA to all training trajectories
minval, maxval = float("inf"), float("-inf")
for fp in filter(lambda s: s.endswith(".npy"), tqdm(os.listdir("grids/MPIIGaze/"))):
    # Load params from entire trajectory and cut the patience epochs
    fp = np.load(f"grids/MPIIGaze/{fp}", mmap_mode="r")[:-5*N_CLIENTS*...] # 5 is patience, ... is max local epochs
    # Save min/max for subsequent ax fixing
    minval, maxval = jnp.minimum(minval, jnp.min(fp, axis=0)), jnp.maximum(maxval, jnp.max(fp, axis=0))
    # Fit PCA (batch-wise due to memory constraints)
    for batch in np.array_split(fp, fp.shape[0]//b+1):
        pca = pca.partial_fit(batch)
# Save PCA model
pickle.dump(pca, open("grids/MPIIGaze/pca.pkl", "wb"))

# Use feature-wise minimum and maximum to estimate grid limits # TODO: account for negative component values
alpha_min = jnp.sum((minval-pca.mean_)*pca.components_[0])
alpha_max = jnp.sum((maxval-pca.mean_)*pca.components_[0])
beta_min = jnp.sum((minval-pca.mean_)*pca.components_[1])
beta_max = jnp.sum((maxval-pca.mean_)*pca.components_[1])
alpha_grid = jnp.linspace(alpha_min, alpha_max, 30)
beta_grid = jnp.linspace(beta_min, beta_max, 30)

# Reconstruct function
reconstruct = get_reconstruct(LeNet(jax.random.key(42)), False)
# For each setting, plot PCA training trajectory
# pca = pickle.load(open("grids/MPIIGaze/pca.pkl", "rb"))
for beta, skew in tqdm(settings):
    # Recreate data
    ds_test = fetch_data(beta=beta, partition="test", batch_size=16, n_clients=N_CLIENTS, skew=skew)
    # Fix axes limits overall
    params_trans = pca.transform(np.load(f"grids/MPIIGaze/params_{skew}{beta}.npy", mmap_mode="r")[:-5*N_CLIENTS*...])
    # Calculate error surface
    err_grid, alpha_grid, beta_grid = compute_surface(alpha_grid, beta_grid, pca, reconstruct, ds_test, val_fn=angle_err)
    # Plot
    model_idx = jnp.arange(params_trans.shape[0])%N_CLIENTS
    plot_trajectory(err_grid, model_idx, ..., params_trans, alpha_grid, beta_grid, labels=False, 
                    filename=f"grids/MPIIGaze/plot_{skew}{beta}.png")