# Efficient heterogeneous model deployment

In this tutorial we will show you one of the key contribution of OFM:

**Quickly downsize the target FM to meet various resource constraints, without retraining**


To run this Tutorial, you need to install our package:  **`pip install .`**.

## Step 1: import dependency libary

In [1]:

import numpy as np
from datasets import load_metric
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
from rafm import RAFM, rafm_train
import torch
import os


Define a utility function to calculate number of parameters in a model

In [2]:
from torch.nn import Parameter

def calculate_params(model):
    """calculate the number of parameters in the model
    Args:
        model: the model to be evaluated
    Returns:
        total_params: the number of parameters in the model
        percentage: the percentage of trainable parameters in the model
    """

    millions = 1000000
    total_params = 0
    for name, module in model.named_modules():
        if hasattr(module, "weight") and isinstance(module.weight, Parameter):
            total_params += torch.prod(torch.tensor(module.weight.size())).item()

    return total_params / millions


## Step 3. Load dataset and pre-process
First define the dataset process function for datasets from [huggingface datasets API](https://huggingface.co/)

In [3]:

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }
    
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
    inputs = processor([x for x in example_batch['img']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['label']
    return inputs

Then load and preprocess the datapoints from images to torch tensors

In [4]:
dataset = load_dataset('cifar100')
dataset = dataset.rename_column("fine_label", "label")

train_val = dataset["train"].train_test_split(test_size=0.2,seed=123)

dataset['train'] = train_val["train"]
dataset["validation"] = train_val["test"]

prepared_ds = dataset.with_transform(transform)


Check the dataset status

In [5]:

prepared_ds


DatasetDict({
    train: Dataset({
        features: ['img', 'label', 'coarse_label'],
        num_rows: 40000
    })
    test: Dataset({
        features: ['img', 'label', 'coarse_label'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['img', 'label', 'coarse_label'],
        num_rows: 10000
    })
})

## Step 3. Initialize a FM

Initialize the FM with pre-trained ckpts
You can find the ckpts in [README.MD](./README.MD), we provide the Scalable FM checkpoints trained by RaFFM in various FL-edge budget settings.

In this example, we use the scalable FM checkpoints downloaded from **[here](./README.MD/#download-the-scalable-fm-checkpoints)**


In [6]:
# use our pretrained ckpts: https://huggingface.co/yusx-swapp/ofm-vit-base-patch16-224-cifar100
ckpt_path = 'yusx-swapp/ofm-vit-base-patch16-224-cifar100'
labels = dataset['train'].features['label'].names

model = ViTForImageClassification.from_pretrained(
    ckpt_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

## Step 4. Convert it to RAFM scalable network

Next we load the scalable FMs weights trained by RaFFM, and convert it to a RaFFM scalable network

In [7]:

raffm_model = RAFM(model.to("cpu"))
print("Original FM number of parameters:",raffm_model.total_params)

Original FM number of parameters: 85.62048


# Experiment

## Randomly sample a downsize model

In [8]:
scaled_model,_,_ = raffm_model.random_resource_aware_model()


## Check the downsize model architecture

We are not do binary-masking or zero-masking the weights, we etract the clean downsize model, and the model size **is truely reduced**.

In [9]:
params = calculate_params(scaled_model)
print("scaled model params",params)
print(scaled_model)

scaled model params 59.864832
ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-2): 3 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (

# Evaluate the sampled scaled FM

In [10]:

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)


  metric = load_metric("accuracy")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [11]:
training_args = TrainingArguments(
  output_dir="./log/debug",
  per_device_train_batch_size=16,
  per_device_eval_batch_size=256,
  evaluation_strategy="steps",
  num_train_epochs=4,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

trainer = Trainer(
    model=scaled_model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    # tokenizer=processor,
)



In [12]:
metrics = trainer.evaluate()

We get **98.51%** accuracy from the downsize FM, **without further training**!

In [13]:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  eval_accuracy           =     0.9851
  eval_loss               =      0.155
  eval_runtime            = 0:01:16.49
  eval_samples_per_second =    130.723
  eval_steps_per_second   =      0.523


## Sample another downsized FMs and evaluation
You can sample as much as possible scaled from RaFFM model

In [14]:
scaled_model,_,_ = raffm_model.smallest_model()



In [15]:
params = calculate_params(scaled_model)
print("scaled model params",params)
print(scaled_model)

scaled model params 39.396456
ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 648, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=648, out_features=768, bias=True)
              (key): Linear(in_features=648, out_features=768, bias=True)
              (value): Linear(in_features=648, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=648, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
           

In [16]:

trainer = Trainer(
    model=scaled_model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    # tokenizer=processor,
)
metrics = trainer.evaluate()


In the second scaled FM we get **97.31%** Accuracy.

In [17]:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  eval_accuracy           =     0.9731
  eval_loss               =     0.2604
  eval_runtime            = 0:01:00.99
  eval_samples_per_second =    163.953
  eval_steps_per_second   =      0.656


# Not enough More models!
Sample one more downsized FMs and evaluation

In [18]:
scaled_model,_,_ = raffm_model.random_resource_aware_model()
params = calculate_params(scaled_model)
print("scaled model params",params)
print(scaled_model)

scaled model params 50.510952
ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 648, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0): ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=648, out_features=768, bias=True)
              (key): Linear(in_features=648, out_features=768, bias=True)
              (value): Linear(in_features=648, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=648, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense)

In [19]:

trainer = Trainer(
    model=scaled_model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    # tokenizer=processor,
)
metrics = trainer.evaluate()


In [20]:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  eval_accuracy           =     0.9789
  eval_loss               =     0.2144
  eval_runtime            = 0:01:03.36
  eval_samples_per_second =     157.82
  eval_steps_per_second   =      0.631
