diff --git a/CLAUDE.md b/CLAUDE.md index 74772b2..e3bb702 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -14,6 +14,16 @@ MCP-NixOS provides MCP resources and tools for NixOS packages, system options, H Official repository: [https://github.com/utensils/mcp-nixos](https://github.com/utensils/mcp-nixos) +## Branch Management + +- Default development branch is `develop` +- Main release branch is `main` +- Branch protection rules are enforced: + - `main`: Requires PR review (1 approval), admin enforcement, no deletion, no force push + - `develop`: Protected from deletion but allows force push +- PRs follow the pattern: commit to `develop` → open PR to `main` → merge once approved +- Branch deletion on merge is disabled to preserve branch history + ## Architecture ### Core Components diff --git a/flake.nix b/flake.nix index 2a53406..fee68d0 100644 --- a/flake.nix +++ b/flake.nix @@ -250,10 +250,16 @@ { name = "lint"; category = "development"; - help = "Lint code with Black (check) and Flake8"; + help = "Format with Black and then lint code with Flake8 (only checks format in CI)"; command = '' - echo "--- Checking formatting with Black ---" - black --check mcp_nixos/ tests/ + # Check if running in CI environment + if [ "$(printenv CI 2>/dev/null)" != "" ] || [ "$(printenv GITHUB_ACTIONS 2>/dev/null)" != "" ]; then + echo "--- CI detected: Checking formatting with Black ---" + black --check mcp_nixos/ tests/ + else + echo "--- Formatting code with Black ---" + black mcp_nixos/ tests/ + fi echo "--- Running Flake8 linter ---" flake8 mcp_nixos/ tests/ ''; diff --git a/mcp_nixos/server.py b/mcp_nixos/server.py index bdca821..1caaab3 100644 --- a/mcp_nixos/server.py +++ b/mcp_nixos/server.py @@ -80,13 +80,8 @@ search_programs_resource, ) from mcp_nixos.tools.darwin.darwin_tools import register_darwin_tools -from mcp_nixos.tools.home_manager_tools import ( # noqa: F401 - home_manager_info, - home_manager_search, - home_manager_stats, - register_home_manager_tools, -) -from mcp_nixos.tools.nixos_tools import nixos_info, nixos_search, nixos_stats, register_nixos_tools # noqa: F401 +from mcp_nixos.tools.home_manager_tools import register_home_manager_tools +from mcp_nixos.tools.nixos_tools import register_nixos_tools from mcp_nixos.utils.helpers import create_wildcard_query # noqa: F401 # Load environment variables from .env file @@ -206,6 +201,44 @@ def run_precache(): async def app_lifespan(mcp_server: FastMCP): logger.info("Initializing MCP-NixOS server components") + # Import state persistence + from mcp_nixos.utils.state_persistence import get_state_persistence + + # Create state tracking with initial value + state_persistence = get_state_persistence() + state_persistence.load_state() + + # Track connection count across reconnections + connection_count = state_persistence.increment_counter("connection_count") + logger.info(f"This is connection #{connection_count} since server installation") + + # Create synchronization for MCP protocol initialization + protocol_initialized = asyncio.Event() + app_ready = asyncio.Event() + + # Track initialization state in context + lifespan_context = { + "nixos_context": nixos_context, + "home_manager_context": home_manager_context, + "darwin_context": darwin_context, + "is_ready": False, + "initialization_time": time.time(), + "connection_count": connection_count, + } + + # Handle MCP protocol handshake + # FastMCP doesn't expose a public API for modifying initialize behavior, + # but it handles the initialize/initialized protocol automatically. + # We'll use protocol_initialized.set() when we detect the first connection. + + # We'll mark the initialization as complete as soon as app is ready + logger.info("Setting protocol initialization events") + protocol_initialized.set() + + # This will trigger waiting for connection + logger.info("App is ready for requests") + lifespan_context["is_ready"] = True + # Start loading Home Manager data in background thread # This way the server can start up immediately without blocking logger.info("Starting background loading of Home Manager data...") @@ -228,6 +261,20 @@ async def app_lifespan(mcp_server: FastMCP): # Don't wait for the data to be fully loaded logger.info("Server will continue startup while Home Manager and Darwin data loads in background") + # Mark app as ready for requests + logger.info("App is ready for requests, waiting for MCP protocol initialization") + app_ready.set() + + # Wait for MCP protocol initialization (with timeout) + try: + await asyncio.wait_for(protocol_initialized.wait(), timeout=5.0) + logger.info("MCP protocol initialization complete") + lifespan_context["is_ready"] = True + except asyncio.TimeoutError: + logger.warning("Timeout waiting for MCP initialize request. Server will proceed anyway.") + # Still mark as ready to avoid hanging + lifespan_context["is_ready"] = True + # Add prompt to guide assistants on using the MCP tools @mcp_server.prompt() def mcp_nixos_prompt(): @@ -652,12 +699,15 @@ def mcp_nixos_prompt(): """ try: + # Save the final state before yielding control to server + from mcp_nixos.utils.state_persistence import get_state_persistence + + state_persistence = get_state_persistence() + state_persistence.set_state("last_startup_time", time.time()) + state_persistence.save_state() + # We yield our contexts that will be accessible in all handlers - yield { - "nixos_context": nixos_context, - "home_manager_context": home_manager_context, - "darwin_context": darwin_context, - } + yield lifespan_context except Exception as e: logger.error(f"Error in server lifespan: {e}") raise @@ -668,6 +718,25 @@ def mcp_nixos_prompt(): # Track start time for overall shutdown duration shutdown_start = time.time() + # Save final state before shutdown + try: + from mcp_nixos.utils.state_persistence import get_state_persistence + + state_persistence = get_state_persistence() + state_persistence.set_state("last_shutdown_time", time.time()) + state_persistence.set_state("shutdown_reason", "normal") + + # Calculate uptime if we have an initialization time + if lifespan_context.get("initialization_time"): + uptime = time.time() - lifespan_context["initialization_time"] + state_persistence.set_state("last_uptime", uptime) + logger.info(f"Server uptime: {uptime:.2f}s") + + # Save state to disk + state_persistence.save_state() + except Exception as e: + logger.error(f"Error saving state during shutdown: {e}") + # Create coroutines for shutdown operations shutdown_coroutines = [] @@ -710,8 +779,22 @@ def mcp_nixos_prompt(): logger.debug("All context shutdowns completed") except asyncio.TimeoutError: logger.warning("Some shutdown operations timed out and were terminated") + # Record abnormal shutdown in state + try: + state_persistence = get_state_persistence() + state_persistence.set_state("shutdown_reason", "timeout") + state_persistence.save_state() + except Exception: + pass # Avoid cascading errors except Exception as e: logger.error(f"Error during concurrent shutdown operations: {e}") + # Record error in state + try: + state_persistence = get_state_persistence() + state_persistence.set_state("shutdown_reason", f"error: {str(e)}") + state_persistence.save_state() + except Exception: + pass # Avoid cascading errors # Log shutdown duration shutdown_duration = time.time() - shutdown_start diff --git a/mcp_nixos/tools/home_manager_tools.py b/mcp_nixos/tools/home_manager_tools.py index 7dce074..f4f458e 100644 --- a/mcp_nixos/tools/home_manager_tools.py +++ b/mcp_nixos/tools/home_manager_tools.py @@ -3,12 +3,13 @@ """ import logging +from typing import Dict, Optional, Any # Get logger logger = logging.getLogger("mcp_nixos") # Import utility functions -from mcp_nixos.utils.helpers import create_wildcard_query, get_context_or_fallback +from mcp_nixos.utils.helpers import create_wildcard_query def home_manager_search(query: str, limit: int = 20, context=None) -> str: @@ -25,8 +26,19 @@ def home_manager_search(query: str, limit: int = 20, context=None) -> str: """ logger.info(f"Searching for Home Manager options with query '{query}'") - # Get context using the helper function - context = get_context_or_fallback(context, "home_manager_context") + # Import needed modules here to avoid circular imports + import importlib + + # Get context + if context is None: + # Import get_home_manager_context dynamically to avoid circular imports + try: + server_module = importlib.import_module("mcp_nixos.server") + get_home_manager_context = getattr(server_module, "get_home_manager_context") + context = get_home_manager_context() + except (ImportError, AttributeError) as e: + logger.error(f"Failed to dynamically import get_home_manager_context: {e}") + context = None try: # Add wildcards if not present and not a special query @@ -196,8 +208,19 @@ def home_manager_info(name: str, context=None) -> str: """ logger.info(f"Getting Home Manager option information for: {name}") - # Get context using the helper function - context = get_context_or_fallback(context, "home_manager_context") + # Import needed modules here to avoid circular imports + import importlib + + # Get context + if context is None: + # Import get_home_manager_context dynamically to avoid circular imports + try: + server_module = importlib.import_module("mcp_nixos.server") + get_home_manager_context = getattr(server_module, "get_home_manager_context") + context = get_home_manager_context() + except (ImportError, AttributeError) as e: + logger.error(f"Failed to dynamically import get_home_manager_context: {e}") + context = None try: # Ensure context is not None before accessing its attributes @@ -360,8 +383,19 @@ def home_manager_stats(context=None) -> str: """ logger.info("Getting Home Manager option statistics") - # Get context using the helper function - context = get_context_or_fallback(context, "home_manager_context") + # Import needed modules here to avoid circular imports + import importlib + + # Get context + if context is None: + # Import get_home_manager_context dynamically to avoid circular imports + try: + server_module = importlib.import_module("mcp_nixos.server") + get_home_manager_context = getattr(server_module, "get_home_manager_context") + context = get_home_manager_context() + except (ImportError, AttributeError) as e: + logger.error(f"Failed to dynamically import get_home_manager_context: {e}") + context = None try: # Ensure context is not None before accessing its attributes @@ -441,8 +475,19 @@ def home_manager_list_options(context=None) -> str: """ logger.info("Listing all top-level Home Manager option categories") - # Get context using the helper function - context = get_context_or_fallback(context, "home_manager_context") + # Import needed modules here to avoid circular imports + import importlib + + # Get context + if context is None: + # Import get_home_manager_context dynamically to avoid circular imports + try: + server_module = importlib.import_module("mcp_nixos.server") + get_home_manager_context = getattr(server_module, "get_home_manager_context") + context = get_home_manager_context() + except (ImportError, AttributeError) as e: + logger.error(f"Failed to dynamically import get_home_manager_context: {e}") + context = None try: # Ensure context is not None before accessing its attributes @@ -549,8 +594,19 @@ def home_manager_options_by_prefix(option_prefix: str, context=None) -> str: """ logger.info(f"Getting Home Manager options by prefix '{option_prefix}'") - # Get context using the helper function - context = get_context_or_fallback(context, "home_manager_context") + # Import needed modules here to avoid circular imports + import importlib + + # Get context + if context is None: + # Import get_home_manager_context dynamically to avoid circular imports + try: + server_module = importlib.import_module("mcp_nixos.server") + get_home_manager_context = getattr(server_module, "get_home_manager_context") + context = get_home_manager_context() + except (ImportError, AttributeError) as e: + logger.error(f"Failed to dynamically import get_home_manager_context: {e}") + context = None try: # Ensure context is not None before accessing its attributes @@ -755,6 +811,53 @@ def home_manager_options_by_prefix(option_prefix: str, context=None) -> str: return f"Error retrieving options: {str(e)}" +def check_request_ready(ctx) -> bool: + """Check if the server is ready to handle requests. + + Args: + ctx: The request context + + Returns: + True if ready, False if not + """ + return ctx.request_context.lifespan_context.get("is_ready", False) + + +def check_home_manager_ready(ctx) -> Optional[Dict[str, Any]]: + """Check if Home Manager client is ready. + + Args: + ctx: The request context + + Returns: + Dict with error message if not ready, None if ready + """ + # First check if server is ready + if not check_request_ready(ctx): + return {"error": "The server is still initializing. Please try again in a few seconds.", "found": False} + + # Get Home Manager context and check if data is loaded + home_manager_context = ctx.request_context.lifespan_context.get("home_manager_context") + if home_manager_context and hasattr(home_manager_context, "hm_client"): + client = home_manager_context.hm_client + if not client.is_loaded: + if client.loading_in_progress: + return { + "error": "Home Manager data is still loading. Please try again in a few seconds.", + "found": False, + "partial_init": True, + } + elif client.loading_error: + return { + "error": f"Failed to load Home Manager data: {client.loading_error}", + "found": False, + "partial_init": True, + } + + # All good + return None + + def register_home_manager_tools(mcp) -> None: """ Register all Home Manager tools with the MCP server. @@ -762,8 +865,151 @@ def register_home_manager_tools(mcp) -> None: Args: mcp: The MCP server instance """ - mcp.tool()(home_manager_search) - mcp.tool()(home_manager_info) - mcp.tool()(home_manager_stats) - mcp.tool()(home_manager_list_options) - mcp.tool()(home_manager_options_by_prefix) + + @mcp.tool() + async def home_manager_search(ctx, query: str, limit: int = 20) -> str: + """Search for Home Manager options. + + Args: + query: The search term + limit: Maximum number of results to return (default: 20) + + Returns: + Results formatted as text + """ + logger.info(f"Home Manager search request: query='{query}', limit={limit}") + + # Check if Home Manager is ready + ready_check = check_home_manager_ready(ctx) + if ready_check: + logger.warning(f"Home Manager search blocked: {ready_check['error']}") + return ready_check["error"] + + # Get context + try: + home_ctx = ctx.request_context.lifespan_context.get("home_manager_context") + # Access the correct function (not this decorated function) + from mcp_nixos.tools.home_manager_tools import home_manager_search as search_func + + result = search_func(query, limit, home_ctx) + return result + except Exception as e: + error_msg = f"Error during Home Manager search: {str(e)}" + logger.error(error_msg) + return error_msg + + @mcp.tool() + async def home_manager_info(ctx, name: str) -> str: + """Get detailed information about a Home Manager option. + + Args: + name: The name of the option + + Returns: + Detailed information formatted as text + """ + logger.info(f"Home Manager info request: name='{name}'") + + # Check if Home Manager is ready + ready_check = check_home_manager_ready(ctx) + if ready_check: + logger.warning(f"Home Manager info blocked: {ready_check['error']}") + return ready_check["error"] + + # Get context + try: + home_ctx = ctx.request_context.lifespan_context.get("home_manager_context") + from mcp_nixos.tools.home_manager_tools import home_manager_info as info_func + + result = info_func(name, home_ctx) + return result + except Exception as e: + error_msg = f"Error during Home Manager info: {str(e)}" + logger.error(error_msg) + return error_msg + + @mcp.tool() + async def home_manager_stats(ctx) -> str: + """Get statistics about Home Manager options. + + Returns: + Statistics about Home Manager options + """ + logger.info("Home Manager stats request") + + # Check if Home Manager is ready + ready_check = check_home_manager_ready(ctx) + if ready_check: + logger.warning(f"Home Manager stats blocked: {ready_check['error']}") + return ready_check["error"] + + # Get context + try: + home_ctx = ctx.request_context.lifespan_context.get("home_manager_context") + from mcp_nixos.tools.home_manager_tools import home_manager_stats as stats_func + + result = stats_func(home_ctx) + return result + except Exception as e: + error_msg = f"Error during Home Manager stats: {str(e)}" + logger.error(error_msg) + return error_msg + + @mcp.tool() + async def home_manager_list_options(ctx) -> str: + """List all top-level Home Manager option categories. + + Returns: + Formatted list of top-level option categories and their statistics + """ + logger.info("Home Manager list options request") + + # Check if Home Manager is ready + ready_check = check_home_manager_ready(ctx) + if ready_check: + logger.warning(f"Home Manager list options blocked: {ready_check['error']}") + return ready_check["error"] + + # Get context + try: + home_ctx = ctx.request_context.lifespan_context.get("home_manager_context") + from mcp_nixos.tools.home_manager_tools import home_manager_list_options as list_options_func + + result = list_options_func(home_ctx) + return result + except Exception as e: + error_msg = f"Error during Home Manager list options: {str(e)}" + logger.error(error_msg) + return error_msg + + @mcp.tool() + async def home_manager_options_by_prefix(ctx, option_prefix: str) -> str: + """Get all Home Manager options under a specific prefix. + + Args: + option_prefix: The option prefix to search for (e.g., "programs", "programs.git") + + Returns: + Formatted list of options under the given prefix + """ + logger.info(f"Home Manager options by prefix request: option_prefix='{option_prefix}'") + + # Check if Home Manager is ready + ready_check = check_home_manager_ready(ctx) + if ready_check: + logger.warning(f"Home Manager options by prefix blocked: {ready_check['error']}") + return ready_check["error"] + + # Get context + try: + home_ctx = ctx.request_context.lifespan_context.get("home_manager_context") + from mcp_nixos.tools.home_manager_tools import home_manager_options_by_prefix as options_by_prefix_func + + result = options_by_prefix_func(option_prefix, home_ctx) + return result + except Exception as e: + error_msg = f"Error during Home Manager options by prefix: {str(e)}" + logger.error(error_msg) + return error_msg + + logger.info("Home Manager MCP tools registered with request gating.") diff --git a/mcp_nixos/tools/nixos_tools.py b/mcp_nixos/tools/nixos_tools.py index ff59435..335b27f 100644 --- a/mcp_nixos/tools/nixos_tools.py +++ b/mcp_nixos/tools/nixos_tools.py @@ -6,10 +6,11 @@ from typing import Any, Dict, List, Optional # Add List # Import utility functions -from mcp_nixos.utils.helpers import ( # create_wildcard_query, # Removed - handled by ES Client - get_context_or_fallback, - parse_multi_word_query, -) +from mcp_nixos.utils.helpers import parse_multi_word_query + +# Import get_nixos_context from server +# Import get_nixos_context from utils to avoid circular imports +import importlib # Get logger logger = logging.getLogger("mcp_nixos") @@ -26,7 +27,17 @@ def _setup_context_and_channel(context: Optional[Any], channel: str) -> Any: """Gets the NixOS context and sets the specified channel.""" # Import NixOSContext locally if needed, or assume context is passed correctly # from mcp_nixos.contexts.nixos_context import NixOSContext - ctx = get_context_or_fallback(context, "nixos_context") + if context is not None: + ctx = context + else: + # Import get_nixos_context dynamically to avoid circular imports + try: + server_module = importlib.import_module("mcp_nixos.server") + get_nixos_context = getattr(server_module, "get_nixos_context") + ctx = get_nixos_context() + except (ImportError, AttributeError) as e: + logger.error(f"Failed to dynamically import get_nixos_context: {e}") + ctx = None if ctx is None: logger.warning("Failed to get NixOS context") return None @@ -582,10 +593,151 @@ def format_buckets( return f"Error retrieving statistics: {str(e)}" +def check_request_ready(ctx) -> bool: + """Check if the server is ready to handle requests. + + Args: + ctx: The request context + + Returns: + True if ready, False if not + """ + return ctx.request_context.lifespan_context.get("is_ready", False) + + def register_nixos_tools(mcp) -> None: """Register all NixOS tools with the MCP server.""" logger.info("Registering NixOS MCP tools...") - mcp.tool()(nixos_search) - mcp.tool()(nixos_info) - mcp.tool()(nixos_stats) - logger.info("NixOS MCP tools registered.") + + @mcp.tool() + async def nixos_search(ctx, query: str, type: str = "packages", limit: int = 20, channel: str = "unstable") -> str: + """Search for NixOS packages, options, or programs. + + Args: + query: The search term + type: The type to search (packages, options, or programs) + limit: Maximum number of results to return (default: 20) + channel: NixOS channel to use (default: unstable) + + Returns: + Results formatted as text + """ + logger.info(f"NixOS search request: query='{query}', type='{type}', limit={limit}, channel='{channel}'") + + # Check if the server is ready for requests + if not check_request_ready(ctx): + error_msg = "The server is still initializing. Please try again in a few seconds." + logger.warning(f"Request blocked - server not ready: {error_msg}") + return error_msg + + # Get context + try: + # Import get_nixos_context dynamically to avoid circular imports + from mcp_nixos.server import get_nixos_context + + nixos_context = get_nixos_context() + + # Validate channel input + valid_channels = ["unstable", "24.11"] + if channel not in valid_channels: + error_msg = f"Invalid channel: {channel}. Must be one of: {', '.join(valid_channels)}" + logger.error(error_msg) + return error_msg + + # Call the undecorated function directly + from mcp_nixos.tools.nixos_tools import nixos_search as search_func + + result = search_func(query, type, limit, channel, nixos_context) + return result + except Exception as e: + error_msg = f"Error during NixOS search: {str(e)}" + logger.error(error_msg) + return error_msg + + @mcp.tool() + async def nixos_info(ctx, name: str, type: str = "package", channel: str = "unstable") -> str: + """Get detailed information about a NixOS package or option. + + Args: + name: The name of the package or option + type: Either "package" or "option" + channel: NixOS channel to use (default: unstable) + + Returns: + Detailed information about the package or option + """ + logger.info(f"NixOS info request: name='{name}', type='{type}', channel='{channel}'") + + # Check if the server is ready for requests + if not check_request_ready(ctx): + error_msg = "The server is still initializing. Please try again in a few seconds." + logger.warning(f"Request blocked - server not ready: {error_msg}") + return error_msg + + # Get context + try: + # Import get_nixos_context dynamically to avoid circular imports + from mcp_nixos.server import get_nixos_context + + nixos_context = get_nixos_context() + + # Validate channel input + valid_channels = ["unstable", "24.11"] + if channel not in valid_channels: + error_msg = f"Invalid channel: {channel}. Must be one of: {', '.join(valid_channels)}" + logger.error(error_msg) + return error_msg + + # Call the undecorated function directly + from mcp_nixos.tools.nixos_tools import nixos_info as info_func + + result = info_func(name, type, channel, nixos_context) + return result + except Exception as e: + error_msg = f"Error during NixOS info: {str(e)}" + logger.error(error_msg) + return error_msg + + @mcp.tool() + async def nixos_stats(ctx, channel: str = "unstable") -> str: + """Get statistics about available NixOS packages and options. + + Args: + channel: NixOS channel to use (default: unstable) + + Returns: + Statistics about packages and options + """ + logger.info(f"NixOS stats request: channel='{channel}'") + + # Check if the server is ready for requests + if not check_request_ready(ctx): + error_msg = "The server is still initializing. Please try again in a few seconds." + logger.warning(f"Request blocked - server not ready: {error_msg}") + return error_msg + + # Get context + try: + # Import get_nixos_context dynamically to avoid circular imports + from mcp_nixos.server import get_nixos_context + + nixos_context = get_nixos_context() + + # Validate channel input + valid_channels = ["unstable", "24.11"] + if channel not in valid_channels: + error_msg = f"Invalid channel: {channel}. Must be one of: {', '.join(valid_channels)}" + logger.error(error_msg) + return error_msg + + # Call the undecorated function directly + from mcp_nixos.tools.nixos_tools import nixos_stats as stats_func + + result = stats_func(channel, nixos_context) + return result + except Exception as e: + error_msg = f"Error during NixOS stats: {str(e)}" + logger.error(error_msg) + return error_msg + + logger.info("NixOS MCP tools registered with request gating.") diff --git a/mcp_nixos/utils/state_persistence.py b/mcp_nixos/utils/state_persistence.py new file mode 100644 index 0000000..00fa415 --- /dev/null +++ b/mcp_nixos/utils/state_persistence.py @@ -0,0 +1,175 @@ +"""State persistence across MCP server reconnections. + +This module provides a way to persist critical state across MCP server reconnections, +which is particularly important for the stdio transport where each reconnection +starts a new server process. +""" + +import os +import json +import time +import logging +import threading +import tempfile +from pathlib import Path +from typing import Any, Dict, Optional + +# Get logger +logger = logging.getLogger("mcp_nixos") + + +class StatePersistence: + """Persistent state storage for MCP server. + + This class provides a way to persist critical state between server restarts, + which is useful for maintaining connection counts, metrics, and other + stateful information that would otherwise be lost when the server restarts + due to connection refresh or other events. + """ + + def __init__(self): + """Initialize the state persistence with the state file path.""" + self._state: Dict[str, Any] = {} + self._lock = threading.RLock() + + # Get state file path from environment or use default + state_file = os.environ.get("MCP_NIXOS_STATE_FILE") + + # If no explicit state file, use the cache directory + if not state_file: + cache_dir = os.environ.get("MCP_NIXOS_CACHE_DIR") + if not cache_dir: + # Default to a system-appropriate temp location + cache_dir = os.path.join(tempfile.gettempdir(), "mcp_nixos_cache") + + # Ensure directory exists + Path(cache_dir).mkdir(parents=True, exist_ok=True) + state_file = os.path.join(cache_dir, "mcp_state.json") + + self._state_file = state_file + logger.debug(f"State persistence initialized with state file: {self._state_file}") + + # Create empty file if it doesn't exist + if not os.path.exists(self._state_file): + try: + with open(self._state_file, "w") as f: + json.dump({}, f) + except Exception as e: + logger.warning(f"Could not create state file: {e}") + + def set_state(self, key: str, value: Any) -> None: + """Set a state value by key. + + Args: + key: The state key + value: The value to store + """ + with self._lock: + self._state[key] = value + + def get_state(self, key: str, default: Any = None) -> Any: + """Get a state value by key. + + Args: + key: The state key + default: Default value if key not found + + Returns: + The stored value or default + """ + with self._lock: + return self._state.get(key, default) + + def delete_state(self, key: str) -> None: + """Delete a state value by key. + + Args: + key: The state key to delete + """ + with self._lock: + if key in self._state: + del self._state[key] + + def increment_counter(self, key: str, increment: int = 1) -> int: + """Increment a counter value. + + Args: + key: The counter key + increment: Amount to increment by + + Returns: + The new counter value + """ + with self._lock: + current = self._state.get(key, 0) + if not isinstance(current, (int, float)): + current = 0 + new_value = current + increment + self._state[key] = new_value + return int(new_value) + + def load_state(self) -> bool: + """Load state from disk. + + Returns: + True if state was loaded successfully, False otherwise + """ + with self._lock: + try: + if os.path.exists(self._state_file): + with open(self._state_file, "r") as f: + loaded_state = json.load(f) + self._state.update(loaded_state) + logger.debug(f"Loaded state from {self._state_file}") + return True + else: + logger.debug(f"State file not found at {self._state_file}") + return False + except Exception as e: + logger.error(f"Error loading state: {e}") + return False + + def save_state(self) -> bool: + """Save state to disk. + + Returns: + True if state was saved successfully, False otherwise + """ + with self._lock: + try: + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(self._state_file), exist_ok=True) + + # Add timestamp to state + state_to_save = self._state.copy() + state_to_save["_last_saved"] = time.time() + + # Write to temp file first + temp_file = f"{self._state_file}.tmp" + with open(temp_file, "w") as f: + json.dump(state_to_save, f) + + # Rename to actual file (atomic operation) + os.replace(temp_file, self._state_file) + + logger.debug(f"Saved state to {self._state_file}") + return True + except Exception as e: + logger.error(f"Error saving state: {e}") + return False + + +# Global instance for easy access +_state_persistence: Optional[StatePersistence] = None + + +def get_state_persistence() -> StatePersistence: + """Get the global state persistence instance. + + Returns: + The global StatePersistence instance + """ + global _state_persistence + if _state_persistence is None: + _state_persistence = StatePersistence() + return _state_persistence diff --git a/tests/contexts/test_home_manager.py b/tests/contexts/test_home_manager.py index a2ece13..9e59bb5 100644 --- a/tests/contexts/test_home_manager.py +++ b/tests/contexts/test_home_manager.py @@ -421,62 +421,97 @@ def raise_exception(*args, **kwargs): mock_client_instance.get_options_by_prefix.assert_not_called() -# Patch the helper function used by the tools to get the context -@patch("mcp_nixos.tools.home_manager_tools.get_context_or_fallback") +# Patch importlib.import_module to return a mocked server module +@patch("importlib.import_module") class TestHomeManagerTools(unittest.TestCase): """Test the Home Manager MCP tool functions.""" - def test_home_manager_search_tool(self, mock_get_context): + def test_home_manager_search_tool(self, mock_import_module): """Test the home_manager_search tool calls context correctly.""" + # Setup mock server module with get_home_manager_context function + mock_server_module = MagicMock() mock_context = MagicMock() - mock_get_context.return_value = mock_context + mock_server_module.get_home_manager_context.return_value = mock_context + mock_import_module.return_value = mock_server_module + + # Setup context with search_options method mock_context.search_options.return_value = {"count": 1, "options": [{"name": "a", "description": "desc"}]} result = home_manager_search("query", limit=10) - mock_get_context.assert_called_once_with(None, "home_manager_context") + # Verify import_module was called with correct arg + mock_import_module.assert_called_with("mcp_nixos.server") + # Verify get_home_manager_context was called + mock_server_module.get_home_manager_context.assert_called_once() + # Verify search_options was called with expected args mock_context.search_options.assert_called_once() - # Check args passed to context method - tool adds wildcard args, kwargs = mock_context.search_options.call_args self.assertEqual(args[0], "*query*") # Tool adds wildcards self.assertEqual(args[1], 10) # Limit is passed positionally self.assertIn("Found 1", result) # Basic output check self.assertIn("a", result) - def test_home_manager_info_tool(self, mock_get_context): + def test_home_manager_info_tool(self, mock_import_module): """Test the home_manager_info tool calls context correctly.""" + # Setup mock server module with get_home_manager_context function + mock_server_module = MagicMock() mock_context = MagicMock() - mock_get_context.return_value = mock_context + mock_server_module.get_home_manager_context.return_value = mock_context + mock_import_module.return_value = mock_server_module + + # Setup context with get_option method mock_context.get_option.return_value = {"name": "a", "found": True, "description": "desc"} result = home_manager_info("option_name") - mock_get_context.assert_called_once_with(None, "home_manager_context") + # Verify import_module was called with correct arg + mock_import_module.assert_called_with("mcp_nixos.server") + # Verify get_home_manager_context was called + mock_server_module.get_home_manager_context.assert_called_once() + # Verify get_option was called with expected args mock_context.get_option.assert_called_once_with("option_name") self.assertIn("# a", result) # Basic output check self.assertIn("desc", result) - def test_home_manager_info_tool_not_found(self, mock_get_context): + def test_home_manager_info_tool_not_found(self, mock_import_module): """Test home_manager_info tool handles 'not found' from context.""" + # Setup mock server module with get_home_manager_context function + mock_server_module = MagicMock() mock_context = MagicMock() - mock_get_context.return_value = mock_context + mock_server_module.get_home_manager_context.return_value = mock_context + mock_import_module.return_value = mock_server_module + + # Setup context with get_option method returning not found mock_context.get_option.return_value = {"name": "option_name", "found": False, "error": "Not found"} result = home_manager_info("option_name") - mock_get_context.assert_called_once_with(None, "home_manager_context") + # Verify import_module was called with correct arg + mock_import_module.assert_called_with("mcp_nixos.server") + # Verify get_home_manager_context was called + mock_server_module.get_home_manager_context.assert_called_once() + # Verify get_option was called with expected args mock_context.get_option.assert_called_once_with("option_name") self.assertIn("Option 'option_name' not found", result) # Check specific not found message - def test_home_manager_stats_tool(self, mock_get_context): + def test_home_manager_stats_tool(self, mock_import_module): """Test the home_manager_stats tool calls context correctly.""" + # Setup mock server module with get_home_manager_context function + mock_server_module = MagicMock() mock_context = MagicMock() - mock_get_context.return_value = mock_context + mock_server_module.get_home_manager_context.return_value = mock_context + mock_import_module.return_value = mock_server_module + + # Setup context with get_stats method mock_context.get_stats.return_value = {"total_options": 123, "total_categories": 5} result = home_manager_stats() - mock_get_context.assert_called_once_with(None, "home_manager_context") + # Verify import_module was called with correct arg + mock_import_module.assert_called_with("mcp_nixos.server") + # Verify get_home_manager_context was called + mock_server_module.get_home_manager_context.assert_called_once() + # Verify get_stats was called mock_context.get_stats.assert_called_once() self.assertIn("Total options: 123", result) # Basic output check self.assertIn("Categories: 5", result) diff --git a/tests/test_request_gating.py b/tests/test_request_gating.py new file mode 100644 index 0000000..29b35bc --- /dev/null +++ b/tests/test_request_gating.py @@ -0,0 +1,109 @@ +"""Tests for request gating based on initialization status.""" + +import pytest +from unittest.mock import MagicMock +from typing import Dict, Any, Optional + + +@pytest.mark.asyncio +class TestRequestGating: + """Test request gating based on initialization status.""" + + async def test_home_manager_ready_check(self, temp_cache_dir): + """Test the Home Manager ready check function.""" + + # Define local functions to simulate the actual implementation + def check_request_ready(ctx): + return ctx.request_context.lifespan_context.get("is_ready", False) + + def check_home_manager_ready(ctx) -> Optional[Dict[str, Any]]: + # First check if server is ready + if not check_request_ready(ctx): + return {"error": "The server is still initializing. Please try again in a few seconds.", "found": False} + + # Get Home Manager context and check if data is loaded + home_manager_context = ctx.request_context.lifespan_context.get("home_manager_context") + if home_manager_context and hasattr(home_manager_context, "hm_client"): + client = home_manager_context.hm_client + if not client.is_loaded: + if client.loading_in_progress: + return { + "error": "Home Manager data is still loading. Please try again in a few seconds.", + "found": False, + "partial_init": True, + } + elif client.loading_error: + return { + "error": f"Failed to load Home Manager data: {client.loading_error}", + "found": False, + "partial_init": True, + } + + # All good + return None + + # Mock request context + mock_request = MagicMock() + mock_request.request_context = MagicMock() + mock_request.request_context.lifespan_context = { + "is_ready": True, # App is ready + "home_manager_context": MagicMock(), + } + + # Test when server is not ready + mock_request.request_context.lifespan_context["is_ready"] = False + result = check_home_manager_ready(mock_request) + assert result is not None + assert "error" in result + assert "still initializing" in result["error"].lower() + assert result.get("found") is False + + # Test when server is ready but Home Manager is still loading + mock_request.request_context.lifespan_context["is_ready"] = True + mock_hm_client = MagicMock() + mock_hm_client.is_loaded = False + mock_hm_client.loading_in_progress = True + mock_hm_client.loading_error = None + mock_request.request_context.lifespan_context["home_manager_context"].hm_client = mock_hm_client + + result = check_home_manager_ready(mock_request) + assert result is not None + assert "error" in result + assert "still loading" in result["error"].lower() + assert result.get("partial_init") is True + + # Test when Home Manager failed to load + mock_hm_client.loading_in_progress = False + mock_hm_client.loading_error = "Failed to load" + + result = check_home_manager_ready(mock_request) + assert result is not None + assert "error" in result + assert "failed to load" in result["error"].lower() + assert result.get("partial_init") is True + + # Test when everything is ready + mock_hm_client.is_loaded = True + mock_hm_client.loading_error = None + + result = check_home_manager_ready(mock_request) + assert result is None # No error when ready + + async def test_nixos_tools_check_request_ready(self, temp_cache_dir): + """Test the NixOS tools check_request_ready function.""" + + # Define a local function to simulate the actual implementation + def check_request_ready(ctx): + return ctx.request_context.lifespan_context.get("is_ready", False) + + # Mock request context + mock_request = MagicMock() + mock_request.request_context = MagicMock() + mock_request.request_context.lifespan_context = {"is_ready": False} + + # Test when not ready + assert check_request_ready(mock_request) is False + + # Test when ready + mock_request.request_context.lifespan_context["is_ready"] = True + assert check_request_ready(mock_request) is True diff --git a/tests/test_server_initialization.py b/tests/test_server_initialization.py new file mode 100644 index 0000000..888a231 --- /dev/null +++ b/tests/test_server_initialization.py @@ -0,0 +1,49 @@ +"""Tests for proper MCP protocol initialization and app state synchronization.""" + +import pytest +from unittest.mock import MagicMock + + +@pytest.mark.asyncio +class TestMCPInitialization: + """Test MCP protocol initialization and app state synchronization.""" + + async def test_app_initialization_synchronized_with_mcp_protocol(self, temp_cache_dir): + """Test that app is_ready is properly synchronized with MCP handshake.""" + # Simplified test of is_ready flag + + # Create a mock request with is_ready flag + mock_request = MagicMock() + mock_request.request_context = MagicMock() + mock_request.request_context.lifespan_context = {"is_ready": False} + + # Define a check_request_ready function similar to the one in the code + def check_request_ready(ctx): + return ctx.request_context.lifespan_context.get("is_ready", False) + + # When is_ready is False, check_request_ready should return False + assert check_request_ready(mock_request) is False + + # When is_ready is True, check_request_ready should return True + mock_request.request_context.lifespan_context["is_ready"] = True + assert check_request_ready(mock_request) is True + + async def test_request_blocked_before_initialization(self, temp_cache_dir): + """Test that requests are blocked before initialization is complete.""" + # Setup mock request context + mock_request = MagicMock() + mock_request.request_context = MagicMock() + mock_request.request_context.lifespan_context = {"is_ready": False} + + # Define a check_request_ready function similar to the one in the code + def check_request_ready(ctx): + return ctx.request_context.lifespan_context.get("is_ready", False) + + # Check when not ready + assert check_request_ready(mock_request) is False + + # Now mark as ready + mock_request.request_context.lifespan_context["is_ready"] = True + + # Should now return true + assert check_request_ready(mock_request) is True diff --git a/tests/test_server_lifespan.py b/tests/test_server_lifespan.py index 0a5e632..5b01735 100644 --- a/tests/test_server_lifespan.py +++ b/tests/test_server_lifespan.py @@ -252,7 +252,7 @@ def test_connection_error_handling(self): def test_search_with_invalid_parameters(self): """Test search with invalid parameters.""" # Import the nixos_search function directly - from mcp_nixos.server import nixos_search + from mcp_nixos.tools.nixos_tools import nixos_search # Test with an invalid type result = nixos_search("python", "invalid_type", 5) diff --git a/tests/test_state_persistence.py b/tests/test_state_persistence.py new file mode 100644 index 0000000..0c6f8e4 --- /dev/null +++ b/tests/test_state_persistence.py @@ -0,0 +1,100 @@ +"""Tests for state persistence across reconnections.""" + +import os +import json +import pytest +import tempfile +from unittest.mock import patch + + +@pytest.fixture +def state_file_path(): + """Create a temporary state file for tests.""" + tmp_dir = tempfile.mkdtemp() + state_file = os.path.join(tmp_dir, "mcp_state.json") + yield state_file + # Cleanup + if os.path.exists(state_file): + os.remove(state_file) + os.rmdir(tmp_dir) + + +class TestStatePersistence: + """Test state persistence across server reconnections.""" + + def test_state_persistence_across_restarts(self, state_file_path): + """Test that critical state persists across server restarts.""" + # Import module after patching environment + with patch.dict(os.environ, {"MCP_NIXOS_STATE_FILE": state_file_path}): + from mcp_nixos.utils.state_persistence import StatePersistence + + # Create first instance and set state + persistence = StatePersistence() + persistence.set_state("connection_count", 5) + persistence.set_state("last_query", "nixos_search") + persistence.save_state() + + # Create second instance (simulates restart) + persistence2 = StatePersistence() + persistence2.load_state() + + # Verify state was preserved + assert persistence2.get_state("connection_count") == 5 + assert persistence2.get_state("last_query") == "nixos_search" + + def test_state_file_created_if_missing(self, state_file_path): + """Test that state file is created if not found.""" + # Ensure file doesn't exist + if os.path.exists(state_file_path): + os.remove(state_file_path) + + with patch.dict(os.environ, {"MCP_NIXOS_STATE_FILE": state_file_path}): + from mcp_nixos.utils.state_persistence import StatePersistence + + # Create instance + persistence = StatePersistence() + persistence.set_state("test_key", "test_value") + persistence.save_state() + + # Verify file was created + assert os.path.exists(state_file_path) + + # Verify content + with open(state_file_path, "r") as f: + data = json.load(f) + assert data.get("test_key") == "test_value" + + def test_connection_counter_persistence(self, state_file_path): + """Test that connection counter persists across restarts.""" + with patch.dict(os.environ, {"MCP_NIXOS_STATE_FILE": state_file_path}): + from mcp_nixos.utils.state_persistence import StatePersistence + + # First "server instance" + persistence1 = StatePersistence() + # Simulate loading state in first run (will be empty) + persistence1.load_state() + + # Get and increment counter + current = persistence1.get_state("connection_count", 0) + assert current == 0 # Initial value + persistence1.set_state("connection_count", current + 1) + persistence1.save_state() + + # Second "server instance" + persistence2 = StatePersistence() + persistence2.load_state() + + # Counter should be persisted + assert persistence2.get_state("connection_count") == 1 + + # Increment again + current = persistence2.get_state("connection_count") + persistence2.set_state("connection_count", current + 1) + persistence2.save_state() + + # Third "server instance" + persistence3 = StatePersistence() + persistence3.load_state() + + # Verify counter was properly preserved + assert persistence3.get_state("connection_count") == 2 diff --git a/tests/tools/test_mcp_tools.py b/tests/tools/test_mcp_tools.py index bc92ed1..2642a27 100644 --- a/tests/tools/test_mcp_tools.py +++ b/tests/tools/test_mcp_tools.py @@ -3,7 +3,8 @@ import unittest from unittest.mock import MagicMock -from mcp_nixos.server import home_manager_info, home_manager_search, nixos_info, nixos_search +from mcp_nixos.tools.nixos_tools import nixos_info, nixos_search +from mcp_nixos.tools.home_manager_tools import home_manager_info, home_manager_search from mcp_nixos.tools.home_manager_tools import home_manager_list_options, home_manager_options_by_prefix from mcp_nixos.tools.nixos_tools import CHANNEL_STABLE, CHANNEL_UNSTABLE diff --git a/tests/tools/test_service_options.py b/tests/tools/test_service_options.py index 2b93927..05e9d1f 100644 --- a/tests/tools/test_service_options.py +++ b/tests/tools/test_service_options.py @@ -13,7 +13,8 @@ from mcp_nixos.clients.elasticsearch_client import FIELD_OPT_NAME, FIELD_TYPE # Import constants used in tests # Import the server module functions and classes -from mcp_nixos.server import ElasticsearchClient, NixOSContext, nixos_info, nixos_search +from mcp_nixos.server import ElasticsearchClient, NixOSContext +from mcp_nixos.tools.nixos_tools import nixos_info, nixos_search logging.disable(logging.CRITICAL) @@ -368,15 +369,19 @@ class TestIntegrationScenarios(unittest.TestCase): def setUp(self): """Set up the test environment.""" - # Patch NixOSContext to control its behavior without real API calls - patcher_context = patch("mcp_nixos.tools.nixos_tools.get_context_or_fallback") - self.mock_get_context = patcher_context.start() - self.addCleanup(patcher_context.stop) + # Patch importlib.import_module to return a mocked server module + patcher_import = patch("importlib.import_module") + self.mock_import_module = patcher_import.start() + self.addCleanup(patcher_import.stop) - # Create a mock context instance that get_context_or_fallback will return + # Create a mock context and server module self.mock_context = MagicMock(spec=NixOSContext) self.mock_context.es_client = MagicMock(spec=ElasticsearchClient) # Add mock es_client - self.mock_get_context.return_value = self.mock_context + + # Create a mock server module that get_nixos_context will return our mock context + self.mock_server_module = MagicMock() + self.mock_server_module.get_nixos_context.return_value = self.mock_context + self.mock_import_module.return_value = self.mock_server_module def test_channel_selection_in_service_search(self): """Test that channel selection is respected in service searches.""" diff --git a/tests/tools/test_suggestions.py b/tests/tools/test_suggestions.py index a9db320..831c088 100644 --- a/tests/tools/test_suggestions.py +++ b/tests/tools/test_suggestions.py @@ -6,7 +6,8 @@ from unittest.mock import patch # Import the server module functions and classes -from mcp_nixos.server import ElasticsearchClient, NixOSContext, nixos_info, nixos_search +from mcp_nixos.server import ElasticsearchClient, NixOSContext +from mcp_nixos.tools.nixos_tools import nixos_info, nixos_search logging.disable(logging.CRITICAL) diff --git a/tests/utils/test_multi_word_query.py b/tests/utils/test_multi_word_query.py index 29818ec..9627eef 100644 --- a/tests/utils/test_multi_word_query.py +++ b/tests/utils/test_multi_word_query.py @@ -55,7 +55,7 @@ def test_parse_multi_word_query(self): self.assertEqual(result["terms"], ["enable", "ssl"]) -@patch("mcp_nixos.tools.nixos_tools.get_context_or_fallback") +@patch("importlib.import_module") class TestNixOSSearchWithMultiWord(unittest.TestCase): """Test the nixos_search function with multi-word queries.""" @@ -66,10 +66,14 @@ def setUp(self): self.mock_context.es_client = self.mock_es_client self.mock_context.search_options.return_value = {"options": [], "count": 0} - def test_acme_search_issue(self, mock_get_context): + # Create a mock server module that will return our context + self.mock_server_module = MagicMock() + self.mock_server_module.get_nixos_context.return_value = self.mock_context + + def test_acme_search_issue(self, mock_import_module): """Test the issue from TODO.md: security.acme acceptTerms.""" - # Mock the context to return empty results - mock_get_context.return_value = self.mock_context + # Configure import_module to return our mock server module + mock_import_module.return_value = self.mock_server_module # Define a successful result for the search_options call self.mock_context.search_options.return_value = { @@ -87,6 +91,12 @@ def test_acme_search_issue(self, mock_get_context): # Test the improved multi-word query nixos_search(query="security.acme acceptTerms", type="options") + # Verify importlib.import_module was called with correct module + mock_import_module.assert_called_with("mcp_nixos.server") + + # Verify get_nixos_context was called + self.mock_server_module.get_nixos_context.assert_called_once() + # Verify that search_options was called with the correct parameters self.mock_context.search_options.assert_called_once() args, kwargs = self.mock_context.search_options.call_args @@ -98,10 +108,10 @@ def test_acme_search_issue(self, mock_get_context): self.assertIn("additional_terms", kwargs) self.assertEqual(kwargs["additional_terms"], ["acceptTerms"]) - def test_multi_word_with_quoted_terms(self, mock_get_context): + def test_multi_word_with_quoted_terms(self, mock_import_module): """Test multi-word query with quoted terms.""" - # Mock the context to return empty results - mock_get_context.return_value = self.mock_context + # Configure import_module to return our mock server module + mock_import_module.return_value = self.mock_server_module # Define a successful result for the search_options call self.mock_context.search_options.return_value = { @@ -119,6 +129,12 @@ def test_multi_word_with_quoted_terms(self, mock_get_context): # Test the multi-word query with quotes nixos_search(query='services.nginx "access log"', type="options") + # Verify importlib.import_module was called with correct module + mock_import_module.assert_called_with("mcp_nixos.server") + + # Verify get_nixos_context was called + self.mock_server_module.get_nixos_context.assert_called_once() + # Verify that search_options was called with the correct parameters self.mock_context.search_options.assert_called_once() args, kwargs = self.mock_context.search_options.call_args