Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jit] Support custom error messages for raise #34112

Closed
wants to merge 16 commits into from
9 changes: 4 additions & 5 deletions test/expect/TestScript.test_python_frontend_py3.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
(decl (list) (option))
(list
(raise
(option
(apply
(variable (ident Exception))
(list (string_literal hello))
(list))))))
(apply
(variable (ident Exception))
(list (string_literal hello))
(list)))))
71 changes: 44 additions & 27 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11220,12 +11220,12 @@ def test_pack_unpack_state(self):

@unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
def test_torch_functional(self):
def foo(input, n_fft):
# type: (Tensor, int) -> Tensor
return torch.stft(input, n_fft)
# def foo(input, n_fft):
# # type: (Tensor, int) -> Tensor
# return torch.stft(input, n_fft)

inps = (torch.randn(10), 7)
self.assertEqual(foo(*inps), torch.jit.script(foo)(*inps))
# inps = (torch.randn(10), 7)
# self.assertEqual(foo(*inps), torch.jit.script(foo)(*inps))

def lu(x):
# type: (Tensor) -> Tuple[Tensor, Tensor]
Expand Down Expand Up @@ -15732,10 +15732,9 @@ def foo(cond):
''')

cu.foo(torch.tensor(0))
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
with self.assertRaisesRegex(torch.jit.Error, "3"):
cu.foo(torch.tensor(1))

@torch.jit.script
def foo(cond):
a = 3
if bool(cond):
Expand All @@ -15744,24 +15743,27 @@ def foo(cond):
raise ArbitraryError
return a

foo(torch.tensor(0))
# we don't currently validate the name of the exception
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
foo(torch.tensor(1))
with self.assertRaisesRegex(RuntimeError, "undefined value ArbitraryError"):
torch.jit.script(foo)

@torch.jit.script
def foo_except_used():
def exception_as_value():
a = Exception()
print(a)

with self.assertRaisesRegex(RuntimeError, "cannot be used as a value"):
torch.jit.script(exception_as_value)

@torch.jit.script
def exception():
a = Exception()
raise a

# a not DCEd
with self.assertRaisesRegex(RuntimeError, "expected value of type Tensor"):
foo_except_used()
with self.assertRaisesRegex(torch.jit.Error, ""):
exception()

@torch.jit.script
def foo_no_decl_always_throws():
raise "Hi"
raise RuntimeError("Hi")

# function that has no declared type but always throws set to None
output_type = next(foo_no_decl_always_throws.graph.outputs()).type()
Expand All @@ -15775,11 +15777,12 @@ def foo_decl_always_throws():
output_type = next(foo_decl_always_throws.graph.outputs()).type()
self.assertTrue(str(output_type) == "Tensor")

# We don't validate the expr following raise
@torch.jit.script
def foo():
raise 3 + 4

with self.assertRaisesRegex(RuntimeError, "must derive from BaseException"):
torch.jit.script(foo)

# a escapes scope
@torch.jit.script
def foo():
Expand All @@ -15793,6 +15796,20 @@ def foo():
return a
self.assertEqual(foo(), 1)

@torch.jit.script
def tuple_fn():
raise RuntimeError("hello", "goodbye")

with self.assertRaisesRegex(torch.jit.Error, "hello, goodbye"):
tuple_fn()

@torch.jit.script
def no_message():
raise RuntimeError

with self.assertRaisesRegex(torch.jit.Error, "RuntimeError"):
no_message()

def test_assertions(self):
cu = torch.jit.CompilationUnit('''
def foo(cond):
Expand All @@ -15801,7 +15818,7 @@ def foo(cond):
''')

cu.foo(torch.tensor(1))
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
with self.assertRaisesRegex(torch.jit.Error, "hi"):
cu.foo(torch.tensor(0))

@torch.jit.script
Expand All @@ -15810,7 +15827,7 @@ def foo(cond):

foo(torch.tensor(1))
# we don't currently validate the name of the exception
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
with self.assertRaisesRegex(torch.jit.Error, "hi"):
foo(torch.tensor(0))

def test_python_op_exception(self):
Expand Down Expand Up @@ -16632,7 +16649,7 @@ def no_guard_ifs_added(x):
def no_ifs_added(x):
# type: (int) -> int
if x < 0:
raise RunTimeError("hi")
raise RuntimeError("hi")
return x

self.checkScript(no_ifs_added, (1,))
Expand All @@ -16647,7 +16664,7 @@ def test_if_might(x):
else:
a = 2
else:
raise RunTimeError("hi")
raise RuntimeError("hi")
driazati marked this conversation as resolved.
Show resolved Hide resolved
return a + 2

self.checkScript(test_if_might, (1,))
Expand All @@ -16659,7 +16676,7 @@ def test_loop_no_escape(x):
# type: (int)
if x >= 0:
for i in range(x):
raise RunTimeError("hi")
raise RuntimeError("hi")
else:
return 5
return x + 3
Expand All @@ -16676,7 +16693,7 @@ def test_loop_exception_with_continue(x):
i = 0
for i in range(5):
if i == x:
raise RunTimeError("hi")
raise RuntimeError("hi")
else:
continue
print(i)
Expand All @@ -16693,7 +16710,7 @@ def no_return_func(self):
# type: (Tensor) -> Tensor
output = torch.tanh(self)
def backward(grad_output):
raise "Hi"
raise RuntimeError("Hi")
''')
with self.assertRaisesRegex(RuntimeError, "does not return along all"):
cu = torch.jit.CompilationUnit(code)
Expand All @@ -16704,7 +16721,7 @@ def test_exit_pair_reset(x):
if x > 0:
a = 0
def backward(grad_output):
raise "Hi"
raise RuntimeError("Hi")
a = a + 1
else:
return x
Expand Down
14 changes: 13 additions & 1 deletion torch/_jit_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def __getattr__(self, key):
return f_locals[key]
elif key in f_globals:
return f_globals[key]
elif key in dir(builtins):
return getattr(builtins, key)
driazati marked this conversation as resolved.
Show resolved Hide resolved

return createResolutionCallbackFromEnv(env())

Expand Down Expand Up @@ -188,7 +190,17 @@ def createResolutionCallbackForClassMethods(cls):
for fn in fns:
captures.update(get_closure(fn))

return lambda key: captures.get(key, None)
class closure_lookup(object):
# This is a class since `closure` is a dict and it's easier in
# `env_helper` if everything just works with `getattr` calls
def __getattr__(self, key):
if key in captures:
return captures[key]
elif hasattr(builtins, key):
return getattr(builtins, key)
return None

return createResolutionCallbackFromEnv(closure_lookup())


def boolean_dispatch(arg_name, arg_index, default, if_true, if_false, module_name, func_name):
Expand Down
5 changes: 4 additions & 1 deletion torch/_linalg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ def is_sparse(A):
"""Check if tensor A is a sparse tensor"""
if isinstance(A, torch.Tensor):
return A.layout == torch.sparse_coo
raise TypeError("expected Tensor but got %s" % (type(A).__name__))
error_str = "expected Tensor"
if not torch.jit.is_scripting():
error_str += " but got {}".format(type(A))
raise TypeError(error_str)


def get_floating_dtype(A):
Expand Down
11 changes: 7 additions & 4 deletions torch/_lobpcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,8 @@ def update_converged_count(self):
# strict ordering of eigenpairs
break
count += 1
assert count >= prev_count, (
'the number of converged eigenpairs '
'(was %s, got %s) cannot decrease' % (prev_count, count))
assert count >= prev_count, 'the number of converged eigenpairs ' \
'(was {}, got {}) cannot decrease'.format(prev_count, count)
self.ivars['converged_count'] = count
self.tvars['rerr'] = rerr
return count
Expand Down Expand Up @@ -720,10 +719,14 @@ def _get_ortho(self, U, V):
if rerr < tau_ortho:
break
if m < U.shape[-1] + V.shape[-1]:
# TorchScript needs the class var to be assigned to a local to
# do optional type refinement
B = self.B
assert B is not None
raise ValueError(
'Overdetermined shape of U:'
' #B-cols(={}) >= #U-cols(={}) + #V-cols(={}) must hold'
.format(self.B.shape[-1], U.shape[-1], V.shape[-1]))
.format(B.shape[-1], U.shape[-1], V.shape[-1]))
self.ivars['ortho_i'] = i
self.ivars['ortho_j'] = j
return U
Expand Down
53 changes: 46 additions & 7 deletions torch/csrc/jit/frontend/ir_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,10 @@ struct Environment {
{"ord", std::make_shared<BuiltinFunction>(aten::ord, at::nullopt)},
{"chr", std::make_shared<BuiltinFunction>(aten::chr, at::nullopt)},
{"bin", std::make_shared<BuiltinFunction>(aten::bin, at::nullopt)},
// Only AssertionError is bound so that we can use it from emitAssert,
// all other exceptions should be resolved at the Python level
{"AssertionError",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should only bind Exception which is the base class that has 62 other classes that inherit from it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AssertionError is special here since we don't have real assert statements, we don't really need to add any others. This PR changes them into, so the compiler needs to know about AssertionError:

if not cond:
    raise AssertionError(message)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that makes sense. Can you add a comment ? I still think maybe we should bind Exception for C++ only code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment, why should we bind in Exception? If its for things like AD formulas I think we can just wait and see what is specifically requested before adding it

std::make_shared<ExceptionValue>("AssertionError")},
{"range", SpecialFormValue::create(prim::range)},
{"zip", SpecialFormValue::create(prim::zip)},
{"enumerate", SpecialFormValue::create(prim::enumerate)},
Expand Down Expand Up @@ -995,7 +999,7 @@ struct to_ir {
emitSugaredExpr(expr, 0);
} break;
case TK_RAISE:
emitRaise(Raise(stmt).range());
emitRaise(Raise(stmt));
break;
case TK_ASSERT:
emitAssert(Assert(stmt));
Expand Down Expand Up @@ -1724,19 +1728,54 @@ struct to_ir {
// raise a
//
// We ignore the expression following raise
void emitRaise(const SourceRange& loc) {
const std::string exception = "Exception";
auto string_input = insertConstant(*graph, exception, loc);
graph->insert(prim::RaiseException, {string_input}, {}, loc);
void emitRaise(const Raise& raise) {
auto sv = emitSugaredExpr(raise.expr(), 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this all to python_sugared_values ? and put all the conversion of the inputs to strings in there.

Value* error_message = nullptr;

if (auto exception_instance =
std::dynamic_pointer_cast<ExceptionMessageValue>(sv)) {
// The typical case, an instance of the exception class was thrown:
driazati marked this conversation as resolved.
Show resolved Hide resolved
// raise RuntimeError("error")
error_message = exception_instance->getValue();
} else if (
auto exception_class = std::dynamic_pointer_cast<ExceptionValue>(sv)) {
// A bare exception was thrown so add an empty message. e.g.
// raise RuntimeError
error_message = insertConstant(*graph, "", raise.range());
driazati marked this conversation as resolved.
Show resolved Hide resolved
} else {
// The raise was not followed by an exception (i.e. it was something like
// `raise "error"` instead of `raise RuntimeError("error")`)
throw ErrorReport(raise.range())
<< "exceptions must derive from BaseException";
}

if (!error_message->type()->isSubtypeOf(StringType::get())) {
error_message = graph->insert(aten::str, {error_message});
}

graph->insert(prim::RaiseException, {error_message}, {}, raise.range());
exit_blocks.insert(environment_stack->block());
}

// emit assserions as an if branch so that assertions will reuse the
void emitAssert(const Assert& stmt) {
CondValue cond_value = emitCondExpr(stmt.test());
List<Stmt> true_branch = List<Stmt>::create(stmt.range(), {});
List<Stmt> false_branch =
List<Stmt>::create(stmt.range(), {Raise::create(stmt.range())});

// Create an `AssertionError("the_message")` call
auto message = (stmt.msg().present())
? stmt.msg().get()
: StringLiteral::create(stmt.range(), "");
auto callee = Var::create(
stmt.range(), Ident::create(stmt.range(), "AssertionError"));
auto apply = Apply::create(
stmt.range(),
callee,
List<Expr>::create(stmt.range(), {message}),
List<Attribute>::create(stmt.range(), {}));

List<Stmt> false_branch = List<Stmt>::create(
stmt.range(), {Raise::create(stmt.range(), apply)});
emitIfElseBlocks(stmt.range(), cond_value, true_branch, false_branch);
}

Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/frontend/sugared_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,8 @@ std::shared_ptr<SugaredValue> ClassValue::attr(
Function& m,
const std::string& field) {
if (field != "__new__") {
throw ErrorReport(loc) << "Tried to lookup unknown attribute on class";
throw ErrorReport(loc) << "Tried to lookup unknown attribute on class "
<< type_->python_str();
}
return SpecialFormValue::create(prim::CreateObject);
}
Expand Down
39 changes: 39 additions & 0 deletions torch/csrc/jit/frontend/sugared_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,5 +652,44 @@ struct SimpleSelf : public Self {
private:
ClassTypePtr classType_;
};


// This is not a SimpleValue so it can pass through the code paths
// that expect a SimpleValue as a sugared value
struct TORCH_API ExceptionMessageValue : public SugaredValue {
ExceptionMessageValue(Value* value) : value_(value) {}

std::string kind() const override {
return "exception message";
}

Value* getValue() {
return value_;
}

Value* value_;
};


struct TORCH_API ExceptionValue : public SugaredValue {
ExceptionValue(const std::string& message) : message_(std::move(message)) {}
driazati marked this conversation as resolved.
Show resolved Hide resolved

std::string kind() const override {
return "exception";
}

std::shared_ptr<SugaredValue> call(
const SourceRange& loc,
Function& m,
at::ArrayRef<NamedValue> inputs,
at::ArrayRef<NamedValue> attributes,
size_t n_binders) {
return std::make_shared<ExceptionMessageValue>(
insertConstant(*m.graph(), message_, loc));
}

std::string message_;
};

} // namespace jit
} // namespace torch