Skip to content

Commit 767f6aa

Browse files
malfetpytorchmergebot
authored andcommitted
[JIT][Security] Do not blindly eval input string (#89189)
Introduce `_eval_no_call` method, that evaluates statement only if it does not contain any calls(done by examining the bytecode), thus preventing command injection exploit Added simple unit test to check for that `torch.jit.annotations.get_signature` would not result in calling random code. Although, this code path exists for Python-2 compatibility, and perhaps should be simply removed. Fixes #88868 Pull Request resolved: #89189 Approved by: https://github.com/suo
1 parent fbbf368 commit 767f6aa

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

Diff for: test/test_jit.py

+8
Original file line numberDiff line numberDiff line change
@@ -3951,6 +3951,14 @@ def invalid4(a):
39513951
return a + 2
39523952
torch.jit.script(invalid4)
39533953

3954+
def test_calls_in_type_annotations(self):
3955+
with self.assertRaisesRegex(RuntimeError, "Type annotation should not contain calls"):
3956+
def spooky(a):
3957+
# type: print("Hello") -> Tensor # noqa: F723
3958+
return a + 2
3959+
print(torch.__file__)
3960+
torch.jit.annotations.get_signature(spooky, None, 1, True)
3961+
39543962
def test_is_optional(self):
39553963
ann = Union[List[int], List[float]]
39563964
torch._jit_internal.is_optional(ann)

Diff for: torch/csrc/jit/frontend/script_type_parser.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ std::vector<IValue> ScriptTypeParser::evaluateDefaults(
316316
// We then run constant prop on this graph and check the results are
317317
// constant. This approach avoids having to have separate handling of
318318
// default arguments from standard expressions by piecing together existing
319-
// machinery for graph generation, constant propgation, and constant
319+
// machinery for graph generation, constant propagation, and constant
320320
// extraction.
321321
auto tuple_type = Subscript::create(
322322
r,

Diff for: torch/jit/annotations.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ast
2+
import dis
23
import enum
34
import inspect
45
import re
@@ -144,6 +145,15 @@ def check_fn(fn, loc):
144145
raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function")
145146

146147

148+
def _eval_no_call(stmt, glob, loc):
149+
"""Evaluate statement as long as it does not contain any method/function calls"""
150+
bytecode = compile(stmt, "", mode="eval")
151+
for insn in dis.get_instructions(bytecode):
152+
if "CALL" in insn.opname:
153+
raise RuntimeError(f"Type annotation should not contain calls, but '{stmt}' does")
154+
return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204
155+
156+
147157
def parse_type_line(type_line, rcb, loc):
148158
"""Parses a type annotation specified as a comment.
149159
@@ -154,15 +164,15 @@ def parse_type_line(type_line, rcb, loc):
154164
arg_ann_str, ret_ann_str = split_type_line(type_line)
155165

156166
try:
157-
arg_ann = eval(arg_ann_str, {}, EvalEnv(rcb)) # type: ignore[arg-type] # noqa: P204
167+
arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb))
158168
except (NameError, SyntaxError) as e:
159169
raise RuntimeError("Failed to parse the argument list of a type annotation") from e
160170

161171
if not isinstance(arg_ann, tuple):
162172
arg_ann = (arg_ann,)
163173

164174
try:
165-
ret_ann = eval(ret_ann_str, {}, EvalEnv(rcb)) # type: ignore[arg-type] # noqa: P204
175+
ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb))
166176
except (NameError, SyntaxError) as e:
167177
raise RuntimeError("Failed to parse the return type of a type annotation") from e
168178

0 commit comments

Comments
 (0)