# Load and Use Saved Index for Inference

This notebook demonstrates how to:
1. Load a previously saved rule-filtered retrieval index
2. Use it for fast inference on new data
3. Evaluate performance without rebuilding the index

**Note**: The index must be built first (see nb_11_rule_filtered_retrieval.ipynb)

## 1. Setup and Imports

In [1]:
%%writefile rule_filtered_retrieval.py

"""
Rule-Filtered Vector Similarity Retrieval System

This system implements a two-stage retrieval:
1. Filter by matching rules first
2. Retrieve top K similar examples from filtered chunks using vector similarity
"""

import numpy as np
import pandas as pd
import faiss
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
import pickle
import os
from pathlib import Path


class FastEmbedder:
    """Lightweight embedder using Qwen3-Embedding model"""
    def __init__(self, model_name='Qwen/Qwen3-Embedding-0.6B', output_dim=512):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(
            model_name,
            trust_remote_code=True,
            torch_dtype=torch.float16
        )
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model.to(self.device)
        self.model.eval()
        self.output_dim = output_dim
        self.max_length = 1024

    def last_token_pool(self, last_hidden_states, attention_mask):
        """Pool using last token (EOS) as recommended for Qwen3"""
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            return last_hidden_states[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_states.shape[0]
            return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]

    def encode(self, texts, batch_size=32, convert_to_numpy=True, instruction=None):
        """Encode texts to embeddings with optional instruction"""
        all_embeddings = []

        if instruction:
            texts = [f"Instruct: {instruction}\nQuery: {text}" for text in texts]

        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            encoded = self.tokenizer(
                batch,
                padding=True,
                truncation=True,
                return_tensors='pt',
                max_length=self.max_length
            )
            encoded = {k: v.to(self.device) for k, v in encoded.items()}

            with torch.no_grad():
                outputs = self.model(**encoded)
                embeddings = self.last_token_pool(
                    outputs.last_hidden_state,
                    encoded['attention_mask']
                )
                embeddings = F.normalize(embeddings, p=2, dim=1)

                if self.output_dim and self.output_dim < embeddings.shape[1]:
                    embeddings = embeddings[:, :self.output_dim]
                    embeddings = F.normalize(embeddings, p=2, dim=1)

            all_embeddings.append(embeddings.cpu().float())

        embeddings = torch.cat(all_embeddings, dim=0)

        if convert_to_numpy:
            return embeddings.numpy().astype(np.float32)
        return embeddings


class RuleFilteredRetriever:
    """
    Two-stage retrieval system:
    1. Filter by rule category
    2. Retrieve top K similar examples from filtered set
    """

    def __init__(self, embedder: FastEmbedder, use_gpu: bool = True):
        self.embedder = embedder
        self.use_gpu = use_gpu and faiss.get_num_gpus() > 0

        # Store rule categories and their indices
        self.rule_to_indices = defaultdict(list)
        self.rule_to_faiss_index = {}
        self.rule_to_data = {}

        # Global storage
        self.all_data = None
        self.all_vectors = None

    def build_index(self, df: pd.DataFrame, rule_col: str = 'rule',
                   comment_col: str = 'test_comment',
                   value_col: str = 'value',
                   instruction: Optional[str] = None):
        """
        Build FAISS indices per rule category
        Uses ONLY comment embeddings for similarity search within each rule

        Args:
            df: DataFrame with rules, comments, and values
            rule_col: Column name for rules
            comment_col: Column name for comments
            value_col: Column name for target values
            instruction: Optional instruction for embeddings
        """
        print(f"Building rule-filtered index from {len(df)} examples...")
        print("Strategy: Filter by rule FIRST, then search by COMMENT similarity only")

        # Store data
        df = df.copy()
        self.all_data = df

        # Group by rule
        rule_groups = df.groupby(rule_col)
        print(f"Found {len(rule_groups)} unique rules")

        # Encode ONLY comments for similarity search (not rule+comment combined)
        print("Encoding comments only (not combined with rules)...")
        all_vectors = self.embedder.encode(
            df[comment_col].tolist(),
            instruction=instruction
        )
        faiss.normalize_L2(all_vectors)
        self.all_vectors = all_vectors

        # Build separate FAISS index for each rule
        print("Building per-rule FAISS indices...")
        for rule, group_df in rule_groups:
            indices = group_df.index.tolist()
            rule_vectors = all_vectors[indices]

            # Store mapping
            self.rule_to_indices[rule] = indices
            self.rule_to_data[rule] = group_df

            # Build FAISS index for this rule using ONLY comment vectors
            dimension = rule_vectors.shape[1]
            index = faiss.IndexFlatIP(dimension)

            if self.use_gpu:
                res = faiss.StandardGpuResources()
                index = faiss.index_cpu_to_gpu(res, 0, index)

            index.add(rule_vectors)
            self.rule_to_faiss_index[rule] = index

            print(f"  Rule '{rule[:50]}...' : {len(indices)} examples")

        print(f"\nIndex building complete!")
        print(f"Total rules: {len(self.rule_to_faiss_index)}")
        print(f"Total examples: {len(df)}")

    def retrieve(self, query_rule: str, query_comment: str,
                top_k: int = 10,
                instruction: Optional[str] = None,
                fallback_to_similar_rules: bool = True) -> Dict:
        """
        Two-stage retrieval:
        1. Filter by rule (exact or similar)
        2. Search by comment similarity ONLY within filtered rule

        Args:
            query_rule: The rule to match
            query_comment: The comment to find similar examples for
            top_k: Number of similar examples to retrieve
            instruction: Optional instruction for encoding
            fallback_to_similar_rules: If exact rule not found, use rule similarity

        Returns:
            Dict with retrieved examples and metadata
        """
        # STAGE 1: Filter by rule first
        # Check if exact rule exists
        if query_rule in self.rule_to_faiss_index:
            matched_rule = query_rule
            rule_match_type = "exact"
        elif fallback_to_similar_rules:
            # Find most similar rule using rule text similarity
            matched_rule, rule_match_type = self._find_similar_rule(query_rule, instruction)
        else:
            return {
                'success': False,
                'error': f'Rule not found: {query_rule}',
                'matched_rule': None,
                'examples': []
            }

        # STAGE 2: Encode ONLY the comment for similarity search
        query_vector = self.embedder.encode(
            [query_comment],  # Only comment, no rule!
            instruction=instruction
        )
        faiss.normalize_L2(query_vector)

        # Get FAISS index for matched rule (contains only comment embeddings)
        rule_index = self.rule_to_faiss_index[matched_rule]
        rule_data = self.rule_to_data[matched_rule]

        # Search in filtered index using comment similarity
        similarities, indices = rule_index.search(query_vector, min(top_k, len(rule_data)))

        # Get actual dataframe indices
        actual_indices = [self.rule_to_indices[matched_rule][i] for i in indices[0]]
        retrieved_examples = self.all_data.iloc[actual_indices]

        return {
            'success': True,
            'matched_rule': matched_rule,
            'rule_match_type': rule_match_type,
            'query_rule': query_rule,
            'query_comment': query_comment,
            'similarities': similarities[0].tolist(),
            'examples': retrieved_examples.to_dict('records'),
            'num_examples_in_rule': len(rule_data)
        }

    def _find_similar_rule(self, query_rule: str, instruction: Optional[str] = None) -> Tuple[str, str]:
        """Find the most similar rule to query_rule using embedding similarity"""
        # Encode query rule
        query_rule_vector = self.embedder.encode(
            [f"rule: {query_rule}"],
            instruction=instruction
        )
        faiss.normalize_L2(query_rule_vector)

        # Encode all unique rules
        unique_rules = list(self.rule_to_faiss_index.keys())
        rule_vectors = self.embedder.encode(
            [f"rule: {r}" for r in unique_rules],
            instruction=instruction
        )
        faiss.normalize_L2(rule_vectors)

        # Find most similar
        similarities = np.dot(query_rule_vector, rule_vectors.T)[0]
        best_idx = np.argmax(similarities)
        best_rule = unique_rules[best_idx]

        return best_rule, f"similar (similarity: {similarities[best_idx]:.3f})"

    def batch_retrieve(self, df_query: pd.DataFrame,
                      rule_col: str = 'rule',
                      comment_col: str = 'test_comment',
                      top_k: int = 10,
                      instruction: Optional[str] = None) -> pd.DataFrame:
        """
        Batch retrieval for multiple queries

        Args:
            df_query: DataFrame with query rules and comments
            rule_col: Column name for rules
            comment_col: Column name for comments
            top_k: Number of examples to retrieve per query
            instruction: Optional instruction for encoding

        Returns:
            DataFrame with predictions
        """
        results = []

        for idx, row in df_query.iterrows():
            query_rule = row[rule_col]
            query_comment = row[comment_col]

            # Retrieve examples
            retrieval_result = self.retrieve(
                query_rule=query_rule,
                query_comment=query_comment,
                top_k=top_k,
                instruction=instruction
            )

            if retrieval_result['success']:
                # Calculate weighted sum of similarity scores with values
                examples = retrieval_result['examples']
                similarities = retrieval_result['similarities']
                values = [ex['value'] for ex in examples]

                # Weighted sum: multiply each similarity by its value and sum
                weighted_sum = sum(sim * val for sim, val in zip(similarities, values))

                # Decision based on sign of weighted sum
                decision = 1 if weighted_sum >= 0 else 0
                avg_violation = weighted_sum  # Store weighted sum for inspection

                results.append({
                    'row_id': idx,
                    'rule_violation': avg_violation,
                    'decision': decision,
                    'matched_rule': retrieval_result['matched_rule'],
                    'rule_match_type': retrieval_result['rule_match_type'],
                    'num_retrieved': len(examples)
                })
            else:
                # Fallback: return neutral prediction
                results.append({
                    'row_id': idx,
                    'rule_violation': 0.5,
                    'decision': 0,
                    'matched_rule': None,
                    'rule_match_type': 'not_found',
                    'num_retrieved': 0
                })

        return pd.DataFrame(results)

    def save_index(self, save_dir: str):
        """
        Save the retriever indices and data to disk

        Args:
            save_dir: Directory to save the index files
        """
        save_path = Path(save_dir)
        save_path.mkdir(parents=True, exist_ok=True)

        print(f"Saving index to {save_dir}...")

        # Save FAISS indices (convert GPU to CPU first if needed)
        faiss_dir = save_path / "faiss_indices"
        faiss_dir.mkdir(exist_ok=True)

        for rule, index in self.rule_to_faiss_index.items():
            # Convert to CPU index if it's a GPU index
            if self.use_gpu:
                cpu_index = faiss.index_gpu_to_cpu(index)
            else:
                cpu_index = index

            # Create safe filename from rule
            safe_filename = f"index_{hash(rule) % 10**8}.faiss"
            faiss.write_index(cpu_index, str(faiss_dir / safe_filename))

        # Save metadata (rule mappings)
        metadata = {
            'rule_to_indices': dict(self.rule_to_indices),
            'rule_to_filename': {rule: f"index_{hash(rule) % 10**8}.faiss"
                                 for rule in self.rule_to_faiss_index.keys()},
            'use_gpu': self.use_gpu
        }

        with open(save_path / "metadata.pkl", 'wb') as f:
            pickle.dump(metadata, f)

        # Save training data
        if self.all_data is not None:
            self.all_data.to_parquet(save_path / "training_data.parquet", index=True)

        # Save embeddings
        if self.all_vectors is not None:
            np.save(save_path / "embeddings.npy", self.all_vectors)

        # Save rule_to_data mapping
        rule_data_mapping = {rule: df.index.tolist() for rule, df in self.rule_to_data.items()}
        with open(save_path / "rule_data_mapping.pkl", 'wb') as f:
            pickle.dump(rule_data_mapping, f)

        print(f"âœ“ Index saved successfully!")
        print(f"  - {len(self.rule_to_faiss_index)} FAISS indices")
        print(f"  - {len(self.all_data)} training examples")
        print(f"  - Embeddings shape: {self.all_vectors.shape}")

    def load_index(self, load_dir: str):
        """
        Load a previously saved index from disk

        Args:
            load_dir: Directory containing the saved index files
        """
        load_path = Path(load_dir)

        if not load_path.exists():
            raise FileNotFoundError(f"Index directory not found: {load_dir}")

        print(f"Loading index from {load_dir}...")

        # Load metadata
        with open(load_path / "metadata.pkl", 'rb') as f:
            metadata = pickle.load(f)

        self.rule_to_indices = defaultdict(list, metadata['rule_to_indices'])
        rule_to_filename = metadata['rule_to_filename']
        saved_use_gpu = metadata['use_gpu']

        # Load training data
        if (load_path / "training_data.parquet").exists():
            self.all_data = pd.read_parquet(load_path / "training_data.parquet")

        # Load embeddings
        if (load_path / "embeddings.npy").exists():
            self.all_vectors = np.load(load_path / "embeddings.npy")

        # Load rule_to_data mapping
        with open(load_path / "rule_data_mapping.pkl", 'rb') as f:
            rule_data_mapping = pickle.load(f)

        # Reconstruct rule_to_data from mapping
        self.rule_to_data = {}
        for rule, indices in rule_data_mapping.items():
            self.rule_to_data[rule] = self.all_data.loc[indices]

        # Load FAISS indices
        faiss_dir = load_path / "faiss_indices"
        self.rule_to_faiss_index = {}

        for rule, filename in rule_to_filename.items():
            # Load CPU index
            cpu_index = faiss.read_index(str(faiss_dir / filename))

            # Convert to GPU if requested and available
            if self.use_gpu and faiss.get_num_gpus() > 0:
                res = faiss.StandardGpuResources()
                gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index)
                self.rule_to_faiss_index[rule] = gpu_index
            else:
                self.rule_to_faiss_index[rule] = cpu_index

        print(f"âœ“ Index loaded successfully!")
        print(f"  - {len(self.rule_to_faiss_index)} FAISS indices")
        print(f"  - {len(self.all_data)} training examples")
        print(f"  - Embeddings shape: {self.all_vectors.shape}")
        print(f"  - Using GPU: {self.use_gpu}")


def main_example():
    """Example usage of RuleFilteredRetriever"""

    # Load data
    print("Loading data...")
    base_path = "./data/final/zothers/"
    df_train = pd.read_csv(f"{base_path}rule_comment.csv")
    df_train = df_train.sample(n=5000, random_state=42).reset_index(drop=True)

    print(f"Training data shape: {df_train.shape}")
    print(f"Unique rules: {df_train['rule'].nunique()}")

    # Initialize embedder
    print("\nInitializing embedder...")
    embedder = FastEmbedder('Qwen/Qwen3-Embedding-0.6B', output_dim=1024)

    # Initialize retriever
    print("\nInitializing retriever...")
    retriever = RuleFilteredRetriever(embedder, use_gpu=True)

    # Build index
    instruction = "Given a rule and comment, retrieve similar training examples"
    retriever.build_index(
        df_train,
        rule_col='rule',
        comment_col='test_comment',
        value_col='value',
        instruction=instruction
    )

    # Test single query
    print("\n" + "="*80)
    print("Testing single query retrieval...")
    print("="*80)

    test_row = df_train.iloc[0]
    result = retriever.retrieve(
        query_rule=test_row['rule'],
        query_comment=test_row['test_comment'],
        top_k=5,
        instruction=instruction
    )

    print(f"\nQuery Rule: {result['query_rule'][:100]}...")
    print(f"Query Comment: {result['query_comment'][:100]}...")
    print(f"\nMatched Rule: {result['matched_rule'][:100]}...")
    print(f"Rule Match Type: {result['rule_match_type']}")
    print(f"Retrieved {len(result['examples'])} examples")
    print(f"\nTop 3 similarities: {result['similarities'][:3]}")

    # Test batch retrieval
    print("\n" + "="*80)
    print("Testing batch retrieval...")
    print("="*80)

    df_test = pd.read_csv("./data/final/df_test_cr_12.csv")
    df_test_sample = df_test.sample(n=100, random_state=42)

    results_df = retriever.batch_retrieve(
        df_test_sample,
        rule_col='rule',
        comment_col='test_comment',
        top_k=10,
        instruction=instruction
    )

    print(f"\nBatch retrieval results:")
    print(results_df.head(10))

    # Evaluate
    if 'violates_rule' in df_test_sample.columns:
        from sklearn.metrics import f1_score

        y_true = df_test_sample["violates_rule"].astype(str).str.strip().str.strip('"').str.strip("'").str.lower().map({"yes": 1, "no": 0})
        y_pred = results_df["decision"].values

        f1 = f1_score(y_true, y_pred)
        print(f"\nF1-Score: {f1:.4f}")

        # Match statistics
        print(f"\nRule matching statistics:")
        print(results_df['rule_match_type'].value_counts())

    return retriever, results_df


if __name__ == "__main__":
    retriever, results = main_example()


Overwriting rule_filtered_retrieval.py


In [2]:
import sys
sys.path.append('..')

import pandas as pd
import numpy as np
from rule_filtered_retrieval import FastEmbedder, RuleFilteredRetriever
from sklearn.metrics import f1_score, classification_report, confusion_matrix
from sklearn.metrics import accuracy_score, precision_score, recall_score
import time

## 2. Initialize Embedder and Retriever

**IMPORTANT**: Use the SAME model and output_dim as when building the index!

In [3]:
# Initialize embedder (must match the one used during index building)
print("Loading Qwen3-Embedding model...")
embedder = FastEmbedder(
    model_name='Qwen/Qwen3-Embedding-0.6B',
    output_dim=1024  # MUST match what was used during build
)

print(f"Embedder initialized on device: {embedder.device}")

Loading Qwen3-Embedding model...
Embedder initialized on device: cuda


In [4]:
# Initialize retriever
print("Initializing retriever...")
retriever = RuleFilteredRetriever(embedder, use_gpu=True)

print("âœ“ Retriever initialized (index not loaded yet)")

Initializing retriever...
âœ“ Retriever initialized (index not loaded yet)


## 3. Load Saved Index

In [5]:
# Path to saved index (choose one)
INDEX_DIR = "./saved_indices/rule_filtered_sample"  # Sample index (10K examples)
# INDEX_DIR = "./saved_indices/rule_filtered_full"  # Full index (250K examples)

print(f"Loading index from {INDEX_DIR}...")
start_time = time.time()

retriever.load_index(INDEX_DIR)

load_time = time.time() - start_time
print(f"\nâœ“ Index loaded in {load_time:.2f} seconds")
print(f"  Total rules: {len(retriever.rule_to_faiss_index)}")
print(f"  Total examples: {len(retriever.all_data)}")

Loading index from ./saved_indices/rule_filtered_sample...
Loading index from ./saved_indices/rule_filtered_sample...
âœ“ Index loaded successfully!
  - 6 FAISS indices
  - 251739 training examples
  - Embeddings shape: (251739, 1024)
  - Using GPU: True

âœ“ Index loaded in 1.32 seconds
  Total rules: 6
  Total examples: 251739


## 4. Test Single Query Retrieval

In [6]:
# Define instruction for retrieval
instruction = "Given a rule and comment, retrieve similar training examples for classification"

# Test with a custom example
test_rule = "No medical advice: do not offer or request specific medical advice."
test_comment = "You should take 500mg of ibuprofen every 4 hours for that headache."

print(f"Query Rule: {test_rule}")
print(f"\nQuery Comment: {test_comment}")

# Retrieve similar examples
result = retriever.retrieve(
    query_rule=test_rule,
    query_comment=test_comment,
    top_k=10,
    instruction=instruction
)

print(f"\n{'='*80}")
print("RETRIEVAL RESULTS")
print(f"{'='*80}")
print(f"Matched Rule Type: {result['rule_match_type']}")
print(f"Number of examples in this rule category: {result['num_examples_in_rule']}")
print(f"Retrieved: {len(result['examples'])} examples")

print(f"\nTop 5 Similarities:")
for i, (sim, ex) in enumerate(zip(result['similarities'][:5], result['examples'][:5])):
    print(f"  {i+1}. Similarity: {sim:.4f}, Value: {ex['value']:>2}, Comment: {ex['test_comment'][:80]}...")

# Calculate prediction
retrieved_values = [ex['value'] for ex in result['examples']]
predicted_value = np.mean(retrieved_values)
predicted_decision = 1 if predicted_value > 0 else 0

print(f"\nPredicted Value: {predicted_value:.3f}")
print(f"Predicted Decision: {'Violation' if predicted_decision == 1 else 'No Violation'}")

Query Rule: No medical advice: do not offer or request specific medical advice.

Query Comment: You should take 500mg of ibuprofen every 4 hours for that headache.

RETRIEVAL RESULTS
Matched Rule Type: similar (similarity: 0.964)
Number of examples in this rule category: 45965
Retrieved: 10 examples

Top 5 Similarities:
  1. Similarity: 0.8947, Value:  1, Comment: You should definitely take 500mg of ibuprofen every 6 hours for that TV-induced ...
  2. Similarity: 0.8942, Value:  1, Comment: You should definitely try taking 500mg of ibuprofen every 6 hours for that persi...
  3. Similarity: 0.8820, Value:  1, Comment: You should definitely start taking 500mg of ibuprofen every 6 hours for that hea...
  4. Similarity: 0.8729, Value:  1, Comment: You should definitely try taking 800mg of ibuprofen every 6 hours for that heada...
  5. Similarity: 0.8723, Value:  1, Comment: Dude you should definitely take 800mg of ibuprofen every 6 hours for that headac...

Predicted Value: 1.000
Predicted

## 5. Load Test Data

In [16]:
# Load test data
df_test = pd.read_csv("./data/final/df_train_0.csv")

print(f"Test set size: {len(df_test):,}")
print(f"Unique rules in test: {df_test['rule'].nunique():,}")

# Show sample
df_test.head()

Test set size: 2,029
Unique rules in test: 2


Unnamed: 0,row_id,body,rule,subreddit,positive_example_1,positive_example_2,negative_example_1,negative_example_2,rule_violation,violates_rule,test_comment
0,0,Banks don't want you to know this! Click here ...,"No Advertising: Spam, referral links, unsolici...",Futurology,If you could tell your younger self something ...,hunt for lady for jack off in neighbourhood ht...,Watch Golden Globe Awards 2017 Live Online in ...,"DOUBLE CEE x BANDS EPPS - ""BIRDS""\n\nDOWNLOAD/...",0,False,Banks don't want you to know this! Click here ...
1,1,SD Stream [ ENG Link 1] (http://www.sportsstre...,"No Advertising: Spam, referral links, unsolici...",soccerstreams,[I wanna kiss you all over! Stunning!](http://...,LOLGA.COM is One of the First Professional Onl...,#Rapper \nðŸš¨Straight Outta Cross Keys SC ðŸš¨YouTu...,[15 Amazing Hidden Features Of Google Search Y...,0,False,SD Stream [ ENG Link 1] (http://www.sportsstre...
2,2,Lol. Try appealing the ban and say you won't d...,No legal advice: Do not offer or request legal...,pcmasterrace,Don't break up with him or call the cops. If ...,It'll be dismissed: https://en.wikipedia.org/w...,Where is there a site that still works where y...,Because this statement of his is true. It isn'...,1,True,Lol. Try appealing the ban and say you won't d...
3,3,she will come your home open her legs with an...,"No Advertising: Spam, referral links, unsolici...",sex,Selling Tyrande codes for 3â‚¬ to paypal. PM. \n...,tight pussy watch for your cock get her at thi...,NSFW(obviously) http://spankbang.com/iy3u/vide...,Good News ::Download WhatsApp 2.16.230 APK for...,1,True,she will come your home open her legs with an...
4,4,code free tyrande --->>> [Imgur](http://i.imgu...,"No Advertising: Spam, referral links, unsolici...",hearthstone,wow!! amazing reminds me of the old days.Well...,seek for lady for sex in around http://p77.pl/...,must be watch movie https://sites.google.com/s...,We're streaming Pokemon Veitnamese Crystal RIG...,1,True,code free tyrande --->>> [Imgur](http://i.imgu...


## 6. Batch Inference on Test Set

In [17]:
# Run batch retrieval
TOP_K = 2

print(f"Running batch retrieval with top_k={TOP_K}...")
start_time = time.time()

results_df = retriever.batch_retrieve(
    df_test,
    rule_col='rule',
    comment_col='test_comment',
    top_k=TOP_K,
    instruction=instruction
)

inference_time = time.time() - start_time
print(f"\nâœ“ Batch inference complete in {inference_time:.2f} seconds")
print(f"  Average time per query: {inference_time/len(df_test)*1000:.2f} ms")

# Show sample results
results_df.head()

Running batch retrieval with top_k=4...

âœ“ Batch inference complete in 56.13 seconds
  Average time per query: 27.66 ms


Unnamed: 0,row_id,rule_violation,decision,matched_rule,rule_match_type,num_retrieved
0,0,-3.999986,0,"No Advertising: Spam, referral links, unsolici...",exact,4
1,1,-3.999986,0,"No Advertising: Spam, referral links, unsolici...",exact,4
2,2,-1.210768,0,No legal advice: Do not offer or request legal...,exact,4
3,3,3.999984,1,"No Advertising: Spam, referral links, unsolici...",exact,4
4,4,3.999987,1,"No Advertising: Spam, referral links, unsolici...",exact,4


In [18]:
#results_df["rule_violation"]=results_df["rule_violation"]/TOP_K
#results_df["rule_violation"]=results_df["rule_violation"]-results_df["rule_violation"].mean()
#results_df["decision"] = np.where(results_df["rule_violation"] > 0, 1, 0)

In [19]:
# Prepare ground truth
#y_true = df_test["violates_rule"].astype(str).str.strip().str.strip('"').str.strip("'").str.lower().map({"yes": 1, "no": 0})
y_true = df_test["rule_violation"].values
y_pred = results_df["decision"].values

print("="*80)
print("EVALUATION METRICS")
print("="*80)

# Calculate metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)

print(f"\nAccuracy:  {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1-Score:  {f1:.4f}")

print(f"\n{'='*80}")
print("CLASSIFICATION REPORT")
print(f"{'='*80}")
print(classification_report(y_true, y_pred, target_names=['No Violation', 'Violation']))

print(f"{'='*80}")
print("CONFUSION MATRIX")
print(f"{'='*80}")
cm = confusion_matrix(y_true, y_pred)
print(f"\n                 Predicted")
print(f"                 No  Yes")
print(f"Actual No      {cm[0,0]:4d} {cm[0,1]:4d}")
print(f"       Yes     {cm[1,0]:4d} {cm[1,1]:4d}")

print(f"\n{'='*80}")
print("RULE MATCH STATISTICS")
print(f"{'='*80}")
print(results_df['rule_match_type'].value_counts())

EVALUATION METRICS

Accuracy:  0.9019
Precision: 0.8925
Recall:    0.9176
F1-Score:  0.9048

CLASSIFICATION REPORT
              precision    recall  f1-score   support

No Violation       0.91      0.89      0.90       998
   Violation       0.89      0.92      0.90      1031

    accuracy                           0.90      2029
   macro avg       0.90      0.90      0.90      2029
weighted avg       0.90      0.90      0.90      2029

CONFUSION MATRIX

                 Predicted
                 No  Yes
Actual No       884  114
       Yes       85  946

RULE MATCH STATISTICS
rule_match_type
exact    2029
Name: count, dtype: int64
