Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/pytorch/pytorch into fix_…
Browse files Browse the repository at this point in the history
…pow_complex
  • Loading branch information
RockingJavaBean committed Sep 28, 2020
2 parents 903d525 + 95a97e5 commit 9c9101e
Show file tree
Hide file tree
Showing 16 changed files with 480 additions and 251 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -272,6 +272,7 @@ namespace c10 {
_(prim, grad) \
_(aten, zero_) \
_(aten, fill_) \
_(aten, masked_fill_) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \
Expand Down
71 changes: 71 additions & 0 deletions test/jit/test_freezing.py
Expand Up @@ -524,6 +524,77 @@ def forward(self, x):
self.assertEqual(output_s, output_f)


def test_freeze_module_with_preserve_sub_module(self):
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.a = torch.tensor([1.1])
self.b = 2.2

def forward(self, x):
return self.a

class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub1 = SubModule() # aliasing
self.sub2 = SubModule()

def forward(self, x):
return self.sub2(x) + self.sub1(x)
m = TestModule()
ms = torch.jit.script(m)
ms.eval()
mf = torch._C._freeze_module(ms._c, ["sub1"])

# Test that 'sub1' is preserved entirely and 'sub2' is completely folded
self.assertTrue(mf.hasattr('sub1'))
self.assertTrue(mf.sub1.hasattr('a'))
self.assertTrue(mf.sub1.hasattr('b'))
self.assertFalse(mf.hasattr('sub2'))
input = torch.randn(2, 2)
output_s = ms.forward(input)
output_f = mf.forward(input)
self.assertEqual(output_s, output_f)

def test_freeze_module_with_preserve_sub_module_and_mutation(self):
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.a = torch.tensor([1.1])
self.b = 2.2

def forward(self, x):
self.a[0] = 3.3
return self.a

class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub1 = SubModule() # aliasing
self.sub2 = SubModule()

def forward(self, x):
return self.sub2(x) + self.sub1(x)
m = TestModule()
ms = torch.jit.script(m)
ms.eval()
mf = torch._C._freeze_module(ms._c, ["sub1"])

# Test that be both sub1 and sub1 are preserved and 'b' is preserved
# even if it is not used. To fulfill user request to preserve 'sub1'
self.assertTrue(mf.hasattr('sub1'))
self.assertTrue(mf.sub1.hasattr('a'))
self.assertTrue(mf.sub1.hasattr('b'))
self.assertTrue(mf.hasattr('sub2'))
self.assertTrue(mf.sub2.hasattr('a'))
self.assertTrue(mf.sub2.hasattr('b'))
input = torch.randn(2, 2)
output_s = ms.forward(input)
output_f = mf.forward(input)
self.assertEqual(output_s, output_f)


def test_freeze_module_with_helperfunction(self):
class SubModule(nn.Module):
def __init__(self):
Expand Down
Expand Up @@ -8,6 +8,11 @@ graph {
output: "2"
name: "SoftmaxCrossEntropyLoss_0"
op_type: "SoftmaxCrossEntropyLoss"
attribute {
name: "ignore_index"
i: -100
type: INT
}
attribute {
name: "reduction"
s: "mean"
Expand Down
Expand Up @@ -8,6 +8,11 @@ graph {
output: "2"
name: "SoftmaxCrossEntropyLoss_0"
op_type: "SoftmaxCrossEntropyLoss"
attribute {
name: "ignore_index"
i: -100
type: INT
}
attribute {
name: "reduction"
s: "mean"
Expand Down
Expand Up @@ -8,6 +8,11 @@ graph {
output: "2"
name: "SoftmaxCrossEntropyLoss_0"
op_type: "SoftmaxCrossEntropyLoss"
attribute {
name: "ignore_index"
i: -100
type: INT
}
attribute {
name: "reduction"
s: "none"
Expand Down
Expand Up @@ -8,6 +8,11 @@ graph {
output: "2"
name: "SoftmaxCrossEntropyLoss_0"
op_type: "SoftmaxCrossEntropyLoss"
attribute {
name: "ignore_index"
i: -100
type: INT
}
attribute {
name: "reduction"
s: "mean"
Expand Down
Expand Up @@ -9,6 +9,11 @@ graph {
output: "3"
name: "SoftmaxCrossEntropyLoss_0"
op_type: "SoftmaxCrossEntropyLoss"
attribute {
name: "ignore_index"
i: -100
type: INT
}
attribute {
name: "reduction"
s: "mean"
Expand Down
16 changes: 4 additions & 12 deletions test/onnx/test_models.py
Expand Up @@ -49,7 +49,6 @@ class TestModels(TestCase):
opset_version = _export_onnx_opset_version

def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7):
self.is_script_test_enabled = True
with torch.onnx.select_model_mode_for_export(model, None):
graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
torch._C._jit_pass_lint(graph)
Expand Down Expand Up @@ -94,14 +93,12 @@ def test_srresnet(self):
self.exportTest(toC(SRResNet(rescale_factor=4, n_filters=64, n_blocks=8)), toC(x))

@skipIfNoLapack
@disableScriptTest()
def test_super_resolution(self):
x = Variable(
torch.randn(BATCH_SIZE, 1, 224, 224).fill_(1.0)
)
self.exportTest(toC(SuperResolutionNet(upscale_factor=3)), toC(x), atol=1e-6)

@disableScriptTest()
def test_alexnet(self):
x = Variable(
torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
Expand Down Expand Up @@ -137,13 +134,12 @@ def test_vgg19_bn(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(vgg19_bn()), toC(x))

@disableScriptTest()
def test_resnet(self):
# ResNet50 model
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(resnet50()), toC(x), atol=1e-6)

@disableScriptTest()
@disableScriptTest() # None type in outputs
def test_inception(self):
x = Variable(
torch.randn(BATCH_SIZE, 3, 299, 299) + 1.)
Expand Down Expand Up @@ -208,22 +204,20 @@ def test_qat_resnet(self):

self.exportTest(toC(qat_resnet50), toC(x))

@disableScriptTest()
@disableScriptTest() # None type in outputs
def test_googlenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(googlenet()), toC(x), rtol=1e-3, atol=1e-5)

@disableScriptTest()
def test_mnasnet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(mnasnet1_0()), toC(x), rtol=1e-3, atol=1e-5)

@disableScriptTest()
def test_mobilenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(mobilenet_v2()), toC(x), rtol=1e-3, atol=1e-5)

@disableScriptTest()
@disableScriptTest() # prim_data
def test_shufflenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(shufflenet_v2_x1_0()), toC(x), rtol=1e-3, atol=1e-5)
Expand All @@ -238,20 +232,18 @@ def test_deeplab(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
self.exportTest(toC(deeplabv3_resnet101()), toC(x), rtol=1e-3, atol=1e-5)

@disableScriptTest()
def test_r3d_18_video(self):
x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
self.exportTest(toC(r3d_18()), toC(x), rtol=1e-3, atol=1e-5)

@disableScriptTest()
def test_mc3_18_video(self):
x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
self.exportTest(toC(mc3_18()), toC(x), rtol=1e-3, atol=1e-5)

@disableScriptTest()
def test_r2plus1d_18_video(self):
x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
self.exportTest(toC(r2plus1d_18()), toC(x), rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
run_tests()
24 changes: 21 additions & 3 deletions test/onnx/test_models_onnxruntime.py
Expand Up @@ -15,13 +15,31 @@ def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
input=inputs, rtol=rtol, atol=atol)

if self.is_script_test_enabled and opset_version > 11:
TestModels.use_new_jit_passes = True
TestModels.onnx_shape_inference = True

outputs = model(inputs)
script_model = torch.jit.script(model)
run_model_test(self, script_model, False, example_outputs=outputs,
input=inputs, rtol=rtol, atol=atol, use_new_jit_passes=True)
input=inputs, rtol=rtol, atol=atol)


TestModels = type(str("TestModels"),
(unittest.TestCase,),
dict(TestModels.__dict__,
is_script_test_enabled=False,
exportTest=exportTest))


# model tests for scripting with new JIT APIs and shape inference
TestModels_new_jit_API = type(str("TestModels_new_jit_API"),
(unittest.TestCase,),
dict(TestModels.__dict__,
exportTest=exportTest,
is_script_test_enabled=True,
use_new_jit_passes=True,
onnx_shape_inference=True))


if __name__ == '__main__':
TestModels.is_script_test_enabled = True
TestModels.exportTest = exportTest
unittest.main()

0 comments on commit 9c9101e

Please sign in to comment.