# Downsize CLIP model use OSF

In this tutorial, we will show you how to quickly extract various subnets from super-CLIP optimized by OSF. The subnets is more efficient with siginificantly reduced parameters and FLOPs, while maintaining the competitive performance as the original model.

First, import dependency packages, and utility function for calculate model size.

Also make sure you have GPU enabled.

In [None]:

import torch
from datasets import load_dataset
import functools
from ofm import OFM
from tqdm import tqdm

import functools
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
import torch
import warnings
warnings.filterwarnings("ignore")


In [2]:

from torch.nn import Parameter
def calculate_params(model):
    """calculate the number of parameters in the model
    Args:
        model: the model to be evaluated
    Returns:
        total_params: the number of parameters in the model
        percentage: the percentage of trainable parameters in the model
    """

    millions = 1000000
    total_params = 0
    for name, module in model.named_modules():
        if hasattr(module, "weight") and isinstance(module.weight, Parameter):
            total_params += torch.prod(torch.tensor(module.weight.size())).item()

    return total_params / millions


In [3]:
!nvidia-smi

Thu Mar 28 19:03:11 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 520.61.05    Driver Version: 520.61.05    CUDA Version: 11.8     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:AF:00.0 Off |                    0 |
| N/A   30C    P0    83W / 500W |  12265MiB / 81920MiB |     51%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

Then, we load super-CLIP model optimized by OSF use Huggingface `CLIPModel` API.

You can find our published checkpoint at [README.md](README.md) 

The model checkpoint we used is [here](https://huggingface.co/yusx-swapp/ofm-clip-base-patch32-cifar10)

In [4]:

ckpt_name = "yusx-swapp/ofm-clip-base-patch32-cifar10" #You can find our published checkpoint at README.md
model = CLIPModel.from_pretrained(ckpt_name)
processor = CLIPProcessor.from_pretrained(ckpt_name)


Next, we convert the CLIP model to a supernet via OSF AIP :`OFM supernet class`, with only 1 line of code.





In [5]:
supernet = OFM(model=model)


To extract a subnet from the supernet, we can simply call the `resource_aware_model` API. There are multiple ways to get a subnet, such as specifying the target model structure, get smallest size model with a elastic space, or get a random downsized model. More details can be found in the `examples/post_training_deployment`.

In this example, we randomly sample a subnet within the search space via `random_resource_aware_model` API.

In [6]:
ds_model, param, arc_config = supernet.random_resource_aware_model()


Let's compare the model size between the original model and the subnet.

In [7]:
original_model_params = calculate_params(model)
ds_model_params = calculate_params(ds_model)
print(f"Original model has {original_model_params}M parameters")
print(f"Downsized model has {ds_model_params}M parameters")
print(f"Total model size reduction: {original_model_params - ds_model_params}M")

Original model has 151.105536M parameters
Downsized model has 112.832512M parameters
Total model size reduction: 38.27302400000001M


Now, let's evaluate the subnet's performance on the CIFAR-10 dataset via the metric of **accuracy, F1, precision, and recall**.

First, we load the CIFAR-10 dataset and preprocess it.



In [8]:
dataset_name = "cifar10"
dataset = load_dataset(dataset_name)

labels = dataset["train"].features["label"].names
label_to_text = {i: label for i, label in enumerate(labels)}


def collate_fn(batch):
    """This function is used to collate the data samples into batches.
    It is used to supply the DataLoader with the collate_fn argument.

    Args:
        batch: A list of samples from the dataset
    returns:
        A dictionary of tensors containing the batched samples
    """
    return {
        "pixel_values": torch.stack([x["pixel_values"] for x in batch]),
        "input_ids": torch.stack([x["input_ids"] for x in batch]),
        "labels": torch.tensor([x["labels"] for x in batch]),
    }

def transform_eval(example_batch, processor):
    # Take a list of PIL images and turn them to pixel values
    inputs = processor(
        text=[label_to_text[label] for label in range(10)],
        images=example_batch["img"],
        return_tensors="pt",
        padding=True,
    )
    inputs["labels"] = example_batch["label"]
    return inputs

prepared_test = dataset["test"].with_transform(
    functools.partial(transform_eval, processor=processor)
)

Then, we use use the `evaluate` function bellow to calculate the subnet's performance on the cifar-10 dataset.



In [9]:
def evaluate(eval_dataloader):
    from sklearn.metrics import (
        accuracy_score,
        f1_score,
        precision_score,
        recall_score,
    )

    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    ds_model.to(device)
    ds_model.eval()

    true_labels = []
    pred_labels = []

    progress_bar = tqdm(eval_dataloader, desc="Evaluation")

    for batch in progress_bar:
        batch = {k: v.to(device) for k, v in batch.items()}
        images = batch["pixel_values"]
        input_ids = batch["input_ids"]

        labels = batch["labels"]

        with torch.no_grad():
            outputs = ds_model(pixel_values=images, input_ids=input_ids)
            logits = outputs.logits_per_image
            predicted_labels = torch.argmax(logits, dim=1).to("cpu").tolist()

        true_labels.extend(labels.to("cpu"))
        pred_labels.extend(predicted_labels)

        # Calculate intermediate metrics
        accuracy = accuracy_score(true_labels, pred_labels)
        f1 = f1_score(true_labels, pred_labels, average="weighted")
        precision = precision_score(true_labels, pred_labels, average="weighted")
        recall = recall_score(true_labels, pred_labels, average="weighted")

        progress_bar.set_postfix(
            {
                "Accuracy": f"{accuracy:.4f}",
                "F1 Score": f"{f1:.4f}",
                "Precision": f"{precision:.4f}",
                "Recall": f"{recall:.4f}",
            }
        )
    eval_metrics = {
        "accuracy": accuracy,
        "f1": f1,
        "precision": precision,
        "recall": recall,
    }
    return eval_metrics
eval_dataloader = torch.utils.data.DataLoader(
            prepared_test,
            batch_size=256,
            collate_fn=collate_fn,
            shuffle=False,
            num_workers=8,
            # drop_last=True,
        )
eval_metrics = evaluate(eval_dataloader)

Evaluation:   0%|          | 0/40 [00:00<?, ?it/s]Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Unused or unrecognized kwargs: padding.
Evaluation:   2%|▎         | 1/40 [00:06<04:03,  6.24s/it, Accuracy=0.8000, F1 Score=0.8667, Precision=1.0000, Recall=0.8000]Unused or unrecognized kwargs: padding.
Evaluation:   8%|▊         | 3/40 [00:06<01:01,  1.66s/it, Accuracy=0.8000, F1 Score=0.7967, Precision=0.


Finally, we compare print out the subnet's performance on the CIFAR-10 dataset.

In [10]:
print(eval_metrics)

{'accuracy': 0.8375, 'f1': 0.8293750643570631, 'precision': 0.8683762448655351, 'recall': 0.8375}
