Skip to content

Commit

Permalink
nomnigraph - easy - some code cleanup for transformations_test (#12101)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #12101

clean up some duplicate test code

Reviewed By: ZolotukhinM

Differential Revision: D10051914

fbshipit-source-id: 698ff144a85e8c70572116c5ddb415cd2396b4e3
  • Loading branch information
duc0 authored and facebook-github-bot committed Oct 1, 2018
1 parent 006171f commit e43ffb0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 87 deletions.
4 changes: 4 additions & 0 deletions caffe2/python/test_util.py
Expand Up @@ -30,6 +30,10 @@ def randBlobsFloat32(names, *dims, **kwargs):
randBlobFloat32(name, *dims, **kwargs)


def numOps(net):
return len(net.Proto().op)


def str_compare(a, b, encoding="utf8"):
if isinstance(a, bytes):
a = a.decode(encoding)
Expand Down
126 changes: 39 additions & 87 deletions caffe2/python/transformations_test.py
Expand Up @@ -30,134 +30,86 @@


class TestTransformations(tu.TestCase):
def test_transformer_AddNNPACK(self):
def _base_test_net(self):
net = core.Net("net")
net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
net.Relu(["Y"], ["Y2"])
return net

def _add_nnpack(self, net):
transformer.AddNNPACK(net)
assert tu.str_compare(net.Proto().op[0].engine, "NNPACK")

def test_transformer_FuseNNPACKConvRelu(self):
net = core.Net("net")
net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
net.Relu(["Y"], ["Y2"])
transformer.AddNNPACK(net) # get the NNPACK engine
assert tu.str_compare(net.Proto().op[0].engine, "NNPACK")
def _fuse_nnpack_convrelu(self, net, expected_result_num_ops,
expected_activation_arg=True):
self._add_nnpack(net)
transformer.FuseNNPACKConvRelu(net)
assert len(net.Proto().op) == 1
self.assertEquals(tu.numOps(net), expected_result_num_ops)
has_activation_arg = False
for arg in net.Proto().op[0].arg:
if tu.str_compare(arg.name, "activation"):
assert tu.str_compare(arg.s, "Relu")
has_activation_arg = True
assert has_activation_arg
if expected_activation_arg:
assert has_activation_arg
else:
assert not has_activation_arg

def test_transformer_AddNNPACK(self):
net = self._base_test_net()
net.Relu(["Y"], ["Y2"])
self._add_nnpack(net)

def test_transformer_FuseNNPACKConvRelu(self):
net = self._base_test_net()
net.Relu(["Y"], ["Y2"])
self._fuse_nnpack_convrelu(net, 1)

def test_noFuseNNPACKConvRelu(self):
net = core.Net("net")
net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
net = self._base_test_net()
net.Relu(["Y"], ["Y2"])
net.Relu(["Y"], ["Y3"])
transformer.AddNNPACK(net) # get the NNPACK engine
assert tu.str_compare(net.Proto().op[0].engine, "NNPACK")
transformer.FuseNNPACKConvRelu(net)
assert len(net.Proto().op) == 3
has_activation_arg = False
for arg in net.Proto().op[0].arg:
if tu.str_compare(arg.name, "activation") and tu.str_compare(arg.s, "Relu"):
has_activation_arg = True
assert not has_activation_arg
self._fuse_nnpack_convrelu(net, 3, expected_activation_arg=False)

def test_transformer_FuseNNPACKConvReluNoInplace(self):
net = core.Net("net")
net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
net = self._base_test_net()
net.Relu(["Y"], ["X"])
transformer.AddNNPACK(net) # get the NNPACK engine
assert tu.str_compare(net.Proto().op[0].engine, "NNPACK")
transformer.FuseNNPACKConvRelu(net)
assert len(net.Proto().op) == 1
has_activation_arg = False
for arg in net.Proto().op[0].arg:
if tu.str_compare(arg.name, "activation"):
assert tu.str_compare(arg.s, "Relu")
has_activation_arg = True
assert has_activation_arg
self._fuse_nnpack_convrelu(net, 1)
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]

def test_transformer_FuseNNPACKConvReluInplaceRelu(self):
net = core.Net("net")
net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
net = self._base_test_net()
net.Relu(["Y"], ["Y"])
transformer.AddNNPACK(net) # get the NNPACK engine
assert tu.str_compare(net.Proto().op[0].engine, "NNPACK")
transformer.FuseNNPACKConvRelu(net)
assert len(net.Proto().op) == 1
has_activation_arg = False
for arg in net.Proto().op[0].arg:
if tu.str_compare(arg.name, "activation"):
assert tu.str_compare(arg.s, "Relu")
has_activation_arg = True
assert has_activation_arg
self._fuse_nnpack_convrelu(net, 1)
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]

def test_transformer_FuseNNPACKConvReluPingPongNaming(self):
net = core.Net("net")
net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
net = self._base_test_net()
net.Relu(["Y"], ["X"])
net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
transformer.AddNNPACK(net) # get the NNPACK engine
assert tu.str_compare(net.Proto().op[0].engine, "NNPACK")
transformer.FuseNNPACKConvRelu(net)
assert len(net.Proto().op) == 2
has_activation_arg = False
for arg in net.Proto().op[0].arg:
if tu.str_compare(arg.name, "activation"):
assert tu.str_compare(arg.s, "Relu")
has_activation_arg = True
assert has_activation_arg
self._fuse_nnpack_convrelu(net, 2)
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]

def test_transformer_FuseNNPACKConvReluFollowedByMultipleInputOp(self):
net = core.Net("net")
net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
net = self._base_test_net()
net.Relu(["Y"], ["Y2"])
net.Conv(["Y2", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
net.Relu(["Y"], ["Y2"])
transformer.AddNNPACK(net) # get the NNPACK engine
assert tu.str_compare(net.Proto().op[0].engine, "NNPACK")
transformer.FuseNNPACKConvRelu(net)
assert len(net.Proto().op) == 2
has_activation_arg = False
for arg in net.Proto().op[0].arg:
if tu.str_compare(arg.name, "activation"):
assert tu.str_compare(arg.s, "Relu")
has_activation_arg = True
assert has_activation_arg
self._fuse_nnpack_convrelu(net, 2)
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]

def test_transformer_FuseNNPACKConvReluInplaceFollowedByMultipleInputOp(self):
net = core.Net("net")
net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
net = self._base_test_net()
net.Relu(["Y"], ["Y"])
net.Conv(["Y", "w", "b"], ["Y2"], stride=1, pad=0, kernel=3, order="NCHW")
net.Relu(["Y2"], ["Y2"])
transformer.AddNNPACK(net) # get the NNPACK engine
assert tu.str_compare(net.Proto().op[0].engine, "NNPACK")
transformer.FuseNNPACKConvRelu(net)
assert len(net.Proto().op) == 2
has_activation_arg = False
for arg in net.Proto().op[0].arg:
if tu.str_compare(arg.name, "activation"):
assert tu.str_compare(arg.s, "Relu")
has_activation_arg = True
assert has_activation_arg
self._fuse_nnpack_convrelu(net, 2)
assert net.Proto().op[0].output[0] != net.Proto().op[0].input[0]
assert net.Proto().op[1].output[0] != net.Proto().op[1].input[0]

def test_transformer_SinkMaxPool(self):
net = core.Net("net")
net.Conv(["X", "w", "b"], ["Y"], stride=1, pad=0, kernel=3, order="NCHW")
net = self._base_test_net()
net.MaxPool(["Y"], ["Y1"], kernel=3)
net.Relu(["Y1"], ["Y1"])
transformer.SinkMaxPool(net)
Expand Down Expand Up @@ -205,7 +157,7 @@ def test_transformer_FuseConvBN(self, size, input_channels, seed, order, epsilon
transformer.FuseConvBN(net)

# Ensure fusion
assert len(net.Proto().op) == 1
assert tu.numOps(net) == 1
workspace.RunNetOnce(net)
postTransformOutput = workspace.FetchBlob("Y2").flatten()
# Check that there is no numerical difference
Expand Down Expand Up @@ -256,7 +208,7 @@ def test_transformer_FuseConvBNNoConvBias(self, size, input_channels, seed, orde
transformer.FuseConvBN(net)

# Ensure fusion
assert len(net.Proto().op) == 1
assert tu.numOps(net) == 1
workspace.RunNetOnce(net)
postTransformOutput = workspace.FetchBlob("Y2").flatten()
# Check that there is no numerical difference
Expand Down Expand Up @@ -307,7 +259,7 @@ def test_transformer_FuseConvBNNoConvBiasDuplicatedName(self, size, input_channe
transformer.FuseConvBN(net)

# Ensure fusion
assert len(net.Proto().op) == 1
assert tu.numOps(net) == 1
workspace.RunNetOnce(net)
postTransformOutput = workspace.FetchBlob("Y2").flatten()
print("pre")
Expand Down Expand Up @@ -365,7 +317,7 @@ def test_transformer_FuseConv3DBN(
transformer.FuseConvBN(net)

# Ensure fusion
assert len(net.Proto().op) == 1
assert tu.numOps(net) == 1
workspace.RunNetOnce(net)
postTransformOutput = workspace.FetchBlob("Y2").flatten()
# Check that there is no numerical difference
Expand Down

0 comments on commit e43ffb0

Please sign in to comment.