Skip to content
Merged

Mcp #18

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 250 additions & 24 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ toml = "0.8"
dirs = "6.0"
uuid = { version = "1", features = ["v4", "serde"] }
chrono = { version = "0.4", features = ["serde"] }
rmcp = { version = "1.6.0", default-features = false, features = ["client", "transport-child-process", "transport-streamable-http-client-reqwest"] }
Comment thread
themartto marked this conversation as resolved.

[dev-dependencies]
tempfile = "3"
3 changes: 2 additions & 1 deletion src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ pub async fn start_api_server(
tracing::info!(" WS /ws");

let llm_client: Arc<dyn LlmClient> = create_client(&config, &client);
let tool_executor: Arc<dyn ToolExecutor> = Arc::new(SystemToolExecutor::new());
let tool_executor: Arc<dyn ToolExecutor> =
Arc::new(SystemToolExecutor::build(&app_config.mcp_servers).await);
let rag_context = RagContext::new()?;

let server = HttpServer::new(move || {
Expand Down
9 changes: 5 additions & 4 deletions src/cli/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ struct SessionContext {
}

impl SessionContext {
fn new(
async fn new(
client: &Client,
agent_config: &AgentConfig,
app_config: &AppConfig,
Expand All @@ -88,7 +88,8 @@ impl SessionContext {
create_client(agent_config, client),
agent_config,
)?;
let tool_executor: Arc<dyn ToolExecutor> = Arc::new(SystemToolExecutor::new());
let tool_executor: Arc<dyn ToolExecutor> =
Arc::new(SystemToolExecutor::build(&app_config.mcp_servers).await);
let rag = RagContext::new()?;
let (conversation, prompt_builder) = rag.prepare(
chat_id,
Expand Down Expand Up @@ -177,7 +178,7 @@ pub async fn run_agent_mode(
let chat_id = resolve_chat_id(chat_id, continue_last)?;
let mut session = SessionContext::new(
client, config, app_config, model_name, max_iterations, chat_id, skill_names,
)?;
).await?;

let chat_id_short = &session.conversation.meta.id.to_string()[..8];
let chat_status = if session.conversation.messages.is_empty() {
Expand Down Expand Up @@ -290,7 +291,7 @@ pub async fn run_single_prompt(

let mut session = SessionContext::new(
client, config, app_config, model_name, max_iterations, chat_id, skill_names,
)?;
).await?;

session.conversation.messages.push(Message::user(prompt.clone()));

Expand Down
1 change: 1 addition & 0 deletions src/config/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ mod tests {
default_provider: "openai".into(),
max_iterations: 10,
providers,
mcp_servers: BTreeMap::new(),
}
}

Expand Down
13 changes: 13 additions & 0 deletions src/config/config.toml.default
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,16 @@ env_var = "ANTHROPIC_API_KEY"
# api_base = "http://localhost:11434/v1"
# default_model = "llama2"
# models = ["llama2", "mistral", "codellama", "mixtral"]

# --- MCP Server Configuration (Model Context Protocol) ---
# Tools are exposed to the LLM as "{server_name}__{tool_name}".
#
# Stdio transport — spawn a local process:
# [mcp_servers.filesystem]
# command = "npx"
# args = ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
# env = { }
#
# Streamable HTTP transport — connect to a running server:
# [mcp_servers.remote-tools]
# url = "http://localhost:8080/mcp"
2 changes: 1 addition & 1 deletion src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod resolve;
mod types;

pub use client::{build_http_client, create_client, resolve_client_and_config};
pub use types::{AgentConfig, AppConfig, ProviderConfig};
pub use types::{AgentConfig, AppConfig, McpServerConfig, ProviderConfig};

use std::path::PathBuf;

Expand Down
3 changes: 3 additions & 0 deletions src/config/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ mod tests {
default_provider: "openai".into(),
max_iterations: 5,
providers,
mcp_servers: BTreeMap::new(),
}
}

Expand Down Expand Up @@ -161,6 +162,7 @@ mod tests {
default_provider: "nonexistent".into(),
max_iterations: 10,
providers: BTreeMap::new(),
mcp_servers: BTreeMap::new(),
};
let err = config.resolve(None).unwrap_err();
assert!(err.to_string().contains("nonexistent"));
Expand All @@ -183,6 +185,7 @@ mod tests {
default_provider: "openai".into(),
max_iterations: 10,
providers: BTreeMap::new(),
mcp_servers: BTreeMap::new(),
};
let err = config.list_models().unwrap_err();
assert!(err.to_string().contains("No providers configured"));
Expand Down
20 changes: 19 additions & 1 deletion src/config/types.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;

/// Top-level configuration loaded from ~/.openheim/config.toml
Expand All @@ -10,6 +10,24 @@ pub struct AppConfig {
pub max_iterations: usize,
#[serde(default)]
pub providers: BTreeMap<String, ProviderConfig>,
#[serde(default)]
pub mcp_servers: BTreeMap<String, McpServerConfig>,
}

/// Configuration for a single MCP server connection.
/// The map key in `[mcp_servers.<name>]` is used as the server name and tool-name prefix.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpServerConfig {
/// Binary to spawn for stdio transport (e.g. `"npx"`, `"uvx"`).
pub command: Option<String>,
/// Arguments passed to `command`.
#[serde(default)]
pub args: Vec<String>,
/// Extra environment variables for the spawned process.
#[serde(default)]
pub env: HashMap<String, String>,
/// Base URL for Streamable HTTP transport (e.g. `"http://localhost:8080/mcp"`).
pub url: Option<String>,
}

fn default_max_iterations() -> usize {
Expand Down
8 changes: 6 additions & 2 deletions src/core/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::core::llm::LlmClient;
use crate::core::models::*;
use crate::error::Result;
use crate::rag::PromptBuilder;
use crate::tools::{get_available_tools, ToolExecutor};
use crate::tools::ToolExecutor;

async fn call_llm(
llm: &Arc<dyn LlmClient>,
Expand All @@ -31,7 +31,7 @@ async fn run_agent_loop(
verbose: bool,
mut callback: Option<&mut dyn FnMut(StreamEvent)>,
) -> Result<AgentResult> {
let tools = get_available_tools();
let tools = tool_executor.list_tools();
let mut steps = Vec::new();
let mut final_response = String::new();

Expand Down Expand Up @@ -284,6 +284,10 @@ mod tests {

#[async_trait]
impl ToolExecutor for MockToolExecutor {
fn list_tools(&self) -> Vec<Tool> {
vec![]
}

async fn execute(&self, name: &str, args_json: &str) -> Result<String> {
self.calls.lock().unwrap().push((name.into(), args_json.into()));
Ok(self.result.clone())
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod cli;
pub mod config;
pub mod core;
pub mod error;
pub mod mcp;
pub mod rag;
pub mod tools;

Expand All @@ -11,6 +12,5 @@ pub use core::{agent, llm, models};
pub use error::{Error, Result};
pub use models::*;

pub use tools::{execute_tool, get_available_tools};
pub use llm::{LlmClient, OpenAiClient, OpenAiCompatibleClient, AnthropicClient, GeminiClient};
pub use rag::{RagContext, PromptBuilder};
99 changes: 99 additions & 0 deletions src/mcp/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use rmcp::{
ServiceExt,
model::{CallToolRequestParams, Content, RawContent, ResourceContents, Tool},
service::{RoleClient, RunningService},
transport::{TokioChildProcess, streamable_http_client::StreamableHttpClientTransport},
};

use crate::{
config::McpServerConfig,
error::{Error, Result},
};

pub struct McpClient {
service: RunningService<RoleClient, ()>,
pub server_name: String,
}

impl McpClient {
pub async fn connect(name: &str, config: &McpServerConfig) -> Result<Self> {
if let Some(ref url) = config.url {
let transport = StreamableHttpClientTransport::from_uri(url.as_str());
let service = ()
.serve(transport)
.await
.map_err(|e| Error::Other(format!("MCP HTTP connect to '{}' failed: {}", name, e)))?;
Ok(Self { service, server_name: name.to_string() })
} else if let Some(ref command) = config.command {
let mut cmd = tokio::process::Command::new(command);
cmd.args(&config.args);
for (k, v) in &config.env {
cmd.env(k, v);
}
let transport = TokioChildProcess::new(cmd)
.map_err(|e| Error::Other(format!("MCP spawn '{}' failed: {}", name, e)))?;
let service = ()
.serve(transport)
.await
.map_err(|e| Error::Other(format!("MCP stdio connect to '{}' failed: {}", name, e)))?;
Ok(Self { service, server_name: name.to_string() })
} else {
Err(Error::ConfigError(format!(
"MCP server '{}' must have either 'command' (stdio) or 'url' (HTTP)",
name
)))
}
}

pub async fn list_tools(&self) -> Result<Vec<Tool>> {
self.service
.list_all_tools()
.await
.map_err(|e| Error::Other(format!("MCP list_tools failed for '{}': {}", self.server_name, e)))
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

pub async fn call_tool(&self, name: &str, args_json: &str) -> Result<String> {
let params = build_call_params(name, args_json)?;

let result = self
.service
.peer()
.call_tool(params)
.await
.map_err(|e| Error::ToolExecutionError(format!("MCP tool '{}' on '{}' failed: {}", name, self.server_name, e)))?;

if result.is_error.unwrap_or(false) {
return Err(Error::ToolExecutionError(extract_text_content(&result.content)));
}

Ok(extract_text_content(&result.content))
}
}

fn build_call_params(name: &str, args_json: &str) -> Result<CallToolRequestParams> {
let trimmed = args_json.trim();
if trimmed.is_empty() || trimmed == "{}" {
return Ok(CallToolRequestParams::new(name.to_string()));
}
let map: serde_json::Map<String, serde_json::Value> = serde_json::from_str(trimmed)?;
Ok(CallToolRequestParams::new(name.to_string()).with_arguments(map))
}

fn extract_text_content(content: &[Content]) -> String {
content
.iter()
.map(|item| match &**item {
RawContent::Text(t) => t.text.clone(),
RawContent::Image(i) => format!("[image: {}]", i.mime_type),
RawContent::Audio(a) => format!("[audio: {}]", a.mime_type),
RawContent::Resource(r) => match &r.resource {
ResourceContents::TextResourceContents { text, .. } => text.clone(),
ResourceContents::BlobResourceContents { uri, mime_type, .. } => {
format!("[blob: {} ({})]", uri, mime_type.as_deref().unwrap_or("unknown"))
}
},
RawContent::ResourceLink(l) => format!("[resource: {}]", l.uri),
})
.collect::<Vec<_>>()
.join("\n")
}
53 changes: 53 additions & 0 deletions src/mcp/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
mod client;
mod tool_handler;

use std::collections::BTreeMap;
use std::sync::Arc;

use client::McpClient;
use tool_handler::McpToolHandler;

use crate::{config::McpServerConfig, error::Result, tools::ToolHandler};

pub(crate) async fn load_mcp_tools(configs: &BTreeMap<String, McpServerConfig>) -> Vec<Box<dyn ToolHandler>> {
let mut handlers: Vec<Box<dyn ToolHandler>> = Vec::new();

for (name, config) in configs {
match connect_server(name, config).await {
Ok(server_handlers) => {
tracing::info!(
server = %name,
count = server_handlers.len(),
"MCP server connected"
);
handlers.extend(server_handlers);
}
Err(e) => {
tracing::warn!(server = %name, error = %e, "MCP server failed to connect");
}
}
Comment on lines +15 to +28
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -name "mod.rs" -path "*/mcp/*" -type f

Repository: weirdstuff-dev/openheim

Length of output: 84


🏁 Script executed:

find . -name "*.rs" | head -20

Repository: weirdstuff-dev/openheim

Length of output: 485


🏁 Script executed:

fd -e rs mcp | head -20

Repository: weirdstuff-dev/openheim

Length of output: 49


🏁 Script executed:

rg "connect_server" -A 5 -B 5

Repository: weirdstuff-dev/openheim

Length of output: 1234


🏁 Script executed:

rg "McpClient" -A 10 -B 2

Repository: weirdstuff-dev/openheim

Length of output: 4000


🏁 Script executed:

wc -l src/mcp/mod.rs

Repository: weirdstuff-dev/openheim

Length of output: 85


🏁 Script executed:

cat -n src/mcp/mod.rs

Repository: weirdstuff-dev/openheim

Length of output: 1954


🏁 Script executed:

wc -l src/mcp/client.rs

Repository: weirdstuff-dev/openheim

Length of output: 88


🏁 Script executed:

cat -n src/mcp/client.rs

Repository: weirdstuff-dev/openheim

Length of output: 4104


🏁 Script executed:

rg "timeout|Timeout" src/mcp/

Repository: weirdstuff-dev/openheim

Length of output: 49


Add a connection timeout to prevent MCP servers from blocking startup indefinitely.

McpClient::connect (both HTTP and stdio paths) and client.list_tools() have no timeout. A misconfigured or unresponsive MCP server will block startup indefinitely with no error output.

Proposed fix: wrap each `connect_server` call with `tokio::time::timeout`
+use std::time::Duration;
+
+const MCP_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);

 pub(crate) async fn load_mcp_tools(configs: &BTreeMap<String, McpServerConfig>) -> Vec<Box<dyn ToolHandler>> {
     let mut handlers: Vec<Box<dyn ToolHandler>> = Vec::new();

     for (name, config) in configs {
-        match connect_server(name, config).await {
-            Ok(server_handlers) => {
+        match tokio::time::timeout(MCP_CONNECT_TIMEOUT, connect_server(name, config)).await {
+            Ok(Ok(server_handlers)) => {
                 tracing::info!(
                     server = %name,
                     count = server_handlers.len(),
                     "MCP server connected"
                 );
                 handlers.extend(server_handlers);
             }
-            Err(e) => {
-                tracing::warn!(server = %name, error = %e, "MCP server failed to connect");
+            Ok(Err(e)) => {
+                tracing::warn!(server = %name, error = %e, "MCP server failed to connect");
+            }
+            Err(_) => {
+                tracing::warn!(server = %name, timeout_secs = MCP_CONNECT_TIMEOUT.as_secs(), "MCP server connection timed out");
             }
         }
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for (name, config) in configs {
match connect_server(name, config).await {
Ok(server_handlers) => {
tracing::info!(
server = %name,
count = server_handlers.len(),
"MCP server connected"
);
handlers.extend(server_handlers);
}
Err(e) => {
tracing::warn!(server = %name, error = %e, "MCP server failed to connect");
}
}
use std::time::Duration;
const MCP_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
pub(crate) async fn load_mcp_tools(configs: &BTreeMap<String, McpServerConfig>) -> Vec<Box<dyn ToolHandler>> {
let mut handlers: Vec<Box<dyn ToolHandler>> = Vec::new();
for (name, config) in configs {
match tokio::time::timeout(MCP_CONNECT_TIMEOUT, connect_server(name, config)).await {
Ok(Ok(server_handlers)) => {
tracing::info!(
server = %name,
count = server_handlers.len(),
"MCP server connected"
);
handlers.extend(server_handlers);
}
Ok(Err(e)) => {
tracing::warn!(server = %name, error = %e, "MCP server failed to connect");
}
Err(_) => {
tracing::warn!(server = %name, timeout_secs = MCP_CONNECT_TIMEOUT.as_secs(), "MCP server connection timed out");
}
}
}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/mcp/mod.rs` around lines 15 - 28, Wrap each call to connect_server(name,
config).await in a tokio::time::timeout(...) and handle the timeout case by
logging a warn with the server name and skipping that server; also add timeouts
inside McpClient::connect (both HTTP and stdio paths) and around
client.list_tools() so those internal awaits return an explicit timeout error
instead of blocking forever. Ensure you import tokio::time::Duration/timeout,
choose a sensible Duration constant, translate a timeout error into an Err
returned from connect_server (or into the same logging path used for other
connection errors) so connect_server, McpClient::connect, and callers uniformly
handle and report timeouts.

}

handlers
}

async fn connect_server(name: &str, config: &McpServerConfig) -> Result<Vec<Box<dyn ToolHandler>>> {
let client = Arc::new(McpClient::connect(name, config).await?);
let tools = client.list_tools().await?;

// Sanitise the prefix: hyphens and spaces become underscores so the
// combined name is a valid identifier for tool-call APIs.
let prefix: String = name
.chars()
.map(|c| if c.is_alphanumeric() || c == '_' { c } else { '_' })
.collect();

let handlers = tools
.iter()
.map(|tool| -> Box<dyn ToolHandler> {
Box::new(McpToolHandler::new(Arc::clone(&client), tool, &prefix))
})
.collect();

Ok(handlers)
}
52 changes: 52 additions & 0 deletions src/mcp/tool_handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use std::sync::Arc;

use async_trait::async_trait;
use rmcp::model::Tool as McpTool;

use crate::{
core::models::{FunctionDefinition, Tool},
error::Result,
tools::ToolHandler,
};

use super::client::McpClient;

pub struct McpToolHandler {
client: Arc<McpClient>,
/// Original tool name as reported by the MCP server.
tool_name: String,
/// Name exposed to the LLM: `{server_prefix}__{tool_name}`.
prefixed_name: String,
description: String,
schema: serde_json::Value,
}

impl McpToolHandler {
pub fn new(client: Arc<McpClient>, tool: &McpTool, server_prefix: &str) -> Self {
let tool_name = tool.name.to_string();
let prefixed_name = format!("{}__{}", server_prefix, tool_name);
let description = tool.description.as_deref().unwrap_or("").to_string();
let schema = serde_json::to_value(&tool.input_schema)
.unwrap_or_else(|_| serde_json::json!({"type": "object", "properties": {}}));

Self { client, tool_name, prefixed_name, description, schema }
}
}

#[async_trait]
impl ToolHandler for McpToolHandler {
fn definition(&self) -> Tool {
Tool {
tool_type: "function".to_string(),
function: FunctionDefinition {
name: self.prefixed_name.clone(),
description: self.description.clone(),
parameters: self.schema.clone(),
},
}
}

async fn execute(&self, args: &str) -> Result<String> {
self.client.call_tool(&self.tool_name, args).await
}
}
Loading