Skip to content

Commit

Permalink
conv2d (#61093)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #61093

Test Plan: Imported from OSS

Reviewed By: eellison

Differential Revision: D29562478

Pulled By: migeed-z

fbshipit-source-id: d41f3a9526ee52a9571cb861be03bf9ae176a373
  • Loading branch information
migeed-z authored and facebook-github-bot committed Jul 9, 2021
1 parent 5fbc853 commit d52ebf2
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 1 deletion.
165 changes: 165 additions & 0 deletions test/fx/test_gradual_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from torch.fx.experimental.rewriter import RewritingTracer
from torch.fx import GraphModule

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)

class AnnotationsTest(unittest.TestCase):

def test_annotations(self):
Expand Down Expand Up @@ -330,6 +335,166 @@ def forward(self, x: Dyn):
with self.assertRaises(TypeError):
tc.type_check()

def test_type_check_conv2D(self):
class BasicBlock(torch.nn.Module):
def __init__(self, inplanes, planes, stride=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = torch.nn.BatchNorm2d
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)

def forward(self, x: Dyn):
identity = x
out: TensorType((2, 2, Dyn, 4)) = self.conv1(x)
out += identity
return out

B = BasicBlock(2, 2)
ast_rewriter = RewritingTracer()
graph = ast_rewriter.trace(B)
traced = GraphModule(ast_rewriter.root, graph, "gm")
tc = GraphTypeChecker({}, traced)
tc.type_check()
for n in graph.nodes:
if n.op == 'placeholder':
assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
if n.op == 'call_function':
assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
if n.op == 'output':
assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
if n.op == 'call_module':
assert n.type == TensorType((2, 2, Dyn, 4))

def test_type_check_conv2D_2(self):
class BasicBlock(torch.nn.Module):
def __init__(self, inplanes, planes, stride=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = torch.nn.BatchNorm2d
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)

def forward(self, x: TensorType((5, 2, 3, 4))):
identity = x
out = self.conv1(x)
out += identity
return out

B = BasicBlock(2, 2)
b = B.forward(torch.rand(5, 2, 3, 4))

ast_rewriter = RewritingTracer()
graph = ast_rewriter.trace(B)
traced = GraphModule(ast_rewriter.root, graph, "gm")
tc = GraphTypeChecker({}, traced)
tc.type_check()
t = TensorType((5, 2, 3, 4))
for n in graph.nodes:
if n.op == 'placeholder':
assert n.type == t
if n.op == 'call_function':
assert n.type == t
if n.op == 'output':
assert torch.Size(n.type.__args__) == b.shape
if n.op == 'call_module':
assert n.type == t

B = BasicBlock(1, 2)
ast_rewriter = RewritingTracer()
graph = ast_rewriter.trace(B)
traced = GraphModule(ast_rewriter.root, graph, "gm")
tc = GraphTypeChecker({}, traced)
with self.assertRaises(TypeError):
tc.type_check()

def test_type_check_conv2D_2_fully_static(self):
annotation_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
(10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 3)]
input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
(10, 15, 13, 14), (1, 2, 2, 3)]
intermediate_types = [(1, Dyn, Dyn, 7), (2, Dyn, 4, 6), (10, 15, Dyn, 5),
(10, 15, 7, 7), (1, Dyn, Dyn, Dyn)]
in_planes_list = [2, 5, 15, 15, 2]
stride_list = [1, 2, 3, 2, 2]
out_planes_list = [2, 5, 15, 15, 2]
groups_list = [1, 5, 5, 5, 2]
dilation_list = [1, 2, 3, 3, 3]
padding_list = [1, 2, 3, 3, 3]
kernel_size_list = [1, 2, 3, 3, 3]
output_types = [(1, 2, Dyn, 7), (2, 5, 4, 6), (10, 15, Dyn, 5), (10, 15, 7, 7), (1, 2, Dyn, Dyn)]

for i in range(5):
annotation = annotation_list[i]
input = input_list[i]
in_planes = in_planes_list[i]
stride = stride_list[i]
out_planes = out_planes_list[i]
groups = groups_list[i]
dilation = dilation_list[i]
padding = padding_list[i]
kernel_size = kernel_size_list[i]
intermediate_type = intermediate_types[i]

class BasicBlock(torch.nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation):
super(BasicBlock, self).__init__()
self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, groups=groups, bias=False, dilation=dilation)

def forward(self, x):
out = self.conv1(x)
return out

B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation)
ast_rewriter = RewritingTracer()
graph = ast_rewriter.trace(B)
traced = GraphModule(ast_rewriter.root, graph, "gm")

# annotate our argument
for n in graph.nodes:
if n.op == 'placeholder':
n.type = TensorType(annotation)

b = B.forward(torch.rand(input))
tc = GraphTypeChecker({}, traced)
tc.type_check()

for n in graph.nodes:
if n.op == 'output':
assert is_consistent(n.type, TensorType(b.size()))

# test with intermediate annotations
class BasicBlock(torch.nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation):
super(BasicBlock, self).__init__()
self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, groups=groups, bias=False, dilation=dilation)

def forward(self, x):
out = self.conv1(x)
return out

B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation)
ast_rewriter = RewritingTracer()
graph = ast_rewriter.trace(B)
traced = GraphModule(ast_rewriter.root, graph, "gm")

# populate our intermediate notes
for n in traced.graph.nodes:
if n.op == 'call_module':
n.type = TensorType(intermediate_type)

tc = GraphTypeChecker({}, traced)
tc.type_check()

for n in traced.graph.nodes:
if n.op == 'output':
assert n.type == TensorType(output_types[i])
assert is_consistent(n.type, TensorType(b.size()))


if __name__ == '__main__':
unittest.main()
62 changes: 61 additions & 1 deletion torch/fx/experimental/graph_gradual_typechecker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, Dict
from torch.fx.node import Target, Node
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.conv import Conv2d


_INFERENCE_RULES: Dict[Target, Callable] = {}
Expand All @@ -22,7 +23,7 @@ def expand_to_tensor_dim(t, n):
return TensorType(tuple(dims))
elif isinstance(t, TensorType):
if len(t.__args__) != n:
raise TypeError(f'Cannot apply matching. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}')
raise TypeError(f'Cannot extend tensor dimension. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}')
return t
else:
raise TypeError(f'Cannot match the type {t}')
Expand Down Expand Up @@ -207,6 +208,65 @@ def bn2d_inference_rule(n: Node, module_instance):
else:
raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}')

def calculate(d_in, module_instance, index):
"""
For calculating h_in and w_out.
"""
if d_in == Dyn:
return Dyn

elif isinstance(d_in, int):
n = d_in + 2 * module_instance.padding[index] - \
module_instance.dilation[index] * \
(module_instance.kernel_size[index] - 1) - 1

return (n // module_instance.stride[0]) + 1
else:
raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn')


def get_greatest_upper_bound(type1, type2):
"""
Get the most precise type that's consistent with the given types
"""
if type1 == Dyn:
return type2
elif type2 == Dyn:
return type1
elif isinstance(type1, TensorType) and isinstance(type2, TensorType):
assert is_consistent(type1, type2)
gub = [t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__)]
return TensorType(tuple(gub))
else:
raise NotImplementedError(f'Greatest upper bound not yet implemented for these types {type1}, {type2}')

@register_inference_rule(Conv2d)
def conv2d_inference_rule(n: Node, module_instance):
"""
Given a Conv2D instance and a node check the following conditions:
- the input type can be expanded to a size 4 tensor: t = (x_1, x_2, H, W)
- the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4')
- x_2 is consistent with the module's in_channels
- let o = (x_1, out_channels, H_out, W_out)
then the output is the greatest upper bound of o and the existing node type t'.
"""
assert isinstance(n.args[0], Node)
n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4)
arg_type = n.args[0].type
curr_node_type = expand_to_tensor_dim(n.type, 4)

if is_consistent(arg_type.__args__[1], module_instance.in_channels):
w_in = arg_type.__args__[3]
h_in = arg_type.__args__[2]
h_out = calculate(h_in, module_instance, 0)
w_out = calculate(w_in, module_instance, 1)
new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out))
gub = get_greatest_upper_bound(new_type, curr_node_type)
n.type = gub
return n.type
else:
raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}')

class GraphTypeChecker:
def __init__(self, env, traced):
self.env = env
Expand Down

0 comments on commit d52ebf2

Please sign in to comment.