Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 189 additions & 5 deletions crates/forge_services/src/mcp/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
tools: Arc<RwLock<HashMap<ToolName, ToolHolder<McpExecutor<C>>>>>,
failed_servers: Arc<RwLock<HashMap<ServerName, String>>>,
previous_config_hash: Arc<Mutex<u64>>,
init_lock: Arc<Mutex<()>>,
manager: Arc<M>,
infra: Arc<I>,
}
Expand All @@ -50,6 +51,7 @@
tools: Default::default(),
failed_servers: Default::default(),
previous_config_hash: Arc::new(Mutex::new(Default::default())),
init_lock: Arc::new(Mutex::new(())),
manager,
infra,
}
Expand Down Expand Up @@ -101,7 +103,17 @@
async fn init_mcp(&self) -> anyhow::Result<()> {
let mcp = self.manager.read_mcp_config(None).await?;

// If config is unchanged, skip reinitialization
// Fast path: if config is unchanged, skip reinitialization without acquiring
// the lock
if !self.is_config_modified(&mcp).await {
return Ok(());
}

// Serialise concurrent initialisations so only one caller runs update_mcp at a
// time
let _guard = self.init_lock.lock().await;

// Double-check under the lock: a concurrent caller may have already updated
if !self.is_config_modified(&mcp).await {
return Ok(());
}
Expand All @@ -110,9 +122,10 @@
}

async fn update_mcp(&self, mcp: McpConfig) -> Result<(), anyhow::Error> {
// Update the hash with the new config
// Compute the hash early before mcp is consumed, but write it only after
// all connections are established so waiters on init_lock see a consistent
// state.
let new_hash = mcp.cache_key();
*self.previous_config_hash.lock().await = new_hash;
self.clear_tools().await;

// Clear failed servers map before attempting new connections
Expand Down Expand Up @@ -149,6 +162,11 @@
}
}

// Write the hash only after join_all finishes so that any waiter on
// init_lock re-checks is_config_modified only once self.tools is fully
// populated, preventing "Tool not found" races.
*self.previous_config_hash.lock().await = new_hash;

Ok(())
}

Expand Down Expand Up @@ -194,6 +212,10 @@
/// when list() or call() is invoked, avoiding interactive OAuth during
/// reload.
async fn refresh_cache(&self) -> anyhow::Result<()> {
// Hold init_lock so we don't race with an in-flight update_mcp: without
// this, clear_tools could run while connections are still being
// established, leaving waiters released into an empty tool map.
let _guard = self.init_lock.lock().await;
// Clear the infra cache and reset config hash to force re-init on next access
self.infra.cache_clear().await?;
*self.previous_config_hash.lock().await = Default::default();
Expand Down Expand Up @@ -239,10 +261,141 @@

#[cfg(test)]
mod tests {
use forge_app::domain::{ServerName, ToolName};
use std::collections::BTreeMap;
use std::sync::Arc;

use fake::{Fake, Faker};
use forge_app::domain::{
ConfigOperation, Environment, McpConfig, McpServerConfig, Scope, ServerName, ToolCallFull,
ToolDefinition, ToolName, ToolOutput,
};
use forge_app::{
EnvironmentInfra, KVStore, McpClientInfra, McpConfigManager, McpServerInfra, McpService,
};
use forge_config::ForgeConfig;
use pretty_assertions::assert_eq;
use serde::de::DeserializeOwned;

use super::{ForgeMcpService, generate_mcp_tool_name};

// ── Mock MCP client ──────────────────────────────────────────────────────

#[derive(Clone)]
struct MockMcpClient;

#[async_trait::async_trait]
impl McpClientInfra for MockMcpClient {
async fn list(&self) -> anyhow::Result<Vec<ToolDefinition>> {
Ok(vec![ToolDefinition::new("test_tool")])
}

async fn call(
&self,
_tool_name: &ToolName,
_input: serde_json::Value,
) -> anyhow::Result<ToolOutput> {
Ok(ToolOutput::text("mock result"))
}
}

// ── Mock config manager ──────────────────────────────────────────────────

struct MockMcpManager;

#[async_trait::async_trait]
impl McpConfigManager for MockMcpManager {
async fn read_mcp_config(&self, _scope: Option<&Scope>) -> anyhow::Result<McpConfig> {
let mut servers = BTreeMap::new();
servers.insert(
ServerName::from("test-server".to_string()),
McpServerConfig::new_stdio("echo", vec![], None),
);
Ok(McpConfig { mcp_servers: servers })
}

async fn write_mcp_config(
&self,
_config: &McpConfig,
_scope: &Scope,
) -> anyhow::Result<()> {
Ok(())
}
}

// ── Mock infrastructure ──────────────────────────────────────────────────

#[derive(Clone)]
struct MockInfra;

#[async_trait::async_trait]
impl McpServerInfra for MockInfra {
type Client = MockMcpClient;

use super::generate_mcp_tool_name;
async fn connect(
&self,
_config: McpServerConfig,
_env_vars: &BTreeMap<String, String>,
_environment: &Environment,
) -> anyhow::Result<MockMcpClient> {
Ok(MockMcpClient)
}
}

#[async_trait::async_trait]
impl KVStore for MockInfra {
async fn cache_get<K, V>(&self, _key: &K) -> anyhow::Result<Option<V>>
where
K: std::hash::Hash + Sync,
V: serde::Serialize + DeserializeOwned + Send,
{
Ok(None)
}

async fn cache_set<K, V>(&self, _key: &K, _value: &V) -> anyhow::Result<()>
where
K: std::hash::Hash + Sync,
V: serde::Serialize + Sync,
{
Ok(())
}

async fn cache_clear(&self) -> anyhow::Result<()> {
Ok(())
}
}

impl EnvironmentInfra for MockInfra {
type Config = ForgeConfig;

fn get_env_var(&self, _key: &str) -> Option<String> {
None
}

fn get_env_vars(&self) -> BTreeMap<String, String> {
BTreeMap::new()
}

fn get_environment(&self) -> Environment {
Faker.fake()
}

fn get_config(&self) -> anyhow::Result<ForgeConfig> {
Ok(ForgeConfig::default())
}

fn update_environment(

Check warning on line 386 in crates/forge_services/src/mcp/service.rs

View workflow job for this annotation

GitHub Actions / Lint Fix

this function can be simplified using the `async fn` syntax
&self,
_ops: Vec<ConfigOperation>,
) -> impl std::future::Future<Output = anyhow::Result<()>> + Send {
async { Ok(()) }
}
}

// ── Fixture ──────────────────────────────────────────────────────────────

fn fixture() -> ForgeMcpService<MockMcpManager, MockInfra, MockMcpClient> {
ForgeMcpService::new(Arc::new(MockMcpManager), Arc::new(MockInfra))
}

#[test]
fn test_generate_mcp_tool_name_uses_legacy_format() {
Expand Down Expand Up @@ -286,4 +439,35 @@
let actual = ToolName::new("mcp_github_tool_create_issue").to_legacy_mcp_name();
assert_eq!(actual, None);
}

// ── Concurrent initialisation test ──────────────────────────────────────

/// Verify that two concurrent callers of `get_mcp_servers` do not race:
/// after both futures settle, every registered tool must be callable
/// without a "Tool not found" error.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_concurrent_init_does_not_race() {
let service = Arc::new(fixture());

let s1 = service.clone();
let s2 = service.clone();
let (r1, r2) = tokio::join!(s1.get_mcp_servers(), s2.get_mcp_servers());
r1.unwrap();
r2.unwrap();

let servers = service.get_mcp_servers().await.unwrap();
let tool_name = servers
.get_servers()
.values()
.flat_map(|tools| tools.iter())
.next()
.expect("at least one tool must be registered")
.name
.clone();

let call = ToolCallFull::new(tool_name);
let actual = service.execute_mcp(call).await.unwrap();
let expected = ToolOutput::text("mock result");
assert_eq!(actual, expected);
}
}
Loading