In [None]:
%load_ext autoreload
%autoreload 2
from fedflax import train
from models import ResNet, ResNetAutoEncoder, ViTAutoEncoder
from data import fetch_data
from utils import load_model, return_ce, mean_iou_err
import jax, optax, pickle
from flax import nnx

n_clients = 3
asymkwargs = {"sigma":1e-3}

## Fetch foundation model
ResNets rained using the imagenet script.

In [None]:
# Reload resnet backbone and use as encoder
backbones = load_model(
    lambda: ResNet(layers=[3,4,6,3], dim_out=1000), 
    "models/resnet34_central_imagenet1000.pkl"
)
struct, params, rest = nnx.split(backbones, (nnx.Param, nnx.BatchStat), ...)
backbone = nnx.merge(
    struct,
    jax.tree.map(lambda p: p.mean(0), params),
    rest
)
# Plug into autoencoder
asymkwargs = {}
ae = ResNetAutoEncoder(backboneencoder=backbone, key=jax.random.key(43), **asymkwargs)
# Optimizer with lower lr for pretrained backbone
lr = optax.warmup_exponential_decay_schedule(1e-4, .1, 4000, 1000, .9, end_value=1e-5)
opt = nnx.Optimizer(
    ae,
    optax.chain(
        optax.masked(optax.adamw(1e-4), lambda ptree: jax.tree.map_with_path(lambda path, _p: "backboneencoder" in path, ptree)),
        optax.masked(optax.adamw(lr), lambda ptree: jax.tree.map_with_path(lambda path, _p: not "backboneencoder" in path, ptree))
    )
)

## Alternatively, use Google's ViT as backbone

In [None]:
# Reloading of ViT is delegated
ae = ViTAutoEncoder(**asymkwargs)
# Optimizer which applies only to nnx.Param, leaving linen backbone untouched
lr = optax.warmup_exponential_decay_schedule(1e-4, .1, 4000, 1000, .9, end_value=1e-5)
opt = nnx.Optimizer(
    ae,
    optax.adamw(lr),
    wrt=nnx.Param
)

## Finetune using asymmetries
Compare with and without asymmetries in the ResNetAutoEncoder.

In [None]:
# Cityscapes data
ds_train = fetch_data(beta=1., dataset=2, n_clients=n_clients, batch_size=16)
ds_val = fetch_data(beta=1., dataset=2, partition="val", n_clients=n_clients, batch_size=16)

# Train
aes, rounds = train(
    ae,
    opt,
    ds_train,
    return_ce(0.), 
    ds_val,
    local_epochs="early",
    n_clients=n_clients,
    max_patience=3,
    rounds="early",
    val_fn=mean_iou_err
)

# Save decoder
state = nnx.state(aes, nnx.Not((nnx.PathContains("backboneencoder"), nnx.PathContains("params"))))
pickle.dump(state, open("models/cs_rn34_syre.pkl", "wb"))

## Reload

In [None]:
### OPTION 1: ResNet AutoEncoder ###
load_fn = lambda: ResNetAutoEncoder(
    backboneencoder=load_model(
        lambda: ResNet(layers=[3,4,6,3], dim_out=1000), 
        "models/resnet34_central_imagenet1000.pkl"
    ),
    **asymkwargs
)
abstract_model = nnx.eval_shape(load_fn)
struct, state, rest = nnx.split(abstract_model, nnx.Not(nnx.PathContains("backboneencoder")), ...)
with open("models/cs_rn34_syre.pkl", "rb") as f:
    state = pickle.load(f)
aes = nnx.merge(struct, state, rest)

### OPTION 2: ViT AutoEncoder ###
load_fn = lambda: ViTAutoEncoder(**asymkwargs)
aes = load_model(load_fn, "models/cs_vit_decoder.pkl")