-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Test Plan: Unittest Reviewers: Subscribers: Tasks: Tags:
- Loading branch information
Showing
3 changed files
with
367 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
# Owner(s): ["module: fx"] | ||
import sys | ||
import unittest | ||
from torch.fx.verifier import ( | ||
SpecViolationError, | ||
check_valid_aten_dialect, | ||
check_valid, | ||
is_valid_aten_dialect, | ||
is_valid, | ||
) | ||
|
||
|
||
from typing import Tuple | ||
|
||
|
||
from torch.testing._internal.common_utils import TestCase | ||
import torch # noqa: F401 | ||
import torch.nn as nn | ||
from torch import Tensor | ||
import torch._dynamo as torchdynamo | ||
import copy | ||
from functorch import make_fx | ||
from functorch.experimental import functionalize | ||
|
||
|
||
@torch.no_grad() | ||
def capture(f, args): | ||
torchdynamo.config.capture_scalar_outputs = True | ||
torchdynamo.config.guard_nn_modules = True | ||
torchdynamo.config.dynamic_shapes = True | ||
torchdynamo.config.specialize_int_float = True | ||
torchdynamo.config.allow_rnn = True | ||
torchdynamo.config.verbose = True | ||
torchdynamo.reset() | ||
graphmodule, _ = torchdynamo.export( | ||
f, | ||
*copy.deepcopy(args), | ||
aten_graph=True, | ||
tracing_mode='fake', | ||
) | ||
|
||
def graph_with_interpreter(*args): | ||
with torch.fx.traceback.preserve_node_meta(): | ||
return torch.fx.Interpreter(graphmodule).run(*args) | ||
|
||
functionalized_callable = functionalize( | ||
graph_with_interpreter, | ||
remove='mutations_and_views', | ||
) | ||
gm = make_fx(functionalized_callable, tracing_mode='fake', _allow_non_fake_inputs=True)(*args) | ||
return gm | ||
|
||
|
||
class Transpose(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def forward(self, x: Tensor, dim0: int, dim1: int) -> Tensor: | ||
return x.transpose(dim0, dim1) | ||
|
||
|
||
class Mul(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def forward(self, input: Tensor, other: Tensor) -> Tensor: | ||
# or return torch.mul(input, other) | ||
return input * other | ||
|
||
def get_random_inputs(self) -> Tuple[Tensor, Tensor]: | ||
return (torch.randn(3, 2), torch.randn(3, 2)) | ||
|
||
|
||
class ElementwiseAdd(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def forward(self, x: Tensor, y: Tensor) -> Tensor: | ||
return x + y | ||
|
||
def get_random_inputs(self) -> Tuple[Tensor, Tensor]: | ||
return (torch.randn(1, 3), torch.randn(1, 3)) | ||
|
||
|
||
class Cat(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
# def forward(self, tensors, dim=0): | ||
def forward(self, *args: Tensor, dim: int) -> Tensor: | ||
tensors = args[:-1] | ||
return torch.cat(tensors, dim) | ||
|
||
|
||
class FeedForwardBlock(nn.Module): | ||
def __init__(self, input_dim: int, hidden_dim: int) -> None: | ||
super().__init__() | ||
self.input_dim = input_dim | ||
self.hidden_dim = hidden_dim | ||
|
||
self.layer_norm = nn.LayerNorm(input_dim) | ||
|
||
self.relu = nn.ReLU() | ||
|
||
self.linear1 = nn.Linear(input_dim, hidden_dim) | ||
self.dropout1 = nn.Dropout() | ||
|
||
self.linear2 = nn.Linear(hidden_dim, input_dim) | ||
self.dropout2 = nn.Dropout() | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
# LayerNorm -> Linear -> Dropout -> ReLU -> Linear -> Dropout | ||
y = self.layer_norm(x) | ||
y = self.linear1(y) | ||
y = self.dropout1(y) | ||
y = self.relu(y) | ||
y = self.linear2(y) | ||
y = self.dropout2(y) | ||
return y | ||
|
||
|
||
|
||
class VerifierTest(TestCase): | ||
|
||
@unittest.skipIf(sys.version_info >= (3, 11), "dynamo doesnt support 3.11") | ||
def test_verifier(self) -> None: | ||
m = ElementwiseAdd() | ||
egm = capture(m, (torch.randn(100), torch.randn(100))) | ||
# assert not throw | ||
check_valid(egm) | ||
self.assertTrue(is_valid(egm)) | ||
|
||
@unittest.skipIf(sys.version_info >= (3, 11), "dynamo doesnt support 3.11") | ||
def testr_verifier_call_module(self) -> None: | ||
m = FeedForwardBlock(10, 10) | ||
gm = torch.fx.symbolic_trace(m) | ||
# this would have modules that are not delegates | ||
with self.assertRaises(SpecViolationError): | ||
check_valid(gm) | ||
self.assertFalse(is_valid(gm)) | ||
|
||
@unittest.skipIf(sys.version_info >= (3, 11), "dynamo doesnt support 3.11") | ||
def test_verifier_no_functional(self) -> None: | ||
m = ElementwiseAdd() | ||
egm = capture(m, (torch.randn(100), torch.randn(100))) | ||
for node in egm.graph.nodes: | ||
if node.target == torch.ops.aten.add.Tensor: | ||
node.target = torch.ops.aten.add.out | ||
with self.assertRaises(SpecViolationError): | ||
check_valid(egm) | ||
self.assertFalse(is_valid(egm)) | ||
|
||
@unittest.skipIf(sys.version_info >= (3, 11), "dynamo doesnt support 3.11") | ||
def test_aten_dialect(self) -> None: | ||
m = ElementwiseAdd() | ||
egm = capture(m, (torch.randn(100), torch.randn(100))) | ||
check_valid_aten_dialect(egm) | ||
self.assertTrue(is_valid_aten_dialect(egm)) | ||
|
||
@unittest.skipIf(sys.version_info >= (3, 11), "dynamo doesnt support 3.11") | ||
def test_aten_wrong_mem_format(self) -> None: | ||
class TestModel(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.a = torch.nn.parameter.Parameter( | ||
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last) | ||
) | ||
|
||
def forward(self, x): | ||
return self.a + x | ||
|
||
m = TestModel() | ||
egm = capture(m, (torch.randn(1, 3, 100, 100),)) | ||
egm._apply(lambda t: t.to(memory_format=torch.channels_last)) | ||
with self.assertRaises(SpecViolationError): | ||
check_valid_aten_dialect(egm) | ||
self.assertFalse(is_valid_aten_dialect(egm)) | ||
|
||
@unittest.skipIf(sys.version_info >= (3, 11), "dynamo doesnt support 3.11") | ||
def test_aten_wrong_mem_format_buffer(self) -> None: | ||
class TestModel(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.register_buffer( | ||
"a", | ||
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last), | ||
) | ||
|
||
def forward(self, x): | ||
return self.a + x | ||
|
||
m = TestModel() | ||
egm = capture(m, (torch.randn(1, 3, 100, 100),)) | ||
egm._apply(lambda t: t.to(memory_format=torch.channels_last)) | ||
with self.assertRaises(SpecViolationError): | ||
check_valid_aten_dialect(egm) | ||
self.assertFalse(is_valid_aten_dialect(egm)) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import itertools | ||
import operator | ||
from collections.abc import Iterable | ||
|
||
import torch | ||
from torch._ops import OpOverload | ||
from torch._subclasses.fake_tensor import FakeTensor | ||
from torch.fx import GraphModule | ||
|
||
|
||
ALLOWED_META_KEYS = {"spec", "stack_trace"} | ||
|
||
@torch.fx._compatibility.compatibility(is_backward_compatible=False) | ||
class SpecViolationError(Exception): | ||
pass | ||
|
||
@torch.fx._compatibility.compatibility(is_backward_compatible=False) | ||
def is_functional(op: OpOverload) -> bool: | ||
return not op._schema.is_mutable | ||
|
||
|
||
@torch.fx._compatibility.compatibility(is_backward_compatible=False) | ||
def _check_has_fake_tensor(node: torch.fx.Node) -> None: | ||
def _check_is_fake_tensor(val): | ||
if isinstance(val, FakeTensor): | ||
return True | ||
if isinstance(val, Iterable): | ||
return all(_check_is_fake_tensor(x) for x in val) | ||
return False | ||
|
||
val = node.meta.get("val") | ||
if not _check_is_fake_tensor(val): | ||
raise SpecViolationError("Node.meta {} is missing val field.".format(node.name)) | ||
|
||
|
||
@torch.fx._compatibility.compatibility(is_backward_compatible=False) | ||
def check_valid(gm: GraphModule) -> None: # noqa: C901 | ||
|
||
for node in gm.graph.nodes: | ||
# TODO(T140410192): should have fake tensor for all dialects | ||
# _check_has_fake_tensor(node) | ||
# TODO(qihan): Check for "val" which is not there yet. | ||
if node.op == "call_method": | ||
# what is delegates in ATen dialect? | ||
raise SpecViolationError( | ||
"call_module can only be used for delegates, got a object of class '{}.{}' instead".format( | ||
type(node.args[0]).__module__, type(node.args[0]).__name__ | ||
), | ||
) | ||
|
||
if node.op == "call_module": | ||
raise SpecViolationError( | ||
"call_module is not valid: got a class '{}' ".format(node.target), | ||
) | ||
|
||
if node.op == "call_function": | ||
op_name = ( | ||
node.target.name | ||
if hasattr(node.target, "name") | ||
else node.target.__name__ | ||
) | ||
is_builtin_func = (node.target == operator.getitem or node.target.__name__ in [ | ||
'while_loop', | ||
'cond', | ||
]) | ||
if not isinstance(node.target, OpOverload) and not is_builtin_func: | ||
raise SpecViolationError( | ||
"Operator '{}' is not a registered Op".format(op_name), | ||
) | ||
# All ops functional | ||
# TODO(qihan): use node.target.is_functional: when PR/83134 lands | ||
if not is_builtin_func and not is_functional(node.target): | ||
raise SpecViolationError( | ||
"operator '{}' is not functional".format(op_name), | ||
) | ||
|
||
if isinstance(node.target, OpOverload): | ||
stacktrace = node.meta.get("stack_trace") | ||
|
||
if stacktrace is None: | ||
raise SpecViolationError( | ||
"node of name '{}' for operator '{}' is missing stackstrace".format( | ||
node.name, op_name | ||
), | ||
) | ||
|
||
|
||
@torch.fx._compatibility.compatibility(is_backward_compatible=False) | ||
def is_valid(gm: GraphModule) -> bool: | ||
try: | ||
check_valid(gm) | ||
return True | ||
except SpecViolationError: | ||
return False | ||
|
||
|
||
@torch.fx._compatibility.compatibility(is_backward_compatible=False) | ||
def check_valid_aten_dialect(gm: GraphModule) -> None: | ||
"""Raises exception if gm is not in aten dialect. | ||
Args: | ||
gm: GraphModule | ||
""" | ||
# need to be first valid | ||
check_valid(gm) | ||
# Operators be aten cannonical | ||
for n in gm.graph.nodes: | ||
if n.op == "call_function" and isinstance(n.target, OpOverload): | ||
if ( | ||
torch.Tag.core not in n.target.tags # type: ignore[attr-defined] | ||
and torch.Tag.view_copy not in n.target.tags # type: ignore[attr-defined] | ||
): | ||
# NOTE(qihan): whether view_copy operators are marked as canonical is still under | ||
# discussion. | ||
raise SpecViolationError( | ||
"Operator {}.{} is not Aten Canonical.".format( | ||
n.target.__module__, n.target.__name__ | ||
) | ||
) | ||
|
||
# Tensors be of contiguous format | ||
for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()): | ||
if isinstance(param, torch.Tensor): | ||
if not param.is_contiguous(): | ||
raise SpecViolationError( | ||
f"Tensors in Aten dialect must be contiguous, {name} is not contiguous" | ||
) | ||
|
||
|
||
@torch.fx._compatibility.compatibility(is_backward_compatible=False) | ||
def is_valid_aten_dialect(gm: GraphModule) -> bool: | ||
try: | ||
check_valid_aten_dialect(gm) | ||
return True | ||
except SpecViolationError: | ||
return False | ||
|
||
|
||
@torch.fx._compatibility.compatibility(is_backward_compatible=False) | ||
def check_valid_edge_dialect(gm: GraphModule) -> None: | ||
check_valid_aten_dialect(gm) | ||
|
||
# Additionally, edge dialect's operator must have same input dtype | ||
for n in gm.graph.nodes: | ||
if n.op == "call_function" and isinstance(n.target, OpOverload): | ||
_check_has_fake_tensor(n) | ||
dtypes = set() | ||
for arg in n.args: | ||
if isinstance(arg, torch.Tensor): | ||
dtypes.add(arg.dtype) | ||
if isinstance(arg, torch.fx.Node): | ||
dtypes.add(arg.meta["val"].dtype) | ||
if len(dtypes) > 1: | ||
raise SpecViolationError( | ||
"Operators of Edge dialect in should work on tensors of same dtype" | ||
) | ||
|
||
|
||
@torch.fx._compatibility.compatibility(is_backward_compatible=False) | ||
def is_valid_edge_dialect(gm: GraphModule) -> bool: | ||
try: | ||
check_valid_edge_dialect(gm) | ||
return True | ||
except SpecViolationError: | ||
return False |