From 5d1c9ce0d48f17f8e40888267de625f549846062 Mon Sep 17 00:00:00 2001 From: nuwangeek Date: Thu, 11 Sep 2025 13:16:27 +0530 Subject: [PATCH 1/3] partialy completes prompt refiner --- .dockerignore | 106 +++++++ API_README.md | 136 +++++++++ Dockerfile.llm_orchestration_service | 78 +++++ build-llm-service.sh | 57 ++++ docker-compose.llm-dev.yml | 33 ++ docker-compose.yml | 32 ++ pyproject.toml | 2 + run_api.py | 43 +++ src/__init__.py | 1 + src/llm_orchestration_service.py | 184 +++++++++++ src/llm_orchestration_service_api.py | 120 ++++++++ src/models/__init__.py | 1 + src/models/request_models.py | 55 ++++ src/prompt_refiner_module/prompt_refiner.py | 207 +++++++++++++ test_api.py | 89 ++++++ test_integration.py | 57 ++++ test_prompt_refiner_schema.py | 72 +++++ tests/test_prompt_refiner.py | 322 ++++++++++++++++++++ uv.lock | 44 +++ 19 files changed, 1639 insertions(+) create mode 100644 .dockerignore create mode 100644 API_README.md create mode 100644 Dockerfile.llm_orchestration_service create mode 100644 build-llm-service.sh create mode 100644 docker-compose.llm-dev.yml create mode 100644 run_api.py create mode 100644 src/__init__.py create mode 100644 src/llm_orchestration_service.py create mode 100644 src/llm_orchestration_service_api.py create mode 100644 src/models/__init__.py create mode 100644 src/models/request_models.py create mode 100644 src/prompt_refiner_module/prompt_refiner.py create mode 100644 test_api.py create mode 100644 test_integration.py create mode 100644 test_prompt_refiner_schema.py create mode 100644 tests/test_prompt_refiner.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..d25f099 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,106 @@ +# Docker ignore file for LLM Orchestration Service +# Exclude unnecessary files from Docker build context + +# Git +.git +.gitignore + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# Virtual environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Logs +*.log +logs/ +*.log.* + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ +coverage.xml + +# Documentation +docs/ +*.md +!README.md + +# Config files (will be mounted) +.env.local +.env.development +.env.test + +# Cache directories +.ruff_cache/ +.mypy_cache/ +.pyright_cache/ + +# Test files +test_*.py +*_test.py +tests/ + +# Development scripts +run_*.py +test_*.py + +# Temporary files +*.tmp +*.temp +.temporary + +# Node modules (if any) +node_modules/ + +# Docker files (except the specific one being built) +Dockerfile* +!Dockerfile.llm_orchestration_service +docker-compose*.yml + +# Grafana configs (not needed for this service) +grafana-configs/ diff --git a/API_README.md b/API_README.md new file mode 100644 index 0000000..2f67761 --- /dev/null +++ b/API_README.md @@ -0,0 +1,136 @@ +# LLM Orchestration Service API + +A FastAPI-based service for orchestrating LLM requests with configuration management and proper validation. + +## API Endpoints + +### POST /orchestrate +Processes LLM orchestration requests. + +**Request Body:** +```json +{ + "chatId": "chat-12345", + "message": "I need help with my electricity bill.", + "authorId": "12345", + "conversationHistory": [ + { + "authorRole": "user", + "message": "Hi, I have a billing issue", + "timestamp": "2025-04-29T09:00:00Z" + }, + { + "authorRole": "bot", + "message": "Sure, can you tell me more about the issue?", + "timestamp": "2025-04-29T09:00:05Z" + } + ], + "url": "id.ee", + "environment": "production|test|development", + "connection_id": "optional-connection-id" +} +``` + +**Response:** +```json +{ + "chatId": "chat-12345", + "llmServiceActive": true, + "questionOutOfLLMScope": false, + "inputGuardFailed": false, + "content": "This is a random answer payload.\n\nwith citations.\n\nReferences\n- https://gov.ee/sample1,\n- https://gov.ee/sample2" +} +``` + +### GET /health +Health check endpoint. + +**Response:** +```json +{ + "status": "healthy", + "service": "llm-orchestration-service" +} +``` + +## Running the API + +### Local Development: +```bash +uv run uvicorn src.llm_orchestration_service_api:app --host 0.0.0.0 --port 8100 --reload +``` + +### Docker (Standalone): +```bash +# Build and run with custom script +.\build-llm-service.bat run # Windows +./build-llm-service.sh run # Linux/Mac + +# Or manually +docker build -f Dockerfile.llm_orchestration_service -t llm-orchestration-service . +docker run -p 8100:8100 --env-file .env llm-orchestration-service +``` + +### Docker Compose (Production): +```bash +docker-compose up llm-orchestration-service +``` + +### Docker Compose (Development with hot reload): +```bash +docker-compose -f docker-compose.yml -f docker-compose.llm-dev.yml up llm-orchestration-service +``` + +### Test the API: +```bash +uv run python test_api.py +``` + +## Features + +- ✅ FastAPI with automatic OpenAPI documentation +- ✅ Pydantic validation for requests/responses +- ✅ Proper error handling and logging with Loguru +- ✅ Integration with existing LLM config module +- ✅ Type-safe implementation +- ✅ Health check endpoint +- 🔄 Hardcoded responses (TODO: Implement actual LLM pipeline) + +## Documentation + +When the server is running, visit: +- API docs: http://localhost:8100/docs +- ReDoc: http://localhost:8100/redoc + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────┐ +│ FastAPI Application │ +│ (llm_orchestration_service_api.py) │ +└─────────────────────┬───────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Business Logic Service │ +│ (llm_orchestration_service.py) │ +└─────────────────────┬───────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ LLM Config Module │ +│ (llm_manager.py) │ +└─────────────────────────────────────────────────────────────┘ +``` + +## TODO Items + +- [ ] Implement actual LLM processing pipeline +- [ ] Add input validation and guard checks +- [ ] Implement question scope validation +- [ ] Add proper citation generation +- [ ] Handle multi-tenant scenarios with connection_id +- [ ] Add authentication/authorization +- [ ] Add comprehensive error handling +- [ ] Add request/response logging +- [ ] Add metrics and monitoring diff --git a/Dockerfile.llm_orchestration_service b/Dockerfile.llm_orchestration_service new file mode 100644 index 0000000..7966747 --- /dev/null +++ b/Dockerfile.llm_orchestration_service @@ -0,0 +1,78 @@ +# Dockerfile for LLM Orchestration Service +# Multi-stage build for optimized production image + +# Stage 1: Build environment with uv +FROM python:3.12-slim AS builder + +# Set environment variables for uv +ENV UV_CACHE_DIR=/opt/uv-cache \ + UV_LINK_MODE=copy \ + UV_COMPILE_BYTECODE=1 \ + UV_PYTHON_DOWNLOADS=never + +# Install system dependencies for building +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Install uv using the official installer (as per CONTRIBUTING.md) +RUN curl -LsSf https://astral.sh/uv/install.sh | sh + +# Add uv to PATH +ENV PATH="/root/.cargo/bin:$PATH" + +# Set working directory +WORKDIR /app + +# Copy uv configuration files +COPY pyproject.toml uv.lock ./ + +# Install dependencies using uv +RUN uv sync --frozen --no-dev + +# Stage 2: Runtime environment +FROM python:3.12-slim AS runtime + +# Set environment variables +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PATH="/app/.venv/bin:$PATH" \ + PYTHONPATH="/app/src" + +# Install runtime system dependencies +RUN apt-get update && apt-get install -y \ + curl \ + && rm -rf /var/lib/apt/lists/* \ + && apt-get clean + +# Create non-root user for security +RUN groupadd -r appuser && useradd -r -g appuser appuser + +# Set working directory +WORKDIR /app + +# Copy virtual environment from builder stage +COPY --from=builder /app/.venv /app/.venv + +# Copy source code +COPY src/ src/ + +# Copy configuration files (will be mounted as volumes in production) +COPY src/llm_config_module/config/llm_config.yaml src/llm_config_module/config/ + +# Create logs directory +RUN mkdir -p logs && chown -R appuser:appuser /app + +# Switch to non-root user +USER appuser + +# Expose the application port +EXPOSE 8100 + +# Health check using the FastAPI health endpoint +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8100/health || exit 1 + +# Default command to run the LLM orchestration service +CMD ["uvicorn", "src.llm_orchestration_service_api:app", "--host", "0.0.0.0", "--port", "8100"] diff --git a/build-llm-service.sh b/build-llm-service.sh new file mode 100644 index 0000000..4a918dd --- /dev/null +++ b/build-llm-service.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +# Build and run script for LLM Orchestration Service Docker container + +set -e + +echo "🐳 Building LLM Orchestration Service Docker container..." + +# Build the Docker image +docker build -f Dockerfile.llm_orchestration_service -t llm-orchestration-service:latest . + +echo "✅ Docker image built successfully!" + +# Check if we should run the container +if [ "$1" = "run" ]; then + echo "🚀 Starting LLM Orchestration Service container..." + + # Stop and remove existing container if it exists + docker stop llm-orchestration-service 2>/dev/null || true + docker rm llm-orchestration-service 2>/dev/null || true + + # Run the container + docker run -d \ + --name llm-orchestration-service \ + --network bykstack \ + -p 8100:8100 \ + --env-file .env \ + -e ENVIRONMENT=development \ + -v "$(pwd)/src/llm_config_module/config:/app/src/llm_config_module/config:ro" \ + -v llm_orchestration_logs:/app/logs \ + llm-orchestration-service:latest + + echo "✅ LLM Orchestration Service is running!" + echo "🌐 API available at: http://localhost:8100" + echo "🔍 Health check: http://localhost:8100/health" + echo "📊 API docs: http://localhost:8100/docs" + + # Show logs + echo "" + echo "📋 Container logs (Ctrl+C to stop viewing logs):" + docker logs -f llm-orchestration-service + +elif [ "$1" = "compose" ]; then + echo "🚀 Starting with Docker Compose..." + docker-compose up --build llm-orchestration-service + +else + echo "" + echo "📖 Usage:" + echo " $0 - Build the Docker image only" + echo " $0 run - Build and run the container standalone" + echo " $0 compose - Build and run with docker-compose" + echo "" + echo "🌐 Once running, the API will be available at:" + echo " Health check: http://localhost:8100/health" + echo " API docs: http://localhost:8100/docs" +fi diff --git a/docker-compose.llm-dev.yml b/docker-compose.llm-dev.yml new file mode 100644 index 0000000..8224ac5 --- /dev/null +++ b/docker-compose.llm-dev.yml @@ -0,0 +1,33 @@ +# Docker Compose override for LLM Orchestration Service development +# Use: docker-compose -f docker-compose.yml -f docker-compose.llm-dev.yml up + +version: '3.8' + +services: + llm-orchestration-service: + build: + context: . + dockerfile: Dockerfile.llm_orchestration_service + target: runtime + environment: + - ENVIRONMENT=development + - PYTHONPATH=/app/src + volumes: + # Mount source code for development (hot reload if needed) + - ./src:/app/src + # Mount configuration files + - ./src/llm_config_module/config:/app/src/llm_config_module/config:ro + # Mount logs for easier debugging + - ./logs:/app/logs + command: > + uvicorn src.llm_orchestration_service_api:app + --host 0.0.0.0 + --port 8100 + --reload + --reload-dir /app/src + ports: + - "8100:8100" + depends_on: + - vault + networks: + - bykstack diff --git a/docker-compose.yml b/docker-compose.yml index bc71344..1aace95 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -258,11 +258,43 @@ services: timeout: 5s retries: 5 + # LLM Orchestration Service + llm-orchestration-service: + build: + context: . + dockerfile: Dockerfile.llm_orchestration_service + container_name: llm-orchestration-service + restart: unless-stopped + ports: + - "8100:8100" + env_file: + - .env + environment: + - ENVIRONMENT=production + - PYTHONPATH=/app/src + volumes: + # Mount configuration files + - ./src/llm_config_module/config:/app/src/llm_config_module/config:ro + # Mount logs directory for persistence + - llm_orchestration_logs:/app/logs + networks: + - bykstack + depends_on: + - vault + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8100/health"] + interval: 30s + timeout: 10s + start_period: 40s + retries: 3 + volumes: loki-data: name: loki-data grafana-data: name: grafana-data + llm_orchestration_logs: + name: llm_orchestration_logs qdrant_data: name: qdrant_data rag-search-db: diff --git a/pyproject.toml b/pyproject.toml index 7533f6c..680aa3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,8 @@ dependencies = [ "pydantic>=2.11.7", "testcontainers>=4.13.0", "hvac>=2.3.0", + "fastapi>=0.116.1", + "uvicorn>=0.35.0", ] [tool.pyright] diff --git a/run_api.py b/run_api.py new file mode 100644 index 0000000..5585b97 --- /dev/null +++ b/run_api.py @@ -0,0 +1,43 @@ +"""Run script for LLM Orchestration Service API.""" + +import sys +import os +from pathlib import Path + +# Add src directory to Python path +src_path = Path(__file__).parent / "src" +sys.path.insert(0, str(src_path)) + +if __name__ == "__main__": + try: + import uvicorn # type: ignore[import-untyped] + + print("Starting LLM Orchestration Service API on port 8100...") + print(f"Source path: {src_path}") + + # Change to src directory and run + os.chdir(src_path) + + uvicorn.run( # type: ignore[attr-defined] + "llm_orchestration_service_api:app", + host="0.0.0.0", + port=8100, + reload=True, + log_level="info", + ) + + except ImportError: + print("uvicorn not installed. Please install dependencies first.") + print("Commands to run the API:") + print("1. From project root:") + print( + " cd src && uv run uvicorn llm_orchestration_service_api:app --host 0.0.0.0 --port 8100 --reload" + ) + print("2. Or use this script:") + print(" uv run python run_api.py") + except Exception as e: + print(f"Error starting server: {e}") + print("\nAlternative commands to try:") + print( + "cd src && uv run uvicorn llm_orchestration_service_api:app --host 0.0.0.0 --port 8100 --reload" + ) diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..060e4ea --- /dev/null +++ b/src/__init__.py @@ -0,0 +1 @@ +"""Source package for RAG Module.""" diff --git a/src/llm_orchestration_service.py b/src/llm_orchestration_service.py new file mode 100644 index 0000000..cc08995 --- /dev/null +++ b/src/llm_orchestration_service.py @@ -0,0 +1,184 @@ +"""LLM Orchestration Service - Business logic for LLM orchestration.""" + +from typing import Optional, List, Dict +import json +from loguru import logger + +from llm_config_module.llm_manager import LLMManager +from models.request_models import ( + OrchestrationRequest, + OrchestrationResponse, + ConversationItem, + PromptRefinerOutput, +) +from prompt_refiner_module.prompt_refiner import PromptRefinerAgent + + +class LLMOrchestrationService: + """Service class for handling LLM orchestration business logic.""" + + def __init__(self) -> None: + """Initialize the orchestration service.""" + self.llm_manager: Optional[LLMManager] = None + + def process_orchestration_request( + self, request: OrchestrationRequest + ) -> OrchestrationResponse: + """ + Process an orchestration request and return response. + + Args: + request: The orchestration request containing user message and context + + Returns: + OrchestrationResponse: Response with LLM output and status flags + + Raises: + Exception: For any processing errors + """ + try: + logger.info( + f"Processing orchestration request for chatId: {request.chatId}, " + f"authorId: {request.authorId}, environment: {request.environment}" + ) + + # Initialize LLM Manager with configuration + # TODO: Remove hardcoded config path when proper configuration management is implemented + self._initialize_llm_manager( + environment=request.environment, connection_id=request.connection_id + ) + + # Step 2: Refine user prompt using loaded configuration + self._refine_user_prompt( + original_message=request.message, + conversation_history=request.conversationHistory, + ) + + # TODO: Implement actual LLM processing pipeline + # This will include: + # 1. Input validation and guard checks + # 2. Context preparation from conversation history + # 3. LLM provider selection based on configuration + # 4. Question scope validation + # 5. LLM inference execution + # 6. Response post-processing + # 7. Citation generation + + # For now, return hardcoded response + response = self._generate_hardcoded_response(request.chatId) + + logger.info(f"Successfully processed request for chatId: {request.chatId}") + return response + + except Exception as e: + logger.error( + f"Error processing orchestration request for chatId: {request.chatId}, " + f"error: {str(e)}" + ) + # Return error response + return OrchestrationResponse( + chatId=request.chatId, + llmServiceActive=False, + questionOutOfLLMScope=False, + inputGuardFailed=True, + content="An error occurred while processing your request. Please try again later.", + ) + + def _initialize_llm_manager( + self, environment: str, connection_id: Optional[str] + ) -> None: + """ + Initialize LLM Manager with proper configuration. + + Args: + environment: Environment context (production/test/development) + connection_id: Optional connection identifier + """ + try: + # TODO: Implement proper config path resolution based on environment + # TODO: Handle connection_id for multi-tenant scenarios + logger.info(f"Initializing LLM Manager for environment: {environment}") + + self.llm_manager = LLMManager( + environment=environment, connection_id=connection_id + ) + + logger.info("LLM Manager initialized successfully") + + except Exception as e: + logger.error(f"Failed to initialize LLM Manager: {str(e)}") + raise + + def _refine_user_prompt( + self, original_message: str, conversation_history: List[ConversationItem] + ) -> None: + """ + Refine user prompt using loaded LLM configuration and log all variants. + + Args: + original_message: The original user message to refine + conversation_history: Previous conversation context + """ + try: + logger.info("Starting prompt refinement process") + + if self.llm_manager is None: + logger.error("LLM Manager not initialized, cannot refine prompts") + return + + # Convert conversation history to DSPy format + history: List[Dict[str, str]] = [] + for item in conversation_history: + # Map 'bot' to 'assistant' for consistency with standard chat formats + role = "assistant" if item.authorRole == "bot" else item.authorRole + history.append({"role": role, "content": item.message}) + + # Create prompt refiner using the same LLM manager instance + refiner = PromptRefinerAgent(llm_manager=self.llm_manager) + + # Generate structured prompt refinement output + refinement_result = refiner.forward_structured( + history=history, question=original_message + ) + + # Validate the output schema using Pydantic + validated_output = PromptRefinerOutput(**refinement_result) + + # Log the complete structured output as JSON + output_json = validated_output.model_dump() + logger.info( + f"Prompt refinement output: {json.dumps(output_json, indent=2)}" + ) + + logger.info("Prompt refinement completed successfully") + + except Exception as e: + logger.error(f"Prompt refinement failed: {str(e)}") + logger.info(f"Continuing with original message: {original_message}") + # Don't raise exception - continue with original message + + def _generate_hardcoded_response(self, chat_id: str) -> OrchestrationResponse: + """ + Generate hardcoded response for testing purposes. + + Args: + chat_id: Chat session identifier + + Returns: + OrchestrationResponse with hardcoded values + """ + hardcoded_content = """This is a random answer payload. + +with citations. + +References +- https://gov.ee/sample1, +- https://gov.ee/sample2""" + + return OrchestrationResponse( + chatId=chat_id, + llmServiceActive=True, + questionOutOfLLMScope=False, + inputGuardFailed=False, + content=hardcoded_content, + ) diff --git a/src/llm_orchestration_service_api.py b/src/llm_orchestration_service_api.py new file mode 100644 index 0000000..93cf727 --- /dev/null +++ b/src/llm_orchestration_service_api.py @@ -0,0 +1,120 @@ +"""LLM Orchestration Service API - FastAPI application.""" + +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from fastapi import FastAPI, HTTPException, status +from fastapi.responses import JSONResponse +from loguru import logger + +from llm_orchestration_service import LLMOrchestrationService +from models.request_models import OrchestrationRequest, OrchestrationResponse + + +# Global service instance +orchestration_service: LLMOrchestrationService | None = None + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + """Application lifespan manager.""" + # Startup + logger.info("Starting LLM Orchestration Service API") + global orchestration_service + orchestration_service = LLMOrchestrationService() + logger.info("LLM Orchestration Service initialized") + + yield + + # Shutdown + logger.info("Shutting down LLM Orchestration Service API") + + +# Create FastAPI application +app = FastAPI( + title="LLM Orchestration Service API", + description="API for orchestrating LLM requests with configuration management", + version="1.0.0", + lifespan=lifespan, +) + + +@app.get("/health") +async def health_check() -> dict[str, str]: + """Health check endpoint.""" + return {"status": "healthy", "service": "llm-orchestration-service"} + + +@app.post( + "/orchestrate", + response_model=OrchestrationResponse, + status_code=status.HTTP_200_OK, + summary="Process LLM orchestration request", + description="Processes a user message through the LLM orchestration pipeline", +) +async def orchestrate_llm_request( + request: OrchestrationRequest, +) -> OrchestrationResponse: + """ + Process LLM orchestration request. + + Args: + request: OrchestrationRequest containing user message and context + + Returns: + OrchestrationResponse: Response with LLM output and status flags + + Raises: + HTTPException: For processing errors + """ + try: + logger.info(f"Received orchestration request for chatId: {request.chatId}") + + if orchestration_service is None: + logger.error("Orchestration service not initialized") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Service not initialized", + ) + + # Process the request + response = orchestration_service.process_orchestration_request(request) + + logger.info(f"Successfully processed request for chatId: {request.chatId}") + return response + + except HTTPException: + raise + except Exception as e: + logger.error(f"Unexpected error processing request: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Internal server error occurred", + ) + + +@app.exception_handler(Exception) +async def global_exception_handler(request: object, exc: Exception) -> JSONResponse: + """Global exception handler.""" + logger.error(f"Unhandled exception: {str(exc)}") + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": "Internal server error"}, + ) + + +if __name__ == "__main__": + try: + import uvicorn # type: ignore[import-untyped] + except ImportError: + logger.error("uvicorn not installed. Please install with: pip install uvicorn") + raise + + logger.info("Starting LLM Orchestration Service API server on port 8100") + uvicorn.run( # type: ignore[attr-defined] + "llm_orchestration_service_api:app", + host="0.0.0.0", + port=8100, + reload=True, + log_level="info", + ) diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..169789b --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1 @@ +"""Models package for API request/response schemas.""" diff --git a/src/models/request_models.py b/src/models/request_models.py new file mode 100644 index 0000000..38a8545 --- /dev/null +++ b/src/models/request_models.py @@ -0,0 +1,55 @@ +"""Pydantic models for API requests and responses.""" + +from typing import List, Literal, Optional +from pydantic import BaseModel, Field + + +class ConversationItem(BaseModel): + """Model for conversation history item.""" + + authorRole: Literal["user", "bot"] = Field( + ..., description="Role of the message author" + ) + message: str = Field(..., description="Content of the message") + timestamp: str = Field(..., description="Timestamp in ISO format") + + +class PromptRefinerOutput(BaseModel): + """Model for prompt refiner output.""" + + original_question: str = Field(..., description="The original user question") + refined_questions: List[str] = Field( + ..., description="List of refined question variants" + ) + + +class OrchestrationRequest(BaseModel): + """Model for LLM orchestration request.""" + + chatId: str = Field(..., description="Unique identifier for the chat session") + message: str = Field(..., description="User's message/query") + authorId: str = Field(..., description="Unique identifier for the user") + conversationHistory: List[ConversationItem] = Field( + ..., description="Previous conversation history" + ) + url: str = Field(..., description="Source URL context") + environment: Literal["production", "test", "development"] = Field( + ..., description="Environment context" + ) + connection_id: Optional[str] = Field( + None, description="Optional connection identifier" + ) + + +class OrchestrationResponse(BaseModel): + """Model for LLM orchestration response.""" + + chatId: str = Field(..., description="Chat session identifier from request") + llmServiceActive: bool = Field(..., description="Whether LLM service is active") + questionOutOfLLMScope: bool = Field( + ..., description="Whether question is out of LLM scope" + ) + inputGuardFailed: bool = Field( + ..., description="Whether input guard validation failed" + ) + content: str = Field(..., description="Response content with citations") diff --git a/src/prompt_refiner_module/prompt_refiner.py b/src/prompt_refiner_module/prompt_refiner.py new file mode 100644 index 0000000..80354b2 --- /dev/null +++ b/src/prompt_refiner_module/prompt_refiner.py @@ -0,0 +1,207 @@ +from __future__ import annotations + +from typing import Any, Iterable, List, Mapping, Sequence, Optional, Dict + +import logging +import dspy # type: ignore + +from llm_config_module import LLMManager, LLMProvider + + +LOGGER = logging.getLogger(__name__) + + +class PromptRefineSig(dspy.Signature): + """Produce N distinct, concise rewrites of the user's question using chat history. + + Constraints: + - Preserve the original intent; don't inject unsupported constraints. + - Resolve pronouns with context when safe; avoid changing semantics. + - Prefer explicit, searchable phrasing (entities, dates, units). + - Make each rewrite meaningfully distinct. + - Return exactly N items. + """ + + history = dspy.InputField(desc="Recent conversation history (turns).") # type: ignore + question = dspy.InputField(desc="The user's latest question to refine.") # type: ignore + n = dspy.InputField(desc="Number of rewrites to produce (N).") # type: ignore + + rewrites: List[str] = dspy.OutputField( # type: ignore + desc="Exactly N refined variations of the question, each a single sentence." + ) + + +def _coerce_to_list(value: Any) -> list[str]: + """Coerce model output into a list[str] safely.""" + if isinstance(value, list): + # Ensure elements are strings + return [str(x).strip() for x in value if str(x).strip()] # type: ignore + if isinstance(value, str): + lines = [ln.strip() for ln in value.splitlines() if ln.strip()] + cleaned: list[str] = [] + for ln in lines: + s = ln.lstrip("•*-—-").strip() + while s and (s[0].isdigit() or s[0] in ".)]"): + s = s[1:].lstrip() + if s: + cleaned.append(s) + return cleaned + return [] + + +def _dedupe_keep_order(items: Iterable[str], limit: int) -> list[str]: + """Deduplicate case-insensitively, keep order, truncate to limit.""" + seen: set[str] = set() + out: list[str] = [] + for it in items: + key = it.strip().rstrip(".").lower() + if key and key not in seen: + seen.add(key) + out.append(it.strip().rstrip(".")) + if len(out) >= limit: + break + return out + + +def _validate_inputs(question: str, n: int) -> None: + """Validate inputs with clear errors (Sonar: no magic, explicit checks).""" + if not isinstance(question, str) or not question.strip(): # type: ignore + raise ValueError("`question` must be a non-empty string.") + if not isinstance(n, int) or n <= 0: # type: ignore + raise ValueError("`n` must be a positive integer.") + + +def _is_history_like(history: Any) -> bool: + """Accept dspy.History or list[{'role': str, 'content': str}] to stay flexible.""" + + if hasattr(history, "messages"): # likely a dspy.History + return True + if isinstance(history, Sequence): + return all( + isinstance(m, Mapping) + and "role" in m + and "content" in m + and isinstance(m["role"], str) + and isinstance(m["content"], str) + for m in history # type: ignore[assignment] + ) + return False + + +class PromptRefinerAgent(dspy.Module): + """Config-driven Prompt Refiner that emits N rewrites from history + question. + + This module uses the LLMManager to access configured providers and configures + DSPy globally via the manager's configure_dspy method. + + Parameters + ---------- + config_path : str, optional + Path to the YAML configuration file. If None, uses default config. + provider : LLLProvider, optional + Specific provider to use. If None, uses default provider from config. + default_n : int + Fallback number of rewrites when `n` not provided in `forward`. + llm_manager : LLMManager, optional + Existing LLMManager instance to reuse. If provided, config_path is ignored. + """ + + def __init__( + self, + config_path: Optional[str] = None, + provider: Optional[LLMProvider] = None, + default_n: int = 5, + llm_manager: Optional[LLMManager] = None, + ) -> None: + super().__init__() # type: ignore + if default_n <= 0: + raise ValueError("`default_n` must be a positive integer.") + + self._default_n = int(default_n) + + # Use existing LLMManager if provided, otherwise create new one + if llm_manager is not None: + self._manager = llm_manager + LOGGER.debug("PromptRefinerAgent using provided LLMManager instance.") + else: + self._manager = LLMManager(config_path) + LOGGER.debug("PromptRefinerAgent created new LLMManager instance.") + + self._manager.configure_dspy(provider) + + provider_info = self._manager.get_provider_info(provider) + LOGGER.debug( + "PromptRefinerAgent configured with provider '%s'.", + provider_info.get("provider", "unknown"), + ) + + # Use ChainOfThought for better reasoning before output fields + self._predictor = dspy.ChainOfThought(PromptRefineSig) + + def forward( + self, + history: Sequence[Mapping[str, str]] | Any, + question: str, + n: int | None = None, + ) -> list[str]: + """Return up to N refined variants (exactly N when possible). + + `history` can be a DSPy History or a list of {role, content}. + """ + k = int(n) if n is not None else self._default_n + _validate_inputs(question, k) + + if not _is_history_like(history): + raise ValueError( + "`history` must be a dspy.History or a sequence of {'role','content'}." + ) + + # Primary prediction + result = self._predictor(history=history, question=question, n=k) + rewrites = _coerce_to_list(getattr(result, "rewrites", [])) + deduped = _dedupe_keep_order(rewrites, k) + + if len(deduped) == k: + return deduped + + # If short, ask for a few more variants to top up + missing = k - len(deduped) + if missing > 0: + follow = self._predictor( + history=history, + question=f"Create {missing} additional, *new* paraphrases of: {question}", + n=missing, + ) + extra = _coerce_to_list(getattr(follow, "rewrites", [])) + combined = _dedupe_keep_order(deduped + extra, k) + return combined + + return deduped + + def forward_structured( + self, + history: Sequence[Mapping[str, str]] | Any, + question: str, + n: int | None = None, + ) -> Dict[str, Any]: + """Return structured output with original question and refined variants. + + Returns dictionary in format: + { + "original_question": "original question text", + "refined_questions": ["variant1", "variant2", ...] + } + + Args: + history: Conversation history (DSPy History or list of {role, content}) + question: Original user question to refine + n: Number of variants to generate (uses default_n if None) + + Returns: + Dictionary with original_question and refined_questions + """ + # Get refined variants using existing forward method + refined_variants = self.forward(history, question, n) + + # Return structured format + return {"original_question": question, "refined_questions": refined_variants} diff --git a/test_api.py b/test_api.py new file mode 100644 index 0000000..a950f3f --- /dev/null +++ b/test_api.py @@ -0,0 +1,89 @@ +"""Test script for the LLM Orchestration Service API.""" + +import json +import requests + + +def test_api(): + """Test the orchestration API endpoint.""" + # API endpoint + url = "http://localhost:8100/orchestrate" + + # Test request payload + test_payload = { + "chatId": "chat-12345", + "message": "I need help with my electricity bill.", + "authorId": "12345", + "conversationHistory": [ + { + "authorRole": "user", + "message": "Hi, I have a billing issue", + "timestamp": "2025-04-29T09:00:00Z", + }, + { + "authorRole": "bot", + "message": "Sure, can you tell me more about the issue?", + "timestamp": "2025-04-29T09:00:05Z", + }, + ], + "url": "id.ee", + "environment": "development", + "connection_id": "test-connection-123", + } + + try: + print("Testing /orchestrate endpoint...") + print(f"Request payload: {json.dumps(test_payload, indent=2)}") + + # Make the request + response = requests.post(url, json=test_payload, timeout=30) + + print(f"\nResponse Status: {response.status_code}") + print(f"Response Headers: {dict(response.headers)}") + + if response.status_code == 200: + response_data = response.json() + print(f"Response Body: {json.dumps(response_data, indent=2)}") + print("✅ API test successful!") + else: + print(f"❌ API test failed with status: {response.status_code}") + print(f"Error: {response.text}") + + except requests.exceptions.ConnectionError: + print( + "❌ Could not connect to API. Make sure the server is running on port 8100" + ) + print( + "Run: uv run uvicorn src.llm_orchestration_service_api:app --host 0.0.0.0 --port 8100" + ) + except Exception as e: + print(f"❌ Error during API test: {str(e)}") + + +def test_health_check(): + """Test the health check endpoint.""" + try: + print("\nTesting /health endpoint...") + response = requests.get("http://localhost:8100/health", timeout=10) + + if response.status_code == 200: + print(f"Health check response: {response.json()}") + print("✅ Health check successful!") + else: + print(f"❌ Health check failed: {response.status_code}") + + except requests.exceptions.ConnectionError: + print("❌ Could not connect to health endpoint") + except Exception as e: + print(f"❌ Health check error: {str(e)}") + + +if __name__ == "__main__": + print("LLM Orchestration Service API Test") + print("=" * 50) + + test_health_check() + test_api() + + print("\n" + "=" * 50) + print("Test completed!") diff --git a/test_integration.py b/test_integration.py new file mode 100644 index 0000000..1ed4baf --- /dev/null +++ b/test_integration.py @@ -0,0 +1,57 @@ +"""Test script for the prompt refiner integration.""" + +import sys +from pathlib import Path + +# Add src directory to Python path +src_path = Path(__file__).parent / "src" +sys.path.insert(0, str(src_path)) + +# Import after path setup +from models.request_models import OrchestrationRequest, ConversationItem # type: ignore[import-untyped] +from llm_orchestration_service import LLMOrchestrationService # type: ignore[import-untyped] + + +def test_integration(): + """Test the orchestration service with prompt refiner integration.""" + print("Testing LLM Orchestration Service with Prompt Refiner...") + + # Create test request + test_request = OrchestrationRequest( + chatId="test-chat-123", + message="I need help with my electricity bill payment.", + authorId="test-user", + conversationHistory=[ + ConversationItem( + authorRole="user", + message="Hello, I have a question about my bill", + timestamp="2025-09-11T10:00:00Z", + ), + ConversationItem( + authorRole="bot", + message="I'm here to help with your billing questions. What specific issue do you have?", + timestamp="2025-09-11T10:00:30Z", + ), + ], + url="gov.ee", + environment="development", + connection_id="test-conn-123", + ) + + try: + # Test the orchestration service + service = LLMOrchestrationService() + response = service.process_orchestration_request(test_request) + + print("✅ Integration test successful!") + print(f"Response: {response}") + + except Exception as e: + print(f"❌ Integration test failed: {str(e)}") + import traceback + + print(traceback.format_exc()) + + +if __name__ == "__main__": + test_integration() diff --git a/test_prompt_refiner_schema.py b/test_prompt_refiner_schema.py new file mode 100644 index 0000000..b6504ee --- /dev/null +++ b/test_prompt_refiner_schema.py @@ -0,0 +1,72 @@ +"""Test script to validate prompt refiner output schema.""" + +import sys +import json +from pathlib import Path + +# Add src directory to Python path +src_path = Path(__file__).parent / "src" +sys.path.insert(0, str(src_path)) + + +def test_prompt_refiner_schema(): + """Test the PromptRefinerOutput schema validation.""" + print("Testing PromptRefinerOutput Schema Validation...") + + try: + # Import after path setup + from models.request_models import PromptRefinerOutput # type: ignore[import-untyped] + + # Test valid data that matches your required format + valid_data = PromptRefinerOutput( + original_question="How do I configure Azure embeddings?", + refined_questions=[ + "Configure Azure OpenAI embedding endpoint", + "Set Azure embedding deployment name", + "Azure OpenAI embeddings API version requirements", + "Provide API key for Azure embedding generator", + "Azure OpenAI embedding configuration steps", + ], + ) + + print("✅ Schema validation successful!") + print(f"Original question: {valid_data.original_question}") + print(f"Number of refined questions: {len(valid_data.refined_questions)}") + print("\nRefined questions:") + for i, question in enumerate(valid_data.refined_questions, 1): + print(f" {i}. {question}") + + # Test JSON serialization + json_output = valid_data.model_dump() + print("\n✅ JSON serialization successful!") + print(f"JSON output:\n{json.dumps(json_output, indent=2)}") + + # Verify the exact format you requested + expected_keys = {"original_question", "refined_questions"} + actual_keys = set(json_output.keys()) + + if expected_keys == actual_keys: + print("✅ Output format matches exactly with required schema!") + else: + print(f"❌ Schema mismatch. Expected: {expected_keys}, Got: {actual_keys}") + return False + + return True + + except Exception as e: + print(f"❌ Schema validation failed: {str(e)}") + import traceback + + print(traceback.format_exc()) + return False + + +if __name__ == "__main__": + print("Prompt Refiner Output Schema Validation Test") + print("=" * 50) + success = test_prompt_refiner_schema() + print("\n" + "=" * 50) + if success: + print("✅ Schema validation test passed!") + else: + print("❌ Schema validation test failed!") diff --git a/tests/test_prompt_refiner.py b/tests/test_prompt_refiner.py new file mode 100644 index 0000000..dcdcf18 --- /dev/null +++ b/tests/test_prompt_refiner.py @@ -0,0 +1,322 @@ +import os +from pathlib import Path +import pytest +from typing import Dict, List + +from llm_config_module.llm_manager import LLMManager +from llm_config_module.types import LLMProvider +from prompt_refiner_module.prompt_refiner import PromptRefinerAgent + + +class TestPromptRefinerAgent: + """Test suite for PromptRefinerAgent functionality.""" + + @pytest.fixture + def config_path(self) -> str: + """Get path to llm_config.yaml.""" + cfg_path = ( + Path(__file__).parent.parent + / "src" + / "llm_config_module" + / "config" + / "llm_config.yaml" + ) + assert cfg_path.exists(), f"llm_config.yaml not found at {cfg_path}" + return str(cfg_path) + + @pytest.fixture + def sample_history(self) -> List[Dict[str, str]]: + """Sample conversation history for testing.""" + return [ + { + "role": "user", + "content": "What government services are available for healthcare?", + }, + { + "role": "assistant", + "content": "Government healthcare services include public hospitals, subsidized medical treatments, and health insurance programs like Medicaid and Medicare.", + }, + {"role": "user", "content": "Can you provide more details about Medicaid?"}, + ] + + @pytest.fixture + def empty_history(self) -> List[Dict[str, str]]: + """Empty conversation history for testing.""" + return [] + + def test_prompt_refiner_initialization_default(self, config_path: str) -> None: + """Test PromptRefinerAgent initialization with default settings.""" + agent = PromptRefinerAgent(config_path=config_path) + assert agent._default_n == 5 # type: ignore + assert agent._manager is not None # type: ignore + assert agent._predictor is not None # type: ignore + + def test_prompt_refiner_initialization_custom_n(self, config_path: str) -> None: + """Test PromptRefinerAgent initialization with custom default_n.""" + agent = PromptRefinerAgent(config_path=config_path, default_n=3) + assert agent._default_n == 3 # type: ignore + + def test_prompt_refiner_initialization_invalid_n(self, config_path: str) -> None: + """Test PromptRefinerAgent initialization with invalid default_n.""" + with pytest.raises(ValueError, match="`default_n` must be a positive integer"): + PromptRefinerAgent(config_path=config_path, default_n=0) + + with pytest.raises(ValueError, match="`default_n` must be a positive integer"): + PromptRefinerAgent(config_path=config_path, default_n=-1) + + def test_validation_empty_question( + self, config_path: str, sample_history: List[Dict[str, str]] + ) -> None: + """Test validation with empty question.""" + agent = PromptRefinerAgent(config_path=config_path) + + with pytest.raises(ValueError, match="`question` must be a non-empty string"): + agent.forward(sample_history, "", 3) + + with pytest.raises(ValueError, match="`question` must be a non-empty string"): + agent.forward(sample_history, " ", 3) + + def test_validation_invalid_n( + self, config_path: str, sample_history: List[Dict[str, str]] + ) -> None: + """Test validation with invalid n parameter.""" + agent = PromptRefinerAgent(config_path=config_path) + + with pytest.raises(ValueError, match="`n` must be a positive integer"): + agent.forward( + sample_history, + "What are the benefits of government housing programs?", + 0, + ) + + with pytest.raises(ValueError, match="`n` must be a positive integer"): + agent.forward( + sample_history, + "What are the benefits of government housing programs?", + -1, + ) + + def test_validation_invalid_history(self, config_path: str) -> None: + """Test validation with invalid history format.""" + agent = PromptRefinerAgent(config_path=config_path) + + with pytest.raises( + ValueError, match="`history` must be a dspy.History or a sequence" + ): + agent.forward("invalid_history", "What is AI?", 3) # type: ignore + + with pytest.raises( + ValueError, match="`history` must be a dspy.History or a sequence" + ): + agent.forward({"invalid": "format"}, "What is AI?", 3) # type: ignore + + @pytest.mark.skipif( + not any( + os.getenv(var) for var in ["AWS_ACCESS_KEY_ID", "AZURE_OPENAI_API_KEY"] + ), + reason="No LLM provider environment variables set", + ) + def test_prompt_refiner_with_history( + self, config_path: str, sample_history: List[Dict[str, str]] + ) -> None: + """Test prompt refiner with conversation history.""" + manager = LLMManager(config_path) + + # Find available provider + available_providers = manager.get_available_providers() + if not available_providers: + pytest.skip("No LLM providers available for testing") + + provider = next(iter(available_providers.keys())) + print(f"\n🔧 Testing with provider: {provider.value}") + + agent = PromptRefinerAgent( + config_path=config_path, provider=provider, default_n=3 + ) + + question = "How can I apply for unemployment benefits?" + rewrites = agent.forward(sample_history, question, n=3) + + # Validate output + assert isinstance(rewrites, list), "Output should be a list" + assert len(rewrites) <= 3, "Should return at most 3 rewrites" + assert len(rewrites) > 0, "Should return at least 1 rewrite" + + for rewrite in rewrites: + assert isinstance(rewrite, str), "Each rewrite should be a string" + assert len(rewrite.strip()) > 0, "Each rewrite should be non-empty" + + print(f"🤖 Original question: {question}") + print(f"📝 Generated {len(rewrites)} rewrites:") + for i, rewrite in enumerate(rewrites, 1): + print(f" {i}. {rewrite}") + + @pytest.mark.skipif( + not any( + os.getenv(var) for var in ["AWS_ACCESS_KEY_ID", "AZURE_OPENAI_API_KEY"] + ), + reason="No LLM provider environment variables set", + ) + def test_prompt_refiner_without_history( + self, config_path: str, empty_history: List[Dict[str, str]] + ) -> None: + """Test prompt refiner without conversation history.""" + manager = LLMManager(config_path) + + # Find available provider + available_providers = manager.get_available_providers() + if not available_providers: + pytest.skip("No LLM providers available for testing") + + provider = next(iter(available_providers.keys())) + + agent = PromptRefinerAgent( + config_path=config_path, provider=provider, default_n=2 + ) + + question = "What are the eligibility criteria for food assistance programs?" + rewrites = agent.forward(empty_history, question, n=2) + + # Validate output + assert isinstance(rewrites, list), "Output should be a list" + assert len(rewrites) <= 2, "Should return at most 2 rewrites" + assert len(rewrites) > 0, "Should return at least 1 rewrite" + + for rewrite in rewrites: + assert isinstance(rewrite, str), "Each rewrite should be a string" + assert len(rewrite.strip()) > 0, "Each rewrite should be non-empty" + + print(f"🤖 Original question: {question}") + print(f"📝 Generated {len(rewrites)} rewrites (no history):") + for i, rewrite in enumerate(rewrites, 1): + print(f" {i}. {rewrite}") + + @pytest.mark.skipif( + not any( + os.getenv(var) for var in ["AWS_ACCESS_KEY_ID", "AZURE_OPENAI_API_KEY"] + ), + reason="No LLM provider environment variables set", + ) + def test_prompt_refiner_default_n( + self, config_path: str, sample_history: List[Dict[str, str]] + ) -> None: + """Test prompt refiner using default n value.""" + manager = LLMManager(config_path) + + # Find available provider + available_providers = manager.get_available_providers() + if not available_providers: + pytest.skip("No LLM providers available for testing") + + provider = next(iter(available_providers.keys())) + + agent = PromptRefinerAgent( + config_path=config_path, provider=provider, default_n=4 + ) + + question = "How does this technology impact society?" + # Don't specify n, should use default_n=4 + rewrites = agent.forward(sample_history, question) + + # Validate output + assert isinstance(rewrites, list), "Output should be a list" + assert len(rewrites) <= 4, "Should return at most 4 rewrites (default_n)" + assert len(rewrites) > 0, "Should return at least 1 rewrite" + + print(f"🤖 Original question: {question}") + print(f"📝 Generated {len(rewrites)} rewrites (using default_n=4):") + for i, rewrite in enumerate(rewrites, 1): + print(f" {i}. {rewrite}") + + @pytest.mark.skipif( + not any( + os.getenv(var) for var in ["AWS_ACCESS_KEY_ID", "AZURE_OPENAI_API_KEY"] + ), + reason="No LLM provider environment variables set", + ) + def test_prompt_refiner_single_rewrite( + self, config_path: str, sample_history: List[Dict[str, str]] + ) -> None: + """Test prompt refiner with n=1.""" + manager = LLMManager(config_path) + + # Find available provider + available_providers = manager.get_available_providers() + if not available_providers: + pytest.skip("No LLM providers available for testing") + + provider = next(iter(available_providers.keys())) + + agent = PromptRefinerAgent(config_path=config_path, provider=provider) + + question = "Tell me about deep learning." + rewrites = agent.forward(sample_history, question, n=1) + + # Validate output + assert isinstance(rewrites, list), "Output should be a list" + assert len(rewrites) == 1, "Should return exactly 1 rewrite" + assert isinstance(rewrites[0], str), "Rewrite should be a string" + assert len(rewrites[0].strip()) > 0, "Rewrite should be non-empty" + + print(f"🤖 Original question: {question}") + print(f"📝 Single rewrite: {rewrites[0]}") + + def test_prompt_refiner_with_specific_provider_aws( + self, config_path: str, sample_history: List[Dict[str, str]] + ) -> None: + """Test prompt refiner with specific AWS provider.""" + if not all( + os.getenv(v) + for v in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION"] + ): + pytest.skip("AWS environment variables not set") + + manager = LLMManager(config_path) + if not manager.is_provider_available(LLMProvider.AWS_BEDROCK): + pytest.skip("AWS Bedrock provider not available") + + agent = PromptRefinerAgent( + config_path=config_path, provider=LLMProvider.AWS_BEDROCK, default_n=2 + ) + + question = "What are neural networks?" + rewrites = agent.forward(sample_history, question, n=2) + + assert isinstance(rewrites, list), "Output should be a list" + assert len(rewrites) <= 2, "Should return at most 2 rewrites" + assert len(rewrites) > 0, "Should return at least 1 rewrite" + + print(f"🤖 AWS Bedrock - Original: {question}") + print(f"📝 AWS Bedrock - Rewrites: {rewrites}") + + def test_prompt_refiner_with_specific_provider_azure( + self, config_path: str, sample_history: List[Dict[str, str]] + ) -> None: + """Test prompt refiner with specific Azure provider.""" + if not all( + os.getenv(v) + for v in [ + "AZURE_OPENAI_API_KEY", + "AZURE_OPENAI_ENDPOINT", + "AZURE_OPENAI_DEPLOYMENT_NAME", + ] + ): + pytest.skip("Azure environment variables not set") + + manager = LLMManager(config_path) + if not manager.is_provider_available(LLMProvider.AZURE_OPENAI): + pytest.skip("Azure OpenAI provider not available") + + agent = PromptRefinerAgent( + config_path=config_path, provider=LLMProvider.AZURE_OPENAI, default_n=3 + ) + + question = "Explain computer vision applications." + rewrites = agent.forward(sample_history, question, n=3) + + assert isinstance(rewrites, list), "Output should be a list" + assert len(rewrites) <= 3, "Should return at most 3 rewrites" + assert len(rewrites) > 0, "Should return at least 1 rewrite" + + print(f"🤖 Azure OpenAI - Original: {question}") + print(f"📝 Azure OpenAI - Rewrites: {rewrites}") diff --git a/uv.lock b/uv.lock index e3f1c7d..c909f56 100644 --- a/uv.lock +++ b/uv.lock @@ -403,6 +403,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e3/4f/58e7dce7985b35f98fcaba7b366de5baaf4637bc0811be66df4025c1885f/dspy-3.0.3-py3-none-any.whl", hash = "sha256:d19cc38ab3ec7edcb3db56a3463a606268dd2e83280595062b052bcfe0cfd24f", size = 261742, upload-time = "2025-08-31T18:49:30.129Z" }, ] +[[package]] +name = "fastapi" +version = "0.116.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/d7/6c8b3bfe33eeffa208183ec037fee0cce9f7f024089ab1c5d12ef04bd27c/fastapi-0.116.1.tar.gz", hash = "sha256:ed52cbf946abfd70c5a0dccb24673f0670deeb517a88b3544d03c2a6bf283143", size = 296485, upload-time = "2025-07-11T16:22:32.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/47/d63c60f59a59467fda0f93f46335c9d18526d7071f025cb5b89d5353ea42/fastapi-0.116.1-py3-none-any.whl", hash = "sha256:c46ac7c312df840f0c9e220f7964bada936781bc4e2e6eb71f1c4d7553786565", size = 95631, upload-time = "2025-07-11T16:22:30.485Z" }, +] + [[package]] name = "fastuuid" version = "0.12.0" @@ -1154,6 +1168,7 @@ dependencies = [ { name = "azure-identity" }, { name = "boto3" }, { name = "dspy" }, + { name = "fastapi" }, { name = "hvac" }, { name = "loguru" }, { name = "numpy" }, @@ -1167,6 +1182,7 @@ dependencies = [ { name = "requests" }, { name = "ruff" }, { name = "testcontainers" }, + { name = "uvicorn" }, ] [package.metadata] @@ -1174,6 +1190,7 @@ requires-dist = [ { name = "azure-identity", specifier = ">=1.24.0" }, { name = "boto3", specifier = ">=1.40.25" }, { name = "dspy", specifier = ">=3.0.3" }, + { name = "fastapi", specifier = ">=0.116.1" }, { name = "hvac", specifier = ">=2.3.0" }, { name = "loguru", specifier = ">=0.7.3" }, { name = "numpy", specifier = ">=2.3.2" }, @@ -1187,6 +1204,7 @@ requires-dist = [ { name = "requests", specifier = ">=2.32.5" }, { name = "ruff", specifier = ">=0.12.12" }, { name = "testcontainers", specifier = ">=4.13.0" }, + { name = "uvicorn", specifier = ">=0.35.0" }, ] [[package]] @@ -1353,6 +1371,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b8/d9/13bdde6521f322861fab67473cec4b1cc8999f3871953531cf61945fad92/sqlalchemy-2.0.43-py3-none-any.whl", hash = "sha256:1681c21dd2ccee222c2fe0bef671d1aef7c504087c9c4e800371cfcc8ac966fc", size = 1924759, upload-time = "2025-08-11T15:39:53.024Z" }, ] +[[package]] +name = "starlette" +version = "0.47.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/15/b9/cc3017f9a9c9b6e27c5106cc10cc7904653c3eec0729793aec10479dd669/starlette-0.47.3.tar.gz", hash = "sha256:6bc94f839cc176c4858894f1f8908f0ab79dfec1a6b8402f6da9be26ebea52e9", size = 2584144, upload-time = "2025-08-24T13:36:42.122Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/fd/901cfa59aaa5b30a99e16876f11abe38b59a1a2c51ffb3d7142bb6089069/starlette-0.47.3-py3-none-any.whl", hash = "sha256:89c0778ca62a76b826101e7c709e70680a1699ca7da6b44d38eb0a7e61fe4b51", size = 72991, upload-time = "2025-08-24T13:36:40.887Z" }, +] + [[package]] name = "tenacity" version = "9.1.2" @@ -1463,6 +1494,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, ] +[[package]] +name = "uvicorn" +version = "0.35.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5e/42/e0e305207bb88c6b8d3061399c6a961ffe5fbb7e2aa63c9234df7259e9cd/uvicorn-0.35.0.tar.gz", hash = "sha256:bc662f087f7cf2ce11a1d7fd70b90c9f98ef2e2831556dd078d131b96cc94a01", size = 78473, upload-time = "2025-06-28T16:15:46.058Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/e2/dc81b1bd1dcfe91735810265e9d26bc8ec5da45b4c0f6237e286819194c3/uvicorn-0.35.0-py3-none-any.whl", hash = "sha256:197535216b25ff9b785e29a0b79199f55222193d47f820816e7da751e9bc8d4a", size = 66406, upload-time = "2025-06-28T16:15:44.816Z" }, +] + [[package]] name = "virtualenv" version = "20.34.0" From e7382d101fc9716f45f611bb1e87d5c9b72624bc Mon Sep 17 00:00:00 2001 From: nuwangeek Date: Mon, 15 Sep 2025 12:34:39 +0530 Subject: [PATCH 2/3] integrate prompt refiner with llm_config_module --- API_README.md | 136 -------- Dockerfile.llm_orchestration_service | 74 +--- LLM_ORCHESTRATION_SERVICE_API_README.md | 241 +++++++++++++ build-llm-service.sh | 57 ---- docker-compose.llm-dev.yml | 33 -- docker-compose.yml | 8 +- run_api.py | 43 --- src/llm_config_module/config/llm_config.yaml | 4 +- src/llm_config_module/config/loader.py | 26 +- src/llm_config_module/llm_manager.py | 5 +- .../providers/aws_bedrock.py | 36 +- .../providers/azure_openai.py | 36 +- src/llm_config_module/providers/base.py | 28 +- src/llm_orchestration_service.py | 54 +-- src/llm_orchestration_service_api.py | 9 +- src/prompt_refiner_module/prompt_refiner.py | 62 ++-- test_api.py | 89 ----- test_integration.py | 57 ---- test_prompt_refiner_schema.py | 72 ---- tests/conftest.py | 34 +- tests/test_aws.py | 8 +- tests/test_azure.py | 8 +- tests/test_integration_vault_llm_config.py | 13 +- tests/test_llm_vault_integration.py | 2 +- tests/test_prompt_refiner.py | 322 ------------------ 25 files changed, 383 insertions(+), 1074 deletions(-) delete mode 100644 API_README.md create mode 100644 LLM_ORCHESTRATION_SERVICE_API_README.md delete mode 100644 build-llm-service.sh delete mode 100644 docker-compose.llm-dev.yml delete mode 100644 run_api.py delete mode 100644 test_api.py delete mode 100644 test_integration.py delete mode 100644 test_prompt_refiner_schema.py delete mode 100644 tests/test_prompt_refiner.py diff --git a/API_README.md b/API_README.md deleted file mode 100644 index 2f67761..0000000 --- a/API_README.md +++ /dev/null @@ -1,136 +0,0 @@ -# LLM Orchestration Service API - -A FastAPI-based service for orchestrating LLM requests with configuration management and proper validation. - -## API Endpoints - -### POST /orchestrate -Processes LLM orchestration requests. - -**Request Body:** -```json -{ - "chatId": "chat-12345", - "message": "I need help with my electricity bill.", - "authorId": "12345", - "conversationHistory": [ - { - "authorRole": "user", - "message": "Hi, I have a billing issue", - "timestamp": "2025-04-29T09:00:00Z" - }, - { - "authorRole": "bot", - "message": "Sure, can you tell me more about the issue?", - "timestamp": "2025-04-29T09:00:05Z" - } - ], - "url": "id.ee", - "environment": "production|test|development", - "connection_id": "optional-connection-id" -} -``` - -**Response:** -```json -{ - "chatId": "chat-12345", - "llmServiceActive": true, - "questionOutOfLLMScope": false, - "inputGuardFailed": false, - "content": "This is a random answer payload.\n\nwith citations.\n\nReferences\n- https://gov.ee/sample1,\n- https://gov.ee/sample2" -} -``` - -### GET /health -Health check endpoint. - -**Response:** -```json -{ - "status": "healthy", - "service": "llm-orchestration-service" -} -``` - -## Running the API - -### Local Development: -```bash -uv run uvicorn src.llm_orchestration_service_api:app --host 0.0.0.0 --port 8100 --reload -``` - -### Docker (Standalone): -```bash -# Build and run with custom script -.\build-llm-service.bat run # Windows -./build-llm-service.sh run # Linux/Mac - -# Or manually -docker build -f Dockerfile.llm_orchestration_service -t llm-orchestration-service . -docker run -p 8100:8100 --env-file .env llm-orchestration-service -``` - -### Docker Compose (Production): -```bash -docker-compose up llm-orchestration-service -``` - -### Docker Compose (Development with hot reload): -```bash -docker-compose -f docker-compose.yml -f docker-compose.llm-dev.yml up llm-orchestration-service -``` - -### Test the API: -```bash -uv run python test_api.py -``` - -## Features - -- ✅ FastAPI with automatic OpenAPI documentation -- ✅ Pydantic validation for requests/responses -- ✅ Proper error handling and logging with Loguru -- ✅ Integration with existing LLM config module -- ✅ Type-safe implementation -- ✅ Health check endpoint -- 🔄 Hardcoded responses (TODO: Implement actual LLM pipeline) - -## Documentation - -When the server is running, visit: -- API docs: http://localhost:8100/docs -- ReDoc: http://localhost:8100/redoc - -## Architecture - -``` -┌─────────────────────────────────────────────────────────────┐ -│ FastAPI Application │ -│ (llm_orchestration_service_api.py) │ -└─────────────────────┬───────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ Business Logic Service │ -│ (llm_orchestration_service.py) │ -└─────────────────────┬───────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────┐ -│ LLM Config Module │ -│ (llm_manager.py) │ -└─────────────────────────────────────────────────────────────┘ -``` - -## TODO Items - -- [ ] Implement actual LLM processing pipeline -- [ ] Add input validation and guard checks -- [ ] Implement question scope validation -- [ ] Add proper citation generation -- [ ] Handle multi-tenant scenarios with connection_id -- [ ] Add authentication/authorization -- [ ] Add comprehensive error handling -- [ ] Add request/response logging -- [ ] Add metrics and monitoring diff --git a/Dockerfile.llm_orchestration_service b/Dockerfile.llm_orchestration_service index 7966747..5b65cfe 100644 --- a/Dockerfile.llm_orchestration_service +++ b/Dockerfile.llm_orchestration_service @@ -1,78 +1,22 @@ -# Dockerfile for LLM Orchestration Service -# Multi-stage build for optimized production image +FROM python:3.12-slim -# Stage 1: Build environment with uv -FROM python:3.12-slim AS builder - -# Set environment variables for uv -ENV UV_CACHE_DIR=/opt/uv-cache \ - UV_LINK_MODE=copy \ - UV_COMPILE_BYTECODE=1 \ - UV_PYTHON_DOWNLOADS=never - -# Install system dependencies for building -RUN apt-get update && apt-get install -y \ - build-essential \ - curl \ - && rm -rf /var/lib/apt/lists/* - -# Install uv using the official installer (as per CONTRIBUTING.md) -RUN curl -LsSf https://astral.sh/uv/install.sh | sh - -# Add uv to PATH -ENV PATH="/root/.cargo/bin:$PATH" - -# Set working directory -WORKDIR /app - -# Copy uv configuration files -COPY pyproject.toml uv.lock ./ - -# Install dependencies using uv -RUN uv sync --frozen --no-dev - -# Stage 2: Runtime environment -FROM python:3.12-slim AS runtime - -# Set environment variables -ENV PYTHONUNBUFFERED=1 \ - PYTHONDONTWRITEBYTECODE=1 \ - PATH="/app/.venv/bin:$PATH" \ - PYTHONPATH="/app/src" - -# Install runtime system dependencies RUN apt-get update && apt-get install -y \ curl \ && rm -rf /var/lib/apt/lists/* \ && apt-get clean -# Create non-root user for security -RUN groupadd -r appuser && useradd -r -g appuser appuser +COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ -# Set working directory -WORKDIR /app +COPY . /app -# Copy virtual environment from builder stage -COPY --from=builder /app/.venv /app/.venv - -# Copy source code -COPY src/ src/ - -# Copy configuration files (will be mounted as volumes in production) -COPY src/llm_config_module/config/llm_config.yaml src/llm_config_module/config/ +WORKDIR /app -# Create logs directory -RUN mkdir -p logs && chown -R appuser:appuser /app +# Set Python path to include src directory +ENV PYTHONPATH="/app/src:$PYTHONPATH" -# Switch to non-root user -USER appuser +RUN uv sync --locked -# Expose the application port EXPOSE 8100 -# Health check using the FastAPI health endpoint -HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ - CMD curl -f http://localhost:8100/health || exit 1 - -# Default command to run the LLM orchestration service -CMD ["uvicorn", "src.llm_orchestration_service_api:app", "--host", "0.0.0.0", "--port", "8100"] +# Run the FastAPI app via uvicorn +CMD ["uv","run","uvicorn", "src.llm_orchestration_service_api:app", "--host", "0.0.0.0", "--port", "8100"] diff --git a/LLM_ORCHESTRATION_SERVICE_API_README.md b/LLM_ORCHESTRATION_SERVICE_API_README.md new file mode 100644 index 0000000..98e78b9 --- /dev/null +++ b/LLM_ORCHESTRATION_SERVICE_API_README.md @@ -0,0 +1,241 @@ +# LLM Orchestration Service API + +A FastAPI-based service for orchestrating LLM requests with configuration management, prompt refinement, and proper validation. + +## Overview + +The LLM Orchestration Service provides a unified API for processing user queries through a sophisticated pipeline that includes configuration management, prompt refinement, and LLM interaction. The service integrates multiple components to deliver intelligent responses with proper validation and error handling. + +## Architecture & Data Flow + +``` +┌─────────────────────────────────────────────────────────────────────────────────────┐ +│ Client Request │ +│ POST /orchestrate │ +└─────────────────────────┬───────────────────────────────────────────────────────────┘ + │ OrchestrationRequest + ▼ +┌─────────────────────────────────────────────────────────────────────────────────────┐ +│ FastAPI Application │ +│ (llm_orchestration_service_api.py) │ +│ • Request validation with Pydantic │ +│ • Lifespan management │ +│ • Error handling & logging │ +└─────────────────────────┬───────────────────────────────────────────────────────────┘ + │ Validated Request + ▼ +┌─────────────────────────────────────────────────────────────────────────────────────┐ +│ Business Logic Service │ +│ (llm_orchestration_service.py) │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 1: LLM Configuration Management │ │ +│ │ • Initialize LLMManager with environment context │ │ +│ │ • Load configuration from Vault (via llm_config_module) │ │ +│ │ • Select appropriate LLM provider (Azure OpenAI, AWS Bedrock, etc.) │ │ +│ └─────────────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 2: Prompt Refinement │ │ +│ │ • Create PromptRefinerAgent with LLMManager instance │ │ +│ │ • Convert conversation history to DSPy format │ │ +│ │ • Generate N distinct refined question variants │ │ +│ │ • Validate output with PromptRefinerOutput schema │ │ +│ └─────────────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 3: LLM Processing Pipeline (TODO) │ │ +│ │ • Input validation and guard checks │ │ +│ │ • Context preparation from conversation history │ │ +│ │ • Question scope validation │ │ +│ │ • LLM inference execution │ │ +│ │ • Response post-processing │ │ +│ │ • Citation generation │ │ +│ └─────────────────────────────────────────────────────────────────────────────┘ │ +└─────────────────────────┬───────────────────────────────────────────────────────────┘ + │ OrchestrationResponse + ▼ +┌─────────────────────────────────────────────────────────────────────────────────────┐ +│ Client Response │ +│ JSON with status flags │ +└─────────────────────────────────────────────────────────────────────────────────────┘ +``` + +## Component Integration + +### 1. LLM Configuration Module Reuse + +The `llm_config_module` is the core configuration management system that's reused throughout the orchestration flow: + +```python +# Initialization in orchestration service +self.llm_manager = LLMManager( + environment=environment, # production/test/development + connection_id=connection_id # tenant/client identifier +) +``` + +**Configuration Flow:** +1. **Vault Integration**: LLMManager connects to HashiCorp Vault using `rag_config_manager.vault.client` +2. **Schema Validation**: Configuration is validated against `llm_config_module.config.schema` +3. **Provider Selection**: Based on config, appropriate provider is selected (Azure OpenAI, AWS Bedrock) +4. **LLM Instance Creation**: Provider-specific LLM instances are created and cached + +### 2. Prompt Refiner Integration + +The prompt refiner reuses the same LLMManager instance for consistency: + +```python +# Create refiner with shared configuration +refiner = PromptRefinerAgent(llm_manager=self.llm_manager) + +# Generate structured refinement output +refinement_result = refiner.forward_structured( + history=conversation_history, + question=original_message +) +``` + +## API Endpoints + +### POST /orchestrate + +Processes LLM orchestration requests through the complete pipeline. + +**Input Schema** (`OrchestrationRequest`): +```json +{ + "chatId": "string - Unique chat session identifier", + "message": "string - User's input message", + "authorId": "string - User/author identifier", + "conversationHistory": [ + { + "authorRole": "user|bot|assistant", + "message": "string - Message content", + "timestamp": "ISO 8601 datetime string" + } + ], + "url": "string - Context URL (e.g., 'id.ee')", + "environment": "production|test|development", + "connection_id": "string (optional) - Tenant/connection identifier" +} +``` + +**Output Schema** (`OrchestrationResponse`): +```json +{ + "chatId": "string - Same as input", + "llmServiceActive": "boolean - Whether LLM processing succeeded", + "questionOutOfLLMScope": "boolean - Whether question is out of scope", + "inputGuardFailed": "boolean - Whether input validation failed", + "content": "string - Response content with citations" +} +``` + +**Prompt Refiner Output Schema** (`PromptRefinerOutput`): +```json +{ + "original_question": "string - The original user question", + "refined_questions": [ + "string - Refined variant 1", + "string - Refined variant 2", + "string - Refined variant N" + ] +} +``` +``` + +### GET /health +Health check endpoint for monitoring service availability. + +**Response:** +```json +{ + "status": "healthy", + "service": "llm-orchestration-service" +} +``` + +## Running the API + +### Local Development: +```bash +uv run uvicorn src.llm_orchestration_service_api:app --host 0.0.0.0 --port 8100 --reload +``` + +### Docker (Standalone): +```bash +# Build and run with custom script +.\build-llm-service.bat run # Windows +./build-llm-service.sh run # Linux/Mac + +# Or manually +docker build -f Dockerfile.llm_orchestration_service -t llm-orchestration-service . +docker run -p 8100:8100 --env-file .env llm-orchestration-service +``` + +### Docker Compose (Production): +```bash +docker-compose up llm-orchestration-service +``` + +### Docker Compose (Development with hot reload): +```bash +docker-compose -f docker-compose.yml -f docker-compose.llm-dev.yml up llm-orchestration-service +``` + +### Test the API: +```bash +uv run python test_api.py +``` + +## Features + +- ✅ FastAPI with automatic OpenAPI documentation +- ✅ Pydantic validation for requests/responses +- ✅ Proper error handling and logging with Loguru +- ✅ Integration with existing LLM config module +- ✅ Type-safe implementation +- ✅ Health check endpoint +- 🔄 Hardcoded responses (TODO: Implement actual LLM pipeline) + +## Documentation + +When the server is running, visit: +- API docs: http://localhost:8100/docs +- ReDoc: http://localhost:8100/redoc + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────┐ +│ FastAPI Application │ +│ (llm_orchestration_service_api.py) │ +└─────────────────────┬───────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Business Logic Service │ +│ (llm_orchestration_service.py) │ +└─────────────────────┬───────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ LLM Config Module │ +│ (llm_manager.py) │ +└─────────────────────────────────────────────────────────────┘ +``` + +## TODO Items + +- [ ] Implement actual LLM processing pipeline +- [ ] Add input validation and guard checks +- [ ] Implement question scope validation +- [ ] Add proper citation generation +- [ ] Handle multi-tenant scenarios with connection_id +- [ ] Add authentication/authorization +- [ ] Add comprehensive error handling +- [ ] Add request/response logging +- [ ] Add metrics and monitoring diff --git a/build-llm-service.sh b/build-llm-service.sh deleted file mode 100644 index 4a918dd..0000000 --- a/build-llm-service.sh +++ /dev/null @@ -1,57 +0,0 @@ -#!/bin/bash - -# Build and run script for LLM Orchestration Service Docker container - -set -e - -echo "🐳 Building LLM Orchestration Service Docker container..." - -# Build the Docker image -docker build -f Dockerfile.llm_orchestration_service -t llm-orchestration-service:latest . - -echo "✅ Docker image built successfully!" - -# Check if we should run the container -if [ "$1" = "run" ]; then - echo "🚀 Starting LLM Orchestration Service container..." - - # Stop and remove existing container if it exists - docker stop llm-orchestration-service 2>/dev/null || true - docker rm llm-orchestration-service 2>/dev/null || true - - # Run the container - docker run -d \ - --name llm-orchestration-service \ - --network bykstack \ - -p 8100:8100 \ - --env-file .env \ - -e ENVIRONMENT=development \ - -v "$(pwd)/src/llm_config_module/config:/app/src/llm_config_module/config:ro" \ - -v llm_orchestration_logs:/app/logs \ - llm-orchestration-service:latest - - echo "✅ LLM Orchestration Service is running!" - echo "🌐 API available at: http://localhost:8100" - echo "🔍 Health check: http://localhost:8100/health" - echo "📊 API docs: http://localhost:8100/docs" - - # Show logs - echo "" - echo "📋 Container logs (Ctrl+C to stop viewing logs):" - docker logs -f llm-orchestration-service - -elif [ "$1" = "compose" ]; then - echo "🚀 Starting with Docker Compose..." - docker-compose up --build llm-orchestration-service - -else - echo "" - echo "📖 Usage:" - echo " $0 - Build the Docker image only" - echo " $0 run - Build and run the container standalone" - echo " $0 compose - Build and run with docker-compose" - echo "" - echo "🌐 Once running, the API will be available at:" - echo " Health check: http://localhost:8100/health" - echo " API docs: http://localhost:8100/docs" -fi diff --git a/docker-compose.llm-dev.yml b/docker-compose.llm-dev.yml deleted file mode 100644 index 8224ac5..0000000 --- a/docker-compose.llm-dev.yml +++ /dev/null @@ -1,33 +0,0 @@ -# Docker Compose override for LLM Orchestration Service development -# Use: docker-compose -f docker-compose.yml -f docker-compose.llm-dev.yml up - -version: '3.8' - -services: - llm-orchestration-service: - build: - context: . - dockerfile: Dockerfile.llm_orchestration_service - target: runtime - environment: - - ENVIRONMENT=development - - PYTHONPATH=/app/src - volumes: - # Mount source code for development (hot reload if needed) - - ./src:/app/src - # Mount configuration files - - ./src/llm_config_module/config:/app/src/llm_config_module/config:ro - # Mount logs for easier debugging - - ./logs:/app/logs - command: > - uvicorn src.llm_orchestration_service_api:app - --host 0.0.0.0 - --port 8100 - --reload - --reload-dir /app/src - ports: - - "8100:8100" - depends_on: - - vault - networks: - - bykstack diff --git a/docker-compose.yml b/docker-compose.yml index 1aace95..7f74068 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -244,7 +244,7 @@ services: - "8200:8200" # UI & API environment: - VAULT_ADDR=http://0.0.0.0:8200 - - VAULT_API_ADDR=http://localhost:8200 + - VAULT_API_ADDR=http://vault:8200 - VAULT_DEV_ROOT_TOKEN_ID=myroot - VAULT_DEV_LISTEN_ADDRESS=0.0.0.0:8200 volumes: @@ -264,14 +264,14 @@ services: context: . dockerfile: Dockerfile.llm_orchestration_service container_name: llm-orchestration-service - restart: unless-stopped + restart: always ports: - "8100:8100" env_file: - .env environment: - ENVIRONMENT=production - - PYTHONPATH=/app/src + - VAULT_ADDR=http://vault:8200 volumes: # Mount configuration files - ./src/llm_config_module/config:/app/src/llm_config_module/config:ro @@ -282,7 +282,7 @@ services: depends_on: - vault healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8100/health"] + test: ["CMD", "curl", "-f", "http://llm-orchestration-service:8100/health"] interval: 30s timeout: 10s start_period: 40s diff --git a/run_api.py b/run_api.py deleted file mode 100644 index 5585b97..0000000 --- a/run_api.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Run script for LLM Orchestration Service API.""" - -import sys -import os -from pathlib import Path - -# Add src directory to Python path -src_path = Path(__file__).parent / "src" -sys.path.insert(0, str(src_path)) - -if __name__ == "__main__": - try: - import uvicorn # type: ignore[import-untyped] - - print("Starting LLM Orchestration Service API on port 8100...") - print(f"Source path: {src_path}") - - # Change to src directory and run - os.chdir(src_path) - - uvicorn.run( # type: ignore[attr-defined] - "llm_orchestration_service_api:app", - host="0.0.0.0", - port=8100, - reload=True, - log_level="info", - ) - - except ImportError: - print("uvicorn not installed. Please install dependencies first.") - print("Commands to run the API:") - print("1. From project root:") - print( - " cd src && uv run uvicorn llm_orchestration_service_api:app --host 0.0.0.0 --port 8100 --reload" - ) - print("2. Or use this script:") - print(" uv run python run_api.py") - except Exception as e: - print(f"Error starting server: {e}") - print("\nAlternative commands to try:") - print( - "cd src && uv run uvicorn llm_orchestration_service_api:app --host 0.0.0.0 --port 8100 --reload" - ) diff --git a/src/llm_config_module/config/llm_config.yaml b/src/llm_config_module/config/llm_config.yaml index 949230d..250a150 100644 --- a/src/llm_config_module/config/llm_config.yaml +++ b/src/llm_config_module/config/llm_config.yaml @@ -1,7 +1,7 @@ llm: # Vault Configuration vault: - url: "${VAULT_ADDR:http://localhost:8200}" + url: "${VAULT_ADDR:http://vault:8200}" token: "${VAULT_TOKEN}" enabled: true @@ -24,7 +24,7 @@ llm: # AWS Bedrock Configuration aws_bedrock: # enabled: true # Enable AWS Bedrock for testing - model: "anthropic.claude-3-haiku-20240307-v1:0" + model: "eu.anthropic.claude-3-haiku-20240307-v1:0" max_tokens: 4096 temperature: 0.7 cache: true # Keep caching enabled (DSPY default) diff --git a/src/llm_config_module/config/loader.py b/src/llm_config_module/config/loader.py index 0b88a63..6046863 100644 --- a/src/llm_config_module/config/loader.py +++ b/src/llm_config_module/config/loader.py @@ -3,7 +3,7 @@ import os import re from pathlib import Path -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, List, Optional, Union, cast import yaml from dotenv import load_dotenv @@ -23,6 +23,9 @@ # Constants DEFAULT_CONFIG_FILENAME = "llm_config.yaml" +# Type alias for configuration values that can be processed +ConfigValue = Union[str, Dict[str, "ConfigValue"], List["ConfigValue"], int, float, bool, None] + class ConfigurationLoader: """Loads and processes LLM configuration from YAML files with environment variable support.""" @@ -368,7 +371,7 @@ def _process_environment_variables(self, config: Dict[str, Any]) -> Dict[str, An Configuration with environment variables substituted. """ - def substitute_env_vars(obj: Any) -> Any: + def substitute_env_vars(obj: ConfigValue) -> ConfigValue: if isinstance(obj, str): # Pattern to match ${VAR_NAME} or ${VAR_NAME:default_value} pattern = r"\$\{([^}:]+)(?::([^}]*))?\}" @@ -380,19 +383,26 @@ def replace_env_var(match: re.Match[str]) -> str: return re.sub(pattern, replace_env_var, obj) elif isinstance(obj, dict): - result: Dict[str, Any] = {} - for key, value in obj.items(): # type: ignore[misc] - result[str(key)] = substitute_env_vars(value) # type: ignore[arg-type] + result: Dict[str, ConfigValue] = {} + for key, value in obj.items(): + result[str(key)] = substitute_env_vars(value) return result elif isinstance(obj, list): - result_list: List[Any] = [] - for item in obj: # type: ignore[misc] + result_list: List[ConfigValue] = [] + for item in obj: result_list.append(substitute_env_vars(item)) return result_list else: return obj - return substitute_env_vars(config) + result = substitute_env_vars(config) + # Since we know config is a Dict[str, Any] and substitute_env_vars preserves structure, + # the result should also be a Dict[str, Any] + if isinstance(result, dict): + return cast(Dict[str, Any], result) + else: + # This should never happen given our input type, but provide a fallback + raise ConfigurationError("Environment variable substitution resulted in non-dictionary type") def _parse_configuration(self, config: Dict[str, Any]) -> LLMConfiguration: """Parse the processed configuration into structured objects. diff --git a/src/llm_config_module/llm_manager.py b/src/llm_config_module/llm_manager.py index 462e532..bd3ec52 100644 --- a/src/llm_config_module/llm_manager.py +++ b/src/llm_config_module/llm_manager.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional -import dspy # type: ignore[import-untyped] +import dspy from llm_config_module.llm_factory import LLMFactory from llm_config_module.config.loader import ConfigurationLoader @@ -163,8 +163,7 @@ def configure_dspy(self, provider: Optional[LLMProvider] = None) -> None: provider: Optional specific provider to configure DSPY with. """ dspy_client = self.get_dspy_client(provider) - dspy.configure(lm=dspy_client) # type: ignore[attr-defined] - + dspy.configure(lm=dspy_client) def get_available_providers(self) -> Dict[LLMProvider, str]: """Get information about available providers. diff --git a/src/llm_config_module/providers/aws_bedrock.py b/src/llm_config_module/providers/aws_bedrock.py index 52ec7eb..642fab9 100644 --- a/src/llm_config_module/providers/aws_bedrock.py +++ b/src/llm_config_module/providers/aws_bedrock.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List -import dspy # type: ignore[import-untyped] +import dspy from llm_config_module.providers.base import BaseLLMProvider from llm_config_module.exceptions import ProviderInitializationError @@ -60,40 +60,6 @@ def initialize(self) -> None: f"Failed to initialize {self.provider_name} provider: {e}" ) from e - def generate(self, prompt: str, **kwargs: Any) -> str: - """Generate response from AWS Bedrock. - - Args: - prompt: The input prompt for the LLM. - **kwargs: Additional generation parameters. - - Returns: - Generated response text. - - Raises: - RuntimeError: If the provider is not initialized. - Exception: If generation fails. - """ - self._ensure_initialized() - - if self._client is None: - raise RuntimeError("Client is not initialized") - - try: - # Use DSPY's generate method - response = self._client.generate(prompt, **kwargs) # type: ignore[attr-defined] - - # Simple response handling - convert to string regardless of format - if isinstance(response, str): - return response - elif isinstance(response, list) and len(response) > 0: # type: ignore[arg-type] - return str(response[0]) # type: ignore[return-value] - else: - return str(response) # type: ignore[arg-type] - - except Exception as e: - raise RuntimeError(f"Failed to generate response: {e}") from e - def get_dspy_client(self) -> dspy.LM: """Return DSPY-compatible client. diff --git a/src/llm_config_module/providers/azure_openai.py b/src/llm_config_module/providers/azure_openai.py index 9fe0007..a27a1bc 100644 --- a/src/llm_config_module/providers/azure_openai.py +++ b/src/llm_config_module/providers/azure_openai.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List -import dspy # type: ignore[import-untyped] +import dspy from llm_config_module.providers.base import BaseLLMProvider from llm_config_module.exceptions import ProviderInitializationError @@ -63,40 +63,6 @@ def initialize(self) -> None: f"Failed to initialize {self.provider_name} provider: {e}" ) from e - def generate(self, prompt: str, **kwargs: Any) -> str: - """Generate response from Azure OpenAI. - - Args: - prompt: The input prompt for the LLM. - **kwargs: Additional generation parameters. - - Returns: - Generated response text. - - Raises: - RuntimeError: If the provider is not initialized. - Exception: If generation fails. - """ - self._ensure_initialized() - - if self._client is None: - raise RuntimeError("Client is not initialized") - - try: - # Use DSPY's generate method - response = self._client.generate(prompt, **kwargs) # type: ignore[attr-defined] - - # Simple response handling - convert to string regardless of format - if isinstance(response, str): - return response - elif isinstance(response, list) and len(response) > 0: # type: ignore[arg-type] - return str(response[0]) # type: ignore[return-value] - else: - return str(response) # type: ignore[arg-type] - - except Exception as e: - raise RuntimeError(f"Failed to generate response: {e}") from e - def get_dspy_client(self) -> dspy.LM: """Return DSPY-compatible client. diff --git a/src/llm_config_module/providers/base.py b/src/llm_config_module/providers/base.py index c6d4326..2a7d951 100644 --- a/src/llm_config_module/providers/base.py +++ b/src/llm_config_module/providers/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -import dspy # type: ignore[import-untyped] +import dspy from llm_config_module.exceptions import InvalidConfigurationError @@ -37,23 +37,6 @@ def initialize(self) -> None: """ pass - @abstractmethod - def generate(self, prompt: str, **kwargs: Any) -> str: - """Generate response from the LLM. - - Args: - prompt: The input prompt for the LLM. - **kwargs: Additional generation parameters. - - Returns: - Generated response text. - - Raises: - RuntimeError: If the provider is not initialized. - Exception: If generation fails. - """ - pass - @abstractmethod def get_dspy_client(self) -> dspy.LM: """Return DSPY-compatible client. @@ -76,15 +59,6 @@ def provider_name(self) -> str: """ pass - @property - def is_initialized(self) -> bool: - """Check if the provider is initialized. - - Returns: - True if the provider is initialized, False otherwise. - """ - return self._initialized - def validate_config(self) -> None: """Validate provider configuration. diff --git a/src/llm_orchestration_service.py b/src/llm_orchestration_service.py index cc08995..55eaeba 100644 --- a/src/llm_orchestration_service.py +++ b/src/llm_orchestration_service.py @@ -43,7 +43,6 @@ def process_orchestration_request( ) # Initialize LLM Manager with configuration - # TODO: Remove hardcoded config path when proper configuration management is implemented self._initialize_llm_manager( environment=request.environment, connection_id=request.connection_id ) @@ -55,14 +54,10 @@ def process_orchestration_request( ) # TODO: Implement actual LLM processing pipeline - # This will include: - # 1. Input validation and guard checks - # 2. Context preparation from conversation history - # 3. LLM provider selection based on configuration - # 4. Question scope validation - # 5. LLM inference execution - # 6. Response post-processing - # 7. Citation generation + # 3. Chunk retriever + # 4. Re-ranker + # 5. Response Generator + # 6. Output Validator # For now, return hardcoded response response = self._generate_hardcoded_response(request.chatId) @@ -95,8 +90,6 @@ def _initialize_llm_manager( connection_id: Optional connection identifier """ try: - # TODO: Implement proper config path resolution based on environment - # TODO: Handle connection_id for multi-tenant scenarios logger.info(f"Initializing LLM Manager for environment: {environment}") self.llm_manager = LLMManager( @@ -118,18 +111,24 @@ def _refine_user_prompt( Args: original_message: The original user message to refine conversation_history: Previous conversation context + + Raises: + ValueError: When LLM Manager is not initialized + ValidationError: When prompt refinement output validation fails + Exception: For other prompt refinement failures """ - try: - logger.info("Starting prompt refinement process") + logger.info("Starting prompt refinement process") - if self.llm_manager is None: - logger.error("LLM Manager not initialized, cannot refine prompts") - return + # Check if LLM Manager is initialized + if self.llm_manager is None: + error_msg = "LLM Manager not initialized, cannot refine prompts" + logger.error(error_msg) + raise ValueError(error_msg) + try: # Convert conversation history to DSPy format history: List[Dict[str, str]] = [] for item in conversation_history: - # Map 'bot' to 'assistant' for consistency with standard chat formats role = "assistant" if item.authorRole == "bot" else item.authorRole history.append({"role": role, "content": item.message}) @@ -141,10 +140,19 @@ def _refine_user_prompt( history=history, question=original_message ) - # Validate the output schema using Pydantic - validated_output = PromptRefinerOutput(**refinement_result) + # Validate the output schema using Pydantic - this will raise ValidationError if invalid + try: + validated_output = PromptRefinerOutput(**refinement_result) + except Exception as validation_error: + logger.error( + f"Prompt refinement output validation failed: {str(validation_error)}" + ) + logger.error(f"Invalid refinement result: {refinement_result}") + raise ValueError( + f"Prompt refinement validation failed: {str(validation_error)}" + ) from validation_error + - # Log the complete structured output as JSON output_json = validated_output.model_dump() logger.info( f"Prompt refinement output: {json.dumps(output_json, indent=2)}" @@ -152,10 +160,12 @@ def _refine_user_prompt( logger.info("Prompt refinement completed successfully") + except ValueError: + raise except Exception as e: logger.error(f"Prompt refinement failed: {str(e)}") - logger.info(f"Continuing with original message: {original_message}") - # Don't raise exception - continue with original message + logger.error(f"Failed to refine message: {original_message}") + raise RuntimeError(f"Prompt refinement process failed: {str(e)}") from e def _generate_hardcoded_response(self, chat_id: str) -> OrchestrationResponse: """ diff --git a/src/llm_orchestration_service_api.py b/src/llm_orchestration_service_api.py index 93cf727..db8efdd 100644 --- a/src/llm_orchestration_service_api.py +++ b/src/llm_orchestration_service_api.py @@ -6,6 +6,7 @@ from fastapi import FastAPI, HTTPException, status from fastapi.responses import JSONResponse from loguru import logger +import uvicorn from llm_orchestration_service import LLMOrchestrationService from models.request_models import OrchestrationRequest, OrchestrationResponse @@ -104,14 +105,8 @@ async def global_exception_handler(request: object, exc: Exception) -> JSONRespo if __name__ == "__main__": - try: - import uvicorn # type: ignore[import-untyped] - except ImportError: - logger.error("uvicorn not installed. Please install with: pip install uvicorn") - raise - logger.info("Starting LLM Orchestration Service API server on port 8100") - uvicorn.run( # type: ignore[attr-defined] + uvicorn.run( "llm_orchestration_service_api:app", host="0.0.0.0", port=8100, diff --git a/src/prompt_refiner_module/prompt_refiner.py b/src/prompt_refiner_module/prompt_refiner.py index 80354b2..e8c4894 100644 --- a/src/prompt_refiner_module/prompt_refiner.py +++ b/src/prompt_refiner_module/prompt_refiner.py @@ -1,16 +1,25 @@ from __future__ import annotations -from typing import Any, Iterable, List, Mapping, Sequence, Optional, Dict +from typing import Any, Iterable, List, Mapping, Sequence, Optional, Dict, Union, Protocol import logging -import dspy # type: ignore +import dspy from llm_config_module import LLMManager, LLMProvider LOGGER = logging.getLogger(__name__) +# Protocol for DSPy History objects +class DSPyHistoryProtocol(Protocol): + messages: Any +DSPyOutput = Union[str, Sequence[str], Sequence[Any], None] +HistoryList = Sequence[Mapping[str, str]] +# Use Protocol for DSPy History objects instead of Any +HistoryLike = Union[HistoryList, DSPyHistoryProtocol] + +# 1. SIGNATURE: Defines the interface for the DSPy module class PromptRefineSig(dspy.Signature): """Produce N distinct, concise rewrites of the user's question using chat history. @@ -22,20 +31,19 @@ class PromptRefineSig(dspy.Signature): - Return exactly N items. """ - history = dspy.InputField(desc="Recent conversation history (turns).") # type: ignore - question = dspy.InputField(desc="The user's latest question to refine.") # type: ignore - n = dspy.InputField(desc="Number of rewrites to produce (N).") # type: ignore + history = dspy.InputField(desc="Recent conversation history (turns).") + question = dspy.InputField(desc="The user's latest question to refine.") + n = dspy.InputField(desc="Number of rewrites to produce (N).") - rewrites: List[str] = dspy.OutputField( # type: ignore + rewrites: List[str] = dspy.OutputField( desc="Exactly N refined variations of the question, each a single sentence." ) - -def _coerce_to_list(value: Any) -> list[str]: +def _coerce_to_list(value: DSPyOutput) -> list[str]: """Coerce model output into a list[str] safely.""" - if isinstance(value, list): + if isinstance(value, (list, tuple)): # Handle sequences # Ensure elements are strings - return [str(x).strip() for x in value if str(x).strip()] # type: ignore + return [str(x).strip() for x in value if str(x).strip()] if isinstance(value, str): lines = [ln.strip() for ln in value.splitlines() if ln.strip()] cleaned: list[str] = [] @@ -65,29 +73,37 @@ def _dedupe_keep_order(items: Iterable[str], limit: int) -> list[str]: def _validate_inputs(question: str, n: int) -> None: """Validate inputs with clear errors (Sonar: no magic, explicit checks).""" - if not isinstance(question, str) or not question.strip(): # type: ignore + if not question.strip(): raise ValueError("`question` must be a non-empty string.") - if not isinstance(n, int) or n <= 0: # type: ignore + if n <= 0: raise ValueError("`n` must be a positive integer.") -def _is_history_like(history: Any) -> bool: +def _is_history_like(history: HistoryLike) -> bool: """Accept dspy.History or list[{'role': str, 'content': str}] to stay flexible.""" - if hasattr(history, "messages"): # likely a dspy.History + # Case 1: Object with `messages` attribute (e.g., dspy.History) + if hasattr(history, "messages"): return True - if isinstance(history, Sequence): - return all( - isinstance(m, Mapping) - and "role" in m - and "content" in m - and isinstance(m["role"], str) - and isinstance(m["content"], str) - for m in history # type: ignore[assignment] - ) + + # Case 2: Sequence of dict-like items + if isinstance(history, Sequence) and not isinstance(history, str): + return _validate_history_sequence(history) + return False +def _validate_history_sequence(history: Sequence[Mapping[str, str]]) -> bool: + """Helper function to validate history sequence structure.""" + try: + for item in history: + # Check if required keys exist + if "role" not in item or "content" not in item: + return False + return True + except (KeyError, TypeError): + return False +# 3. MODULE: Uses the signature + adds logic class PromptRefinerAgent(dspy.Module): """Config-driven Prompt Refiner that emits N rewrites from history + question. diff --git a/test_api.py b/test_api.py deleted file mode 100644 index a950f3f..0000000 --- a/test_api.py +++ /dev/null @@ -1,89 +0,0 @@ -"""Test script for the LLM Orchestration Service API.""" - -import json -import requests - - -def test_api(): - """Test the orchestration API endpoint.""" - # API endpoint - url = "http://localhost:8100/orchestrate" - - # Test request payload - test_payload = { - "chatId": "chat-12345", - "message": "I need help with my electricity bill.", - "authorId": "12345", - "conversationHistory": [ - { - "authorRole": "user", - "message": "Hi, I have a billing issue", - "timestamp": "2025-04-29T09:00:00Z", - }, - { - "authorRole": "bot", - "message": "Sure, can you tell me more about the issue?", - "timestamp": "2025-04-29T09:00:05Z", - }, - ], - "url": "id.ee", - "environment": "development", - "connection_id": "test-connection-123", - } - - try: - print("Testing /orchestrate endpoint...") - print(f"Request payload: {json.dumps(test_payload, indent=2)}") - - # Make the request - response = requests.post(url, json=test_payload, timeout=30) - - print(f"\nResponse Status: {response.status_code}") - print(f"Response Headers: {dict(response.headers)}") - - if response.status_code == 200: - response_data = response.json() - print(f"Response Body: {json.dumps(response_data, indent=2)}") - print("✅ API test successful!") - else: - print(f"❌ API test failed with status: {response.status_code}") - print(f"Error: {response.text}") - - except requests.exceptions.ConnectionError: - print( - "❌ Could not connect to API. Make sure the server is running on port 8100" - ) - print( - "Run: uv run uvicorn src.llm_orchestration_service_api:app --host 0.0.0.0 --port 8100" - ) - except Exception as e: - print(f"❌ Error during API test: {str(e)}") - - -def test_health_check(): - """Test the health check endpoint.""" - try: - print("\nTesting /health endpoint...") - response = requests.get("http://localhost:8100/health", timeout=10) - - if response.status_code == 200: - print(f"Health check response: {response.json()}") - print("✅ Health check successful!") - else: - print(f"❌ Health check failed: {response.status_code}") - - except requests.exceptions.ConnectionError: - print("❌ Could not connect to health endpoint") - except Exception as e: - print(f"❌ Health check error: {str(e)}") - - -if __name__ == "__main__": - print("LLM Orchestration Service API Test") - print("=" * 50) - - test_health_check() - test_api() - - print("\n" + "=" * 50) - print("Test completed!") diff --git a/test_integration.py b/test_integration.py deleted file mode 100644 index 1ed4baf..0000000 --- a/test_integration.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Test script for the prompt refiner integration.""" - -import sys -from pathlib import Path - -# Add src directory to Python path -src_path = Path(__file__).parent / "src" -sys.path.insert(0, str(src_path)) - -# Import after path setup -from models.request_models import OrchestrationRequest, ConversationItem # type: ignore[import-untyped] -from llm_orchestration_service import LLMOrchestrationService # type: ignore[import-untyped] - - -def test_integration(): - """Test the orchestration service with prompt refiner integration.""" - print("Testing LLM Orchestration Service with Prompt Refiner...") - - # Create test request - test_request = OrchestrationRequest( - chatId="test-chat-123", - message="I need help with my electricity bill payment.", - authorId="test-user", - conversationHistory=[ - ConversationItem( - authorRole="user", - message="Hello, I have a question about my bill", - timestamp="2025-09-11T10:00:00Z", - ), - ConversationItem( - authorRole="bot", - message="I'm here to help with your billing questions. What specific issue do you have?", - timestamp="2025-09-11T10:00:30Z", - ), - ], - url="gov.ee", - environment="development", - connection_id="test-conn-123", - ) - - try: - # Test the orchestration service - service = LLMOrchestrationService() - response = service.process_orchestration_request(test_request) - - print("✅ Integration test successful!") - print(f"Response: {response}") - - except Exception as e: - print(f"❌ Integration test failed: {str(e)}") - import traceback - - print(traceback.format_exc()) - - -if __name__ == "__main__": - test_integration() diff --git a/test_prompt_refiner_schema.py b/test_prompt_refiner_schema.py deleted file mode 100644 index b6504ee..0000000 --- a/test_prompt_refiner_schema.py +++ /dev/null @@ -1,72 +0,0 @@ -"""Test script to validate prompt refiner output schema.""" - -import sys -import json -from pathlib import Path - -# Add src directory to Python path -src_path = Path(__file__).parent / "src" -sys.path.insert(0, str(src_path)) - - -def test_prompt_refiner_schema(): - """Test the PromptRefinerOutput schema validation.""" - print("Testing PromptRefinerOutput Schema Validation...") - - try: - # Import after path setup - from models.request_models import PromptRefinerOutput # type: ignore[import-untyped] - - # Test valid data that matches your required format - valid_data = PromptRefinerOutput( - original_question="How do I configure Azure embeddings?", - refined_questions=[ - "Configure Azure OpenAI embedding endpoint", - "Set Azure embedding deployment name", - "Azure OpenAI embeddings API version requirements", - "Provide API key for Azure embedding generator", - "Azure OpenAI embedding configuration steps", - ], - ) - - print("✅ Schema validation successful!") - print(f"Original question: {valid_data.original_question}") - print(f"Number of refined questions: {len(valid_data.refined_questions)}") - print("\nRefined questions:") - for i, question in enumerate(valid_data.refined_questions, 1): - print(f" {i}. {question}") - - # Test JSON serialization - json_output = valid_data.model_dump() - print("\n✅ JSON serialization successful!") - print(f"JSON output:\n{json.dumps(json_output, indent=2)}") - - # Verify the exact format you requested - expected_keys = {"original_question", "refined_questions"} - actual_keys = set(json_output.keys()) - - if expected_keys == actual_keys: - print("✅ Output format matches exactly with required schema!") - else: - print(f"❌ Schema mismatch. Expected: {expected_keys}, Got: {actual_keys}") - return False - - return True - - except Exception as e: - print(f"❌ Schema validation failed: {str(e)}") - import traceback - - print(traceback.format_exc()) - return False - - -if __name__ == "__main__": - print("Prompt Refiner Output Schema Validation Test") - print("=" * 50) - success = test_prompt_refiner_schema() - print("\n" + "=" * 50) - if success: - print("✅ Schema validation test passed!") - else: - print("❌ Schema validation test failed!") diff --git a/tests/conftest.py b/tests/conftest.py index a806261..4b16978 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,10 +5,10 @@ import pytest from pathlib import Path from typing import Dict, Generator -from testcontainers.vault import VaultContainer # type: ignore -from testcontainers.core.wait_strategies import LogMessageWaitStrategy # type: ignore +from testcontainers.vault import VaultContainer +from testcontainers.core.wait_strategies import LogMessageWaitStrategy from loguru import logger -import hvac # type: ignore +import hvac # Add src directory to Python path @@ -38,7 +38,7 @@ def vault_container() -> Generator[VaultContainer, None, None]: def vault_client(vault_container: VaultContainer) -> hvac.Client: """Get the Vault client.""" vault_url = vault_container.get_connection_url() - return hvac.Client(url=vault_url, token=vault_container.root_token) # type: ignore + return hvac.Client(url=vault_url, token=vault_container.root_token) @pytest.fixture(scope="session") @@ -97,7 +97,7 @@ def populated_vault(vault_client: hvac.Client) -> None: for path, data in test_data.items(): try: - vault_client.secrets.kv.v2.create_or_update_secret( # type: ignore + vault_client.secrets.kv.v2.create_or_update_secret( path=path, secret=data ) logger.debug(f"Created test secret at {path}") @@ -112,9 +112,9 @@ def vault_env_vars( ) -> Generator[Dict[str, str], None, None]: """Set environment variables for Vault access.""" env_vars: Dict[str, str] = { - "VAULT_ADDR": vault_container.get_connection_url(), # type: ignore - "VAULT_URL": vault_container.get_connection_url(), # type: ignore - "VAULT_TOKEN": vault_container.root_token, # type: ignore + "VAULT_ADDR": vault_container.get_connection_url(), + "VAULT_URL": vault_container.get_connection_url(), + "VAULT_TOKEN": vault_container.root_token, "ENVIRONMENT": "production", } @@ -133,17 +133,17 @@ def reset_singletons() -> Generator[None, None, None]: """Reset singleton instances between tests.""" # Reset LLMManager - from llm_config_module.llm_manager import LLMManager + from src.llm_config_module.llm_manager import LLMManager if hasattr(LLMManager, "_instance"): - LLMManager._instance = None # type: ignore + LLMManager._instance = None # Reset VaultConnectionManager if available try: - from rag_config_manager.vault.connection_manager import VaultConnectionManager # type: ignore + from src.rag_config_manager.vault.connection_manager import ConnectionManager as VaultConnectionManager - if hasattr(VaultConnectionManager, "_instance"): # type: ignore - VaultConnectionManager._instance = None # type: ignore + if hasattr(VaultConnectionManager, "_instance"): + VaultConnectionManager._instance = None except ImportError: pass @@ -151,11 +151,11 @@ def reset_singletons() -> Generator[None, None, None]: # Clean up again after test if hasattr(LLMManager, "_instance"): - LLMManager._instance = None # type: ignore + LLMManager._instance = None try: - from rag_config_manager.vault.connection_manager import VaultConnectionManager # type: ignore + from src.rag_config_manager.vault.connection_manager import ConnectionManager as VaultConnectionManager - if hasattr(VaultConnectionManager, "_instance"): # type: ignore - VaultConnectionManager._instance = None # type: ignore + if hasattr(VaultConnectionManager, "_instance"): + VaultConnectionManager._instance = None except ImportError: pass diff --git a/tests/test_aws.py b/tests/test_aws.py index 00bcd41..c7b787f 100644 --- a/tests/test_aws.py +++ b/tests/test_aws.py @@ -1,5 +1,5 @@ import pytest -import dspy # type: ignore +import dspy from typing import Any, Dict from pathlib import Path from src.llm_config_module.llm_manager import LLMManager @@ -43,8 +43,8 @@ def test_aws_llm_inference(vault_env_vars: Dict[str, str]) -> None: class QA(dspy.Signature): """Short factual answer""" - question = dspy.InputField() # type: ignore - answer = dspy.OutputField() # type: ignore + question = dspy.InputField() + answer = dspy.OutputField() qa = dspy.Predict(QA) out = qa( @@ -54,7 +54,7 @@ class QA(dspy.Signature): print( "Question: If this pass through the AWS Bedrock provider, say 'AWS DSPY Configuration Successful'" ) - print(f"Answer: {out.answer}") # type: ignore + print(f"Answer: {out.answer}") # Type-safe assertions answer: Any = getattr(out, "answer", None) diff --git a/tests/test_azure.py b/tests/test_azure.py index 9869439..7174582 100644 --- a/tests/test_azure.py +++ b/tests/test_azure.py @@ -1,5 +1,5 @@ import pytest -import dspy # type: ignore +import dspy from typing import Any, Dict from pathlib import Path from src.llm_config_module.llm_manager import LLMManager @@ -43,8 +43,8 @@ def test_azure_llm_inference(vault_env_vars: Dict[str, str]) -> None: class QA(dspy.Signature): """Short factual answer""" - question = dspy.InputField() # type: ignore - answer = dspy.OutputField() # type: ignore + question = dspy.InputField() + answer = dspy.OutputField() qa = dspy.Predict(QA) out = qa( @@ -54,7 +54,7 @@ class QA(dspy.Signature): print( "Question: If this pass through the Azure OpenAI provider, say 'Azure DSPY Configuration Successful'" ) - print(f"Answer: {out.answer}") # type: ignore + print(f"Answer: {out.answer}") # Type-safe assertions answer: Any = getattr(out, "answer", None) diff --git a/tests/test_integration_vault_llm_config.py b/tests/test_integration_vault_llm_config.py index acdd592..20b581b 100644 --- a/tests/test_integration_vault_llm_config.py +++ b/tests/test_integration_vault_llm_config.py @@ -4,8 +4,8 @@ import pytest from pathlib import Path from typing import Dict -from llm_config_module.llm_manager import LLMManager -from llm_config_module.exceptions import ConfigurationError +from src.llm_config_module.llm_manager import LLMManager +from src.llm_config_module.exceptions import ConfigurationError class TestVaultIntegration: @@ -44,7 +44,7 @@ def test_development_environment_requires_connection_id( self, vault_env_vars: Dict[str, str] ) -> None: """Test that development environment requires connection_id.""" - with pytest.raises(ConfigurationError, match="connection_id is required"): + with pytest.raises(ConfigurationError, match=r".*connection_id is required.*development"): LLMManager( config_path=str(self.cfg_path), environment="development", @@ -81,7 +81,7 @@ def test_valid_connection_id_works(self, vault_env_vars: Dict[str, str]) -> None def test_invalid_connection_id_fails(self, vault_env_vars: Dict[str, str]) -> None: """Test that invalid connection_id causes failure.""" - with pytest.raises(ConfigurationError): + with pytest.raises(ConfigurationError, match=r".*(Connection not found|Failed to discover providers)"): LLMManager( config_path=str(self.cfg_path), environment="development", @@ -180,9 +180,6 @@ def test_vault_unavailable_fallback() -> None: original_values[var] = os.environ.get(var) del os.environ[var] - # Reset any singletons that might be carrying state from other tests - from llm_config_module.llm_manager import LLMManager - LLMManager.reset_instance() try: @@ -193,7 +190,7 @@ def test_vault_unavailable_fallback() -> None: # This should fail since vault is unreachable and token is empty with pytest.raises( ConfigurationError, - match="Vault URL and token must be provided|Failed to load LLM configuration|No production connections found|Connection refused|Failed to connect", + match=r".*(Vault URL and token must be provided|Failed to load LLM configuration|No production connections found|Connection refused|Failed to connect|must be provided.*configuration.*environment)", ): LLMManager(config_path=str(cfg_path), environment="production") diff --git a/tests/test_llm_vault_integration.py b/tests/test_llm_vault_integration.py index 5874810..bb2387a 100644 --- a/tests/test_llm_vault_integration.py +++ b/tests/test_llm_vault_integration.py @@ -10,7 +10,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from loguru import logger -from llm_config_module.llm_manager import LLMManager +from src.llm_config_module.llm_manager import LLMManager # Configure loguru diff --git a/tests/test_prompt_refiner.py b/tests/test_prompt_refiner.py deleted file mode 100644 index dcdcf18..0000000 --- a/tests/test_prompt_refiner.py +++ /dev/null @@ -1,322 +0,0 @@ -import os -from pathlib import Path -import pytest -from typing import Dict, List - -from llm_config_module.llm_manager import LLMManager -from llm_config_module.types import LLMProvider -from prompt_refiner_module.prompt_refiner import PromptRefinerAgent - - -class TestPromptRefinerAgent: - """Test suite for PromptRefinerAgent functionality.""" - - @pytest.fixture - def config_path(self) -> str: - """Get path to llm_config.yaml.""" - cfg_path = ( - Path(__file__).parent.parent - / "src" - / "llm_config_module" - / "config" - / "llm_config.yaml" - ) - assert cfg_path.exists(), f"llm_config.yaml not found at {cfg_path}" - return str(cfg_path) - - @pytest.fixture - def sample_history(self) -> List[Dict[str, str]]: - """Sample conversation history for testing.""" - return [ - { - "role": "user", - "content": "What government services are available for healthcare?", - }, - { - "role": "assistant", - "content": "Government healthcare services include public hospitals, subsidized medical treatments, and health insurance programs like Medicaid and Medicare.", - }, - {"role": "user", "content": "Can you provide more details about Medicaid?"}, - ] - - @pytest.fixture - def empty_history(self) -> List[Dict[str, str]]: - """Empty conversation history for testing.""" - return [] - - def test_prompt_refiner_initialization_default(self, config_path: str) -> None: - """Test PromptRefinerAgent initialization with default settings.""" - agent = PromptRefinerAgent(config_path=config_path) - assert agent._default_n == 5 # type: ignore - assert agent._manager is not None # type: ignore - assert agent._predictor is not None # type: ignore - - def test_prompt_refiner_initialization_custom_n(self, config_path: str) -> None: - """Test PromptRefinerAgent initialization with custom default_n.""" - agent = PromptRefinerAgent(config_path=config_path, default_n=3) - assert agent._default_n == 3 # type: ignore - - def test_prompt_refiner_initialization_invalid_n(self, config_path: str) -> None: - """Test PromptRefinerAgent initialization with invalid default_n.""" - with pytest.raises(ValueError, match="`default_n` must be a positive integer"): - PromptRefinerAgent(config_path=config_path, default_n=0) - - with pytest.raises(ValueError, match="`default_n` must be a positive integer"): - PromptRefinerAgent(config_path=config_path, default_n=-1) - - def test_validation_empty_question( - self, config_path: str, sample_history: List[Dict[str, str]] - ) -> None: - """Test validation with empty question.""" - agent = PromptRefinerAgent(config_path=config_path) - - with pytest.raises(ValueError, match="`question` must be a non-empty string"): - agent.forward(sample_history, "", 3) - - with pytest.raises(ValueError, match="`question` must be a non-empty string"): - agent.forward(sample_history, " ", 3) - - def test_validation_invalid_n( - self, config_path: str, sample_history: List[Dict[str, str]] - ) -> None: - """Test validation with invalid n parameter.""" - agent = PromptRefinerAgent(config_path=config_path) - - with pytest.raises(ValueError, match="`n` must be a positive integer"): - agent.forward( - sample_history, - "What are the benefits of government housing programs?", - 0, - ) - - with pytest.raises(ValueError, match="`n` must be a positive integer"): - agent.forward( - sample_history, - "What are the benefits of government housing programs?", - -1, - ) - - def test_validation_invalid_history(self, config_path: str) -> None: - """Test validation with invalid history format.""" - agent = PromptRefinerAgent(config_path=config_path) - - with pytest.raises( - ValueError, match="`history` must be a dspy.History or a sequence" - ): - agent.forward("invalid_history", "What is AI?", 3) # type: ignore - - with pytest.raises( - ValueError, match="`history` must be a dspy.History or a sequence" - ): - agent.forward({"invalid": "format"}, "What is AI?", 3) # type: ignore - - @pytest.mark.skipif( - not any( - os.getenv(var) for var in ["AWS_ACCESS_KEY_ID", "AZURE_OPENAI_API_KEY"] - ), - reason="No LLM provider environment variables set", - ) - def test_prompt_refiner_with_history( - self, config_path: str, sample_history: List[Dict[str, str]] - ) -> None: - """Test prompt refiner with conversation history.""" - manager = LLMManager(config_path) - - # Find available provider - available_providers = manager.get_available_providers() - if not available_providers: - pytest.skip("No LLM providers available for testing") - - provider = next(iter(available_providers.keys())) - print(f"\n🔧 Testing with provider: {provider.value}") - - agent = PromptRefinerAgent( - config_path=config_path, provider=provider, default_n=3 - ) - - question = "How can I apply for unemployment benefits?" - rewrites = agent.forward(sample_history, question, n=3) - - # Validate output - assert isinstance(rewrites, list), "Output should be a list" - assert len(rewrites) <= 3, "Should return at most 3 rewrites" - assert len(rewrites) > 0, "Should return at least 1 rewrite" - - for rewrite in rewrites: - assert isinstance(rewrite, str), "Each rewrite should be a string" - assert len(rewrite.strip()) > 0, "Each rewrite should be non-empty" - - print(f"🤖 Original question: {question}") - print(f"📝 Generated {len(rewrites)} rewrites:") - for i, rewrite in enumerate(rewrites, 1): - print(f" {i}. {rewrite}") - - @pytest.mark.skipif( - not any( - os.getenv(var) for var in ["AWS_ACCESS_KEY_ID", "AZURE_OPENAI_API_KEY"] - ), - reason="No LLM provider environment variables set", - ) - def test_prompt_refiner_without_history( - self, config_path: str, empty_history: List[Dict[str, str]] - ) -> None: - """Test prompt refiner without conversation history.""" - manager = LLMManager(config_path) - - # Find available provider - available_providers = manager.get_available_providers() - if not available_providers: - pytest.skip("No LLM providers available for testing") - - provider = next(iter(available_providers.keys())) - - agent = PromptRefinerAgent( - config_path=config_path, provider=provider, default_n=2 - ) - - question = "What are the eligibility criteria for food assistance programs?" - rewrites = agent.forward(empty_history, question, n=2) - - # Validate output - assert isinstance(rewrites, list), "Output should be a list" - assert len(rewrites) <= 2, "Should return at most 2 rewrites" - assert len(rewrites) > 0, "Should return at least 1 rewrite" - - for rewrite in rewrites: - assert isinstance(rewrite, str), "Each rewrite should be a string" - assert len(rewrite.strip()) > 0, "Each rewrite should be non-empty" - - print(f"🤖 Original question: {question}") - print(f"📝 Generated {len(rewrites)} rewrites (no history):") - for i, rewrite in enumerate(rewrites, 1): - print(f" {i}. {rewrite}") - - @pytest.mark.skipif( - not any( - os.getenv(var) for var in ["AWS_ACCESS_KEY_ID", "AZURE_OPENAI_API_KEY"] - ), - reason="No LLM provider environment variables set", - ) - def test_prompt_refiner_default_n( - self, config_path: str, sample_history: List[Dict[str, str]] - ) -> None: - """Test prompt refiner using default n value.""" - manager = LLMManager(config_path) - - # Find available provider - available_providers = manager.get_available_providers() - if not available_providers: - pytest.skip("No LLM providers available for testing") - - provider = next(iter(available_providers.keys())) - - agent = PromptRefinerAgent( - config_path=config_path, provider=provider, default_n=4 - ) - - question = "How does this technology impact society?" - # Don't specify n, should use default_n=4 - rewrites = agent.forward(sample_history, question) - - # Validate output - assert isinstance(rewrites, list), "Output should be a list" - assert len(rewrites) <= 4, "Should return at most 4 rewrites (default_n)" - assert len(rewrites) > 0, "Should return at least 1 rewrite" - - print(f"🤖 Original question: {question}") - print(f"📝 Generated {len(rewrites)} rewrites (using default_n=4):") - for i, rewrite in enumerate(rewrites, 1): - print(f" {i}. {rewrite}") - - @pytest.mark.skipif( - not any( - os.getenv(var) for var in ["AWS_ACCESS_KEY_ID", "AZURE_OPENAI_API_KEY"] - ), - reason="No LLM provider environment variables set", - ) - def test_prompt_refiner_single_rewrite( - self, config_path: str, sample_history: List[Dict[str, str]] - ) -> None: - """Test prompt refiner with n=1.""" - manager = LLMManager(config_path) - - # Find available provider - available_providers = manager.get_available_providers() - if not available_providers: - pytest.skip("No LLM providers available for testing") - - provider = next(iter(available_providers.keys())) - - agent = PromptRefinerAgent(config_path=config_path, provider=provider) - - question = "Tell me about deep learning." - rewrites = agent.forward(sample_history, question, n=1) - - # Validate output - assert isinstance(rewrites, list), "Output should be a list" - assert len(rewrites) == 1, "Should return exactly 1 rewrite" - assert isinstance(rewrites[0], str), "Rewrite should be a string" - assert len(rewrites[0].strip()) > 0, "Rewrite should be non-empty" - - print(f"🤖 Original question: {question}") - print(f"📝 Single rewrite: {rewrites[0]}") - - def test_prompt_refiner_with_specific_provider_aws( - self, config_path: str, sample_history: List[Dict[str, str]] - ) -> None: - """Test prompt refiner with specific AWS provider.""" - if not all( - os.getenv(v) - for v in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION"] - ): - pytest.skip("AWS environment variables not set") - - manager = LLMManager(config_path) - if not manager.is_provider_available(LLMProvider.AWS_BEDROCK): - pytest.skip("AWS Bedrock provider not available") - - agent = PromptRefinerAgent( - config_path=config_path, provider=LLMProvider.AWS_BEDROCK, default_n=2 - ) - - question = "What are neural networks?" - rewrites = agent.forward(sample_history, question, n=2) - - assert isinstance(rewrites, list), "Output should be a list" - assert len(rewrites) <= 2, "Should return at most 2 rewrites" - assert len(rewrites) > 0, "Should return at least 1 rewrite" - - print(f"🤖 AWS Bedrock - Original: {question}") - print(f"📝 AWS Bedrock - Rewrites: {rewrites}") - - def test_prompt_refiner_with_specific_provider_azure( - self, config_path: str, sample_history: List[Dict[str, str]] - ) -> None: - """Test prompt refiner with specific Azure provider.""" - if not all( - os.getenv(v) - for v in [ - "AZURE_OPENAI_API_KEY", - "AZURE_OPENAI_ENDPOINT", - "AZURE_OPENAI_DEPLOYMENT_NAME", - ] - ): - pytest.skip("Azure environment variables not set") - - manager = LLMManager(config_path) - if not manager.is_provider_available(LLMProvider.AZURE_OPENAI): - pytest.skip("Azure OpenAI provider not available") - - agent = PromptRefinerAgent( - config_path=config_path, provider=LLMProvider.AZURE_OPENAI, default_n=3 - ) - - question = "Explain computer vision applications." - rewrites = agent.forward(sample_history, question, n=3) - - assert isinstance(rewrites, list), "Output should be a list" - assert len(rewrites) <= 3, "Should return at most 3 rewrites" - assert len(rewrites) > 0, "Should return at least 1 rewrite" - - print(f"🤖 Azure OpenAI - Original: {question}") - print(f"📝 Azure OpenAI - Rewrites: {rewrites}") From a7a23032d4f5a1e86b703107a19d8ca9ddbfc872 Mon Sep 17 00:00:00 2001 From: nuwangeek Date: Mon, 15 Sep 2025 12:35:27 +0530 Subject: [PATCH 3/3] fixed ruff lint issues --- src/llm_config_module/config/loader.py | 8 ++++++-- src/llm_config_module/llm_manager.py | 1 + src/llm_orchestration_service.py | 1 - src/prompt_refiner_module/prompt_refiner.py | 20 ++++++++++++++++++-- tests/conftest.py | 12 +++++++----- tests/test_integration_vault_llm_config.py | 9 +++++++-- 6 files changed, 39 insertions(+), 12 deletions(-) diff --git a/src/llm_config_module/config/loader.py b/src/llm_config_module/config/loader.py index 6046863..0645371 100644 --- a/src/llm_config_module/config/loader.py +++ b/src/llm_config_module/config/loader.py @@ -24,7 +24,9 @@ DEFAULT_CONFIG_FILENAME = "llm_config.yaml" # Type alias for configuration values that can be processed -ConfigValue = Union[str, Dict[str, "ConfigValue"], List["ConfigValue"], int, float, bool, None] +ConfigValue = Union[ + str, Dict[str, "ConfigValue"], List["ConfigValue"], int, float, bool, None +] class ConfigurationLoader: @@ -402,7 +404,9 @@ def replace_env_var(match: re.Match[str]) -> str: return cast(Dict[str, Any], result) else: # This should never happen given our input type, but provide a fallback - raise ConfigurationError("Environment variable substitution resulted in non-dictionary type") + raise ConfigurationError( + "Environment variable substitution resulted in non-dictionary type" + ) def _parse_configuration(self, config: Dict[str, Any]) -> LLMConfiguration: """Parse the processed configuration into structured objects. diff --git a/src/llm_config_module/llm_manager.py b/src/llm_config_module/llm_manager.py index bd3ec52..0a9097a 100644 --- a/src/llm_config_module/llm_manager.py +++ b/src/llm_config_module/llm_manager.py @@ -164,6 +164,7 @@ def configure_dspy(self, provider: Optional[LLMProvider] = None) -> None: """ dspy_client = self.get_dspy_client(provider) dspy.configure(lm=dspy_client) + def get_available_providers(self) -> Dict[LLMProvider, str]: """Get information about available providers. diff --git a/src/llm_orchestration_service.py b/src/llm_orchestration_service.py index 55eaeba..3686a59 100644 --- a/src/llm_orchestration_service.py +++ b/src/llm_orchestration_service.py @@ -152,7 +152,6 @@ def _refine_user_prompt( f"Prompt refinement validation failed: {str(validation_error)}" ) from validation_error - output_json = validated_output.model_dump() logger.info( f"Prompt refinement output: {json.dumps(output_json, indent=2)}" diff --git a/src/prompt_refiner_module/prompt_refiner.py b/src/prompt_refiner_module/prompt_refiner.py index e8c4894..8406609 100644 --- a/src/prompt_refiner_module/prompt_refiner.py +++ b/src/prompt_refiner_module/prompt_refiner.py @@ -1,6 +1,16 @@ from __future__ import annotations -from typing import Any, Iterable, List, Mapping, Sequence, Optional, Dict, Union, Protocol +from typing import ( + Any, + Iterable, + List, + Mapping, + Sequence, + Optional, + Dict, + Union, + Protocol, +) import logging import dspy @@ -10,15 +20,18 @@ LOGGER = logging.getLogger(__name__) + # Protocol for DSPy History objects class DSPyHistoryProtocol(Protocol): messages: Any + DSPyOutput = Union[str, Sequence[str], Sequence[Any], None] HistoryList = Sequence[Mapping[str, str]] # Use Protocol for DSPy History objects instead of Any HistoryLike = Union[HistoryList, DSPyHistoryProtocol] + # 1. SIGNATURE: Defines the interface for the DSPy module class PromptRefineSig(dspy.Signature): """Produce N distinct, concise rewrites of the user's question using chat history. @@ -39,6 +52,7 @@ class PromptRefineSig(dspy.Signature): desc="Exactly N refined variations of the question, each a single sentence." ) + def _coerce_to_list(value: DSPyOutput) -> list[str]: """Coerce model output into a list[str] safely.""" if isinstance(value, (list, tuple)): # Handle sequences @@ -92,6 +106,7 @@ def _is_history_like(history: HistoryLike) -> bool: return False + def _validate_history_sequence(history: Sequence[Mapping[str, str]]) -> bool: """Helper function to validate history sequence structure.""" try: @@ -103,7 +118,8 @@ def _validate_history_sequence(history: Sequence[Mapping[str, str]]) -> bool: except (KeyError, TypeError): return False -# 3. MODULE: Uses the signature + adds logic + +# 3. MODULE: Uses the signature + adds logic class PromptRefinerAgent(dspy.Module): """Config-driven Prompt Refiner that emits N rewrites from history + question. diff --git a/tests/conftest.py b/tests/conftest.py index 4b16978..4991e8a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -97,9 +97,7 @@ def populated_vault(vault_client: hvac.Client) -> None: for path, data in test_data.items(): try: - vault_client.secrets.kv.v2.create_or_update_secret( - path=path, secret=data - ) + vault_client.secrets.kv.v2.create_or_update_secret(path=path, secret=data) logger.debug(f"Created test secret at {path}") except Exception as e: logger.error(f"Failed to create secret at {path}: {e}") @@ -140,7 +138,9 @@ def reset_singletons() -> Generator[None, None, None]: # Reset VaultConnectionManager if available try: - from src.rag_config_manager.vault.connection_manager import ConnectionManager as VaultConnectionManager + from src.rag_config_manager.vault.connection_manager import ( + ConnectionManager as VaultConnectionManager, + ) if hasattr(VaultConnectionManager, "_instance"): VaultConnectionManager._instance = None @@ -153,7 +153,9 @@ def reset_singletons() -> Generator[None, None, None]: if hasattr(LLMManager, "_instance"): LLMManager._instance = None try: - from src.rag_config_manager.vault.connection_manager import ConnectionManager as VaultConnectionManager + from src.rag_config_manager.vault.connection_manager import ( + ConnectionManager as VaultConnectionManager, + ) if hasattr(VaultConnectionManager, "_instance"): VaultConnectionManager._instance = None diff --git a/tests/test_integration_vault_llm_config.py b/tests/test_integration_vault_llm_config.py index 20b581b..9dab72a 100644 --- a/tests/test_integration_vault_llm_config.py +++ b/tests/test_integration_vault_llm_config.py @@ -44,7 +44,9 @@ def test_development_environment_requires_connection_id( self, vault_env_vars: Dict[str, str] ) -> None: """Test that development environment requires connection_id.""" - with pytest.raises(ConfigurationError, match=r".*connection_id is required.*development"): + with pytest.raises( + ConfigurationError, match=r".*connection_id is required.*development" + ): LLMManager( config_path=str(self.cfg_path), environment="development", @@ -81,7 +83,10 @@ def test_valid_connection_id_works(self, vault_env_vars: Dict[str, str]) -> None def test_invalid_connection_id_fails(self, vault_env_vars: Dict[str, str]) -> None: """Test that invalid connection_id causes failure.""" - with pytest.raises(ConfigurationError, match=r".*(Connection not found|Failed to discover providers)"): + with pytest.raises( + ConfigurationError, + match=r".*(Connection not found|Failed to discover providers)", + ): LLMManager( config_path=str(self.cfg_path), environment="development",