Skip to content

Commit

Permalink
Add verifier for EXIR Aten dialect.
Browse files Browse the repository at this point in the history
Summary:

Test Plan:
Unittest

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
qihqi committed Feb 16, 2023
1 parent a005dd1 commit e386ca1
Show file tree
Hide file tree
Showing 3 changed files with 367 additions and 0 deletions.
201 changes: 201 additions & 0 deletions test/fx/test_verifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Owner(s): ["module: fx"]
import sys
import unittest
from torch.fx.verifier import (
SpecViolationError,
check_valid_aten_dialect,
check_valid,
is_valid_aten_dialect,
is_valid,
)


from typing import Tuple


from torch.testing._internal.common_utils import TestCase
import torch # noqa: F401
import torch.nn as nn
from torch import Tensor
import torch._dynamo as torchdynamo
import copy
from functorch import make_fx
from functorch.experimental import functionalize


@torch.no_grad()
def capture(f, args):
torchdynamo.config.capture_scalar_outputs = True
torchdynamo.config.guard_nn_modules = True
torchdynamo.config.dynamic_shapes = True
torchdynamo.config.specialize_int_float = True
torchdynamo.config.allow_rnn = True
torchdynamo.config.verbose = True
torchdynamo.reset()
graphmodule, _ = torchdynamo.export(
f,
*copy.deepcopy(args),
aten_graph=True,
tracing_mode='fake',
)

def graph_with_interpreter(*args):
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(graphmodule).run(*args)

functionalized_callable = functionalize(
graph_with_interpreter,
remove='mutations_and_views',
)
gm = make_fx(functionalized_callable, tracing_mode='fake', _allow_non_fake_inputs=True)(*args)
return gm


class Transpose(nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x: Tensor, dim0: int, dim1: int) -> Tensor:
return x.transpose(dim0, dim1)


class Mul(nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, input: Tensor, other: Tensor) -> Tensor:
# or return torch.mul(input, other)
return input * other

def get_random_inputs(self) -> Tuple[Tensor, Tensor]:
return (torch.randn(3, 2), torch.randn(3, 2))


class ElementwiseAdd(nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x: Tensor, y: Tensor) -> Tensor:
return x + y

def get_random_inputs(self) -> Tuple[Tensor, Tensor]:
return (torch.randn(1, 3), torch.randn(1, 3))


class Cat(nn.Module):
def __init__(self) -> None:
super().__init__()

# def forward(self, tensors, dim=0):
def forward(self, *args: Tensor, dim: int) -> Tensor:
tensors = args[:-1]
return torch.cat(tensors, dim)


class FeedForwardBlock(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int) -> None:
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim

self.layer_norm = nn.LayerNorm(input_dim)

self.relu = nn.ReLU()

self.linear1 = nn.Linear(input_dim, hidden_dim)
self.dropout1 = nn.Dropout()

self.linear2 = nn.Linear(hidden_dim, input_dim)
self.dropout2 = nn.Dropout()

def forward(self, x: Tensor) -> Tensor:
# LayerNorm -> Linear -> Dropout -> ReLU -> Linear -> Dropout
y = self.layer_norm(x)
y = self.linear1(y)
y = self.dropout1(y)
y = self.relu(y)
y = self.linear2(y)
y = self.dropout2(y)
return y



class VerifierTest(TestCase):

@unittest.skipIf(sys.version_info >= (3, 11), "dynamo doesnt support 3.11")
def test_verifier(self) -> None:
m = ElementwiseAdd()
egm = capture(m, (torch.randn(100), torch.randn(100)))
# assert not throw
check_valid(egm)
self.assertTrue(is_valid(egm))

@unittest.skipIf(sys.version_info >= (3, 11), "dynamo doesnt support 3.11")
def testr_verifier_call_module(self) -> None:
m = FeedForwardBlock(10, 10)
gm = torch.fx.symbolic_trace(m)
# this would have modules that are not delegates
with self.assertRaises(SpecViolationError):
check_valid(gm)
self.assertFalse(is_valid(gm))

@unittest.skipIf(sys.version_info >= (3, 11), "dynamo doesnt support 3.11")
def test_verifier_no_functional(self) -> None:
m = ElementwiseAdd()
egm = capture(m, (torch.randn(100), torch.randn(100)))
for node in egm.graph.nodes:
if node.target == torch.ops.aten.add.Tensor:
node.target = torch.ops.aten.add.out
with self.assertRaises(SpecViolationError):
check_valid(egm)
self.assertFalse(is_valid(egm))

@unittest.skipIf(sys.version_info >= (3, 11), "dynamo doesnt support 3.11")
def test_aten_dialect(self) -> None:
m = ElementwiseAdd()
egm = capture(m, (torch.randn(100), torch.randn(100)))
check_valid_aten_dialect(egm)
self.assertTrue(is_valid_aten_dialect(egm))

@unittest.skipIf(sys.version_info >= (3, 11), "dynamo doesnt support 3.11")
def test_aten_wrong_mem_format(self) -> None:
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.nn.parameter.Parameter(
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last)
)

def forward(self, x):
return self.a + x

m = TestModel()
egm = capture(m, (torch.randn(1, 3, 100, 100),))
egm._apply(lambda t: t.to(memory_format=torch.channels_last))
with self.assertRaises(SpecViolationError):
check_valid_aten_dialect(egm)
self.assertFalse(is_valid_aten_dialect(egm))

@unittest.skipIf(sys.version_info >= (3, 11), "dynamo doesnt support 3.11")
def test_aten_wrong_mem_format_buffer(self) -> None:
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"a",
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last),
)

def forward(self, x):
return self.a + x

m = TestModel()
egm = capture(m, (torch.randn(1, 3, 100, 100),))
egm._apply(lambda t: t.to(memory_format=torch.channels_last))
with self.assertRaises(SpecViolationError):
check_valid_aten_dialect(egm)
self.assertFalse(is_valid_aten_dialect(egm))


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from fx.test_common_passes import TestCommonPass # noqa: F401
from fx.test_cse_pass import TestCSEPass # noqa: F401
from fx.test_matcher_utils import TestMatcher # noqa: F401
from fx.test_verifier import VerifierTest # noqa: F401

from fx.test_gradual_type import AnnotationsTest # noqa: F401
from fx.test_gradual_type import TypeCheckerTest # noqa: F401
Expand Down
165 changes: 165 additions & 0 deletions torch/fx/verifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import itertools
import operator
from collections.abc import Iterable

import torch
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx import GraphModule


ALLOWED_META_KEYS = {"spec", "stack_trace"}

@torch.fx._compatibility.compatibility(is_backward_compatible=False)
class SpecViolationError(Exception):
pass

@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def is_functional(op: OpOverload) -> bool:
return not op._schema.is_mutable


@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def _check_has_fake_tensor(node: torch.fx.Node) -> None:
def _check_is_fake_tensor(val):
if isinstance(val, FakeTensor):
return True
if isinstance(val, Iterable):
return all(_check_is_fake_tensor(x) for x in val)
return False

val = node.meta.get("val")
if not _check_is_fake_tensor(val):
raise SpecViolationError("Node.meta {} is missing val field.".format(node.name))


@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def check_valid(gm: GraphModule) -> None: # noqa: C901

for node in gm.graph.nodes:
# TODO(T140410192): should have fake tensor for all dialects
# _check_has_fake_tensor(node)
# TODO(qihan): Check for "val" which is not there yet.
if node.op == "call_method":
# what is delegates in ATen dialect?
raise SpecViolationError(
"call_module can only be used for delegates, got a object of class '{}.{}' instead".format(
type(node.args[0]).__module__, type(node.args[0]).__name__
),
)

if node.op == "call_module":
raise SpecViolationError(
"call_module is not valid: got a class '{}' ".format(node.target),
)

if node.op == "call_function":
op_name = (
node.target.name
if hasattr(node.target, "name")
else node.target.__name__
)
is_builtin_func = (node.target == operator.getitem or node.target.__name__ in [
'while_loop',
'cond',
])
if not isinstance(node.target, OpOverload) and not is_builtin_func:
raise SpecViolationError(
"Operator '{}' is not a registered Op".format(op_name),
)
# All ops functional
# TODO(qihan): use node.target.is_functional: when PR/83134 lands
if not is_builtin_func and not is_functional(node.target):
raise SpecViolationError(
"operator '{}' is not functional".format(op_name),
)

if isinstance(node.target, OpOverload):
stacktrace = node.meta.get("stack_trace")

if stacktrace is None:
raise SpecViolationError(
"node of name '{}' for operator '{}' is missing stackstrace".format(
node.name, op_name
),
)


@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def is_valid(gm: GraphModule) -> bool:
try:
check_valid(gm)
return True
except SpecViolationError:
return False


@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def check_valid_aten_dialect(gm: GraphModule) -> None:
"""Raises exception if gm is not in aten dialect.
Args:
gm: GraphModule
"""
# need to be first valid
check_valid(gm)
# Operators be aten cannonical
for n in gm.graph.nodes:
if n.op == "call_function" and isinstance(n.target, OpOverload):
if (
torch.Tag.core not in n.target.tags # type: ignore[attr-defined]
and torch.Tag.view_copy not in n.target.tags # type: ignore[attr-defined]
):
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
# discussion.
raise SpecViolationError(
"Operator {}.{} is not Aten Canonical.".format(
n.target.__module__, n.target.__name__
)
)

# Tensors be of contiguous format
for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()):
if isinstance(param, torch.Tensor):
if not param.is_contiguous():
raise SpecViolationError(
f"Tensors in Aten dialect must be contiguous, {name} is not contiguous"
)


@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def is_valid_aten_dialect(gm: GraphModule) -> bool:
try:
check_valid_aten_dialect(gm)
return True
except SpecViolationError:
return False


@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def check_valid_edge_dialect(gm: GraphModule) -> None:
check_valid_aten_dialect(gm)

# Additionally, edge dialect's operator must have same input dtype
for n in gm.graph.nodes:
if n.op == "call_function" and isinstance(n.target, OpOverload):
_check_has_fake_tensor(n)
dtypes = set()
for arg in n.args:
if isinstance(arg, torch.Tensor):
dtypes.add(arg.dtype)
if isinstance(arg, torch.fx.Node):
dtypes.add(arg.meta["val"].dtype)
if len(dtypes) > 1:
raise SpecViolationError(
"Operators of Edge dialect in should work on tensors of same dtype"
)


@torch.fx._compatibility.compatibility(is_backward_compatible=False)
def is_valid_edge_dialect(gm: GraphModule) -> bool:
try:
check_valid_edge_dialect(gm)
return True
except SpecViolationError:
return False

0 comments on commit e386ca1

Please sign in to comment.