Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 50 additions & 15 deletions crates/redisctl-mcp/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Application state and credential resolution

#[cfg(any(feature = "cloud", feature = "enterprise"))]
#[cfg(any(feature = "cloud", feature = "enterprise", feature = "database"))]
use std::collections::HashMap;
use std::sync::Arc;

Expand Down Expand Up @@ -30,12 +30,14 @@ pub enum CredentialSource {
},
}

/// Cached API clients (per-profile for multi-cluster support)
/// Cached API clients and connections (per-profile for multi-cluster support)
pub struct CachedClients {
#[cfg(feature = "cloud")]
pub cloud: HashMap<String, CloudClient>,
#[cfg(feature = "enterprise")]
pub enterprise: HashMap<String, EnterpriseClient>,
#[cfg(feature = "database")]
pub database: HashMap<String, redis::aio::MultiplexedConnection>,
}

/// Shared application state
Expand Down Expand Up @@ -85,6 +87,8 @@ impl AppState {
cloud: HashMap::new(),
#[cfg(feature = "enterprise")]
enterprise: HashMap::new(),
#[cfg(feature = "database")]
database: HashMap::new(),
}),
})
}
Expand Down Expand Up @@ -340,23 +344,46 @@ impl AppState {
Ok(format!("{}://{}{}:{}{}", scheme, auth, host, port, db_path))
}

/// Get Redis connection for direct database operations
/// Get or create a cached Redis connection for a resolved URL.
///
/// Connections are cached by URL. If a cached connection fails a PING
/// health check, it is evicted and a fresh connection is created.
#[cfg(feature = "database")]
#[allow(dead_code)]
pub async fn redis_connection(&self) -> Result<redis::aio::MultiplexedConnection> {
let url = self
.database_url
.as_ref()
.cloned()
.or_else(|| std::env::var("REDIS_URL").ok())
.context("No Redis URL configured")?;

let client = redis::Client::open(url.as_str()).context("Failed to create Redis client")?;
pub async fn redis_connection_for_url(
&self,
url: &str,
) -> Result<redis::aio::MultiplexedConnection> {
// Check cache first
{
let clients = self.clients.read().await;
if let Some(conn) = clients.database.get(url) {
// Quick health check -- if PING fails the connection is stale
let mut test_conn = conn.clone();
if redis::cmd("PING")
.query_async::<String>(&mut test_conn)
.await
.is_ok()
{
return Ok(conn.clone());
}
// Fall through to evict + reconnect
}
}

client
// Create new connection (or reconnect after eviction)
let client = redis::Client::open(url).context("Failed to create Redis client")?;
let conn = client
.get_multiplexed_async_connection()
.await
.context("Failed to connect to Redis")
.context("Failed to connect to Redis")?;

// Cache it
{
let mut clients = self.clients.write().await;
clients.database.insert(url.to_string(), conn.clone());
}

Ok(conn)
}

/// Check if write operations are allowed by the global policy tier.
Expand Down Expand Up @@ -395,6 +422,8 @@ impl Clone for AppState {
cloud: HashMap::new(),
#[cfg(feature = "enterprise")]
enterprise: HashMap::new(),
#[cfg(feature = "database")]
database: HashMap::new(),
}),
}
}
Expand Down Expand Up @@ -427,6 +456,8 @@ impl AppState {
cloud,
#[cfg(feature = "enterprise")]
enterprise: HashMap::new(),
#[cfg(feature = "database")]
database: HashMap::new(),
}),
}
}
Expand All @@ -446,6 +477,8 @@ impl AppState {
#[cfg(feature = "cloud")]
cloud: HashMap::new(),
enterprise,
#[cfg(feature = "database")]
database: HashMap::new(),
}),
}
}
Expand All @@ -466,6 +499,8 @@ impl AppState {
clients: RwLock::new(CachedClients {
cloud: cloud_map,
enterprise: enterprise_map,
#[cfg(feature = "database")]
database: HashMap::new(),
}),
}
}
Expand Down
41 changes: 8 additions & 33 deletions crates/redisctl-mcp/src/tools/redis/diagnostics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,8 @@ pub fn health_check(state: Arc<AppState>) -> Tool {
.extractor_handler_typed::<_, _, _, HealthCheckInput>(
state,
|State(state): State<Arc<AppState>>, Json(input): Json<HealthCheckInput>| async move {
let url = super::resolve_redis_url(input.url, input.profile.as_deref(), &state)?;

let client = redis::Client::open(url.as_str()).tool_context("Invalid URL")?;

let mut conn = client
.get_multiplexed_async_connection()
.await
.tool_context("Connection failed")?;
let mut conn =
super::get_connection(input.url, input.profile.as_deref(), &state).await?;

// PING
let ping_response: String = redis::cmd("PING")
Expand Down Expand Up @@ -248,14 +242,8 @@ pub fn key_summary(state: Arc<AppState>) -> Tool {
.extractor_handler_typed::<_, _, _, KeySummaryInput>(
state,
|State(state): State<Arc<AppState>>, Json(input): Json<KeySummaryInput>| async move {
let url = super::resolve_redis_url(input.url, input.profile.as_deref(), &state)?;

let client = redis::Client::open(url.as_str()).tool_context("Invalid URL")?;

let mut conn = client
.get_multiplexed_async_connection()
.await
.tool_context("Connection failed")?;
let mut conn =
super::get_connection(input.url, input.profile.as_deref(), &state).await?;

// TYPE
let key_type: String = redis::cmd("TYPE")
Expand Down Expand Up @@ -368,14 +356,8 @@ pub fn hotkeys(state: Arc<AppState>) -> Tool {
.extractor_handler_typed::<_, _, _, HotkeysInput>(
state,
|State(state): State<Arc<AppState>>, Json(input): Json<HotkeysInput>| async move {
let url = super::resolve_redis_url(input.url, input.profile.as_deref(), &state)?;

let client = redis::Client::open(url.as_str()).tool_context("Invalid URL")?;

let mut conn = client
.get_multiplexed_async_connection()
.await
.tool_context("Connection failed")?;
let mut conn =
super::get_connection(input.url, input.profile.as_deref(), &state).await?;

let pattern = input.pattern.as_deref().unwrap_or("*");
let sample_size = input.sample_size.unwrap_or(1000).min(MAX_SAMPLE_SIZE);
Expand Down Expand Up @@ -511,15 +493,8 @@ pub fn connection_summary(state: Arc<AppState>) -> Tool {
state,
|State(state): State<Arc<AppState>>,
Json(input): Json<ConnectionSummaryInput>| async move {
let url = super::resolve_redis_url(input.url, input.profile.as_deref(), &state)?;

let client = redis::Client::open(url.as_str())
.tool_context("Invalid URL")?;

let mut conn = client
.get_multiplexed_async_connection()
.await
.tool_context("Connection failed")?;
let mut conn =
super::get_connection(input.url, input.profile.as_deref(), &state).await?;

// CLIENT LIST
let client_list_raw: String = redis::cmd("CLIENT")
Expand Down
Loading
Loading