<a href="https://colab.research.google.com/github/wesslen/llm-experiments/blob/main/notebooks/bootstrap/bootstrap_ir_rag_reference_based.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!uv pip install --system "outlines[openai]"

In [12]:
import json
from typing import Dict, List, Optional, Union, Literal
from dataclasses import dataclass, asdict
from pydantic import BaseModel, ConfigDict, Field
from outlines.models.openai import OpenAIConfig
from outlines import models, generate
import openai
from openai import AsyncOpenAI, DefaultHttpxClient
import httpx
import numpy as np
from tqdm.notebook import tqdm
from enum import Enum
from sklearn.metrics.pairwise import cosine_similarity
import torch
from transformers import AutoTokenizer, AutoModel
from collections import defaultdict
from langchain.text_splitter import (
    RecursiveCharacterTextSplitter,
    CharacterTextSplitter,
    TokenTextSplitter,
    MarkdownHeaderTextSplitter,
)

class QuestionType(str, Enum):
    FACTUAL = "factual"
    PROCEDURAL = "procedural"
    CONCEPTUAL = "conceptual"
    COMPARATIVE = "comparative"
    ANALYTICAL = "analytical"

@dataclass
class QueryMetadata(BaseModel):
    """Metadata for a generated query"""
    model_config = ConfigDict(extra='forbid')  # required for openai structured generation

    query: str = Field(..., description="The generated question")
    answer: str = Field(..., description="The answer from the text")
    question_type: QuestionType = Field(..., description="Type of question")
    confidence_score: float = Field(..., ge=0.0, le=1.0, description="Confidence score between 0 and 1")
    requires_context: bool = Field(..., description="Whether additional context is needed")
    tokens_used: int = Field(..., ge=0, description="Number of tokens used")

@dataclass
class ChunkResponse(BaseModel):
    """Response structure for query generation"""
    model_config = ConfigDict(extra='forbid')
    queries: List[QueryMetadata] = Field(..., description="List of generated queries and answers")

@dataclass
class ChunkMetadata:
    """Enhanced metadata for a document chunk including detailed Q&A pairs"""
    chunk_id: int
    text: str
    queries: List[QueryMetadata]
    start_char: int
    end_char: int
    semantic_complexity: float = 0.0
    key_entities: List[str] = Field(default_factory=list)

class ChunkingStrategy(str, Enum):
    RECURSIVE = "recursive"
    CHARACTER = "character"
    TOKEN = "token"
    MARKDOWN = "markdown"


class EntityResponse(BaseModel):
    model_config = ConfigDict(extra='forbid')
    entities: List[str] = Field(..., description="List of key entities extracted from the text")

async def _extract_key_entities(self, text: str) -> List[str]:
    """Extract key entities from chunk using direct JSON response"""
    try:
        print("Starting entity extraction...")

        prompt = f"""Extract key entities (important concepts, terms, or names) from this text.
Return ONLY a JSON object with a single field 'entities' containing an array of strings.

Text: {text}

Required format example:
{{"entities": ["entity1", "entity2", "entity3"]}}"""

        response = await self.client.chat.completions.create(
            model=self.model,
            messages=[{"role": "user", "content": prompt}],
            response_format={"type": "json_object"}
        )

        result = json.loads(response.choices[0].message.content)
        print(f"Extracted entities: {result['entities']}")
        return result['entities']

    except Exception as e:
        print(f"Error in entity extraction: {e}")
        print(f"Full exception: {repr(e)}")
        return []


class AdvancedRAGGenerator:
    def __init__(
      self,
      api_key: str,
      embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
      chunk_size: int = 500,
      chunk_overlap: int = 50,
      queries_per_chunk: int = 3,
      temperature: float = 0.7,
      model: str = "gpt-3.5-turbo",
      chunking_strategy: ChunkingStrategy = ChunkingStrategy.RECURSIVE,
      question_types: List[QuestionType] = None,
      min_confidence_threshold: float = 0.7,
      separators: List[str] = None,
      base_url: Optional[str] = None,
      http_client: Optional[httpx.AsyncClient] = None
  ):
        """
        Initialize the advanced RAG ground truth generator

        Args:
            openai_client: OpenAI API compatible client
            embedding_model: Model to use for semantic similarity
            chunk_size: Size of each document chunk in characters/tokens
            chunk_overlap: Overlap between chunks in characters/tokens
            queries_per_chunk: Number of Q&A pairs to generate per chunk
            temperature: Temperature for query generation
            model: Model to use for generation
            chunking_strategy: Strategy for document chunking
            question_types: List of question types to generate
            min_confidence_threshold: Minimum confidence score for generated Q&A pairs
            separators: Custom separators for recursive splitting
        """
        print("Initializing AdvancedRAGGenerator...")  # Debug log

        # Initialize OpenAI client
        self.client = AsyncOpenAI(
            api_key=api_key,
            base_url=base_url,
            http_client=http_client
        )
        print(f"OpenAI client initialized with base_url: {base_url}")  # Debug log

        # Store model name
        self.model = model
        print(f"Model name set to: {model}")  # Debug log

        # Configure and store outlines model
        config = OpenAIConfig(
            model,
            temperature=temperature,
            max_tokens=2000  # Add reasonable max tokens
        )
        print("OpenAI config created")  # Debug log

        try:
            self.openai_model = models.openai(self.client, config)
            print("Outlines model successfully initialized")  # Debug log
        except Exception as e:
            print(f"Error initializing outlines model: {e}")
            raise

        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.queries_per_chunk = queries_per_chunk
        self.temperature = temperature
        self.chunking_strategy = chunking_strategy
        self.question_types = question_types or list(QuestionType)
        self.min_confidence_threshold = min_confidence_threshold
        self.separators = separators or ["\n\n", "\n", ". ", " ", ""]

        # Initialize embedding model
        self.tokenizer = AutoTokenizer.from_pretrained(embedding_model)
        self.embedding_model = AutoModel.from_pretrained(embedding_model)

        # Initialize text splitter based on strategy
        self._init_text_splitter()

    def _init_text_splitter(self):
        """Initialize the appropriate text splitter based on strategy"""
        if self.chunking_strategy == ChunkingStrategy.RECURSIVE:
            self.text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=self.chunk_size,
                chunk_overlap=self.chunk_overlap,
                separators=self.separators
            )
        elif self.chunking_strategy == ChunkingStrategy.CHARACTER:
            self.text_splitter = CharacterTextSplitter(
                chunk_size=self.chunk_size,
                chunk_overlap=self.chunk_overlap,
                separator="\n"
            )
        elif self.chunking_strategy == ChunkingStrategy.TOKEN:
            self.text_splitter = TokenTextSplitter(
                chunk_size=self.chunk_size,
                chunk_overlap=self.chunk_overlap
            )
        elif self.chunking_strategy == ChunkingStrategy.MARKDOWN:
            headers_to_split_on = [
                ("#", "Header 1"),
                ("##", "Header 2"),
                ("###", "Header 3"),
            ]
            self.text_splitter = MarkdownHeaderTextSplitter(
                headers_to_split_on=headers_to_split_on
            )

    def get_embedding(self, text: str) -> np.ndarray:
        """Generate embedding for text using the embedding model"""
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
        with torch.no_grad():
            outputs = self.embedding_model(**inputs)
        return outputs.last_hidden_state.mean(dim=1).numpy()

    async def chunk_document(self, text: str) -> List[ChunkMetadata]:
        """Split document using LangChain text splitters"""
        print("Starting chunk_document...")  # Debug log
        chunks = []

        try:
            if self.chunking_strategy == ChunkingStrategy.MARKDOWN:
                raw_chunks = self.text_splitter.split_text(text)
                start_char = 0

                for i, chunk in enumerate(raw_chunks):
                    print(f"Processing chunk {i}")  # Debug log
                    chunk_text = chunk.page_content
                    chunk_length = len(chunk_text)

                    # Get entities with await
                    entities = await self._extract_key_entities(chunk_text)
                    print(f"Got entities for chunk {i}: {entities}")  # Debug log

                    chunks.append(ChunkMetadata(
                        chunk_id=i,
                        text=chunk_text,
                        queries=[],
                        start_char=start_char,
                        end_char=start_char + chunk_length,
                        semantic_complexity=self._calculate_semantic_complexity(chunk_text),
                        key_entities=entities
                    ))
                    start_char += chunk_length
            else:
                raw_chunks = self.text_splitter.split_text(text)
                start_char = 0

                for i, chunk_text in enumerate(raw_chunks):
                    print(f"Processing chunk {i}")  # Debug log
                    chunk_length = len(chunk_text)

                    # Get entities with await
                    entities = await self._extract_key_entities(chunk_text)
                    print(f"Got entities for chunk {i}: {entities}")  # Debug log

                    chunks.append(ChunkMetadata(
                        chunk_id=i,
                        text=chunk_text,
                        queries=[],
                        start_char=start_char,
                        end_char=start_char + chunk_length,
                        semantic_complexity=self._calculate_semantic_complexity(chunk_text),
                        key_entities=entities
                    ))
                    start_char += chunk_length - self.chunk_overlap

            print(f"Completed chunk_document with {len(chunks)} chunks")  # Debug log
            return chunks

        except Exception as e:
            print(f"Error in chunk_document: {e}")
            print(f"Full exception: {repr(e)}")
            raise

    def _calculate_semantic_complexity(self, text: str) -> float:
        """Calculate semantic complexity score for a chunk"""
        embedding = self.get_embedding(text)
        return float(np.var(embedding))

    # Update test_entity_extraction to match
    async def test_entity_extraction(self):
        """Test function for entity extraction"""
        test_text = "This is a test text about banking and transactions."
        try:
            print("Testing entity extraction...")
            entities = await self._extract_key_entities(test_text)
            print(f"Successfully extracted entities: {entities}")
            return entities
        except Exception as e:
            print(f"Error in test: {e}")
            print(f"Full exception: {repr(e)}")
            return None

    async def test_simple_completion(self):
        """Test basic completion with simple JSON response"""
        try:
            print("Testing simple completion...")
            response = await self.client.chat.completions.create(
                model=self.model,
                messages=[{
                    "role": "user",
                    "content": "Return only a JSON object with format: {\"text\": \"hello world\"}"
                }],
                response_format={"type": "json_object"}
            )

            print(f"Response received: {response.choices[0].message.content}")
            # Try parsing the JSON to verify it's valid
            result = json.loads(response.choices[0].message.content)
            if "text" in result:
                print(f"Successfully parsed response: {result}")
                return True
            return False

        except Exception as e:
            print(f"Simple test failed: {e}")
            print(f"Full exception: {repr(e)}")
            return False

    async def _extract_key_entities(self, text: str) -> List[str]:
        """Extract key entities from chunk using direct JSON response"""
        try:
            prompt = f"""You are a helpful assistant that extracts key entities from text.

    Your task: Read the text below and identify the main entities (important terms, concepts, numbers, or names).

    Text to analyze: {text}

    Instructions:
    1. Extract all important entities (nouns, terms, numbers, names)
    2. Return them in a JSON array
    3. Use EXACTLY this format: {{"entities": ["entity1", "entity2", "entity3"]}}

    Example output format:
    {{"entities": ["wire transfer", "$50,000", "dual authorization", "officers"]}}"""

            response = await self.client.chat.completions.create(
                model=self.model,
                messages=[{
                    "role": "system",
                    "content": "You are a precise entity extraction assistant that responds only in JSON format."
                },
                {
                    "role": "user",
                    "content": prompt
                }],
                response_format={"type": "json_object"}
            )

            print(f"Entity extraction raw response: {response.choices[0].message.content}")
            result = json.loads(response.choices[0].message.content)

            # Check if we got an error response
            if "errors" in result:
                print(f"Got error response: {result['errors']}")
                # Try to extract entities from error message if possible
                return []

            entities = result.get('entities', [])
            if not entities:
                print("Warning: No entities extracted")
            else:
                print(f"Successfully extracted {len(entities)} entities")

            return entities

        except Exception as e:
            print(f"Error in entity extraction: {e}")
            print(f"Full exception: {repr(e)}")
            return []

    async def generate_queries(self, chunk: ChunkMetadata) -> List[QueryMetadata]:
        """Generate diverse queries using direct JSON response"""
        try:
            prompt = f"""You are a helpful assistant that generates questions and answers from text.

    Text to analyze:
    {chunk.text}

    Task: Create {self.queries_per_chunk} diverse questions and answers based on this text.

    Instructions:
    1. Generate questions that test different aspects of the content
    2. Provide accurate answers from the text
    3. Return the response in the exact JSON format shown below

    Required JSON format:
    {{
        "queries": [
            {{
                "query": "Your question here?",
                "answer": "Answer from the text",
                "question_type": "factual",
                "confidence_score": 0.95,
                "requires_context": false,
                "tokens_used": 42
            }}
        ]
    }}

    Notes:
    - question_type must be one of: factual, procedural, conceptual, comparative, analytical
    - confidence_score must be between 0 and 1
    - requires_context should be true if the answer needs more context
    - tokens_used can be estimated"""

            response = await self.client.chat.completions.create(
                model=self.model,
                messages=[{
                    "role": "system",
                    "content": "You are a precise question generation assistant that responds only in JSON format."
                },
                {
                    "role": "user",
                    "content": prompt
                }],
                response_format={"type": "json_object"}
            )

            print(f"Query generation raw response: {response.choices[0].message.content}")
            result = json.loads(response.choices[0].message.content)

            # Check if we got an error response
            if "errors" in result:
                print(f"Got error response: {result['errors']}")
                return []

            # Convert the JSON response to QueryMetadata objects
            queries = []
            for q in result.get('queries', []):
                try:
                    query = QueryMetadata(
                        query=q['query'],
                        answer=q['answer'],
                        question_type=QuestionType(q['question_type']),
                        confidence_score=float(q['confidence_score']),
                        requires_context=bool(q['requires_context']),
                        tokens_used=int(q['tokens_used'])
                    )
                    if query.confidence_score >= self.min_confidence_threshold:
                        queries.append(query)
                        print(f"Added query: {query.query}")
                except Exception as e:
                    print(f"Error processing query: {e}")
                    continue

            return queries

        except Exception as e:
            print(f"Error generating queries: {e}")
            print(f"Full exception: {repr(e)}")
            return []


    def evaluate_ground_truth(self, chunks: List[ChunkMetadata]) -> Dict:
        """Evaluate the quality of generated ground truth data"""
        metrics = {
            "total_chunks": len(chunks),
            "total_queries": sum(len(chunk.queries) for chunk in chunks),
            "question_type_distribution": defaultdict(int),
            "avg_confidence_score": 0,
            "requires_context_percentage": 0,
            "avg_semantic_complexity": 0,
            "chunk_similarity_matrix": None,
            "query_similarity_matrix": None,
            "duplicate_queries": 0,
            "coverage_score": 0
        }

        # Calculate basic metrics
        total_queries = 0
        total_confidence = 0
        requires_context = 0

        for chunk in chunks:
            metrics["avg_semantic_complexity"] += chunk.semantic_complexity
            for query in chunk.queries:
                total_queries += 1
                total_confidence += query.confidence_score
                requires_context += int(query.requires_context)
                metrics["question_type_distribution"][query.question_type] += 1

        if total_queries > 0:
            metrics["avg_confidence_score"] = total_confidence / total_queries
            metrics["requires_context_percentage"] = (requires_context / total_queries * 100)
            metrics["avg_semantic_complexity"] /= len(chunks)

        # Calculate chunk similarity matrix
        chunk_embeddings = np.vstack([self.get_embedding(chunk.text) for chunk in chunks])
        metrics["chunk_similarity_matrix"] = cosine_similarity(chunk_embeddings)

        # Calculate query similarity and detect duplicates
        all_queries = []
        for chunk in chunks:
            all_queries.extend([q.query for q in chunk.queries])

        if all_queries:  # Only calculate if we have queries
            query_embeddings = np.vstack([self.get_embedding(q) for q in all_queries])
            query_similarities = cosine_similarity(query_embeddings)

            # Count highly similar queries (potential duplicates)
            duplicate_threshold = 0.95
            metrics["duplicate_queries"] = sum(sum(row > duplicate_threshold) for row in query_similarities) - len(all_queries)

            # Calculate coverage score
            coverage_scores = []
            for chunk in chunks:
                chunk_embedding = self.get_embedding(chunk.text)
                query_embeddings = np.vstack([self.get_embedding(q.query) for q in chunk.queries])
                max_similarities = np.max(cosine_similarity(query_embeddings, chunk_embedding), axis=0)
                coverage_scores.append(np.mean(max_similarities))

            metrics["coverage_score"] = np.mean(coverage_scores)

        return metrics

    async def process_document(self, document: str) -> List[ChunkMetadata]:
        """Process document and generate evaluated ground truth data"""
        print("\n=== Starting document processing ===")
        print(f"Document length: {len(document)} characters")
        print(f"Using model: {self.model}")
        print(f"Client initialized: {self.client is not None}")
        print(f"Outlines model initialized: {self.openai_model is not None}")

        try:
            print("\n--- Chunking document ---")
            chunks = await self.chunk_document(document)
            print(f"Created {len(chunks)} chunks")
            for i, chunk in enumerate(chunks):
                print(f"\nChunk {i} preview:")
                print(f"- Length: {len(chunk.text)} characters")
                print(f"- First 100 chars: {chunk.text[:100]}...")
                print(f"- Entities: {chunk.key_entities}")

            processed_chunks = []
            for chunk in tqdm(chunks, desc="Generating queries"):
                try:
                    print(f"\n--- Processing chunk {chunk.chunk_id} ---")
                    queries = await self.generate_queries(chunk)
                    print(f"Generated queries: {len(queries)}")
                    for q in queries:
                        print(f"- Q: {q.query}")
                        print(f"  A: {q.answer}")
                    chunk.queries = queries
                    processed_chunks.append(chunk)
                except Exception as e:
                    print(f"Error processing chunk {chunk.chunk_id}: {e}")
                    print(f"Full exception: {repr(e)}")

            return processed_chunks

        except Exception as e:
            print(f"Error in process_document: {e}")
            print(f"Full exception: {repr(e)}")
            raise


    def save_ground_truth(self, chunks: List[ChunkMetadata], output_file: str):
        """Save ground truth data with evaluation metrics"""
        evaluation = self.evaluate_ground_truth(chunks)

        # Convert numpy arrays to lists for JSON serialization
        if "chunk_similarity_matrix" in evaluation:
            evaluation["chunk_similarity_matrix"] = evaluation["chunk_similarity_matrix"].tolist()

        data = {
            "chunks": [asdict(chunk) for chunk in chunks],
            "evaluation_metrics": evaluation
        }

        with open(output_file, 'w') as f:
            json.dump(data, f, indent=2)


In [13]:
from openai import AsyncOpenAI
from google.colab import userdata
api_key = userdata.get('API_KEY')
base_url = userdata.get('LLAMA_BASE_URL')

import httpx
import asyncio

# Optional custom HTTP client
# http_client = httpx.AsyncClient(
#     proxies="http://my.proxy.example.com",
#     transport=httpx.HTTPTransport(local_address="0.0.0.0"),
# )

# Initialize generator with debug mode
generator = AdvancedRAGGenerator(
    api_key=api_key,
    base_url=base_url,
    # http_client=http_client,  # optional
    chunk_size=500,
    chunk_overlap=50,
    queries_per_chunk=3,
    model="/models/NousResearch/Meta-Llama-3.1-8B-Instruct"  # or your specific model
)

Initializing AdvancedRAGGenerator...
OpenAI client initialized with base_url: https://wesslen--vllm-openai-compatible-serve.modal.run/v1
Model name set to: /models/NousResearch/Meta-Llama-3.1-8B-Instruct
OpenAI config created
Outlines model successfully initialized


In [4]:
# Sample document (banking policies)
policy_document = '''
# Internal Banking Policy Manual

## Section 1: Transaction Processing

### 1.1 Wire Transfer Authorization
All wire transfers exceeding $50,000 require dual authorization from designated officers.
The primary authorizer must be a department manager or above, while the secondary
authorizer must be from a different department to ensure segregation of duties.

### 1.2 Customer Authentication
Customer identity must be verified through two-factor authentication for all high-risk
transactions. This includes:
- Wire transfers above $10,000
- Changes to account ownership
- Updates to primary contact information

## Section 2: Suspicious Activity Reporting

### 2.1 Reporting Requirements
Staff must report any suspicious transactions through the SAR filing system within 24
hours of detection. Red flags include:
- Structured deposits just below reporting thresholds
- Rapid movement of funds between accounts
- Multiple high-value transactions from dormant accounts
'''


## Generate example

In [14]:
# minimum test
async def test_minimal():
    try:
        # Initialize with minimal configuration
        gen = AdvancedRAGGenerator(
            api_key=api_key,
            base_url=base_url,
            model="/models/NousResearch/Meta-Llama-3.1-8B-Instruct"
        )

        # Create chat completion with JSON response format
        response = await gen.client.chat.completions.create(
            model=gen.model,
            messages=[{
                "role": "user",
                "content": "Return a simple JSON response with a 'text' field containing 'hello'. Format: {\"text\": \"hello\"}"
            }],
            response_format={"type": "json_object"}
        )
        print(f"Test result: {response}")
        return True
    except Exception as e:
        print(f"Minimal test failed: {e}")
        print(f"Full exception: {repr(e)}")
        return False

# Run minimal test
await test_minimal()

Initializing AdvancedRAGGenerator...
OpenAI client initialized with base_url: https://wesslen--vllm-openai-compatible-serve.modal.run/v1
Model name set to: /models/NousResearch/Meta-Llama-3.1-8B-Instruct
OpenAI config created
Outlines model successfully initialized
Test result: ChatCompletion(id='chat-849b3e7fced04627893e23e3ed59b8c6', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='{ \n "text" \t: \t"hello" \n}', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=[]), stop_reason=None)], created=1736052780, model='/models/NousResearch/Meta-Llama-3.1-8B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, usage=CompletionUsage(completion_tokens=16, prompt_tokens=33, total_tokens=49, completion_tokens_details=None, prompt_tokens_details=None))


True

In [15]:
# For testing the changes
async def test_run():
    print("\n=== Testing Individual Components ===")

    test_text = """Wire transfers exceeding $50,000 require dual authorization from designated officers.
The primary authorizer must be a department manager or above."""

    # Test 1: Simple completion
    print("\n--- Test 1: Simple Completion ---")
    simple_result = await generator.test_simple_completion()
    print(f"Simple completion test result: {simple_result}")

    if not simple_result:
        print("Simple completion test failed, stopping tests")
        return

    # Test 2: Entity extraction
    print("\n--- Test 2: Entity Extraction ---")
    print(f"Testing with text:\n{test_text}")
    entities = await generator._extract_key_entities(test_text)
    print(f"Extracted entities: {entities}")

    if not entities:
        print("Entity extraction returned no results, but continuing...")

    # Test 3: Generate queries for a small chunk
    print("\n--- Test 3: Query Generation ---")
    test_chunk = ChunkMetadata(
        chunk_id=0,
        text=test_text,
        queries=[],
        start_char=0,
        end_char=len(test_text),
        key_entities=entities
    )

    print("Generating queries for test chunk...")
    queries = await generator.generate_queries(test_chunk)

    if queries:
        print("\nGenerated Queries:")
        for i, q in enumerate(queries, 1):
            print(f"\nQuery {i}:")
            print(f"Q: {q.query}")
            print(f"A: {q.answer}")
            print(f"Type: {q.question_type}")
            print(f"Confidence: {q.confidence_score}")
    else:
        print("No queries were generated")

    return simple_result and (len(queries) > 0)

# Run the test
success = await test_run()
print(f"\nOverall test {'succeeded' if success else 'failed'}")

# If tests pass, ask to run full document processing
# if success and input("\nRun full document processing? (y/n): ").lower() == 'y':
#     chunks = await debug_main()


=== Testing Individual Components ===

--- Test 1: Simple Completion ---
Testing simple completion...
Response received: { 	"text" 	: 	"hello world" 	}
Successfully parsed response: {'text': 'hello world'}
Simple completion test result: True

--- Test 2: Entity Extraction ---
Testing with text:
Wire transfers exceeding $50,000 require dual authorization from designated officers.
The primary authorizer must be a department manager or above.
Entity extraction raw response: 

 	{ 
 	"entities" 	: 	[ 
 	               	"Wire transfers" 	               	, 	               	"$50,000" 	               	, 	               	"dual authorization" 	               	, 	               	"officers" 	               	, 	               	"department manager" 	               	] 
 	}
Successfully extracted 5 entities
Extracted entities: ['Wire transfers', '$50,000', 'dual authorization', 'officers', 'department manager']

--- Test 3: Query Generation ---
Generating queries for test chunk...
Query generation ra

In [None]:
# Debug generator initialization
print("\n=== Testing Generator Configuration ===")
print(f"API Key present: {api_key is not None}")
print(f"Base URL: {base_url}")
print(f"Model: {generator.model}")

# Run the main process with proper async handling
async def debug_main():
    try:
        print("\n=== Starting Debug Process ===")

        # Test 1: Simple completion
        print("\n--- Test 1: Simple Completion ---")
        completion_result = await generator.test_simple_completion()
        print(f"Simple completion test result: {completion_result}")
        if not completion_result:
            return

        # Test 2: Entity extraction
        print("\n--- Test 2: Entity Extraction ---")
        entity_result = await generator.test_entity_extraction()
        print(f"Entity extraction test result: {entity_result}")
        if entity_result is None:
            return

        # Main processing
        print("\n--- Main Document Processing ---")
        chunks = await generator.process_document(policy_document)

        # Results summary
        print("\n=== Results Summary ===")
        print(f"Total chunks processed: {len(chunks)}")
        total_queries = sum(len(chunk.queries) for chunk in chunks)
        print(f"Total queries generated: {total_queries}")

        # Save results
        generator.save_ground_truth(chunks, 'ground_truth.json')
        print("\nProcessing complete - results saved to ground_truth.json")

        return chunks

    except Exception as e:
        print(f"\n!!! Error in debug_main: {str(e)}")
        print(f"Full exception: {repr(e)}")
        raise

# Execute with proper async handling
chunks = await debug_main()


=== Testing Generator Configuration ===
API Key present: True
Base URL: https://wesslen--vllm-openai-compatible-serve.modal.run/v1
Model: /models/NousResearch/Meta-Llama-3.1-8B-Instruct

=== Starting Debug Process ===

--- Test 1: Simple Completion ---
Testing simple completion...
Response received: { 
 	"text" 	: 	"hello world" 
}
Successfully parsed response: {'text': 'hello world'}
Simple completion test result: True

--- Test 2: Entity Extraction ---
Testing entity extraction...
