In [None]:
%load_ext autoreload
%autoreload 2
from fedflax import train
from models import ResNet, ResNetAutoEncoder
from data import fetch_data
from utils import load_model, return_ce, mean_iou_err
import jax, optax, pickle
from jax import numpy as jnp
from flax import nnx

## Fetch foundation model
Trained using the imagenet script.

In [None]:
# Reload backbone and use as encoder
models = load_model(
    lambda: ResNet(layers=[2,2,2,2], dim_out=100), 
    "models/resnet18_central_imagenet100.pkl"
)
struct, params, rest = nnx.split(models, (nnx.Param, nnx.BatchStat), ...)
model = nnx.merge(
    struct,
    jax.tree.map(lambda p: p.mean(0), params),
    rest
)

## Alternatively, fetch ViT-224 foundation model
Requires installing https://github.com/google-research/vision_transformer.

The weights are available at https://console.cloud.google.com/storage/browser/vit_models/imagenet21k. Any version should do, if you change the config accordingly.


In [None]:
from vit_jax.models_vit import VisionTransformer
from functools import partial
from ml_collections.config_dict import ConfigDict
from packaging import version
import flax, sys
# This is a workaround to import the `load_pretrained` without installing tensorflow
flax.io.gfile = flax.io
sys.modules["tensorflow.io"] = flax.io
from vit_jax.checkpoint import load, inspect_params, _fix_groupnorm

# Config copied from the ViT-B_16 at https://github.com/google-research/vision_transformer/blob/main/vit_jax/configs/models.py#L113
config = ConfigDict({
    "num_classes": 0, # No classification head
    "patches": ConfigDict({"size": (16, 16)}),
    "model_name": "ViT-B_16",
    "transformer": ConfigDict(
        {"mlp_dim": 3072, "num_heads": 12, "num_layers": 12, "attention_dropout_rate": 0.0, "dropout_rate": 0.0}
    ),
    "classifier": "token",
    "representation_size": None,
    "hidden_size": 768
})
model = VisionTransformer(**config)
reference_params = model.init(jax.random.key(42), jnp.ones((1,224,224,3)), train=False)["params"]
# Dumbed down version of `load_pretrained`
# Head is explicitly removed
# Case where posemb_new.shape!=posemb.shape is not handled
params = _fix_groupnorm(inspect_params(
    params=load("models/ViT-B_16.npz"),
    expected=reference_params,
    fail_if_extra=False,
    fail_if_missing=False))
if config.get("representation_size") is None and "pre_logits" in params:
    params["pre_logits"] = {}
if version.parse(flax.__version__) >= version.parse("0.3.6"):
    params = _fix_groupnorm(params)
params.pop("head")
params = flax.core.freeze(params)
# Define inference function
infer_fn = jax.jit(partial(model.apply, train=False))

## Finetune using asymmetries
Compare with and without asymmetries in the ResNetAutoEncoder.

In [None]:
# Cityscapes data
n_clients = 3
ds_train = fetch_data(beta=1., dataset=2, n_clients=n_clients, batch_size=32)
ds_val = fetch_data(beta=1., dataset=2, partition="val", n_clients=n_clients, batch_size=16)

# Autoencoder model for segmentation via image reconstruction
asymkwargs = {}
ae = ResNetAutoEncoder(backboneencoder=model, key=jax.random.key(43), **asymkwargs)

# Optimizer with lower lr for pretrained backbone
lr = optax.warmup_exponential_decay_schedule(1e-4, .1, 4000, 1000, .9, end_value=1e-5)
def opt_create(ae:ResNetAutoEncoder):
    return nnx.Optimizer(
        ae,
        optax.chain(
            optax.masked(optax.adamw(1e-4), lambda ptree: jax.tree.map_with_path(lambda path, _p: "backboneencoder" in path, ptree)),
            optax.masked(optax.adamw(lr), lambda ptree: jax.tree.map_with_path(lambda path, _p: not "backboneencoder" in path, ptree))
        )
    )

# Train
aes, rounds = train(
    ae,
    opt_create,
    ds_train,
    return_ce(0.), 
    ds_val,
    local_epochs="early",
    n_clients=n_clients,
    max_patience=3,
    rounds="early",
    val_fn=mean_iou_err
)

# Save decoder
state = nnx.state(aes, nnx.Not(nnx.PathContains("backboneencoder")))
pickle.dump(state, open("models/cs_rn18_decoder.pkl", "wb"))

# Reload and aggregate it
load_fn = lambda: ResNetAutoEncoder(
    backboneencoder=load_model(
        lambda: ResNet(layers=[2,2,2,2], dim_out=100), 
        "models/resnet18_central_imagenet100.pkl"
    ),
    **asymkwargs
)
aes = load_model(load_fn, "models/cs_rn18_autoencoder.pkl")