In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import optax, jax, numpy as np, pickle
from flax import nnx
from jax import numpy as jnp
from vistool import plot_trajectory, compute_surface, get_reconstruct
from sklearn.decomposition import IncrementalPCA
from fedflax import train
from tqdm.auto import tqdm
from data import fetch_data
from models import LeNet
from utils import angle_err, opt_create, return_l2
from functools import partial
N_CLIENTS = 4 # number of clients

## Train

In [None]:
# For all beta-epoch settings
for beta in [.02,.5,1.]:
    # Create data
    ds_train = fetch_data(skew="feature", beta=beta, n_clients=N_CLIENTS)
    ds_val = fetch_data(skew="feature", beta=beta, partition="val", batch_size=16, n_clients=N_CLIENTS)
    for local_epochs in [5,10,20]:
        # Optimize
        train(LeNet(jax.random.key(42)), partial(opt_create, learning_rate=1e-3), ds_train, return_l2(0.), ds_val, local_epochs, filename=f"grids/MPIIGaze/params_omega{0.}_beta{beta}_local{local_epochs}.npy", n_clients=N_CLIENTS, rounds="early", max_patience=5, val_fn=angle_err)

# Three specific omega settings
for omega in [.025,.1]:
    # Create data
    ds_train = fetch_data(skew="feature", beta=.5, n_clients=N_CLIENTS)
    ds_val = fetch_data(skew="feature", beta=.5, partition="val", batch_size=16, n_clients=N_CLIENTS)
    # Optimize
    train(LeNet(jax.random.key(42)), partial(opt_create, learning_rate=1e-3), ds_train, return_l2(omega), ds_val, local_epochs=10, filename=f"grids/MPIIGaze/params_omega{omega}_beta{.5}_local{10}.npy", n_clients=N_CLIENTS, rounds="early", max_patience=5, val_fn=angle_err)

## Visualizations

In [None]:
# Create PCA model
b = 50
pca = IncrementalPCA(2, whiten=True, batch_size=b)
# Fetch the beta-epoch parameters
for beta in tqdm([1.,.5,.02]):
    for epochs in tqdm([20,10,5], leave=False):
        # Load params from entire trajectory and cut the patience epochs
        fp = np.load(f"grids/MPIIGaze/params_omega{0.}_beta{beta}_local{epochs}.npy", mmap_mode="r")[:-5*N_CLIENTS*epochs] # 5 is patience
        # Fit PCA (batch-wise due to memory constraints)
        for batch in tqdm(np.array_split(fp, fp.shape[0]//b+1), leave=False):
            pca = pca.partial_fit(batch)
# Ditto for the omega settings
for omega in tqdm([.025,.1]):
    fp = np.load(f"grids/MPIIGaze/params_omega{omega}_beta{.5}_local{10}.npy", mmap_mode="r")[:-5*N_CLIENTS*10]
    for batch in tqdm(np.array_split(fp, fp.shape[0]//b+1), leave=False):
        pca = pca.partial_fit(batch)
# Save PCA model
pickle.dump(pca, open("grids/MPIIGaze/pca.pkl", "wb"))

# Reconstruct function
reconstruct = get_reconstruct(LeNet(jax.random.key(42)), False)

# For each beta-epoch setting, plot PCA training trajectory
pca = pickle.load(open("grids/MPIIGaze/pca.pkl", "rb"))
for beta in [.02,.5,1.]:
    # Recreate data
    ds_test = fetch_data(beta=beta, partition="test", batch_size=16, n_clients=N_CLIENTS, skew="feature")
    # Fix axes limits per local_epochs
    paramses_trans = [pca.transform(np.load(f"grids/MPIIGaze/params_omega{0.}_beta{beta}_local{epochs}.npy", mmap_mode="r")[:-5*N_CLIENTS*epochs]) for epochs in [5,10,20]]
    alpha_min, beta_min = min([p[:,0].min() for p in paramses_trans]), min([p[:,1].min() for p in paramses_trans])
    alpha_max, beta_max = max([p[:,0].max() for p in paramses_trans]), max([p[:,1].max() for p in paramses_trans])
    # Calculate error surface per data heterogeneity level
    alpha_grid = jnp.linspace(alpha_min, alpha_max, 20)
    beta_grid = jnp.linspace(beta_min, beta_max, 20)
    err_grid, alpha_grid, beta_grid = compute_surface(alpha_grid, beta_grid, pca, reconstruct, ds_test, val_fn=angle_err)
    # Plot for each epoch setting
    for params_trans, epochs in zip(paramses_trans, [5,10,20]):
        # Plot
        model_idx = jnp.arange(params_trans.shape[0])%N_CLIENTS
        plot_trajectory(err_grid, model_idx, epochs, params_trans, alpha_grid, beta_grid, labels=False, 
                        filename=f"grids/MPIIGaze/plot_beta{beta}_local{epochs}.png")

# Plot PCA for each omega setting, at fixed beta and epochs
ds_test = fetch_data(skew="feature", beta=.5, partition="test", batch_size=16, n_clients=N_CLIENTS)
# Fix axes limits per local_epochs
paramses_trans = [pca.transform(np.load(f"grids/MPIIGaze/params_omega{omega}_beta{.5}_local{10}.npy", mmap_mode="r")[:-5*N_CLIENTS*10]) for omega in [0.,.025,.1]]
alpha_min, beta_min = min([p[:,0].min() for p in paramses_trans]), min([p[:,1].min() for p in paramses_trans])
alpha_max, beta_max = max([p[:,0].max() for p in paramses_trans]), max([p[:,1].max() for p in paramses_trans])
# Error surface
alpha_grid = jnp.linspace(alpha_min, alpha_max, 20)
beta_grid = jnp.linspace(beta_min, beta_max, 20)
err_grid, alpha_grid, beta_grid = compute_surface(alpha_grid, beta_grid, pca, reconstruct, ds_test, val_fn=angle_err)
for params_trans, omega in zip(paramses_trans, [0.,.025,.1]):
    # Plot
    model_idx = jnp.arange(params_trans.shape[0])%N_CLIENTS
    plot_trajectory(err_grid, model_idx, 10, params_trans, alpha_grid, beta_grid, labels=False, 
                    filename=f"grids/MPIIGaze/plot_omega{omega}.png")