# Rule-Filtered Vector Similarity Retrieval

This notebook demonstrates a two-stage retrieval approach:
1. **Filter by rule**: Match the query rule to a rule category
2. **Retrieve top K**: Get most similar examples from that rule's examples using vector similarity

## 1. Setup and Imports

In [1]:
%%writefile rule_filtered_retrieval.py
"""
Rule-Filtered Vector Similarity Retrieval System with Reranking using VLLM

This system implements a three-stage retrieval:
1. Filter by matching rules first
2. Retrieve top K similar examples from filtered chunks using vector similarity
3. Rerank retrieved examples using cross-encoder for better relevance
"""

import numpy as np
import pandas as pd
import faiss
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
import pickle
import os
from pathlib import Path
import multiprocessing as mp
mp.set_start_method('spawn', force=True)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

from vllm import LLM
from vllm.distributed.parallel_state import destroy_model_parallel


class Qwen3RerankerVLLM:
    """Reranker using Qwen3-Reranker model with VLLM for cross-encoder scoring"""
    def __init__(self, model_name='Qwen/Qwen3-Reranker-0.6B', max_length=1024):
        # Initialize vLLM model for reranking
        number_of_gpu = torch.cuda.device_count()
        self.model = LLM(
            model=model_name,
            task="score",  # Use score task for reranking
            tensor_parallel_size=number_of_gpu,
            enable_prefix_caching=True,
            gpu_memory_utilization=0.3,  # Lower GPU memory for reranker
            max_model_len=3072,  # Reduce from default 32768 to fit in available memory
            trust_remote_code=True,
            dtype='half'  # Use float16 for memory efficiency
        )
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.max_length = max_length

    def rerank(self, query: str, documents: List[str], top_k: int = 10, batch_size: int = 32) -> List[Dict]:
        """
        Rerank documents based on query using cross-encoder scoring with VLLM

        Args:
            query: The query text
            documents: List of document texts to rerank
            top_k: Number of top results to return
            batch_size: Batch size for processing (note: VLLM handles batching internally)

        Returns:
            List of dicts with 'index' and 'score' sorted by relevance
        """
        # Get scores from VLLM - pass query as text_1 and documents as text_2
        # VLLM score() expects: score(text_1, text_2) where text_1 is query and text_2 is list of docs
        outputs = self.model.score(query, documents)
        
        # Extract scores from outputs
        scores = [o.outputs.score for o in outputs]
        
        # Create results with original indices
        results = [{'index': idx, 'score': score} for idx, score in enumerate(scores)]

        # Sort by score descending and return top_k
        results = sorted(results, key=lambda x: x['score'], reverse=True)[:top_k]

        return results
    
    def __del__(self):
        """Cleanup vLLM resources"""
        try:
            destroy_model_parallel()
        except:
            pass


class Qwen3Reranker:
    """Reranker using Qwen3-Reranker model with transformers for cross-encoder scoring"""
    def __init__(self, model_name='Qwen/Qwen3-Reranker-0.6B', max_length=1024):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(
            model_name,
            trust_remote_code=True,
            torch_dtype=torch.float16
        )
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model.to(self.device)
        self.model.eval()
        self.max_length = max_length

    def rerank(self, query: str, documents: List[str], top_k: int = 10, batch_size: int = 32) -> List[Dict]:
        """
        Rerank documents based on query using cross-encoder scoring

        Args:
            query: The query text
            documents: List of document texts to rerank
            top_k: Number of top results to return
            batch_size: Batch size for processing

        Returns:
            List of dicts with 'index' and 'score' sorted by relevance
        """
        all_scores = []

        # Process in batches
        for i in range(0, len(documents), batch_size):
            batch_docs = documents[i:i+batch_size]

            # Create query-document pairs
            pairs = [[query, doc] for doc in batch_docs]

            # Tokenize pairs
            encoded = self.tokenizer(
                pairs,
                padding=True,
                truncation=True,
                return_tensors='pt',
                max_length=self.max_length
            )
            encoded = {k: v.to(self.device) for k, v in encoded.items()}

            with torch.no_grad():
                outputs = self.model(**encoded, return_dict=True)
                
                # For Qwen3-Reranker, use last hidden state to compute relevance scores
                # Get [CLS] token embeddings (first token) and project to scores
                last_hidden_state = outputs.last_hidden_state
                cls_embeddings = last_hidden_state[:, 0, :]  # Shape: [batch_size, hidden_dim]
                
                # Simple scoring: use mean of CLS embedding as relevance score
                # You can also train a linear layer on top if you have labeled data
                scores = cls_embeddings.mean(dim=-1).cpu().float().numpy()

            all_scores.extend(scores.tolist() if isinstance(scores, np.ndarray) else [scores])

        # Create results with original indices
        results = [{'index': idx, 'score': score} for idx, score in enumerate(all_scores)]

        # Sort by score descending and return top_k
        results = sorted(results, key=lambda x: x['score'], reverse=True)[:top_k]

        return results


class FastEmbedder:
    """Lightweight embedder using Qwen3-Embedding model with vLLM for faster GPU inference"""
    def __init__(self, model_name='Qwen/Qwen3-Embedding-0.6B', output_dim=512):
        # Initialize vLLM model with multi-GPU support
        number_of_gpu = torch.cuda.device_count()
        self.model = LLM(
            model=model_name,
            task="embed",
            tensor_parallel_size=number_of_gpu,
            enable_prefix_caching=True,
            gpu_memory_utilization=0.3,
            max_model_len=3072,  # Reduce from default 32768 to fit in available memory
            trust_remote_code=True,
            dtype='half'  # Use float16 for memory efficiency
        )
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.output_dim = output_dim
        self.max_length = 1024

    def get_detailed_instruct(self, task_description: str, query: str) -> str:
        """Format instruction with query"""
        return f'Instruct: {task_description}\nQuery: {query}'

    def encode(self, texts, batch_size=32, convert_to_numpy=True, instruction=None):
        """Encode texts to embeddings with optional instruction"""
        if instruction:
            texts = [self.get_detailed_instruct(instruction, text) for text in texts]

        # Use vLLM embed method
        outputs = self.model.embed(texts)
        
        # Extract embeddings from outputs - VLLM returns list of EmbeddingRequestOutput
        embeddings_list = [o.outputs.embedding for o in outputs]
        
        # Convert to tensor
        embeddings = torch.tensor(embeddings_list, dtype=torch.float32)

        # Apply output dimension reduction if needed
        if self.output_dim and self.output_dim < embeddings.shape[1]:
            embeddings = embeddings[:, :self.output_dim]
            embeddings = F.normalize(embeddings, p=2, dim=1)

        if convert_to_numpy:
            return embeddings.cpu().numpy().astype(np.float32)
        return embeddings
    
    def __del__(self):
        """Cleanup vLLM resources"""
        try:
            destroy_model_parallel()
        except:
            pass


class RuleFilteredRetriever:
    """
    Three-stage retrieval system:
    1. Filter by rule category
    2. Retrieve top K similar examples from filtered set
    3. Rerank using cross-encoder
    """

    def __init__(self, embedder: FastEmbedder, reranker: Optional[Qwen3Reranker] = None, use_gpu: bool = True):
        self.embedder = embedder
        self.reranker = reranker
        self.use_gpu = use_gpu and faiss.get_num_gpus() > 0

        # Store rule categories and their indices
        self.rule_to_indices = defaultdict(list)
        self.rule_to_faiss_index = {}
        self.rule_to_data = {}

        # Global storage
        self.all_data = None
        self.all_vectors = None

    def build_index(self, df: pd.DataFrame, rule_col: str = 'rule',
                   comment_col: str = 'test_comment',
                   value_col: str = 'value',
                   instruction: Optional[str] = None):
        """
        Build FAISS indices per rule category
        Uses ONLY comment embeddings for similarity search within each rule

        Args:
            df: DataFrame with rules, comments, and values
            rule_col: Column name for rules
            comment_col: Column name for comments
            value_col: Column name for target values
            instruction: Optional instruction for embeddings
        """
        print(f"Building rule-filtered index from {len(df)} examples...")
        print("Strategy: Filter by rule FIRST, then search by COMMENT similarity only")

        # Store data
        df = df.copy()
        self.all_data = df

        # Group by rule
        rule_groups = df.groupby(rule_col)
        print(f"Found {len(rule_groups)} unique rules")

        # Encode ONLY comments for similarity search (not rule+comment combined)
        print("Encoding comments only (not combined with rules)...")
        all_vectors = self.embedder.encode(
            df[comment_col].tolist(),
            instruction=instruction
        )
        faiss.normalize_L2(all_vectors)
        self.all_vectors = all_vectors

        # Build separate FAISS index for each rule
        print("Building per-rule FAISS indices...")
        for rule, group_df in rule_groups:
            indices = group_df.index.tolist()
            rule_vectors = all_vectors[indices]

            # Store mapping
            self.rule_to_indices[rule] = indices
            self.rule_to_data[rule] = group_df

            # Build FAISS index for this rule using ONLY comment vectors
            dimension = rule_vectors.shape[1]
            index = faiss.IndexFlatIP(dimension)

            if self.use_gpu:
                res = faiss.StandardGpuResources()
                index = faiss.index_cpu_to_gpu(res, 0, index)

            index.add(rule_vectors)
            self.rule_to_faiss_index[rule] = index

            print(f"  Rule '{rule[:50]}...' : {len(indices)} examples")

        print(f"\nIndex building complete!")
        print(f"Total rules: {len(self.rule_to_faiss_index)}")
        print(f"Total examples: {len(df)}")

    def retrieve(self, query_rule: str, query_comment: str,
                top_k: int = 10,
                retrieval_k: int = 100,
                instruction: Optional[str] = None,
                fallback_to_similar_rules: bool = True,
                use_reranking: bool = True) -> Dict:
        """
        Three-stage retrieval:
        1. Filter by rule (exact or similar)
        2. Retrieve top retrieval_k by comment similarity
        3. Rerank to get final top_k (if reranker available)

        Args:
            query_rule: The rule to match
            query_comment: The comment to find similar examples for
            top_k: Final number of examples after reranking (default: 10)
            retrieval_k: Number of examples to retrieve before reranking (default: 100)
            instruction: Optional instruction for encoding
            fallback_to_similar_rules: If exact rule not found, use rule similarity
            use_reranking: Whether to use reranker (if available)

        Returns:
            Dict with retrieved examples and metadata including both FAISS and rerank scores
        """
        # STAGE 1: Filter by rule first
        if query_rule in self.rule_to_faiss_index:
            matched_rule = query_rule
            rule_match_type = "exact"
        elif fallback_to_similar_rules:
            matched_rule, rule_match_type = self._find_similar_rule(query_rule, instruction)
        else:
            return {
                'success': False,
                'error': f'Rule not found: {query_rule}',
                'matched_rule': None,
                'examples': []
            }

        # STAGE 2: Encode ONLY the comment for similarity search
        query_vector = self.embedder.encode(
            [query_comment],
            instruction=instruction
        )
        faiss.normalize_L2(query_vector)

        # Get FAISS index for matched rule
        rule_index = self.rule_to_faiss_index[matched_rule]
        rule_data = self.rule_to_data[matched_rule]

        # Retrieve top retrieval_k candidates
        similarities, indices = rule_index.search(query_vector, min(retrieval_k, len(rule_data)))

        # Get actual dataframe indices
        actual_indices = [self.rule_to_indices[matched_rule][i] for i in indices[0]]
        retrieved_examples = self.all_data.iloc[actual_indices]

        # STAGE 3: Rerank if reranker is available and enabled
        if use_reranking and self.reranker is not None and len(retrieved_examples) > 0:
            # Get document texts for reranking
            documents = retrieved_examples['body'].tolist()
            
            # Rerank using cross-encoder
            rerank_results = self.reranker.rerank(
                query=query_comment,
                documents=documents,
                top_k=min(top_k, len(documents))
            )
            
            # Reorder examples based on reranking scores
            reranked_indices = [r['index'] for r in rerank_results]
            reranked_scores = [r['score'] for r in rerank_results]
            
            # Get corresponding FAISS scores for the reranked examples
            faiss_scores_reranked = [similarities[0][idx] for idx in reranked_indices]
            
            # Get reranked examples
            final_examples = retrieved_examples.iloc[reranked_indices].reset_index(drop=True)
            
            return {
                'success': True,
                'matched_rule': matched_rule,
                'rule_match_type': rule_match_type,
                'query_rule': query_rule,
                'query_comment': query_comment,
                'faiss_scores': faiss_scores_reranked,
                'rerank_scores': reranked_scores,
                'examples': final_examples.to_dict('records'),
                'num_examples_in_rule': len(rule_data),
                'reranked': True,
                'retrieval_k': retrieval_k,
                'final_k': len(final_examples)
            }
        else:
            # No reranking - return top_k from initial retrieval
            final_examples = retrieved_examples.iloc[:top_k]
            final_scores = similarities[0][:top_k].tolist()
            
            return {
                'success': True,
                'matched_rule': matched_rule,
                'rule_match_type': rule_match_type,
                'query_rule': query_rule,
                'query_comment': query_comment,
                'faiss_scores': final_scores,
                'rerank_scores': None,
                'examples': final_examples.to_dict('records'),
                'num_examples_in_rule': len(rule_data),
                'reranked': False,
                'retrieval_k': retrieval_k,
                'final_k': len(final_examples)
            }

    def _find_similar_rule(self, query_rule: str, instruction: Optional[str] = None) -> Tuple[str, str]:
        """Find the most similar rule to query_rule using embedding similarity"""
        # Encode query rule
        query_rule_vector = self.embedder.encode(
            [f"rule: {query_rule}"],
            instruction=instruction
        )
        faiss.normalize_L2(query_rule_vector)

        # Encode all unique rules
        unique_rules = list(self.rule_to_faiss_index.keys())
        rule_vectors = self.embedder.encode(
            [f"rule: {r}" for r in unique_rules],
            instruction=instruction
        )
        faiss.normalize_L2(rule_vectors)

        # Find most similar
        similarities = np.dot(query_rule_vector, rule_vectors.T)[0]
        best_idx = np.argmax(similarities)
        best_rule = unique_rules[best_idx]

        return best_rule, f"similar (similarity: {similarities[best_idx]:.3f})"

    def batch_retrieve(self, df_query: pd.DataFrame,
                      rule_col: str = 'rule',
                      comment_col: str = 'test_comment',
                      top_k: int = 10,
                      retrieval_k: int = 100,
                      instruction: Optional[str] = None,
                      use_reranking: bool = True,
                      faiss_weight: float = 0.5,
                      rerank_weight: float = 0.5) -> pd.DataFrame:
        """
        Batch retrieval for multiple queries with weighted score combination

        Args:
            df_query: DataFrame with query rules and comments
            rule_col: Column name for rules
            comment_col: Column name for comments
            top_k: Final number of examples after reranking (default: 10)
            retrieval_k: Number to retrieve before reranking (default: 100)
            instruction: Optional instruction for encoding
            use_reranking: Whether to use reranker
            faiss_weight: Weight for FAISS scores (default: 0.5)
            rerank_weight: Weight for rerank scores (default: 0.5)

        Returns:
            DataFrame with predictions
        """
        results = []

        for idx, row in df_query.iterrows():
            query_rule = row[rule_col]
            query_comment = row[comment_col]

            # Retrieve and rerank examples
            retrieval_result = self.retrieve(
                query_rule=query_rule,
                query_comment=query_comment,
                top_k=top_k,
                retrieval_k=retrieval_k,
                instruction=instruction,
                use_reranking=use_reranking
            )

            if retrieval_result['success']:
                examples = retrieval_result['examples']
                values = [ex['value'] for ex in examples]

                # Calculate weighted score combination
                if retrieval_result['reranked'] and retrieval_result['rerank_scores'] is not None:
                    # Weighted combination: (faiss_weight * faiss_score + rerank_weight * rerank_score) * value
                    faiss_scores = retrieval_result['faiss_scores']
                    rerank_scores = retrieval_result['rerank_scores']
                    
                    # Normalize scores to similar ranges if needed
                    # FAISS scores are cosine similarities (0-1 range typically)
                    # Rerank scores may need normalization depending on model output
                    
                    weighted_sum = sum(
                        (faiss_weight * faiss_score + rerank_weight * rerank_score) * val
                        for faiss_score, rerank_score, val in zip(faiss_scores, rerank_scores, values)
                    )
                else:
                    # No reranking - use only FAISS scores
                    faiss_scores = retrieval_result['faiss_scores']
                    weighted_sum = sum(score * val for score, val in zip(faiss_scores, values))

                # Decision based on sign of weighted sum
                decision = 1 if weighted_sum >= 0 else 0
                avg_violation = weighted_sum  # Store weighted sum for inspection

                results.append({
                    'row_id': idx,
                    'rule_violation': avg_violation,
                    'decision': decision,
                    'matched_rule': retrieval_result['matched_rule'],
                    'rule_match_type': retrieval_result['rule_match_type'],
                    'num_retrieved': len(examples),
                    'reranked': retrieval_result.get('reranked', False)
                })
            else:
                # Fallback: return neutral prediction
                results.append({
                    'row_id': idx,
                    'rule_violation': 0.5,
                    'decision': 0,
                    'matched_rule': None,
                    'rule_match_type': 'not_found',
                    'num_retrieved': 0,
                    'reranked': False
                })

        return pd.DataFrame(results)

    def save_index(self, save_dir: str):
        """
        Save the retriever indices and data to disk

        Args:
            save_dir: Directory to save the index files
        """
        save_path = Path(save_dir)
        save_path.mkdir(parents=True, exist_ok=True)

        print(f"Saving index to {save_dir}...")

        # Save FAISS indices (convert GPU to CPU first if needed)
        faiss_dir = save_path / "faiss_indices"
        faiss_dir.mkdir(exist_ok=True)

        for rule, index in self.rule_to_faiss_index.items():
            # Convert to CPU index if it's a GPU index
            if self.use_gpu:
                cpu_index = faiss.index_gpu_to_cpu(index)
            else:
                cpu_index = index

            # Create safe filename from rule
            safe_filename = f"index_{hash(rule) % 10**8}.faiss"
            faiss.write_index(cpu_index, str(faiss_dir / safe_filename))

        # Save metadata (rule mappings)
        metadata = {
            'rule_to_indices': dict(self.rule_to_indices),
            'rule_to_filename': {rule: f"index_{hash(rule) % 10**8}.faiss"
                                 for rule in self.rule_to_faiss_index.keys()},
            'use_gpu': self.use_gpu
        }

        with open(save_path / "metadata.pkl", 'wb') as f:
            pickle.dump(metadata, f)

        # Save training data
        if self.all_data is not None:
            self.all_data.to_parquet(save_path / "training_data.parquet", index=True)

        # Save embeddings
        if self.all_vectors is not None:
            np.save(save_path / "embeddings.npy", self.all_vectors)

        # Save rule_to_data mapping
        rule_data_mapping = {rule: df.index.tolist() for rule, df in self.rule_to_data.items()}
        with open(save_path / "rule_data_mapping.pkl", 'wb') as f:
            pickle.dump(rule_data_mapping, f)

        print(f"✓ Index saved successfully!")
        print(f"  - {len(self.rule_to_faiss_index)} FAISS indices")
        print(f"  - {len(self.all_data)} training examples")
        print(f"  - Embeddings shape: {self.all_vectors.shape}")

    def load_index(self, load_dir: str):
        """
        Load a previously saved index from disk

        Args:
            load_dir: Directory containing the saved index files
        """
        load_path = Path(load_dir)

        if not load_path.exists():
            raise FileNotFoundError(f"Index directory not found: {load_dir}")

        print(f"Loading index from {load_dir}...")

        # Load metadata
        with open(load_path / "metadata.pkl", 'rb') as f:
            metadata = pickle.load(f)

        self.rule_to_indices = defaultdict(list, metadata['rule_to_indices'])
        rule_to_filename = metadata['rule_to_filename']
        saved_use_gpu = metadata['use_gpu']

        # Load training data
        if (load_path / "training_data.parquet").exists():
            self.all_data = pd.read_parquet(load_path / "training_data.parquet")

        # Load embeddings
        if (load_path / "embeddings.npy").exists():
            self.all_vectors = np.load(load_path / "embeddings.npy")

        # Load rule_to_data mapping
        with open(load_path / "rule_data_mapping.pkl", 'rb') as f:
            rule_data_mapping = pickle.load(f)

        # Reconstruct rule_to_data from mapping
        self.rule_to_data = {}
        for rule, indices in rule_data_mapping.items():
            self.rule_to_data[rule] = self.all_data.loc[indices]

        # Load FAISS indices
        faiss_dir = load_path / "faiss_indices"
        self.rule_to_faiss_index = {}

        for rule, filename in rule_to_filename.items():
            # Load CPU index
            cpu_index = faiss.read_index(str(faiss_dir / filename))

            # Convert to GPU if requested and available
            if self.use_gpu and faiss.get_num_gpus() > 0:
                res = faiss.StandardGpuResources()
                gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
                self.rule_to_faiss_index[rule] = gpu_index
            else:
                self.rule_to_faiss_index[rule] = cpu_index

        print(f"✓ Index loaded successfully!")
        print(f"  - {len(self.rule_to_faiss_index)} FAISS indices")
        print(f"  - {len(self.all_data)} training examples")
        print(f"  - Embeddings shape: {self.all_vectors.shape}")
        print(f"  - Using GPU: {self.use_gpu}")

Overwriting rule_filtered_retrieval.py


## 2. Load Training/Test Data & format it

In [2]:
%%writefile load_train_test_data.py

# Load training data
import pandas as pd
import numpy as np
from cleantext import clean

def cleaner(text):
    """Clean text by removing URLs, emails, phone numbers"""
    return clean(
        text,
        fix_unicode=True,
        to_ascii=True,
        lower=False,
        no_line_breaks=False,
        no_urls=True,
        no_emails=True,
        no_phone_numbers=True,
        no_numbers=False,
        no_digits=False,
        no_currency_symbols=False,
        no_punct=False,
        replace_with_url="<URL>",
        replace_with_email="<EMAIL>",
        replace_with_phone_number="<PHONE>",
        lang="en",
    )

def build_prompt(row):
    """Build prompt with subreddit context"""
    return f"""r/{row["subreddit"]}\nComment: {row["body"]}"""

base_path = "./data/"
df_train = pd.read_csv(f"{base_path}train.csv")
df_train['value']=df_train['rule_violation'].copy()
df_train['value'] = df_train['value'].replace(0, -1)
df_train = df_train.sample(frac=1).reset_index(drop=True)

print(f"Total training examples: {len(df_train):,}")
print(f"Unique rules: {df_train['rule'].nunique():,}")
print(f"\nColumns: {df_train.columns.tolist()}")

# Show rule distribution
rule_counts = df_train['rule'].value_counts()
print(f"\nRule distribution statistics:")
print(f"  Mean examples per rule: {rule_counts.mean():.1f}")
print(f"  Median examples per rule: {rule_counts.median():.1f}")
print(f"  Min examples per rule: {rule_counts.min()}")
print(f"  Max examples per rule: {rule_counts.max()}")

df_train.head()
print(df_train["value"].unique())

### format data
# Step 1: Keep only the main row
df_main = df_train[['rule', 'subreddit', 'body', 'rule_violation']].copy()

# Step 2: Positive examples
df_pos1 = df_train[['rule', 'subreddit', 'positive_example_1']].copy()
df_pos1['rule_violation'] = 1
df_pos1.rename(columns={'positive_example_1': 'body'}, inplace=True)

df_pos2 = df_train[['rule', 'subreddit', 'positive_example_2']].copy()
df_pos2['rule_violation'] = 1
df_pos2.rename(columns={'positive_example_2': 'body'}, inplace=True)

# Step 3: Negative examples
df_neg1 = df_train[['rule', 'subreddit', 'negative_example_1']].copy()
df_neg1['rule_violation'] = 0
df_neg1.rename(columns={'negative_example_1': 'body'}, inplace=True)

df_neg2 = df_train[['rule', 'subreddit', 'negative_example_2']].copy()
df_neg2['rule_violation'] = 0
df_neg2.rename(columns={'negative_example_2': 'body'}, inplace=True)

# Step 4: Concatenate all
df_train_vs = pd.concat([df_main, df_pos1, df_pos2, df_neg1, df_neg2], ignore_index=True)
df_train_vs["value"]=df_train_vs["rule_violation"].copy()

# Step 5: Convert rule_violation to binary value
df_train_vs['value'] = df_train_vs['value'].replace(0, -1)

print(df_train_vs.shape)
print(df_train_vs["value"].unique())
# Show rule distribution
rule_counts = df_train_vs['rule'].value_counts()
print(f"\nRule distribution statistics:")
print(f"  Mean examples per rule: {rule_counts.mean():.1f}")
print(f"  Median examples per rule: {rule_counts.median():.1f}")
print(f"  Min examples per rule: {rule_counts.min()}")
print(f"  Max examples per rule: {rule_counts.max()}")

df_train_vs.head()

#load test data - sample 60% for training augmentation
df_test = pd.read_csv(f"{base_path}test.csv")
df_test_sample = df_test.sample(frac=0.7, random_state=21).reset_index(drop=True)

# Initialize new columns
df_test_sample['rule_violation'] = 100
df_test_sample['value'] = 100

# Shuffle the dataframe
df_test_sample = df_test_sample.sample(frac=1).reset_index(drop=True)

# Show rule distribution
rule_counts = df_test_sample['rule'].value_counts()
print(f"\nTest sample rule distribution statistics:")
print(f"  Mean examples per rule: {rule_counts.mean():.1f}")
print(f"  Median examples per rule: {rule_counts.median():.1f}")
print(f"  Min examples per rule: {rule_counts.min()}")
print(f"  Max examples per rule: {rule_counts.max()}")

print(df_test_sample["value"].unique())
print(df_test_sample.shape)
df_test_sample.head(2)

# Step 1: Keep only the main row
# Skip

# Step 2: Positive examples
df_pos1 = df_test_sample[['rule', 'subreddit', 'positive_example_1']].copy()
df_pos1['rule_violation'] = 1
df_pos1.rename(columns={'positive_example_1': 'body'}, inplace=True)

df_pos2 = df_test_sample[['rule', 'subreddit', 'positive_example_2']].copy()
df_pos2['rule_violation'] = 1
df_pos2.rename(columns={'positive_example_2': 'body'}, inplace=True)

# Step 3: Negative examples
df_neg1 = df_test_sample[['rule', 'subreddit', 'negative_example_1']].copy()
df_neg1['rule_violation'] = 0
df_neg1.rename(columns={'negative_example_1': 'body'}, inplace=True)

df_neg2 = df_test_sample[['rule', 'subreddit', 'negative_example_2']].copy()
df_neg2['rule_violation'] = 0
df_neg2.rename(columns={'negative_example_2': 'body'}, inplace=True)

# Step 4: Concatenate all
df_test_vs = pd.concat([df_pos1, df_pos2, df_neg1, df_neg2], ignore_index=True)
df_test_vs["value"]=df_test_vs["rule_violation"].copy()

# Step 5: Convert rule_violation to binary value
df_test_vs['value'] = df_test_vs['value'].replace(0, -1)
df_test_vs = df_test_vs[df_test_vs['value'].isin([1, -1])]

print(df_test_vs.shape)
print(df_test_vs["value"].unique())
print(df_test_vs["value"].value_counts())

# Show rule distribution
rule_counts = df_test_vs['rule'].value_counts()
print(f"\nRule distribution statistics:")
print(f"  Mean examples per rule: {rule_counts.mean():.1f}")
print(f"  Median examples per rule: {rule_counts.median():.1f}")
print(f"  Min examples per rule: {rule_counts.min()}")
print(f"  Max examples per rule: {rule_counts.max()}")

df_test_vs.head()

#combine train/test data - add test examples to training set
df_train_vs = pd.concat([df_test_vs, df_train_vs], ignore_index=True)
print(f"\nCombined train shape (train + 60% test): {df_train_vs.shape}")

# Apply prompt formatting with subreddit context
print("Applying prompt formatting with subreddit context...")
df_train_vs['prompt'] = df_train_vs.apply(build_prompt, axis=1)

# Apply text cleaning
print("Cleaning text...")
df_train_vs['prompt'] = df_train_vs['prompt'].apply(cleaner)

# Update body column with cleaned prompt
df_train_vs['body'] = df_train_vs['prompt']

# Save
df_train_vs[['rule', 'body', 'rule_violation', 'value']].to_csv("df_train_vs.csv", index=False)
print(f"Final shape: {df_train_vs.shape}")
print("df train vs saved success.!")

Overwriting load_train_test_data.py


In [3]:
%%writefile run_similarity_search.py
import sys
sys.path.append('..')

import pandas as pd
import numpy as np
from rule_filtered_retrieval import FastEmbedder, Qwen3RerankerVLLM, RuleFilteredRetriever
from sklearn.metrics import f1_score, classification_report
from cleantext import clean
import gc
import torch

def cleaner(text):
    """Clean text by removing URLs, emails, phone numbers"""
    return clean(
        text,
        fix_unicode=True,
        to_ascii=True,
        lower=False,
        no_line_breaks=False,
        no_urls=True,
        no_emails=True,
        no_phone_numbers=True,
        no_numbers=False,
        no_digits=False,
        no_currency_symbols=False,
        no_punct=False,
        replace_with_url="<URL>",
        replace_with_email="<EMAIL>",
        replace_with_phone_number="<PHONE>",
        lang="en",
    )

def build_prompt(row):
    """Build prompt with subreddit context"""
    return f"""r/{row["subreddit"]}\nComment: {row["body"]}"""

def clear_gpu_memory():
    """Clear GPU cache and run garbage collection"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        
        # Print memory stats
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            reserved = torch.cuda.memory_reserved(i) / 1024**3
            total = torch.cuda.get_device_properties(i).total_memory / 1024**3
            free = total - allocated
            print(f"GPU {i}: {free:.2f} GB free / {total:.2f} GB total (allocated: {allocated:.2f} GB)")

if __name__ == '__main__':
        
    ###load here---------------------------
    df_train_vs=pd.read_csv("df_train_vs.csv") # index-train data
    df_test = pd.read_csv("./data/train.csv") # test data

    # Apply same preprocessing to test data
    print("Preprocessing test data...")
    df_test['prompt'] = df_test.apply(build_prompt, axis=1)
    df_test['prompt'] = df_test['prompt'].apply(cleaner)
    df_test['body'] = df_test['prompt']

    RETRIEVAL_K = 100  # Retrieve top 100 candidates
    RERANK_TOP_K = 4  # Rerank to get top 10
    USE_VLLM_RERANKER = True  # Set to True to use VLLM-based reranker
    ###------------------------------------

    print(df_train_vs.shape)

    # Initialize embedder with VLLM
    print("Loading Qwen3-Embedding model with VLLM...")
    embedder = FastEmbedder(
        model_name='Qwen/Qwen3-Embedding-0.6B',
        output_dim=1024  # Can experiment with 512, 1024, or full dimension
    )
    print(f"Embedder initialized on device: {embedder.device}")
    

    # Initialize reranker with VLLM
    if USE_VLLM_RERANKER:
        print("\nLoading Qwen3-Reranker model with VLLM...")
        reranker = Qwen3RerankerVLLM(
            model_name='Qwen/Qwen3-Reranker-0.6B',
            max_length=1024
        )
        print(f"Reranker (VLLM) initialized on device: {reranker.device}")
    else:
        print("Reranker disabled for this run")
        reranker = None


    # Initialize retriever and build index
    print("Initializing retriever with reranker...")
    retriever = RuleFilteredRetriever(embedder, reranker=reranker, use_gpu=True)

    # Define instruction for better retrieval
    instruction = "Given a rule and comment, retrieve similar training examples for classification"

    # Build index (this creates separate FAISS index for each rule)
    print("\nBuilding rule-filtered indices...")
    retriever.build_index(
        df_train_vs,
        rule_col='rule',
        comment_col='body',
        value_col='value',
        instruction=instruction
    )

    print("\n✓ Index building complete!")


    # Load test data & predict

    # Batch retrieval with reranking
    print(f"\nEvaluating with retrieval_k={RETRIEVAL_K}, rerank_top_k={RERANK_TOP_K}...")

    results_df = retriever.batch_retrieve(
        df_test,
        rule_col='rule',
        comment_col='body',
        top_k=RERANK_TOP_K,
        retrieval_k=RETRIEVAL_K,
        instruction=instruction,
        use_reranking=USE_VLLM_RERANKER,
        faiss_weight= 0.999,
        rerank_weight= 0.001
    )

    # Calculate metrics
    y_true = df_test["rule_violation"]
    y_pred = results_df["decision"].values
    f1 = f1_score(y_true, y_pred)
    print(f"  F1-Score: {f1:.4f}")
    print("num_retrieved unique:", results_df["num_retrieved"].unique())
    print(f"  Rule match types:")
    print(results_df['rule_match_type'].value_counts().to_string())
    print(f"  Reranked:")
    print(results_df['reranked'].value_counts().to_string())
    print("\n✓ Batch evaluation complete!")


    df_test["rule_violation"]=results_df["decision"].values
    df_test[["row_id","rule_violation"]].to_csv("submission.csv",index=False)
    print(df_test[["row_id","rule_violation"]].head(2))
    print(results_df.head(5))

Overwriting run_similarity_search.py


In [4]:
!python load_train_test_data.py

Total training examples: 2,029
Unique rules: 2

Columns: ['row_id', 'body', 'rule', 'subreddit', 'positive_example_1', 'positive_example_2', 'negative_example_1', 'negative_example_2', 'rule_violation', 'value']

Rule distribution statistics:
  Mean examples per rule: 1014.5
  Median examples per rule: 1014.5
  Min examples per rule: 1012
  Max examples per rule: 1017
[-1  1]
(10145, 5)
[-1  1]

Rule distribution statistics:
  Mean examples per rule: 5072.5
  Median examples per rule: 5072.5
  Min examples per rule: 5060
  Max examples per rule: 5085

Test sample rule distribution statistics:
  Mean examples per rule: 3.5
  Median examples per rule: 3.5
  Min examples per rule: 1
  Max examples per rule: 6
[100]
(7, 10)
(28, 5)
[ 1 -1]
value
 1    14
-1    14
Name: count, dtype: int64

Rule distribution statistics:
  Mean examples per rule: 14.0
  Median examples per rule: 14.0
  Min examples per rule: 4
  Max examples per rule: 24

Combined train shape (train + 60% test): (10173, 5)
A

In [5]:
!python run_similarity_search.py

INFO 10-19 23:49:29 [__init__.py:235] Automatically detected platform cuda.
Preprocessing test data...
(10173, 4)
Loading Qwen3-Embedding model with VLLM...
INFO 10-19 23:49:37 [config.py:538] Found sentence-transformers modules configuration.
INFO 10-19 23:49:37 [config.py:558] Found pooling configuration.
INFO 10-19 23:49:37 [config.py:1604] Using max model len 3072
INFO 10-19 23:49:37 [arg_utils.py:1551] (Enabling) chunked prefill by default
INFO 10-19 23:49:38 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 10-19 23:49:43 [__init__.py:235] Automatically detected platform cuda.
INFO 10-19 23:49:44 [core.py:572] Waiting for init message from front-end.
INFO 10-19 23:49:44 [core.py:71] Initializing a V1 LLM engine (v0.10.0) with config: model='Qwen/Qwen3-Embedding-0.6B', speculative_config=None, tokenizer='Qwen/Qwen3-Embedding-0.6B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, tru