# Downsize CLIP model use OFM

In this tutorial, we will show you how to quickly extract a downsized model from CLIP optimized by OFM. The downsized model is more efficient with siginificantly reduced parameters and FLOPs, while maintaining the competitive performance as the original model.

First, import dependency packages, and utility function for calculate model size.

In [None]:
import os
import torch
import numpy as np
from datasets import load_dataset
import functools
import evaluate
from arguments import arguments
from ofm import OFM
from ofm.trainer import TrainingArguments
from ofm.trainer import CLIPTrainer as Trainer

import functools
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
import torch

In [None]:

from torch.nn import Parameter
def calculate_params(model):
    """calculate the number of parameters in the model
    Args:
        model: the model to be evaluated
    Returns:
        total_params: the number of parameters in the model
        percentage: the percentage of trainable parameters in the model
    """

    millions = 1000000
    total_params = 0
    for name, module in model.named_modules():
        if hasattr(module, "weight") and isinstance(module.weight, Parameter):
            total_params += torch.prod(torch.tensor(module.weight.size())).item()

    return total_params / millions


Then, we load CLIP model optimized by OFM (we call it super-FM) use Huggingface `CLIPModel` API.

You can find our published checkpoint at [README.md](README.md)

In [None]:

ckpt_name = "ckpt_name_here" #You can find our published checkpoint at README.md
model = CLIPModel.from_pretrained(ckpt_name)
processor = CLIPProcessor.from_pretrained(ckpt_name)


Next, we convert the CLIP model to a supernet via OFM supernet class, with only 1 line of code.





In [None]:
supernet = OFM(model=model)


To extract a downsized model from the supernet, we can simply call the `resource_aware_model` API. There are multiple ways to get a downsized models, such as specifying the target model structure, get smallest size model with a elastic space, or get a random downsized model. More details can be found in the `examples/post_training_deployment`.

In this example, we extract the smallest model within a elastic space via `smallest_model` API.

In [None]:
ds_model, param, arc_config = supernet.smallest_model()


Let's compare the model size between the original model and the downsized model.

In [None]:
original_model_params = calculate_params(model)
ds_model_params = calculate_params(ds_model)
print(f"Original model has {original_model_params}M parameters")
print(f"Downsized model has {ds_model_params}M parameters")
print(f"Total model size reduction: {original_model_params - ds_model_params}M")

Now, let's evaluate the downsized model's performance on the CIFAR-10 dataset via the metric of **accuracy, F1, precision, and recall**.

First, we load the CIFAR-10 dataset and preprocess it.



In [None]:
dataset = load_dataset("cifar10", trust_remote_code=True)

label_to_text = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}

def collate_fn(batch):
    """This function is used to collate the data samples into batches.
    It is used to supply the DataLoader with the collate_fn argument.

    Args:
        batch: A list of samples from the dataset
    returns:
        A dictionary of tensors containing the batched samples
    """
    return {
        "pixel_values": torch.stack([x["pixel_values"] for x in batch]),
        "input_ids": torch.stack([x["input_ids"] for x in batch]),
        "labels": torch.tensor([x["labels"] for x in batch]),
    }

def transform_eval(example_batch, processor):
    # Take a list of PIL images and turn them to pixel values
    inputs = processor(
        text=[label_to_text[label] for label in range(10)],
        images=example_batch["img"],
        return_tensors="pt",
        padding=True,
    )
    inputs["labels"] = example_batch["label"]
    return inputs

prepared_test = dataset["test"].with_transform(
    functools.partial(transform_eval, processor=processor)
)

Then, we use use the `evaluate` function bellow to calculate the downsized model's performance on the cifar-10 dataset.



In [None]:
def evaluate(eval_dataloader):
    from sklearn.metrics import (
        accuracy_score,
        f1_score,
        precision_score,
        recall_score,
    )
    import tqdm
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    ds_model.eval()

    true_labels = []
    pred_labels = []

    progress_bar = tqdm(eval_dataloader, desc="Evaluation")

    for batch in progress_bar:
        batch = {k: v.to(device) for k, v in batch.items()}
        images = batch["pixel_values"]
        input_ids = batch["input_ids"]

        labels = batch["labels"]

        with torch.no_grad():
            outputs = ds_model(pixel_values=images, input_ids=input_ids)
            logits = outputs.logits_per_image
            predicted_labels = torch.argmax(logits, dim=1).to("cpu").tolist()

        true_labels.extend(labels.to("cpu"))
        pred_labels.extend(predicted_labels)

        # Calculate intermediate metrics
        accuracy = accuracy_score(true_labels, pred_labels)
        f1 = f1_score(true_labels, pred_labels, average="weighted")
        precision = precision_score(true_labels, pred_labels, average="weighted")
        recall = recall_score(true_labels, pred_labels, average="weighted")

        progress_bar.set_postfix(
            {
                "Accuracy": f"{accuracy:.4f}",
                "F1 Score": f"{f1:.4f}",
                "Precision": f"{precision:.4f}",
                "Recall": f"{recall:.4f}",
            }
        )
    eval_metrics = {
        "accuracy": accuracy,
        "f1": f1,
        "precision": precision,
        "recall": recall,
    }
    return eval_metrics


eval_dataloader = torch.utils.data.DataLoader(prepared_test, collate_fn=collate_fn, batch_size=32)
eval_metrics = evaluate(eval_dataloader)


Finally, we compare print out the downsized model's performance on the CIFAR-10 dataset.

In [None]:
print(eval_metrics)