In [3]:
import itertools
import json
import time
import pandas as pd
from typing import Dict, List, Any
import numpy as np
from pathlib import Path
import re

# Import your refactored RAG system
from rag_script3 import create_rag_instance, DEFAULT_CONFIG

# TEST CONFIGURATIONS
PARAMETER_GRID = {
    "temperature": [0.1, 0.3, 0.5, 0.7],
    "top_p": [0.7, 0.8, 0.9, 0.95],
    "top_k_llm": [5, 10, 20, 40],  # LLM top_k
    "top_k_retrieval": [2, 3, 5, 8],  # Retrieval top_k
    "min_similarity": [0.2, 0.25, 0.3, 0.35],
    "max_tokens": [200, 300, 500, 800],
    "mirostat_eta": [0.1, 0.3, 0.5, 0.8],  # Learning rate for Mirostat algorithm
    "mirostat_tau": [3.0, 5.0, 7.0, 10.0]  # Target entropy for Mirostat algorithm
}

# TEST QUERIES - Add more diverse queries for better testing
TEST_QUERIES = [
    "What are the key factors to consider when analyzing a potential real estate investment?",
    "How do I calculate cap rates for rental properties?",
    "What are the risks of investing in commercial real estate?",
    "Explain the difference between gross and net rental yields",
    "What is the 1% rule in real estate investing?"
]

class QualityEvaluator:
    """Evaluate response quality across multiple dimensions."""
    
    def __init__(self):
        # Keywords that might indicate hallucination or uncertainty
        self.uncertainty_indicators = [
            "i don't know", "i'm not sure", "uncertain", "unclear",
            "might be", "could be", "possibly", "perhaps", "maybe"
        ]
        
        # Keywords that indicate good source usage
        self.source_indicators = [
            "according to", "based on", "as mentioned", "the document states",
            "from the source", "referenced", "cited", "as shown in"
        ]
    
    def evaluate_relevance(self, query: str, answer: str) -> float:
        """
        Evaluate how relevant the answer is to the query.
        Uses keyword overlap and topic alignment.
        """
        # Extract key terms from query
        query_words = set(re.findall(r'\w+', query.lower()))
        answer_words = set(re.findall(r'\w+', answer.lower()))
        
        # Remove common stop words
        stop_words = {"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "up", "about", "into", "through", "during", "before", "after", "above", "below", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must", "shall", "can"}
        query_words -= stop_words
        answer_words -= stop_words
        
        if not query_words:
            return 0.5  # Neutral if no meaningful query words
        
        # Calculate overlap ratio
        overlap = len(query_words.intersection(answer_words))
        relevance_score = min(overlap / len(query_words), 1.0)
        
        # Bonus for topic-specific terms based on real estate context
        real_estate_terms = {"real estate", "property", "investment", "cap rate", "rental", "yield", "roi", "cash flow", "market", "valuation", "appreciation"}
        re_overlap = len(real_estate_terms.intersection(answer_words))
        if re_overlap > 0:
            relevance_score = min(relevance_score + 0.1 * re_overlap, 1.0)
        
        return relevance_score
    
    def evaluate_factual_accuracy(self, answer: str, sources: List[str]) -> float:
        """
        Evaluate factual accuracy based on confidence indicators and source alignment.
        This is a heuristic approach - true accuracy would need ground truth.
        """
        answer_lower = answer.lower()
        
        # Penalty for uncertainty indicators
        uncertainty_count = sum(1 for indicator in self.uncertainty_indicators if indicator in answer_lower)
        uncertainty_penalty = min(uncertainty_count * 0.1, 0.3)
        
        # Bonus for specific numbers/facts (indicates concrete information)
        number_matches = len(re.findall(r'\b\d+\.?\d*%?\b', answer))
        specificity_bonus = min(number_matches * 0.05, 0.2)
        
        # Base score starts neutral
        accuracy_score = 0.7 - uncertainty_penalty + specificity_bonus
        
        return max(0.0, min(accuracy_score, 1.0))
    
    def evaluate_coherence(self, answer: str) -> float:
        """
        Evaluate readability and coherence of the response.
        """
        if not answer or len(answer.strip()) < 10:
            return 0.0
        
        sentences = re.split(r'[.!?]+', answer)
        sentences = [s.strip() for s in sentences if s.strip()]
        
        if not sentences:
            return 0.0
        
        # Average sentence length (ideal range: 15-25 words)
        avg_sentence_length = np.mean([len(s.split()) for s in sentences])
        length_score = 1.0 - abs(avg_sentence_length - 20) / 20
        length_score = max(0.3, min(length_score, 1.0))
        
        # Sentence count (too few or too many can hurt coherence)
        sentence_count = len(sentences)
        if sentence_count < 2:
            count_penalty = 0.2
        elif sentence_count > 10:
            count_penalty = 0.1
        else:
            count_penalty = 0.0
        
        # Check for repetitive patterns
        words = answer.lower().split()
        word_freq = {}
        for word in words:
            if len(word) > 4:  # Only consider longer words
                word_freq[word] = word_freq.get(word, 0) + 1
        
        # Penalty for excessive repetition
        max_freq = max(word_freq.values()) if word_freq else 1
        repetition_penalty = min((max_freq - 3) * 0.05, 0.2) if max_freq > 3 else 0
        
        coherence_score = length_score - count_penalty - repetition_penalty
        return max(0.0, min(coherence_score, 1.0))
    
    def evaluate_source_usage(self, answer: str, sources: List[str]) -> float:
        """
        Evaluate how well the response uses retrieved sources.
        """
        if not sources:
            return 0.0  # No sources available
        
        answer_lower = answer.lower()
        
        # Check for source integration indicators
        source_integration = sum(1 for indicator in self.source_indicators if indicator in answer_lower)
        integration_score = min(source_integration * 0.2, 0.6)
        
        # Length suggests more detailed use of sources
        answer_length = len(answer.split())
        length_factor = min(answer_length / 100, 1.0) * 0.3  # Up to 0.3 bonus for longer answers
        
        # Base score for having sources
        base_score = 0.1
        
        source_score = base_score + integration_score + length_factor
        return min(source_score, 1.0)
    
    def detect_hallucination(self, answer: str, query: str, sources: List[str]) -> float:
        """
        Detect potential hallucinations (lower score = more hallucination detected).
        This is heuristic-based.
        """
        answer_lower = answer.lower()
        
        # Red flags for hallucination
        red_flags = [
            "i remember", "i recall", "i know from experience", "i've seen",
            "in my experience", "personally", "i believe", "i think",
            "as far as i know", "from what i understand"
        ]
        
        hallucination_flags = sum(1 for flag in red_flags if flag in answer_lower)
        
        # Very specific claims without source attribution might be hallucinations
        specific_claims = len(re.findall(r'\b\d{4}\b', answer))  # Years
        specific_claims += len(re.findall(r'\$[\d,]+', answer))  # Dollar amounts
        specific_claims += len(re.findall(r'\b\d+\.?\d*%\b', answer))  # Percentages
        
        # If many specific claims but no source indicators, potential hallucination
        source_attribution = sum(1 for indicator in self.source_indicators if indicator in answer_lower)
        
        if specific_claims > 2 and source_attribution == 0:
            specificity_penalty = 0.3
        else:
            specificity_penalty = 0.0
        
        # Score (1.0 = no hallucination detected, 0.0 = high hallucination)
        hallucination_score = 1.0 - (hallucination_flags * 0.2) - specificity_penalty
        return max(0.0, min(hallucination_score, 1.0))
    
    def evaluate_response(self, query: str, answer: str, sources: List[str]) -> Dict[str, float]:
        """Evaluate all quality dimensions."""
        return {
            "relevance": self.evaluate_relevance(query, answer),
            "factual_accuracy": self.evaluate_factual_accuracy(answer, sources),
            "coherence": self.evaluate_coherence(answer),
            "source_usage": self.evaluate_source_usage(answer, sources),
            "hallucination_resistance": self.detect_hallucination(answer, query, sources),
            "response_length": len(answer.split())
        }

class QualityGridTester:
    """Grid testing focused on response quality metrics."""
    
    def __init__(self, output_dir: str = "quality_grid_results"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        self.results = []
        self.rag_instance = None
        self.evaluator = QualityEvaluator()
        
    def generate_configurations(self, max_configs: int = 50) -> List[Dict]:
        """Generate parameter combinations."""
        print("🔧 Generating parameter combinations...")
        
        # Get all combinations
        keys = list(PARAMETER_GRID.keys())
        values = list(PARAMETER_GRID.values())
        all_combinations = list(itertools.product(*values))
        
        print(f"Total possible combinations: {len(all_combinations)}")
        
        # Limit if too many
        if len(all_combinations) > max_configs:
            print(f"Limiting to {max_configs} random combinations")
            np.random.seed(42)  # Reproducible
            indices = np.random.choice(len(all_combinations), max_configs, replace=False)
            all_combinations = [all_combinations[i] for i in indices]
        
        # Convert to config dicts
        configs = []
        for combo in all_combinations:
            config = DEFAULT_CONFIG.copy()
            
            # Map parameters to config structure
            config["llm_options"] = config["llm_options"].copy()
            config["llm_options"]["temperature"] = combo[keys.index("temperature")]
            config["llm_options"]["top_p"] = combo[keys.index("top_p")]
            config["llm_options"]["top_k"] = combo[keys.index("top_k_llm")]
            config["llm_options"]["num_predict"] = combo[keys.index("max_tokens")]
            config["llm_options"]["mirostat_eta"] = combo[keys.index("mirostat_eta")]
            config["llm_options"]["mirostat_tau"] = combo[keys.index("mirostat_tau")]
            
            config["top_k"] = combo[keys.index("top_k_retrieval")]
            config["min_similarity"] = combo[keys.index("min_similarity")]
            
            configs.append(config)
        
        print(f"✅ Generated {len(configs)} configurations")
        return configs
    
    def setup_rag(self) -> bool:
        """Setup RAG instance once."""
        print("🚀 Setting up RAG system...")
        self.rag_instance = create_rag_instance()
        self.rag_instance.set_verbose(False)  # Quiet mode for testing
        
        if not self.rag_instance.setup():
            print("❌ RAG setup failed")
            return False
        
        print("✅ RAG system ready")
        return True
    
    def test_configuration(self, config: Dict, config_id: int) -> Dict:
        """Test a single configuration."""
        print(f"Testing config {config_id}... ", end="")
        
        # Update RAG with new config
        self.rag_instance.update_config(config)
        self.rag_instance.reset_stats()
        self.rag_instance.clear_cache()  # Fresh start for each config
        
        results = []
        
        # Test all queries
        for query in TEST_QUERIES:
            try:
                result = self.rag_instance.ask(query, use_cache=False)
                
                # Extract source content for evaluation
                source_texts = [source.get("content", "") for source in result.get("sources", [])]
                
                # Evaluate quality
                quality_scores = self.evaluator.evaluate_response(
                    query, 
                    result["answer"], 
                    source_texts
                )
                
                results.append({
                    "query": query,
                    "answer": result["answer"],
                    "sources_count": len(result.get("sources", [])),
                    "quality_scores": quality_scores,
                    "success": result["error"] is None
                })
                
            except Exception as e:
                results.append({
                    "query": query,
                    "answer": f"Exception: {str(e)}",
                    "sources_count": 0,
                    "quality_scores": {
                        "relevance": 0.0,
                        "factual_accuracy": 0.0,
                        "coherence": 0.0,
                        "source_usage": 0.0,
                        "hallucination_resistance": 0.0,
                        "response_length": 0
                    },
                    "success": False
                })
        
        # Calculate aggregate metrics
        successful_results = [r for r in results if r["success"]]
        if successful_results:
            avg_relevance = np.mean([r["quality_scores"]["relevance"] for r in successful_results])
            avg_factual_accuracy = np.mean([r["quality_scores"]["factual_accuracy"] for r in successful_results])
            avg_coherence = np.mean([r["quality_scores"]["coherence"] for r in successful_results])
            avg_source_usage = np.mean([r["quality_scores"]["source_usage"] for r in successful_results])
            avg_hallucination_resistance = np.mean([r["quality_scores"]["hallucination_resistance"] for r in successful_results])
            avg_response_length = np.mean([r["quality_scores"]["response_length"] for r in successful_results])
            
            # Overall quality score (weighted average)
            overall_quality = (
                avg_relevance * 0.25 +
                avg_factual_accuracy * 0.25 +
                avg_coherence * 0.2 +
                avg_source_usage * 0.15 +
                avg_hallucination_resistance * 0.15
            )
        else:
            avg_relevance = avg_factual_accuracy = avg_coherence = 0.0
            avg_source_usage = avg_hallucination_resistance = avg_response_length = 0.0
            overall_quality = 0.0
        
        config_result = {
            "config_id": config_id,
            "config": config,
            "quality_metrics": {
                "avg_relevance": avg_relevance,
                "avg_factual_accuracy": avg_factual_accuracy,
                "avg_coherence": avg_coherence,
                "avg_source_usage": avg_source_usage,
                "avg_hallucination_resistance": avg_hallucination_resistance,
                "avg_response_length": avg_response_length,
                "overall_quality": overall_quality,
                "successful_queries": len(successful_results)
            },
            "detailed_results": results
        }
        
        print(f"✅ Quality: {overall_quality:.3f}, Relevance: {avg_relevance:.3f}")
        return config_result
    
    def run_grid_test(self, max_configs: int = 20) -> str:
        """Run full grid test."""
        print("🧪 Starting Quality-Focused Grid Test")
        print("="*50)
        
        if not self.setup_rag():
            return ""
        
        configs = self.generate_configurations(max_configs)
        
        print(f"\n🏃 Running {len(configs)} configurations...")
        print(f"📝 Testing {len(TEST_QUERIES)} queries per config")
        print("🎯 Quality Metrics: Relevance, Accuracy, Coherence, Source Usage, Hallucination Resistance")
        print("-"*50)
        
        for i, config in enumerate(configs, 1):
            result = self.test_configuration(config, i)
            self.results.append(result)
            
            # Save intermediate results
            if i % 5 == 0:
                self.save_results(f"intermediate_{i}")
        
        # Final save and analysis
        results_file = self.save_results("final")
        self.analyze_results()
        
        print("\n🎉 Quality grid test completed!")
        return results_file
    
    def save_results(self, suffix: str = "") -> str:
        """Save results to files."""
        timestamp = int(time.time())
        base_name = f"quality_grid_test_{timestamp}_{suffix}"
        
        # Save raw results
        json_file = self.output_dir / f"{base_name}.json"
        with open(json_file, 'w') as f:
            json.dump(self.results, f, indent=2)
        
        # Save metrics summary
        metrics_data = []
        for result in self.results:
            row = {"config_id": result["config_id"]}
            row.update(result["quality_metrics"])
            
            # Add key config parameters
            config = result["config"]
            row.update({
                "temperature": config["llm_options"]["temperature"],
                "top_p": config["llm_options"]["top_p"],
                "llm_top_k": config["llm_options"]["top_k"],
                "retrieval_top_k": config["top_k"],
                "min_similarity": config["min_similarity"],
                "max_tokens": config["llm_options"]["num_predict"],
                "mirostat_eta": config["llm_options"]["mirostat_eta"],
                "mirostat_tau": config["llm_options"]["mirostat_tau"]
            })
            
            metrics_data.append(row)
        
        df = pd.DataFrame(metrics_data)
        csv_file = self.output_dir / f"{base_name}_metrics.csv"
        df.to_csv(csv_file, index=False)
        
        return str(json_file)
    
    def analyze_results(self):
        """Analyze and print results."""
        if not self.results:
            return
        
        print("\n📊 QUALITY ANALYSIS")
        print("="*50)
        
        # Create metrics DataFrame
        metrics_data = []
        for result in self.results:
            row = {"config_id": result["config_id"]}
            row.update(result["quality_metrics"])
            row.update({
                "temperature": result["config"]["llm_options"]["temperature"],
                "top_p": result["config"]["llm_options"]["top_p"],
                "llm_top_k": result["config"]["llm_options"]["top_k"],
                "retrieval_top_k": result["config"]["top_k"],
                "min_similarity": result["config"]["min_similarity"],
                "max_tokens": result["config"]["llm_options"]["num_predict"],
                "mirostat_eta": result["config"]["llm_options"]["mirostat_eta"],
                "mirostat_tau": result["config"]["llm_options"]["mirostat_tau"]
            })
            metrics_data.append(row)
        
        df = pd.DataFrame(metrics_data)
        
        # Overall statistics
        print(f"🔢 Tested {len(self.results)} configurations")
        print(f"🎯 Average overall quality: {df['overall_quality'].mean():.3f}")
        print(f"📊 Average relevance: {df['avg_relevance'].mean():.3f}")
        print(f"✅ Average factual accuracy: {df['avg_factual_accuracy'].mean():.3f}")
        print(f"📝 Average coherence: {df['avg_coherence'].mean():.3f}")
        print(f"📚 Average source usage: {df['avg_source_usage'].mean():.3f}")
        print(f"🛡️  Average hallucination resistance: {df['avg_hallucination_resistance'].mean():.3f}")
        print(f"📏 Average response length: {df['avg_response_length'].mean():.1f} words")
        
        # Best configurations
        print("\n🏆 TOP 5 HIGHEST OVERALL QUALITY:")
        best_overall = df.nlargest(5, 'overall_quality')[['config_id', 'overall_quality', 'avg_relevance', 'avg_factual_accuracy', 'avg_coherence', 'temperature', 'top_p']]
        print(best_overall.to_string(index=False))
        
        print("\n🎯 TOP 5 MOST RELEVANT:")
        best_relevance = df.nlargest(5, 'avg_relevance')[['config_id', 'avg_relevance', 'overall_quality', 'temperature', 'retrieval_top_k', 'min_similarity']]
        print(best_relevance.to_string(index=False))
        
        print("\n✅ TOP 5 MOST FACTUALLY ACCURATE:")
        best_accuracy = df.nlargest(5, 'avg_factual_accuracy')[['config_id', 'avg_factual_accuracy', 'overall_quality', 'temperature', 'mirostat_eta', 'mirostat_tau']]
        print(best_accuracy.to_string(index=False))
        
        print("\n📚 TOP 5 BEST SOURCE USAGE:")
        best_sources = df.nlargest(5, 'avg_source_usage')[['config_id', 'avg_source_usage', 'overall_quality', 'retrieval_top_k', 'min_similarity', 'max_tokens']]
        print(best_sources.to_string(index=False))
        
        # Parameter correlations
        print(f"\n📈 PARAMETER CORRELATIONS WITH OVERALL QUALITY:")
        quality_correlations = df[['overall_quality', 'temperature', 'top_p', 'llm_top_k', 'retrieval_top_k', 'min_similarity', 'max_tokens', 'mirostat_eta', 'mirostat_tau']].corr()['overall_quality'].abs().sort_values(ascending=False)
        for param, corr in quality_correlations.items():
            if param != 'overall_quality' and abs(corr) > 0.1:
                print(f"  {param}: {corr:.3f}")


def main():
    """Run the quality-focused grid test."""
    print("🧪 RAG Quality Configuration Grid Test")
    print("="*50)
    
    # Configuration
    max_configs = 12  # Adjust based on your time budget
    
    # Run test
    tester = QualityGridTester()
    results_file = tester.run_grid_test(max_configs)
    
    print(f"\n📁 Results saved to: {results_file}")
    print("🔍 Check the CSV file for detailed quality metrics")

if __name__ == "__main__":
    main()

🧪 RAG Quality Configuration Grid Test
🧪 Starting Quality-Focused Grid Test
🚀 Setting up RAG system...
✅ RAG system ready
🔧 Generating parameter combinations...
Total possible combinations: 65536
Limiting to 12 random combinations
✅ Generated 12 configurations

🏃 Running 12 configurations...
📝 Testing 5 queries per config
🎯 Quality Metrics: Relevance, Accuracy, Coherence, Source Usage, Hallucination Resistance
--------------------------------------------------
Testing config 1... ✅ Quality: 0.775, Relevance: 0.994
Testing config 2... ✅ Quality: 0.812, Relevance: 1.000
Testing config 3... ✅ Quality: 0.812, Relevance: 0.974
Testing config 4... ✅ Quality: 0.782, Relevance: 0.994
Testing config 5... ✅ Quality: 0.825, Relevance: 1.000
Testing config 6... ✅ Quality: 0.777, Relevance: 1.000
Testing config 7... ✅ Quality: 0.805, Relevance: 0.994
Testing config 8... ✅ Quality: 0.774, Relevance: 0.994
Testing config 9... ✅ Quality: 0.797, Relevance: 0.994
Testing config 10... ✅ Quality: 0.803, Re