diff --git a/crates/forge_services/src/mcp/service.rs b/crates/forge_services/src/mcp/service.rs index 78329354ea..63075e9482 100644 --- a/crates/forge_services/src/mcp/service.rs +++ b/crates/forge_services/src/mcp/service.rs @@ -27,6 +27,7 @@ pub struct ForgeMcpService { tools: Arc>>>>, failed_servers: Arc>>, previous_config_hash: Arc>, + init_lock: Arc>, manager: Arc, infra: Arc, } @@ -50,6 +51,7 @@ where tools: Default::default(), failed_servers: Default::default(), previous_config_hash: Arc::new(Mutex::new(Default::default())), + init_lock: Arc::new(Mutex::new(())), manager, infra, } @@ -101,7 +103,17 @@ where async fn init_mcp(&self) -> anyhow::Result<()> { let mcp = self.manager.read_mcp_config(None).await?; - // If config is unchanged, skip reinitialization + // Fast path: if config is unchanged, skip reinitialization without acquiring + // the lock + if !self.is_config_modified(&mcp).await { + return Ok(()); + } + + // Serialise concurrent initialisations so only one caller runs update_mcp at a + // time + let _guard = self.init_lock.lock().await; + + // Double-check under the lock: a concurrent caller may have already updated if !self.is_config_modified(&mcp).await { return Ok(()); } @@ -110,9 +122,10 @@ where } async fn update_mcp(&self, mcp: McpConfig) -> Result<(), anyhow::Error> { - // Update the hash with the new config + // Compute the hash early before mcp is consumed, but write it only after + // all connections are established so waiters on init_lock see a consistent + // state. let new_hash = mcp.cache_key(); - *self.previous_config_hash.lock().await = new_hash; self.clear_tools().await; // Clear failed servers map before attempting new connections @@ -149,6 +162,11 @@ where } } + // Write the hash only after join_all finishes so that any waiter on + // init_lock re-checks is_config_modified only once self.tools is fully + // populated, preventing "Tool not found" races. + *self.previous_config_hash.lock().await = new_hash; + Ok(()) } @@ -194,6 +212,10 @@ where /// when list() or call() is invoked, avoiding interactive OAuth during /// reload. async fn refresh_cache(&self) -> anyhow::Result<()> { + // Hold init_lock so we don't race with an in-flight update_mcp: without + // this, clear_tools could run while connections are still being + // established, leaving waiters released into an empty tool map. + let _guard = self.init_lock.lock().await; // Clear the infra cache and reset config hash to force re-init on next access self.infra.cache_clear().await?; *self.previous_config_hash.lock().await = Default::default(); @@ -239,10 +261,141 @@ where #[cfg(test)] mod tests { - use forge_app::domain::{ServerName, ToolName}; + use std::collections::BTreeMap; + use std::sync::Arc; + + use fake::{Fake, Faker}; + use forge_app::domain::{ + ConfigOperation, Environment, McpConfig, McpServerConfig, Scope, ServerName, ToolCallFull, + ToolDefinition, ToolName, ToolOutput, + }; + use forge_app::{ + EnvironmentInfra, KVStore, McpClientInfra, McpConfigManager, McpServerInfra, McpService, + }; + use forge_config::ForgeConfig; use pretty_assertions::assert_eq; + use serde::de::DeserializeOwned; + + use super::{ForgeMcpService, generate_mcp_tool_name}; + + // ── Mock MCP client ────────────────────────────────────────────────────── + + #[derive(Clone)] + struct MockMcpClient; + + #[async_trait::async_trait] + impl McpClientInfra for MockMcpClient { + async fn list(&self) -> anyhow::Result> { + Ok(vec![ToolDefinition::new("test_tool")]) + } + + async fn call( + &self, + _tool_name: &ToolName, + _input: serde_json::Value, + ) -> anyhow::Result { + Ok(ToolOutput::text("mock result")) + } + } + + // ── Mock config manager ────────────────────────────────────────────────── + + struct MockMcpManager; + + #[async_trait::async_trait] + impl McpConfigManager for MockMcpManager { + async fn read_mcp_config(&self, _scope: Option<&Scope>) -> anyhow::Result { + let mut servers = BTreeMap::new(); + servers.insert( + ServerName::from("test-server".to_string()), + McpServerConfig::new_stdio("echo", vec![], None), + ); + Ok(McpConfig { mcp_servers: servers }) + } + + async fn write_mcp_config( + &self, + _config: &McpConfig, + _scope: &Scope, + ) -> anyhow::Result<()> { + Ok(()) + } + } + + // ── Mock infrastructure ────────────────────────────────────────────────── + + #[derive(Clone)] + struct MockInfra; + + #[async_trait::async_trait] + impl McpServerInfra for MockInfra { + type Client = MockMcpClient; - use super::generate_mcp_tool_name; + async fn connect( + &self, + _config: McpServerConfig, + _env_vars: &BTreeMap, + _environment: &Environment, + ) -> anyhow::Result { + Ok(MockMcpClient) + } + } + + #[async_trait::async_trait] + impl KVStore for MockInfra { + async fn cache_get(&self, _key: &K) -> anyhow::Result> + where + K: std::hash::Hash + Sync, + V: serde::Serialize + DeserializeOwned + Send, + { + Ok(None) + } + + async fn cache_set(&self, _key: &K, _value: &V) -> anyhow::Result<()> + where + K: std::hash::Hash + Sync, + V: serde::Serialize + Sync, + { + Ok(()) + } + + async fn cache_clear(&self) -> anyhow::Result<()> { + Ok(()) + } + } + + impl EnvironmentInfra for MockInfra { + type Config = ForgeConfig; + + fn get_env_var(&self, _key: &str) -> Option { + None + } + + fn get_env_vars(&self) -> BTreeMap { + BTreeMap::new() + } + + fn get_environment(&self) -> Environment { + Faker.fake() + } + + fn get_config(&self) -> anyhow::Result { + Ok(ForgeConfig::default()) + } + + fn update_environment( + &self, + _ops: Vec, + ) -> impl std::future::Future> + Send { + async { Ok(()) } + } + } + + // ── Fixture ────────────────────────────────────────────────────────────── + + fn fixture() -> ForgeMcpService { + ForgeMcpService::new(Arc::new(MockMcpManager), Arc::new(MockInfra)) + } #[test] fn test_generate_mcp_tool_name_uses_legacy_format() { @@ -286,4 +439,35 @@ mod tests { let actual = ToolName::new("mcp_github_tool_create_issue").to_legacy_mcp_name(); assert_eq!(actual, None); } + + // ── Concurrent initialisation test ────────────────────────────────────── + + /// Verify that two concurrent callers of `get_mcp_servers` do not race: + /// after both futures settle, every registered tool must be callable + /// without a "Tool not found" error. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_concurrent_init_does_not_race() { + let service = Arc::new(fixture()); + + let s1 = service.clone(); + let s2 = service.clone(); + let (r1, r2) = tokio::join!(s1.get_mcp_servers(), s2.get_mcp_servers()); + r1.unwrap(); + r2.unwrap(); + + let servers = service.get_mcp_servers().await.unwrap(); + let tool_name = servers + .get_servers() + .values() + .flat_map(|tools| tools.iter()) + .next() + .expect("at least one tool must be registered") + .name + .clone(); + + let call = ToolCallFull::new(tool_name); + let actual = service.execute_mcp(call).await.unwrap(); + let expected = ToolOutput::text("mock result"); + assert_eq!(actual, expected); + } }