In [None]:
!pip install -q datasets

In [None]:
import torch
from transformers import BertForSequenceClassification, BertTokenizer, pipeline
from torch.nn.utils import prune
import time
from datasets import load_dataset

# Load the model and tokenizer
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
original_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)

# Load a test dataset (e.g., IMDB small subset)
dataset = load_dataset("imdb", split="test[:50%]")  # Using a small portion for demonstration
texts = dataset["text"]
labels = dataset["label"]

# Check for cpu type if arm or x86
import os
import platform
import subprocess

def get_cpu_type():
    # Check if the CPU is ARM or x86
    if "aarch64" in os.uname().machine:
        return "arm"
    else:
        return "x86"
    
cpu_type = get_cpu_type()
if cpu_type == "arm":
    # Enable quantization backend
    torch.backends.quantized.engine = "qnnpack"  # Use 'fbgemm' for x86 CPUs

# Function to evaluate model
def evaluate_model(model, texts, labels, tokenizer):
    # Create a pipeline for sentiment analysis
    classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, truncation=True, padding=True)
    
    # Accuracy calculation
    correct = 0
    total_time = 0
    for text, label in zip(texts, labels):
        start_time = time.time()
        prediction = classifier(text)[0]["label"]
        end_time = time.time()
        
        # Convert prediction to binary label
        pred_label = 1 if prediction == "POSITIVE" else 0
        if pred_label == label:
            correct += 1
        
        total_time += (end_time - start_time)
    
    accuracy = correct / len(texts)
    avg_time = total_time / len(texts)
    return accuracy, avg_time

In [None]:
# Evaluate the original model
original_model.eval()
print("Evaluating Original Model...")
orig_acc, orig_time = evaluate_model(original_model, texts, labels, tokenizer)
print(f"Original Model Accuracy: {orig_acc:.2f}, Average Inference Time: {orig_time:.4f} seconds")

In [None]:
# Prune the model
print("\nApplying Pruning...")
pruned_model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
for name, module in pruned_model.bert.encoder.layer[0].attention.self.named_modules():
    if isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name="weight", amount=0.2)  # Prune 20% weights
pruned_model.eval()

print("Evaluating Pruned Model...")
pruned_acc, pruned_time = evaluate_model(pruned_model, texts, labels, tokenizer)
print(f"Pruned Model Accuracy: {pruned_acc:.2f}, Average Inference Time: {pruned_time:.4f} seconds")

In [None]:
print(torch.backends.quantized.supported_engines)

In [None]:
# Quantize the model
print("\nApplying Quantization...")
quantized_model = torch.quantization.quantize_dynamic(
    original_model, {torch.nn.Linear}, dtype=torch.qint8
)
quantized_model.eval()

print("Evaluating Quantized Model...")
quant_acc, quant_time = evaluate_model(quantized_model, texts, labels, tokenizer)
print(f"Quantized Model Accuracy: {quant_acc:.2f}, Average Inference Time: {quant_time:.4f} seconds")