Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions dspy/primitives/python_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
import subprocess
from types import TracebackType
from typing import Any, Dict, List, Optional
import os


class InterpreterError(ValueError):
pass


class PythonInterpreter:
r"""
PythonInterpreter that runs code in a sandboxed environment using Deno and Pyodide.
Expand All @@ -16,22 +19,15 @@ class PythonInterpreter:
Example Usage:
```python
code_string = "print('Hello'); 1 + 2"
interp = PythonInterpreter()
output = interp(code_string)
print(output) # If final statement is non-None, prints the numeric result, else prints captured output
interp.shutdown()
with PythonInterpreter() as interp:
output = interp(code_string) # If final statement is non-None, prints the numeric result, else prints captured output
```
"""

def __init__(
self,
deno_command: Optional[List[str]] = None
) -> None:
def __init__(self, deno_command: Optional[List[str]] = None) -> None:
if isinstance(deno_command, dict):
deno_command = None # no-op, just a guard in case someone passes a dict
self.deno_command = deno_command or [
"deno", "run", "--allow-read", self._get_runner_path()
]
self.deno_command = deno_command or ["deno", "run", "--allow-read", self._get_runner_path()]
self.deno_process = None

def _get_runner_path(self) -> str:
Expand All @@ -46,7 +42,7 @@ def _ensure_deno_process(self) -> None:
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
text=True,
)
except FileNotFoundError as e:
install_instructions = (
Expand Down Expand Up @@ -77,7 +73,7 @@ def _serialize_value(self, value: Any) -> str:
elif isinstance(value, (int, float, bool)):
return str(value)
elif value is None:
return 'None'
return "None"
elif isinstance(value, list) or isinstance(value, dict):
return json.dumps(value)
else:
Expand Down Expand Up @@ -129,6 +125,18 @@ def execute(
# If there's no error, return the "output" field
return result.get("output", None)

def __enter__(self):
return self

# All exception fields are ignored and the runtime will automatically re-raise the exception
def __exit__(
self,
_exc_type: Optional[type[BaseException]],
_exc_val: Optional[BaseException],
_exc_tb: Optional[TracebackType],
):
self.shutdown()

def __call__(
self,
code: str,
Expand Down
40 changes: 20 additions & 20 deletions tests/primitives/test_python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,35 @@


def test_execute_simple_code():
interpreter = PythonInterpreter()
code = "print('Hello, World!')"
result = interpreter.execute(code)
assert result == "Hello, World!\n", "Simple print statement should return 'Hello World!\n'"
with PythonInterpreter() as interpreter:
code = "print('Hello, World!')"
result = interpreter.execute(code)
assert result == "Hello, World!\n", "Simple print statement should return 'Hello World!\n'"


def test_import():
interpreter = PythonInterpreter()
code = "import math\nresult = math.sqrt(4)\nresult"
result = interpreter.execute(code)
assert result == 2, "Should be able to import and use math.sqrt"
with PythonInterpreter() as interpreter:
code = "import math\nresult = math.sqrt(4)\nresult"
result = interpreter.execute(code)
assert result == 2, "Should be able to import and use math.sqrt"


def test_user_variable_definitions():
interpreter = PythonInterpreter()
code = "result = number + 1\nresult"
result = interpreter.execute(code, variables={"number": 4})
assert result == 5, "User variable assignment should work"
with PythonInterpreter() as interpreter:
code = "result = number + 1\nresult"
result = interpreter.execute(code, variables={"number": 4})
assert result == 5, "User variable assignment should work"


def test_failure_syntax_error():
interpreter = PythonInterpreter()
code = "+++"
with pytest.raises(SyntaxError, match="Invalid Python syntax"):
interpreter.execute(code)
with PythonInterpreter() as interpreter:
code = "+++"
with pytest.raises(SyntaxError, match="Invalid Python syntax"):
interpreter.execute(code)


def test_failure_zero_division():
interpreter = PythonInterpreter()
code = "1+0/0"
with pytest.raises(InterpreterError, match="ZeroDivisionError"):
interpreter.execute(code)
with PythonInterpreter() as interpreter:
code = "1+0/0"
with pytest.raises(InterpreterError, match="ZeroDivisionError"):
interpreter.execute(code)
Loading