# RaFFM – Foundation Model Scaling for Post-training Deployment
This is the tutorial for Generate high-performance scaled FMs for heterogeneous resource local clients.

In this tutorial, we will show you how to generate scaled ViT for image classification use our pre-trained RaFFM checkpoints.

**[Note from author] I run this experiments on Macbook Air with M2 chip**

In [1]:
import numpy as np
from datasets import load_metric
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
from raffm import RaFFM
import torch


  from .autonotebook import tqdm as notebook_tqdm


## 1. Download RaFFM checkpoints
You can download it from here:


or you can train you own

## Load dataset and pre-process

In [2]:
dataset = load_dataset('cifar10')

In [3]:
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }
    
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = processor([x for x in example_batch['img']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['label']
    return inputs

train_val = dataset["train"].train_test_split(test_size=0.2)
train_val

dataset['train'] = train_val["train"]
dataset["validation"] = train_val["test"]

prepared_ds = dataset.with_transform(transform)


In [9]:
prepared_ds

DatasetDict({
    train: Dataset({
        features: ['img', 'label'],
        num_rows: 40000
    })
    test: Dataset({
        features: ['img', 'label'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['img', 'label'],
        num_rows: 10000
    })
})

## Load scalable FM from RaFFM checkpoint

In [4]:
ckpt_path = 'ckpt/cifar10'
labels = dataset['train'].features['label'].names

model = ViTForImageClassification.from_pretrained(
    ckpt_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

In [5]:
raffm_model = RaFFM(model.to("cpu"))
print("Original FM number of parameters:",raffm_model.total_params)

Original FM number of parameters: 85.55136


## Sample a scaled FM 

In [6]:
submodel,params,arc_config = raffm_model.random_resource_aware_model()

print("subnetwork params",params)

subnetwork params 72.673536


# Evaluate high-performance scaled FM

In [7]:

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)


  metric = load_metric("accuracy")


In [8]:
training_args = TrainingArguments(
  output_dir="./log/debug",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

trainer = Trainer(
    model=submodel,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    # tokenizer=processor,
)



In [10]:
metrics = trainer.evaluate()


100%|██████████| 1250/1250 [04:48<00:00,  4.33it/s]


We get **97.52%** accuracy from the scaled FM!

In [11]:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  eval_accuracy           =     0.9752
  eval_loss               =     0.0863
  eval_runtime            = 0:04:49.50
  eval_samples_per_second =     34.542
  eval_steps_per_second   =      4.318


## Sample another scaled FMs and evaluation
You can sample as much as possible scaled from RaFFM model

In [12]:
submodel,params,arc_config = raffm_model.random_resource_aware_model()

print("subnetwork params",params)

subnetwork params 75.72096


In [13]:

trainer = Trainer(
    model=submodel,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    # tokenizer=processor,
)
metrics = trainer.evaluate()


100%|██████████| 1250/1250 [04:26<00:00,  4.69it/s]


In the second scaled FM we get **97.6%** Accuracy.

In [14]:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  eval_accuracy           =      0.976
  eval_loss               =      0.085
  eval_runtime            = 0:04:26.80
  eval_samples_per_second =     37.481
  eval_steps_per_second   =      4.685
