In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import optax, jax, numpy as np, pickle
from flax import nnx
from jax import numpy as jnp
from vistool import plot_trajectory, compute_surface, get_reconstruct
from sklearn.decomposition import IncrementalPCA
from fedflax import train
from tqdm.auto import tqdm
from data import get_gaze
from models import LeNet
n = 4 # number of clients

## Train

In [None]:
# Optimizer
opt = lambda model: nnx.Optimizer(
    model,
    optax.adamw(learning_rate=1e-3),
    wrt=nnx.Param
)

# Loss includes softmax layer
def return_ell(omega):
    def ell(model, model_g, x_batch, z_batch, y_batch, train):
        prox = sum(jax.tree.map(lambda a, b: jnp.sum((a-b)**2), jax.tree.leaves(nnx.to_tree(model)), jax.tree.leaves(nnx.to_tree(model_g))))
        ce = optax.softmax_cross_entropy(model(x_batch, z_batch, train=train), y_batch).mean()
        return omega/2*prox + ce
    return ell

# For all beta-epoch settings
for beta in [.02,.5,1.]:
    # Create data
    ds_train = get_gaze(beta=beta)
    ds_val = get_gaze(beta=beta, partition="val", batch_size=16)
    for local_epochs in [5,10,20]:
        # Optimize
        train(LeNet(nnx.Rngs(42)), opt, ds_train, ds_val, return_ell(0.), local_epochs, f"grids/MPIIGaze/params_omega{0.}_beta{beta}_local{local_epochs}.npy", rounds=20, max_patience=5)

# Three specific omega settings
for omega in [.025,.1]:
    # Create data
    ds_train = get_gaze(beta=.5)
    ds_val = get_gaze(beta=.5, partition="val", batch_size=16)
    # Optimize
    train(LeNet(nnx.Rngs(42)), opt, ds_train, ds_val, return_ell(omega), 10, f"grids/MPIIGaze/params_omega{omega}_beta{.5}_local{10}.npy", rounds=20, max_patience=5)

## Visualizations

In [None]:
# Create PCA model
b = 50
pca = IncrementalPCA(2, whiten=True, batch_size=b)
# Fetch the beta-epoch parameters
for beta in tqdm([1.,.5,.02]):
    for epochs in tqdm([20,10,5], leave=False):
        # Load params from entire trajectory and cut the patience epochs
        fp = np.load(f"grids/MPIIGaze/params_omega{0.}_beta{beta}_local{epochs}.npy", mmap_mode="r")[:-5*n*epochs] # 5 is patience
        # Fit PCA (batch-wise due to memory constraints)
        for batch in tqdm(np.array_split(fp, fp.shape[0]//b+1), leave=False):
            pca = pca.partial_fit(batch)
# Ditto for the omega settings
for omega in tqdm([.025,.1]):
    fp = np.load(f"grids/MPIIGaze/params_omega{omega}_beta{.5}_local{10}.npy", mmap_mode="r")[:-5*n*10]
    for batch in tqdm(np.array_split(fp, fp.shape[0]//b+1), leave=False):
        pca = pca.partial_fit(batch)
# Save PCA model
pickle.dump(pca, open("grids/MPIIGaze/pca.pkl", "wb"))

# Reconstruct function
reconstruct = get_reconstruct(LeNet(nnx.Rngs(42)), False)

# For each beta-epoch setting, plot PCA training trajectory
pca = pickle.load(open("grids/MPIIGaze/pca.pkl", "rb"))
for beta in [.02,.5,1.]:
    # Recreate data
    ds_test = get_gaze(beta=beta, partition="test", batch_size=16)
    # Fix axes limits per local_epochs
    paramses_trans = [pca.transform(np.load(f"grids/MPIIGaze/params_omega{0.}_beta{beta}_local{epochs}.npy", mmap_mode="r")[:-5*n*epochs]) for epochs in [5,10,20]]
    alpha_min, beta_min = min([p[:,0].min() for p in paramses_trans]), min([p[:,1].min() for p in paramses_trans])
    alpha_max, beta_max = max([p[:,0].max() for p in paramses_trans]), max([p[:,1].max() for p in paramses_trans])
    # Calculate error surface per data heterogeneity level
    alpha_grid = jnp.linspace(alpha_min, alpha_max, 20)
    beta_grid = jnp.linspace(beta_min, beta_max, 20)
    err_grid = compute_surface(alpha_grid, beta_grid, pca, reconstruct, ds_test)
    # Plot for each epoch setting
    for params_trans, epochs in zip(paramses_trans, [5,10,20]):
        # Plot
        model_idx = jnp.arange(params_trans.shape[0])%n
        plot_trajectory(err_grid, model_idx, epochs, paramses_trans, alpha_grid, beta_grid, labels=False, 
                        filename=f"grids/MPIIGaze/plot_beta{beta}_local{epochs}.png")

# Plot PCA for each omega setting, at fixed beta and epochs
ds_test = get_gaze(beta=.5, partition="test", batch_size=16)
# Fix axes limits per local_epochs
paramses_trans = [pca.transform(np.load(f"grids/MPIIGaze/params_omega{omega}_beta{.5}_local{10}.npy", mmap_mode="r")[:-5*n*10]) for omega in [0.,.025,.1]]
alpha_min, beta_min = min([p[:,0].min() for p in paramses_trans]), min([p[:,1].min() for p in paramses_trans])
alpha_max, beta_max = max([p[:,0].max() for p in paramses_trans]), max([p[:,1].max() for p in paramses_trans])
# Error surface
alpha_grid = jnp.linspace(alpha_min, alpha_max, 20)
beta_grid = jnp.linspace(beta_min, beta_max, 20)
err_grid = compute_surface(alpha_grid, beta_grid, pca, reconstruct, ds_test)
for params_trans, omega in zip(paramses_trans, [0.,.025,.1]):
    # Plot
    model_idx = jnp.arange(params_trans.shape[0])%n
    plot_trajectory(err_grid, model_idx, paramses_trans, alpha_grid, beta_grid, labels=False, 
                    epochs=10, filename=f"grids/MPIIGaze/plot_omega{omega}.png")