In [None]:
%load_ext autoreload
%autoreload 2
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
import optax, jax, numpy as np, pickle
from flax import nnx
from jax import numpy as jnp
from sklearn.decomposition import PCA
from fedflax import train
from data import get_gaze
from models import LeNet, ResNet
from functools import reduce
import matplotlib as mpl
from matplotlib import pyplot as plt

In [None]:
# Optimizer
opt_create = lambda model: nnx.Optimizer(
    model,
    optax.adamw(learning_rate=1e-3),
    wrt=nnx.Param
)

# Loss includes softmax layer
def ell(model, _, x_batch, z_batch, y_batch, train):
    ce = optax.softmax_cross_entropy(model(x_batch, z_batch, train=train), y_batch).mean()
    return ce, (0., 0.)

# Train
ds_train = get_gaze(beta=.5)
ds_val = get_gaze(beta=.5, partition="val")
_, models = train(ResNet, opt_create, ds_train, ds_val, ell, local_epochs=50, rounds=1)
paramses, struct = jax.tree.flatten(nnx.to_tree(models))

# Perform PCA
pca = PCA(n_components=2, whiten=True)
paramses_trans = pca.fit_transform(
    np.concat([p.reshape(4,-1) for p in paramses], axis=1)
)

# Function to reconstruct model from flat params
shapes = [p.shape[1:] for p in paramses]+[None]
def reconstruct(flat_params):
    # Indices of kernels in flat vector
    slices = [slice
        (sum(map(lambda s: np.prod(s), shapes[:i])),
        sum(map(lambda s: np.prod(s), shapes[:i+1])))
    for i in range(len(shapes)-1)]
    # Get kernels as correct shape
    params = [flat_params[sl] for sl in slices]
    params = [jnp.array(p).reshape(s) for p, s in zip(params, shapes)]
    # Revert to model
    return nnx.from_tree(jax.tree.unflatten(struct, params))

# Set up grid for error surface
points = 30
x_min, x_max = paramses_trans[:, 0].min(), paramses_trans[:, 0].max()
y_min, y_max = paramses_trans[:, 1].min(), paramses_trans[:, 1].max()
alpha_grid = jnp.linspace(x_min-1/4*jnp.abs(x_min), x_max+1/4*jnp.abs(x_max), points)
beta_grid = jnp.linspace(y_min-1/4*jnp.abs(y_min), y_max+1/4*jnp.abs(y_max), points)
errs = jnp.zeros((points, points, 4))
# For sampled points on the 2d plane, compute the accuracy
ds_test = get_gaze(beta=.5, partition="test")
acc_fn = nnx.jit(nnx.vmap(lambda m,x,z,y: (m(x,z,train=False).argmax(-1)==y.argmax(-1)).mean(), in_axes=(None,0,0,0)))
for i, alpha in enumerate(alpha_grid):
    for j, beta in enumerate(beta_grid):
        # Reconstruct the model for some point in the 2d plane
        params = pca.inverse_transform(jnp.array([[alpha, beta]])).reshape(-1)
        models = reconstruct(params)
        # Compute accuracy
        acc = reduce(lambda acc, b: acc + acc_fn(models,*b), ds_test, 0.) / len(ds_test)
        errs = errs.at[i,j,:].set(1-acc) # do not take mean over clients

# Contour plot
fig, ax  = plt.subplots(dpi=300)
levels = jnp.log(jnp.linspace(jnp.exp(errs.min()), jnp.exp(errs.max()), 10))
colors = ["blue", "red", "yellow", "green"]
for i in range(errs.shape[-1]):
    # Make sure the optimum is displayed as a contour
    optimum = errs[jnp.abs(alpha_grid - paramses_trans[i,0]).argmin(), jnp.abs(beta_grid - paramses_trans[i,1]).argmin(), i]
    levels = jnp.log(jnp.linspace(jnp.exp(errs[...,i].min()), jnp.exp(optimum), 4))
    # Opacity increases with accuracy
    cmap = jnp.repeat(mpl.colors.to_rgba_array(colors[i]), 256, axis=0)
    cmap = cmap.at[:,-1].set(jnp.linspace(0,1/4,256))
    cmap = mpl.colors.LinearSegmentedColormap.from_list(name=f"C{i}_alpha", colors=cmap.tolist())
    # Plot
    ax.contourf(
        alpha_grid,
        beta_grid,
        errs[...,i].T,
        levels=levels,
        cmap=cmap,
        norm=mpl.colors.LogNorm(vmin=errs[...,i].min(), vmax=errs[...,i].max())
    )
# Plot the optima
ax.scatter(
    *paramses_trans.T,
    c=[colors[i] for i in range(errs.shape[-1])],
)
ax.set_xticks([]);
ax.set_yticks([]);