![Redis](https://redis.io/wp-content/uploads/2024/04/Logotype.svg?auto=webp&quality=85,75&width=120)

# Optimizing for Production: Context Engineering at Scale

## Welcome to Section 5: Context Optimization

In Section 4, you built a sophisticated multi-tool agent with semantic routing. Now you'll optimize it for production use with:
- Context compression and pruning strategies
- Token usage optimization and cost management
- Performance monitoring and analytics
- Scalable architecture patterns

This is where your educational project becomes a production-ready system.

## Learning Objectives

By the end of this notebook, you will:
1. Implement context compression and relevance-based pruning
2. Add token usage tracking and cost optimization
3. Build performance monitoring and analytics
4. Create scalable caching and batching strategies
5. Deploy optimization techniques for production workloads

## The Production Challenge

Your multi-tool agent works great in development, but production brings new challenges:

### Scale Challenges:
- **Cost**: Token usage can become expensive at scale
- **Latency**: Large contexts slow down responses
- **Memory**: Long conversations consume increasing memory
- **Concurrency**: Multiple users require efficient resource sharing

### Cross-Reference: Optimization Concepts

This builds on optimization patterns from existing notebooks and production systems:
- Context window management and token budgeting
- Memory compression and summarization strategies
- Performance monitoring and cost tracking

**Development vs Production:**
```
Development: "Does it work?"
Production: "Does it work efficiently at scale with acceptable cost?"
```

## Step 1: Load Your Multi-Tool Agent

First, let's load the multi-tool agent you built in Section 4 as our optimization target.

In [None]:
# Environment setup
import os
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

# Verify required environment variables are set
if not os.getenv("OPENAI_API_KEY"):
    raise ValueError(
        "OPENAI_API_KEY not found. Please create a .env file with your OpenAI API key. "
        "Get your key from: https://platform.openai.com/api-keys"
    )

print("✅ Environment variables loaded")
print(f"   REDIS_URL: {os.getenv('REDIS_URL', 'redis://localhost:6379')}")
print(f"   OPENAI_API_KEY: {'✓ Set' if os.getenv('OPENAI_API_KEY') else '✗ Not set'}")

# Import components from previous sections
import sys
import time
import json
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from collections import defaultdict

# Add reference agent to path
sys.path.append('../../../reference-agent')

from redis_context_course.models import (
    Course, StudentProfile, DifficultyLevel, 
    CourseFormat, Semester
)
from redis_context_course.course_manager import CourseManager

print("Foundation components loaded for optimization")

## Step 2: Build Context Optimizer

Let's create a context optimizer that can compress and prune context intelligently.

In [None]:
class ProductionContextOptimizer:
    """Context optimizer for production workloads"""
    
    def __init__(self, max_tokens: int = 4000, compression_ratio: float = 0.7):
        self.max_tokens = max_tokens
        self.compression_ratio = compression_ratio
        self.token_usage_stats = defaultdict(int)
        self.optimization_stats = defaultdict(int)
    
    def estimate_tokens(self, text: str) -> int:
        """Estimate token count (simplified - real implementation would use tiktoken)"""
        # Rough estimation: ~4 characters per token
        return len(text) // 4
    
    def compress_conversation_history(self, conversation: List[Dict]) -> List[Dict]:
        """Compress conversation history by summarizing older messages"""
        if len(conversation) <= 6:  # Keep recent messages as-is
            return conversation
        
        # Keep last 4 messages, summarize the rest
        recent_messages = conversation[-4:]
        older_messages = conversation[:-4]
        
        # Create summary of older messages
        summary_content = self._summarize_messages(older_messages)
        
        summary_message = {
            "role": "system",
            "content": f"[Conversation Summary: {summary_content}]",
            "timestamp": datetime.now().isoformat(),
            "type": "summary"
        }
        
        self.optimization_stats["conversations_compressed"] += 1
        return [summary_message] + recent_messages
    
    def _summarize_messages(self, messages: List[Dict]) -> str:
        """Create a summary of conversation messages"""
        topics = set()
        user_intents = []
        
        for msg in messages:
            content = msg.get("content", "").lower()
            
            # Extract topics
            if "machine learning" in content or "ml" in content:
                topics.add("machine learning")
            if "course" in content:
                topics.add("courses")
            if "recommend" in content or "suggest" in content:
                topics.add("recommendations")
            
            # Extract user intents
            if msg.get("role") == "user":
                if "what" in content and "course" in content:
                    user_intents.append("course inquiry")
                elif "can i" in content or "eligible" in content:
                    user_intents.append("eligibility check")
        
        summary_parts = []
        if topics:
            summary_parts.append(f"Topics: {', '.join(topics)}")
        if user_intents:
            summary_parts.append(f"User asked about: {', '.join(set(user_intents))}")
        
        return "; ".join(summary_parts) if summary_parts else "General conversation about courses"
    
    def prune_context_by_relevance(self, context_parts: List[Tuple[str, str]], query: str) -> List[Tuple[str, str]]:
        """Prune context parts based on relevance to current query"""
        if len(context_parts) <= 3:  # Don't prune if already small
            return context_parts
        
        # Score relevance of each context part
        scored_parts = []
        query_words = set(query.lower().split())
        
        for part_type, content in context_parts:
            content_words = set(content.lower().split())
            overlap = len(query_words.intersection(content_words))
            
            # Boost score for certain context types
            relevance_score = overlap
            if part_type in ["student_profile", "current_query"]:
                relevance_score += 10  # Always keep these
            elif part_type == "conversation_history":
                relevance_score += 5   # High priority
            
            scored_parts.append((relevance_score, part_type, content))
        
        # Sort by relevance and keep top parts
        scored_parts.sort(key=lambda x: x[0], reverse=True)
        
        # Keep parts that fit within token budget
        selected_parts = []
        total_tokens = 0
        
        for score, part_type, content in scored_parts:
            part_tokens = self.estimate_tokens(content)
            if total_tokens + part_tokens <= self.max_tokens * self.compression_ratio:
                selected_parts.append((part_type, content))
                total_tokens += part_tokens
            else:
                self.optimization_stats["context_parts_pruned"] += 1
        
        return selected_parts
    
    def optimize_context(self, context_data: Dict[str, Any], query: str) -> Tuple[str, Dict[str, int]]:
        """Main optimization method that combines all strategies"""
        start_time = time.time()
        
        # Extract context parts
        context_parts = []
        
        # Student profile (always include)
        if "student_profile" in context_data:
            profile_text = self._format_student_profile(context_data["student_profile"])
            context_parts.append(("student_profile", profile_text))
        
        # Conversation history (compress if needed)
        if "conversation_history" in context_data:
            compressed_history = self.compress_conversation_history(context_data["conversation_history"])
            history_text = self._format_conversation_history(compressed_history)
            context_parts.append(("conversation_history", history_text))
        
        # Retrieved courses (limit to most relevant)
        if "retrieved_courses" in context_data:
            courses_text = self._format_courses(context_data["retrieved_courses"][:3])  # Limit to top 3
            context_parts.append(("retrieved_courses", courses_text))
        
        # Memory context (summarize if long)
        if "loaded_memories" in context_data:
            memory_text = self._format_memories(context_data["loaded_memories"][:5])  # Limit to top 5
            context_parts.append(("loaded_memories", memory_text))
        
        # Current query (always include)
        context_parts.append(("current_query", f"CURRENT QUERY: {query}"))
        
        # Prune by relevance
        optimized_parts = self.prune_context_by_relevance(context_parts, query)
        
        # Assemble final context
        final_context = "\n\n".join([content for _, content in optimized_parts])
        
        # Calculate metrics
        optimization_time = time.time() - start_time
        final_tokens = self.estimate_tokens(final_context)
        
        metrics = {
            "original_parts": len(context_parts),
            "optimized_parts": len(optimized_parts),
            "final_tokens": final_tokens,
            "optimization_time_ms": int(optimization_time * 1000),
            "compression_achieved": len(context_parts) > len(optimized_parts)
        }
        
        # Update stats
        self.token_usage_stats["total_tokens"] += final_tokens
        self.optimization_stats["contexts_optimized"] += 1
        
        return final_context, metrics
    
    def _format_student_profile(self, profile: Dict) -> str:
        """Format student profile concisely"""
        return f"""STUDENT: {profile.get('name', 'Unknown')}
Major: {profile.get('major', 'Unknown')}, Year: {profile.get('year', 'Unknown')}
Completed: {', '.join(profile.get('completed_courses', []))}
Interests: {', '.join(profile.get('interests', []))}
Preferences: {profile.get('preferred_format', 'Unknown')}, {profile.get('preferred_difficulty', 'Unknown')} level"""
    
    def _format_conversation_history(self, history: List[Dict]) -> str:
        """Format conversation history concisely"""
        if not history:
            return ""
        
        formatted = "CONVERSATION:\n"
        for msg in history[-4:]:  # Last 4 messages
            role = msg["role"].title()
            content = msg["content"][:100] + "..." if len(msg["content"]) > 100 else msg["content"]
            formatted += f"{role}: {content}\n"
        
        return formatted.strip()
    
    def _format_courses(self, courses: List[Dict]) -> str:
        """Format course information concisely"""
        if not courses:
            return ""
        
        formatted = "RELEVANT COURSES:\n"
        for i, course in enumerate(courses, 1):
            formatted += f"{i}. {course.get('course_code', 'Unknown')}: {course.get('title', 'Unknown')}\n"
            formatted += f"   Level: {course.get('level', 'Unknown')}, Credits: {course.get('credits', 'Unknown')}\n"
        
        return formatted.strip()
    
    def _format_memories(self, memories: List[Dict]) -> str:
        """Format memory information concisely"""
        if not memories:
            return ""
        
        formatted = "RELEVANT MEMORIES:\n"
        for memory in memories:
            if isinstance(memory, dict) and "content" in memory:
                content = memory["content"][:80] + "..." if len(memory["content"]) > 80 else memory["content"]
                formatted += f"- {content}\n"
            else:
                formatted += f"- {str(memory)[:80]}...\n"
        
        return formatted.strip()
    
    def get_optimization_stats(self) -> Dict[str, Any]:
        """Get optimization performance statistics"""
        return {
            "token_usage": dict(self.token_usage_stats),
            "optimization_stats": dict(self.optimization_stats),
            "average_tokens_per_context": (
                self.token_usage_stats["total_tokens"] / max(1, self.optimization_stats["contexts_optimized"])
            )
        }

# Initialize the context optimizer
context_optimizer = ProductionContextOptimizer(max_tokens=4000, compression_ratio=0.7)

print("Production context optimizer initialized")
print(f"Max tokens: {context_optimizer.max_tokens}")
print(f"Compression ratio: {context_optimizer.compression_ratio}")

## Step 3: Build Production-Ready Agent

Let's create an optimized version of your multi-tool agent that uses the context optimizer.

In [None]:
class OptimizedProductionAgent:
    """Production-optimized agent with context compression and monitoring"""
    
    def __init__(self, context_optimizer: ProductionContextOptimizer):
        self.context_optimizer = context_optimizer
        self.course_manager = CourseManager()
        
        # Performance monitoring
        self.performance_metrics = defaultdict(list)
        self.cost_tracking = defaultdict(float)
        
        # Caching for efficiency
        self.query_cache = {}  # Simple in-memory cache
        self.cache_hits = 0
        self.cache_misses = 0
        
        # Session management
        self.active_sessions = {}
        self.session_stats = defaultdict(int)
    
    def start_optimized_session(self, student: StudentProfile) -> str:
        """Start an optimized session with efficient memory management"""
        session_id = f"{student.email}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        
        # Create lightweight session context
        session_context = {
            "student_profile": {
                "name": student.name,
                "email": student.email,
                "major": student.major,
                "year": student.year,
                "completed_courses": student.completed_courses,
                "interests": student.interests[:3],  # Limit to top 3 interests
                "preferred_format": student.preferred_format.value,
                "preferred_difficulty": student.preferred_difficulty.value
            },
            "conversation_history": [],
            "loaded_memories": [],  # Would load from Redis in real system
            "session_start_time": time.time(),
            "query_count": 0
        }
        
        self.active_sessions[session_id] = session_context
        self.session_stats["sessions_started"] += 1
        
        print(f"Started optimized session {session_id} for {student.name}")
        return session_id
    
    def _check_cache(self, query: str, student_email: str) -> Optional[str]:
        """Check if we have a cached response for this query"""
        cache_key = f"{student_email}:{query.lower().strip()}"
        
        if cache_key in self.query_cache:
            cache_entry = self.query_cache[cache_key]
            # Check if cache entry is still fresh (within 1 hour)
            if time.time() - cache_entry["timestamp"] < 3600:
                self.cache_hits += 1
                return cache_entry["response"]
            else:
                # Remove stale cache entry
                del self.query_cache[cache_key]
        
        self.cache_misses += 1
        return None
    
    def _cache_response(self, query: str, student_email: str, response: str):
        """Cache a response for future use"""
        cache_key = f"{student_email}:{query.lower().strip()}"
        self.query_cache[cache_key] = {
            "response": response,
            "timestamp": time.time()
        }
        
        # Limit cache size to prevent memory bloat
        if len(self.query_cache) > 1000:
            # Remove oldest entries
            oldest_keys = sorted(self.query_cache.keys(), 
                               key=lambda k: self.query_cache[k]["timestamp"])[:100]
            for key in oldest_keys:
                del self.query_cache[key]
    
    def optimized_chat(self, session_id: str, query: str) -> Dict[str, Any]:
        """Optimized chat method with performance monitoring"""
        start_time = time.time()
        
        if session_id not in self.active_sessions:
            return {"error": "Invalid session ID", "response": "Please start a session first."}
        
        session_context = self.active_sessions[session_id]
        student_email = session_context["student_profile"]["email"]
        
        # Check cache first
        cached_response = self._check_cache(query, student_email)
        if cached_response:
            return {
                "response": cached_response,
                "cached": True,
                "processing_time_ms": int((time.time() - start_time) * 1000)
            }
        
        # Add query to conversation history
        session_context["conversation_history"].append({
            "role": "user",
            "content": query,
            "timestamp": datetime.now().isoformat()
        })
        session_context["query_count"] += 1
        
        # Simulate course retrieval (would use real search in production)
        retrieved_courses = self._simulate_course_search(query)
        
        # Prepare context data for optimization
        context_data = {
            "student_profile": session_context["student_profile"],
            "conversation_history": session_context["conversation_history"],
            "retrieved_courses": retrieved_courses,
            "loaded_memories": session_context["loaded_memories"]
        }
        
        # Optimize context
        optimized_context, optimization_metrics = self.context_optimizer.optimize_context(context_data, query)
        
        # Generate response (simplified - would use LLM in production)
        response = self._generate_optimized_response(query, retrieved_courses, session_context)
        
        # Add response to conversation history
        session_context["conversation_history"].append({
            "role": "assistant",
            "content": response,
            "timestamp": datetime.now().isoformat()
        })
        
        # Cache the response
        self._cache_response(query, student_email, response)
        
        # Calculate performance metrics
        total_time = time.time() - start_time
        
        # Track costs (simplified calculation)
        estimated_cost = optimization_metrics["final_tokens"] * 0.00002  # $0.02 per 1K tokens
        self.cost_tracking["total_cost"] += estimated_cost
        self.cost_tracking["total_tokens"] += optimization_metrics["final_tokens"]
        
        # Record performance metrics
        self.performance_metrics["response_times"].append(total_time)
        self.performance_metrics["token_counts"].append(optimization_metrics["final_tokens"])
        self.performance_metrics["optimization_times"].append(optimization_metrics["optimization_time_ms"])
        
        return {
            "response": response,
            "cached": False,
            "processing_time_ms": int(total_time * 1000),
            "optimization_metrics": optimization_metrics,
            "estimated_cost": estimated_cost,
            "session_query_count": session_context["query_count"]
        }
    
    def _simulate_course_search(self, query: str) -> List[Dict]:
        """Simulate course search (would use real CourseManager in production)"""
        # Simplified course data for demonstration
        all_courses = [
            {"course_code": "RU101", "title": "Introduction to Redis", "level": "beginner", "credits": 3},
            {"course_code": "RU201", "title": "Redis for Python", "level": "intermediate", "credits": 4},
            {"course_code": "RU301", "title": "Vector Similarity Search", "level": "advanced", "credits": 4},
            {"course_code": "RU302", "title": "Redis for Machine Learning", "level": "advanced", "credits": 4}
        ]
        
        # Simple keyword matching
        query_lower = query.lower()
        relevant_courses = []
        
        for course in all_courses:
            if any(keyword in query_lower for keyword in ["machine learning", "ml", "vector"]):
                if "machine learning" in course["title"].lower() or "vector" in course["title"].lower():
                    relevant_courses.append(course)
            elif "python" in query_lower:
                if "python" in course["title"].lower():
                    relevant_courses.append(course)
            elif "beginner" in query_lower or "introduction" in query_lower:
                if course["level"] == "beginner":
                    relevant_courses.append(course)
        
        return relevant_courses[:3]  # Return top 3 matches
    
    def _generate_optimized_response(self, query: str, courses: List[Dict], session_context: Dict) -> str:
        """Generate optimized response (simplified - would use LLM in production)"""
        if not courses:
            return "I couldn't find specific courses matching your query. Could you provide more details about what you're looking for?"
        
        student_name = session_context["student_profile"]["name"]
        interests = session_context["student_profile"]["interests"]
        
        response = f"Hi {student_name}! Based on your interests in {', '.join(interests)}, I found these relevant courses:\n\n"
        
        for course in courses:
            response += f"• **{course['course_code']}: {course['title']}**\n"
            response += f"  Level: {course['level'].title()}, Credits: {course['credits']}\n\n"
        
        response += "Would you like more details about any of these courses?"
        
        return response
    
    def get_performance_analytics(self) -> Dict[str, Any]:
        """Get comprehensive performance analytics"""
        response_times = self.performance_metrics["response_times"]
        token_counts = self.performance_metrics["token_counts"]
        
        analytics = {
            "performance": {
                "total_queries": len(response_times),
                "avg_response_time_ms": int(sum(response_times) / len(response_times) * 1000) if response_times else 0,
                "max_response_time_ms": int(max(response_times) * 1000) if response_times else 0,
                "min_response_time_ms": int(min(response_times) * 1000) if response_times else 0
            },
            "token_usage": {
                "total_tokens": sum(token_counts),
                "avg_tokens_per_query": int(sum(token_counts) / len(token_counts)) if token_counts else 0,
                "max_tokens_per_query": max(token_counts) if token_counts else 0
            },
            "caching": {
                "cache_hits": self.cache_hits,
                "cache_misses": self.cache_misses,
                "cache_hit_rate": self.cache_hits / (self.cache_hits + self.cache_misses) if (self.cache_hits + self.cache_misses) > 0 else 0,
                "cache_size": len(self.query_cache)
            },
            "costs": {
                "total_estimated_cost": round(self.cost_tracking["total_cost"], 4),
                "total_tokens_processed": int(self.cost_tracking["total_tokens"]),
                "avg_cost_per_query": round(self.cost_tracking["total_cost"] / len(response_times), 4) if response_times else 0
            },
            "sessions": dict(self.session_stats),
            "optimization": self.context_optimizer.get_optimization_stats()
        }
        
        return analytics

# Initialize the optimized production agent
production_agent = OptimizedProductionAgent(context_optimizer)

print("Optimized production agent initialized")
print("Features: Context optimization, caching, performance monitoring, cost tracking")