Skip to content

Commit

Permalink
[JIT] Constant prop getattr (#49806)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49806

Fix for #47089

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D25696791

Pulled By: eellison

fbshipit-source-id: 914c17b8effef7f4f341775ac2b8150ee4703efd
  • Loading branch information
Elias Ellison authored and facebook-github-bot committed Dec 28, 2020
1 parent 268441c commit fc559bd
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
30 changes: 30 additions & 0 deletions test/jit/test_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,3 +1331,33 @@ def forward(self, x: torch.Tensor, key: str) -> Any:
m.eval()
with self.assertRaisesRegex(RuntimeError, "Freezing modules containing prim::ModuleDictIndex is not supported"):
mf = torch._C._freeze_module(m._c)


def test_freeze_non_module_class_getattr(self):
class BoxCoder(object):
def __init__(self, bbox_xform_clip):
# type: (float) -> None
self.bbox_xform_clip = bbox_xform_clip

def decode(self, input):
return input * self.bbox_xform_clip

class MyModule(torch.nn.Module):
__annotations__ = {
'box_coder': BoxCoder,
}

def __init__(self):
super(MyModule, self).__init__()
self.box_coder = BoxCoder(50.)

def forward(self, input):
return self.box_coder.decode(input)

model = MyModule()
model.eval()
script_model = torch.jit.freeze(torch.jit.script(model))
inp = torch.randn([4, 4])
output_eager = model(inp)
self.assertEqual(model(inp), script_model(inp))
FileCheck().check_not("GetAttr").run(script_model.graph)
4 changes: 4 additions & 0 deletions torch/csrc/jit/passes/constant_propagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(
case prim::CreateObject: {
createObject(stack, n->output()->type()->expect<ClassType>());
} break;
case prim::GetAttr: {
auto attr = pop(stack).toObject()->getAttr(n->s(attr::name));
push(stack, attr);
} break;
case prim::isinstance: {
isinstance(stack, n->tys(attr::types));
} break;
Expand Down

0 comments on commit fc559bd

Please sign in to comment.