diff --git a/README.md b/README.md index cbca10b..4cda380 100644 --- a/README.md +++ b/README.md @@ -20,10 +20,12 @@ The architecture follows a defense-in-depth approach, with multiple layers of is ## File Descriptions -- `server.py`: The main Python file that implements a TCP server which executes Python code sent via TCP connections. It maintains stateful sessions with unique IDs, allowing variables and functions defined in one execution to be available in subsequent executions within the same session. -- `run.sh`: A shell script that runs the Python server inside a Docker container using gVisor's runsc runtime for isolation. It mounts the server.py file into the container and exposes port 8000. -- `test.sh`: A shell script that runs the test_tcp.py script to test the server. +- `server.py`: The main Python file that implements both TCP and WebSocket servers which execute Python code sent via connections. It maintains stateful sessions with unique IDs, allowing variables and functions defined in one execution to be available in subsequent executions within the same session. +- `run.sh`: A shell script that runs the Python server inside a Docker container using gVisor's runsc runtime for isolation. It mounts the server.py file into the container and exposes ports 8000 (TCP) and 8001 (WebSocket). +- `test.sh`: A shell script that runs the test_tcp.py script to test the TCP server. - `test_tcp.py`: A Python script that tests the TCP server by connecting to it, sending Python code to execute, and demonstrating session persistence. +- `test_ws.sh`: A shell script that runs the test_ws.py script to test the WebSocket server. +- `test_ws.py`: A Python script that tests the WebSocket server by connecting to it, sending Python code to execute, and demonstrating session persistence. - `.gitignore`: A configuration file that specifies files to be ignored by version control. ## Setup Instructions @@ -116,6 +118,91 @@ The server uses a simple protocol for communication: - If a `session_id` is provided, the code is executed in the context of that session. - If the provided `session_id` doesn't exist, an error is returned. +### WebSocket Protocol + +The server also supports WebSocket connections on port 8001. The WebSocket protocol is simpler than the TCP protocol since WebSockets handle message framing automatically. + +1. **Request Format**: + ```json + { + "code": "Python code to execute", + "session_id": "optional-session-id" + } + ``` + +2. **Response Format**: + ```json + { + "status": "ok|error", + "output": "execution output (if status is ok)", + "error": "error message (if status is error)", + "session_id": "session-id" + } + ``` + +3. **Session Management**: + - If no `session_id` is provided in the request, a new session is created with a unique ID. + - If a `session_id` is provided, the code is executed in the context of that session. + - If the provided `session_id` doesn't exist, an error is returned. + +### Testing WebSocket Connection + +You can test the WebSocket connection using the provided test_ws.sh script: + +```bash +./test_ws.sh +``` + +This will run the test_ws.py script, which connects to the server via WebSocket, sends Python code to execute, and demonstrates session persistence. + +### Python WebSocket Client Example + +Here's a simple example of how to use the server from Python with WebSockets: + +```python +import asyncio +import websockets +import json + +async def send_code(websocket, code, session_id=None): + # Prepare request + request = {"code": code} + if session_id: + request["session_id"] = session_id + + # Convert to JSON + request_json = json.dumps(request) + + # Send the message + await websocket.send(request_json) + + # Receive the response + response_json = await websocket.recv() + return json.loads(response_json) + +async def main(): + # Connect to the server + uri = "ws://localhost:8001" + async with websockets.connect(uri) as websocket: + # Receive initial greeting + greeting = await websocket.recv() + print(f"Server greeting: {greeting}") + + # Execute code in a new session + response = await send_code(websocket, "x = 42\nprint(f'x = {x}')") + print(f"Response: {json.dumps(response, indent=2)}") + + # Save the session ID for later use + session_id = response.get("session_id") + + # Execute more code in the same session + response = await send_code(websocket, "y = x * 2\nprint(f'y = {y}')", session_id) + print(f"Response: {json.dumps(response, indent=2)}") + +if __name__ == "__main__": + asyncio.run(main()) +``` + ### Python Client Example Here's a simple example of how to use the server from Python: diff --git a/run.sh b/run.sh index 6ac0a23..dd8fc48 100755 --- a/run.sh +++ b/run.sh @@ -1,5 +1,8 @@ +#!/bin/bash + docker run --runtime=runsc --rm -it \ -v "$(pwd)/server.py:/server.py" \ -p 8000:8000 \ + -p 8001:8001 \ python:3.9.21-alpine3.21 \ - python /server.py + sh -c "pip install websockets && python /server.py" \ No newline at end of file diff --git a/server.py b/server.py index bc11ad9..6b85fae 100644 --- a/server.py +++ b/server.py @@ -7,12 +7,54 @@ import logging import uuid import threading +import asyncio +import websockets from concurrent.futures import ThreadPoolExecutor # Dictionary to store session environments sessions = {} sessions_lock = threading.Lock() +def process_request(code, session_id=None): + """Process a code execution request.""" + # If no session_id provided, create a new session + if not session_id: + session_id = str(uuid.uuid4()) + with sessions_lock: + sessions[session_id] = {} + logging.info(f"Created new session: {session_id}") + # If session_id provided but doesn't exist, return error + elif session_id not in sessions: + return { + "status": "error", + "error": f"Session {session_id} not found" + } + + # Execute the code in the session's environment + output = io.StringIO() + try: + old_stdout = sys.stdout + try: + sys.stdout = output + with sessions_lock: + exec(code, sessions[session_id]) + finally: + sys.stdout = old_stdout + + result = output.getvalue() + return { + "status": "ok", + "output": result, + "session_id": session_id + } + except Exception: + tb = traceback.format_exc() + return { + "status": "error", + "error": tb, + "session_id": session_id + } + class PythonREPLHandler(socketserver.BaseRequestHandler): def handle(self): """Handle incoming TCP connections.""" @@ -40,7 +82,7 @@ def handle(self): session_id = request.get('session_id', None) # Process the request - response = self.process_request(code, session_id) + response = process_request(code, session_id) self.send_response(response) except json.JSONDecodeError: @@ -104,68 +146,109 @@ def send_response(self, response_dict): except Exception as e: logging.error(f"Error sending response: {str(e)}") - def process_request(self, code, session_id=None): - """Process a code execution request.""" - # If no session_id provided, create a new session - if not session_id: - session_id = str(uuid.uuid4()) - with sessions_lock: - sessions[session_id] = {} - logging.info(f"Created new session: {session_id}") - # If session_id provided but doesn't exist, return error - elif session_id not in sessions: - return { - "status": "error", - "error": f"Session {session_id} not found" - } +# WebSocket handler +async def websocket_handler(websocket, path): + """Handle incoming WebSocket connections.""" + client_address = websocket.remote_address + logging.info(f"WebSocket connection established from {client_address}") + + try: + # Initial greeting with protocol info + await websocket.send(json.dumps({ + "status": "ok", + "message": "Python REPL Server. Send JSON with 'code' to execute. Optional 'session_id' to continue a session." + })) - # Execute the code in the session's environment - output = io.StringIO() - try: - old_stdout = sys.stdout + async for message in websocket: try: - sys.stdout = output - with sessions_lock: - exec(code, sessions[session_id]) - finally: - sys.stdout = old_stdout + request = json.loads(message) - result = output.getvalue() - return { - "status": "ok", - "output": result, - "session_id": session_id - } - except Exception: - tb = traceback.format_exc() - return { - "status": "error", - "error": tb, - "session_id": session_id - } + # Extract code and optional session_id + code = request.get('code', '') + session_id = request.get('session_id', None) + + # Process the request + response = process_request(code, session_id) + await websocket.send(json.dumps(response)) + + except json.JSONDecodeError: + await websocket.send(json.dumps({ + "status": "error", + "error": "Invalid JSON format" + })) + except Exception as e: + await websocket.send(json.dumps({ + "status": "error", + "error": str(e) + })) + except websockets.exceptions.ConnectionClosed: + logging.info(f"WebSocket connection from {client_address} closed by client") + except Exception as e: + logging.error(f"Error handling WebSocket connection from {client_address}: {str(e)}") + finally: + logging.info(f"WebSocket connection from {client_address} closed") class ThreadedTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer): allow_reuse_address = True daemon_threads = True -def main(): - # Use TCP configuration - host = "0.0.0.0" # Listen on all interfaces - port = 8000 - - logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") - - # Create and start the server +def run_tcp_server(host, port): + """Run the TCP server.""" server = ThreadedTCPServer((host, port), PythonREPLHandler) - logging.info(f"Python REPL server listening on TCP {host}:{port}") + try: server.serve_forever() except KeyboardInterrupt: - logging.info("Server is shutting down") + logging.info("TCP Server is shutting down") finally: server.server_close() - logging.info("Server shut down") + logging.info("TCP Server shut down") + +async def run_websocket_server(host, port): + """Run the WebSocket server.""" + server = await websockets.serve(websocket_handler, host, port) + logging.info(f"Python REPL server listening on WebSocket ws://{host}:{port}") + + try: + await asyncio.Future() # Run forever + except asyncio.CancelledError: + logging.info("WebSocket Server is shutting down") + server.close() + await server.wait_closed() + logging.info("WebSocket Server shut down") + +def main(): + # Server configuration + host = "0.0.0.0" # Listen on all interfaces + tcp_port = 8000 + ws_port = 8001 + + logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") + + # Check if command line arguments specify the connection type + if len(sys.argv) > 1 and sys.argv[1] == "ws-only": + # WebSocket only mode + asyncio.run(run_websocket_server(host, ws_port)) + elif len(sys.argv) > 1 and sys.argv[1] == "tcp-only": + # TCP only mode + run_tcp_server(host, tcp_port) + else: + # Run both TCP and WebSocket servers + # Create a thread for the TCP server + tcp_thread = threading.Thread(target=run_tcp_server, args=(host, tcp_port)) + tcp_thread.daemon = True + tcp_thread.start() + + # Run the WebSocket server in the main thread + try: + asyncio.run(run_websocket_server(host, ws_port)) + except KeyboardInterrupt: + logging.info("Servers are shutting down") + finally: + # TCP server will shut down automatically when the main thread exits + # because it's a daemon thread + logging.info("Servers shut down") if __name__ == "__main__": main() \ No newline at end of file diff --git a/test_ws.py b/test_ws.py new file mode 100644 index 0000000..83dbb1d --- /dev/null +++ b/test_ws.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +import asyncio +import websockets +import json +import sys + +async def send_receive(websocket, request_dict): + """Send a request to the server and receive the response.""" + # Convert request to JSON + request_json = json.dumps(request_dict) + + # Send the message + await websocket.send(request_json) + + # Receive the response + response_json = await websocket.recv() + return json.loads(response_json) + +async def main(): + # Server connection details + host = "localhost" + port = 8001 # WebSocket port + uri = f"ws://{host}:{port}" + + try: + # Connect to the server + async with websockets.connect(uri) as websocket: + print(f"Connected to {uri}") + + # Receive initial greeting + greeting = await websocket.recv() + print(f"Server greeting: {greeting}") + + # Test 1: Execute code without a session ID (creates a new session) + print("\n--- Test 1: Execute code without a session ID ---") + response = await send_receive(websocket, { + "code": "x = 42\nprint(f'x = {x}')" + }) + print(f"Response: {json.dumps(response, indent=2)}") + + # Save the session ID for later use + session_id = response.get("session_id") + print(f"Session ID: {session_id}") + + # Test 2: Execute code in the same session (using the session ID) + print("\n--- Test 2: Execute code in the same session ---") + response = await send_receive(websocket, { + "code": "y = x * 2\nprint(f'y = {y}')", + "session_id": session_id + }) + print(f"Response: {json.dumps(response, indent=2)}") + + # Test 3: Define a function in the session + print("\n--- Test 3: Define a function in the session ---") + response = await send_receive(websocket, { + "code": """ +def greet(name): + return f"Hello, {name}!" +print(greet("World")) +""", + "session_id": session_id + }) + print(f"Response: {json.dumps(response, indent=2)}") + + # Test 4: Call the function defined in the previous request + print("\n--- Test 4: Call the function defined in the previous request ---") + response = await send_receive(websocket, { + "code": "print(greet('Python'))", + "session_id": session_id + }) + print(f"Response: {json.dumps(response, indent=2)}") + + # Test 5: Create a new session + print("\n--- Test 5: Create a new session ---") + response = await send_receive(websocket, { + "code": "print('This is a new session')" + }) + print(f"Response: {json.dumps(response, indent=2)}") + new_session_id = response.get("session_id") + print(f"New Session ID: {new_session_id}") + + # Test 6: Verify the new session doesn't have access to variables from the first session + print("\n--- Test 6: Verify session isolation ---") + response = await send_receive(websocket, { + "code": "try:\n print(f'x = {x}')\nexcept NameError as e:\n print(f'Error: {e}')", + "session_id": new_session_id + }) + print(f"Response: {json.dumps(response, indent=2)}") + + # Test 7: Try to access a non-existent session + print("\n--- Test 7: Try to access a non-existent session ---") + response = await send_receive(websocket, { + "code": "print('This should fail')", + "session_id": "non-existent-session-id" + }) + print(f"Response: {json.dumps(response, indent=2)}") + + except Exception as e: + print(f"Error: {e}") + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/test_ws.sh b/test_ws.sh new file mode 100755 index 0000000..e44afd9 --- /dev/null +++ b/test_ws.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python3 test_ws.py \ No newline at end of file