Skip to content

Commit

Permalink
Revert D22898051: [pytorch][PR] Fix freeze_module pass for sharedtype
Browse files Browse the repository at this point in the history
Test Plan: revert-hammer

Differential Revision:
D22898051 (4665f3f)

Original commit changeset: 8b1d80f0eb40

fbshipit-source-id: 4dc0ba274282a157509db16df13269eed6cd5be9
  • Loading branch information
zou3519 authored and facebook-github-bot committed Aug 12, 2020
1 parent bda0007 commit 3d3752d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 79 deletions.
54 changes: 0 additions & 54 deletions test/jit/test_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

from torch.testing import FileCheck

from torch.jit._recursive import wrap_cpp_module

import io

if __name__ == '__main__':
Expand Down Expand Up @@ -1028,55 +1026,3 @@ def modify_a(self, x):
fm = torch._C._freeze_module(m._c, ["modify_a"])
FileCheck().check('prim::GetAttr[name="a"]').run(fm.forward.graph)
FileCheck().check('prim::GetAttr[name="b"]').run(fm.modify_a.graph)

def test_module_with_shared_type_instances(self):
class Child(nn.Module):
def __init__(self):
super(Child, self).__init__()
self.conv1 = nn.Conv2d(1, 1, 1)

def forward(self, x):
x = self.conv1(x)
return x

class Parent(nn.Module):
def __init__(self):
super(Parent, self).__init__()
self.quant = torch.quantization.QuantStub()
self.conv1 = nn.Conv2d(1, 1, 1)
self.child = Child()
self.child2 = Child()
self.dequant = torch.quantization.DeQuantStub()

def forward(self, x):
x = self.quant(x)
x = self.conv1(x)
x = self.child(x)
x = self.child2(x)
x = self.dequant(x)
return x

def _static_quant(model):
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
torch.quantization.prepare(model, inplace=True)
model(torch.rand(4, 1, 4, 4))
model = torch.quantization.convert(model, inplace=False)
return model

current_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float32)
data = torch.randn(4, 1, 4, 4)
m = Parent()
m = _static_quant(m)
m = torch.jit.script(m)
m.eval()
torch._C._jit_pass_inline(m.graph)
m_frozen = wrap_cpp_module(torch._C._freeze_module(m._c))
# Earlier bug resulted in _packed_params set to false.
FileCheck().check_not('_packed_params = False').run(m_frozen._c.dump_to_str(True, True, False))

m_res = m(data)
# It used to segfault while running frozen module.
m_frozen_res = m_frozen(data)
self.assertEqual(m_res, m_frozen_res)
torch.set_default_dtype(current_dtype)
36 changes: 11 additions & 25 deletions torch/csrc/jit/passes/freeze_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,6 @@ namespace torch {
namespace jit {

namespace {
ModulePtr getModulePtrForGetAttrNode(
const Node* node,
const std::shared_ptr<Graph>& graph,
const Module& graph_input_module) {
std::vector<std::string> names;
names.clear();
while (!(node->outputs()[0]->type() == graph->inputs()[0]->type())) {
TORCH_INTERNAL_ASSERT(
node->kind() == prim::GetAttr, "Expected prim::GetAttr nodes");
names.insert(names.begin(), node->s(attr::name));
node = node->inputs()[0]->node();
}
// Copy/paste from quantization/helper.h
Module m = graph_input_module;
for (const auto& p : names) {
m = m.attr(p).toModule();
}
return m._ivalue();
}

class AttributePropagator {
public:
Expand Down Expand Up @@ -452,12 +433,17 @@ class AttributePropagator {
}
if (n->kind() == prim::GetAttr) {
auto& name = n->s(attr::name);
auto mptr =
getModulePtrForGetAttrNode(n->input(0)->node(), graph, module_);
auto module = Module(mptr);
if (module.type() == n->inputs()[0]->type() && module.hasattr(name)) {
auto attr = module.attr(name);
insertMutableAttr(name, attr, mptr);
for (auto& mptr : modules) {
auto module = Module(mptr);
if (module.type() == n->inputs()[0]->type() &&
module.hasattr(name)) {
auto attr = module.attr(name);
insertMutableAttr(name, attr, mptr);
if (attr.isModule()) {
modules.insert(attr.toModule()._ivalue());
}
break;
}
}
} else if (n->kind() == prim::fork) {
applyToForkSubgraph(
Expand Down

0 comments on commit 3d3752d

Please sign in to comment.