# Post-federated learning scalable foundation model deployment
In this tutorial we will show you one of the key contribution of RaFFM:

**Quickly get a resource-aware FM from RaFFM without additional training**


This is very important in FL, that means for a newly join client with resource constraint, RaFFM can quickly deploy a high-performance model to it without additional training

This Tutorial covers how to use the RaFFM APIs, if you want to quickly check the experiments, jump to [Experiment Results](post_training_deployment.ipynb#experiment) at the end of the notebook.



## Step 1: import dependency libary

In [1]:
import numpy as np
from datasets import load_metric
from transformers import BertForSequenceClassification, BertTokenizerFast, TrainingArguments, Trainer
from datasets import load_dataset
from raffm import RaFFM
import torch
import os


Define a utility function to calculate number of parameters in a model

In [2]:
from torch.nn import Parameter

def calculate_params(model):
    """calculate the number of parameters in the model
    Args:
        model: the model to be evaluated
    Returns:
        total_params: the number of parameters in the model
        percentage: the percentage of trainable parameters in the model
    """

    millions = 1000000
    total_params = 0
    for name, module in model.named_modules():
        if hasattr(module, "weight") and isinstance(module.weight, Parameter):
            total_params += torch.prod(torch.tensor(module.weight.size())).item()

    return total_params / millions


## Step 3. Load dataset and pre-process
First define the dataset process function for datasets from [huggingface datasets API](https://huggingface.co/)

In [3]:
def tokenize_function(examples, tokenizer):
    if "sentence" in examples.keys() and "question" not in examples.keys():
        return tokenizer(
            examples["sentence"],
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
    elif "premise" in examples.keys() and "hypothesis" in examples.keys():
        return tokenizer(
            examples["premise"],
            examples["hypothesis"],
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
    elif "question1" in examples.keys() and "question2" in examples.keys():
        return tokenizer(
            examples["question1"],
            examples["question2"],
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
    elif "question" in examples.keys() and "sentence" in examples.keys():
        return tokenizer(
            examples["question"],
            examples["sentence"],
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
    elif "sentence1" in examples.keys() and "sentence2" in examples.keys():
        return tokenizer(
            examples["sentence1"],
            examples["sentence2"],
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")


Then load and preprocess the datapoints from images to torch tensors

In [4]:
dataset = load_dataset("glue", "sst2")
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
tokenize_val_dataset = val_dataset.map(
    lambda examples: tokenize_function(examples, tokenizer),
    batched=True,
)


Check the dataset status

In [5]:

tokenize_val_dataset


Dataset({
    features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 872
})

## Step 3. Initialize a FM
Here we first initialize a Vision Transformer (ViT), and later we will load the scalable ViT (FM) checkpoints trained by RaFFM

In [6]:
model_name = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(
    model_name, num_labels=2
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Step 3. Download scalable FM checkpoints
You can find the ckpts in [README.MD](./README.MD), we provide the Scalable FM checkpoints trained by RaFFM in various FL-edge budget settings.

In this example, we use the scalable FM checkpoints downloaded from **[here](./README.MD/#download-the-scalable-fm-checkpoints)**


## Step 4. Load scalable FM checkpoint

Next we load the scalable FMs weights trained by RaFFM, and convert it to a RaFFM scalable network

In [7]:
ckpt_path = "ckpts/bert_base_sst2_small_budget"

elastic_config = os.path.join( ckpt_path,"elastic_space.json")
model = model.from_pretrained(ckpt_path)

raffm_model = RaFFM(model.to("cpu"),elastic_config)
print("Original FM number of parameters:",raffm_model.total_params)

Original FM number of parameters: 109.380864


# Experiment

## Sample a scaled FM 

In [8]:
scaled_model,_,_ = raffm_model.random_resource_aware_model()
params = calculate_params(scaled_model)
print("scaled model params",params)

scaled model params 81.364224


# Evaluate the sampled scaled FM

In [9]:
from sklearn.metrics import accuracy_score, matthews_corrcoef
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions.argmax(-1)
    return {"accuracy": accuracy_score(labels, predictions)}


In [10]:

training_args = TrainingArguments(
    output_dir="./log/debug",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    evaluation_strategy="no",
    save_strategy="no",
    learning_rate=2e-4,
    num_train_epochs=1,
    weight_decay=0.01,
    report_to="none",
)

trainer = Trainer(
    model=scaled_model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=tokenize_val_dataset,
    # tokenizer=processor,
)



In [11]:
metrics = trainer.evaluate()

We get **80.73%** accuracy from the scaled FM, **without further training**!

In [12]:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  eval_accuracy           =     0.8108
  eval_loss               =     0.5055
  eval_runtime            = 0:00:06.35
  eval_samples_per_second =    137.299
  eval_steps_per_second   =      4.409


## Sample another scaled FMs and evaluation
You can sample as much as possible scaled from RaFFM model

In [13]:
scaled_model,_,_ = raffm_model.sample_smallest_model()

params = calculate_params(scaled_model)
print("scaled model params",params)

scaled model params 66.618624


In [14]:

trainer = Trainer(
    model=scaled_model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=tokenize_val_dataset,
    # tokenizer=processor,
)
metrics = trainer.evaluate()


In the second scaled FM we get **81.31%** Accuracy.

In [15]:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  eval_accuracy           =     0.8131
  eval_loss               =     0.4954
  eval_runtime            = 0:00:04.85
  eval_samples_per_second =    179.589
  eval_steps_per_second   =      5.767
