-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Test Plan: Unittest Reviewers: Subscribers: Tasks: Tags:
- Loading branch information
Showing
3 changed files
with
348 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
import unittest | ||
from torch.fx.verifier import ( | ||
ExirSpecViolationError, | ||
check_valid_aten_dialect, | ||
check_valid_edge_dialect, | ||
check_valid_exir, | ||
is_valid_aten_dialect, | ||
is_valid_edge_dialect, | ||
is_valid_exir, | ||
) | ||
|
||
|
||
from typing import List, Optional, Tuple, Union | ||
|
||
|
||
from torch.testing._internal.common_utils import TestCase | ||
import torch # noqa: F401 | ||
import torch.nn as nn | ||
from torch import Tensor | ||
import torch._dynamo as torchdynamo | ||
import copy | ||
from functorch import make_fx | ||
from functorch.experimental import functionalize | ||
|
||
|
||
@torch.no_grad() | ||
def capture(f, args): | ||
torchdynamo.config.capture_scalar_outputs = True | ||
torchdynamo.config.guard_nn_modules = True | ||
torchdynamo.config.dynamic_shapes = True | ||
torchdynamo.config.specialize_int_float = True | ||
torchdynamo.config.allow_rnn = True | ||
torchdynamo.config.verbose = True | ||
torchdynamo.reset() | ||
f, _ = torchdynamo.export( | ||
f, | ||
*copy.deepcopy(args), | ||
aten_graph=True, | ||
tracing_mode='fake', | ||
) | ||
f = functionalize(f, remove='mutations') | ||
gm = make_fx(f, tracing_mode='fake', _allow_non_fake_inputs=True)(*args) | ||
return gm | ||
|
||
|
||
class Transpose(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def forward(self, x: Tensor, dim0: int, dim1: int) -> Tensor: | ||
return x.transpose(dim0, dim1) | ||
|
||
|
||
class Mul(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def forward(self, input: Tensor, other: Tensor) -> Tensor: | ||
# or return torch.mul(input, other) | ||
return input * other | ||
|
||
def get_random_inputs(self) -> Tuple[Tensor, Tensor]: | ||
return (torch.randn(3, 2), torch.randn(3, 2)) | ||
|
||
|
||
class ElementwiseAdd(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def forward(self, x: Tensor, y: Tensor) -> Tensor: | ||
return x + y | ||
|
||
def get_random_inputs(self) -> Tuple[Tensor, Tensor]: | ||
return (torch.randn(1, 3), torch.randn(1, 3)) | ||
|
||
|
||
class Cat(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
# def forward(self, tensors, dim=0): | ||
def forward(self, *args: Tensor, dim: int) -> Tensor: | ||
tensors = args[:-1] | ||
return torch.cat(tensors, dim) | ||
|
||
|
||
class FeedForwardBlock(nn.Module): | ||
def __init__(self, input_dim: int, hidden_dim: int) -> None: | ||
super().__init__() | ||
self.input_dim = input_dim | ||
self.hidden_dim = hidden_dim | ||
|
||
self.layer_norm = nn.LayerNorm(input_dim) | ||
|
||
self.relu = nn.ReLU() | ||
|
||
self.linear1 = nn.Linear(input_dim, hidden_dim) | ||
self.dropout1 = nn.Dropout() | ||
|
||
self.linear2 = nn.Linear(hidden_dim, input_dim) | ||
self.dropout2 = nn.Dropout() | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
# LayerNorm -> Linear -> Dropout -> ReLU -> Linear -> Dropout | ||
y = self.layer_norm(x) | ||
y = self.linear1(y) | ||
y = self.dropout1(y) | ||
y = self.relu(y) | ||
y = self.linear2(y) | ||
y = self.dropout2(y) | ||
return y | ||
|
||
|
||
|
||
class VerifierTest(TestCase): | ||
|
||
def test_exir_verifier(self) -> None: | ||
m = ElementwiseAdd() | ||
egm = capture(m, (torch.randn(100), torch.randn(100))) | ||
# assert not throw | ||
check_valid_exir(egm) | ||
self.assertTrue(is_valid_exir(egm)) | ||
|
||
def test_exir_verifier_call_module(self) -> None: | ||
m = FeedForwardBlock(10, 10) | ||
gm = torch.fx.symbolic_trace(m) | ||
# this would have modules that are not delegates | ||
with self.assertRaises(ExirSpecViolationError): | ||
check_valid_exir(gm) | ||
self.assertFalse(is_valid_exir(gm)) | ||
|
||
def test_exir_verifier_no_functional(self) -> None: | ||
m = ElementwiseAdd() | ||
egm = capture(m, (torch.randn(100), torch.randn(100))) | ||
for node in egm.graph.nodes: | ||
if node.target == torch.ops.aten.add.Tensor: | ||
node.target = torch.ops.aten.add.out | ||
with self.assertRaises(ExirSpecViolationError): | ||
check_valid_exir(egm) | ||
self.assertFalse(is_valid_exir(egm)) | ||
|
||
def test_aten_dialect(self) -> None: | ||
m = ElementwiseAdd() | ||
egm = capture(m, (torch.randn(100), torch.randn(100))) | ||
check_valid_aten_dialect(egm) | ||
self.assertTrue(is_valid_aten_dialect(egm)) | ||
|
||
def test_aten_wrong_mem_format(self) -> None: | ||
class TestModel(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.a = torch.nn.parameter.Parameter( | ||
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last) | ||
) | ||
|
||
def forward(self, x): | ||
return self.a + x | ||
|
||
m = TestModel() | ||
egm = capture(m, (torch.randn(1, 3, 100, 100),)) | ||
egm._apply(lambda t: t.to(memory_format=torch.channels_last)) | ||
with self.assertRaises(ExirSpecViolationError): | ||
check_valid_aten_dialect(egm) | ||
self.assertFalse(is_valid_aten_dialect(egm)) | ||
|
||
def test_aten_wrong_mem_format_buffer(self) -> None: | ||
class TestModel(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.register_buffer( | ||
"a", | ||
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last), | ||
) | ||
|
||
def forward(self, x): | ||
return self.a + x | ||
|
||
m = TestModel() | ||
egm = capture(m, (torch.randn(1, 3, 100, 100),)) | ||
egm._apply(lambda t: t.to(memory_format=torch.channels_last)) | ||
with self.assertRaises(ExirSpecViolationError): | ||
check_valid_aten_dialect(egm) | ||
self.assertFalse(is_valid_aten_dialect(egm)) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import functools | ||
import itertools | ||
import operator | ||
from collections.abc import Iterable | ||
|
||
import torch | ||
from torch._ops import OpOverload | ||
from torch._subclasses.fake_tensor import FakeTensor | ||
from torch.fx import GraphModule | ||
|
||
|
||
ALLOWED_META_KEYS = {"spec", "stack_trace"} | ||
|
||
class ExirSpecViolationError(Exception): | ||
pass | ||
|
||
|
||
def is_functional(op: OpOverload) -> bool: | ||
return not op._schema.is_mutable | ||
|
||
|
||
|
||
def _check_has_fake_tensor(node: torch.fx.Node) -> None: | ||
def _check_is_fake_tensor(val): | ||
if isinstance(val, FakeTensor): | ||
return True | ||
if isinstance(val, Iterable): | ||
return all(_check_is_fake_tensor(x) for x in val) | ||
return False | ||
|
||
val = node.meta.get("val") | ||
if not _check_is_fake_tensor(val): | ||
raise Exir("Node.meta {} is missing val field.".format(node.name)) | ||
|
||
|
||
def check_valid_exir(gm: GraphModule) -> None: # noqa: C901 | ||
|
||
for node in gm.graph.nodes: | ||
# TODO(T140410192): EXIR should have fake tensor for all dialects | ||
# _check_has_fake_tensor(node) | ||
# TODO(qihan): Check for "val" which is not there yet. | ||
if node.op == "call_method": | ||
# what is delegates in ATen dialect? | ||
raise ExirSpecViolationError( | ||
"call_module can only be used for delegates, got a object of class '{}.{}' instead".format( | ||
type(node.args[0]).__module__, type(node.args[0]).__name__ | ||
), | ||
) | ||
|
||
if node.op == "call_module": | ||
raise ExirSpecViolationError( | ||
"call_module is not valid EXIR: got a class '{}' ".format(node.target), | ||
) | ||
|
||
if node.op == "call_function": | ||
op_name = ( | ||
node.target.name | ||
if hasattr(node.target, "name") | ||
else node.target.__name__ | ||
) | ||
is_builtin_func = (node.target == operator.getitem or node.target.__name__ in [ | ||
'while_loop', | ||
'cond', | ||
]) | ||
if not isinstance(node.target, OpOverload) and not is_builtin_func: | ||
raise ExirSpecViolationError( | ||
"Operator '{}' is not a registered Op".format(op_name), | ||
) | ||
# All ops functional | ||
# TODO(qihan): use node.target.is_functional: when PR/83134 lands | ||
if not is_builtin_func and not is_functional(node.target): | ||
raise ExirSpecViolationError( | ||
"operator '{}' is not functional".format(op_name), | ||
) | ||
|
||
if isinstance(node.target, OpOverload): | ||
stacktrace = node.meta.get("stack_trace") | ||
|
||
# TODO(qihqi) enable | ||
# if stacktrace is None: | ||
# raise ExirSpecViolationError( | ||
# "node of name '{}' for operator '{}' is missing stackstrace".format( | ||
# node.name, op_name | ||
# ), | ||
# ) | ||
|
||
|
||
def is_valid_exir(gm: GraphModule) -> bool: | ||
try: | ||
check_valid_exir(gm) | ||
return True | ||
except ExirSpecViolationError: | ||
return False | ||
|
||
|
||
def check_valid_aten_dialect(gm: GraphModule) -> None: | ||
"""Raises exception if gm is not in aten dialect. | ||
Args: | ||
gm: GraphModule | ||
""" | ||
# need to be first valid exir | ||
check_valid_exir(gm) | ||
# Operators be aten cannonical | ||
for n in gm.graph.nodes: | ||
if n.op == "call_function" and isinstance(n.target, OpOverload): | ||
if ( | ||
torch.Tag.core not in n.target.tags | ||
and torch.Tag.view_copy not in n.target.tags | ||
): | ||
# NOTE(qihan): whether view_copy operators are marked as canonical is still under | ||
# discussion. | ||
raise ExirSpecViolationError( | ||
"Operator {}.{} is not Aten Canonical.".format( | ||
n.target.__module__, n.target.__name__ | ||
) | ||
) | ||
|
||
# Tensors be of contiguous format | ||
for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()): | ||
if isinstance(param, torch.Tensor): | ||
if not param.is_contiguous(): | ||
raise ExirSpecViolationError( | ||
f"Tensors in Aten dialect must be contiguous, {name} is not contiguous" | ||
) | ||
|
||
|
||
def is_valid_aten_dialect(gm: GraphModule) -> bool: | ||
try: | ||
check_valid_aten_dialect(gm) | ||
return True | ||
except ExirSpecViolationError: | ||
return False | ||
|
||
|
||
def check_valid_edge_dialect(gm: GraphModule) -> None: | ||
check_valid_aten_dialect(gm) | ||
|
||
# Additionally, edge dialect's operator must have same input dtype | ||
for n in gm.graph.nodes: | ||
if n.op == "call_function" and isinstance(n.target, OpOverload): | ||
_check_has_fake_tensor(n) | ||
dtypes = set() | ||
for arg in n.args: | ||
if isinstance(arg, torch.Tensor): | ||
dtypes.add(arg.dtype) | ||
if isinstance(arg, torch.fx.Node): | ||
dtypes.add(arg.meta["val"].dtype) | ||
if len(dtypes) > 1: | ||
raise ExirSpecViolationError( | ||
"Operators of Edge dialect in should work on tensors of same dtype" | ||
) | ||
|
||
|
||
def is_valid_edge_dialect(gm: GraphModule) -> bool: | ||
try: | ||
check_valid_edge_dialect(gm) | ||
return True | ||
except ExirSpecViolationError: | ||
return False |