In [None]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset


In [2]:
# data = load_dataset("imdb")
tokenizer = AutoTokenizer.from_pretrained(r"models\distilbert_model")
model = AutoModelForSequenceClassification.from_pretrained(r"models\distilbert_model", num_labels=2)

In [None]:

# Force CPU for quantization
device = torch.device("cpu")
model.to(device)

# =====================
# Load dataset
# =====================
imdb = load_dataset("imdb")
test_texts = imdb["test"]["text"][:200]   # subset for quick benchmark
test_labels = imdb["test"]["label"][:200]

# =====================
# Evaluation function
# =====================
def evaluate(model, tokenizer, texts, labels):
    model.eval()
    correct = 0
    start = time.time()
    with torch.no_grad():
        for text, label in zip(texts, labels):
            inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
            outputs = model(**inputs)
            pred = torch.argmax(outputs.logits, dim=1).item()
            correct += (pred == label)
    latency = (time.time() - start) / len(texts)
    accuracy = correct / len(texts)
    return accuracy, latency

# =====================
# Benchmark FP32
# =====================
acc_fp32, latency_fp32 = evaluate(model, tokenizer, test_texts, test_labels)
print(f"[FP32] Accuracy: {acc_fp32:.4f}, Latency: {latency_fp32:.4f}s/sample")

# =====================
# Quantize model (INT8)
# =====================
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

# =====================
# Benchmark INT8
# =====================
acc_int8, latency_int8 = evaluate(quantized_model, tokenizer, test_texts, test_labels)
print(f"[INT8] Accuracy: {acc_int8:.4f}, Latency: {latency_int8:.4f}s/sample")

# =====================
# Compare memory footprint
# =====================
def get_size(model):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p") / 1e6
    os.remove("temp.p")
    return size

print(f"Model size (FP32): {get_size(model):.2f} MB")
print(f"Model size (INT8): {get_size(quantized_model):.2f} MB")


('./model\\tokenizer_config.json',
 './model\\special_tokens_map.json',
 './model\\vocab.txt',
 './model\\added_tokens.json',
 './model\\tokenizer.json')