<a href="https://colab.research.google.com/github/wesslen/rag-workflow/blob/main/03_unit_tests.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/wesslen/rag-workflow.git

Cloning into 'rag-workflow'...
remote: Enumerating objects: 44, done.[K
remote: Counting objects: 100% (44/44), done.[K
remote: Compressing objects: 100% (36/36), done.[K
remote: Total 44 (delta 15), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (44/44), 69.86 KiB | 687.00 KiB/s, done.
Resolving deltas: 100% (15/15), done.


In [3]:
import os
from google.colab import userdata

os.environ['API_KEY'] = userdata.get('API_KEY')
os.environ['BASE_URL'] = userdata.get('BASE_URL')

In [5]:
import pandas as pd
import json
from typing import List, Dict, Any
from datetime import datetime
import time

class TestCaseGenerator:
    def __init__(self, eval_results_path: str):
        self.results_path = eval_results_path
        self.results = self._load_results()
        self.hit_data = self._process_hit_data()

    def _load_results(self) -> Dict[str, Any]:
        with open(self.results_path, 'r') as f:
            return json.load(f)

    def _process_hit_data(self) -> pd.DataFrame:
        return pd.DataFrame({
            'query': self.results['queries'],
            'chunk': self.results['original_chunks'],
            'hit': self.results['hit_at_k'],
            'retrieval_time': self.results['retrieval_times']
        })

    def _generate_test_case(self, idx: int, row: pd.Series) -> str:
        query = row['query'].replace("'", "\\'").replace("\n", " ")
        expected_chunk = row['chunk'].replace("'", "\\'").replace("\n", " ")
        avg_time = row['retrieval_time']

        return f'''
def test_retrieval_{idx}(rag_engine):
    """Test retrieval for query: {query}"""
    query = '{query}'
    expected_chunk = '{expected_chunk}'

    # Time the retrieval
    start_time = time.time()
    results = rag_engine.retrieve(query)
    retrieval_time = time.time() - start_time

    # Convert results to text for comparison
    retrieved_texts = [node.text for node in results]

    # Normalize texts for comparison
    normalized_expected = normalize_text(expected_chunk)
    normalized_retrieved = [normalize_text(text) for text in retrieved_texts]

    # Assert the expected chunk is in the results
    assert any(normalized_expected in text for text in normalized_retrieved), \\
        f"Expected chunk not found in retrieval results"

    # Assert retrieval time is within acceptable range (2x baseline)
    assert retrieval_time < {avg_time * 2}, \\
        f"Retrieval took too long: {{retrieval_time:.2f}}s > {avg_time * 2:.2f}s"
'''

    def _generate_test_file_content(self, test_cases: List[str]) -> str:
        header = f'''"""
Automatically generated test cases for RAG retrieval.
Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
"""
import pytest
import time
from typing import List

def normalize_text(text: str) -> str:
    """Normalize text for comparison."""
    return ' '.join(text.lower().split())

# Performance test constants
TEST_QUERIES = [
    "What are the requirements for medical licenses?",
    "How should employees handle confidential information?",
    "What are the policies on conflicts of interest?"
]

def test_retrieval_performance(rag_engine):
    """Test overall retrieval performance."""
    times = []
    for query in TEST_QUERIES:
        start_time = time.time()
        results = rag_engine.retrieve(query)
        retrieval_time = time.time() - start_time
        times.append(retrieval_time)

        assert len(results) > 0, f"No results returned for query: {{query}}"

    avg_time = sum(times) / len(times)
    assert avg_time < 2.0, f"Average retrieval time too high: {{avg_time:.2f}}s"

{test_cases}
'''
        return header

    def generate_test_file(self, output_path: str = 'test_rag_retrieval.py'):
        successful_cases = self.hit_data[self.hit_data['hit'] == 1]
        test_cases = []

        for idx, row in successful_cases.iterrows():
            test_case = self._generate_test_case(idx, row)
            test_cases.append(test_case)

        test_file_content = self._generate_test_file_content('\n'.join(test_cases))

        with open(output_path, 'w') as f:
            f.write(test_file_content)

        print(f"Generated {len(test_cases)} test cases in {output_path}")
        return test_file_content


def analyze_test_coverage(test_generator: TestCaseGenerator) -> Dict[str, Any]:
    """Analyze test coverage statistics."""
    total_queries = len(test_generator.hit_data)
    successful_queries = len(test_generator.hit_data[test_generator.hit_data['hit'] == 1])
    failed_queries = total_queries - successful_queries

    coverage_stats = {
        'total_queries': total_queries,
        'successful_queries': successful_queries,
        'failed_queries': failed_queries,
        'coverage_percentage': (successful_queries / total_queries) * 100 if total_queries > 0 else 0,
        'avg_retrieval_time': test_generator.hit_data['retrieval_time'].mean(),
        'max_retrieval_time': test_generator.hit_data['retrieval_time'].max(),
        'min_retrieval_time': test_generator.hit_data['retrieval_time'].min(),
        'std_retrieval_time': test_generator.hit_data['retrieval_time'].std()
    }

    print("\nTest Coverage Analysis:")
    print(f"Total Queries: {coverage_stats['total_queries']}")
    print(f"Successful Queries: {coverage_stats['successful_queries']}")
    print(f"Failed Queries: {coverage_stats['failed_queries']}")
    print(f"Coverage Percentage: {coverage_stats['coverage_percentage']:.2f}%")
    print(f"Average Retrieval Time: {coverage_stats['avg_retrieval_time']:.3f}s")
    print(f"Max Retrieval Time: {coverage_stats['max_retrieval_time']:.3f}s")
    print(f"Min Retrieval Time: {coverage_stats['min_retrieval_time']:.3f}s")
    print(f"Std Dev Retrieval Time: {coverage_stats['std_retrieval_time']:.3f}s")

    return coverage_stats

def generate_performance_report(coverage_stats: Dict[str, Any]) -> pd.DataFrame:
    """Generate a detailed performance report."""
    performance_data = {
        'Metric': [
            'Average Retrieval Time',
            'Maximum Retrieval Time',
            'Minimum Retrieval Time',
            'Standard Deviation',
            'Coverage Percentage'
        ],
        'Value': [
            f"{coverage_stats['avg_retrieval_time']:.3f}s",
            f"{coverage_stats['max_retrieval_time']:.3f}s",
            f"{coverage_stats['min_retrieval_time']:.3f}s",
            f"{coverage_stats['std_retrieval_time']:.3f}s",
            f"{coverage_stats['coverage_percentage']:.2f}%"
        ]
    }
    return pd.DataFrame(performance_data)

In [6]:
generator = TestCaseGenerator('/content/rag-workflow/data/text.eval_results.json')
test_content = generator.generate_test_file()
print("Tests generated successfully!")

Generated 174 test cases in test_rag_retrieval.py
Tests generated successfully!


In [7]:
# Analyze coverage
coverage_stats = analyze_test_coverage(generator)

# Generate and display performance report
performance_df = generate_performance_report(coverage_stats)
display(HTML(performance_df.to_html(index=False)))


Test Coverage Analysis:
Total Queries: 220
Successful Queries: 174
Failed Queries: 46
Coverage Percentage: 79.09%
Average Retrieval Time: 0.845s
Max Retrieval Time: 1.055s
Min Retrieval Time: 0.799s
Std Dev Retrieval Time: 0.035s


Metric,Value
Average Retrieval Time,0.845s
Maximum Retrieval Time,1.055s
Minimum Retrieval Time,0.799s
Standard Deviation,0.035s
Coverage Percentage,79.09%


# Ideas to extend tests and RAG evaluation workflow

Based on ideas from Jason Liu's [The RAG Playbook](https://jxnl.co/writing/2024/08/19/rag-flywheel/) and [Systematically Improving your RAG](https://jxnl.co/writing/2024/05/22/systematically-improving-your-rag/).

## Synthetic Data Generation

Generate more synthetic queries using these patterns from the article:
- Questions about each text chunk
- System limitations queries
- Edge case queries

Example implementation:
```python
def generate_synthetic_queries(chunk):
    queries = [
        f"What does the article say about {chunk[:50]}?", # Direct
        f"Can you explain the concept of {chunk[:50]}?",  # Conceptual
        f"What are examples of {chunk[:50]}?"  # Examples
    ]
    return queries
```

## Experiment Ideas

### 1. Try different chunking Strategies
```python
# Recursive chunking
from llama_index.node_parsers import RecursiveNodeParser
parser = RecursiveNodeParser(chunk_sizes=[2048, 512, 128])

# Semantic chunking - group by topic similarity
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-MiniLM-L6-v2')
```

### 2. Embedding Models
Try:
- all-MiniLM-L6-v2 (fast baseline)
- all-mpnet-base-v2 (better quality)
- BGE-M3 (optimized for RAG)

### 3. Topic Modeling & Clustering

Use [BERTopic](https://maartengr.github.io/BERTopic/index.html)

```python
# Using BERTopic
topic_model = BERTopic()
topics, _ = topic_model.fit_transform(documents)

# Key topics from article:
# - RAG System Implementation
# - Metrics & Evaluation
# - Real-world Data Analysis
# - Continuous Improvement
```

### 4. Evaluation Pipeline
```python
def evaluate_retrieval(index, test_cases):
    metrics = {
        'precision': [],
        'recall': [],
        'latency': []
    }
    
    for query, expected in test_cases:
        results = index.as_query_engine().query(query)
        # Add precision/recall calculation
        # Add latency tracking
        
    return metrics
```

## Experimental Matrix

Test combinations of:
- Chunk sizes: [128, 256, 512] tokens
- Overlap ratios: [0.1, 0.2, 0.3]  
- Embedding models: [MiniLM, MPNet, BGE]
- Retrieval k: [2, 3, 5]

Track:
- Precision/recall per topic cluster
- Retrieval latency
- Index size
- Memory usage

## User Feedback Integration

Once you begin to get actual user feedback - combine with synthetic data.

- Understand where synthetic data expands user feedback
- Where does user feedback fit in gaps of synthetic data?

```python
from sklearn.cluster import KMeans

def analyze_user_queries(queries):
    # Cluster similar questions
    embeddings = model.encode(queries)
    clusters = KMeans(n_clusters=5).fit_predict(embeddings)
    
    # Analyze patterns within clusters
    # Map to article's topic categories
    return cluster_analysis
```