# Salient Parameter Prioritization (SPP)
## Experiment Goal

In this experiment we show:

- **The novelty and nessisity of RaFFM's specialized SPP**
- **RaFFM SPP preserves the pre-trained knowledge in FMs**
- **Comparison with Standard Pruning-Based Weights Ranking**

## Step 1. Import dependencies

In [7]:
import numpy as np
from datasets import load_metric
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
import torch
import os

## Step 2. Define Parameter Prioritization functions
In this  tutorial we use Vision Transformer (ViT) as an example.

Hence first import RaFFM's specialized SPP components from our libary.

In [8]:
from raffm.param_prioritization import l1_norm,vit_spp_handler

Next, define our baseline: standard pruning-based salient parameter prioritization functions.

In [9]:
def standard_l1_rank_metrics(query, key):
    """
    Rank rows of query and key matrices based on the average L1 norm.

    Args:
    - query (torch.Tensor): The query matrix in attention layer.
    - key (torch.Tensor): The key matrix in attention layer.

    Returns:
    - query_ranked_indices (torch.Tensor): Ranked row indices based on the L1 norm for query.
    - key_ranked_indices (torch.Tensor): Ranked row indices based on the L1 norm for key.
    """

    # Validate input sizes
    if query.size(0) != key.size(0) or query.size(1) != key.size(1):
        raise ValueError("The query and key matrices must have the same dimensions.")

  
    query_head = query
    key_head = key

    # Calculate L1 norm for each row in both matrices for the current head
    query_norms = query_head.norm(p=1, dim=1)
    key_norms = key_head.norm(p=1, dim=1)


    # Sort the rows based on these average norms in descending order and get the indices
    _, query_ranked_indices = torch.sort(query_norms, descending=True)
    _, key_ranked_indices = torch.sort(key_norms, descending=True)



    return query_ranked_indices,key_ranked_indices

def standard_spp_handler(model):
    num_attn_head = model.config.num_attention_heads
    for name, module in model.named_modules():
        # Check if the module is BertSelfAttention

        if "ViTSelfAttention" in str(type(module)):
            # Get permutation using the metric function
            query_rank, key_rank = standard_l1_rank_metrics(
                module.query.weight.data,
                module.key.weight.data
            )

            # Ensure the permutation is in the correct format
            assert isinstance(
                query_rank, torch.Tensor
            ), "The metric function must return a torch.Tensor."
            assert (
                query_rank.shape[0] == module.query.weight.shape[0]
            ), "Invalid permutation size."

            # Permute the query weights
            module.query.weight.data = module.query.weight.data[query_rank, :]
            if module.query.bias is not None:
                module.query.bias.data = module.query.bias.data[query_rank]

            # Permute the key weights
            module.key.weight.data = module.key.weight.data[key_rank, :]
            if module.key.bias is not None:
                module.key.bias.data = module.key.bias.data[key_rank]


## Step 3. Define Vision Transformer and Evaluation Dataset
First define the evaluation dataset and process functions

In [10]:

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }
    
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
    inputs = processor([x for x in example_batch['img']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['label']
    return inputs

dataset = load_dataset('cifar10')
train_val = dataset["train"].train_test_split(test_size=0.1,seed=123)

dataset['train'] = train_val["train"]
dataset["validation"] = train_val["test"]

# Define a tiny training set
train_val = dataset["train"].train_test_split(test_size=0.2,seed=123)
dataset['train'] = train_val["test"]

prepared_ds = dataset.with_transform(transform)
prepared_ds

Downloading builder script:   0%|          | 0.00/3.61k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.66k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/5.00k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/170M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['img', 'label'],
        num_rows: 9000
    })
    test: Dataset({
        features: ['img', 'label'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['img', 'label'],
        num_rows: 5000
    })
})

Then initialize the FM – ViT

In [17]:
ckpt_path = 'google/vit-base-patch16-224-in21k'
labels = dataset['train'].features['label'].names

model = ViTForImageClassification.from_pretrained(
    ckpt_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Step 4. Salient Parameter Prioritization
We will prioritize the model use **Standard pruning based SPP**, **RaFFM Specialized SPP**, and keep the **original FMs** 

In [18]:
import copy

# Standard pruning based SPP
prune_spp_model = copy.deepcopy(model)
standard_spp_handler(model=prune_spp_model)

#RaFFM Specialized SPP
raffm_spp_model = copy.deepcopy(model)
vit_spp_handler(raffm_spp_model, l1_norm)

## Experiment 

**Objective**: Train the above 3 models on target evaluation dataset and compare the performance

### Train the original ViT
First, let's **train the Original ViT**


In [13]:

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

training_args = TrainingArguments(
  output_dir="./log/debug",
  per_device_train_batch_size=16,
  evaluation_strategy="no",
  num_train_epochs=1,
  save_strategy="no",
  # save_steps=100,
  # eval_steps=100,
  logging_steps=100,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    # tokenizer=processor,
)

  metric = load_metric("accuracy")


Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

In [14]:
trainer.train()

(…)24/resolve/main/preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

Step,Training Loss
100,0.9196
200,0.3353
300,0.2802
400,0.1917
500,0.1545


TrainOutput(global_step=563, training_loss=0.3467688492729228, metrics={'train_runtime': 203.3682, 'train_samples_per_second': 44.255, 'train_steps_per_second': 2.768, 'total_flos': 6.97477913137152e+17, 'train_loss': 0.3467688492729228, 'epoch': 1.0})

### Evaluate the original FMs
After on training, the original FM get the validation accuracy of **96.7%**

In [15]:
trainer.evaluate()

{'eval_loss': 0.1344059556722641,
 'eval_accuracy': 0.967,
 'eval_runtime': 88.2851,
 'eval_samples_per_second': 56.635,
 'eval_steps_per_second': 7.079,
 'epoch': 1.0}

### Train prioritized ViT by RaFFM SPP
After on training, the ViT prioritized by RaFFM get the validation accuracy of **96.78%**, which is even **perform better than original ViT**.

In [19]:
trainer = Trainer(
    model=raffm_spp_model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    # tokenizer=processor,
)
trainer.train()
trainer.evaluate()

Step,Training Loss
100,0.9345
200,0.3609
300,0.2451
400,0.1877
500,0.1721


{'eval_loss': 0.1350841522216797,
 'eval_accuracy': 0.9678,
 'eval_runtime': 91.1864,
 'eval_samples_per_second': 54.833,
 'eval_steps_per_second': 6.854,
 'epoch': 1.0}

### Train prioritized ViT by standard SPP
After on training, the ViT prioritized by standard pruning-based SPP get the validation accuracy of **48.82%**, which is far more lower than original ViT and RaFFM prioritization ViT.

Since standard SPP without further consideration of the attention mechanism of transformers, will destroy the pre-trained knowledge

In [20]:
trainer = Trainer(
    model=prune_spp_model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    # tokenizer=processor,
)
trainer.train()
trainer.evaluate()

Step,Training Loss
100,2.1087
200,1.9118
300,1.7474
400,1.6076
500,1.5561


{'eval_loss': 1.4469162225723267,
 'eval_accuracy': 0.4882,
 'eval_runtime': 90.3728,
 'eval_samples_per_second': 55.326,
 'eval_steps_per_second': 6.916,
 'epoch': 1.0}