# Post-federated learning scalable foundation model deployment
In this tutorial we will show you one of the key contribution of RaFFM:

**Quickly get a resource-aware FM from RaFFM without additional training**


This is very important in FL, that means for a newly join client with resource constraint, RaFFM can quickly deploy a high-performance model to it without additional training

This Tutorial covers how to use the RaFFM APIs, if you want to quickly check the experiments, jump to [Experiment Results](post_training_deployment.ipynb#experiment) at the end of the notebook.

## Step 1: import dependency libary

In [None]:
import numpy as np
from datasets import load_metric
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
from raffm import RaFFM
import torch
import os


Define a utility function to calculate number of parameters in a model

In [2]:
from torch.nn import Parameter

def calculate_params(model):
    """calculate the number of parameters in the model
    Args:
        model: the model to be evaluated
    Returns:
        total_params: the number of parameters in the model
        percentage: the percentage of trainable parameters in the model
    """

    millions = 1000000
    total_params = 0
    for name, module in model.named_modules():
        if hasattr(module, "weight") and isinstance(module.weight, Parameter):
            total_params += torch.prod(torch.tensor(module.weight.size())).item()

    return total_params / millions


## Step 3. Load dataset and pre-process
First define the dataset process function for datasets from [huggingface datasets API](https://huggingface.co/)

In [3]:

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }
    
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
    inputs = processor([x for x in example_batch['img']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['label']
    return inputs

Then load and preprocess the datapoints from images to torch tensors

In [4]:
dataset = load_dataset('cifar10')
train_val = dataset["train"].train_test_split(test_size=0.2,seed=123)

dataset['train'] = train_val["train"]
dataset["validation"] = train_val["test"]

prepared_ds = dataset.with_transform(transform)


Check the dataset status

In [5]:

prepared_ds


DatasetDict({
    train: Dataset({
        features: ['img', 'label'],
        num_rows: 40000
    })
    test: Dataset({
        features: ['img', 'label'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['img', 'label'],
        num_rows: 10000
    })
})

## Step 3. Initialize a FM
Here we first initialize a Vision Transformer (ViT), and later we will load the scalable ViT (FM) checkpoints trained by RaFFM

In [6]:
ckpt_path = 'google/vit-base-patch16-224-in21k'
labels = dataset['train'].features['label'].names

model = ViTForImageClassification.from_pretrained(
    ckpt_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Step 3. Download scalable FM checkpoints
You can find the ckpts in [README.MD](./https://github.com/yusx-swapp/RaFFM/blob/47708fe75cdb432e805982fc73954fde1d6aa960/experiments/post_training_deployment/README.MD), we provide the Scalable FM checkpoints trained by RaFFM in various FL-edge budget settings.

In this example, we use the scalable FM checkpoints downloaded from **[here](./https://github.com/yusx-swapp/RaFFM/blob/47708fe75cdb432e805982fc73954fde1d6aa960/experiments/post_training_deployment/README.MD/#download-the-scalable-fm-checkpoints)**


## Step 4. Load scalable FM checkpoint

Next we load the scalable FMs weights trained by RaFFM, and convert it to a RaFFM scalable network

In [7]:
# ckpt_path = "the_downloaded_ckpt_path"
ckpt_path = "ckpt/cifar10/vit_base_small_budget"
elastic_config = os.path.join( ckpt_path,"elastic_space.json")
model = model.from_pretrained(ckpt_path)

raffm_model = RaFFM(model.to("cpu"),elastic_config)
print("Original FM number of parameters:",raffm_model.total_params)

Original FM number of parameters: 85.55136


# Experiment

## Sample a scaled FM 

In [8]:
scaled_model,_,_ = raffm_model.random_resource_aware_model()
params = calculate_params(scaled_model)
print("scaled model params",params)

scaled model params 75.032832


# Evaluate the sampled scaled FM

In [9]:

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)


  metric = load_metric("accuracy")


In [10]:
training_args = TrainingArguments(
  output_dir="./log/debug",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

trainer = Trainer(
    model=scaled_model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    # tokenizer=processor,
)



In [11]:
metrics = trainer.evaluate()

100%|██████████| 1250/1250 [07:51<00:00,  2.65it/s]


We get **97.60%** accuracy from the scaled FM, **without further training**!

In [12]:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  eval_accuracy           =      0.976
  eval_loss               =     0.0787
  eval_runtime            = 0:07:52.35
  eval_samples_per_second =     21.171
  eval_steps_per_second   =      2.646


## Sample another scaled FMs and evaluation
You can sample as much as possible scaled from RaFFM model

In [13]:
scaled_model,_,_ = raffm_model.random_resource_aware_model()

params = calculate_params(scaled_model)
print("scaled model params",params)

scaled model params 75.917568


In [14]:

trainer = Trainer(
    model=scaled_model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    # tokenizer=processor,
)
metrics = trainer.evaluate()


100%|██████████| 1250/1250 [07:32<00:00,  2.76it/s]


In the second scaled FM we get **97.58%** Accuracy.

In [15]:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  eval_accuracy           =     0.9758
  eval_loss               =     0.0788
  eval_runtime            = 0:07:33.11
  eval_samples_per_second =     22.069
  eval_steps_per_second   =      2.759
