# Subnets specialization with OSF supernet

In this tutorial we will show you one of the key contribution of OSF:

**Convert a target pre-trained model to a supernet, and quickly compress the model to meet various resource constraints, without retraining**


To run this Tutorial, you need to install our package:  **`pip install .`**.

In this tutorial we use ViT with CIFAR-100 dataset as an example. We will show how to specialize the model to meet various resource constraints, without retraining.

## Step 1: import dependency libary

In [None]:

import numpy as np
import evaluate
from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
from ofm import OFM
import torch
import os


Define a utility function to calculate number of parameters in a model

In [2]:
from torch.nn import Parameter

def calculate_params(model):
    """calculate the number of parameters in the model
    Args:
        model: the model to be evaluated
    Returns:
        total_params: the number of parameters in the model
        percentage: the percentage of trainable parameters in the model
    """

    millions = 1000000
    total_params = 0
    for name, module in model.named_modules():
        if hasattr(module, "weight") and isinstance(module.weight, Parameter):
            total_params += torch.prod(torch.tensor(module.weight.size())).item()

    return total_params / millions


## Step 3. Load dataset and pre-process
First define the dataset process function for datasets from [Refer to the tutorial](https://huggingface.co/docs/transformers/tasks/image_classification)

In [3]:

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }
    
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
    inputs = processor([x for x in example_batch['img']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['label']
    return inputs

Then load and preprocess the datapoints from images to torch tensors, here we use CIFAR-100 as an example.

In [4]:
dataset = load_dataset('cifar100')
dataset = dataset.rename_column("fine_label", "label")

train_val = dataset["train"].train_test_split(test_size=0.2,seed=123)

dataset['train'] = train_val["train"]
dataset["validation"] = train_val["test"]

prepared_ds = dataset.with_transform(transform)


Check the dataset status

In [5]:

prepared_ds


DatasetDict({
    train: Dataset({
        features: ['img', 'label', 'coarse_label'],
        num_rows: 40000
    })
    test: Dataset({
        features: ['img', 'label', 'coarse_label'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['img', 'label', 'coarse_label'],
        num_rows: 10000
    })
})

## Step 3. Initialize a pre-trained model

Initialize the model with our supernet ckpts.
You can find the ckpts in [README.MD](./README.MD).

In this example, the supernet checkpoint you can find from **[here](https://huggingface.co/yusx-swapp/ofm-vit-base-patch16-224-cifar100)**


In [6]:
# use our pretrained ckpts: https://huggingface.co/yusx-swapp/ofm-vit-base-patch16-224-cifar100
ckpt_path = 'yusx-swapp/ofm-vit-base-patch16-224-cifar100'
labels = dataset['train'].features['label'].names

model = AutoModelForImageClassification.from_pretrained(
    ckpt_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

## Step 4. Convert the model to OSF supernet

Now let's convert the FM to OSF supernet via a high level `OFM` supernet conversion APIs, and check the status of the supernet

In [7]:

supernet = OFM(model.to("cpu"))
print("Original FM number of parameters:",supernet.total_params)

Original FM number of parameters: 85.62048


# Experiment

## Randomly sample a subnet

In [8]:
ds_model,_,_ = supernet.random_resource_aware_model()


## Check the subnet model architecture

OSF modeling a pre-trained model as hierarchical subnets representation instead of binary-masking or zero-masking the weights, we etract the clean subnet, and the model size **is truely reduced**.

In [9]:
params = calculate_params(ds_model)
print("scaled model params",params)
# You can also print the model to see the architecture
# print(scaled_model)

scaled model params 53.911656


# Evaluate the sampled subnets

In [10]:
def compute_metrics(eval_pred):
    """This function is used to compute the metrics for the evaluation.

    Args:
        eval_pred: The output of the Trainer.evaluate function
    returns:
        A dictionary of metrics
    """
    accuracy_metric = evaluate.load("accuracy")
    f1_metric = evaluate.load("f1")

    accuracy = accuracy_metric.compute(
        predictions=np.argmax(eval_pred.predictions, axis=1),
        references=eval_pred.label_ids,
    )
    f1 = f1_metric.compute(
        predictions=np.argmax(eval_pred.predictions, axis=1),
        references=eval_pred.label_ids,
        average="weighted",
    )

    return {"metric": accuracy["accuracy"], "f1": f1["f1"]}


## Now we use Trainer to evaluate the sampled subnets

In [11]:
training_args = TrainingArguments(
  output_dir="./log/debug",
  per_device_train_batch_size=16,
  per_device_eval_batch_size=256,
  evaluation_strategy="steps",
  num_train_epochs=4,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

trainer = Trainer(
    model=ds_model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=None,
    eval_dataset=prepared_ds["validation"],
    # tokenizer=processor,
)



In [12]:
metrics = trainer.evaluate()

We get **98.17%** accuracy from the subnets, **without further training**!

In [13]:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  eval_f1                 =     0.9817
  eval_loss               =     0.1817
  eval_metric             =     0.9817
  eval_runtime            = 0:01:03.07
  eval_samples_per_second =    158.536
  eval_steps_per_second   =      0.634


## Sample another subnet and evaluation
You can sample as much as possible subnets from the supernet.

In [14]:
scaled_model,_,_ = supernet.smallest_model()



In [15]:
params = calculate_params(scaled_model)
print("scaled model params",params)
# print(scaled_model)

scaled model params 39.396456
ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 648, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=648, out_features=768, bias=True)
              (key): Linear(in_features=648, out_features=768, bias=True)
              (value): Linear(in_features=648, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=648, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
           

In [16]:

trainer = Trainer(
    model=scaled_model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=None,
    eval_dataset=prepared_ds["validation"],
    # tokenizer=processor,
)
metrics = trainer.evaluate()


In the second subnet we get **97.31%** Accuracy.

In [17]:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  eval_accuracy           =     0.9731
  eval_loss               =     0.2604
  eval_runtime            = 0:01:00.99
  eval_samples_per_second =    163.953
  eval_steps_per_second   =      0.656


# Not enough More models!
Sample one more subnets and evaluation

In [18]:
scaled_model,_,_ = super.random_resource_aware_model()
params = calculate_params(scaled_model)
print("scaled model params",params)
# print(scaled_model)

scaled model params 50.510952
ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 648, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0): ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=648, out_features=768, bias=True)
              (key): Linear(in_features=648, out_features=768, bias=True)
              (value): Linear(in_features=648, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=648, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense)

In [19]:

trainer = Trainer(
    model=scaled_model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=None,
    eval_dataset=prepared_ds["validation"],
    # tokenizer=processor,
)
metrics = trainer.evaluate()


In [20]:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

***** eval metrics *****
  eval_accuracy           =     0.9789
  eval_loss               =     0.2144
  eval_runtime            = 0:01:03.36
  eval_samples_per_second =     157.82
  eval_steps_per_second   =      0.631
