From e911aec2e2bf4b4db417c462acc50f6a2dd131c0 Mon Sep 17 00:00:00 2001 From: lxdlam Date: Thu, 20 Mar 2025 15:00:27 +0800 Subject: [PATCH] feat: support context protocol in PythonInterpreter --- dspy/primitives/python_interpreter.py | 34 +++++++++++------- tests/primitives/test_python_interpreter.py | 40 ++++++++++----------- 2 files changed, 41 insertions(+), 33 deletions(-) diff --git a/dspy/primitives/python_interpreter.py b/dspy/primitives/python_interpreter.py index f3332f0f95..9d420090ec 100644 --- a/dspy/primitives/python_interpreter.py +++ b/dspy/primitives/python_interpreter.py @@ -1,11 +1,14 @@ import json import subprocess +from types import TracebackType from typing import Any, Dict, List, Optional import os + class InterpreterError(ValueError): pass + class PythonInterpreter: r""" PythonInterpreter that runs code in a sandboxed environment using Deno and Pyodide. @@ -16,22 +19,15 @@ class PythonInterpreter: Example Usage: ```python code_string = "print('Hello'); 1 + 2" - interp = PythonInterpreter() - output = interp(code_string) - print(output) # If final statement is non-None, prints the numeric result, else prints captured output - interp.shutdown() + with PythonInterpreter() as interp: + output = interp(code_string) # If final statement is non-None, prints the numeric result, else prints captured output ``` """ - def __init__( - self, - deno_command: Optional[List[str]] = None - ) -> None: + def __init__(self, deno_command: Optional[List[str]] = None) -> None: if isinstance(deno_command, dict): deno_command = None # no-op, just a guard in case someone passes a dict - self.deno_command = deno_command or [ - "deno", "run", "--allow-read", self._get_runner_path() - ] + self.deno_command = deno_command or ["deno", "run", "--allow-read", self._get_runner_path()] self.deno_process = None def _get_runner_path(self) -> str: @@ -46,7 +42,7 @@ def _ensure_deno_process(self) -> None: stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - text=True + text=True, ) except FileNotFoundError as e: install_instructions = ( @@ -77,7 +73,7 @@ def _serialize_value(self, value: Any) -> str: elif isinstance(value, (int, float, bool)): return str(value) elif value is None: - return 'None' + return "None" elif isinstance(value, list) or isinstance(value, dict): return json.dumps(value) else: @@ -129,6 +125,18 @@ def execute( # If there's no error, return the "output" field return result.get("output", None) + def __enter__(self): + return self + + # All exception fields are ignored and the runtime will automatically re-raise the exception + def __exit__( + self, + _exc_type: Optional[type[BaseException]], + _exc_val: Optional[BaseException], + _exc_tb: Optional[TracebackType], + ): + self.shutdown() + def __call__( self, code: str, diff --git a/tests/primitives/test_python_interpreter.py b/tests/primitives/test_python_interpreter.py index 60e94bb2c3..ec20c48ea7 100644 --- a/tests/primitives/test_python_interpreter.py +++ b/tests/primitives/test_python_interpreter.py @@ -10,35 +10,35 @@ def test_execute_simple_code(): - interpreter = PythonInterpreter() - code = "print('Hello, World!')" - result = interpreter.execute(code) - assert result == "Hello, World!\n", "Simple print statement should return 'Hello World!\n'" + with PythonInterpreter() as interpreter: + code = "print('Hello, World!')" + result = interpreter.execute(code) + assert result == "Hello, World!\n", "Simple print statement should return 'Hello World!\n'" def test_import(): - interpreter = PythonInterpreter() - code = "import math\nresult = math.sqrt(4)\nresult" - result = interpreter.execute(code) - assert result == 2, "Should be able to import and use math.sqrt" + with PythonInterpreter() as interpreter: + code = "import math\nresult = math.sqrt(4)\nresult" + result = interpreter.execute(code) + assert result == 2, "Should be able to import and use math.sqrt" def test_user_variable_definitions(): - interpreter = PythonInterpreter() - code = "result = number + 1\nresult" - result = interpreter.execute(code, variables={"number": 4}) - assert result == 5, "User variable assignment should work" + with PythonInterpreter() as interpreter: + code = "result = number + 1\nresult" + result = interpreter.execute(code, variables={"number": 4}) + assert result == 5, "User variable assignment should work" def test_failure_syntax_error(): - interpreter = PythonInterpreter() - code = "+++" - with pytest.raises(SyntaxError, match="Invalid Python syntax"): - interpreter.execute(code) + with PythonInterpreter() as interpreter: + code = "+++" + with pytest.raises(SyntaxError, match="Invalid Python syntax"): + interpreter.execute(code) def test_failure_zero_division(): - interpreter = PythonInterpreter() - code = "1+0/0" - with pytest.raises(InterpreterError, match="ZeroDivisionError"): - interpreter.execute(code) + with PythonInterpreter() as interpreter: + code = "1+0/0" + with pytest.raises(InterpreterError, match="ZeroDivisionError"): + interpreter.execute(code)