-
Notifications
You must be signed in to change notification settings - Fork 21.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Test Plan: Unittest Reviewers: Subscribers: Tasks: Tags:
- Loading branch information
Showing
2 changed files
with
195 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
import unittest | ||
from torch.fx.verifier import ( | ||
ExirSpecViolationError, | ||
check_valid_aten_dialect, | ||
check_valid_edge_dialect, | ||
check_valid_exir, | ||
is_valid_aten_dialect, | ||
is_valid_edge_dialect, | ||
is_valid_exir, | ||
) | ||
|
||
|
||
from typing import List, Optional, Tuple, Union | ||
|
||
|
||
from torch.testing._internal.common_utils import TestCase | ||
import torch # noqa: F401 | ||
import torch.nn as nn | ||
from torch import Tensor | ||
import torch._dynamo as torchdynamo | ||
import copy | ||
from functorch import make_fx | ||
from functorch.experimental import functionalize | ||
|
||
|
||
@torch.no_grad() | ||
def capture(f, args): | ||
torchdynamo.config.capture_scalar_outputs = True | ||
torchdynamo.config.guard_nn_modules = True | ||
torchdynamo.config.dynamic_shapes = True | ||
torchdynamo.config.specialize_int_float = True | ||
torchdynamo.config.allow_rnn = True | ||
torchdynamo.config.verbose = True | ||
torchdynamo.reset() | ||
f, _ = torchdynamo.export( | ||
f, | ||
*copy.deepcopy(args), | ||
aten_graph=True, | ||
tracing_mode='fake', | ||
) | ||
f = functionalize(f, remove='mutations') | ||
gm = make_fx(f, tracing_mode='fake', _allow_non_fake_inputs=True)(*args) | ||
return gm | ||
|
||
|
||
class Transpose(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def forward(self, x: Tensor, dim0: int, dim1: int) -> Tensor: | ||
return x.transpose(dim0, dim1) | ||
|
||
|
||
class Mul(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def forward(self, input: Tensor, other: Tensor) -> Tensor: | ||
# or return torch.mul(input, other) | ||
return input * other | ||
|
||
def get_random_inputs(self) -> Tuple[Tensor, Tensor]: | ||
return (torch.randn(3, 2), torch.randn(3, 2)) | ||
|
||
|
||
class ElementwiseAdd(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def forward(self, x: Tensor, y: Tensor) -> Tensor: | ||
return x + y | ||
|
||
def get_random_inputs(self) -> Tuple[Tensor, Tensor]: | ||
return (torch.randn(1, 3), torch.randn(1, 3)) | ||
|
||
|
||
class Cat(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
# def forward(self, tensors, dim=0): | ||
def forward(self, *args: Tensor, dim: int) -> Tensor: | ||
tensors = args[:-1] | ||
return torch.cat(tensors, dim) | ||
|
||
|
||
class FeedForwardBlock(nn.Module): | ||
def __init__(self, input_dim: int, hidden_dim: int) -> None: | ||
super().__init__() | ||
self.input_dim = input_dim | ||
self.hidden_dim = hidden_dim | ||
|
||
self.layer_norm = nn.LayerNorm(input_dim) | ||
|
||
self.relu = nn.ReLU() | ||
|
||
self.linear1 = nn.Linear(input_dim, hidden_dim) | ||
self.dropout1 = nn.Dropout() | ||
|
||
self.linear2 = nn.Linear(hidden_dim, input_dim) | ||
self.dropout2 = nn.Dropout() | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
# LayerNorm -> Linear -> Dropout -> ReLU -> Linear -> Dropout | ||
y = self.layer_norm(x) | ||
y = self.linear1(y) | ||
y = self.dropout1(y) | ||
y = self.relu(y) | ||
y = self.linear2(y) | ||
y = self.dropout2(y) | ||
return y | ||
|
||
|
||
|
||
class VerifierTest(TestCase): | ||
|
||
def test_exir_verifier(self) -> None: | ||
m = ElementwiseAdd() | ||
egm = capture(m, (torch.randn(100), torch.randn(100))) | ||
# assert not throw | ||
check_valid_exir(egm) | ||
self.assertTrue(is_valid_exir(egm)) | ||
|
||
def test_exir_verifier_call_module(self) -> None: | ||
m = FeedForwardBlock(10, 10) | ||
gm = torch.fx.symbolic_trace(m) | ||
# this would have modules that are not delegates | ||
with self.assertRaises(ExirSpecViolationError): | ||
check_valid_exir(gm) | ||
self.assertFalse(is_valid_exir(gm)) | ||
|
||
def test_exir_verifier_no_functional(self) -> None: | ||
m = ElementwiseAdd() | ||
egm = capture(m, (torch.randn(100), torch.randn(100))) | ||
for node in egm.graph.nodes: | ||
if node.target == torch.ops.aten.add.Tensor: | ||
node.target = torch.ops.aten.add.out | ||
with self.assertRaises(ExirSpecViolationError): | ||
check_valid_exir(egm) | ||
self.assertFalse(is_valid_exir(egm)) | ||
|
||
def test_aten_dialect(self) -> None: | ||
m = ElementwiseAdd() | ||
egm = capture(m, (torch.randn(100), torch.randn(100))) | ||
check_valid_aten_dialect(egm) | ||
self.assertTrue(is_valid_aten_dialect(egm)) | ||
|
||
def test_aten_wrong_mem_format(self) -> None: | ||
class TestModel(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.a = torch.nn.parameter.Parameter( | ||
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last) | ||
) | ||
|
||
def forward(self, x): | ||
return self.a + x | ||
|
||
m = TestModel() | ||
egm = capture(m, (torch.randn(1, 3, 100, 100),)) | ||
egm._apply(lambda t: t.to(memory_format=torch.channels_last)) | ||
with self.assertRaises(ExirSpecViolationError): | ||
check_valid_aten_dialect(egm) | ||
self.assertFalse(is_valid_aten_dialect(egm)) | ||
|
||
def test_aten_wrong_mem_format_buffer(self) -> None: | ||
class TestModel(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.register_buffer( | ||
"a", | ||
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last), | ||
) | ||
|
||
def forward(self, x): | ||
return self.a + x | ||
|
||
m = TestModel() | ||
egm = capture(m, (torch.randn(1, 3, 100, 100),)) | ||
egm._apply(lambda t: t.to(memory_format=torch.channels_last)) | ||
with self.assertRaises(ExirSpecViolationError): | ||
check_valid_aten_dialect(egm) | ||
self.assertFalse(is_valid_aten_dialect(egm)) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters