<a href="https://colab.research.google.com/github/vaishnavi-gith/Knowledge-Distillation-for-RecSys-Efficiency/blob/main/KD_Recsys_clean.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required libraries
!pip install torch-snippets

# Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm.notebook import tqdm

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# --- Configuration Constants ---
NUM_USERS = 1000  # Number of users
NUM_ITEMS = 500   # Number of items
EMBED_SIZE = 64   # Embedding dimension for Teacher
EMBED_SIZE_STUDENT = 32 # Smaller embedding dimension for Student

# --- Data Simulation Function ---
def generate_synthetic_data(num_users, num_items, num_interactions=50000):
    """Generates synthetic (user, item) interaction pairs."""
    users = np.random.randint(0, num_users, num_interactions)
    items = np.random.randint(0, num_items, num_interactions)

    # Create a set of unique positive interactions for BPR sampling
    interactions = list(set(zip(users, items)))

    # Map users and items to unique IDs (0 to N-1)
    user_map = {uid: i for i, uid in enumerate(np.unique(users))}
    item_map = {iid: i for i, iid in enumerate(np.unique(items))}

    mapped_interactions = [(user_map[u], item_map[i]) for u, i in interactions]

    print(f"Generated {len(mapped_interactions)} unique positive interactions.")
    return mapped_interactions, len(user_map), len(item_map)

# --- BPR Dataset Class ---
class BPRDataset(Dataset):
    """
    Dataset for Bayesian Personalized Ranking (BPR).
    Yields (user, positive_item, negative_item) triplets.
    """
    def __init__(self, interactions, num_items):
        self.interactions = interactions # List of (user_id, item_id)
        self.num_items = num_items
        self.user_to_items = {}
        for u, i in interactions:
            self.user_to_items.setdefault(u, set()).add(i)

        self.users = list(self.user_to_items.keys())

    def __len__(self):
        # We sample based on the number of positive interactions
        return len(self.interactions)

    def __getitem__(self, idx):
        # Select a random user from the positive interactions
        u, i_pos = self.interactions[idx]

        # Sample a negative item (i_neg) that the user has not interacted with
        i_neg = np.random.randint(0, self.num_items)
        while i_neg in self.user_to_items.get(u, set()):
            i_neg = np.random.randint(0, self.num_items)

        return torch.tensor(u, dtype=torch.long), \
               torch.tensor(i_pos, dtype=torch.long), \
               torch.tensor(i_neg, dtype=torch.long)

# Generate data
interactions, actual_num_users, actual_num_items = generate_synthetic_data(NUM_USERS, NUM_ITEMS)
train_dataset = BPRDataset(interactions, actual_num_items)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=2)

print(f"Dataset Size: {len(train_dataset)} triplets.")

In [None]:
# --- Model Architectures ---

class MatrixFactorization(nn.Module):
    """
    Standard Matrix Factorization (MF) model.
    The complexity is determined by embed_size.
    """
    def __init__(self, num_users, num_items, embed_size):
        super().__init__()
        self.user_embeddings = nn.Embedding(num_users, embed_size)
        self.item_embeddings = nn.Embedding(num_items, embed_size)
        self.user_bias = nn.Embedding(num_users, 1)
        self.item_bias = nn.Embedding(num_items, 1)

        # Initialize weights
        nn.init.normal_(self.user_embeddings.weight, std=0.01)
        nn.init.normal_(self.item_embeddings.weight, std=0.01)
        nn.init.zeros_(self.user_bias.weight)
        nn.init.zeros_(self.item_bias.weight)

    def forward(self, user_ids, item_ids):
        u_embed = self.user_embeddings(user_ids)
        i_embed = self.item_embeddings(item_ids)

        u_bias = self.user_bias(user_ids).squeeze()
        i_bias = self.item_bias(item_ids).squeeze()

        # Output logit score (pre-activation, required for BPR/Distillation)
        logit_scores = (u_embed * i_embed).sum(dim=1) + u_bias + i_bias
        return logit_scores

# Instantiate models
teacher_model = MatrixFactorization(actual_num_users, actual_num_items, EMBED_SIZE).to(device)
student_model = MatrixFactorization(actual_num_users, actual_num_items, EMBED_SIZE_STUDENT).to(device)

print(f"Teacher Params: {sum(p.numel() for p in teacher_model.parameters()):,}")
print(f"Student Params: {sum(p.numel() for p in student_model.parameters()):,}")

In [None]:
# --- Distillation Loss Function ---

def distillation_loss_bpr(teacher_logits_diff, student_logits_diff, T=2.0, alpha=0.5):
    """
    Combined BPR Hard Loss and KL Divergence Soft Loss for distillation.

    Args:
        teacher_logits_diff (tensor): Teacher's score_pos - score_neg.
        student_logits_diff (tensor): Student's score_pos - score_neg.
        T (float): Temperature for smoothing the soft targets.
        alpha (float): Weighting parameter for the soft loss.

    Returns:
        tensor: The combined distillation loss.
    """

    # 1. Hard Loss (Standard BPR for the Student)
    # BPR Loss = -log(sigmoid(score_pos - score_neg))
    # This guides the student to rank the positive item higher than the negative item.
    hard_loss = -F.logsigmoid(student_logits_diff).mean()

    # 2. Soft Loss (Knowledge Distillation Term)

    # In BPR, the 'knowledge' is the difference in confidence (score_pos - score_neg).
    # We use Mean Squared Error (MSE) to force the Student's confidence difference
    # to match the Teacher's confidence difference (often scaled by T*T).

    # Detach the teacher's output to prevent updating its weights.
    soft_loss = F.mse_loss(student_logits_diff, teacher_logits_diff.detach()) * (T * T)

    # 3. Combined Loss
    combined_loss = (1.0 - alpha) * hard_loss + alpha * soft_loss

    return combined_loss

In [None]:
# --- Training Configuration ---
EPOCHS_TEACHER = 5
EPOCHS_STUDENT = 10
LR_TEACHER = 1e-3
LR_STUDENT = 1e-3
T = 3.0       # Distillation temperature
ALPHA = 0.7   # Weighting factor for soft loss (0.7 means 70% soft, 30% hard)

# --- Training Function ---
def train_model(model, loader, lr, epochs, optimizer_class=optim.Adam):
    optimizer = optimizer_class(model.parameters(), lr=lr)
    model.train()

    print(f"\n--- Training {model.__class__.__name__} (Hard Loss Only) ---")
    for epoch in range(epochs):
        total_loss = 0
        pbar = tqdm(loader, desc=f"E {epoch+1}/{epochs}")
        for u, i_pos, i_neg in pbar:
            u, i_pos, i_neg = u.to(device), i_pos.to(device), i_neg.to(device)

            optimizer.zero_grad()

            # Standard BPR: score_pos vs score_neg
            pos_scores = model(u, i_pos)
            neg_scores = model(u, i_neg)

            # BPR Loss = -log(sigmoid(score_pos - score_neg))
            diff = pos_scores - neg_scores
            hard_loss = -F.logsigmoid(diff).mean()

            hard_loss.backward()
            optimizer.step()
            total_loss += hard_loss.item()
            pbar.set_postfix(loss=total_loss / (pbar.n + 1))

        print(f"Epoch {epoch+1} Average Loss: {total_loss / len(loader):.4f}")
    return model

# --- Distillation Training Function ---
def train_distilled_student(teacher, student, loader, lr, epochs, T, alpha, optimizer_class=optim.Adam):
    teacher.eval() # Teacher is FIXED
    student.train()
    optimizer = optimizer_class(student.parameters(), lr=lr)

    print(f"\n--- Distilling Knowledge into Student (T={T}, alpha={alpha}) ---")

    for epoch in range(epochs):
        total_loss = 0
        pbar = tqdm(loader, desc=f"E {epoch+1}/{epochs}")
        for u, i_pos, i_neg in pbar:
            u, i_pos, i_neg = u.to(device), i_pos.to(device), i_neg.to(device)

            optimizer.zero_grad()

            # 1. Get Teacher's Knowledge (Difference in confidence)
            with torch.no_grad():
                t_pos = teacher(u, i_pos)
                t_neg = teacher(u, i_neg)
                teacher_diff = t_pos - t_neg # Teacher's BPR difference score

            # 2. Get Student's Predictions
            s_pos = student(u, i_pos)
            s_neg = student(u, i_neg)
            student_diff = s_pos - s_neg # Student's BPR difference score

            # 3. Calculate Distillation Loss
            loss = distillation_loss_bpr(teacher_diff, student_diff, T=T, alpha=alpha)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            pbar.set_postfix(loss=total_loss / (pbar.n + 1))

        print(f"Epoch {epoch+1} Average Loss: {total_loss / len(loader):.4f}")
    return student

# --- EXECUTION ---

# 1. Train the Teacher (High-capacity model)
trained_teacher = train_model(teacher_model, train_loader, LR_TEACHER, EPOCHS_TEACHER)

# 2. Train the Standard Student (Small model, no distillation, for baseline)
standard_student = MatrixFactorization(actual_num_users, actual_num_items, EMBED_SIZE_STUDENT).to(device)
trained_standard_student = train_model(standard_student, train_loader, LR_STUDENT, EPOCHS_STUDENT)

# 3. Train the Distilled Student (Small model, with teacher knowledge)
distilled_student = MatrixFactorization(actual_num_users, actual_num_items, EMBED_SIZE_STUDENT).to(device)
trained_distilled_student = train_distilled_student(
    trained_teacher, distilled_student, train_loader, LR_STUDENT, EPOCHS_STUDENT, T=T, alpha=ALPHA
)

print("\n--- Training Complete ---")
print("You now have three models to compare empirically (Recall@K, Latency):")
print("- Teacher: High Performance, Slow")
print("- Standard Student: Low Performance, Fast")
print("- Distilled Student: High Performance, Fast (The desired outcome)")

In [None]:
# --- Evaluation Configuration ---
K = 10 # Standard cutoff for recommendation evaluation
NUM_LATENCY_TESTS = 100 # Number of times to run prediction for averaging

# --- Evaluation Function ---
def evaluate_model(model, interaction_data, num_users, num_items, k):
    model.eval()

    # 1. Prepare Test Data (Simple Leave-One-Out setup)
    # Since we didn't explicitly split test data, we'll use a random sample
    # of users to test full ranking (predicting all items).

    # Map interactions for testing: user -> set of positive items
    user_to_pos_items = {}
    for u, i in interaction_data:
        user_to_pos_items.setdefault(u, set()).add(i)

    test_users = list(user_to_pos_items.keys())

    total_recall = 0
    total_ndcg = 0
    total_users = 0

    with torch.no_grad():
        for user_id in tqdm(test_users, desc="Evaluating Recall & NDCG"):
            # Select a random positive item as the 'target' for evaluation
            # (Simulating a leave-one-out setup)
            positive_items = list(user_to_pos_items[user_id])
            if not positive_items:
                continue

            target_item = positive_items[np.random.randint(len(positive_items))]

            # Predict scores for ALL items (0 to num_items-1)
            user_tensor = torch.tensor([user_id] * num_items, dtype=torch.long).to(device)
            all_item_tensor = torch.arange(num_items, dtype=torch.long).to(device)

            scores = model(user_tensor, all_item_tensor)

            # Mask out the positive items used in training/context (optional, but good practice)
            # Since we didn't explicitly hold out a separate test set, this is a simplified test.

            # Get the indices of the top K predicted items
            _, top_k_indices = scores.topk(k=k)

            # Check if the target item is in the top K
            is_hit = (top_k_indices == target_item).any().item()
            total_recall += is_hit

            # Calculate NDCG@K
            # Determine the rank of the target item
            rank = (scores.argsort(descending=True) == target_item).nonzero(as_tuple=True)[0].item()
            if rank < k:
                # DCG is 1 / log2(rank + 1). IDCG is 1 (perfect score).
                ndcg = 1.0 / np.log2(rank + 2)
                total_ndcg += ndcg

            total_users += 1

    recall_at_k = total_recall / total_users if total_users > 0 else 0
    ndcg_at_k = total_ndcg / total_users if total_users > 0 else 0

    return recall_at_k, ndcg_at_k

# --- Latency Measurement Function ---
def measure_latency(model, num_users, num_items, num_tests):
    model.eval()
    times = []

    # Create test batch: 1 user predicting all items
    user_id = 0 # Use the first user for consistency
    user_tensor = torch.tensor([user_id] * num_items, dtype=torch.long).to(device)
    all_item_tensor = torch.arange(num_items, dtype=torch.long).to(device)

    # Warm-up runs
    for _ in range(10):
        _ = model(user_tensor, all_item_tensor)

    # Timed runs
    for _ in range(num_tests):
        start_time = time.time()
        _ = model(user_tensor, all_item_tensor)
        end_time = time.time()
        times.append(end_time - start_time)

    avg_latency_ms = (sum(times) / num_tests) * 1000
    return avg_latency_ms

In [None]:
import time # Ensure time module is imported

results = []

models = {
    "Teacher (Embed=64)": trained_teacher,
    "Standard Student (Embed=32)": trained_standard_student,
    "Distilled Student (Embed=32)": trained_distilled_student,
}

print("\n--- Starting Empirical Comparison ---")

for name, model in models.items():
    # 1. Performance Metrics
    recall, ndcg = evaluate_model(model, interactions, actual_num_users, actual_num_items, K)

    # 2. Efficiency Metrics
    latency = measure_latency(model, actual_num_users, actual_num_items, NUM_LATENCY_TESTS)
    params = sum(p.numel() for p in model.parameters())

    results.append({
        "Model": name,
        "Recall@K": recall,
        "NDCG@K": ndcg,
        "Latency (ms)": latency,
        "Params": f"{params:,}"
    })

# --- Final Comparison Table ---
print("\n" + "="*80)
print(f"| FINAL KNOWLEDGE DISTILLATION RESULTS (K={K}) |")
print("="*80)
print(f"| {'Model':<30} | {'Recall@{K}':<10} | {'NDCG@{K}':<10} | {'Latency (ms)':<15} | {'Parameters':<10} |".replace('{K}', str(K)))
print("-"*80)

for r in results:
    print(f"| {r['Model']:<30} | {r['Recall@K']:.4f} | {r['NDCG@K']:.4f} | {r['Latency (ms)']:.3f} | {r['Params']:<10} |")

print("="*80)

# Critical Insight
print("\n🔥 **CRITICAL INSIGHT for Paper:**")
print("The Distilled Student must show Recall@K close to the Teacher, AND Latency close to the Standard Student.")
print("This proves the successful transfer of high-fidelity knowledge without the computational overhead.")