Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/pytorch/pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
ynonaolga committed Nov 16, 2022
2 parents 9854d6b + 9d28775 commit b8bb397
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 189 deletions.
28 changes: 28 additions & 0 deletions test/dynamo/test_misc.py
Expand Up @@ -2885,6 +2885,34 @@ def func(x, y):
self.assertTrue(same(ref, res))
self.assertTrue(same(x, x1))

def test_if_cond_nn_mod(self):
class MockModule(torch.nn.Module):
def __init__(self, output_relu=True):
super(MockModule, self).__init__()
self.relu = torch.nn.ReLU() if output_relu else None

def forward(self, x):
x = torch.sin(x)
if self.relu:
x = self.relu(x)
return x

model = MockModule()
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)

x = torch.rand(4)
ref = model(x)
res = opt_model(x)
self.assertTrue(same(ref, res))

model = MockModule(output_relu=False)
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)

x = torch.rand(4)
ref = model(x)
res = opt_model(x)
self.assertTrue(same(ref, res))


class CustomFunc(torch.autograd.Function):
@staticmethod
Expand Down
92 changes: 0 additions & 92 deletions test/dynamo/test_repros.py
Expand Up @@ -1938,98 +1938,6 @@ def fn(x):
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 1)

@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
def test_rewrite_assert_with_msg(self):
def f(x):
b = x.sin()
assert x[0] == 3, "First dim need to be 3"
return x.cos() + b

args = (torch.Tensor([3, 4, 5]),)
cnt = torch._dynamo.testing.CompileCounter()

opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
self.assertTrue(same(f(*args), opt_f(*args)))
self.assertEqual(cnt.op_count, 6)
self.assertEqual(cnt.frame_count, 1)

exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))

with self.assertRaisesRegex(AssertionError, ""):
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))

@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
def test_not_rewrite_assert_for_other_errors(self):
def f(x):
b = x.sin()
if not x.sum() <= 3:
raise ValueError("input sum needs to be 3")
return x.cos() + b

args = (torch.Tensor([3, 4, 5]),)
opt_fn = torch._dynamo.optimize("eager")(f)
with self.assertRaisesRegex(ValueError, "input sum needs to be 3"):
opt_fn(*args)

# TODO (tmanlaibaatar) handle data-dependent fstring in assert statement.
@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
def test_rewrite_assert_with_fstring_msg(self):
def f(x):
b = x.sin()
assert x[0] == 3, f"First dim need to be {x[0]}"
return x.cos() + b

args = (torch.Tensor([3, 4, 5]),)
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"):
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))

@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
def test_rewrite_assert_without_msg(self):
def f(x):
b = x.sin()
assert x[0] == 3
return x.cos() + b

args = (torch.Tensor([3, 4, 5]),)
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))

with self.assertRaisesRegex(AssertionError, ""):
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))

@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True)
def test_rewrite_assert_noop(self):
def f(x):
b = x.sin()
assert True
assert x.dtype == torch.float32
return x.cos() + b

args = (torch.Tensor([3, 4, 5]),)
exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))

cnt = torch._dynamo.testing.CompileCounter()
opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
self.assertTrue(same(f(*args), opt_f(*args)))
# torch._assert shouldn't be in the graph
self.assertEqual(cnt.op_count, 3)
self.assertEqual(cnt.frame_count, 1)

exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))

@patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", False)
def test_not_rewrite_assert(self):
def f(x):
b = x.sin()
assert x[0] == 3
return x.cos() + b

with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"):
torch._dynamo.export(f, torch.Tensor([3, 4, 5]))


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
3 changes: 0 additions & 3 deletions torch/_dynamo/config.py
Expand Up @@ -87,9 +87,6 @@
# if an exception is encountered
replay_record_enabled = False

# Rewrite assert statement in python with torch._assert
rewrite_assert_with_torch_assert = True

# Show a warning on every graph break
print_graph_breaks = False

Expand Down
99 changes: 5 additions & 94 deletions torch/_dynamo/symbolic_convert.py
Expand Up @@ -53,7 +53,6 @@
fake_tensors_available,
graph_break_dup_warning_checker,
istype,
proxy_args_kwargs,
)
from .variables.base import MutableLocal, typestr, VariableTracker
from .variables.builder import VariableBuilder, wrap_fx_proxy
Expand Down Expand Up @@ -122,103 +121,10 @@ def impl(self: "InstructionTranslatorBase", inst: Instruction):
return impl


def _detect_and_normalize_assert_statement(
self: "InstructionTranslatorBase", truth_fn: typing.Callable, push: bool
):
# Detect if this jump instruction is assert and normalize the assert
# by pushing dummy error message when nothing is given.
#
# Python 3.9 assertion is in following format:
# 18 POP_JUMP_IF_TRUE 28
# 20 LOAD_ASSERTION_ERROR
# 22 LOAD_CONST 3 ('Assert message') -> optional instruction
# 24 CALL_FUNCTION 1 -> optional instruction
# 26 RAISE_VARARGS
#
# Python 3.8 assertion is in following format:
# 18 POP_JUMP_IF_TRUE 28
# 20 LOAD_GLOBAL 0 (Assertion type)
# 22 LOAD_CONST 3 ('Assert message') -> optional instruction
# 24 CALL_FUNCTION 1 -> optional instruction
# 26 RAISE_VARARGS 1

if (truth_fn is not operator.truth) or push:
return False

current_instruction_pointer = self.instruction_pointer
inst = self.instructions[current_instruction_pointer]
# Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0
if sys.version_info < (3, 9):
if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError":
return False
else:
if inst.opname != "LOAD_ASSERTION_ERROR":
return False

current_instruction_pointer += 1

if current_instruction_pointer >= len(self.instructions):
return False

inst = self.instructions[current_instruction_pointer]
has_error_msg = False
# DETECT RAISE_VARARGS or LOAD CONST
if inst.opname == "LOAD_CONST":
if not isinstance(inst.argval, str):
return False
self.LOAD_CONST(inst)
has_error_msg = True

# if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION
current_instruction_pointer += 1
if current_instruction_pointer >= len(self.instructions):
return False
inst = self.instructions[current_instruction_pointer]
if inst.opname != "CALL_FUNCTION":
return False

# CALL_FUNCTION should be followed by RAISE_VARARGS
current_instruction_pointer += 1
if current_instruction_pointer >= len(self.instructions):
return False
inst = self.instructions[current_instruction_pointer]

if inst.opname != "RAISE_VARARGS":
return False

if not has_error_msg:
# Push dummy value instead of error message
self.push(ConstantVariable("assertion error"))

return True


def generic_jump(truth_fn: typing.Callable, push: bool):
def inner(self: "InstructionTranslatorBase", inst: Instruction):
value: VariableTracker = self.pop()
self.output.guards.update(value.guards)
if (
config.rewrite_assert_with_torch_assert
and _detect_and_normalize_assert_statement(self, truth_fn, push)
):
error_msg: VariableTracker = self.pop()
self.output.guards.update(error_msg.guards)
# Skip over things like `assert True`
if value.is_python_constant() and bool(value.as_python_constant()):
self.jump(inst)
return

# Manually insert torch._assert instead of python assert and jump over
# assert related instructions as we don't need them anymore.
self.output.create_proxy(
"call_function",
torch._assert,
*proxy_args_kwargs((value, error_msg), {}),
current_tx=self,
)
self.jump(inst)
return

if value.is_python_constant():
if truth_fn(value.as_python_constant()):
push and self.push(value)
Expand Down Expand Up @@ -252,6 +158,11 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction):
+ if_next
+ if_jump
)
elif isinstance(value, NNModuleVariable):
# Equivant of "self.nn_module is not None"
if truth_fn(value):
push and self.push(value)
self.jump(inst)
elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence(
self
):
Expand Down

0 comments on commit b8bb397

Please sign in to comment.