Skip to content

Commit

Permalink
Preserve submodule with __set_state__ in freezing (#47308)
Browse files Browse the repository at this point in the history
Summary:
This PR does the following:

-  fail freezing if input module has __set_state__ method
-  preserves attributes of  submodules with __set_state__ method.

Fixes #{issue number}

Pull Request resolved: #47308

Reviewed By: eellison

Differential Revision: D24711613

Pulled By: bzinodev

fbshipit-source-id: 22e51417454aaf85cc0ae4acb2dc7fc822f149a2
  • Loading branch information
bzinodev authored and facebook-github-bot committed Dec 10, 2020
1 parent a480ca5 commit a3e1bd1
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 2 deletions.
94 changes: 94 additions & 0 deletions test/jit/test_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,100 @@ def forward(self, x):
output_f = mf.forward(input)
self.assertEqual(output_s, output_f)

def test_freeze_module_with_setstate(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.tensor = torch.randn(2, 2)

@torch.jit.export
def __getstate__(self):
return (self.tensor, self.training)

@torch.jit.export
def __setstate__(self, state):
self.tensor = 2 * state[0]
self.training = state[1]

def forward(self, x):
return x + self.tensor

m = torch.jit.script(M())
m.eval()
with self.assertRaisesRegex(RuntimeError, "cannot freeze a module that has __set_state__"):
mf = torch.jit.freeze(m)

def test_freeze_module_with_submodule_setstate(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.tensor = torch.randn(2, 2)

@torch.jit.export
def __getstate__(self):
return (self.tensor, self.training)

@torch.jit.export
def __setstate__(self, state):
self.tensor = 2 * state[0]
self.training = state[1]

def forward(self, x):
return x + self.tensor

class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub = M()
self.a = torch.randn(2, 2)
self.b = 4

def forward(self, x):
return self.sub(x) + self.a

m = torch.jit.script(TestModule())
m.eval()
input = torch.randn(2, 2)
output_s = m.forward(input)
mf = torch.jit.freeze(m)

output_f = mf.forward(input)
buffer = io.BytesIO()
torch.jit.save(mf._c, buffer)
buffer.seek(0)
loaded = torch.jit.load(buffer)
output_l = loaded.forward(input)

# Check if frozen module looks as below:
# module m {
# attributes {
# sub = ...
# }
# ...
# submodule {
# module m {
# attributes {
# training =
# tensor = ...
# }
# ...
# }
# }
# }
mf = mf._c
self.assertFalse(mf.hasattr('a'))
self.assertTrue(mf.hasattr('sub'))
self.assertTrue(mf.sub.hasattr('tensor'))
self.assertTrue(mf.sub.hasattr('training'))

# __setstate__ is executed cloning the module for freezing
self.assertEqual(mf.sub.tensor, 2 * m.sub.tensor)
self.assertEqual(output_s + m.sub.tensor , output_f)

# __setstate__ is execuded loading frozen module
self.assertEqual(loaded.sub.tensor, 2 * mf.sub.tensor)
self.assertEqual(output_l, mf.sub.tensor + output_f)

def test_freeze_module_with_fork(self):
class SubModule(nn.Module):
def __init__(self):
Expand Down
20 changes: 18 additions & 2 deletions torch/csrc/jit/passes/freeze_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,20 @@ class AttributePropagator {
}

auto attr = attrModule.attr(name);
auto mptr = attrModule._ivalue();
if (n->kind() == prim::GetAttr) {
auto type = n->output()->type();
// Do not record submodules. Their attributes are tracked
// individually.
if (attr.isObject() || !AliasDb::isMutableType(attr.type())) {
if (attr.isObject()) {
auto submodule = attr.toModule();
if (submodule.find_method("__setstate__")) {
insertMutableAttr(name, attr, mptr);
}
continue;
}

if (!AliasDb::isMutableType(attr.type())) {
continue;
}
usedAttrs_.insert(attr);
Expand All @@ -256,7 +265,6 @@ class AttributePropagator {
n->kind() == prim::GetAttr ? "attribute: " + name + " in %" +
n->output()->debugName() + " has inplace writer"
: "attribute: " + name + " is set");
auto mptr = attrModule._ivalue();
insertMutableAttr(name, attr, mptr);
}
} else if (n->kind() == prim::fork) {
Expand Down Expand Up @@ -525,6 +533,11 @@ class AttributePropagator {
return true;
}
}

if (subModule.find_method("__setstate__")) {
return true;
}

return preservedSubModule_.count(subModule._ivalue());
}

Expand Down Expand Up @@ -751,6 +764,9 @@ Module freeze_module(
std::vector<std::string> preservedAttrs,
bool freezeInterfaces,
bool preserveParameters) {
TORCH_CHECK(
!module.find_method("__setstate__"),
"cannot freeze a module that has __set_state__");
Method method = module.get_method("forward");
// Check that module does not return itself.
for (auto& output : method.graph()->outputs()) {
Expand Down

0 comments on commit a3e1bd1

Please sign in to comment.