Skip to content

Commit fc559bd

Browse files
Elias Ellisonfacebook-github-bot
authored andcommitted
[JIT] Constant prop getattr (#49806)
Summary: Pull Request resolved: #49806 Fix for #47089 Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D25696791 Pulled By: eellison fbshipit-source-id: 914c17b8effef7f4f341775ac2b8150ee4703efd
1 parent 268441c commit fc559bd

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

test/jit/test_freezing.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,3 +1331,33 @@ def forward(self, x: torch.Tensor, key: str) -> Any:
13311331
m.eval()
13321332
with self.assertRaisesRegex(RuntimeError, "Freezing modules containing prim::ModuleDictIndex is not supported"):
13331333
mf = torch._C._freeze_module(m._c)
1334+
1335+
1336+
def test_freeze_non_module_class_getattr(self):
1337+
class BoxCoder(object):
1338+
def __init__(self, bbox_xform_clip):
1339+
# type: (float) -> None
1340+
self.bbox_xform_clip = bbox_xform_clip
1341+
1342+
def decode(self, input):
1343+
return input * self.bbox_xform_clip
1344+
1345+
class MyModule(torch.nn.Module):
1346+
__annotations__ = {
1347+
'box_coder': BoxCoder,
1348+
}
1349+
1350+
def __init__(self):
1351+
super(MyModule, self).__init__()
1352+
self.box_coder = BoxCoder(50.)
1353+
1354+
def forward(self, input):
1355+
return self.box_coder.decode(input)
1356+
1357+
model = MyModule()
1358+
model.eval()
1359+
script_model = torch.jit.freeze(torch.jit.script(model))
1360+
inp = torch.randn([4, 4])
1361+
output_eager = model(inp)
1362+
self.assertEqual(model(inp), script_model(inp))
1363+
FileCheck().check_not("GetAttr").run(script_model.graph)

torch/csrc/jit/passes/constant_propagation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(
5454
case prim::CreateObject: {
5555
createObject(stack, n->output()->type()->expect<ClassType>());
5656
} break;
57+
case prim::GetAttr: {
58+
auto attr = pop(stack).toObject()->getAttr(n->s(attr::name));
59+
push(stack, attr);
60+
} break;
5761
case prim::isinstance: {
5862
isinstance(stack, n->tys(attr::types));
5963
} break;

0 commit comments

Comments
 (0)