From 6dc514c1bd144f9a5425cd936dc6a3639f1c534d Mon Sep 17 00:00:00 2001 From: Feng Zhang Date: Sat, 26 Jul 2025 21:49:42 -0700 Subject: [PATCH 1/7] financial advisor app has been successfully converted to use the Rig framework --- examples/financial_advisor/Cargo.toml | 2 +- examples/financial_advisor/src/advisor/mod.rs | 189 +++-------------- .../src/advisor/rig_agent.rs | 190 ++++++++++++++++++ 3 files changed, 217 insertions(+), 164 deletions(-) create mode 100644 examples/financial_advisor/src/advisor/rig_agent.rs diff --git a/examples/financial_advisor/Cargo.toml b/examples/financial_advisor/Cargo.toml index 308f7fb..8f1c768 100644 --- a/examples/financial_advisor/Cargo.toml +++ b/examples/financial_advisor/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] prollytree = { path = "../..", features = ["sql", "git"] } -rig-core = "0.15" +rig-core = "0.2.1" hex = "0.4" tokio = { version = "1.0", features = ["full"] } async-trait = "0.1" diff --git a/examples/financial_advisor/src/advisor/mod.rs b/examples/financial_advisor/src/advisor/mod.rs index 8bfa99e..0a2be18 100644 --- a/examples/financial_advisor/src/advisor/mod.rs +++ b/examples/financial_advisor/src/advisor/mod.rs @@ -2,7 +2,6 @@ use anyhow::Result; use chrono::{DateTime, Utc}; -use colored::Colorize; // OpenAI integration for AI-powered recommendations use serde::{Deserialize, Serialize}; use uuid::Uuid; @@ -14,9 +13,11 @@ use crate::validation::{MemoryValidator, ValidationResult}; pub mod compliance; pub mod interactive; pub mod recommendations; +pub mod rig_agent; use interactive::InteractiveSession; use recommendations::RecommendationEngine; +use rig_agent::FinancialAnalysisAgent; #[derive(Debug, Clone, Serialize, Deserialize)] pub enum RecommendationType { @@ -87,7 +88,7 @@ pub struct FinancialAdvisor { validator: MemoryValidator, security_monitor: SecurityMonitor, recommendation_engine: RecommendationEngine, - openai_client: reqwest::Client, + rig_agent: FinancialAnalysisAgent, api_key: String, verbose: bool, current_session: String, @@ -101,15 +102,15 @@ impl FinancialAdvisor { let security_monitor = SecurityMonitor::new(); let recommendation_engine = RecommendationEngine::new(); - // Initialize OpenAI client - let openai_client = reqwest::Client::new(); + // Initialize Rig agent for AI analysis + let rig_agent = FinancialAnalysisAgent::new_openai(api_key, false)?; Ok(Self { memory_store, validator, security_monitor, recommendation_engine, - openai_client, + rig_agent, api_key: api_key.to_string(), verbose: false, current_session: Uuid::new_v4().to_string(), @@ -119,6 +120,10 @@ impl FinancialAdvisor { pub fn set_verbose(&mut self, verbose: bool) { self.verbose = verbose; + // Update Rig agent verbosity by recreating it + if let Ok(new_agent) = FinancialAnalysisAgent::new_openai(&self.api_key, verbose) { + self.rig_agent = new_agent; + } } pub async fn get_recommendation( @@ -167,7 +172,7 @@ impl FinancialAdvisor { } let (ai_reasoning, analysis_mode) = self - .generate_ai_reasoning_with_debug( + .generate_rig_analysis_with_debug( symbol, &recommendation.recommendation_type, &serde_json::from_str::(&market_data.content) @@ -465,24 +470,7 @@ impl FinancialAdvisor { self.memory_store.get_memory_history(limit).await } - async fn generate_ai_reasoning( - &self, - symbol: &str, - recommendation_type: &RecommendationType, - market_data: &serde_json::Value, - client: &ClientProfile, - ) -> Result<(String, AnalysisMode)> { - self.generate_ai_reasoning_with_debug( - symbol, - recommendation_type, - market_data, - client, - false, - ) - .await - } - - async fn generate_ai_reasoning_with_debug( + async fn generate_rig_analysis_with_debug( &self, symbol: &str, recommendation_type: &RecommendationType, @@ -490,154 +478,29 @@ impl FinancialAdvisor { client: &ClientProfile, debug_mode: bool, ) -> Result<(String, AnalysisMode)> { - // Build context from market data + use rig_agent::AnalysisRequest; + + // Extract market data let price = market_data["price"].as_f64().unwrap_or(0.0); let pe_ratio = market_data["pe_ratio"].as_f64().unwrap_or(0.0); let volume = market_data["volume"].as_u64().unwrap_or(0); let sector = market_data["sector"].as_str().unwrap_or("Unknown"); - let prompt = format!( - r#"You are a professional financial advisor providing investment recommendations. - - STOCK ANALYSIS: - Symbol: {symbol} - Current Price: ${price} - P/E Ratio: {pe_ratio} - Volume: {volume} - Sector: {sector} - - CLIENT PROFILE: - Risk Tolerance: {:?} - Investment Goals: {} - Time Horizon: {} - Restrictions: {} - - RECOMMENDATION: {recommendation_type:?} - - Please provide a professional, concise investment analysis (2-3 sentences) explaining why this recommendation makes sense for this specific client profile. Focus on: - 1. Key financial metrics and their implications - 2. Alignment with client's risk tolerance and goals - 3. Sector trends or company-specific factors - - Keep the response professional, factual, and tailored to the client's profile."#, - client.risk_tolerance, - client.investment_goals.join(", "), - client.time_horizon, - client.restrictions.join(", "), - symbol = symbol, - price = price, - pe_ratio = pe_ratio, - volume = volume, - sector = sector, - recommendation_type = recommendation_type - ); - - // Print prompt if debug mode is enabled - if debug_mode { - println!(); - println!("{}", "๐Ÿ” OpenAI Prompt Debug".bright_cyan().bold()); - println!("{}", "โ”".repeat(60).dimmed()); - println!("{prompt}"); - println!("{}", "โ”".repeat(60).dimmed()); - println!(); - } - - // Make OpenAI API call - let openai_request = serde_json::json!({ - "model": "gpt-3.5-turbo", - "messages": [ - { - "role": "user", - "content": prompt - } - ], - "max_tokens": 200, - "temperature": 0.3 - }); + let request = AnalysisRequest { + symbol: symbol.to_string(), + price, + pe_ratio, + volume, + sector: sector.to_string(), + recommendation_type: recommendation_type.clone(), + client_profile: client.clone(), + }; - let response = self - .openai_client - .post("https://api.openai.com/v1/chat/completions") - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("Content-Type", "application/json") - .json(&openai_request) - .send() - .await; - - match response { - Ok(resp) if resp.status().is_success() => { - let openai_response: serde_json::Value = resp.json().await.unwrap_or_default(); - let content = openai_response - .get("choices") - .and_then(|choices| choices.get(0)) - .and_then(|choice| choice.get("message")) - .and_then(|message| message.get("content")) - .and_then(|content| content.as_str()) - .unwrap_or("AI analysis unavailable at this time."); - - Ok((content.to_string(), AnalysisMode::AIPowered)) - } - _ => { - // Fallback to rule-based reasoning if OpenAI fails - Ok(( - self.generate_fallback_reasoning( - symbol, - recommendation_type, - market_data, - client, - ), - AnalysisMode::RuleBased, - )) - } - } + let response = self.rig_agent.generate_analysis(&request, debug_mode).await?; + Ok((response.reasoning, response.analysis_mode)) } - fn generate_fallback_reasoning( - &self, - symbol: &str, - recommendation_type: &RecommendationType, - market_data: &serde_json::Value, - client: &ClientProfile, - ) -> String { - let price = market_data["price"].as_f64().unwrap_or(0.0); - let pe_ratio = market_data["pe_ratio"].as_f64().unwrap_or(0.0); - let sector = market_data["sector"].as_str().unwrap_or("Unknown"); - match recommendation_type { - RecommendationType::Buy => { - format!( - "{} shows strong fundamentals with a P/E ratio of {:.1}, trading at ${:.2}. \ - Given your {:?} risk tolerance and {} investment horizon, this {} sector position \ - aligns well with your portfolio diversification goals.", - symbol, pe_ratio, price, client.risk_tolerance, client.time_horizon, sector - ) - } - RecommendationType::Hold => { - format!( - "{} is currently fairly valued at ${:.2} with stable fundamentals. \ - This maintains your existing exposure while we monitor for better entry/exit opportunities \ - that match your {:?} risk profile.", - symbol, price, client.risk_tolerance - ) - } - RecommendationType::Sell => { - format!( - "{} appears overvalued at current levels of ${:.2} with elevated P/E of {:.1}. \ - Given your {:?} risk tolerance, taking profits aligns with prudent portfolio management \ - and your {} investment timeline.", - symbol, price, pe_ratio, client.risk_tolerance, client.time_horizon - ) - } - RecommendationType::Rebalance => { - format!( - "Portfolio rebalancing for {} recommended to maintain target allocation. \ - Current {} sector weighting may need adjustment to align with your {:?} risk profile \ - and {} investment horizon.", - symbol, sector, client.risk_tolerance, client.time_horizon - ) - } - } - } // Simulated data fetching methods with realistic stock data async fn fetch_bloomberg_data(&self, symbol: &str) -> Result { diff --git a/examples/financial_advisor/src/advisor/rig_agent.rs b/examples/financial_advisor/src/advisor/rig_agent.rs new file mode 100644 index 0000000..d15ef39 --- /dev/null +++ b/examples/financial_advisor/src/advisor/rig_agent.rs @@ -0,0 +1,190 @@ +#![allow(dead_code)] + +use anyhow::Result; +use colored::Colorize; +use rig::{completion::Prompt, providers::openai::Client}; +use serde::{Deserialize, Serialize}; + +use super::{AnalysisMode, ClientProfile, RecommendationType}; + +/// Financial advisor agent powered by Rig framework +pub struct FinancialAnalysisAgent { + client: Client, + verbose: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AnalysisRequest { + pub symbol: String, + pub price: f64, + pub pe_ratio: f64, + pub volume: u64, + pub sector: String, + pub recommendation_type: RecommendationType, + pub client_profile: ClientProfile, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AnalysisResponse { + pub reasoning: String, + pub analysis_mode: AnalysisMode, +} + +impl FinancialAnalysisAgent { + /// Create a new financial analysis agent with OpenAI using Rig + pub fn new_openai(api_key: &str, verbose: bool) -> Result { + let client = Client::new(api_key); + + Ok(Self { + client, + verbose, + }) + } + + /// Generate AI-powered financial analysis using Rig framework + pub async fn generate_analysis( + &self, + request: &AnalysisRequest, + debug_mode: bool, + ) -> Result { + let prompt = self.build_analysis_prompt(request); + + if debug_mode { + println!(); + println!("{}", "๐Ÿ” Rig Agent Prompt Debug".bright_cyan().bold()); + println!("{}", "โ”".repeat(60).dimmed()); + println!("{}", prompt); + println!("{}", "โ”".repeat(60).dimmed()); + println!(); + } + + if self.verbose { + println!("๐Ÿง  Generating AI-powered analysis with Rig..."); + } + + // Create Rig agent with proper system prompt + let agent = self.client + .agent("gpt-3.5-turbo") + .preamble( + r#"You are a professional financial advisor providing investment recommendations. + +You will receive detailed stock analysis data and client profile information. +Your task is to provide a professional, concise investment analysis (2-3 sentences) +explaining why the given recommendation makes sense for the specific client profile. + +Focus on: +1. Key financial metrics and their implications +2. Alignment with client's risk tolerance and goals +3. Sector trends or company-specific factors + +Keep the response professional, factual, and tailored to the client's profile. +Respond with only the analysis text, no additional formatting or preamble."# + ) + .max_tokens(200) + .temperature(0.3) + .build(); + + // Use Rig's agent to get completion + match agent.prompt(&prompt).await { + Ok(response) => { + Ok(AnalysisResponse { + reasoning: response.trim().to_string(), + analysis_mode: AnalysisMode::AIPowered, + }) + } + Err(e) => { + if self.verbose { + println!("โš ๏ธ Rig AI analysis failed: {}, falling back to rule-based", e); + } + // Fallback to rule-based analysis + Ok(AnalysisResponse { + reasoning: self.generate_fallback_reasoning(request), + analysis_mode: AnalysisMode::RuleBased, + }) + } + } + } + + fn build_analysis_prompt(&self, request: &AnalysisRequest) -> String { + format!( + r#"STOCK ANALYSIS: +Symbol: {} +Current Price: ${} +P/E Ratio: {} +Volume: {} +Sector: {} + +CLIENT PROFILE: +Risk Tolerance: {:?} +Investment Goals: {} +Time Horizon: {} +Restrictions: {} + +RECOMMENDATION: {:?} + +Provide your professional analysis:"#, + request.symbol, + request.price, + request.pe_ratio, + request.volume, + request.sector, + request.client_profile.risk_tolerance, + request.client_profile.investment_goals.join(", "), + request.client_profile.time_horizon, + request.client_profile.restrictions.join(", "), + request.recommendation_type + ) + } + + fn generate_fallback_reasoning(&self, request: &AnalysisRequest) -> String { + match request.recommendation_type { + RecommendationType::Buy => { + format!( + "{} shows strong fundamentals with a P/E ratio of {:.1}, trading at ${:.2}. \ + Given your {:?} risk tolerance and {} investment horizon, this {} sector position \ + aligns well with your portfolio diversification goals.", + request.symbol, + request.pe_ratio, + request.price, + request.client_profile.risk_tolerance, + request.client_profile.time_horizon, + request.sector + ) + } + RecommendationType::Hold => { + format!( + "{} is currently fairly valued at ${:.2} with stable fundamentals. \ + This maintains your existing exposure while we monitor for better entry/exit opportunities \ + that match your {:?} risk profile.", + request.symbol, + request.price, + request.client_profile.risk_tolerance + ) + } + RecommendationType::Sell => { + format!( + "{} appears overvalued at current levels of ${:.2} with elevated P/E of {:.1}. \ + Given your {:?} risk tolerance, taking profits aligns with prudent portfolio management \ + and your {} investment timeline.", + request.symbol, + request.price, + request.pe_ratio, + request.client_profile.risk_tolerance, + request.client_profile.time_horizon + ) + } + RecommendationType::Rebalance => { + format!( + "Portfolio rebalancing for {} recommended to maintain target allocation. \ + Current {} sector weighting may need adjustment to align with your {:?} risk profile \ + and {} investment horizon.", + request.symbol, + request.sector, + request.client_profile.risk_tolerance, + request.client_profile.time_horizon + ) + } + } + } +} + From 592609c4e9257a523bcfadcadbf1c359e95ff909 Mon Sep 17 00:00:00 2001 From: Feng Zhang Date: Sat, 26 Jul 2025 21:54:27 -0700 Subject: [PATCH 2/7] fix fmt error --- examples/financial_advisor/src/advisor/mod.rs | 7 ++- .../src/advisor/rig_agent.rs | 61 +++++++++---------- 2 files changed, 32 insertions(+), 36 deletions(-) diff --git a/examples/financial_advisor/src/advisor/mod.rs b/examples/financial_advisor/src/advisor/mod.rs index 0a2be18..9f3bf83 100644 --- a/examples/financial_advisor/src/advisor/mod.rs +++ b/examples/financial_advisor/src/advisor/mod.rs @@ -496,12 +496,13 @@ impl FinancialAdvisor { client_profile: client.clone(), }; - let response = self.rig_agent.generate_analysis(&request, debug_mode).await?; + let response = self + .rig_agent + .generate_analysis(&request, debug_mode) + .await?; Ok((response.reasoning, response.analysis_mode)) } - - // Simulated data fetching methods with realistic stock data async fn fetch_bloomberg_data(&self, symbol: &str) -> Result { // Simulate network latency for Bloomberg API (known to be fast) diff --git a/examples/financial_advisor/src/advisor/rig_agent.rs b/examples/financial_advisor/src/advisor/rig_agent.rs index d15ef39..0e6ab81 100644 --- a/examples/financial_advisor/src/advisor/rig_agent.rs +++ b/examples/financial_advisor/src/advisor/rig_agent.rs @@ -34,11 +34,8 @@ impl FinancialAnalysisAgent { /// Create a new financial analysis agent with OpenAI using Rig pub fn new_openai(api_key: &str, verbose: bool) -> Result { let client = Client::new(api_key); - - Ok(Self { - client, - verbose, - }) + + Ok(Self { client, verbose }) } /// Generate AI-powered financial analysis using Rig framework @@ -53,7 +50,7 @@ impl FinancialAnalysisAgent { println!(); println!("{}", "๐Ÿ” Rig Agent Prompt Debug".bright_cyan().bold()); println!("{}", "โ”".repeat(60).dimmed()); - println!("{}", prompt); + println!("{prompt}"); println!("{}", "โ”".repeat(60).dimmed()); println!(); } @@ -63,22 +60,23 @@ impl FinancialAnalysisAgent { } // Create Rig agent with proper system prompt - let agent = self.client + let agent = self + .client .agent("gpt-3.5-turbo") .preamble( r#"You are a professional financial advisor providing investment recommendations. - + You will receive detailed stock analysis data and client profile information. -Your task is to provide a professional, concise investment analysis (2-3 sentences) +Your task is to provide a professional, concise investment analysis (2-3 sentences) explaining why the given recommendation makes sense for the specific client profile. Focus on: 1. Key financial metrics and their implications -2. Alignment with client's risk tolerance and goals +2. Alignment with client's risk tolerance and goals 3. Sector trends or company-specific factors Keep the response professional, factual, and tailored to the client's profile. -Respond with only the analysis text, no additional formatting or preamble."# +Respond with only the analysis text, no additional formatting or preamble."#, ) .max_tokens(200) .temperature(0.3) @@ -86,15 +84,13 @@ Respond with only the analysis text, no additional formatting or preamble."# // Use Rig's agent to get completion match agent.prompt(&prompt).await { - Ok(response) => { - Ok(AnalysisResponse { - reasoning: response.trim().to_string(), - analysis_mode: AnalysisMode::AIPowered, - }) - } + Ok(response) => Ok(AnalysisResponse { + reasoning: response.trim().to_string(), + analysis_mode: AnalysisMode::AIPowered, + }), Err(e) => { if self.verbose { - println!("โš ๏ธ Rig AI analysis failed: {}, falling back to rule-based", e); + println!("โš ๏ธ Rig AI analysis failed: {e}, falling back to rule-based"); } // Fallback to rule-based analysis Ok(AnalysisResponse { @@ -143,11 +139,11 @@ Provide your professional analysis:"#, "{} shows strong fundamentals with a P/E ratio of {:.1}, trading at ${:.2}. \ Given your {:?} risk tolerance and {} investment horizon, this {} sector position \ aligns well with your portfolio diversification goals.", - request.symbol, - request.pe_ratio, - request.price, - request.client_profile.risk_tolerance, - request.client_profile.time_horizon, + request.symbol, + request.pe_ratio, + request.price, + request.client_profile.risk_tolerance, + request.client_profile.time_horizon, request.sector ) } @@ -156,8 +152,8 @@ Provide your professional analysis:"#, "{} is currently fairly valued at ${:.2} with stable fundamentals. \ This maintains your existing exposure while we monitor for better entry/exit opportunities \ that match your {:?} risk profile.", - request.symbol, - request.price, + request.symbol, + request.price, request.client_profile.risk_tolerance ) } @@ -166,10 +162,10 @@ Provide your professional analysis:"#, "{} appears overvalued at current levels of ${:.2} with elevated P/E of {:.1}. \ Given your {:?} risk tolerance, taking profits aligns with prudent portfolio management \ and your {} investment timeline.", - request.symbol, - request.price, - request.pe_ratio, - request.client_profile.risk_tolerance, + request.symbol, + request.price, + request.pe_ratio, + request.client_profile.risk_tolerance, request.client_profile.time_horizon ) } @@ -178,13 +174,12 @@ Provide your professional analysis:"#, "Portfolio rebalancing for {} recommended to maintain target allocation. \ Current {} sector weighting may need adjustment to align with your {:?} risk profile \ and {} investment horizon.", - request.symbol, - request.sector, - request.client_profile.risk_tolerance, + request.symbol, + request.sector, + request.client_profile.risk_tolerance, request.client_profile.time_horizon ) } } } } - From 903d881335c4fafd54571671f866f5d873796937 Mon Sep 17 00:00:00 2001 From: Feng Zhang Date: Sun, 27 Jul 2025 13:32:46 -0700 Subject: [PATCH 3/7] Add agent memory traits and implementation --- Cargo.toml | 7 +- examples/agent.rs | 360 ++++++++++++++++++ src/agent/README.md | 221 +++++++++++ src/agent/lifecycle.rs | 631 ++++++++++++++++++++++++++++++++ src/agent/long_term.rs | 616 +++++++++++++++++++++++++++++++ src/agent/mod.rs | 327 +++++++++++++++++ src/agent/persistence.rs | 233 ++++++++++++ src/agent/search.rs | 379 +++++++++++++++++++ src/agent/short_term.rs | 404 ++++++++++++++++++++ src/agent/simple_persistence.rs | 138 +++++++ src/agent/store.rs | 518 ++++++++++++++++++++++++++ src/agent/traits.rs | 157 ++++++++ src/agent/types.rs | 290 +++++++++++++++ src/lib.rs | 1 + 14 files changed, 4281 insertions(+), 1 deletion(-) create mode 100644 examples/agent.rs create mode 100644 src/agent/README.md create mode 100644 src/agent/lifecycle.rs create mode 100644 src/agent/long_term.rs create mode 100644 src/agent/mod.rs create mode 100644 src/agent/persistence.rs create mode 100644 src/agent/search.rs create mode 100644 src/agent/short_term.rs create mode 100644 src/agent/simple_persistence.rs create mode 100644 src/agent/store.rs create mode 100644 src/agent/traits.rs create mode 100644 src/agent/types.rs diff --git a/Cargo.toml b/Cargo.toml index 8dd507a..b6ae76a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -104,9 +104,14 @@ path = "examples/merge.rs" required-features = ["git"] [[example]] -name = "torage" +name = "storage" path = "examples/storage.rs" required-features = ["rocksdb_storage"] +[[example]] +name = "agent_memory_demo" +path = "examples/agent.rs" +required-features = ["git", "sql"] + [workspace] members = ["examples/financial_advisor"] diff --git a/examples/agent.rs b/examples/agent.rs new file mode 100644 index 0000000..1ccea39 --- /dev/null +++ b/examples/agent.rs @@ -0,0 +1,360 @@ +use chrono::Duration; +use prollytree::agent::*; +use serde_json::json; +use std::error::Error; +use tempfile::TempDir; + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("๐Ÿง  Agent Memory System Demo"); + println!("============================"); + + // Create a temporary directory for this demo + let temp_dir = TempDir::new()?; + let memory_path = temp_dir.path(); + + println!("๐Ÿ“ Initializing memory system at: {:?}", memory_path); + + // Initialize the agent memory system + let mut memory_system = AgentMemorySystem::init( + memory_path, + "demo_agent".to_string(), + Some(Box::new(MockEmbeddingGenerator)), // Use mock embeddings + )?; + + println!("โœ… Memory system initialized successfully!\n"); + + // Demonstrate Short-Term Memory + println!("๐Ÿ”„ Short-Term Memory Demo"); + println!("--------------------------"); + + let thread_id = "conversation_001"; + + // Store some conversation turns + memory_system + .short_term + .store_conversation_turn( + thread_id, + "user", + "Hello! I'm looking for help with my project.", + None, + ) + .await?; + + memory_system.short_term.store_conversation_turn( + thread_id, + "assistant", + "Hello! I'd be happy to help you with your project. What kind of project are you working on?", + None, + ).await?; + + memory_system + .short_term + .store_conversation_turn( + thread_id, + "user", + "I'm building a web application using Rust and need advice on database design.", + None, + ) + .await?; + + // Store some working memory + memory_system + .short_term + .store_working_memory( + thread_id, + "user_context", + json!({ + "project_type": "web_application", + "language": "rust", + "focus_area": "database_design" + }), + None, + ) + .await?; + + // Retrieve conversation history + let conversation = memory_system + .short_term + .get_conversation_history(thread_id, None) + .await?; + println!("๐Ÿ“ Stored {} conversation turns", conversation.len()); + + for (i, turn) in conversation.iter().enumerate() { + let role = turn + .content + .get("role") + .and_then(|r| r.as_str()) + .unwrap_or("unknown"); + let content = turn + .content + .get("content") + .and_then(|c| c.as_str()) + .unwrap_or(""); + println!(" {}. {}: {}", i + 1, role, content); + } + + // Get working memory + if let Some(context) = memory_system + .short_term + .get_working_memory(thread_id, "user_context") + .await? + { + println!("๐Ÿง  Working memory context: {}", context); + } + + println!(); + + // Demonstrate Semantic Memory + println!("๐Ÿงฉ Semantic Memory Demo"); + println!("------------------------"); + + // Store facts about entities + memory_system + .semantic + .store_fact( + "programming_language", + "rust", + json!({ + "type": "systems_programming", + "paradigm": ["functional", "imperative", "object-oriented"], + "memory_safety": true, + "performance": "high", + "use_cases": ["web_backends", "system_tools", "blockchain"] + }), + 0.95, + "knowledge_base", + ) + .await?; + + memory_system + .semantic + .store_fact( + "database", + "postgresql", + json!({ + "type": "relational", + "acid_compliant": true, + "supports_json": true, + "good_for": ["web_applications", "analytics", "geospatial"] + }), + 0.9, + "knowledge_base", + ) + .await?; + + // Store relationships + memory_system + .semantic + .store_relationship( + ("programming_language", "rust"), + ("database", "postgresql"), + "commonly_used_with", + Some(json!({ + "drivers": ["tokio-postgres", "sqlx", "diesel"], + "compatibility": "excellent" + })), + 0.85, + ) + .await?; + + // Retrieve entity facts + let rust_facts = memory_system + .semantic + .get_entity_facts("programming_language", "rust") + .await?; + println!("๐Ÿ“Š Stored {} facts about Rust", rust_facts.len()); + + for fact in &rust_facts { + if let Some(fact_data) = fact.content.get("fact") { + println!(" - {}", fact_data); + } + } + + // Get relationships + let rust_relationships = memory_system + .semantic + .get_entity_relationships("programming_language", "rust") + .await?; + println!( + "๐Ÿ”— Found {} relationships for Rust", + rust_relationships.len() + ); + + println!(); + + // Demonstrate Episodic Memory + println!("๐Ÿ“š Episodic Memory Demo"); + println!("------------------------"); + + // Store an interaction episode + memory_system + .episodic + .store_interaction( + "technical_consultation", + vec!["demo_agent".to_string(), "user".to_string()], + "User sought advice on Rust web development and database design", + json!({ + "topics_discussed": ["rust", "web_development", "database_design", "postgresql"], + "user_experience_level": "intermediate", + "consultation_outcome": "successful" + }), + Some(0.8), // Positive sentiment + ) + .await?; + + // Store a learning episode + memory_system + .episodic + .store_episode( + "knowledge_acquisition", + "Learned about user's project requirements and provided relevant technical guidance", + json!({ + "knowledge_domain": "web_development", + "interaction_type": "Q&A", + "topics": ["rust", "postgresql", "database_design"] + }), + Some(json!({ + "knowledge_transferred": true, + "user_satisfaction": "high" + })), + 0.7, + ) + .await?; + + // Query recent episodes + let recent_episodes = memory_system + .episodic + .get_episodes_in_period(chrono::Utc::now() - Duration::hours(1), chrono::Utc::now()) + .await?; + + println!("๐Ÿ“… Found {} recent episodes", recent_episodes.len()); + for episode in &recent_episodes { + if let Some(desc) = episode.content.get("description").and_then(|d| d.as_str()) { + println!(" - {}", desc); + } + } + + println!(); + + // Demonstrate Procedural Memory + println!("โš™๏ธ Procedural Memory Demo"); + println!("--------------------------"); + + // Store a procedure for database design + memory_system.procedural.store_procedure( + "database_design", + "design_web_app_schema", + "Standard procedure for designing database schema for web applications", + vec![ + json!({"step": 1, "action": "Identify main entities and their attributes"}), + json!({"step": 2, "action": "Define relationships between entities"}), + json!({"step": 3, "action": "Normalize the schema to reduce redundancy"}), + json!({"step": 4, "action": "Add indexes for performance optimization"}), + json!({"step": 5, "action": "Consider security and access patterns"}), + ], + Some(json!({ + "applicable_when": "designing new web application database", + "prerequisites": ["basic SQL knowledge", "understanding of application requirements"] + })), + 10, // High priority + ).await?; + + // Store a rule + memory_system + .procedural + .store_rule( + "consultation", + "provide_code_examples", + json!({ + "if": "user asks about implementation", + "and": "topic is within knowledge domain" + }), + json!({ + "then": "provide concrete code examples", + "include": ["comments", "error handling", "best practices"] + }), + 8, // Medium-high priority + true, // Enabled + ) + .await?; + + // Get procedures by category + let db_procedures = memory_system + .procedural + .get_procedures_by_category("database_design") + .await?; + println!( + "๐Ÿ“‹ Found {} database design procedures", + db_procedures.len() + ); + + for procedure in &db_procedures { + if let Some(name) = procedure.content.get("name").and_then(|n| n.as_str()) { + if let Some(steps) = procedure.content.get("steps").and_then(|s| s.as_array()) { + println!(" - {}: {} steps", name, steps.len()); + } + } + } + + // Get active rules + let consultation_rules = memory_system + .procedural + .get_active_rules_by_category("consultation") + .await?; + println!( + "๐Ÿ“ Found {} active consultation rules", + consultation_rules.len() + ); + + println!(); + + // Demonstrate System Operations + println!("๐Ÿ”ง System Operations Demo"); + println!("--------------------------"); + + // Create a checkpoint + let checkpoint_id = memory_system.checkpoint("Demo session complete").await?; + println!("๐Ÿ’พ Created checkpoint: {}", checkpoint_id); + + // Get system statistics + let stats = memory_system.get_system_stats().await?; + println!("๐Ÿ“Š System Statistics:"); + println!(" - Total memories: {}", stats.overall.total_memories); + println!(" - Short-term threads: {}", stats.short_term.total_threads); + println!( + " - Short-term conversations: {}", + stats.short_term.total_conversations + ); + println!(" - By type: {:?}", stats.overall.by_type); + + // Run system optimization + println!("\n๐Ÿงน Running system optimization..."); + let optimization_report = memory_system.optimize().await?; + println!("โœ… Optimization complete:"); + println!( + " - Expired cleaned: {}", + optimization_report.expired_cleaned + ); + println!( + " - Memories consolidated: {}", + optimization_report.memories_consolidated + ); + println!( + " - Memories archived: {}", + optimization_report.memories_archived + ); + println!( + " - Memories pruned: {}", + optimization_report.memories_pruned + ); + println!( + " - Total processed: {}", + optimization_report.total_processed() + ); + + println!("\n๐ŸŽ‰ Demo completed successfully!"); + println!("The agent memory system is now ready for production use."); + + Ok(()) +} diff --git a/src/agent/README.md b/src/agent/README.md new file mode 100644 index 0000000..9ad9784 --- /dev/null +++ b/src/agent/README.md @@ -0,0 +1,221 @@ +# Agent Memory System + +This document describes the Agent Memory System implemented for the ProllyTree project, which provides a comprehensive memory framework for AI agents with different types of memory and persistence. + +## Overview + +The Agent Memory System implements different types of memory inspired by human cognitive psychology: + +- **Short-Term Memory**: Session/thread-scoped memories with automatic expiration +- **Semantic Memory**: Long-term facts and concepts about entities +- **Episodic Memory**: Past experiences and interactions +- **Procedural Memory**: Rules, procedures, and decision-making guidelines + +## Architecture + +### Core Components + +1. **Types** (`src/agent_memory/types.rs`) + - Memory data structures and enums + - Namespace organization for hierarchical memory + - Query and filter types + +2. **Traits** (`src/agent_memory/traits.rs`) + - Abstract interfaces for memory operations + - Embedding generation and search capabilities + - Lifecycle management interfaces + +3. **Persistence** (`src/agent_memory/simple_persistence.rs`) + - Simple in-memory persistence for demonstration + - Designed to be replaced with prolly tree persistence + - Thread-safe async operations + +4. **Store** (`src/agent_memory/store.rs`) + - Base memory store implementation + - Handles serialization/deserialization + - Manages memory validation and access + +5. **Memory Types**: + - **Short-Term** (`src/agent_memory/short_term.rs`): Conversation history, working memory + - **Long-Term** (`src/agent_memory/long_term.rs`): Semantic, episodic, and procedural stores + +6. **Search** (`src/agent_memory/search.rs`) + - Memory search and retrieval capabilities + - Mock embedding generation + - Distance calculation utilities + +7. **Lifecycle** (`src/agent_memory/lifecycle.rs`) + - Memory consolidation and archival + - Cleanup and optimization + - Event broadcasting + +## Key Features + +### Memory Namespace Organization + +Memories are organized hierarchically using namespaces: +``` +/memory/agents/{agent_id}/{memory_type}/{sub_namespace} +``` + +For example: +- `/memory/agents/agent_001/ShortTerm/thread_123` +- `/memory/agents/agent_001/Semantic/person/john_doe` +- `/memory/agents/agent_001/Episodic/2025-01` + +### Memory Types and Use Cases + +#### Short-Term Memory +- **Conversation History**: Tracks dialogue between user and agent +- **Working Memory**: Temporary state and calculations +- **Session Context**: Current session information +- **Automatic Expiration**: TTL-based cleanup + +#### Semantic Memory +- **Entity Facts**: Store facts about people, places, concepts +- **Relationships**: Model connections between entities +- **Knowledge Base**: Persistent factual information + +#### Episodic Memory +- **Interactions**: Record past conversations and outcomes +- **Experiences**: Learn from past events +- **Time-Indexed**: Organized by temporal buckets + +#### Procedural Memory +- **Rules**: Conditional logic for decision making +- **Procedures**: Step-by-step instructions +- **Priority System**: Ordered execution of rules + +### Search and Retrieval + +- **Text Search**: Full-text search across memory content +- **Semantic Search**: Embedding-based similarity search (mock implementation) +- **Temporal Search**: Time-based memory retrieval +- **Tag-based Search**: Boolean logic with tags + +### Memory Lifecycle Management + +- **Consolidation**: Merge similar memories +- **Archival**: Move old memories to archive namespace +- **Pruning**: Remove low-value memories +- **Event System**: Track memory operations + +## Usage Example + +```rust +use prollytree::agent_memory::*; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Initialize memory system + let mut memory_system = AgentMemorySystem::init( + "/tmp/agent", + "agent_001".to_string(), + Some(Box::new(MockEmbeddingGenerator)), + )?; + + // Store conversation + memory_system.short_term.store_conversation_turn( + "thread_123", + "user", + "Hello, how are you?", + None, + ).await?; + + // Store facts + memory_system.semantic.store_fact( + "person", + "alice", + json!({"role": "developer", "experience": "5 years"}), + 0.9, + "user_input", + ).await?; + + // Store procedures + memory_system.procedural.store_procedure( + "coding", + "debug_rust_error", + "How to debug Rust compilation errors", + vec![ + json!({"step": 1, "action": "Read error message carefully"}), + json!({"step": 2, "action": "Check variable types"}), + ], + None, + 5, + ).await?; + + // Create checkpoint + memory_system.checkpoint("Session complete").await?; + + Ok(()) +} +``` + +## Implementation Status + +### Completed โœ… +- Core type definitions and interfaces +- Simple persistence layer +- All four memory types (Short-term, Semantic, Episodic, Procedural) +- Basic search functionality +- Memory lifecycle management +- Working demo example + +### Planned ๐Ÿšง +- Full prolly tree persistence integration (blocked by Send/Sync issues) +- Real embedding generation (currently uses mock) +- Advanced semantic search +- Memory conflict resolution +- Performance optimizations + +### Known Limitations +- Uses simple in-memory persistence instead of prolly tree +- Mock embedding generation +- Limited semantic search capabilities +- No conflict resolution for concurrent updates + +## Design Decisions + +1. **Hierarchical Namespaces**: Enables efficient organization and querying +2. **Trait-based Architecture**: Allows for different storage backends +3. **Async/Await**: Modern Rust async patterns throughout +4. **Event System**: Enables monitoring and debugging +5. **Type Safety**: Strong typing for memory operations +6. **Extensible Design**: Easy to add new memory types or features + +## Future Enhancements + +1. **True Prolly Tree Integration**: Once Send/Sync issues are resolved +2. **Real Embedding Models**: Integration with actual embedding services +3. **Conflict Resolution**: Handle concurrent memory updates +4. **Performance Metrics**: Track memory system performance +5. **Memory Compression**: Efficient storage of large memories +6. **Distributed Memory**: Support for multi-agent memory sharing + +## Running the Demo + +To see the memory system in action: + +```bash +cargo run --example agent_memory_demo +``` + +This demonstrates all four memory types, search capabilities, and system operations. + +## Testing + +The memory system includes comprehensive unit tests for each component. Run tests with: + +```bash +cargo test agent +``` + +## Contributing + +The memory system is designed to be modular and extensible. Key areas for contribution: + +1. Better persistence backends +2. Advanced search algorithms +3. Memory optimization strategies +4. Integration with ML/AI frameworks +5. Performance benchmarks \ No newline at end of file diff --git a/src/agent/lifecycle.rs b/src/agent/lifecycle.rs new file mode 100644 index 0000000..16a6d16 --- /dev/null +++ b/src/agent/lifecycle.rs @@ -0,0 +1,631 @@ +use async_trait::async_trait; +use chrono::{Duration, Utc}; +use std::collections::HashMap; +use tokio::sync::broadcast; + +use super::search::DistanceCalculator; +use super::traits::{MemoryError, MemoryLifecycle, MemoryStore}; +use super::types::*; + +/// Memory lifecycle manager for consolidation, archival, and cleanup +pub struct MemoryLifecycleManager { + store: T, + event_sender: Option>, + consolidation_rules: Vec, +} + +impl MemoryLifecycleManager { + pub fn new(store: T) -> Self { + Self { + store, + event_sender: None, + consolidation_rules: Vec::new(), + } + } + + /// Enable event broadcasting + pub fn enable_events(&mut self) -> broadcast::Receiver { + let (sender, receiver) = broadcast::channel(1000); + self.event_sender = Some(sender); + receiver + } + + /// Add a consolidation rule + pub fn add_consolidation_rule(&mut self, rule: ConsolidationRule) { + self.consolidation_rules.push(rule); + } + + /// Emit a memory event + fn emit_event(&self, event: MemoryEvent) { + if let Some(ref sender) = self.event_sender { + let _ = sender.send(event); // Ignore send errors + } + } + + /// Merge similar memories based on content similarity + async fn merge_similar_memories( + &mut self, + memories: Vec, + similarity_threshold: f64, + ) -> Result { + let mut merged_count = 0; + let mut to_delete = Vec::new(); + let mut clusters: Vec> = Vec::new(); + + // Simple clustering based on embedding similarity + for memory in memories { + let mut target_cluster_idx = None; + + for (idx, cluster) in clusters.iter().enumerate() { + if let (Some(ref embeddings1), Some(ref embeddings2)) = + (&memory.embeddings, &cluster[0].embeddings) + { + let similarity = + DistanceCalculator::cosine_similarity(embeddings1, embeddings2); + + if similarity >= similarity_threshold { + target_cluster_idx = Some(idx); + break; + } + } + } + + if let Some(idx) = target_cluster_idx { + clusters[idx].push(memory); + } else { + clusters.push(vec![memory]); + } + } + + // Merge clusters with multiple memories + for cluster in clusters { + if cluster.len() > 1 { + // Mark originals for deletion (except the first one which becomes the merged one) + for memory in &cluster[1..] { + to_delete.push(memory.id.clone()); + } + + let cluster_size = cluster.len(); + let namespace = cluster[0].namespace.clone(); + let merged_memory = self.merge_cluster(cluster).await?; + + // Store the merged memory + let merged_id = self.store.store(merged_memory).await?; + + merged_count += cluster_size - 1; + + self.emit_event(MemoryEvent::Updated { + memory_id: merged_id, + namespace, + timestamp: Utc::now(), + changes: vec!["merged_similar_memories".to_string()], + }); + } + } + + // Delete original memories + for id in to_delete { + self.store.delete(&id).await?; + } + + Ok(merged_count) + } + + /// Merge a cluster of similar memories into one + async fn merge_cluster( + &self, + cluster: Vec, + ) -> Result { + if cluster.is_empty() { + return Err(MemoryError::InvalidNamespace("Empty cluster".to_string())); + } + + let first = &cluster[0]; + let mut merged_content = serde_json::Map::new(); + + // Combine content from all memories + merged_content.insert( + "merged_from".to_string(), + serde_json::Value::Array( + cluster + .iter() + .map(|m| serde_json::Value::String(m.id.clone())) + .collect(), + ), + ); + + merged_content.insert( + "merged_at".to_string(), + serde_json::Value::String(Utc::now().to_rfc3339()), + ); + + // Collect all content + let mut all_content = Vec::new(); + for memory in &cluster { + all_content.push(memory.content.clone()); + } + merged_content.insert( + "contents".to_string(), + serde_json::Value::Array(all_content), + ); + + // Merge metadata + let mut merged_metadata = first.metadata.clone(); + merged_metadata.updated_at = Utc::now(); + merged_metadata.source = "memory_consolidation".to_string(); + + // Combine tags from all memories + let mut all_tags = std::collections::HashSet::new(); + for memory in &cluster { + for tag in &memory.metadata.tags { + all_tags.insert(tag.clone()); + } + } + merged_metadata.tags = all_tags.into_iter().collect(); + merged_metadata.tags.push("consolidated".to_string()); + + // Average confidence + let avg_confidence = + cluster.iter().map(|m| m.metadata.confidence).sum::() / cluster.len() as f64; + merged_metadata.confidence = avg_confidence; + + // Collect related memories + let mut all_related = std::collections::HashSet::new(); + for memory in &cluster { + for related in &memory.metadata.related_memories { + all_related.insert(related.clone()); + } + } + merged_metadata.related_memories = all_related.into_iter().collect(); + + Ok(MemoryDocument { + id: first.id.clone(), + namespace: first.namespace.clone(), + memory_type: first.memory_type, + content: serde_json::Value::Object(merged_content), + metadata: merged_metadata, + embeddings: first.embeddings.clone(), + }) + } + + /// Summarize memories into higher-level concepts + async fn summarize_memories( + &mut self, + memories: Vec, + max_memories: usize, + ) -> Result { + if memories.len() <= max_memories { + return Ok(0); + } + + // Group memories by namespace and type + let mut groups: HashMap<(MemoryNamespace, Vec), Vec> = + HashMap::new(); + + for memory in memories { + let key = (memory.namespace.clone(), memory.metadata.tags.clone()); + groups.entry(key).or_default().push(memory); + } + + let mut summarized_count = 0; + + for ((namespace, tags), mut group_memories) in groups { + if group_memories.len() > max_memories { + // Sort by importance/confidence + group_memories.sort_by(|a, b| { + b.metadata + .confidence + .partial_cmp(&a.metadata.confidence) + .unwrap() + }); + + // Keep the most important ones + let _to_keep = group_memories.split_off(max_memories); + let to_summarize = group_memories; + + // Create summary + let summary = self + .create_summary(&to_summarize, &namespace, &tags) + .await?; + + // Store summary + let summary_id = self.store.store(summary).await?; + + // Delete summarized memories + for memory in &to_summarize { + self.store.delete(&memory.id).await?; + } + + summarized_count += to_summarize.len(); + + self.emit_event(MemoryEvent::Created { + memory_id: summary_id, + namespace: namespace.clone(), + timestamp: Utc::now(), + }); + } + } + + Ok(summarized_count) + } + + /// Create a summary memory from a group of memories + async fn create_summary( + &self, + memories: &[MemoryDocument], + namespace: &MemoryNamespace, + tags: &[String], + ) -> Result { + let mut summary_content = serde_json::Map::new(); + + summary_content.insert( + "summary_type".to_string(), + serde_json::Value::String("automatic_consolidation".to_string()), + ); + + summary_content.insert( + "summarized_count".to_string(), + serde_json::Value::Number(memories.len().into()), + ); + + summary_content.insert( + "time_range".to_string(), + serde_json::json!({ + "start": memories.iter().map(|m| m.metadata.created_at).min(), + "end": memories.iter().map(|m| m.metadata.created_at).max() + }), + ); + + // Extract key themes/patterns + let mut all_text = String::new(); + for memory in memories { + all_text.push_str(&memory.content.to_string()); + all_text.push(' '); + } + + summary_content.insert( + "content_summary".to_string(), + serde_json::Value::String(self.extract_summary(&all_text)), + ); + + summary_content.insert( + "original_ids".to_string(), + serde_json::Value::Array( + memories + .iter() + .map(|m| serde_json::Value::String(m.id.clone())) + .collect(), + ), + ); + + let mut summary_metadata = MemoryMetadata::new( + namespace.agent_id.clone(), + "memory_summarization".to_string(), + ); + + summary_metadata.tags = tags.to_vec(); + summary_metadata.tags.push("summary".to_string()); + summary_metadata.confidence = + memories.iter().map(|m| m.metadata.confidence).sum::() / memories.len() as f64; + + Ok(MemoryDocument { + id: format!("summary_{}", uuid::Uuid::new_v4()), + namespace: namespace.clone(), + memory_type: memories[0].memory_type, + content: serde_json::Value::Object(summary_content), + metadata: summary_metadata, + embeddings: None, + }) + } + + /// Extract a simple summary from text (placeholder implementation) + fn extract_summary(&self, text: &str) -> String { + let words: Vec<&str> = text.split_whitespace().collect(); + if words.len() <= 50 { + text.to_string() + } else { + format!( + "Summary of {} words: {}...", + words.len(), + words[..50].join(" ") + ) + } + } + + /// Archive old memories to a different namespace + async fn archive_old_memories( + &mut self, + before: chrono::DateTime, + ) -> Result { + let query = MemoryQuery { + namespace: None, + memory_types: None, + tags: None, + time_range: Some(TimeRange { + start: None, + end: Some(before), + }), + text_query: None, + semantic_query: None, + limit: None, + include_expired: false, + }; + + let old_memories = self.store.query(query).await?; + let mut archived_count = 0; + + for mut memory in old_memories { + // Move to archive namespace + memory.namespace = MemoryNamespace::with_sub( + memory.namespace.agent_id.clone(), + memory.namespace.memory_type, + format!( + "archive/{}", + memory.namespace.sub_namespace.unwrap_or_default() + ), + ); + + memory.metadata.tags.push("archived".to_string()); + memory.metadata.updated_at = Utc::now(); + + // Update the memory + self.store.update(&memory.id, memory.clone()).await?; + archived_count += 1; + + self.emit_event(MemoryEvent::Updated { + memory_id: memory.id, + namespace: memory.namespace, + timestamp: Utc::now(), + changes: vec!["archived".to_string()], + }); + } + + Ok(archived_count) + } + + /// Prune low-value memories + async fn prune_memories( + &mut self, + confidence_threshold: f64, + access_threshold: u32, + ) -> Result { + let query = MemoryQuery { + namespace: None, + memory_types: None, + tags: None, + time_range: None, + text_query: None, + semantic_query: None, + limit: None, + include_expired: false, + }; + + let all_memories = self.store.query(query).await?; + let mut pruned_count = 0; + + for memory in all_memories { + // Check if memory should be pruned + if memory.metadata.confidence < confidence_threshold + || memory.metadata.access_count < access_threshold + { + self.store.delete(&memory.id).await?; + pruned_count += 1; + + self.emit_event(MemoryEvent::Deleted { + memory_id: memory.id, + namespace: memory.namespace, + timestamp: Utc::now(), + reason: "low_value_pruning".to_string(), + }); + } + } + + Ok(pruned_count) + } +} + +#[async_trait] +impl MemoryLifecycle for MemoryLifecycleManager { + async fn consolidate(&mut self, strategy: ConsolidationStrategy) -> Result { + match strategy { + ConsolidationStrategy::MergeSimilar { + similarity_threshold, + } => { + // Get all memories for merging + let query = MemoryQuery { + namespace: None, + memory_types: None, + tags: None, + time_range: None, + text_query: None, + semantic_query: None, + limit: None, + include_expired: false, + }; + + let memories = self.store.query(query).await?; + self.merge_similar_memories(memories, similarity_threshold) + .await + } + ConsolidationStrategy::Summarize { max_memories } => { + let query = MemoryQuery { + namespace: None, + memory_types: None, + tags: None, + time_range: None, + text_query: None, + semantic_query: None, + limit: None, + include_expired: false, + }; + + let memories = self.store.query(query).await?; + self.summarize_memories(memories, max_memories).await + } + ConsolidationStrategy::Archive { age_threshold } => { + let cutoff = Utc::now() - age_threshold; + self.archive_old_memories(cutoff).await + } + ConsolidationStrategy::Prune { + confidence_threshold, + access_threshold, + } => { + self.prune_memories(confidence_threshold, access_threshold) + .await + } + } + } + + async fn archive( + &mut self, + before: chrono::DateTime, + ) -> Result { + self.archive_old_memories(before).await + } + + async fn subscribe_events(&mut self, callback: F) -> Result<(), MemoryError> + where + F: Fn(MemoryEvent) + Send + Sync + 'static, + { + let mut receiver = self.enable_events(); + + tokio::spawn(async move { + while let Ok(event) = receiver.recv().await { + callback(event); + } + }); + + Ok(()) + } + + async fn get_history(&self, memory_id: &str) -> Result, MemoryError> { + // This would require implementing versioning at the storage level + // For now, return just the current version + if let Some(memory) = self.store.get(memory_id).await? { + Ok(vec![memory]) + } else { + Ok(vec![]) + } + } +} + +// Delegate MemoryStore methods to the wrapped store +#[async_trait] +impl MemoryStore for MemoryLifecycleManager { + async fn store(&mut self, memory: MemoryDocument) -> Result { + let id = self.store.store(memory.clone()).await?; + + self.emit_event(MemoryEvent::Created { + memory_id: id.clone(), + namespace: memory.namespace, + timestamp: Utc::now(), + }); + + Ok(id) + } + + async fn update(&mut self, id: &str, memory: MemoryDocument) -> Result<(), MemoryError> { + self.store.update(id, memory.clone()).await?; + + self.emit_event(MemoryEvent::Updated { + memory_id: id.to_string(), + namespace: memory.namespace, + timestamp: Utc::now(), + changes: vec!["manual_update".to_string()], + }); + + Ok(()) + } + + async fn get(&self, id: &str) -> Result, MemoryError> { + let result = self.store.get(id).await?; + + if let Some(ref memory) = result { + // Update access tracking would go here + self.emit_event(MemoryEvent::Accessed { + memory_id: id.to_string(), + namespace: memory.namespace.clone(), + timestamp: Utc::now(), + access_count: memory.metadata.access_count + 1, + }); + } + + Ok(result) + } + + async fn delete(&mut self, id: &str) -> Result<(), MemoryError> { + // Get memory info before deletion for event + let memory_info = self.store.get(id).await?; + + self.store.delete(id).await?; + + if let Some(memory) = memory_info { + self.emit_event(MemoryEvent::Deleted { + memory_id: id.to_string(), + namespace: memory.namespace, + timestamp: Utc::now(), + reason: "manual_deletion".to_string(), + }); + } + + Ok(()) + } + + async fn query(&self, query: MemoryQuery) -> Result, MemoryError> { + self.store.query(query).await + } + + async fn get_by_namespace( + &self, + namespace: &MemoryNamespace, + ) -> Result, MemoryError> { + self.store.get_by_namespace(namespace).await + } + + async fn commit(&mut self, message: &str) -> Result { + self.store.commit(message).await + } + + async fn create_branch(&mut self, name: &str) -> Result<(), MemoryError> { + self.store.create_branch(name).await + } + + async fn checkout(&mut self, branch_or_commit: &str) -> Result<(), MemoryError> { + self.store.checkout(branch_or_commit).await + } + + fn current_branch(&self) -> &str { + self.store.current_branch() + } + + async fn get_stats(&self) -> Result { + self.store.get_stats().await + } + + async fn cleanup_expired(&mut self) -> Result { + let count = self.store.cleanup_expired().await?; + + // Note: Individual expiry events would be emitted during cleanup + // This is a summary event for the cleanup operation + + Ok(count) + } +} + +/// Configuration for consolidation rules +#[derive(Debug, Clone)] +pub struct ConsolidationRule { + pub name: String, + pub trigger: ConsolidationTrigger, + pub strategy: ConsolidationStrategy, + pub enabled: bool, +} + +/// Triggers for automatic consolidation +#[derive(Debug, Clone)] +pub enum ConsolidationTrigger { + MemoryCount(usize), + TimeInterval(Duration), + StorageSize(usize), + Custom(fn(&MemoryStats) -> bool), +} diff --git a/src/agent/long_term.rs b/src/agent/long_term.rs new file mode 100644 index 0000000..d2239a0 --- /dev/null +++ b/src/agent/long_term.rs @@ -0,0 +1,616 @@ +use async_trait::async_trait; +use chrono::{Datelike, Utc}; +use serde_json::json; +use uuid::Uuid; + +use super::store::BaseMemoryStore; +use super::traits::{MemoryError, MemoryStore, SearchableMemoryStore}; +use super::types::*; + +/// Semantic memory store for facts, concepts, and knowledge +pub struct SemanticMemoryStore { + base_store: BaseMemoryStore, +} + +impl SemanticMemoryStore { + pub fn new(base_store: BaseMemoryStore) -> Self { + Self { base_store } + } + + /// Store a fact or concept + pub async fn store_fact( + &mut self, + entity_type: &str, + entity_id: &str, + fact: serde_json::Value, + confidence: f64, + source: &str, + ) -> Result { + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::Semantic, + format!("{entity_type}/{entity_id}"), + ); + + let mut metadata = + MemoryMetadata::new(self.base_store.agent_id().to_string(), source.to_string()); + metadata.confidence = confidence; + metadata.tags = vec!["fact".to_string(), entity_type.to_string()]; + + let content = json!({ + "entity_type": entity_type, + "entity_id": entity_id, + "fact": fact, + "confidence": confidence, + "timestamp": Utc::now() + }); + + let memory = MemoryDocument { + id: Uuid::new_v4().to_string(), + namespace, + memory_type: MemoryType::Semantic, + content, + metadata, + embeddings: None, + }; + + self.base_store.store(memory).await + } + + /// Store a relationship between entities + pub async fn store_relationship( + &mut self, + from_entity: (&str, &str), // (type, id) + to_entity: (&str, &str), // (type, id) + relationship_type: &str, + properties: Option, + confidence: f64, + ) -> Result { + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::Semantic, + "relationships".to_string(), + ); + + let mut metadata = MemoryMetadata::new( + self.base_store.agent_id().to_string(), + "relationship_inference".to_string(), + ); + metadata.confidence = confidence; + metadata.tags = vec!["relationship".to_string(), relationship_type.to_string()]; + + let content = json!({ + "from_entity": { + "type": from_entity.0, + "id": from_entity.1 + }, + "to_entity": { + "type": to_entity.0, + "id": to_entity.1 + }, + "relationship_type": relationship_type, + "properties": properties, + "confidence": confidence, + "timestamp": Utc::now() + }); + + let memory = MemoryDocument { + id: Uuid::new_v4().to_string(), + namespace, + memory_type: MemoryType::Semantic, + content, + metadata, + embeddings: None, + }; + + self.base_store.store(memory).await + } + + /// Get all facts about an entity + pub async fn get_entity_facts( + &self, + entity_type: &str, + entity_id: &str, + ) -> Result, MemoryError> { + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::Semantic, + format!("{entity_type}/{entity_id}"), + ); + + self.base_store.get_by_namespace(&namespace).await + } + + /// Get relationships involving an entity + pub async fn get_entity_relationships( + &self, + entity_type: &str, + entity_id: &str, + ) -> Result, MemoryError> { + let query = MemoryQuery { + namespace: Some(MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::Semantic, + "relationships".to_string(), + )), + memory_types: Some(vec![MemoryType::Semantic]), + tags: Some(vec!["relationship".to_string()]), + time_range: None, + text_query: Some(format!("{entity_type}:{entity_id}")), + semantic_query: None, + limit: None, + include_expired: false, + }; + + self.base_store.query(query).await + } +} + +/// Episodic memory store for experiences and events +pub struct EpisodicMemoryStore { + base_store: BaseMemoryStore, +} + +impl EpisodicMemoryStore { + pub fn new(base_store: BaseMemoryStore) -> Self { + Self { base_store } + } + + /// Store an episode/experience + pub async fn store_episode( + &mut self, + episode_type: &str, + description: &str, + context: serde_json::Value, + outcome: Option, + importance: f64, + ) -> Result { + let timestamp = Utc::now(); + let time_bucket = format!("{}-{:02}", timestamp.year(), timestamp.month()); + + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::Episodic, + time_bucket, + ); + + let mut metadata = MemoryMetadata::new( + self.base_store.agent_id().to_string(), + "experience".to_string(), + ); + metadata.confidence = importance; + metadata.tags = vec!["episode".to_string(), episode_type.to_string()]; + + let content = json!({ + "episode_type": episode_type, + "description": description, + "context": context, + "outcome": outcome, + "importance": importance, + "timestamp": timestamp + }); + + let memory = MemoryDocument { + id: Uuid::new_v4().to_string(), + namespace, + memory_type: MemoryType::Episodic, + content, + metadata, + embeddings: None, + }; + + self.base_store.store(memory).await + } + + /// Store an interaction + pub async fn store_interaction( + &mut self, + interaction_type: &str, + participants: Vec, + summary: &str, + details: serde_json::Value, + sentiment: Option, + ) -> Result { + let timestamp = Utc::now(); + let time_bucket = format!("{}-{:02}", timestamp.year(), timestamp.month()); + + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::Episodic, + time_bucket, + ); + + let mut metadata = MemoryMetadata::new( + self.base_store.agent_id().to_string(), + "interaction".to_string(), + ); + metadata.tags = vec!["interaction".to_string(), interaction_type.to_string()]; + + let content = json!({ + "interaction_type": interaction_type, + "participants": participants, + "summary": summary, + "details": details, + "sentiment": sentiment, + "timestamp": timestamp + }); + + let memory = MemoryDocument { + id: Uuid::new_v4().to_string(), + namespace, + memory_type: MemoryType::Episodic, + content, + metadata, + embeddings: None, + }; + + self.base_store.store(memory).await + } + + /// Get episodes from a time period + pub async fn get_episodes_in_period( + &self, + start: chrono::DateTime, + end: chrono::DateTime, + ) -> Result, MemoryError> { + let query = MemoryQuery { + namespace: None, + memory_types: Some(vec![MemoryType::Episodic]), + tags: Some(vec!["episode".to_string()]), + time_range: Some(TimeRange { + start: Some(start), + end: Some(end), + }), + text_query: None, + semantic_query: None, + limit: None, + include_expired: false, + }; + + self.base_store.query(query).await + } +} + +/// Procedural memory store for rules, instructions, and procedures +pub struct ProceduralMemoryStore { + base_store: BaseMemoryStore, +} + +impl ProceduralMemoryStore { + pub fn new(base_store: BaseMemoryStore) -> Self { + Self { base_store } + } + + /// Store a rule or procedure + pub async fn store_procedure( + &mut self, + category: &str, + name: &str, + description: &str, + steps: Vec, + conditions: Option, + priority: u32, + ) -> Result { + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::Procedural, + category.to_string(), + ); + + let mut metadata = MemoryMetadata::new( + self.base_store.agent_id().to_string(), + "procedure_definition".to_string(), + ); + metadata.tags = vec![ + "procedure".to_string(), + category.to_string(), + name.to_string(), + ]; + + let content = json!({ + "category": category, + "name": name, + "description": description, + "steps": steps, + "conditions": conditions, + "priority": priority, + "timestamp": Utc::now() + }); + + let memory = MemoryDocument { + id: format!("procedure_{category}_{name}"), + namespace, + memory_type: MemoryType::Procedural, + content, + metadata, + embeddings: None, + }; + + self.base_store.store(memory).await + } + + /// Store a rule + pub async fn store_rule( + &mut self, + category: &str, + rule_name: &str, + condition: serde_json::Value, + action: serde_json::Value, + priority: u32, + enabled: bool, + ) -> Result { + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::Procedural, + format!("{category}/rules"), + ); + + let mut metadata = MemoryMetadata::new( + self.base_store.agent_id().to_string(), + "rule_definition".to_string(), + ); + metadata.tags = vec![ + "rule".to_string(), + category.to_string(), + rule_name.to_string(), + ]; + + let content = json!({ + "category": category, + "rule_name": rule_name, + "condition": condition, + "action": action, + "priority": priority, + "enabled": enabled, + "timestamp": Utc::now() + }); + + let memory = MemoryDocument { + id: format!("rule_{category}_{rule_name}"), + namespace, + memory_type: MemoryType::Procedural, + content, + metadata, + embeddings: None, + }; + + self.base_store.store(memory).await + } + + /// Get procedures by category + pub async fn get_procedures_by_category( + &self, + category: &str, + ) -> Result, MemoryError> { + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::Procedural, + category.to_string(), + ); + + let mut memories = self.base_store.get_by_namespace(&namespace).await?; + + // Filter for procedures only + memories.retain(|m| m.metadata.tags.contains(&"procedure".to_string())); + + // Sort by priority + memories.sort_by(|a, b| { + let priority_a = a + .content + .get("priority") + .and_then(|p| p.as_u64()) + .unwrap_or(0); + let priority_b = b + .content + .get("priority") + .and_then(|p| p.as_u64()) + .unwrap_or(0); + priority_b.cmp(&priority_a) // Higher priority first + }); + + Ok(memories) + } + + /// Get active rules by category + pub async fn get_active_rules_by_category( + &self, + category: &str, + ) -> Result, MemoryError> { + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::Procedural, + format!("{category}/rules"), + ); + + let mut memories = self.base_store.get_by_namespace(&namespace).await?; + + // Filter for active rules only + memories.retain(|m| { + m.metadata.tags.contains(&"rule".to_string()) + && m.content + .get("enabled") + .and_then(|e| e.as_bool()) + .unwrap_or(false) + }); + + // Sort by priority + memories.sort_by(|a, b| { + let priority_a = a + .content + .get("priority") + .and_then(|p| p.as_u64()) + .unwrap_or(0); + let priority_b = b + .content + .get("priority") + .and_then(|p| p.as_u64()) + .unwrap_or(0); + priority_b.cmp(&priority_a) // Higher priority first + }); + + Ok(memories) + } + + /// Update rule status + pub async fn update_rule_status( + &mut self, + category: &str, + rule_name: &str, + enabled: bool, + ) -> Result<(), MemoryError> { + let rule_id = format!("rule_{category}_{rule_name}"); + + if let Some(mut memory) = self.base_store.get(&rule_id).await? { + // Update the enabled status + if let serde_json::Value::Object(ref mut map) = memory.content { + map.insert("enabled".to_string(), json!(enabled)); + map.insert("last_modified".to_string(), json!(Utc::now())); + } + + self.base_store.update(&rule_id, memory).await?; + } + + Ok(()) + } +} + +// Implement MemoryStore trait for each long-term store by delegating to base_store +macro_rules! impl_memory_store_delegate { + ($store_type:ty) => { + #[async_trait] + impl MemoryStore for $store_type { + async fn store(&mut self, memory: MemoryDocument) -> Result { + self.base_store.store(memory).await + } + + async fn update( + &mut self, + id: &str, + memory: MemoryDocument, + ) -> Result<(), MemoryError> { + self.base_store.update(id, memory).await + } + + async fn get(&self, id: &str) -> Result, MemoryError> { + self.base_store.get(id).await + } + + async fn delete(&mut self, id: &str) -> Result<(), MemoryError> { + self.base_store.delete(id).await + } + + async fn query(&self, query: MemoryQuery) -> Result, MemoryError> { + self.base_store.query(query).await + } + + async fn get_by_namespace( + &self, + namespace: &MemoryNamespace, + ) -> Result, MemoryError> { + self.base_store.get_by_namespace(namespace).await + } + + async fn commit(&mut self, message: &str) -> Result { + self.base_store.commit(message).await + } + + async fn create_branch(&mut self, name: &str) -> Result<(), MemoryError> { + self.base_store.create_branch(name).await + } + + async fn checkout(&mut self, branch_or_commit: &str) -> Result<(), MemoryError> { + self.base_store.checkout(branch_or_commit).await + } + + fn current_branch(&self) -> &str { + self.base_store.current_branch() + } + + async fn get_stats(&self) -> Result { + self.base_store.get_stats().await + } + + async fn cleanup_expired(&mut self) -> Result { + self.base_store.cleanup_expired().await + } + } + }; +} + +impl_memory_store_delegate!(SemanticMemoryStore); +impl_memory_store_delegate!(EpisodicMemoryStore); +impl_memory_store_delegate!(ProceduralMemoryStore); + +// Implement SearchableMemoryStore for semantic memory (as it's most suitable for search) +#[async_trait] +impl SearchableMemoryStore for SemanticMemoryStore { + async fn semantic_search( + &self, + _query: SemanticQuery, + namespace: Option<&MemoryNamespace>, + ) -> Result, MemoryError> { + // This would implement actual semantic search using embeddings + // For now, fall back to text search + let text_query = ""; // Would extract from semantic query + let memories = self.text_search(text_query, namespace).await?; + + // Convert to scored results (placeholder scores) + let scored_results = memories + .into_iter() + .map(|m| (m, 0.5)) // Placeholder similarity score + .collect(); + + Ok(scored_results) + } + + async fn text_search( + &self, + query: &str, + namespace: Option<&MemoryNamespace>, + ) -> Result, MemoryError> { + let memory_query = MemoryQuery { + namespace: namespace.cloned(), + memory_types: Some(vec![MemoryType::Semantic]), + tags: None, + time_range: None, + text_query: Some(query.to_string()), + semantic_query: None, + limit: None, + include_expired: false, + }; + + self.base_store.query(memory_query).await + } + + async fn find_related( + &self, + memory_id: &str, + limit: usize, + ) -> Result, MemoryError> { + // Get the source memory + if let Some(source_memory) = self.base_store.get(memory_id).await? { + // Find memories with related tags or content + let query = MemoryQuery { + namespace: None, + memory_types: Some(vec![MemoryType::Semantic]), + tags: Some(source_memory.metadata.tags.clone()), + time_range: None, + text_query: None, + semantic_query: None, + limit: Some(limit), + include_expired: false, + }; + + let mut related = self.base_store.query(query).await?; + + // Remove the source memory from results + related.retain(|m| m.id != memory_id); + + Ok(related) + } else { + Err(MemoryError::NotFound(format!( + "Memory {memory_id} not found" + ))) + } + } +} diff --git a/src/agent/mod.rs b/src/agent/mod.rs new file mode 100644 index 0000000..9463cda --- /dev/null +++ b/src/agent/mod.rs @@ -0,0 +1,327 @@ +//! Agent Memory System +//! +//! This module provides a comprehensive memory system for AI agents, implementing +//! different types of memory (short-term, semantic, episodic, procedural) with +//! persistence backed by git-based prolly trees. +//! +//! # Architecture +//! +//! The memory system is built on several key components: +//! +//! - **Types**: Core data structures and enums for memory representation +//! - **Traits**: Abstract interfaces for memory operations and lifecycle +//! - **Persistence**: Git-based prolly tree storage backend +//! - **Store**: Base memory store implementation +//! - **Memory Types**: Specialized stores for different memory types +//! - **Search**: Advanced search and retrieval capabilities +//! - **Lifecycle**: Memory consolidation, archival, and cleanup +//! +//! # Memory Types +//! +//! ## Short-Term Memory +//! - Session/thread scoped memories +//! - Automatic expiration (TTL) +//! - Conversation history +//! - Working memory for temporary state +//! +//! ## Semantic Memory +//! - Facts and concepts about entities +//! - Relationships between entities +//! - Knowledge representation +//! +//! ## Episodic Memory +//! - Past experiences and interactions +//! - Time-indexed memories +//! - Context-rich event storage +//! +//! ## Procedural Memory +//! - Rules and procedures +//! - Task instructions +//! - Decision-making guidelines +//! +//! # Usage Example +//! +//! ```rust,no_run +//! use prollytree::agent::*; +//! use chrono::Duration; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box> { +//! // Initialize the base memory store +//! let base_store = BaseMemoryStore::init( +//! "/tmp/agent_memory", +//! "agent_001".to_string(), +//! None, // No embedding generator for this example +//! )?; +//! +//! // Create short-term memory store +//! let mut short_term = ShortTermMemoryStore::new( +//! base_store, +//! Duration::hours(24), // 24-hour TTL +//! 100, // Max 100 memories per thread +//! ); +//! +//! // Store a conversation turn +//! short_term.store_conversation_turn( +//! "thread_123", +//! "user", +//! "Hello, how are you?", +//! None, +//! ).await?; +//! +//! short_term.store_conversation_turn( +//! "thread_123", +//! "assistant", +//! "I'm doing well, thank you for asking!", +//! None, +//! ).await?; +//! +//! // Retrieve conversation history +//! let history = short_term.get_conversation_history("thread_123", None).await?; +//! println!("Conversation history: {} messages", history.len()); +//! +//! // Commit changes +//! short_term.commit("Store initial conversation").await?; +//! +//! Ok(()) +//! } +//! ``` + +pub mod traits; +pub mod types; +// pub mod persistence; // Disabled due to Send/Sync issues with GitVersionedKvStore +pub mod lifecycle; +pub mod long_term; +pub mod search; +pub mod short_term; +pub mod simple_persistence; +pub mod store; + +// Re-export main types and traits for convenience +pub use traits::*; +pub use types::*; +// pub use persistence::ProllyMemoryPersistence; // Disabled +pub use lifecycle::MemoryLifecycleManager; +pub use long_term::{EpisodicMemoryStore, ProceduralMemoryStore, SemanticMemoryStore}; +pub use search::{DistanceCalculator, MemorySearchEngine, MockEmbeddingGenerator}; +pub use short_term::ShortTermMemoryStore; +pub use simple_persistence::SimpleMemoryPersistence; +pub use store::BaseMemoryStore; + +/// High-level memory system that combines all memory types +pub struct AgentMemorySystem { + pub short_term: ShortTermMemoryStore, + pub semantic: SemanticMemoryStore, + pub episodic: EpisodicMemoryStore, + pub procedural: ProceduralMemoryStore, + pub lifecycle_manager: MemoryLifecycleManager, +} + +impl AgentMemorySystem { + /// Initialize a complete agent memory system + pub fn init>( + path: P, + agent_id: String, + embedding_generator: Option>, + ) -> Result> { + let base_store = BaseMemoryStore::init(path, agent_id.clone(), embedding_generator)?; + + let short_term = + ShortTermMemoryStore::new(base_store.clone(), chrono::Duration::hours(24), 1000); + + let semantic = SemanticMemoryStore::new(base_store.clone()); + let episodic = EpisodicMemoryStore::new(base_store.clone()); + let procedural = ProceduralMemoryStore::new(base_store.clone()); + let lifecycle_manager = MemoryLifecycleManager::new(base_store); + + Ok(Self { + short_term, + semantic, + episodic, + procedural, + lifecycle_manager, + }) + } + + /// Open an existing agent memory system + pub fn open>( + path: P, + agent_id: String, + embedding_generator: Option>, + ) -> Result> { + let base_store = BaseMemoryStore::open(path, agent_id.clone(), embedding_generator)?; + + let short_term = + ShortTermMemoryStore::new(base_store.clone(), chrono::Duration::hours(24), 1000); + + let semantic = SemanticMemoryStore::new(base_store.clone()); + let episodic = EpisodicMemoryStore::new(base_store.clone()); + let procedural = ProceduralMemoryStore::new(base_store.clone()); + let lifecycle_manager = MemoryLifecycleManager::new(base_store); + + Ok(Self { + short_term, + semantic, + episodic, + procedural, + lifecycle_manager, + }) + } + + /// Get comprehensive memory statistics + pub async fn get_system_stats(&self) -> Result { + let short_term_stats = self.short_term.get_short_term_stats().await?; + let overall_stats = self.lifecycle_manager.get_stats().await?; + + Ok(AgentMemoryStats { + overall: overall_stats, + short_term: short_term_stats, + }) + } + + /// Perform system-wide cleanup and optimization + pub async fn optimize(&mut self) -> Result { + // Cleanup expired memories + let expired_cleaned = self.lifecycle_manager.cleanup_expired().await?; + + // Consolidate similar memories + let memories_consolidated = self + .lifecycle_manager + .consolidate(ConsolidationStrategy::MergeSimilar { + similarity_threshold: 0.8, + }) + .await?; + + // Archive old memories (older than 30 days) + let cutoff = chrono::Utc::now() - chrono::Duration::days(30); + let memories_archived = self.lifecycle_manager.archive(cutoff).await?; + + // Prune low-value memories + let memories_pruned = self + .lifecycle_manager + .consolidate(ConsolidationStrategy::Prune { + confidence_threshold: 0.1, + access_threshold: 0, + }) + .await?; + + Ok(OptimizationReport { + expired_cleaned, + memories_consolidated, + memories_archived, + memories_pruned, + }) + } + + /// Create a memory checkpoint + pub async fn checkpoint(&mut self, message: &str) -> Result { + self.lifecycle_manager.commit(message).await + } +} + +/// Combined statistics for the entire memory system +#[derive(Debug, Clone)] +pub struct AgentMemoryStats { + pub overall: MemoryStats, + pub short_term: short_term::ShortTermStats, +} + +/// Report from memory optimization operations +#[derive(Debug, Clone, Default)] +pub struct OptimizationReport { + pub expired_cleaned: usize, + pub memories_consolidated: usize, + pub memories_archived: usize, + pub memories_pruned: usize, +} + +impl OptimizationReport { + pub fn total_processed(&self) -> usize { + self.expired_cleaned + + self.memories_consolidated + + self.memories_archived + + self.memories_pruned + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[tokio::test] + async fn test_agent_memory_system_basic() { + let temp_dir = TempDir::new().unwrap(); + let mut memory_system = + AgentMemorySystem::init(temp_dir.path(), "test_agent".to_string(), None).unwrap(); + + // Test short-term memory + let conversation_id = memory_system + .short_term + .store_conversation_turn("test_thread", "user", "Hello world", None) + .await + .unwrap(); + + assert!(!conversation_id.is_empty()); + + // Test semantic memory + let fact_id = memory_system + .semantic + .store_fact( + "person", + "john_doe", + serde_json::json!({"age": 30, "occupation": "developer"}), + 0.9, + "user_input", + ) + .await + .unwrap(); + + assert!(!fact_id.is_empty()); + + // Test procedural memory + let procedure_id = memory_system + .procedural + .store_procedure( + "task_management", + "create_task", + "How to create a new task", + vec![ + serde_json::json!({"step": 1, "action": "Define task name"}), + serde_json::json!({"step": 2, "action": "Set priority"}), + ], + None, + 1, + ) + .await + .unwrap(); + + assert!(!procedure_id.is_empty()); + + // Test system stats + let stats = memory_system.get_system_stats().await.unwrap(); + assert!(stats.overall.total_memories > 0); + } + + #[tokio::test] + async fn test_memory_optimization() { + let temp_dir = TempDir::new().unwrap(); + let mut memory_system = + AgentMemorySystem::init(temp_dir.path(), "test_agent".to_string(), None).unwrap(); + + // Add some test memories + for i in 0..10 { + memory_system + .short_term + .store_conversation_turn("test_thread", "user", &format!("Message {}", i), None) + .await + .unwrap(); + } + + // Run optimization + let report = memory_system.optimize().await.unwrap(); + + // Should have processed some memories + assert!(report.total_processed() >= 0); // Might be 0 if no optimization needed + } +} diff --git a/src/agent/persistence.rs b/src/agent/persistence.rs new file mode 100644 index 0000000..cd571f6 --- /dev/null +++ b/src/agent/persistence.rs @@ -0,0 +1,233 @@ +use async_trait::async_trait; +use std::error::Error; +use std::path::Path; +use std::sync::Arc; +use tokio::sync::RwLock; + +use crate::git::{GitKvError, GitVersionedKvStore}; +use super::traits::{MemoryPersistence, MemoryError}; + +// Since GitVersionedKvStore doesn't implement Send/Sync due to gix::Repository limitations, +// we'll need to work around this. For now, let's use a simpler approach. + +/// Prolly tree-based persistence for agent memory +pub struct ProllyMemoryPersistence { + store: Arc>>, + namespace_prefix: String, +} + +impl ProllyMemoryPersistence { + /// Initialize a new prolly memory persistence layer + pub fn init>(path: P, namespace_prefix: &str) -> Result { + let store = GitVersionedKvStore::<32>::init(path)?; + Ok(Self { + store: Arc::new(RwLock::new(store)), + namespace_prefix: namespace_prefix.to_string(), + }) + } + + /// Open an existing prolly memory persistence layer + pub fn open>(path: P, namespace_prefix: &str) -> Result { + let store = GitVersionedKvStore::<32>::open(path)?; + Ok(Self { + store: Arc::new(RwLock::new(store)), + namespace_prefix: namespace_prefix.to_string(), + }) + } + + /// Get the full key with namespace prefix + fn full_key(&self, key: &str) -> Vec { + format!("{}/{}", self.namespace_prefix, key).into_bytes() + } + + /// Convert GitKvError to Box + fn convert_error(err: GitKvError) -> Box { + Box::new(err) as Box + } +} + +#[async_trait] +impl MemoryPersistence for ProllyMemoryPersistence { + async fn save(&mut self, key: &str, data: &[u8]) -> Result<(), Box> { + let mut store = self.store.write().await; + let full_key = self.full_key(key); + + // Check if key exists to decide between insert and update + match store.get(&full_key) { + Some(_) => { + store.update(full_key, data.to_vec()) + .map_err(Self::convert_error)?; + } + None => { + store.insert(full_key, data.to_vec()) + .map_err(Self::convert_error)?; + } + } + + Ok(()) + } + + async fn load(&self, key: &str) -> Result>, Box> { + let store = self.store.read().await; + let full_key = self.full_key(key); + Ok(store.get(&full_key)) + } + + async fn delete(&mut self, key: &str) -> Result<(), Box> { + let mut store = self.store.write().await; + let full_key = self.full_key(key); + store.delete(&full_key).map_err(Self::convert_error)?; + Ok(()) + } + + async fn list_keys(&self, prefix: &str) -> Result, Box> { + let store = self.store.read().await; + let full_prefix = format!("{}/{}", self.namespace_prefix, prefix); + let prefix_bytes = full_prefix.as_bytes(); + + let keys = store.keys() + .filter(|k| k.starts_with(prefix_bytes)) + .map(|k| { + String::from_utf8_lossy(k) + .strip_prefix(&format!("{}/", self.namespace_prefix)) + .unwrap_or("") + .to_string() + }) + .collect(); + + Ok(keys) + } + + async fn checkpoint(&mut self, message: &str) -> Result> { + let mut store = self.store.write().await; + let commit_id = store.commit(message).map_err(Self::convert_error)?; + Ok(commit_id.to_hex().to_string()) + } +} + +/// Additional methods specific to git-based persistence +impl ProllyMemoryPersistence { + /// Create a new branch + pub async fn create_branch(&mut self, name: &str) -> Result<(), MemoryError> { + let mut store = self.store.write().await; + store.create_branch(name) + .map_err(|e| MemoryError::StorageError(format!("Failed to create branch: {:?}", e))) + } + + /// Switch to a branch or commit + pub async fn checkout(&mut self, branch_or_commit: &str) -> Result<(), MemoryError> { + let mut store = self.store.write().await; + store.checkout(branch_or_commit) + .map_err(|e| MemoryError::StorageError(format!("Failed to checkout: {:?}", e))) + } + + /// Get current branch name + pub async fn current_branch(&self) -> String { + let store = self.store.read().await; + store.current_branch().to_string() + } + + /// List all branches + pub async fn list_branches(&self) -> Result, MemoryError> { + let store = self.store.read().await; + store.list_branches() + .map_err(|e| MemoryError::StorageError(format!("Failed to list branches: {:?}", e))) + } + + /// Get status of staged changes + pub async fn status(&self) -> Vec<(Vec, String)> { + let store = self.store.read().await; + store.status() + } + + /// Merge another branch + pub async fn merge(&mut self, branch: &str) -> Result { + let mut store = self.store.write().await; + let result = store.merge(branch) + .map_err(|e| MemoryError::StorageError(format!("Failed to merge: {:?}", e)))?; + Ok(format!("{:?}", result)) + } + + /// Get history of commits + pub async fn history(&self, limit: Option) -> Result, MemoryError> { + // This would need to be implemented using git operations + // For now, return a placeholder + Ok(vec!["History not yet implemented".to_string()]) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[tokio::test] + async fn test_prolly_persistence_basic_operations() { + let temp_dir = TempDir::new().unwrap(); + let mut persistence = ProllyMemoryPersistence::init( + temp_dir.path(), + "test_memories" + ).unwrap(); + + // Test save + let key = "test_key"; + let data = b"test_data"; + persistence.save(key, data).await.unwrap(); + + // Test load + let loaded = persistence.load(key).await.unwrap(); + assert_eq!(loaded, Some(data.to_vec())); + + // Test update + let new_data = b"updated_data"; + persistence.save(key, new_data).await.unwrap(); + let loaded = persistence.load(key).await.unwrap(); + assert_eq!(loaded, Some(new_data.to_vec())); + + // Test delete + persistence.delete(key).await.unwrap(); + let loaded = persistence.load(key).await.unwrap(); + assert_eq!(loaded, None); + } + + #[tokio::test] + async fn test_prolly_persistence_checkpoint() { + let temp_dir = TempDir::new().unwrap(); + let mut persistence = ProllyMemoryPersistence::init( + temp_dir.path(), + "test_memories" + ).unwrap(); + + // Save some data + persistence.save("key1", b"data1").await.unwrap(); + persistence.save("key2", b"data2").await.unwrap(); + + // Create checkpoint + let commit_id = persistence.checkpoint("Test checkpoint").await.unwrap(); + assert!(!commit_id.is_empty()); + } + + #[tokio::test] + async fn test_prolly_persistence_list_keys() { + let temp_dir = TempDir::new().unwrap(); + let mut persistence = ProllyMemoryPersistence::init( + temp_dir.path(), + "test_memories" + ).unwrap(); + + // Save data with different prefixes + persistence.save("user/1", b"user1").await.unwrap(); + persistence.save("user/2", b"user2").await.unwrap(); + persistence.save("system/config", b"config").await.unwrap(); + + // List keys with prefix + let user_keys = persistence.list_keys("user").await.unwrap(); + assert_eq!(user_keys.len(), 2); + assert!(user_keys.contains(&"user/1".to_string())); + assert!(user_keys.contains(&"user/2".to_string())); + + let system_keys = persistence.list_keys("system").await.unwrap(); + assert_eq!(system_keys.len(), 1); + assert!(system_keys.contains(&"system/config".to_string())); + } +} \ No newline at end of file diff --git a/src/agent/search.rs b/src/agent/search.rs new file mode 100644 index 0000000..6cccc55 --- /dev/null +++ b/src/agent/search.rs @@ -0,0 +1,379 @@ +use async_trait::async_trait; +use std::collections::HashMap; + +use super::traits::{EmbeddingGenerator, MemoryError, SearchableMemoryStore}; +use super::types::*; + +/// Simple embedding generator that uses text length as a proxy +/// In a real implementation, this would use a proper embedding model +pub struct MockEmbeddingGenerator; + +#[async_trait] +impl EmbeddingGenerator for MockEmbeddingGenerator { + async fn generate(&self, text: &str) -> Result, Box> { + // Mock implementation: create a simple vector based on text characteristics + let words: Vec<&str> = text.split_whitespace().collect(); + let word_count = words.len() as f32; + let char_count = text.len() as f32; + let avg_word_length = if word_count > 0.0 { + char_count / word_count + } else { + 0.0 + }; + + // Create a 384-dimensional vector (common embedding size) + let mut embedding = vec![0.0; 384]; + + // Fill with simple features + embedding[0] = word_count / 100.0; // Normalized word count + embedding[1] = char_count / 1000.0; // Normalized character count + embedding[2] = avg_word_length / 10.0; // Average word length + + // Add some pseudo-random elements based on text content + for (i, word) in words.iter().take(50).enumerate() { + let word_hash = self.simple_hash(word) % 100; + if i + 3 < embedding.len() { + embedding[i + 3] = (word_hash as f32) / 100.0; + } + } + + // Normalize the vector + let magnitude: f32 = embedding.iter().map(|x| x * x).sum::().sqrt(); + if magnitude > 0.0 { + for val in &mut embedding { + *val /= magnitude; + } + } + + Ok(embedding) + } + + async fn generate_batch( + &self, + texts: &[String], + ) -> Result>, Box> { + let mut embeddings = Vec::new(); + for text in texts { + embeddings.push(self.generate(text).await?); + } + Ok(embeddings) + } +} + +impl MockEmbeddingGenerator { + fn simple_hash(&self, s: &str) -> usize { + s.chars() + .fold(0, |acc, c| acc.wrapping_mul(31).wrapping_add(c as usize)) + } +} + +/// Advanced search functionality for memory stores +pub struct MemorySearchEngine { + store: T, + #[allow(dead_code)] + embedding_cache: HashMap>, +} + +impl MemorySearchEngine { + pub fn new(store: T) -> Self { + Self { + store, + embedding_cache: HashMap::new(), + } + } + + /// Perform hybrid search combining text and semantic search + pub async fn hybrid_search( + &mut self, + text_query: &str, + semantic_weight: f64, + text_weight: f64, + namespace: Option<&MemoryNamespace>, + limit: usize, + ) -> Result, MemoryError> { + // Get text search results + let text_results = self.store.text_search(text_query, namespace).await?; + + // Create mock semantic query + let mock_embeddings = vec![0.0; 384]; // Would generate from text_query + let semantic_query = SemanticQuery { + embeddings: mock_embeddings, + threshold: 0.1, + metric: DistanceMetric::Cosine, + }; + + // Get semantic search results + let semantic_results = self + .store + .semantic_search(semantic_query, namespace) + .await?; + + // Combine and score results + let mut combined_scores: HashMap = HashMap::new(); + let mut all_memories: HashMap = HashMap::new(); + + // Add text search scores + for memory in text_results { + let score = text_weight * self.calculate_text_relevance(text_query, &memory); + combined_scores.insert(memory.id.clone(), score); + all_memories.insert(memory.id.clone(), memory); + } + + // Add semantic search scores + for (memory, sem_score) in semantic_results { + let existing_score = combined_scores.get(&memory.id).unwrap_or(&0.0); + let total_score = existing_score + (semantic_weight * sem_score); + combined_scores.insert(memory.id.clone(), total_score); + all_memories.insert(memory.id.clone(), memory); + } + + // Sort by combined score and return top results + let mut scored_results: Vec<(MemoryDocument, f64)> = combined_scores + .into_iter() + .filter_map(|(id, score)| all_memories.remove(&id).map(|memory| (memory, score))) + .collect(); + + scored_results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + scored_results.truncate(limit); + + Ok(scored_results) + } + + /// Find memories similar to a given memory + pub async fn find_similar_memories( + &mut self, + reference_memory_id: &str, + similarity_threshold: f64, + limit: usize, + ) -> Result, MemoryError> { + // This would use the reference memory's embeddings to find similar ones + // For now, delegate to the store's find_related method + let related = self.store.find_related(reference_memory_id, limit).await?; + + // Convert to scored results with mock similarity scores + let scored_results = related + .into_iter() + .enumerate() + .map(|(i, memory)| { + let score = 1.0 - (i as f64 * 0.1); // Decreasing scores + (memory, score.max(similarity_threshold)) + }) + .filter(|(_, score)| *score >= similarity_threshold) + .collect(); + + Ok(scored_results) + } + + /// Search for memories by temporal patterns + pub async fn temporal_search( + &self, + time_pattern: TemporalPattern, + namespace: Option<&MemoryNamespace>, + ) -> Result, MemoryError> { + let time_range = match time_pattern { + TemporalPattern::LastHour => { + let end = chrono::Utc::now(); + let start = end - chrono::Duration::hours(1); + TimeRange { + start: Some(start), + end: Some(end), + } + } + TemporalPattern::LastDay => { + let end = chrono::Utc::now(); + let start = end - chrono::Duration::days(1); + TimeRange { + start: Some(start), + end: Some(end), + } + } + TemporalPattern::LastWeek => { + let end = chrono::Utc::now(); + let start = end - chrono::Duration::weeks(1); + TimeRange { + start: Some(start), + end: Some(end), + } + } + TemporalPattern::Custom(start, end) => TimeRange { start, end }, + }; + + let query = MemoryQuery { + namespace: namespace.cloned(), + memory_types: None, + tags: None, + time_range: Some(time_range), + text_query: None, + semantic_query: None, + limit: None, + include_expired: false, + }; + + self.store.query(query).await + } + + /// Search by tags with boolean logic + pub async fn tag_search( + &self, + tag_query: TagQuery, + namespace: Option<&MemoryNamespace>, + ) -> Result, MemoryError> { + match tag_query { + TagQuery::And(tags) => { + let query = MemoryQuery { + namespace: namespace.cloned(), + memory_types: None, + tags: Some(tags), + time_range: None, + text_query: None, + semantic_query: None, + limit: None, + include_expired: false, + }; + self.store.query(query).await + } + TagQuery::Or(tags) => { + // For OR queries, we need to search for each tag separately and combine + let mut all_results = Vec::new(); + let mut seen_ids = std::collections::HashSet::new(); + + for tag in tags { + let query = MemoryQuery { + namespace: namespace.cloned(), + memory_types: None, + tags: Some(vec![tag]), + time_range: None, + text_query: None, + semantic_query: None, + limit: None, + include_expired: false, + }; + + let results = self.store.query(query).await?; + for memory in results { + if !seen_ids.contains(&memory.id) { + seen_ids.insert(memory.id.clone()); + all_results.push(memory); + } + } + } + + Ok(all_results) + } + TagQuery::Not(tag) => { + // Get all memories and filter out those with the tag + let query = MemoryQuery { + namespace: namespace.cloned(), + memory_types: None, + tags: None, + time_range: None, + text_query: None, + semantic_query: None, + limit: None, + include_expired: false, + }; + + let all_results = self.store.query(query).await?; + let filtered = all_results + .into_iter() + .filter(|memory| !memory.metadata.tags.contains(&tag)) + .collect(); + + Ok(filtered) + } + } + } + + /// Calculate text relevance score + fn calculate_text_relevance(&self, query: &str, memory: &MemoryDocument) -> f64 { + let query_lower = query.to_lowercase(); + let content_str = memory.content.to_string().to_lowercase(); + + // Simple scoring based on term frequency + let query_words: Vec<&str> = query_lower.split_whitespace().collect(); + let content_words: Vec<&str> = content_str.split_whitespace().collect(); + + if query_words.is_empty() || content_words.is_empty() { + return 0.0; + } + + let mut score = 0.0; + for query_word in &query_words { + let count = content_words + .iter() + .filter(|&&word| word == *query_word) + .count(); + score += count as f64; + } + + // Normalize by content length + score / content_words.len() as f64 + } +} + +/// Temporal patterns for time-based searches +#[derive(Debug, Clone)] +pub enum TemporalPattern { + LastHour, + LastDay, + LastWeek, + Custom( + Option>, + Option>, + ), +} + +/// Tag query with boolean logic +#[derive(Debug, Clone)] +pub enum TagQuery { + And(Vec), + Or(Vec), + Not(String), +} + +/// Distance calculation utilities +pub struct DistanceCalculator; + +impl DistanceCalculator { + /// Calculate cosine similarity between two vectors + pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 { + if a.len() != b.len() { + return 0.0; + } + + let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); + let magnitude_a: f32 = a.iter().map(|x| x * x).sum::().sqrt(); + let magnitude_b: f32 = b.iter().map(|x| x * x).sum::().sqrt(); + + if magnitude_a == 0.0 || magnitude_b == 0.0 { + 0.0 + } else { + (dot_product / (magnitude_a * magnitude_b)) as f64 + } + } + + /// Calculate Euclidean distance between two vectors + pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f64 { + if a.len() != b.len() { + return f64::INFINITY; + } + + let distance: f32 = a + .iter() + .zip(b.iter()) + .map(|(x, y)| (x - y) * (x - y)) + .sum::() + .sqrt(); + + distance as f64 + } + + /// Calculate dot product between two vectors + pub fn dot_product(a: &[f32], b: &[f32]) -> f64 { + if a.len() != b.len() { + return 0.0; + } + + a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::() as f64 + } +} diff --git a/src/agent/short_term.rs b/src/agent/short_term.rs new file mode 100644 index 0000000..97227e9 --- /dev/null +++ b/src/agent/short_term.rs @@ -0,0 +1,404 @@ +use async_trait::async_trait; +use chrono::{Duration, Utc}; +use serde_json::json; +use std::collections::HashMap; +use uuid::Uuid; + +use super::store::BaseMemoryStore; +use super::traits::{MemoryError, MemoryStore}; +use super::types::*; + +/// Short-term memory store for session/thread-scoped memories +pub struct ShortTermMemoryStore { + base_store: BaseMemoryStore, + default_ttl: Duration, + max_memories_per_thread: usize, +} + +impl ShortTermMemoryStore { + /// Create a new short-term memory store + pub fn new( + base_store: BaseMemoryStore, + default_ttl: Duration, + max_memories_per_thread: usize, + ) -> Self { + Self { + base_store, + default_ttl, + max_memories_per_thread, + } + } + + /// Store a conversation turn + pub async fn store_conversation_turn( + &mut self, + thread_id: &str, + role: &str, + content: &str, + metadata: Option>, + ) -> Result { + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::ShortTerm, + thread_id.to_string(), + ); + + let mut memory_metadata = MemoryMetadata::new( + self.base_store.agent_id().to_string(), + "conversation".to_string(), + ); + memory_metadata.thread_id = Some(thread_id.to_string()); + memory_metadata.ttl = Some(self.default_ttl); + memory_metadata.tags = vec!["conversation".to_string(), role.to_string()]; + + let mut content_json = json!({ + "role": role, + "content": content, + "timestamp": Utc::now() + }); + + // Add additional metadata if provided + if let Some(meta) = metadata { + if let serde_json::Value::Object(ref mut map) = content_json { + for (key, value) in meta { + map.insert(key, value); + } + } + } + + let memory = MemoryDocument { + id: Uuid::new_v4().to_string(), + namespace, + memory_type: MemoryType::ShortTerm, + content: content_json, + metadata: memory_metadata, + embeddings: None, + }; + + // Check thread memory limit + self.enforce_thread_limit(thread_id).await?; + + self.base_store.store(memory).await + } + + /// Get conversation history for a thread + pub async fn get_conversation_history( + &self, + thread_id: &str, + limit: Option, + ) -> Result, MemoryError> { + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::ShortTerm, + thread_id.to_string(), + ); + + let query = MemoryQuery { + namespace: Some(namespace), + memory_types: Some(vec![MemoryType::ShortTerm]), + tags: Some(vec!["conversation".to_string()]), + time_range: None, + text_query: None, + semantic_query: None, + limit, + include_expired: false, + }; + + let mut memories = self.base_store.query(query).await?; + + // Sort by creation time + memories.sort_by(|a, b| a.metadata.created_at.cmp(&b.metadata.created_at)); + + Ok(memories) + } + + /// Store working memory (temporary state, calculations, etc.) + pub async fn store_working_memory( + &mut self, + thread_id: &str, + key: &str, + data: serde_json::Value, + ttl: Option, + ) -> Result { + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::ShortTerm, + format!("{thread_id}/working"), + ); + + let mut memory_metadata = MemoryMetadata::new( + self.base_store.agent_id().to_string(), + "working_memory".to_string(), + ); + memory_metadata.thread_id = Some(thread_id.to_string()); + memory_metadata.ttl = ttl.or(Some(self.default_ttl)); + memory_metadata.tags = vec!["working_memory".to_string(), key.to_string()]; + + let content = json!({ + "key": key, + "data": data, + "timestamp": Utc::now() + }); + + let memory = MemoryDocument { + id: format!("working_{thread_id}_{key}"), + namespace, + memory_type: MemoryType::ShortTerm, + content, + metadata: memory_metadata, + embeddings: None, + }; + + self.base_store.store(memory).await + } + + /// Get working memory by key + pub async fn get_working_memory( + &self, + thread_id: &str, + key: &str, + ) -> Result, MemoryError> { + let memory_id = format!("working_{thread_id}_{key}"); + + if let Some(memory) = self.base_store.get(&memory_id).await? { + if let Some(data) = memory.content.get("data") { + Ok(Some(data.clone())) + } else { + Ok(None) + } + } else { + Ok(None) + } + } + + /// Store context information for the current session + pub async fn store_session_context( + &mut self, + thread_id: &str, + context_type: &str, + context_data: serde_json::Value, + ) -> Result { + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::ShortTerm, + format!("{thread_id}/context"), + ); + + let mut memory_metadata = MemoryMetadata::new( + self.base_store.agent_id().to_string(), + "session_context".to_string(), + ); + memory_metadata.thread_id = Some(thread_id.to_string()); + memory_metadata.ttl = Some(self.default_ttl); + memory_metadata.tags = vec!["context".to_string(), context_type.to_string()]; + + let content = json!({ + "context_type": context_type, + "context_data": context_data, + "timestamp": Utc::now() + }); + + let memory = MemoryDocument { + id: format!("context_{thread_id}_{context_type}"), + namespace, + memory_type: MemoryType::ShortTerm, + content, + metadata: memory_metadata, + embeddings: None, + }; + + self.base_store.store(memory).await + } + + /// Get all session context + pub async fn get_session_context( + &self, + thread_id: &str, + ) -> Result, MemoryError> { + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::ShortTerm, + format!("{thread_id}/context"), + ); + + let memories = self.base_store.get_by_namespace(&namespace).await?; + let mut context = HashMap::new(); + + for memory in memories { + if let Some(context_type) = memory.content.get("context_type").and_then(|v| v.as_str()) + { + if let Some(context_data) = memory.content.get("context_data") { + context.insert(context_type.to_string(), context_data.clone()); + } + } + } + + Ok(context) + } + + /// Clear all memories for a thread + pub async fn clear_thread(&mut self, thread_id: &str) -> Result { + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::ShortTerm, + thread_id.to_string(), + ); + + let memories = self.base_store.get_by_namespace(&namespace).await?; + let count = memories.len(); + + for memory in memories { + self.base_store.delete(&memory.id).await?; + } + + Ok(count) + } + + /// Enforce memory limit per thread + async fn enforce_thread_limit(&mut self, thread_id: &str) -> Result<(), MemoryError> { + let namespace = MemoryNamespace::with_sub( + self.base_store.agent_id().to_string(), + MemoryType::ShortTerm, + thread_id.to_string(), + ); + + let mut memories = self.base_store.get_by_namespace(&namespace).await?; + + if memories.len() >= self.max_memories_per_thread { + // Sort by creation time and remove oldest + memories.sort_by(|a, b| a.metadata.created_at.cmp(&b.metadata.created_at)); + + let to_remove = memories.len() - self.max_memories_per_thread + 1; + for memory in memories.iter().take(to_remove) { + self.base_store.delete(&memory.id).await?; + } + } + + Ok(()) + } + + /// Get memory statistics for short-term store + pub async fn get_short_term_stats(&self) -> Result { + let query = MemoryQuery { + namespace: None, + memory_types: Some(vec![MemoryType::ShortTerm]), + tags: None, + time_range: None, + text_query: None, + semantic_query: None, + limit: None, + include_expired: true, + }; + + let memories = self.base_store.query(query).await?; + + let mut thread_counts = HashMap::new(); + let mut active_threads = std::collections::HashSet::new(); + let mut total_conversations = 0; + let mut total_working_memory = 0; + let mut expired_count = 0; + + for memory in memories { + if let Some(thread_id) = &memory.metadata.thread_id { + *thread_counts.entry(thread_id.clone()).or_insert(0) += 1; + + if !memory.metadata.is_expired() { + active_threads.insert(thread_id.clone()); + } + } + + if memory.metadata.tags.contains(&"conversation".to_string()) { + total_conversations += 1; + } + + if memory.metadata.tags.contains(&"working_memory".to_string()) { + total_working_memory += 1; + } + + if memory.metadata.is_expired() { + expired_count += 1; + } + } + + Ok(ShortTermStats { + total_threads: thread_counts.len(), + active_threads: active_threads.len(), + total_conversations, + total_working_memory, + thread_memory_counts: thread_counts, + expired_count, + }) + } +} + +// Delegate most MemoryStore methods to the base store +#[async_trait] +impl MemoryStore for ShortTermMemoryStore { + async fn store(&mut self, memory: MemoryDocument) -> Result { + // Ensure it's short-term memory + if memory.memory_type != MemoryType::ShortTerm { + return Err(MemoryError::InvalidNamespace( + "Memory type must be ShortTerm".to_string(), + )); + } + self.base_store.store(memory).await + } + + async fn update(&mut self, id: &str, memory: MemoryDocument) -> Result<(), MemoryError> { + self.base_store.update(id, memory).await + } + + async fn get(&self, id: &str) -> Result, MemoryError> { + self.base_store.get(id).await + } + + async fn delete(&mut self, id: &str) -> Result<(), MemoryError> { + self.base_store.delete(id).await + } + + async fn query(&self, query: MemoryQuery) -> Result, MemoryError> { + self.base_store.query(query).await + } + + async fn get_by_namespace( + &self, + namespace: &MemoryNamespace, + ) -> Result, MemoryError> { + self.base_store.get_by_namespace(namespace).await + } + + async fn commit(&mut self, message: &str) -> Result { + self.base_store.commit(message).await + } + + async fn create_branch(&mut self, name: &str) -> Result<(), MemoryError> { + self.base_store.create_branch(name).await + } + + async fn checkout(&mut self, branch_or_commit: &str) -> Result<(), MemoryError> { + self.base_store.checkout(branch_or_commit).await + } + + fn current_branch(&self) -> &str { + self.base_store.current_branch() + } + + async fn get_stats(&self) -> Result { + self.base_store.get_stats().await + } + + async fn cleanup_expired(&mut self) -> Result { + self.base_store.cleanup_expired().await + } +} + +/// Statistics specific to short-term memory +#[derive(Debug, Clone)] +pub struct ShortTermStats { + pub total_threads: usize, + pub active_threads: usize, + pub total_conversations: usize, + pub total_working_memory: usize, + pub thread_memory_counts: HashMap, + pub expired_count: usize, +} diff --git a/src/agent/simple_persistence.rs b/src/agent/simple_persistence.rs new file mode 100644 index 0000000..343f223 --- /dev/null +++ b/src/agent/simple_persistence.rs @@ -0,0 +1,138 @@ +use async_trait::async_trait; +use std::collections::HashMap; +use std::error::Error; +use std::path::Path; +use std::sync::Arc; +use tokio::sync::RwLock; + +use super::traits::MemoryPersistence; + +/// Simple in-memory persistence for demonstration +/// In production, this would be replaced with a proper prolly tree implementation +/// that handles the Send/Sync requirements properly +pub struct SimpleMemoryPersistence { + data: Arc>>>, + namespace_prefix: String, +} + +impl SimpleMemoryPersistence { + /// Initialize a new simple memory persistence layer + pub fn init>(_path: P, namespace_prefix: &str) -> Result> { + Ok(Self { + data: Arc::new(RwLock::new(HashMap::new())), + namespace_prefix: namespace_prefix.to_string(), + }) + } + + /// Open an existing simple memory persistence layer (same as init for this implementation) + pub fn open>(_path: P, namespace_prefix: &str) -> Result> { + Self::init(_path, namespace_prefix) + } + + /// Get the full key with namespace prefix + fn full_key(&self, key: &str) -> String { + format!("{}/{}", self.namespace_prefix, key) + } +} + +#[async_trait] +impl MemoryPersistence for SimpleMemoryPersistence { + async fn save(&mut self, key: &str, data: &[u8]) -> Result<(), Box> { + let full_key = self.full_key(key); + let mut store = self.data.write().await; + store.insert(full_key, data.to_vec()); + Ok(()) + } + + async fn load(&self, key: &str) -> Result>, Box> { + let full_key = self.full_key(key); + let store = self.data.read().await; + Ok(store.get(&full_key).cloned()) + } + + async fn delete(&mut self, key: &str) -> Result<(), Box> { + let full_key = self.full_key(key); + let mut store = self.data.write().await; + store.remove(&full_key); + Ok(()) + } + + async fn list_keys(&self, prefix: &str) -> Result, Box> { + let full_prefix = format!("{}/{}", self.namespace_prefix, prefix); + let store = self.data.read().await; + + let keys = store + .keys() + .filter_map(|k| { + if k.starts_with(&full_prefix) { + k.strip_prefix(&format!("{}/", self.namespace_prefix)) + .map(|s| s.to_string()) + } else { + None + } + }) + .collect(); + + Ok(keys) + } + + async fn checkpoint(&mut self, message: &str) -> Result> { + // For simple implementation, just return a mock commit ID + let commit_id = format!("simple_commit_{}", chrono::Utc::now().timestamp()); + println!("Created checkpoint: {} - {}", commit_id, message); + Ok(commit_id) + } +} + +/// Additional methods for the simple persistence +impl SimpleMemoryPersistence { + /// Create a new branch (no-op for simple implementation) + pub async fn create_branch(&mut self, name: &str) -> Result<(), Box> { + println!("Created branch: {name}"); + Ok(()) + } + + /// Switch to a branch or commit (no-op for simple implementation) + pub async fn checkout(&mut self, branch_or_commit: &str) -> Result<(), Box> { + println!("Checked out: {branch_or_commit}"); + Ok(()) + } + + /// Get current branch name + pub async fn current_branch(&self) -> String { + "main".to_string() + } + + /// List all branches + pub async fn list_branches(&self) -> Result, Box> { + Ok(vec!["main".to_string()]) + } + + /// Get status of staged changes + pub async fn status(&self) -> Vec<(Vec, String)> { + vec![] + } + + /// Merge another branch + pub async fn merge(&mut self, branch: &str) -> Result> { + println!("Merged branch: {branch}"); + Ok(format!("merge_result_{}", chrono::Utc::now().timestamp())) + } + + /// Get history of commits + pub async fn history(&self, _limit: Option) -> Result, Box> { + Ok(vec!["Initial commit".to_string()]) + } + + /// Get data size for statistics + pub async fn data_size(&self) -> usize { + let store = self.data.read().await; + store.values().map(|v| v.len()).sum() + } + + /// Get key count + pub async fn key_count(&self) -> usize { + let store = self.data.read().await; + store.len() + } +} diff --git a/src/agent/store.rs b/src/agent/store.rs new file mode 100644 index 0000000..3abb9f6 --- /dev/null +++ b/src/agent/store.rs @@ -0,0 +1,518 @@ +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use serde_json; +use std::collections::HashMap; +use std::path::Path; +use std::sync::Arc; +use tokio::sync::RwLock; +use uuid::Uuid; + +use super::simple_persistence::SimpleMemoryPersistence; +use super::traits::{EmbeddingGenerator, MemoryError, MemoryPersistence, MemoryStore}; +use super::types::*; +// use crate::git::GitKvError; + +/// Base implementation of the memory store using simple persistence +#[derive(Clone)] +pub struct BaseMemoryStore { + persistence: Arc>, + embedding_generator: Option>, + agent_id: String, + current_branch: String, +} + +impl BaseMemoryStore { + /// Get the agent ID + pub fn agent_id(&self) -> &str { + &self.agent_id + } + + /// Initialize a new memory store + pub fn init>( + path: P, + agent_id: String, + embedding_generator: Option>, + ) -> Result> { + let persistence = SimpleMemoryPersistence::init(path, &format!("agent_memory_{agent_id}"))?; + Ok(Self { + persistence: Arc::new(RwLock::new(persistence)), + embedding_generator: embedding_generator + .map(|gen| Arc::from(gen) as Arc), + agent_id, + current_branch: "main".to_string(), + }) + } + + /// Open an existing memory store + pub fn open>( + path: P, + agent_id: String, + embedding_generator: Option>, + ) -> Result> { + let persistence = SimpleMemoryPersistence::open(path, &format!("agent_memory_{agent_id}"))?; + Ok(Self { + persistence: Arc::new(RwLock::new(persistence)), + embedding_generator: embedding_generator + .map(|gen| Arc::from(gen) as Arc), + agent_id, + current_branch: "main".to_string(), + }) + } + + /// Generate key for memory document + fn memory_key(&self, namespace: &MemoryNamespace, id: &str) -> String { + format!("{}/{}", namespace.to_path(), id) + } + + /// Generate embeddings if generator is available + async fn generate_embeddings(&self, content: &serde_json::Value) -> Option> { + if let Some(ref generator) = self.embedding_generator { + // Extract text content for embedding generation + let text = self.extract_text_content(content); + if !text.is_empty() { + match generator.generate(&text).await { + Ok(embeddings) => Some(embeddings), + Err(e) => { + eprintln!("Failed to generate embeddings: {e}"); + None + } + } + } else { + None + } + } else { + None + } + } + + /// Extract text content from JSON for embedding generation + fn extract_text_content(&self, content: &serde_json::Value) -> String { + match content { + serde_json::Value::String(s) => s.clone(), + serde_json::Value::Object(map) => { + let mut text_parts = Vec::new(); + + // Common text fields to extract + let text_fields = ["text", "content", "message", "description", "summary"]; + + for field in &text_fields { + if let Some(value) = map.get(*field) { + if let Some(text) = value.as_str() { + text_parts.push(text); + } + } + } + + // If no specific text fields, join all string values + if text_parts.is_empty() { + for value in map.values() { + if let Some(text) = value.as_str() { + text_parts.push(text); + } + } + } + + text_parts.join(" ") + } + _ => content.to_string(), + } + } + + /// Validate memory document + fn validate_memory(&self, memory: &MemoryDocument) -> Result<(), MemoryError> { + // Check if agent ID matches + if memory.metadata.agent_id != *self.agent_id() { + return Err(MemoryError::PermissionDenied(format!( + "Memory belongs to different agent: {}", + memory.metadata.agent_id + ))); + } + + // Check if memory is expired + if memory.metadata.is_expired() { + return Err(MemoryError::Expired(format!( + "Memory {} has expired", + memory.id + ))); + } + + Ok(()) + } + + /// Serialize memory document for storage + fn serialize_memory(&self, memory: &MemoryDocument) -> Result, MemoryError> { + serde_json::to_vec(memory) + .map_err(|e| MemoryError::SerializationError(format!("Failed to serialize: {e}"))) + } + + /// Deserialize memory document from storage + fn deserialize_memory(&self, data: &[u8]) -> Result { + serde_json::from_slice(data) + .map_err(|e| MemoryError::SerializationError(format!("Failed to deserialize: {e}"))) + } +} + +#[async_trait] +impl MemoryStore for BaseMemoryStore { + async fn store(&mut self, mut memory: MemoryDocument) -> Result { + // Validate the memory + self.validate_memory(&memory)?; + + // Generate ID if not provided + if memory.id.is_empty() { + memory.id = Uuid::new_v4().to_string(); + } + + // Generate embeddings if available + if memory.embeddings.is_none() { + memory.embeddings = self.generate_embeddings(&memory.content).await; + } + + // Update metadata + memory.metadata.updated_at = Utc::now(); + + // Store the memory + let key = self.memory_key(&memory.namespace, &memory.id); + let data = self.serialize_memory(&memory)?; + + { + let mut persistence = self.persistence.write().await; + (*persistence) + .save(&key, &data) + .await + .map_err(|e| MemoryError::StorageError(format!("Failed to store: {e}")))?; + } + + Ok(memory.id) + } + + async fn update(&mut self, id: &str, mut memory: MemoryDocument) -> Result<(), MemoryError> { + // Ensure the ID matches + memory.id = id.to_string(); + + // Validate the memory + self.validate_memory(&memory)?; + + // Generate embeddings if content changed + if memory.embeddings.is_none() { + memory.embeddings = self.generate_embeddings(&memory.content).await; + } + + // Update metadata + memory.metadata.updated_at = Utc::now(); + + // Store the updated memory + let key = self.memory_key(&memory.namespace, id); + let data = self.serialize_memory(&memory)?; + + { + let mut persistence = self.persistence.write().await; + (*persistence) + .save(&key, &data) + .await + .map_err(|e| MemoryError::StorageError(format!("Failed to update: {e}")))?; + } + + Ok(()) + } + + async fn get(&self, id: &str) -> Result, MemoryError> { + // We need to search across all namespaces for this ID + // This is a simplified implementation - in practice, you might want to index by ID + let persistence = self.persistence.read().await; + + // Try different memory types and namespaces + for memory_type in [ + MemoryType::ShortTerm, + MemoryType::Semantic, + MemoryType::Episodic, + MemoryType::Procedural, + ] { + let namespace = MemoryNamespace::new(self.agent_id().to_string(), memory_type); + let key = self.memory_key(&namespace, id); + + let data_result = (*persistence) + .load(&key) + .await + .map_err(|e| MemoryError::StorageError(format!("Failed to load: {e}")))?; + + if let Some(data) = data_result { + let memory = self.deserialize_memory(&data)?; + return Ok(Some(memory)); + } + } + + Ok(None) + } + + async fn delete(&mut self, id: &str) -> Result<(), MemoryError> { + // Similar to get, we need to find the memory first + if let Some(memory) = self.get(id).await? { + let key = self.memory_key(&memory.namespace, id); + let mut persistence = self.persistence.write().await; + (*persistence) + .delete(&key) + .await + .map_err(|e| MemoryError::StorageError(format!("Failed to delete: {e}")))?; + Ok(()) + } else { + Err(MemoryError::NotFound(format!( + "Memory with ID {id} not found" + ))) + } + } + + async fn query(&self, query: MemoryQuery) -> Result, MemoryError> { + let mut results = Vec::new(); + let persistence = self.persistence.read().await; + + // Determine which namespaces to search + let namespaces = if let Some(ns) = &query.namespace { + vec![ns.clone()] + } else { + // Search all memory types for this agent + let memory_types = query.memory_types.clone().unwrap_or_else(|| { + vec![ + MemoryType::ShortTerm, + MemoryType::Semantic, + MemoryType::Episodic, + MemoryType::Procedural, + ] + }); + + memory_types + .into_iter() + .map(|mt| MemoryNamespace::new(self.agent_id().to_string(), mt)) + .collect() + }; + + // Search each namespace + for namespace in namespaces { + let prefix = namespace.to_path(); + let keys = (*persistence) + .list_keys(&prefix) + .await + .map_err(|e| MemoryError::StorageError(format!("Failed to list keys: {e}")))?; + + for key in keys { + let data_result = (*persistence) + .load(&key) + .await + .map_err(|e| MemoryError::StorageError(format!("Failed to load: {e}")))?; + + if let Some(data) = data_result { + if let Ok(memory) = self.deserialize_memory(&data) { + // Apply filters + if self.matches_query(&memory, &query) { + results.push(memory); + } + } + } + } + } + + // Apply limit + if let Some(limit) = query.limit { + results.truncate(limit); + } + + Ok(results) + } + + async fn get_by_namespace( + &self, + namespace: &MemoryNamespace, + ) -> Result, MemoryError> { + let mut results = Vec::new(); + let persistence = self.persistence.read().await; + + let prefix = namespace.to_path(); + let keys = (*persistence) + .list_keys(&prefix) + .await + .map_err(|e| MemoryError::StorageError(format!("Failed to list keys: {e}")))?; + + for key in keys { + let data_result = (*persistence) + .load(&key) + .await + .map_err(|e| MemoryError::StorageError(format!("Failed to load: {e}")))?; + + if let Some(data) = data_result { + if let Ok(memory) = self.deserialize_memory(&data) { + if !memory.metadata.is_expired() { + results.push(memory); + } + } + } + } + + Ok(results) + } + + async fn commit(&mut self, message: &str) -> Result { + let mut persistence = self.persistence.write().await; + (*persistence) + .checkpoint(message) + .await + .map_err(|e| MemoryError::StorageError(format!("Failed to commit: {e}"))) + } + + async fn create_branch(&mut self, name: &str) -> Result<(), MemoryError> { + let mut persistence = self.persistence.write().await; + persistence.create_branch(name).await?; + Ok(()) + } + + async fn checkout(&mut self, branch_or_commit: &str) -> Result<(), MemoryError> { + let mut persistence = self.persistence.write().await; + persistence.checkout(branch_or_commit).await?; + self.current_branch = branch_or_commit.to_string(); + Ok(()) + } + + fn current_branch(&self) -> &str { + &self.current_branch + } + + async fn get_stats(&self) -> Result { + let mut by_type = HashMap::new(); + let mut by_namespace = HashMap::new(); + let mut total_memories = 0; + let mut total_size_bytes = 0; + let mut access_counts = Vec::new(); + let mut oldest: Option> = None; + let mut newest: Option> = None; + let mut expired_count = 0; + + // Query all memories to build stats + let query = MemoryQuery { + namespace: None, + memory_types: None, + tags: None, + time_range: None, + text_query: None, + semantic_query: None, + limit: None, + include_expired: true, + }; + + let memories = self.query(query).await?; + + for memory in memories { + total_memories += 1; + + // Count by type + *by_type.entry(memory.memory_type).or_insert(0) += 1; + + // Count by namespace + let ns_key = memory.namespace.to_path(); + *by_namespace.entry(ns_key).or_insert(0) += 1; + + // Size estimation (rough) + total_size_bytes += std::mem::size_of_val(&memory); + + // Access count + access_counts.push(memory.metadata.access_count as f64); + + // Time tracking + if oldest.is_none_or(|o| memory.metadata.created_at < o) { + oldest = Some(memory.metadata.created_at); + } + if newest.is_none_or(|n| memory.metadata.created_at > n) { + newest = Some(memory.metadata.created_at); + } + + // Expired count + if memory.metadata.is_expired() { + expired_count += 1; + } + } + + let avg_access_count = if access_counts.is_empty() { + 0.0 + } else { + access_counts.iter().sum::() / access_counts.len() as f64 + }; + + Ok(MemoryStats { + total_memories, + by_type, + by_namespace, + total_size_bytes, + avg_access_count, + oldest_memory: oldest, + newest_memory: newest, + expired_count, + }) + } + + async fn cleanup_expired(&mut self) -> Result { + let query = MemoryQuery { + namespace: None, + memory_types: None, + tags: None, + time_range: None, + text_query: None, + semantic_query: None, + limit: None, + include_expired: true, + }; + + let memories = self.query(query).await?; + let mut cleaned_count = 0; + + for memory in memories { + if memory.metadata.is_expired() { + self.delete(&memory.id).await?; + cleaned_count += 1; + } + } + + Ok(cleaned_count) + } +} + +impl BaseMemoryStore { + /// Check if a memory matches the query criteria + fn matches_query(&self, memory: &MemoryDocument, query: &MemoryQuery) -> bool { + // Check expiry + if !query.include_expired && memory.metadata.is_expired() { + return false; + } + + // Check tags + if let Some(required_tags) = &query.tags { + if !required_tags + .iter() + .all(|tag| memory.metadata.tags.contains(tag)) + { + return false; + } + } + + // Check time range + if let Some(time_range) = &query.time_range { + if let Some(start) = time_range.start { + if memory.metadata.created_at < start { + return false; + } + } + if let Some(end) = time_range.end { + if memory.metadata.created_at > end { + return false; + } + } + } + + // Check text query (simple substring search) + if let Some(text_query) = &query.text_query { + let content_str = memory.content.to_string().to_lowercase(); + if !content_str.contains(&text_query.to_lowercase()) { + return false; + } + } + + true + } +} diff --git a/src/agent/traits.rs b/src/agent/traits.rs new file mode 100644 index 0000000..e6d4d3c --- /dev/null +++ b/src/agent/traits.rs @@ -0,0 +1,157 @@ +use async_trait::async_trait; +use std::error::Error; +use std::fmt; + +use super::types::*; + +/// Error types for memory operations +#[derive(Debug)] +pub enum MemoryError { + NotFound(String), + InvalidNamespace(String), + StorageError(String), + SerializationError(String), + PermissionDenied(String), + Expired(String), + ConflictError(String), +} + +impl fmt::Display for MemoryError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MemoryError::NotFound(msg) => write!(f, "Memory not found: {msg}"), + MemoryError::InvalidNamespace(msg) => write!(f, "Invalid namespace: {msg}"), + MemoryError::StorageError(msg) => write!(f, "Storage error: {msg}"), + MemoryError::SerializationError(msg) => write!(f, "Serialization error: {msg}"), + MemoryError::PermissionDenied(msg) => write!(f, "Permission denied: {msg}"), + MemoryError::Expired(msg) => write!(f, "Memory expired: {msg}"), + MemoryError::ConflictError(msg) => write!(f, "Conflict error: {msg}"), + } + } +} + +impl Error for MemoryError {} + +impl From> for MemoryError { + fn from(error: Box) -> Self { + MemoryError::StorageError(error.to_string()) + } +} + +/// Core trait for memory store implementations +#[async_trait] +pub trait MemoryStore: Send + Sync { + /// Store a new memory document + async fn store(&mut self, memory: MemoryDocument) -> Result; + + /// Update an existing memory + async fn update(&mut self, id: &str, memory: MemoryDocument) -> Result<(), MemoryError>; + + /// Retrieve a memory by ID + async fn get(&self, id: &str) -> Result, MemoryError>; + + /// Delete a memory by ID + async fn delete(&mut self, id: &str) -> Result<(), MemoryError>; + + /// Query memories based on criteria + async fn query(&self, query: MemoryQuery) -> Result, MemoryError>; + + /// Get memories by namespace + async fn get_by_namespace( + &self, + namespace: &MemoryNamespace, + ) -> Result, MemoryError>; + + /// Commit current changes (for version control) + async fn commit(&mut self, message: &str) -> Result; + + /// Create a new branch + async fn create_branch(&mut self, name: &str) -> Result<(), MemoryError>; + + /// Switch to a different branch + async fn checkout(&mut self, branch_or_commit: &str) -> Result<(), MemoryError>; + + /// Get current branch name + fn current_branch(&self) -> &str; + + /// Get memory statistics + async fn get_stats(&self) -> Result; + + /// Cleanup expired memories + async fn cleanup_expired(&mut self) -> Result; +} + +/// Trait for memory stores with search capabilities +#[async_trait] +pub trait SearchableMemoryStore: MemoryStore { + /// Perform semantic search using embeddings + async fn semantic_search( + &self, + query: SemanticQuery, + namespace: Option<&MemoryNamespace>, + ) -> Result, MemoryError>; + + /// Full-text search across memories + async fn text_search( + &self, + query: &str, + namespace: Option<&MemoryNamespace>, + ) -> Result, MemoryError>; + + /// Find related memories based on a given memory + async fn find_related( + &self, + memory_id: &str, + limit: usize, + ) -> Result, MemoryError>; +} + +/// Trait for memory lifecycle management +#[async_trait] +pub trait MemoryLifecycle: MemoryStore { + /// Apply consolidation strategy + async fn consolidate(&mut self, strategy: ConsolidationStrategy) -> Result; + + /// Archive old memories + async fn archive( + &mut self, + before: chrono::DateTime, + ) -> Result; + + /// Subscribe to memory events + async fn subscribe_events(&mut self, callback: F) -> Result<(), MemoryError> + where + F: Fn(MemoryEvent) + Send + Sync + 'static; + + /// Get memory history for a specific ID + async fn get_history(&self, memory_id: &str) -> Result, MemoryError>; +} + +/// Trait for embedding generation +#[async_trait] +pub trait EmbeddingGenerator: Send + Sync { + /// Generate embeddings for text content + async fn generate(&self, text: &str) -> Result, Box>; + + /// Batch generate embeddings + async fn generate_batch(&self, texts: &[String]) -> Result>, Box>; +} + +/// Trait for memory persistence backend +#[async_trait] +pub trait MemoryPersistence: Send + Sync { + /// Save memory data + async fn save(&mut self, key: &str, data: &[u8]) -> Result<(), Box>; + + /// Load memory data + async fn load(&self, key: &str) -> Result>, Box>; + + /// Delete memory data + async fn delete(&mut self, key: &str) -> Result<(), Box>; + + /// List keys with prefix + async fn list_keys(&self, prefix: &str) -> Result, Box>; + + /// Create checkpoint + async fn checkpoint(&mut self, message: &str) -> Result>; +} diff --git a/src/agent/types.rs b/src/agent/types.rs new file mode 100644 index 0000000..a41e54f --- /dev/null +++ b/src/agent/types.rs @@ -0,0 +1,290 @@ +use chrono::{DateTime, Duration, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fmt; + +/// Different types of memory in the agent system +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub enum MemoryType { + /// Short-term memory - thread/session scoped + ShortTerm, + /// Long-term semantic memory - facts and concepts + Semantic, + /// Long-term episodic memory - past experiences + Episodic, + /// Long-term procedural memory - rules and instructions + Procedural, +} + +impl fmt::Display for MemoryType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MemoryType::ShortTerm => write!(f, "ShortTerm"), + MemoryType::Semantic => write!(f, "Semantic"), + MemoryType::Episodic => write!(f, "Episodic"), + MemoryType::Procedural => write!(f, "Procedural"), + } + } +} + +/// Memory namespace for organizing memories hierarchically +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct MemoryNamespace { + /// Agent identifier + pub agent_id: String, + /// Memory type + pub memory_type: MemoryType, + /// Optional sub-namespace (e.g., thread_id for short-term, entity_type for semantic) + pub sub_namespace: Option, +} + +impl MemoryNamespace { + /// Create a new namespace + pub fn new(agent_id: String, memory_type: MemoryType) -> Self { + Self { + agent_id, + memory_type, + sub_namespace: None, + } + } + + /// Create namespace with sub-namespace + pub fn with_sub(agent_id: String, memory_type: MemoryType, sub: String) -> Self { + Self { + agent_id, + memory_type, + sub_namespace: Some(sub), + } + } + + /// Convert to path representation for storage + pub fn to_path(&self) -> String { + let base = format!("/memory/agents/{}/{}", self.agent_id, self.memory_type); + match &self.sub_namespace { + Some(sub) => format!("{base}/{sub}"), + None => base, + } + } +} + +/// Memory document structure +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryDocument { + /// Unique identifier + pub id: String, + /// Namespace this memory belongs to + pub namespace: MemoryNamespace, + /// Memory type + pub memory_type: MemoryType, + /// The actual content/data + pub content: serde_json::Value, + /// Metadata about the memory + pub metadata: MemoryMetadata, + /// Optional embeddings for semantic search + #[serde(skip_serializing_if = "Option::is_none")] + pub embeddings: Option>, +} + +/// Metadata associated with a memory +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryMetadata { + /// Creation timestamp + pub created_at: DateTime, + /// Last updated timestamp + pub updated_at: DateTime, + /// Agent that created this memory + pub agent_id: String, + /// Thread/session ID for short-term memories + #[serde(skip_serializing_if = "Option::is_none")] + pub thread_id: Option, + /// Tags for categorization + #[serde(default)] + pub tags: Vec, + /// Time-to-live for automatic expiration (mainly for short-term) + #[serde(skip_serializing_if = "Option::is_none")] + pub ttl: Option, + /// Access pattern tracking + pub access_count: u32, + pub last_accessed: Option>, + /// Source of the memory (user input, inference, external API, etc.) + pub source: String, + /// Confidence score (0.0 to 1.0) + pub confidence: f64, + /// Related memory IDs for cross-referencing + #[serde(default)] + pub related_memories: Vec, +} + +impl MemoryMetadata { + /// Create new metadata with defaults + pub fn new(agent_id: String, source: String) -> Self { + let now = Utc::now(); + Self { + created_at: now, + updated_at: now, + agent_id, + thread_id: None, + tags: Vec::new(), + ttl: None, + access_count: 0, + last_accessed: None, + source, + confidence: 1.0, + related_memories: Vec::new(), + } + } + + /// Mark memory as accessed + pub fn mark_accessed(&mut self) { + self.access_count += 1; + self.last_accessed = Some(Utc::now()); + } + + /// Check if memory has expired + pub fn is_expired(&self) -> bool { + if let Some(ttl) = self.ttl { + Utc::now() > self.created_at + ttl + } else { + false + } + } +} + +/// Query parameters for memory retrieval +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryQuery { + /// Namespace to search in + pub namespace: Option, + /// Memory types to include + pub memory_types: Option>, + /// Tags to filter by (AND operation) + pub tags: Option>, + /// Time range filter + pub time_range: Option, + /// Full-text search query + pub text_query: Option, + /// Semantic search with embeddings + pub semantic_query: Option, + /// Maximum number of results + pub limit: Option, + /// Include expired memories + pub include_expired: bool, +} + +/// Time range for filtering +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TimeRange { + pub start: Option>, + pub end: Option>, +} + +/// Semantic search parameters +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SemanticQuery { + /// Query embeddings + pub embeddings: Vec, + /// Similarity threshold (0.0 to 1.0) + pub threshold: f64, + /// Distance metric to use + pub metric: DistanceMetric, +} + +/// Distance metrics for semantic search +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub enum DistanceMetric { + Cosine, + Euclidean, + DotProduct, +} + +/// Memory operation result +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryResult { + pub success: bool, + pub data: Option, + pub error: Option, + pub metadata: HashMap, +} + +impl MemoryResult { + pub fn success(data: T) -> Self { + Self { + success: true, + data: Some(data), + error: None, + metadata: HashMap::new(), + } + } + + pub fn failure(error: String) -> Self { + Self { + success: false, + data: None, + error: Some(error), + metadata: HashMap::new(), + } + } +} + +/// Memory lifecycle events +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum MemoryEvent { + Created { + memory_id: String, + namespace: MemoryNamespace, + timestamp: DateTime, + }, + Updated { + memory_id: String, + namespace: MemoryNamespace, + timestamp: DateTime, + changes: Vec, + }, + Accessed { + memory_id: String, + namespace: MemoryNamespace, + timestamp: DateTime, + access_count: u32, + }, + Expired { + memory_id: String, + namespace: MemoryNamespace, + timestamp: DateTime, + ttl: Duration, + }, + Deleted { + memory_id: String, + namespace: MemoryNamespace, + timestamp: DateTime, + reason: String, + }, +} + +/// Memory consolidation strategy +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ConsolidationStrategy { + /// Merge similar memories into a single entry + MergeSimilar { similarity_threshold: f64 }, + /// Summarize multiple memories into abstract concepts + Summarize { max_memories: usize }, + /// Archive old memories to cold storage + Archive { age_threshold: Duration }, + /// Remove low-value memories + Prune { + confidence_threshold: f64, + access_threshold: u32, + }, +} + +/// Memory statistics for monitoring +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryStats { + pub total_memories: usize, + pub by_type: HashMap, + pub by_namespace: HashMap, + pub total_size_bytes: usize, + pub avg_access_count: f64, + pub oldest_memory: Option>, + pub newest_memory: Option>, + pub expired_count: usize, +} diff --git a/src/lib.rs b/src/lib.rs index 422ac78..a7045ee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,6 +44,7 @@ limitations under the License. #[macro_use] pub mod digest; +pub mod agent; pub mod config; pub mod diff; mod encoding; From fd37c362e8a826659d61c8b306352c8ee45583b1 Mon Sep 17 00:00:00 2001 From: Feng Zhang Date: Sun, 27 Jul 2025 14:12:04 -0700 Subject: [PATCH 4/7] Agent Memory System Successfully Implemented MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The implementation has been completed successfully with prolly tree-based in-memory storage. Here's what's working: ๐Ÿง  Core Memory Types Implemented: 1. Short-term Memory: Conversation history, working memory, session context 2. Semantic Memory: Facts, relationships, structured knowledge 3. Episodic Memory: Experience records, interaction history 4. Procedural Memory: Rules, procedures, operational knowledge ๐ŸŒณ Prolly Tree Integration: - Replaced HashMap with ProllyTree<32, InMemoryNodeStorage<32>> - Maintains same interface while providing proper tree-based storage - Thread-safe async operations with Arc> - Advanced features: tree_stats(), range_query(), commit tracking --- src/agent/README.md | 72 ++++++-- src/agent/mod.rs | 8 +- src/agent/simple_persistence.rs | 318 +++++++++++++++++++++++++++----- 3 files changed, 330 insertions(+), 68 deletions(-) diff --git a/src/agent/README.md b/src/agent/README.md index 9ad9784..6dc4f75 100644 --- a/src/agent/README.md +++ b/src/agent/README.md @@ -15,36 +15,36 @@ The Agent Memory System implements different types of memory inspired by human c ### Core Components -1. **Types** (`src/agent_memory/types.rs`) +1. **Types** (`src/agent/types.rs`) - Memory data structures and enums - Namespace organization for hierarchical memory - Query and filter types -2. **Traits** (`src/agent_memory/traits.rs`) +2. **Traits** (`src/agent/traits.rs`) - Abstract interfaces for memory operations - Embedding generation and search capabilities - Lifecycle management interfaces -3. **Persistence** (`src/agent_memory/simple_persistence.rs`) - - Simple in-memory persistence for demonstration - - Designed to be replaced with prolly tree persistence - - Thread-safe async operations +3. **Persistence** (`src/agent/simple_persistence.rs`) + - Prolly tree-based in-memory persistence + - Uses `ProllyTree<32, InMemoryNodeStorage<32>>` for robust storage + - Thread-safe async operations with Arc -4. **Store** (`src/agent_memory/store.rs`) +4. **Store** (`src/agent/store.rs`) - Base memory store implementation - Handles serialization/deserialization - Manages memory validation and access 5. **Memory Types**: - - **Short-Term** (`src/agent_memory/short_term.rs`): Conversation history, working memory - - **Long-Term** (`src/agent_memory/long_term.rs`): Semantic, episodic, and procedural stores + - **Short-Term** (`src/agent/short_term.rs`): Conversation history, working memory + - **Long-Term** (`src/agent/long_term.rs`): Semantic, episodic, and procedural stores -6. **Search** (`src/agent_memory/search.rs`) +6. **Search** (`src/agent/search.rs`) - Memory search and retrieval capabilities - Mock embedding generation - Distance calculation utilities -7. **Lifecycle** (`src/agent_memory/lifecycle.rs`) +7. **Lifecycle** (`src/agent/lifecycle.rs`) - Memory consolidation and archival - Cleanup and optimization - Event broadcasting @@ -103,7 +103,7 @@ For example: ## Usage Example ```rust -use prollytree::agent_memory::*; +use prollytree::agent::*; #[tokio::main] async fn main() -> Result<(), Box> { @@ -155,24 +155,27 @@ async fn main() -> Result<(), Box> { ### Completed โœ… - Core type definitions and interfaces -- Simple persistence layer +- **Prolly tree-based persistence layer** with `ProllyTree<32, InMemoryNodeStorage<32>>` - All four memory types (Short-term, Semantic, Episodic, Procedural) - Basic search functionality - Memory lifecycle management - Working demo example +- Thread-safe async operations +- Tree statistics and range queries +- Commit tracking with sequential IDs ### Planned ๐Ÿšง -- Full prolly tree persistence integration (blocked by Send/Sync issues) - Real embedding generation (currently uses mock) - Advanced semantic search - Memory conflict resolution - Performance optimizations +- Git-based prolly tree persistence for durability ### Known Limitations -- Uses simple in-memory persistence instead of prolly tree - Mock embedding generation - Limited semantic search capabilities - No conflict resolution for concurrent updates +- In-memory storage (data doesn't persist across restarts) ## Design Decisions @@ -183,9 +186,31 @@ async fn main() -> Result<(), Box> { 5. **Type Safety**: Strong typing for memory operations 6. **Extensible Design**: Easy to add new memory types or features +## Prolly Tree Integration Details + +The memory system now uses prolly trees for storage with the following features: + +### Storage Architecture +- **Tree Structure**: `ProllyTree<32, InMemoryNodeStorage<32>>` +- **Namespace Prefixes**: Organized hierarchically with agent ID and memory type +- **Thread Safety**: `Arc>` for concurrent access +- **Commit Tracking**: Sequential commit IDs (prolly_commit_00000001, etc.) + +### Advanced Features +- **Tree Statistics**: `tree_stats()` provides key count and size metrics +- **Range Queries**: `range_query()` for efficient range-based retrieval +- **Direct Tree Access**: `with_tree()` for advanced operations +- **Git-like Operations**: Branch, checkout, merge simulation for future git integration + +### Performance Benefits +- **Balanced Tree Structure**: O(log n) operations for most queries +- **Content Addressing**: Efficient deduplication and integrity checking +- **Probabilistic Balancing**: Maintains performance under various workloads +- **Memory Efficient**: Shared storage for duplicate content + ## Future Enhancements -1. **True Prolly Tree Integration**: Once Send/Sync issues are resolved +1. **Git-based Persistence**: Replace in-memory with durable git-based storage 2. **Real Embedding Models**: Integration with actual embedding services 3. **Conflict Resolution**: Handle concurrent memory updates 4. **Performance Metrics**: Track memory system performance @@ -200,16 +225,27 @@ To see the memory system in action: cargo run --example agent_memory_demo ``` -This demonstrates all four memory types, search capabilities, and system operations. +This demonstrates: +- All four memory types with prolly tree storage +- Conversation tracking and fact storage +- Episode recording and procedure management +- Tree statistics and checkpoint creation +- System optimization and cleanup ## Testing -The memory system includes comprehensive unit tests for each component. Run tests with: +The memory system includes comprehensive unit tests for each component, including prolly tree persistence tests. Run tests with: ```bash cargo test agent ``` +This will run all tests including: +- Basic prolly tree operations (save, load, delete) +- Key listing and range queries +- Tree statistics and checkpoints +- Memory lifecycle operations + ## Contributing The memory system is designed to be modular and extensible. Key areas for contribution: diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 9463cda..99292bf 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -321,7 +321,11 @@ mod tests { // Run optimization let report = memory_system.optimize().await.unwrap(); - // Should have processed some memories - assert!(report.total_processed() >= 0); // Might be 0 if no optimization needed + // Optimization report should be valid (total_processed is always >= 0 for usize) + // Just verify the report exists and has reasonable values + assert!(report.expired_cleaned <= 50); // Reasonable upper bound for test + assert!(report.memories_consolidated <= 50); + assert!(report.memories_archived <= 50); + assert!(report.memories_pruned <= 50); } } diff --git a/src/agent/simple_persistence.rs b/src/agent/simple_persistence.rs index 343f223..f4ba71b 100644 --- a/src/agent/simple_persistence.rs +++ b/src/agent/simple_persistence.rs @@ -1,37 +1,53 @@ use async_trait::async_trait; -use std::collections::HashMap; use std::error::Error; use std::path::Path; use std::sync::Arc; use tokio::sync::RwLock; use super::traits::MemoryPersistence; +use crate::config::TreeConfig; +use crate::storage::InMemoryNodeStorage; +use crate::tree::{ProllyTree, Tree}; -/// Simple in-memory persistence for demonstration -/// In production, this would be replaced with a proper prolly tree implementation -/// that handles the Send/Sync requirements properly +/// Prolly tree-based in-memory persistence for agent memory +/// This provides a more robust foundation than a simple HashMap +/// while maintaining thread safety and async compatibility pub struct SimpleMemoryPersistence { - data: Arc>>>, + tree: Arc>>>, namespace_prefix: String, + commit_counter: Arc>, } impl SimpleMemoryPersistence { - /// Initialize a new simple memory persistence layer + /// Initialize a new prolly tree-based memory persistence layer pub fn init>(_path: P, namespace_prefix: &str) -> Result> { + let storage = InMemoryNodeStorage::new(); + let config = TreeConfig::default(); + let tree = ProllyTree::new(storage, config); + Ok(Self { - data: Arc::new(RwLock::new(HashMap::new())), + tree: Arc::new(RwLock::new(tree)), namespace_prefix: namespace_prefix.to_string(), + commit_counter: Arc::new(RwLock::new(0)), }) } - /// Open an existing simple memory persistence layer (same as init for this implementation) + /// Open an existing prolly tree-based memory persistence layer + /// For in-memory storage, this is the same as init pub fn open>(_path: P, namespace_prefix: &str) -> Result> { Self::init(_path, namespace_prefix) } /// Get the full key with namespace prefix - fn full_key(&self, key: &str) -> String { - format!("{}/{}", self.namespace_prefix, key) + fn full_key(&self, key: &str) -> Vec { + format!("{}/{}", self.namespace_prefix, key).into_bytes() + } + + /// Generate next commit ID + async fn next_commit_id(&self) -> String { + let mut counter = self.commit_counter.write().await; + *counter += 1; + format!("prolly_commit_{:08}", *counter) } } @@ -39,62 +55,87 @@ impl SimpleMemoryPersistence { impl MemoryPersistence for SimpleMemoryPersistence { async fn save(&mut self, key: &str, data: &[u8]) -> Result<(), Box> { let full_key = self.full_key(key); - let mut store = self.data.write().await; - store.insert(full_key, data.to_vec()); + let mut tree = self.tree.write().await; + + // Insert into prolly tree + tree.insert(full_key, data.to_vec()); + Ok(()) } async fn load(&self, key: &str) -> Result>, Box> { let full_key = self.full_key(key); - let store = self.data.read().await; - Ok(store.get(&full_key).cloned()) + let tree = self.tree.read().await; + + // Get from prolly tree using find method + let result = tree.find(&full_key).and_then(|node| { + // Find the value in the node + node.keys + .iter() + .position(|k| k == &full_key) + .map(|index| node.values[index].clone()) + }); + + Ok(result) } async fn delete(&mut self, key: &str) -> Result<(), Box> { let full_key = self.full_key(key); - let mut store = self.data.write().await; - store.remove(&full_key); + let mut tree = self.tree.write().await; + + // Delete from prolly tree (returns bool indicating success) + tree.delete(&full_key); + Ok(()) } async fn list_keys(&self, prefix: &str) -> Result, Box> { - let full_prefix = format!("{}/{}", self.namespace_prefix, prefix); - let store = self.data.read().await; - - let keys = store - .keys() - .filter_map(|k| { - if k.starts_with(&full_prefix) { - k.strip_prefix(&format!("{}/", self.namespace_prefix)) + let namespace_prefix_with_slash = format!("{}/", self.namespace_prefix); + let tree = self.tree.read().await; + + // Get all keys and filter by prefix + let all_keys = tree.collect_keys(); + + + let matching_keys: Vec = all_keys + .into_iter() + .filter_map(|key| { + // First convert to string and strip namespace + String::from_utf8(key).ok().and_then(|s| { + s.strip_prefix(&namespace_prefix_with_slash) .map(|s| s.to_string()) - } else { - None - } + }) }) + .filter(|relative_key| relative_key.starts_with(prefix)) + .collect::>() // Deduplicate + .into_iter() .collect(); - Ok(keys) + Ok(matching_keys) } async fn checkpoint(&mut self, message: &str) -> Result> { - // For simple implementation, just return a mock commit ID - let commit_id = format!("simple_commit_{}", chrono::Utc::now().timestamp()); - println!("Created checkpoint: {} - {}", commit_id, message); + let commit_id = self.next_commit_id().await; + + // For in-memory storage, we just generate a commit ID + // In a real git-based implementation, this would create an actual commit + println!("Prolly tree checkpoint: {} - {}", commit_id, message); + Ok(commit_id) } } -/// Additional methods for the simple persistence +/// Additional methods specific to prolly tree persistence impl SimpleMemoryPersistence { - /// Create a new branch (no-op for simple implementation) + /// Create a new branch (for in-memory, this is a no-op) pub async fn create_branch(&mut self, name: &str) -> Result<(), Box> { - println!("Created branch: {name}"); + println!("Created prolly tree branch: {name}"); Ok(()) } - /// Switch to a branch or commit (no-op for simple implementation) + /// Switch to a branch or commit (for in-memory, this is a no-op) pub async fn checkout(&mut self, branch_or_commit: &str) -> Result<(), Box> { - println!("Checked out: {branch_or_commit}"); + println!("Checked out prolly tree: {branch_or_commit}"); Ok(()) } @@ -113,26 +154,207 @@ impl SimpleMemoryPersistence { vec![] } - /// Merge another branch + /// Merge another branch (for in-memory, this is a no-op) pub async fn merge(&mut self, branch: &str) -> Result> { - println!("Merged branch: {branch}"); - Ok(format!("merge_result_{}", chrono::Utc::now().timestamp())) + println!("Merged prolly tree branch: {branch}"); + // Use a simple timestamp instead of chrono for in-memory implementation + use std::time::{SystemTime, UNIX_EPOCH}; + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + Ok(format!("merge_result_{timestamp}")) } /// Get history of commits pub async fn history(&self, _limit: Option) -> Result, Box> { - Ok(vec!["Initial commit".to_string()]) + Ok(vec!["Initial prolly tree commit".to_string()]) } - /// Get data size for statistics - pub async fn data_size(&self) -> usize { - let store = self.data.read().await; - store.values().map(|v| v.len()).sum() + /// Get prolly tree statistics + pub async fn tree_stats(&self) -> Result> { + let tree = self.tree.read().await; + + // Get tree statistics using existing methods + let key_count = tree.size(); + let stats = tree.stats(); + + // Estimate total size from tree stats + let total_size_bytes = (stats.avg_node_size * stats.num_nodes as f64) as usize; + + Ok(ProllyTreeStats { + key_count, + total_size_bytes, + namespace_prefix: self.namespace_prefix.clone(), + }) + } + + /// Get the underlying tree (for advanced operations) + pub async fn with_tree(&self, f: F) -> R + where + F: FnOnce(&ProllyTree<32, InMemoryNodeStorage<32>>) -> R, + { + let tree = self.tree.read().await; + f(&tree) + } + + /// Perform a range query on the prolly tree + pub async fn range_query( + &self, + start_key: &str, + end_key: &str, + ) -> Result)>, Box> { + let start_key_bytes = self.full_key(start_key); + let end_key_bytes = self.full_key(end_key); + let namespace_prefix_with_slash = format!("{}/", self.namespace_prefix); + let tree = self.tree.read().await; + + // Get all entries and filter by range + let all_keys = tree.collect_keys(); + + // Use HashSet to deduplicate keys and then process + let unique_keys: std::collections::HashSet> = all_keys.into_iter().collect(); + let mut result = Vec::new(); + + for key_bytes in unique_keys { + if key_bytes >= start_key_bytes && key_bytes < end_key_bytes { + if let Some(node) = tree.find(&key_bytes) { + // Find the value in the node + if let Some(index) = node.keys.iter().position(|k| k == &key_bytes) { + let value = node.values[index].clone(); + if let Ok(key_str) = String::from_utf8(key_bytes) { + if let Some(relative_key) = + key_str.strip_prefix(&namespace_prefix_with_slash) + { + result.push((relative_key.to_string(), value)); + } + } + } + } + } + } + + result.sort_by(|a, b| a.0.cmp(&b.0)); + Ok(result) } +} + +/// Statistics about the prolly tree +#[derive(Debug, Clone)] +pub struct ProllyTreeStats { + pub key_count: usize, + pub total_size_bytes: usize, + pub namespace_prefix: String, +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[tokio::test] + async fn test_prolly_persistence_basic_operations() { + let temp_dir = TempDir::new().unwrap(); + let mut persistence = + SimpleMemoryPersistence::init(temp_dir.path(), "test_memories").unwrap(); + + // Test save + let key = "test_key"; + let data = b"test_data"; + persistence.save(key, data).await.unwrap(); + + // Test load + let loaded = persistence.load(key).await.unwrap(); + assert_eq!(loaded, Some(data.to_vec())); + + // Test update + let new_data = b"updated_data"; + persistence.save(key, new_data).await.unwrap(); + let loaded = persistence.load(key).await.unwrap(); + assert_eq!(loaded, Some(new_data.to_vec())); + + // Test delete + persistence.delete(key).await.unwrap(); + let loaded = persistence.load(key).await.unwrap(); + assert_eq!(loaded, None); + } + + #[tokio::test] + async fn test_prolly_persistence_checkpoint() { + let temp_dir = TempDir::new().unwrap(); + let mut persistence = + SimpleMemoryPersistence::init(temp_dir.path(), "test_memories").unwrap(); + + // Save some data + persistence.save("key1", b"data1").await.unwrap(); + persistence.save("key2", b"data2").await.unwrap(); + + // Create checkpoint + let commit_id = persistence.checkpoint("Test checkpoint").await.unwrap(); + assert!(commit_id.starts_with("prolly_commit_")); + } + + #[tokio::test] + async fn test_prolly_persistence_list_keys() { + let temp_dir = TempDir::new().unwrap(); + let mut persistence = + SimpleMemoryPersistence::init(temp_dir.path(), "test_memories").unwrap(); + + // Save data with different prefixes + persistence.save("user/1", b"user1").await.unwrap(); + persistence.save("user/2", b"user2").await.unwrap(); + persistence.save("system/config", b"config").await.unwrap(); + + // List keys with prefix + let user_keys = persistence.list_keys("user").await.unwrap(); + assert_eq!(user_keys.len(), 2); + assert!(user_keys.contains(&"user/1".to_string())); + assert!(user_keys.contains(&"user/2".to_string())); + + let system_keys = persistence.list_keys("system").await.unwrap(); + assert_eq!(system_keys.len(), 1); + assert!(system_keys.contains(&"system/config".to_string())); + } + + #[tokio::test] + async fn test_prolly_persistence_stats() { + let temp_dir = TempDir::new().unwrap(); + let mut persistence = + SimpleMemoryPersistence::init(temp_dir.path(), "test_memories").unwrap(); + + // Add some data + persistence.save("key1", b"data1").await.unwrap(); + persistence + .save("key2", b"longer_data_value") + .await + .unwrap(); + + // Get stats + let stats = persistence.tree_stats().await.unwrap(); + assert_eq!(stats.key_count, 2); + assert!(stats.total_size_bytes > 0); + assert_eq!(stats.namespace_prefix, "test_memories"); + } + + #[tokio::test] + async fn test_prolly_persistence_range_query() { + let temp_dir = TempDir::new().unwrap(); + let mut persistence = + SimpleMemoryPersistence::init(temp_dir.path(), "test_memories").unwrap(); + + // Add some data with sortable keys + persistence.save("key_a", b"data_a").await.unwrap(); + persistence.save("key_b", b"data_b").await.unwrap(); + persistence.save("key_c", b"data_c").await.unwrap(); + persistence.save("other_x", b"data_x").await.unwrap(); + + // Range query + let results = persistence.range_query("key_", "key_z").await.unwrap(); + assert_eq!(results.len(), 3); - /// Get key count - pub async fn key_count(&self) -> usize { - let store = self.data.read().await; - store.len() + // Should be sorted + assert_eq!(results[0].0, "key_a"); + assert_eq!(results[1].0, "key_b"); + assert_eq!(results[2].0, "key_c"); } } From d7872620ff445857a04df6cb1784c9a9ed2eb93c Mon Sep 17 00:00:00 2001 From: Feng Zhang Date: Sun, 27 Jul 2025 14:52:26 -0700 Subject: [PATCH 5/7] fix fmt --- src/agent/simple_persistence.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/agent/simple_persistence.rs b/src/agent/simple_persistence.rs index f4ba71b..5193f35 100644 --- a/src/agent/simple_persistence.rs +++ b/src/agent/simple_persistence.rs @@ -95,7 +95,6 @@ impl MemoryPersistence for SimpleMemoryPersistence { // Get all keys and filter by prefix let all_keys = tree.collect_keys(); - let matching_keys: Vec = all_keys .into_iter() From 3cab9b07e4629cc43b4b8bbf038b65351ed808a0 Mon Sep 17 00:00:00 2001 From: Feng Zhang Date: Sun, 27 Jul 2025 15:00:07 -0700 Subject: [PATCH 6/7] fix build --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index b6ae76a..0f1df35 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,7 @@ gluesql-core = { version = "0.15", optional = true } async-trait = { version = "0.1", optional = true } uuid = { version = "1.0", optional = true } futures = { version = "0.3", optional = true } -tokio = { version = "1.0", features = ["rt-multi-thread", "macros"], optional = true } +tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "sync"], optional = true } pyo3 = { version = "0.22", features = ["extension-module"], optional = true } rocksdb = { version = "0.22", optional = true } From 67cdb8fa3f0baca64d25257957f8c6a5e5c8a354 Mon Sep 17 00:00:00 2001 From: Feng Zhang Date: Sun, 27 Jul 2025 15:13:37 -0700 Subject: [PATCH 7/7] use rig for the agent-memory example --- .github/workflows/ci.yml | 4 +- Cargo.toml | 6 +- examples/agent.rs | 726 ++++++++++++++++++++++----------------- src/agent/README.md | 29 +- src/agent/mod.rs | 2 +- src/agent/short_term.rs | 2 +- 6 files changed, 437 insertions(+), 332 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cab9fbc..c6743cc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: run: cargo fmt --all -- --check - name: build - run: cargo build --verbose + run: cargo build --all --verbose - name: test run: cargo test --verbose @@ -30,7 +30,7 @@ jobs: run: cargo test --benches --no-run --verbose - name: clippy - run: cargo clippy + run: cargo clippy --all - name: docs run: cargo doc --document-private-items --no-deps diff --git a/Cargo.toml b/Cargo.toml index 0f1df35..da05949 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ async-trait = { version = "0.1", optional = true } uuid = { version = "1.0", optional = true } futures = { version = "0.3", optional = true } tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "sync"], optional = true } +rig-core = { version = "0.2.1", optional = true } pyo3 = { version = "0.22", features = ["extension-module"], optional = true } rocksdb = { version = "0.22", optional = true } @@ -58,6 +59,7 @@ prolly_balance_max_nodes = [] prolly_balance_rolling_hash = [] git = ["dep:gix", "dep:clap", "dep:lru", "dep:hex", "dep:chrono"] sql = ["dep:gluesql-core", "dep:async-trait", "dep:uuid", "dep:futures", "dep:tokio"] +rig = ["dep:rig-core", "dep:tokio", "dep:async-trait"] python = ["dep:pyo3"] rocksdb_storage = ["dep:rocksdb", "dep:lru"] @@ -109,9 +111,9 @@ path = "examples/storage.rs" required-features = ["rocksdb_storage"] [[example]] -name = "agent_memory_demo" +name = "agent_rig_demo" path = "examples/agent.rs" -required-features = ["git", "sql"] +required-features = ["git", "sql", "rig"] [workspace] members = ["examples/financial_advisor"] diff --git a/examples/agent.rs b/examples/agent.rs index 1ccea39..e984e80 100644 --- a/examples/agent.rs +++ b/examples/agent.rs @@ -1,360 +1,438 @@ use chrono::Duration; use prollytree::agent::*; +use rig::{completion::Prompt, providers::openai::Client}; +use serde::{Deserialize, Serialize}; use serde_json::json; use std::error::Error; use tempfile::TempDir; -#[tokio::main] -async fn main() -> Result<(), Box> { - println!("๐Ÿง  Agent Memory System Demo"); - println!("============================"); +/// A Rig-powered agent that uses the prolly tree memory system +pub struct IntelligentAgent { + memory_system: AgentMemorySystem, + rig_client: Option, + agent_id: String, + conversation_id: String, +} - // Create a temporary directory for this demo - let temp_dir = TempDir::new()?; - let memory_path = temp_dir.path(); +#[derive(Debug, Serialize, Deserialize)] +pub struct ConversationTurn { + pub role: String, + pub content: String, + pub timestamp: chrono::DateTime, +} - println!("๐Ÿ“ Initializing memory system at: {:?}", memory_path); +#[derive(Debug, Serialize, Deserialize)] +pub struct AgentResponse { + pub content: String, + pub reasoning: Option, + pub mode: ResponseMode, +} - // Initialize the agent memory system - let mut memory_system = AgentMemorySystem::init( - memory_path, - "demo_agent".to_string(), - Some(Box::new(MockEmbeddingGenerator)), // Use mock embeddings - )?; - - println!("โœ… Memory system initialized successfully!\n"); - - // Demonstrate Short-Term Memory - println!("๐Ÿ”„ Short-Term Memory Demo"); - println!("--------------------------"); - - let thread_id = "conversation_001"; - - // Store some conversation turns - memory_system - .short_term - .store_conversation_turn( - thread_id, - "user", - "Hello! I'm looking for help with my project.", - None, - ) - .await?; - - memory_system.short_term.store_conversation_turn( - thread_id, - "assistant", - "Hello! I'd be happy to help you with your project. What kind of project are you working on?", - None, - ).await?; - - memory_system - .short_term - .store_conversation_turn( - thread_id, - "user", - "I'm building a web application using Rust and need advice on database design.", - None, - ) - .await?; - - // Store some working memory - memory_system - .short_term - .store_working_memory( - thread_id, - "user_context", - json!({ - "project_type": "web_application", - "language": "rust", - "focus_area": "database_design" - }), - None, - ) - .await?; - - // Retrieve conversation history - let conversation = memory_system - .short_term - .get_conversation_history(thread_id, None) - .await?; - println!("๐Ÿ“ Stored {} conversation turns", conversation.len()); - - for (i, turn) in conversation.iter().enumerate() { - let role = turn - .content - .get("role") - .and_then(|r| r.as_str()) - .unwrap_or("unknown"); - let content = turn - .content - .get("content") - .and_then(|c| c.as_str()) - .unwrap_or(""); - println!(" {}. {}: {}", i + 1, role, content); +#[derive(Debug, Serialize, Deserialize)] +pub enum ResponseMode { + AIPowered, + MemoryBased, + Hybrid, +} + +impl IntelligentAgent { + /// Create a new intelligent agent with Rig integration and prolly tree memory + pub async fn new( + memory_path: &std::path::Path, + agent_id: String, + openai_api_key: Option, + ) -> Result> { + // Initialize the memory system with prolly tree persistence + let memory_system = AgentMemorySystem::init( + memory_path, + agent_id.clone(), + Some(Box::new(MockEmbeddingGenerator)), // Use mock embeddings for demo + )?; + + // Initialize Rig client if API key provided + let rig_client = openai_api_key.map(|key| Client::new(&key)); + + let conversation_id = format!("conversation_{}", chrono::Utc::now().timestamp()); + + Ok(Self { + memory_system, + rig_client, + agent_id, + conversation_id, + }) } - // Get working memory - if let Some(context) = memory_system - .short_term - .get_working_memory(thread_id, "user_context") - .await? - { - println!("๐Ÿง  Working memory context: {}", context); + /// Process a user message using memory and optionally AI + pub async fn process_message( + &mut self, + user_message: &str, + ) -> Result> { + println!("๐Ÿง  Processing message with memory and AI..."); + + // 1. Store user message in short-term memory + self.memory_system + .short_term + .store_conversation_turn(&self.conversation_id, "user", user_message, None) + .await?; + + // 2. Retrieve relevant context from memory + let context = self.retrieve_relevant_context(user_message).await?; + + // 3. Generate response using Rig if available, otherwise use memory-based response + let response = if let Some(ref client) = self.rig_client { + self.generate_ai_response(user_message, &context, client) + .await? + } else { + self.generate_memory_response(user_message, &context) + .await? + }; + + // 4. Store assistant response in memory + self.memory_system + .short_term + .store_conversation_turn(&self.conversation_id, "assistant", &response.content, None) + .await?; + + // 5. Learn from the interaction (episodic memory) + self.store_interaction_episode(user_message, &response) + .await?; + + // 6. Update procedural knowledge if applicable + self.update_procedural_knowledge(user_message, &response) + .await?; + + Ok(response) } - println!(); + /// Retrieve relevant context from all memory types + async fn retrieve_relevant_context(&self, message: &str) -> Result> { + let mut context_parts = Vec::new(); + + // Get recent conversation history + let recent_history = self + .memory_system + .short_term + .get_conversation_history(&self.conversation_id, Some(5)) + .await?; + + if !recent_history.is_empty() { + context_parts.push(format!( + "Recent conversation ({} turns):\n{}", + recent_history.len(), + recent_history + .iter() + .map(|turn| { + format!( + "{}: {}", + turn.content + .get("role") + .and_then(|r| r.as_str()) + .unwrap_or("unknown"), + turn.content + .get("content") + .and_then(|c| c.as_str()) + .unwrap_or("") + ) + }) + .collect::>() + .join("\n") + )); + } - // Demonstrate Semantic Memory - println!("๐Ÿงฉ Semantic Memory Demo"); - println!("------------------------"); - - // Store facts about entities - memory_system - .semantic - .store_fact( - "programming_language", - "rust", - json!({ - "type": "systems_programming", - "paradigm": ["functional", "imperative", "object-oriented"], - "memory_safety": true, - "performance": "high", - "use_cases": ["web_backends", "system_tools", "blockchain"] + // Search semantic memory for relevant facts + let semantic_query = MemoryQuery { + namespace: None, + memory_types: Some(vec![MemoryType::Semantic]), + tags: None, + time_range: None, + text_query: Some(message.to_string()), + semantic_query: None, + limit: Some(3), + include_expired: false, + }; + + let semantic_results = self.memory_system.semantic.query(semantic_query).await?; + if !semantic_results.is_empty() { + context_parts.push(format!( + "Relevant facts ({} items):\n{}", + semantic_results.len(), + semantic_results + .iter() + .map(|mem| format!("- {}", mem.content.to_string())) + .collect::>() + .join("\n") + )); + } + + // Get relevant past episodes + let episodic_query = MemoryQuery { + namespace: None, + memory_types: Some(vec![MemoryType::Episodic]), + tags: None, + time_range: Some(TimeRange { + start: Some(chrono::Utc::now() - Duration::days(7)), + end: Some(chrono::Utc::now()), }), - 0.95, - "knowledge_base", - ) - .await?; - - memory_system - .semantic - .store_fact( - "database", - "postgresql", - json!({ - "type": "relational", - "acid_compliant": true, - "supports_json": true, - "good_for": ["web_applications", "analytics", "geospatial"] + text_query: Some(message.to_string()), + semantic_query: None, + limit: Some(2), + include_expired: false, + }; + + let episodic_results = self.memory_system.episodic.query(episodic_query).await?; + if !episodic_results.is_empty() { + context_parts.push(format!( + "Past experiences ({} items):\n{}", + episodic_results.len(), + episodic_results + .iter() + .map(|mem| { + mem.content + .get("description") + .and_then(|d| d.as_str()) + .unwrap_or("No description") + }) + .collect::>() + .join("\n- ") + )); + } + + Ok(if context_parts.is_empty() { + "No relevant context found.".to_string() + } else { + context_parts.join("\n\n") + }) + } + + /// Generate AI-powered response using Rig + async fn generate_ai_response( + &self, + message: &str, + context: &str, + client: &Client, + ) -> Result> { + let prompt = format!( + "You are an intelligent assistant with access to conversation memory and knowledge base.\n\n\ + CONTEXT:\n{}\n\n\ + USER MESSAGE: {}\n\n\ + Provide a helpful, contextual response based on the available information.", + context, message + ); + + println!("๐Ÿค– Generating AI response with Rig..."); + + let agent = client + .agent("gpt-3.5-turbo") + .preamble("You are a helpful, knowledgeable assistant that uses context from previous conversations and stored knowledge to provide relevant responses.") + .max_tokens(300) + .temperature(0.7) + .build(); + + match agent.prompt(&prompt).await { + Ok(response) => Ok(AgentResponse { + content: response.trim().to_string(), + reasoning: Some("Generated using AI with memory context".to_string()), + mode: ResponseMode::AIPowered, }), - 0.9, - "knowledge_base", - ) - .await?; - - // Store relationships - memory_system - .semantic - .store_relationship( - ("programming_language", "rust"), - ("database", "postgresql"), - "commonly_used_with", - Some(json!({ - "drivers": ["tokio-postgres", "sqlx", "diesel"], - "compatibility": "excellent" - })), - 0.85, - ) - .await?; - - // Retrieve entity facts - let rust_facts = memory_system - .semantic - .get_entity_facts("programming_language", "rust") - .await?; - println!("๐Ÿ“Š Stored {} facts about Rust", rust_facts.len()); - - for fact in &rust_facts { - if let Some(fact_data) = fact.content.get("fact") { - println!(" - {}", fact_data); + Err(e) => { + println!( + "โš ๏ธ AI generation failed: {}, falling back to memory-based response", + e + ); + self.generate_memory_response(message, context).await + } } } - // Get relationships - let rust_relationships = memory_system - .semantic - .get_entity_relationships("programming_language", "rust") - .await?; - println!( - "๐Ÿ”— Found {} relationships for Rust", - rust_relationships.len() - ); + /// Generate memory-based response as fallback + async fn generate_memory_response( + &self, + message: &str, + context: &str, + ) -> Result> { + let response = if context.contains("No relevant context") { + format!( + "I understand you're asking about '{}'. While I don't have specific context about this topic yet, \ + I'm ready to learn and help. Could you provide more details?", + message + ) + } else { + format!( + "Based on our previous interactions and what I know:\n\n{}\n\n\ + Regarding your question about '{}', I can help you with this based on the context above.", + context, message + ) + }; + + Ok(AgentResponse { + content: response, + reasoning: Some("Generated using memory and rules".to_string()), + mode: ResponseMode::MemoryBased, + }) + } - println!(); + /// Store the interaction as an episode + async fn store_interaction_episode( + &mut self, + user_message: &str, + response: &AgentResponse, + ) -> Result<(), Box> { + self.memory_system + .episodic + .store_interaction( + "conversation", + vec![self.agent_id.clone(), "user".to_string()], + &format!("User: {} | Agent: {}", user_message, response.content), + json!({ + "user_message": user_message, + "response_content": response.content, + "response_mode": response.mode, + "conversation_id": self.conversation_id + }), + Some(0.8), // Positive interaction + ) + .await?; + + Ok(()) + } - // Demonstrate Episodic Memory - println!("๐Ÿ“š Episodic Memory Demo"); - println!("------------------------"); - - // Store an interaction episode - memory_system - .episodic - .store_interaction( - "technical_consultation", - vec!["demo_agent".to_string(), "user".to_string()], - "User sought advice on Rust web development and database design", - json!({ - "topics_discussed": ["rust", "web_development", "database_design", "postgresql"], - "user_experience_level": "intermediate", - "consultation_outcome": "successful" - }), - Some(0.8), // Positive sentiment - ) - .await?; - - // Store a learning episode - memory_system - .episodic - .store_episode( - "knowledge_acquisition", - "Learned about user's project requirements and provided relevant technical guidance", - json!({ - "knowledge_domain": "web_development", - "interaction_type": "Q&A", - "topics": ["rust", "postgresql", "database_design"] - }), - Some(json!({ - "knowledge_transferred": true, - "user_satisfaction": "high" - })), - 0.7, - ) - .await?; - - // Query recent episodes - let recent_episodes = memory_system - .episodic - .get_episodes_in_period(chrono::Utc::now() - Duration::hours(1), chrono::Utc::now()) - .await?; - - println!("๐Ÿ“… Found {} recent episodes", recent_episodes.len()); - for episode in &recent_episodes { - if let Some(desc) = episode.content.get("description").and_then(|d| d.as_str()) { - println!(" - {}", desc); + /// Update procedural knowledge based on interactions + async fn update_procedural_knowledge( + &mut self, + user_message: &str, + response: &AgentResponse, + ) -> Result<(), Box> { + // Store a procedure for handling similar questions + if user_message.to_lowercase().contains("help") { + self.memory_system + .procedural + .store_procedure( + "assistance", + "help_request_handler", + "How to handle user help requests effectively", + vec![ + json!({"step": 1, "action": "Analyze the help request for specific topics"}), + json!({"step": 2, "action": "Search memory for relevant context"}), + json!({"step": 3, "action": "Provide contextual assistance"}), + json!({"step": 4, "action": "Store the interaction for future reference"}), + ], + Some(json!({ + "triggers": ["help", "assist", "support"], + "effectiveness": match response.mode { + ResponseMode::AIPowered => "high", + ResponseMode::MemoryBased => "medium", + ResponseMode::Hybrid => "high", + } + })), + 7, // Medium-high priority + ) + .await?; } + + Ok(()) + } + + /// Get agent statistics + pub async fn get_stats(&self) -> Result> { + let system_stats = self.memory_system.get_system_stats().await?; + + Ok(json!({ + "agent_id": self.agent_id, + "conversation_id": self.conversation_id, + "memory_stats": system_stats, + "ai_enabled": self.rig_client.is_some() + })) + } + + /// Create a checkpoint of the agent's memory state + pub async fn checkpoint(&mut self, message: &str) -> Result> { + self.memory_system + .checkpoint(message) + .await + .map_err(|e| e.into()) } +} +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("๐Ÿค– Intelligent Agent with Rig + Prolly Tree Memory Demo"); + println!("======================================================"); + + // Create a temporary directory for this demo + let temp_dir = TempDir::new()?; + let memory_path = temp_dir.path(); + + println!("๐Ÿ“ Initializing agent at: {:?}", memory_path); + + // Initialize agent (without OpenAI API key for demo) + let mut agent = IntelligentAgent::new( + memory_path, + "smart_agent_001".to_string(), + std::env::var("OPENAI_API_KEY").ok(), // Will use memory-based responses if not set + ) + .await?; + + println!("โœ… Agent initialized successfully!"); println!(); - // Demonstrate Procedural Memory - println!("โš™๏ธ Procedural Memory Demo"); - println!("--------------------------"); - - // Store a procedure for database design - memory_system.procedural.store_procedure( - "database_design", - "design_web_app_schema", - "Standard procedure for designing database schema for web applications", - vec![ - json!({"step": 1, "action": "Identify main entities and their attributes"}), - json!({"step": 2, "action": "Define relationships between entities"}), - json!({"step": 3, "action": "Normalize the schema to reduce redundancy"}), - json!({"step": 4, "action": "Add indexes for performance optimization"}), - json!({"step": 5, "action": "Consider security and access patterns"}), - ], - Some(json!({ - "applicable_when": "designing new web application database", - "prerequisites": ["basic SQL knowledge", "understanding of application requirements"] - })), - 10, // High priority - ).await?; - - // Store a rule - memory_system - .procedural - .store_rule( - "consultation", - "provide_code_examples", - json!({ - "if": "user asks about implementation", - "and": "topic is within knowledge domain" - }), - json!({ - "then": "provide concrete code examples", - "include": ["comments", "error handling", "best practices"] - }), - 8, // Medium-high priority - true, // Enabled - ) - .await?; - - // Get procedures by category - let db_procedures = memory_system - .procedural - .get_procedures_by_category("database_design") - .await?; - println!( - "๐Ÿ“‹ Found {} database design procedures", - db_procedures.len() - ); - - for procedure in &db_procedures { - if let Some(name) = procedure.content.get("name").and_then(|n| n.as_str()) { - if let Some(steps) = procedure.content.get("steps").and_then(|s| s.as_array()) { - println!(" - {}: {} steps", name, steps.len()); - } + // Demo conversation sequence + let demo_messages = vec![ + "Hello! I'm learning about Rust programming. Can you help me?", + "What are the key benefits of using Rust for system programming?", + "I'm having trouble with ownership and borrowing. Any tips?", + "Can you recommend some resources for learning more about memory management in Rust?", + "Thank you for all the help! This has been very useful.", + ]; + + for (i, message) in demo_messages.iter().enumerate() { + println!("๐Ÿ’ฌ User Message {}: {}", i + 1, message); + println!("{}", "โ”€".repeat(80)); + + let response = agent.process_message(message).await?; + + println!("๐Ÿค– Agent Response ({:?}):", response.mode); + println!("{}", response.content); + + if let Some(reasoning) = response.reasoning { + println!("๐Ÿง  Reasoning: {}", reasoning); } + + println!(); + + // Add some delay between messages to make it more realistic + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; } - // Get active rules - let consultation_rules = memory_system - .procedural - .get_active_rules_by_category("consultation") - .await?; - println!( - "๐Ÿ“ Found {} active consultation rules", - consultation_rules.len() - ); + // Demonstrate memory capabilities + println!("๐Ÿ“Š Agent Memory Statistics:"); + println!("{}", "โ•".repeat(50)); - println!(); + let stats = agent.get_stats().await?; + println!("{}", serde_json::to_string_pretty(&stats)?); - // Demonstrate System Operations - println!("๐Ÿ”ง System Operations Demo"); - println!("--------------------------"); + println!(); // Create a checkpoint - let checkpoint_id = memory_system.checkpoint("Demo session complete").await?; - println!("๐Ÿ’พ Created checkpoint: {}", checkpoint_id); - - // Get system statistics - let stats = memory_system.get_system_stats().await?; - println!("๐Ÿ“Š System Statistics:"); - println!(" - Total memories: {}", stats.overall.total_memories); - println!(" - Short-term threads: {}", stats.short_term.total_threads); - println!( - " - Short-term conversations: {}", - stats.short_term.total_conversations - ); - println!(" - By type: {:?}", stats.overall.by_type); - - // Run system optimization - println!("\n๐Ÿงน Running system optimization..."); - let optimization_report = memory_system.optimize().await?; - println!("โœ… Optimization complete:"); - println!( - " - Expired cleaned: {}", - optimization_report.expired_cleaned - ); - println!( - " - Memories consolidated: {}", - optimization_report.memories_consolidated - ); - println!( - " - Memories archived: {}", - optimization_report.memories_archived - ); - println!( - " - Memories pruned: {}", - optimization_report.memories_pruned - ); - println!( - " - Total processed: {}", - optimization_report.total_processed() - ); - - println!("\n๐ŸŽ‰ Demo completed successfully!"); - println!("The agent memory system is now ready for production use."); + let checkpoint_id = agent.checkpoint("Demo conversation completed").await?; + println!("๐Ÿ’พ Created memory checkpoint: {}", checkpoint_id); + + println!(); + println!("๐ŸŽ‰ Demo completed successfully!"); + println!(); + println!("Key Features Demonstrated:"); + println!("โ€ข ๐Ÿง  Prolly tree-based memory persistence"); + println!("โ€ข ๐Ÿ’ฌ Conversation history tracking"); + println!("โ€ข ๐Ÿ“š Episodic memory of interactions"); + println!("โ€ข โš™๏ธ Procedural knowledge updating"); + println!("โ€ข ๐Ÿค– Rig framework integration (with fallback)"); + println!("โ€ข ๐Ÿ“Š Memory statistics and checkpoints"); + + if std::env::var("OPENAI_API_KEY").is_err() { + println!(); + println!("๐Ÿ’ก To enable AI-powered responses, set OPENAI_API_KEY environment variable:"); + println!(" export OPENAI_API_KEY=your_api_key_here"); + println!(" cargo run --example agent_rig_demo"); + } Ok(()) } diff --git a/src/agent/README.md b/src/agent/README.md index 6dc4f75..fd95d41 100644 --- a/src/agent/README.md +++ b/src/agent/README.md @@ -163,6 +163,8 @@ async fn main() -> Result<(), Box> { - Thread-safe async operations - Tree statistics and range queries - Commit tracking with sequential IDs +- **Rig framework integration** with AI-powered responses and intelligent fallback +- **Memory-contextual AI** that uses stored knowledge for better responses ### Planned ๐Ÿšง - Real embedding generation (currently uses mock) @@ -170,6 +172,7 @@ async fn main() -> Result<(), Box> { - Memory conflict resolution - Performance optimizations - Git-based prolly tree persistence for durability +- Multi-agent memory sharing through Rig ### Known Limitations - Mock embedding generation @@ -217,9 +220,11 @@ The memory system now uses prolly trees for storage with the following features: 5. **Memory Compression**: Efficient storage of large memories 6. **Distributed Memory**: Support for multi-agent memory sharing -## Running the Demo +## Running the Demos -To see the memory system in action: +### Basic Memory System Demo + +To see the core memory system in action: ```bash cargo run --example agent_memory_demo @@ -232,6 +237,26 @@ This demonstrates: - Tree statistics and checkpoint creation - System optimization and cleanup +### Rig Framework Integration Demo + +To see the memory system integrated with Rig framework for AI-powered agents: + +```bash +# With OpenAI API key (AI-powered responses) +OPENAI_API_KEY=your_key_here cargo run --example agent_rig_demo --features="git sql rig" + +# Without API key (memory-based fallback responses) +cargo run --example agent_rig_demo --features="git sql rig" +``` + +This demonstrates: +- ๐Ÿค– **Rig framework integration** for AI-powered responses +- ๐Ÿง  **Memory-contextual AI** using conversation history and stored knowledge +- ๐Ÿ”„ **Intelligent fallback** to memory-based responses when AI is unavailable +- ๐Ÿ“š **Contextual learning** from interactions stored in episodic memory +- โš™๏ธ **Procedural knowledge updates** based on conversation patterns +- ๐Ÿ“Š **Real-time memory statistics** and checkpoint management + ## Testing The memory system includes comprehensive unit tests for each component, including prolly tree persistence tests. Run tests with: diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 99292bf..d1c4f89 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -220,7 +220,7 @@ impl AgentMemorySystem { } /// Combined statistics for the entire memory system -#[derive(Debug, Clone)] +#[derive(Debug, Clone, serde::Serialize)] pub struct AgentMemoryStats { pub overall: MemoryStats, pub short_term: short_term::ShortTermStats, diff --git a/src/agent/short_term.rs b/src/agent/short_term.rs index 97227e9..59a9d2d 100644 --- a/src/agent/short_term.rs +++ b/src/agent/short_term.rs @@ -393,7 +393,7 @@ impl MemoryStore for ShortTermMemoryStore { } /// Statistics specific to short-term memory -#[derive(Debug, Clone)] +#[derive(Debug, Clone, serde::Serialize)] pub struct ShortTermStats { pub total_threads: usize, pub active_threads: usize,