In [None]:
%load_ext autoreload
%autoreload 2
from models import ResNet
from data import get_data
from utils import top_5_err, opt_create, save_model, load_model
import jax, optax
from flax import nnx
from tqdm.auto import tqdm
from functools import reduce

In [None]:
ds_train = get_data(beta=0, dataset=1, n_clients=1, n_classes=1000)
collate = ds_train.collate_fn
ds_train.collate_fn = lambda *args: jax.tree.map(
    lambda b: b.reshape(-1, *b.shape[2:]),
    collate(*args)
)
ds_val = get_data(beta=0, dataset=1, partition="val", n_clients=1, n_classes=1000)
ds_val.collate_fn = lambda *args: jax.tree.map(
    lambda b: b.reshape(-1, *b.shape[2:]),
    collate(*args)
)
ds_test = get_data(beta=0, dataset=1, partition="test", n_clients=1, n_classes=1000)
ds_test.collate_fn = lambda *args: jax.tree.map(
    lambda b: b.reshape(-1, *b.shape[2:]),
    collate(*args)
)

model = ResNet(jax.random.key(42), layers=[3,4,6,3], dim_out=1000)
opt = opt_create(model, learning_rate=1e-3)

def ce(model, y, x):
    return optax.softmax_cross_entropy(model(x, train=True), y).mean()
@nnx.jit
def train_step(model, opt, y, x):
    loss, grad = nnx.value_and_grad(ce)(model, y, x)
    opt.update(grad)
    return loss

In [None]:
max_patience = 3
patience = 1
val_losses = []
epoch = 0
while patience<=max_patience:
    # Iterate over batches
    for batch, (y, x) in enumerate(bar := tqdm(ds_train, leave=False)):
        # Train step
        loss = train_step(model, opt, y, x)
        # Inform user
        bar.set_description(f"Epoch {epoch} (local validation score: {'N/A' if epoch==0 else val}, local batch loss: {loss.mean():.4f})")
    # Evaluate on local validation
    val = reduce(lambda a, batch: a+top_5_err(model, *batch).mean(), ds_val, 0.)
    val /= len(ds_val)
    val_losses.append(val)
    # Check if local models are converged
    if epoch>=1 and val>=val_losses[-patience-1]:
        patience += 1
    else:
        patience = 1
        save_model(model, "models/imagenet1k_resnet34_centralized.pkl")
    epoch += 1

model = load_model(
    lambda: ResNet(jax.random.key(42), layers=[3,4,6,3], dim_out=1000), 
    "models/imagenet1k_resnet34_centralized.pkl"
)
val_fn = nnx.jit(top_5_err)
val = reduce(lambda e, batch: e + val_fn(model, *batch), ds_test, 0.) / len(ds_test)
print(f"Final test top-5 error: {val:.2f}%")