Skip to content

Commit fd5aed5

Browse files
committed
Updated testing functions to support device template.
1 parent c1c012e commit fd5aed5

File tree

1 file changed

+35
-50
lines changed

1 file changed

+35
-50
lines changed

test/dynamo/test_error_messages.py

Lines changed: 35 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1403,19 +1403,16 @@ def f(sz, x):
14031403
S0 = 420
14041404
S1 = N - S0
14051405

1406-
self.assertExpectedInlineMunged(
1406+
with self.assertRaisesRegex(
14071407
Exception,
1408-
lambda: f(
1408+
re.escape(
1409+
"""got RuntimeError("test_clarity_list::iterator_mismatch() Expected a value of type 'List[int]' for argument 'sizes' but instead found type 'immutable_list(SymInt, SymInt)'."""
1410+
),
1411+
):
1412+
f(
14091413
torch.tensor([S0, S1], device=device),
14101414
make_tensor(N, dtype=torch.float32, device=device),
1411-
),
1412-
"""\
1413-
Dynamo failed to run FX node with fake tensors: call_function test_clarity_list.iterator_mismatch.default(*(FakeTensor(..., size=(7312,)), [u0, u1]), **{}): got RuntimeError("test_clarity_list::iterator_mismatch() Expected a value of type 'List[int]' for argument 'sizes' but instead found type 'immutable_list(SymInt, SymInt)'.\\nPosition: 1\\nValue: [u0, u1]\\nDeclaration: test_clarity_list::iterator_mismatch(Tensor input, int[] sizes) -> Tensor[]\\nCast error details: Unable to cast Python instance of type <class 'torch.fx.immutable_collections.immutable_list'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)")
1414-
1415-
from user code:
1416-
File "test_error_messages.py", line N, in f
1417-
r0, r1 = torch.ops.test_clarity_list.iterator_mismatch.default(x, [s0, s1])""",
1418-
)
1415+
)
14191416

14201417
def test_tuple_iterator_contents_error_message(self, device):
14211418
lib_name = "test_clarity_tuple"
@@ -1431,19 +1428,16 @@ def f(sz, x):
14311428
S0 = 420
14321429
S1 = N - S0
14331430

1434-
self.assertExpectedInlineMunged(
1431+
with self.assertRaisesRegex(
14351432
Exception,
1436-
lambda: f(
1433+
re.escape(
1434+
"""got RuntimeError("test_clarity_tuple::iterator_mismatch() Expected a value of type 'List[int]' for argument 'sizes' but instead found type 'tuple(SymInt, SymInt)'."""
1435+
),
1436+
):
1437+
f(
14371438
torch.tensor((S0, S1), device=device),
14381439
make_tensor(N, dtype=torch.float32, device=device),
1439-
),
1440-
"""\
1441-
Dynamo failed to run FX node with fake tensors: call_function test_clarity_tuple.iterator_mismatch.default(*(FakeTensor(..., size=(7312,)), (u0, u1)), **{}): got RuntimeError("test_clarity_tuple::iterator_mismatch() Expected a value of type 'List[int]' for argument 'sizes' but instead found type 'tuple(SymInt, SymInt)'.\\nPosition: 1\\nValue: (u0, u1)\\nDeclaration: test_clarity_tuple::iterator_mismatch(Tensor input, int[] sizes) -> Tensor[]\\nCast error details: Unable to cast Python instance of type <class 'tuple'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)")
1442-
1443-
from user code:
1444-
File "test_error_messages.py", line N, in f
1445-
r0, r1 = torch.ops.test_clarity_tuple.iterator_mismatch.default(x, (s0, s1))""",
1446-
)
1440+
)
14471441

14481442
def test_dict_iterator_contents_error_message(self, device):
14491443
lib_name = "test_clarity_dict"
@@ -1461,19 +1455,16 @@ def f(sz, x):
14611455
S0 = 420
14621456
S1 = N - S0
14631457

1464-
self.assertExpectedInlineMunged(
1458+
with self.assertRaisesRegex(
14651459
Exception,
1466-
lambda: f(
1460+
re.escape(
1461+
"""got RuntimeError("test_clarity_dict::iterator_mismatch() Expected a value of type 'List[int]' for argument 'sizes' but instead found type 'immutable_dict(int, int)'."""
1462+
),
1463+
):
1464+
f(
14671465
torch.tensor((S0, S1), device=device),
14681466
make_tensor(N, dtype=torch.float32, device=device),
1469-
),
1470-
"""\
1471-
Dynamo failed to run FX node with fake tensors: call_function test_clarity_dict.iterator_mismatch.default(*(FakeTensor(..., size=(7312,)), {1: u0, 2: u1}), **{}): got RuntimeError("test_clarity_dict::iterator_mismatch() Expected a value of type 'List[int]' for argument 'sizes' but instead found type 'immutable_dict(int, int)'.\\nPosition: 1\\nValue: {1: u0, 2: u1}\\nDeclaration: test_clarity_dict::iterator_mismatch(Tensor input, int[] sizes) -> Tensor[]\\nCast error details: Unable to cast Python instance of type <class 'torch.fx.immutable_collections.immutable_dict'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)")
1472-
1473-
from user code:
1474-
File "test_error_messages.py", line N, in f
1475-
r0, r1 = torch.ops.test_clarity_dict.iterator_mismatch.default(""",
1476-
)
1467+
)
14771468

14781469
def test_named_tuple_iterator_contents_error_message(self, device):
14791470
lib_name = "test_clarity_named_tuple"
@@ -1492,19 +1483,16 @@ def f(sz, x):
14921483
S0 = 420
14931484
S1 = N - S0
14941485

1495-
self.assertExpectedInlineMunged(
1486+
with self.assertRaisesRegex(
14961487
Exception,
1497-
lambda: f(
1488+
re.escape(
1489+
"""got RuntimeError("test_clarity_named_tuple::iterator_mismatch() Expected a value of type 'List[int]' for argument 'sizes' but instead found type 'inSizes (aka NamedTuple(size_0, size_1))'."""
1490+
),
1491+
):
1492+
f(
14981493
torch.tensor([S0, S1], device=device),
14991494
make_tensor(N, dtype=torch.float32, device=device),
1500-
),
1501-
"""\
1502-
Dynamo failed to run FX node with fake tensors: call_function test_clarity_named_tuple.iterator_mismatch.default(*(FakeTensor(..., size=(7312,)), inSizes(size_0=u0, size_1=u1)), **{}): got RuntimeError("test_clarity_named_tuple::iterator_mismatch() Expected a value of type 'List[int]' for argument 'sizes' but instead found type 'inSizes (aka NamedTuple(size_0, size_1))'.\\nPosition: 1\\nValue: inSizes(size_0=u0, size_1=u1)\\nDeclaration: test_clarity_named_tuple::iterator_mismatch(Tensor input, int[] sizes) -> Tensor[]\\nCast error details: Unable to cast Python instance of type <class 'torch._dynamo.variables.functions.inSizes'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)")
1503-
1504-
from user code:
1505-
File "test_error_messages.py", line N, in f
1506-
r0, r1 = torch.ops.test_clarity_named_tuple.iterator_mismatch.default(""",
1507-
)
1495+
)
15081496

15091497
def test_noniter_contents_error_message(self, device):
15101498
lib_name = "test_clarity_noniter"
@@ -1520,19 +1508,16 @@ def f(sz, x):
15201508
S0 = 420
15211509
S1 = N - S0
15221510

1523-
self.assertExpectedInlineMunged(
1511+
with self.assertRaisesRegex(
15241512
Exception,
1525-
lambda: f(
1513+
re.escape(
1514+
"""got RuntimeError("test_clarity_noniter::iterator_mismatch() Expected a value of type 'List[int]' for argument 'sizes' but instead found type 'SymInt'."""
1515+
),
1516+
):
1517+
f(
15261518
torch.tensor([S0, S1], device=device),
15271519
make_tensor(N, dtype=torch.float32, device=device),
1528-
),
1529-
"""\
1530-
Dynamo failed to run FX node with fake tensors: call_function test_clarity_noniter.iterator_mismatch.default(*(FakeTensor(..., size=(7312,)), u0), **{}): got RuntimeError("test_clarity_noniter::iterator_mismatch() Expected a value of type 'List[int]' for argument 'sizes' but instead found type 'SymInt'.\\nPosition: 1\\nValue: u0\\nDeclaration: test_clarity_noniter::iterator_mismatch(Tensor input, int[] sizes) -> Tensor[]\\nCast error details: Unable to cast Python instance of type <class 'torch.SymInt'> to C++ type '?' (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details)")
1531-
1532-
from user code:
1533-
File "test_error_messages.py", line N, in f
1534-
r0, r1 = torch.ops.test_clarity_noniter.iterator_mismatch.default(x, s0)""",
1535-
)
1520+
)
15361521

15371522

15381523
instantiate_device_type_tests(ErrorMessageClarityTest, globals())

0 commit comments

Comments
 (0)