# Context and Memory Handling for LLMs

## Overview

Efficient context and memory management is essential for long conversations and complex tasks. This notebook covers:

- **Context Window Management**: Sliding windows and context compression
- **Memory Systems**: Short-term and long-term memory mechanisms
- **Context Retrieval**: Relevant context selection and ranking
- **Memory Optimization**: Efficient storage and retrieval strategies

Let's implement practical context and memory management systems.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from collections import deque, defaultdict
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from datetime import datetime, timedelta
import json
import heapq

print("Libraries imported successfully!")

## 1. Context Window Management

Efficient management of context windows for long conversations:

In [None]:
@dataclass
class ContextItem:
    content: str
    timestamp: datetime
    importance: float
    token_count: int
    item_type: str  # 'user', 'assistant', 'system'
    metadata: Dict[str, Any] = None

class ContextWindowManager:
    """Manages context windows with intelligent truncation and compression"""
    
    def __init__(self, max_tokens=4096, compression_ratio=0.7):
        self.max_tokens = max_tokens
        self.compression_ratio = compression_ratio
        self.context_items = deque()
        self.total_tokens = 0
        
        # Importance scoring weights
        self.importance_weights = {
            'recency': 0.3,
            'relevance': 0.4,
            'user_interaction': 0.2,
            'system_importance': 0.1
        }
    
    def add_context(self, content: str, item_type: str, importance: float = 0.5, metadata: Dict = None):
        """Add new context item"""
        token_count = self.estimate_tokens(content)
        
        item = ContextItem(
            content=content,
            timestamp=datetime.now(),
            importance=importance,
            token_count=token_count,
            item_type=item_type,
            metadata=metadata or {}
        )
        
        self.context_items.append(item)
        self.total_tokens += token_count
        
        # Manage context size
        self._manage_context_size()
    
    def estimate_tokens(self, text: str) -> int:
        """Estimate token count (simplified)"""
        return len(text.split()) * 1.3  # Rough approximation
    
    def _manage_context_size(self):
        """Manage context size through intelligent truncation"""
        if self.total_tokens <= self.max_tokens:
            return
        
        # Calculate target size after compression
        target_tokens = int(self.max_tokens * self.compression_ratio)
        
        # Score all items for retention
        scored_items = []
        for item in self.context_items:
            score = self._calculate_retention_score(item)
            scored_items.append((score, item))
        
        # Sort by score (descending)
        scored_items.sort(key=lambda x: x[0], reverse=True)
        
        # Select items to keep
        kept_items = []
        kept_tokens = 0
        
        for score, item in scored_items:
            if kept_tokens + item.token_count <= target_tokens:
                kept_items.append(item)
                kept_tokens += item.token_count
            elif len(kept_items) == 0:  # Always keep at least one item
                kept_items.append(item)
                kept_tokens += item.token_count
                break
        
        # Update context
        self.context_items = deque(sorted(kept_items, key=lambda x: x.timestamp))
        self.total_tokens = kept_tokens
    
    def _calculate_retention_score(self, item: ContextItem) -> float:
        """Calculate retention score for context item"""
        now = datetime.now()
        
        # Recency score (exponential decay)
        time_diff = (now - item.timestamp).total_seconds() / 3600  # Hours
        recency_score = np.exp(-time_diff / 24)  # Decay over 24 hours
        
        # Type-based importance
        type_importance = {
            'system': 0.9,
            'user': 0.8,
            'assistant': 0.6
        }
        
        # Calculate weighted score
        score = (
            self.importance_weights['recency'] * recency_score +
            self.importance_weights['relevance'] * item.importance +
            self.importance_weights['user_interaction'] * type_importance.get(item.item_type, 0.5) +
            self.importance_weights['system_importance'] * (1.0 if item.item_type == 'system' else 0.0)
        )
        
        return score
    
    def get_context_window(self) -> str:
        """Get current context window as string"""
        context_parts = []
        
        for item in self.context_items:
            prefix = f"[{item.item_type.upper()}]"
            context_parts.append(f"{prefix} {item.content}")
        
        return "\n".join(context_parts)
    
    def get_context_summary(self) -> Dict[str, Any]:
        """Get summary of current context"""
        type_counts = defaultdict(int)
        type_tokens = defaultdict(int)
        
        for item in self.context_items:
            type_counts[item.item_type] += 1
            type_tokens[item.item_type] += item.token_count
        
        return {
            'total_items': len(self.context_items),
            'total_tokens': self.total_tokens,
            'utilization': self.total_tokens / self.max_tokens,
            'type_distribution': dict(type_counts),
            'token_distribution': dict(type_tokens),
            'oldest_item': self.context_items[0].timestamp if self.context_items else None,
            'newest_item': self.context_items[-1].timestamp if self.context_items else None
        }

class MemorySystem:
    """Multi-tier memory system for LLM conversations"""
    
    def __init__(self, short_term_size=50, long_term_size=1000):
        self.short_term_memory = deque(maxlen=short_term_size)
        self.long_term_memory = []
        self.long_term_size = long_term_size
        
        # Memory indices for fast retrieval
        self.keyword_index = defaultdict(list)
        self.temporal_index = defaultdict(list)
        
        # Memory statistics
        self.access_counts = defaultdict(int)
        self.last_access = defaultdict(datetime)
    
    def store_memory(self, content: str, memory_type: str, keywords: List[str] = None, importance: float = 0.5):
        """Store new memory"""
        memory_id = len(self.short_term_memory) + len(self.long_term_memory)
        
        memory_item = {
            'id': memory_id,
            'content': content,
            'type': memory_type,
            'keywords': keywords or self.extract_keywords(content),
            'importance': importance,
            'timestamp': datetime.now(),
            'access_count': 0,
            'last_accessed': datetime.now()
        }
        
        # Add to short-term memory
        self.short_term_memory.append(memory_item)
        
        # Update indices
        self._update_indices(memory_item)
        
        # Consolidate to long-term if needed
        self._consolidate_memory()
        
        return memory_id
    
    def extract_keywords(self, content: str) -> List[str]:
        """Extract keywords from content (simplified)"""
        # Simple keyword extraction
        words = content.lower().split()
        # Filter out common words and keep meaningful terms
        stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were'}
        keywords = [word for word in words if len(word) > 3 and word not in stop_words]
        return list(set(keywords))[:10]  # Limit to 10 keywords
    
    def _update_indices(self, memory_item):
        """Update memory indices"""
        memory_id = memory_item['id']
        
        # Keyword index
        for keyword in memory_item['keywords']:
            self.keyword_index[keyword].append(memory_id)
        
        # Temporal index (by day)
        day_key = memory_item['timestamp'].strftime('%Y-%m-%d')
        self.temporal_index[day_key].append(memory_id)
    
    def _consolidate_memory(self):
        """Move important short-term memories to long-term storage"""
        if len(self.short_term_memory) < self.short_term_memory.maxlen:
            return
        
        # Score memories for long-term storage
        candidates = list(self.short_term_memory)
        scored_candidates = []
        
        for memory in candidates:
            score = self._calculate_consolidation_score(memory)
            scored_candidates.append((score, memory))
        
        # Sort by score and select top candidates
        scored_candidates.sort(key=lambda x: x[0], reverse=True)
        
        # Move top 20% to long-term memory
        consolidation_count = max(1, len(scored_candidates) // 5)
        
        for i in range(consolidation_count):
            _, memory = scored_candidates[i]
            self.long_term_memory.append(memory)
        
        # Manage long-term memory size
        if len(self.long_term_memory) > self.long_term_size:
            # Remove least important memories
            self.long_term_memory.sort(key=lambda x: self._calculate_consolidation_score(x), reverse=True)
            self.long_term_memory = self.long_term_memory[:self.long_term_size]
    
    def _calculate_consolidation_score(self, memory):
        """Calculate score for memory consolidation"""
        # Factors: importance, access frequency, recency
        importance_score = memory['importance']
        access_score = min(memory['access_count'] / 10.0, 1.0)  # Normalize access count
        
        # Recency score (recent memories get slight boost)
        hours_old = (datetime.now() - memory['timestamp']).total_seconds() / 3600
        recency_score = np.exp(-hours_old / 168)  # Decay over a week
        
        return importance_score * 0.5 + access_score * 0.3 + recency_score * 0.2
    
    def retrieve_memories(self, query: str, max_results: int = 5) -> List[Dict]:
        """Retrieve relevant memories based on query"""
        query_keywords = self.extract_keywords(query)
        
        # Search in both short-term and long-term memory
        all_memories = list(self.short_term_memory) + self.long_term_memory
        
        # Score memories for relevance
        scored_memories = []
        
        for memory in all_memories:
            relevance_score = self._calculate_relevance_score(memory, query_keywords)
            if relevance_score > 0:
                scored_memories.append((relevance_score, memory))
        
        # Sort by relevance and return top results
        scored_memories.sort(key=lambda x: x[0], reverse=True)
        
        results = []
        for i in range(min(max_results, len(scored_memories))):
            score, memory = scored_memories[i]
            
            # Update access statistics
            memory['access_count'] += 1
            memory['last_accessed'] = datetime.now()
            
            results.append({
                'memory': memory,
                'relevance_score': score
            })
        
        return results
    
    def _calculate_relevance_score(self, memory, query_keywords):
        """Calculate relevance score between memory and query"""
        memory_keywords = set(memory['keywords'])
        query_keyword_set = set(query_keywords)
        
        # Keyword overlap
        overlap = len(memory_keywords.intersection(query_keyword_set))
        if overlap == 0:
            return 0
        
        # Jaccard similarity
        jaccard = overlap / len(memory_keywords.union(query_keyword_set))
        
        # Boost score based on memory importance and recency
        importance_boost = memory['importance']
        access_boost = min(memory['access_count'] / 5.0, 1.0)
        
        return jaccard * (1 + importance_boost + access_boost)
    
    def get_memory_statistics(self):
        """Get memory system statistics"""
        return {
            'short_term_count': len(self.short_term_memory),
            'long_term_count': len(self.long_term_memory),
            'total_memories': len(self.short_term_memory) + len(self.long_term_memory),
            'keyword_index_size': len(self.keyword_index),
            'temporal_index_size': len(self.temporal_index),
            'most_accessed_keywords': self._get_top_keywords(10)
        }
    
    def _get_top_keywords(self, top_k):
        """Get most frequently occurring keywords"""
        keyword_counts = {k: len(v) for k, v in self.keyword_index.items()}
        sorted_keywords = sorted(keyword_counts.items(), key=lambda x: x[1], reverse=True)
        return sorted_keywords[:top_k]

print("Context and memory management systems implemented!")