# **Libraries**

In [32]:
import torch
from torch import nn
from transformers import BertForSequenceClassification, DistilBertForSequenceClassification, BertTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AdamW
import os
import numpy as np
import random

# **Part 4: Model Distillation/Quantization**

### **a. Model Distillation/Quantization**: 
Distill/Quantize your best-performing model into a lighter model. Document the process and tools used.

In [15]:
path = os.getcwd()
path_model = os.path.join(path, 'saved_teacher_model')

In [16]:

# Load your fine-tuned teacher model
teacher_model = BertForSequenceClassification.from_pretrained(path_model).eval()
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")


In [17]:
# Load Financial Phrasebank Dataset (assuming you created a custom version)
dataset = load_dataset("financial_phrasebank", "sentences_75agree")
train_dataset = dataset["train"]

# Preprocessing
def tokenize(batch):
    return tokenizer(batch["sentence"], truncation=True, padding="max_length", max_length=128)

tokenized = train_dataset.map(tokenize, batched=True)
tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])


In [18]:
student_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=3)
student_model.train()

train_loader = DataLoader(tokenized, batch_size=16)

# Define loss and optimizer
kd_loss_fn = nn.KLDivLoss(reduction="batchmean")
ce_loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(student_model.parameters(), lr=5e-5)

# Training loop
for epoch in range(5):
    for batch in train_loader:
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["label"]

        with torch.no_grad():
            teacher_logits = teacher_model(input_ids, attention_mask=attention_mask).logits
            soft_labels = torch.nn.functional.softmax(teacher_logits / 2.0, dim=-1)

        student_logits = student_model(input_ids, attention_mask=attention_mask).logits
        student_log_probs = torch.nn.functional.log_softmax(student_logits / 2.0, dim=-1)

        kd_loss = kd_loss_fn(student_log_probs, soft_labels)
        ce_loss = ce_loss_fn(student_logits, labels)
        loss = 0.5 * kd_loss + 0.5 * ce_loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

student_model.save_pretrained("distilled_student_model")


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:

quantized_model = torch.quantization.quantize_dynamic(
    student_model, {torch.nn.Linear}, dtype=torch.qint8
)

torch.save(quantized_model.state_dict(), "quantized_student_model.pt")


### **b. Performance and Speed Comparison**: 
Evaluate the distilled model's performance and inference speed compared to the original. Highlight key findings.

In [37]:

# --- Reproducibility ---
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# --- Device ---
device = torch.device("cpu")


# --- Load Models ---
teacher_model = BertForSequenceClassification.from_pretrained("saved_teacher_model").to(device)
student_model = DistilBertForSequenceClassification.from_pretrained("distilled_student_model").to(device)
quantized_model = torch.quantization.quantize_dynamic(
    student_model, {torch.nn.Linear}, dtype=torch.qint8
).to(device)

# --- Inference Time Benchmarking ---
def benchmark(model, inputs, label):
    model.eval()
    with torch.no_grad():
        start = time.time()
        _ = model(**inputs)
        elapsed = time.time() - start
    print(f"{label} Inference Time: {elapsed:.4f}s")

# Single sentence for timing
timing_input = tokenizer("The company's performance was exceptional this quarter", return_tensors="pt", padding="max_length", truncation=True, max_length=128)
timing_input.pop("token_type_ids", None)
timing_input = {k: v.to(device) for k, v in timing_input.items()}

print("⏱️ Inference Time Benchmarking:")
benchmark(teacher_model, timing_input, "Original (Teacher)")
benchmark(student_model, timing_input, "Distilled (Student)")
benchmark(quantized_model, timing_input, "Quantized (Student)")

# --- Evaluation on Full Dataset ---
def evaluate_model(model, dataloader, label):
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="weighted")

    print(f"\n📊 {label} Evaluation Metrics:")
    print(f"Accuracy:  {acc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-score:  {f1:.4f}")

# Create DataLoader from your tokenized dataset
test_loader = DataLoader(tokenized, batch_size=32)

evaluate_model(teacher_model, test_loader, "Original (Teacher)")
evaluate_model(student_model, test_loader, "Distilled (Student)")
evaluate_model(quantized_model, test_loader, "Quantized (Student)")

# --- Custom Example Sentences (Manual Evaluation) ---
examples = [
    ("The company's performance was exceptional this quarter", 2),  # Positive
    ("The results were below expectations and disappointed investors", 0),  # Negative
    ("The company held a press conference regarding its quarterly report", 1),  # Neutral
]

def predict(model, sentence):
    model.eval()
    inputs = tokenizer(sentence, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
    inputs.pop("token_type_ids", None)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
        return torch.argmax(probs, dim=1).item()

print("\n🔍 Custom Sentence Predictions:\n")
for text, true_label in examples:
    print(f"📝 Sentence: {text}")
    for name, model in [("Teacher", teacher_model), ("Student", student_model), ("Quantized", quantized_model)]:
        pred = predict(model, text)
        correct = "✅" if pred == true_label else "❌"
        print(f"  🔹 {name} Prediction: {pred} | {correct}")
    print()


⏱️ Inference Time Benchmarking:
Original (Teacher) Inference Time: 0.1365s
Distilled (Student) Inference Time: 0.0516s
Quantized (Student) Inference Time: 0.0450s


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



📊 Original (Teacher) Evaluation Metrics:
Accuracy:  0.6215
Precision: 0.3862
Recall:    0.6215
F1-score:  0.4764


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



📊 Distilled (Student) Evaluation Metrics:
Accuracy:  0.1216
Precision: 0.0148
Recall:    0.1216
F1-score:  0.0264


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



📊 Quantized (Student) Evaluation Metrics:
Accuracy:  0.1216
Precision: 0.0148
Recall:    0.1216
F1-score:  0.0264

🔍 Custom Sentence Predictions:

📝 Sentence: The company's performance was exceptional this quarter
  🔹 Teacher Prediction: 1 | ❌
  🔹 Student Prediction: 0 | ❌
  🔹 Quantized Prediction: 0 | ❌

📝 Sentence: The results were below expectations and disappointed investors
  🔹 Teacher Prediction: 1 | ❌
  🔹 Student Prediction: 0 | ✅
  🔹 Quantized Prediction: 0 | ✅

📝 Sentence: The company held a press conference regarding its quarterly report
  🔹 Teacher Prediction: 1 | ✅
  🔹 Student Prediction: 0 | ❌
  🔹 Quantized Prediction: 0 | ❌



### **c. Analysis and Improvements**: 
Analyze deficiencies in the student model's learning. Suggest potential improvements or further research directions.

The student and especially the quantized model show significant performance degradation when compared to the original teacher model:
** Observed Deficiencies:**

- **Very Low Accuracy and F1-Score:**
    - The quantized student model has an accuracy of only ~12%, with F1-score near zero, indicating it is nearly guessing or predicting one class consistently.
    - Even the unquantized student model performs poorly on custom examples, suggesting distillation failed to retain key patterns from the teacher.

- **Mode Collapse / Class Bias:**
    - The student and quantized models overpredict class 0 (Negative), regardless of actual sentiment.
    - This indicates poor generalization and possible class imbalance or overfitting to frequent negative samples.

- **Teacher Errors:**
    - Even the teacher incorrectly predicted class 1 (Neutral) for a clearly Positive sentence, suggesting potential label noise or inadequate fine-tuning.



**Suggested Improvements & Research Directions:**

- **Better Distillation Process:**
    - Use soft-labels (logits/softmax outputs) from the teacher instead of just hard labels.
    - Apply temperature scaling during distillation to preserve class probability distributions.

- **Data Augmentation:**
    - The Financial Phrasebank is relatively small. Apply paraphrasing, back-translation, or synonym replacement to create richer training samples for the student.

- **Balance the Dataset:**
    - Analyze the class distribution. If imbalanced, apply class-weighted loss or oversampling for minority classes.

- **Quantization-Aware Training (QAT):**
    - Instead of post-training quantization, train the student model with quantization simulated during training to preserve accuracy.

- **Error Analysis & Curriculum Learning:**
    - Identify hard-to-classify examples and apply focused retraining or curriculum learning to guide the student through easier to harder samples.

- **Layer-Wise Distillation:**
    - Instead of only distilling final logits, distill hidden representations (intermediate features) to better capture the teacher’s knowledge. Improvements & Research Directions:

 