In [14]:
import os
import sys
import time
import random
import yaml
import re  # Added for parsing tags
import tempfile
import numpy as np
from typing import List, Tuple, Dict, Any
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer

# ----------------------------------------------------------------
# [1] Semantic IOU Evaluator (Metric Calculation Class)
# ----------------------------------------------------------------
class SemanticIOUEvaluator:
    def __init__(self, embedding_model_name='all-MiniLM-L6-v2', threshold=0.65):
        print(f"Loading Evaluator Model ({embedding_model_name})...")
        self.encoder = SentenceTransformer(embedding_model_name)
        self.threshold = threshold

    def _split_into_sentences(self, text: str) -> List[str]:
        # Simple sentence splitting
        if not text: return []
        return [s.strip() for s in text.replace('\n', '.').split('.') if len(s.strip()) > 10]

    def calculate_iou(self, reference_text: str, predicted_text: str) -> float:
        ref_sentences = self._split_into_sentences(reference_text)
        pred_sentences = self._split_into_sentences(predicted_text)

        if not ref_sentences:
            return 0.0
        if not pred_sentences:
            # If model didn't cite anything, IOU is 0
            return 0.0

        ref_emb = self.encoder.encode(ref_sentences)
        pred_emb = self.encoder.encode(pred_sentences)

        # Cosine Similarity Matrix
        sim_matrix = cosine_similarity(ref_emb, pred_emb)

        # Soft Matching
        max_sim_ref = np.max(sim_matrix, axis=1)   # Best match in Pred for each Ref sentence
        max_sim_pred = np.max(sim_matrix, axis=0)  # Best match in Ref for each Pred sentence

        intersection_ref = np.sum(max_sim_ref > self.threshold)
        intersection_pred = np.sum(max_sim_pred > self.threshold)
        
        # Intersection size (average)
        intersection = (intersection_ref + intersection_pred) / 2.0
        
        # Union size
        union = len(ref_sentences) + len(pred_sentences) - intersection
        
        return intersection / union if union > 0 else 0.0

# ----------------------------------------------------------------
# [2] Integration Logic (Prompt Injection & Parsing)
# ----------------------------------------------------------------
def run_citation_iou_test(data_path: str):
    # 0. Load Configuration
    try:
        from src.core.orchestrator import WorkflowOrchestrator
    except ImportError:
        print("[Error] 'src' folder not found. Please run from the project root.")
        return

    # Test configuration
    config = {
        'data': {'chunk_size': 4000, 'overlap': 1000, 'batch_size': 32},
        'logging': {'base_path': 'test_results'},
        'experiment_name': 'citation_iou_test',
        'strategy': {'scan_top_n': 10, 'final_top_k': 3} 
    }

    # 1. Load Data
    with open(data_path, 'r', encoding='utf-8') as f:
        full_text = f.read()

    print(f"\n{'='*60}")
    print(f"[TEST START] Data Length: {len(full_text)} chars")
    print(f"{'='*60}")

    # 2. [Oracle Process] Simulate Ground Truth
    print("\n[1] Oracle: Generating Ground Truth Query & Reference...")
    
    if len(full_text) < 1000:
        start_idx = 0
        end_idx = len(full_text)
    else:
        start_idx = random.randint(0, len(full_text) - 1000)
        end_idx = start_idx + random.randint(300, 800)
    
    oracle_reference_text = full_text[start_idx:end_idx]
    
    # Generate Base Query
    snippet = oracle_reference_text[:50].replace('\n', ' ')
    base_query = f"Please explain the following content from this document in detail: '{snippet}...'"
    
    # [CRITICAL] Prompt Injection for Explicit Citation
    citation_instruction = (
        "\n\n[IMPORTANT INSTRUCTION]\n"
        "After answering, you MUST provide the exact text segments from the document "
        "that you used to derive your answer.\n"
        "Wrap the cited text inside <reference> and </reference> tags.\n"
        "Example:\n"
        "Answer: ...\n"
        "Reference: <reference>The cited text from document...</reference>"
    )
    
    final_query = base_query + citation_instruction
    
    print(f" -> Base Query: {base_query}")
    print(f" -> Oracle Reference Length: {len(oracle_reference_text)} chars")

    # 3. [Student Process] Run Orchestrator Pipeline
    print("\n[2] Student: Running Orchestrator Pipeline (with Prompt Injection)...")
    
    cited_text = ""
    final_answer = ""
    
    try:
        # Initialize Orchestrator
        orchestrator = WorkflowOrchestrator(config)
        
        # Run Pipeline (No interceptor needed, we rely on model output)
        final_answer = orchestrator.run_pipeline([data_path], final_query)
        
        # 4. Parse the Output (Explicit Citation Extraction)
        # Find all content between <reference> tags
        citations = re.findall(r"<reference>(.*?)</reference>", final_answer, re.DOTALL)
        
        if citations:
            cited_text = " ".join(citations).strip()
            print(f" -> [Success] Parsed {len(citations)} citation segments.")
        else:
            print(" -> [Warning] No <reference> tags found in the model output.")
            print("    (The model may have failed to follow instructions or hallucinated without context.)")
            
    except Exception as e:
        print(f" [Error] Pipeline Execution Failed: {e}")
        import traceback
        traceback.print_exc()
        return

    print(f" -> Model Answer (Preview): {final_answer[:100]}...")
    print(f" -> Cited Text Length: {len(cited_text)} chars")

    # 5. [Evaluation] Calculate Semantic IOU
    print("\n[3] Evaluation: Calculating Semantic IOU (Oracle vs. Cited Text)...")
    evaluator = SemanticIOUEvaluator(threshold=0.6)
    
    iou_score = evaluator.calculate_iou(oracle_reference_text, cited_text)

    # 6. Results Report
    print(f"\n{'+'*60}")
    print(f" RESULTS REPORT (Explicit Citation Method)")
    print(f"{'+'*60}")
    print(f" | Query Type       : Prompt Injection (<reference> enforcement)")
    print(f" | Oracle Ref Length: {len(oracle_reference_text)}")
    print(f" | Cited Text Len   : {len(cited_text)}")
    print(f" | ----------------------------------")
    print(f" | Language IOU     : {iou_score:.4f}  (0.0 ~ 1.0)")
    print(f"{'+'*60}")
    
    if iou_score > 0.6:
        print(" => EXCELLENT: The model correctly cited the evidence used.")
    elif iou_score > 0.3:
        print(" => GOOD: The model cited relevant parts, but with some noise or missing info.")
    elif iou_score > 0.0:
        print(" => POOR: The citations are barely relevant to the ground truth.")
    else:
        print(" => FAIL: No intersection found or no citations provided.")

# ----------------------------------------------------------------
# [4] Main Entry Point
# ----------------------------------------------------------------
if __name__ == "__main__":
    # Create dummy data for testing
    dummy_filename = "./data/aiact.txt"
    
    os.makedirs(os.path.dirname(dummy_filename), exist_ok=True)
    
    if not os.path.exists(dummy_filename):
        print("Creating dummy data for testing...")
        with open(dummy_filename, "w", encoding='utf-8') as f:
            texts = []
            full_content = "\n\n".join(texts * 20) 
            f.write(full_content)

    # Execution
    run_citation_iou_test(dummy_filename)


[TEST START] Data Length: 356513 chars

[1] Oracle: Generating Ground Truth Query & Reference...
 -> Base Query: Please explain the following content from this document in detail: 'tains specific rules for AI systems that create a ...'
 -> Oracle Reference Length: 657 chars

[2] Student: Running Orchestrator Pipeline (with Prompt Injection)...

[RUN] Starting Pipeline Execution
Query: Please explain the following content from this document in detail: 'tains specific rules for AI systems that create a ...'

[IMPORTANT INSTRUCTION]
After answering, you MUST provide the exact text segments from the document that you used to derive your answer.
Wrap the cited text inside <reference> and </reference> tags.
Example:
Answer: ...
Reference: <reference>The cited text from document...</reference>

[STEP 1/8] Preprocessing data and checking cache...
[DONE] Preprocessing completed (0.00s)

[STEP 2/8] Generating embeddings and formulating adaptive strategy...
    >>> [Cache Found] Loading previous

Scouting: 100%|██████████| 11/11 [00:00<00:00, 31.73it/s]


[DONE] Detailed scouting completed (0.40s)

[STEP 5/8] Selecting high-value information and packing evidence...
    >>> Included top 1 pieces of evidence for final synthesis.
[DONE] Packaging completed (0.00s)

[STEP 6/8] Synthesizing technical report (Mistral-Small)...
[DONE] Synthesis completed (3.09s)

[STEP 7/8] Refining user-friendly response...
[DONE] Refinement completed (5.66s)

[STEP 8/8] Processing finished. Logging results...

[SUCCESS] Total processing time: 12.00s

 -> [Success] Parsed 1 citation segments.
 -> Model Answer (Preview): ### Detailed Explanation of the Proposed Regulatory Framework for AI

The European Commission's prop...
 -> Cited Text Length: 1320 chars

[3] Evaluation: Calculating Semantic IOU (Oracle vs. Cited Text)...
Loading Evaluator Model (all-MiniLM-L6-v2)...

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 RESULTS REPORT (Explicit Citation Method)
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | Query Type       : Prompt

In [None]:
import os
import re
import time
import random
import numpy as np
from typing import List
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer

try:
    from mistralai import Mistral
except ImportError:
    print("[!] Mistral SDK not found. Install via 'pip install mistralai'")
    Mistral = None

# ==============================================================================
# [1] Semantic IOU Evaluator
# ==============================================================================
class SemanticIOUEvaluator:
    def __init__(self, embedding_model_name='all-MiniLM-L6-v2', threshold=0.65):
        print(f"Loading Evaluator Model ({embedding_model_name})...")
        self.encoder = SentenceTransformer(embedding_model_name)
        self.threshold = threshold

    def _split_into_sentences(self, text: str) -> List[str]:
        if not text: return []
        
        return [s.strip() for s in text.replace('\n', '.').split('.') if len(s.strip()) > 10]

    def calculate_iou(self, reference_text: str, predicted_text: str) -> float:
        ref_sentences = self._split_into_sentences(reference_text)
        pred_sentences = self._split_into_sentences(predicted_text)

        if not ref_sentences: return 0.0
        if not pred_sentences: return 0.0

        ref_emb = self.encoder.encode(ref_sentences)
        pred_emb = self.encoder.encode(pred_sentences)

        # Cosine Similarity Matrix
        sim_matrix = cosine_similarity(ref_emb, pred_emb)

        max_sim_ref = np.max(sim_matrix, axis=1)
        max_sim_pred = np.max(sim_matrix, axis=0)

        intersection_ref = np.sum(max_sim_ref > self.threshold)
        intersection_pred = np.sum(max_sim_pred > self.threshold)
        
        intersection = (intersection_ref + intersection_pred) / 2.0
        union = len(ref_sentences) + len(pred_sentences) - intersection
        
        return intersection / union if union > 0 else 0.0

# ==============================================================================
# [2] Mistral Client Wrapper
# ==============================================================================
class MistralGenerator:
    def __init__(self, model_name="mistral-small-2506"):
        api_key = "xkqaUYgXKP8lYmONgEZUGLoQTGlDAcRg"
        if not api_key:
            raise ValueError("MISTRAL_API_KEY environment variable is not set.")
        
        self.client = Mistral(api_key=api_key)
        self.model = model_name
        print(f"[*] Initialized Mistral Client with model: {self.model}")

    def generate_response(self, context: str, query: str) -> str:
        system_instruction = (
            "You are a precise research assistant.\n"
            "1. Answer the user's question based ONLY on the provided Context.\n"
            "2. CRITICAL: You MUST cite the exact text segments from the Context used to derive your answer.\n"
            "3. Format your citations by wrapping the exact text inside <reference> and </reference> tags.\n"
            "   Example: ...answer... Reference: <reference>exact text from context</reference>"
        )

        user_message = (
            f"Context:\n{context}\n\n"
            f"Question: {query}"
        )

        try:
            response = self.client.chat.complete(
                model=self.model,
                messages=[
                    {"role": "system", "content": system_instruction},
                    {"role": "user", "content": user_message}
                ],
                temperature=0.0
            )
            return response.choices[0].message.content
        except Exception as e:
            print(f"[!] API Error: {e}")
            return ""

# ==============================================================================
# [3] Test Logic (Mistral Specific)
# ==============================================================================
def run_mistral_iou_test(data_path: str):
    evaluator = SemanticIOUEvaluator(threshold=0.7) #
    try:
        generator = MistralGenerator(model_name="mistral-small-2506")
    except ValueError as e:
        print(e)
        return

    with open(data_path, 'r', encoding='utf-8') as f:
        full_text = f.read()

    print(f"\n{'='*60}")
    print(f"[TEST START] Mistral-Small Language IOU Test")
    print(f"{'='*60}")

    ctx_len = 4000
    if len(full_text) > ctx_len:
        start_ctx = random.randint(0, len(full_text) - ctx_len)
        context_text = full_text[start_ctx : start_ctx + ctx_len]
    else:
        context_text = full_text

    target_len = random.randint(200, 500)
    target_start = random.randint(0, len(context_text) - target_len)
    oracle_reference_text = context_text[target_start : target_start + target_len]

    snippet = oracle_reference_text[:80].replace('\n', ' ')
    query = f"Please explain the details regarding this segment: '{snippet}...'"

    print(f"\n[1] Setup Scenario")
    print(f" -> Context Length : {len(context_text)} chars")
    print(f" -> Oracle Ref Len : {len(oracle_reference_text)} chars")
    print(f" -> Query          : {query}")

    print("\n[2] Calling Mistral API...")
    start_time = time.time()
    
    final_answer = generator.generate_response(context_text, query)
    
    elapsed = time.time() - start_time
    print(f" -> Generation Time: {elapsed:.2f}s")

    print("\n[3] Parsing Citations...")
    citations = re.findall(r"<reference>(.*?)</reference>", final_answer, re.DOTALL)
    cited_text = " ".join(citations).strip()

    if cited_text:
        print(f" -> Found {len(citations)} citation segments.")
        preview = cited_text[:100] + "..." if len(cited_text) > 100 else cited_text
        print(f" -> Cited Preview: {preview}")
    else:
        print(" -> [Warning] No <reference> tags found. (Model Failed to Follow Instructions)")

    print("\n[4] Calculating Semantic IOU...")
    iou_score = evaluator.calculate_iou(oracle_reference_text, cited_text)

    print(f"\n{'+'*60}")
    print(f" MISTRAL-SMALL EVALUATION REPORT")
    print(f"{'+'*60}")
    print(f" | Model            : mistral-small-2506")
    print(f" | Oracle Ref Len   : {len(oracle_reference_text)}")
    print(f" | Student Cited Len: {len(cited_text)}")
    print(f" | ----------------------------------")
    print(f" | Language IOU     : {iou_score:.4f}  (Range: 0.0 - 1.0)")
    print(f"{'+'*60}")

    if iou_score > 0.6:
        print(" => PASS: Mistral accurately cited the source text.")
    elif iou_score > 0.1:
        print(" => PARTIAL: Citations are relevant but not precise.")
    else:
        print(" => FAIL: Hallucination or failure to cite.")

# ==============================================================================
# [4] Entry Point
# ==============================================================================
if __name__ == "__main__":
    dummy_file = "data/aiact.txt"
    run_mistral_iou_test(dummy_file)

Loading Evaluator Model (all-MiniLM-L6-v2)...
[*] Initialized Mistral Client with model: mistral-small-2506

[TEST START] Mistral-Small Language IOU Test

[1] Setup Scenario
 -> Context Length : 4000 chars
 -> Oracle Ref Len : 220 chars
 -> Query          : Please explain the details regarding this segment: 'cles (OJ L 60, 2.3.2013, p. 52).\n\n<!-- page-start-marker-25 -->\n\nthe  Europe...'

[2] Calling Mistral API...
 -> Generation Time: 5.80s

[3] Parsing Citations...
 -> Found 4 citation segments.
 -> Cited Preview: Regulation (EU) No 168/2013 of the European Parliament and of the Council of 15 January 2013 on the ...

[4] Calculating Semantic IOU...

++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 MISTRAL-SMALL EVALUATION REPORT
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | Model            : mistral-small-2506
 | Oracle Ref Len   : 220
 | Student Cited Len: 1849
 | ----------------------------------
 | Language IOU     : 0.1429  (Range: 0.0 - 1.0)
