|
1 | 1 | import asyncio
|
2 | 2 | import logging
|
| 3 | +import signal |
3 | 4 | import traceback
|
4 | 5 | from collections.abc import Sequence
|
5 | 6 | from typing import Any
|
@@ -30,11 +31,21 @@ def get_allowed_commands(self) -> list[str]:
|
30 | 31 | """Get the allowed commands"""
|
31 | 32 | return self.executor.validator.get_allowed_commands()
|
32 | 33 |
|
| 34 | + def get_allowed_patterns(self) -> list[str]: |
| 35 | + """Get the allowed regex patterns""" |
| 36 | + return [ |
| 37 | + pattern.pattern |
| 38 | + for pattern in self.executor.validator._get_allowed_patterns() |
| 39 | + ] |
| 40 | + |
33 | 41 | def get_tool_description(self) -> Tool:
|
| 42 | + """Get the tool description for the execute command""" |
| 43 | + allowed_commands = ", ".join(self.get_allowed_commands()) |
| 44 | + allowed_patterns = ", ".join(self.get_allowed_patterns()) |
34 | 45 | """Get the tool description for the execute command"""
|
35 | 46 | return Tool(
|
36 | 47 | name=self.name,
|
37 |
| - description=f"{self.description}\nAllowed commands: {', '.join(self.get_allowed_commands())}", |
| 48 | + description=f"{self.description}\nAllowed commands: {allowed_commands}\nAllowed patterns: {allowed_patterns}", |
38 | 49 | inputSchema={
|
39 | 50 | "type": "object",
|
40 | 51 | "properties": {
|
@@ -142,13 +153,64 @@ async def call_tool(name: str, arguments: Any) -> Sequence[TextContent]:
|
142 | 153 | async def main() -> None:
|
143 | 154 | """Main entry point for the MCP shell server"""
|
144 | 155 | logger.info(f"Starting MCP shell server v{__version__}")
|
| 156 | + |
| 157 | + # Setup signal handling |
| 158 | + loop = asyncio.get_running_loop() |
| 159 | + stop_event = asyncio.Event() |
| 160 | + |
| 161 | + def handle_signal(): |
| 162 | + if not stop_event.is_set(): # Prevent duplicate handling |
| 163 | + logger.info("Received shutdown signal, starting cleanup...") |
| 164 | + stop_event.set() |
| 165 | + |
| 166 | + # Register signal handlers |
| 167 | + for sig in (signal.SIGTERM, signal.SIGINT): |
| 168 | + loop.add_signal_handler(sig, handle_signal) |
| 169 | + |
145 | 170 | try:
|
146 | 171 | from mcp.server.stdio import stdio_server
|
147 | 172 |
|
148 | 173 | async with stdio_server() as (read_stream, write_stream):
|
149 |
| - await app.run( |
150 |
| - read_stream, write_stream, app.create_initialization_options() |
| 174 | + # Run the server until stop_event is set |
| 175 | + server_task = asyncio.create_task( |
| 176 | + app.run(read_stream, write_stream, app.create_initialization_options()) |
151 | 177 | )
|
| 178 | + |
| 179 | + # Create task for stop event |
| 180 | + stop_task = asyncio.create_task(stop_event.wait()) |
| 181 | + |
| 182 | + # Wait for either server completion or stop signal |
| 183 | + done, pending = await asyncio.wait( |
| 184 | + [server_task, stop_task], return_when=asyncio.FIRST_COMPLETED |
| 185 | + ) |
| 186 | + |
| 187 | + # Check for exceptions in completed tasks |
| 188 | + for task in done: |
| 189 | + try: |
| 190 | + await task |
| 191 | + except Exception: |
| 192 | + raise # Re-raise the exception |
| 193 | + |
| 194 | + # Cancel any pending tasks |
| 195 | + for task in pending: |
| 196 | + task.cancel() |
| 197 | + try: |
| 198 | + await task |
| 199 | + except asyncio.CancelledError: |
| 200 | + pass |
| 201 | + |
152 | 202 | except Exception as e:
|
153 | 203 | logger.error(f"Server error: {str(e)}")
|
154 | 204 | raise
|
| 205 | + finally: |
| 206 | + # Cleanup signal handlers |
| 207 | + for sig in (signal.SIGTERM, signal.SIGINT): |
| 208 | + loop.remove_signal_handler(sig) |
| 209 | + |
| 210 | + # Ensure all processes are terminated |
| 211 | + if hasattr(tool_handler, "executor") and hasattr( |
| 212 | + tool_handler.executor, "process_manager" |
| 213 | + ): |
| 214 | + await tool_handler.executor.process_manager.cleanup_processes() |
| 215 | + |
| 216 | + logger.info("Server shutdown complete") |
0 commit comments