## Pairwise Preference Distillation with Flan-T5

#### Objective
- Understand knowledge distillation and its application to recommendation systems.
- Distill a large LLM (teacher) into a smaller, more efficient model (student) for recommendation tasks.
- Evaluate the performance and efficiency of the distilled model compared to the teacher model.


#### Prerequisite: Install Necessary Libraries

In [25]:
# %pip install transformers datasets evaluate sentencepiece accelerate openai

In [26]:
import openai

print(openai.__version__)

1.93.2


#### Step 1: Prepare the Dataset from MovieLens Data

In [27]:
import pandas as pd
import numpy as np
import random
import pandas as pd
import random

def generate_pairwise_prompts_from_movielens(movielens_folder="./../Data/ml-1m/", num_pairs_per_user=5, use_cot =False):
    """
    Generate pairwise comparison prompts from MovieLens data.
    
    Args:
        movielens_folder (str): Path to MovieLens dataset folder.
        num_pairs_per_user (int): Number of prompts to generate per user.
    
    Returns:
        pd.DataFrame: DataFrame containing prompts and labels.
    """
    # Read ratings and movies data
    ratings = pd.read_csv(
        f"{movielens_folder}ratings.dat", 
        sep="::", 
        engine="python", 
        header=None, 
        names=["userId", "movieId", "rating", "timestamp"], 
        encoding="ISO-8859-1"
    )

    movies = pd.read_csv(
        f"{movielens_folder}movies.dat", 
        sep="::", 
        engine="python", 
        header=None, 
        names=["movieId", "title", "genres"], 
        encoding="ISO-8859-1"
    )

    # Merge ratings with movie info
    movie_ratings = pd.merge(ratings, movies, on='movieId')

    def generate_prompt_label_pairs(user_id, user_rated_movies, all_movies, n_pairs=5, use_cot=False):
        import random
        pairs = []

        liked_movies = user_rated_movies[user_rated_movies['rating'] >= 4]
        disliked_movies = user_rated_movies[user_rated_movies['rating'] <= 2.5]

        if len(liked_movies) < 4:
            return []

        for _ in range(n_pairs):
            n = min(len(liked_movies), 11)
            liked_sample = liked_movies.sample(n).reset_index(drop=True)

            context_sample = liked_sample.iloc[:-1]
            prediction_movie = liked_sample.iloc[-1]
            context_sample = context_sample.sample(min(len(context_sample), 10))

            context_movies = [
                f"{row['title']} ({row['genres']})"
                for _, row in context_sample.iterrows()
            ]

            if len(disliked_movies) > 0:
                disliked = disliked_movies.sample(1).iloc[0]
            else:
                unrated = all_movies[~all_movies['movieId'].isin(user_rated_movies['movieId'])]
                if len(unrated) == 0:
                    continue
                disliked = unrated.sample(1).iloc[0]

            if random.random() < 0.5:
                movie1, movie2 = prediction_movie, disliked
                correct_label = 1
            else:
                movie1, movie2 = disliked, prediction_movie
                correct_label = 2

            intro = f"The user liked the following movies: {', '.join(context_movies)}.\n\n"
            cot_instruction = (
                "Please think step-by-step about the genre and the year of each movie when making a decision.\n\n"
                if use_cot else ""
            )
            question = f"""Which movie is the user more likely to prefer?
                1. {movie1['title']} ({movie1['genres']})
                2. {movie2['title']} ({movie2['genres']})
                Please answer with 1 or 2 only."""

            prompt = intro + cot_instruction + question
            pairs.append((prompt, correct_label))

        return pairs

    # Process all users
    all_pairs = []
    user_groups = movie_ratings.groupby('userId')

    for user_id, user_rated_movies in user_groups:
        pairs = generate_prompt_label_pairs(user_id, user_rated_movies, movies, num_pairs_per_user, use_cot= use_cot)
        all_pairs.extend(pairs)
    
    df = pd.DataFrame(all_pairs, columns=['prompt', 'label'])
    return df


#### Step 2: Prepare Tokenizer and Load the Teacher and Student Model

In [28]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# import mlflow
# mlflow.autolog(disable=True)

import torch
from torch.utils.data import Dataset

class ComparisonDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.examples = data.reset_index(drop=True)  # Ensure integer index
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        row = self.examples.iloc[idx]  # ✅ Access DataFrame row correctly
        prompt = row["prompt"]
        label = int(row["label"])  # Ensure label is integer

        inputs = self.tokenizer(
            prompt,
            truncation=True,
            padding="max_length",
            max_length=128,
            return_tensors="pt"
        )

        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "label": label
        }


#### Step 4: Distill Knowledge From Teacher to Student

In [29]:
import torch

# ============ 5. Define Logits-Based Distillation Loss ============
def distillation_loss(student_logits, teacher_logits, temperature=2.0):
    student_log_probs = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1)
    teacher_probs = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)
    loss = torch.nn.functional.kl_div(
        student_log_probs,
        teacher_probs,
        reduction='batchmean'
    ) * (temperature ** 2)
    return loss

# ============ 6. Distillation Training Loop ============
def train_distill_model(student, teacher, dataset, tokenizer, device="cpu", loss_type="kl", temperature=2.0):
    """
    Args:
        loss_type: 'kl' for KL divergence only, 'hybrid' for KL + CE loss
    """
    student.train()
    teacher.eval()

    optimizer = torch.optim.Adam(student.parameters(), lr=5e-5)

    for epoch in range(3):
        total_loss = 0
        for sample in dataset:
            input_ids = sample['input_ids'].unsqueeze(0).to(device)
            attention_mask = sample['attention_mask'].unsqueeze(0).to(device)
            
            label_text = str(sample["label"])
            label_tokenized = tokenizer(label_text, return_tensors="pt", padding=True, truncation=True)
            labels = label_tokenized["input_ids"].to(device)
            decoder_input_ids = student.prepare_decoder_input_ids_from_labels(labels)
            # labels = torch.tensor([sample['label']], dtype=torch.long).to(device)

            # decoder_input_ids = student.prepare_decoder_input_ids_from_labels(labels)

            # Forward pass for teacher
            with torch.no_grad():
                teacher_outputs = teacher(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    decoder_input_ids=decoder_input_ids
                )
                teacher_logits = teacher_outputs.logits

            # Forward pass for student
            student_outputs = student(
                input_ids=input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                labels=labels if loss_type == "hybrid" else None
            )
            student_logits = student_outputs.logits

            # Compute loss
            if loss_type == "kl":
                loss = distillation_loss(student_logits, teacher_logits, temperature=temperature)
            elif loss_type == "hybrid":
                ce_loss = student_outputs.loss
                kl_loss = distillation_loss(student_logits, teacher_logits, temperature=temperature)
                loss = ce_loss + 0.5 * kl_loss
            else:
                raise ValueError(f"Unsupported loss_type: {loss_type}")

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}: {loss_type.upper()} distillation loss = {total_loss:.4f}")


#### Step 7: Evaluate the Distilled Model
- Evaluate model accuracy of teacher and distilled models.
- Compare inference time of teacher and distilled models.

In [30]:
import yaml
import openai

with open('./../../../Curify/curify_api.yaml', 'r') as yaml_file:
    data = yaml.safe_load(yaml_file)

# Access the API keys and other configuration data
openai_api_key = data.get('openai').get('api_key')

client = openai.OpenAI(api_key=openai_api_key)


def filter_hard_samples_with_gpt4o(df, model="gpt-4o-mini", sleep_time=0.5):
    """
    Filter out evaluation samples that GPT-4o answers incorrectly.
    Returns the filtered dataframe and the percentage removed.
    """
    import time

    filtered = []
    retained = []

    for i, row in df.iterrows():
        prompt = row['prompt']
        label = str(row['label']).strip()

        try:
            response = client.chat.completions.create(
                model=model,
                messages=[
                    {"role": "system", "content": "You are a helpful assistant that chooses between two movie options."},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.2
            )
            
            gpt_output = response.choices[0].message.content.strip()
            print(gpt_output)
            if label in gpt_output:
                retained.append(row)
            else:
                filtered.append(row)

            time.sleep(sleep_time)

        except Exception as e:
            print(f"Error on row {i}: {e}")
            retained.append(row)  # Fail-safe: keep sample

    retained_df = pd.DataFrame(retained)
    filtered_pct = len(filtered) / len(df) * 100

    print(f"Filtered {len(filtered)} out of {len(df)} samples ({filtered_pct:.2f}%) using GPT-4o.")

    return retained_df, filtered_pct


def evaluate_model(model, dataframe, tokenizer):
    model.eval()
    correct = 0
    results = []
    
    for _, row in dataframe.iterrows():
        prompt = row['prompt']
        label = str(row['label']).strip()
        inputs = tokenizer(prompt, return_tensors='pt')
        outputs = model.generate(**inputs, max_new_tokens=10)
        decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        
        results.append({
            "prompt": prompt,
            "label": label,
            "decoded": decoded
        })

        if label in decoded:
            correct += 1

    accuracy = correct / len(dataframe)
    return results, accuracy
    

In [31]:
import time
import json
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import Dataset

def run_distillation_experiment(train_data, test_data, use_cot=True, loss_type="kl"):
    print(f"\n========== Running Experiment (use_cot={use_cot}, loss_type='{loss_type}') ==========")

    # Tokenizer and datasets
    tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
    train_dataset = ComparisonDataset(train_data, tokenizer)
    test_dataset = ComparisonDataset(test_data, tokenizer)

    # Load models
    teacher = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
    student = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")

    # Evaluation before distillation
    print("\nBefore distillation:")
    start_time = time.time()
    results_large, acc_large = evaluate_model(teacher, test_data, tokenizer)
    inference_time_large = time.time() - start_time

    start_time = time.time()
    results_small, acc_small = evaluate_model(student, test_data, tokenizer)
    inference_time_small = time.time() - start_time

    print(f"Flan-T5-Large Accuracy: {acc_large:.2f}")
    print(f"Flan-T5-Small Accuracy: {acc_small:.2f}")

    # Distillation
    print("\nTraining distilled student...")
    start_time = time.time()
    train_distill_model(student, teacher, train_dataset, tokenizer, loss_type=loss_type)
    distillation_time = time.time() - start_time
    print(f"Distillation completed in {distillation_time:.2f} seconds.")

    # Evaluation after distillation
    results_distilled, acc_distilled = evaluate_model(student, test_data, tokenizer)
    print(f"Distilled Flan-T5-Small Accuracy: {acc_distilled:.2f}")

    evaluation_final = {
        "acc_large": acc_large,
        "acc_small": acc_small,
        "acc_distilled": acc_distilled,
        "time_large": inference_time_large,
        "time_small": inference_time_small,
        "distillation_time": distillation_time,
        "Inference_volume": len(test_data)
    }
    raw_outputs = {
        "large": results_large,
        "small_before_distillation": results_small,
        "small_after_distillation": results_distilled
    }
    return evaluation_final, raw_outputs


In [32]:
# Global Preparation
from sklearn.model_selection import train_test_split



In [35]:

evaluation_res = {}
raw_res = {}

for use_cot in [False, True]:
    # Generate and filter once
    prompt_data_full = generate_pairwise_prompts_from_movielens(use_cot=use_cot)  # always generate CoT version

    prompt_data = prompt_data_full.head(5000)
    filtered_data, pct_filtered = filter_hard_samples_with_gpt4o(prompt_data)
    print(f"Filtered dataset size: {len(filtered_data)} (Filtered out {pct_filtered:.2f}%)")

    # Same train/test split reused across experiments
    train_data, test_data = train_test_split(filtered_data, test_size=0.2, random_state=42)

    for loss_type in ["kl", "hybrid"]:
        config_name = f"cot_{use_cot}_loss_{loss_type}"
        evaluation, raw = run_distillation_experiment(
            train_data=train_data,
            test_data=test_data,
            use_cot=use_cot,
            loss_type=loss_type
        )
        evaluation_res[config_name] = evaluation
        raw_res[config_name] = raw

# Save results
with open("all_distillation_evaluation.json", "w") as f:
    json.dump(evaluation_res, f, indent=2, default=str)

with open("all_distillation_raw.json", "w") as f:
    json.dump(raw_res, f, indent=2, default=str)


2
2
1
1
2
1
1
1
1
1
2
1
1
2
2
2
1
1
2
1
1
2
2
2
1
2
2
2
1
1
1
2
2
2
2
1
1
2
1
2
1
2
1
1
2
2
1
2
1
2
1
1
1
1
1
2
1
1
2
2
1
1
1
2
1
1
1
1
2
1
2
2
1
2
1
2
1
1
2
2
1
2
2
2
1
1
2
2
1
2
1
1
1
2
1
2
1
1
1
1
1
1
1
1
1
2
2
1
2
2
1
1
1
1
2
1
1
2
2
1
2
1
1
1
1
2
2
1
1
2
1
1
2
1
1
1
1
1
2
2
1
1
2
1
2
2
1
1
1
2
2
2
1
2
2
1
2
1
1
1
2
1
1
1
1
1
1
2
1
1
1
1
1
1
2
1
1
2
1
2
1
2
1
2
2
1
1
1
1
2
1
1
2
2
2
1
2
2
2
1
1
1
1
2
1
2
1
2
2
1
2
2
1
1
1
2
1
1
2
2
1
2
2
1
2
1
1
2
2
2
1
1
1
1
2
2
2
1
1
1
1
2
2
1
2
2
2
1
2
1
1
1
1
1
1
1
1
2
1
1
2
1
2
1
2
1
2
1
1
1
2
2
1
1
1
1
1
1
2
1
1
1
2
1
1
1
1
2
1
2
1
1
1
1
1
1
1
2
1
1
1
1
2
2
1
2
2
1
2
1
1
2
2
2
1
2
1
2
1
2
1
1
2
2
2
2
1
2
1
2
2
2
2
2
1
1
1
2
1
2
1
2
1
2
2
1
2
1
1
1
1
2
1
1
2
1
1
2
1
2
1
2
1
1
2
1
2
2
2
2
2
1
1
2
2
1
2
1
2
1
2
1
2
1
2
1
2
1
1
1
1
2
1
2
2
2
1
1
1
1
2
2
2
2
1
2
1
1
1
1
1
1
2
2
2
1
1
1
1
2
1
1
1
1
1
1
1
2
2
1
1
2
2
1
1
2
1
2
2
1
1
1
1
2
1
2
1
2
1
1
1
2
2
1
1
1
2
2
1
2
2
2
1
1
2
1
2
2
2
2
1
1
1
1
1
2
2
1
1
2
1
2
1
2
2
2
2
1
2
2
1
1
1
2
1
1
2
2
2
2


In [24]:
print(1)

1
