Skip to content

Commit 8eaba6b

Browse files
tarun292facebook-github-bot
authored andcommitted
Ensure that lifted tensor constants don't show up as inputs in emitted program (#1897)
Summary: Currently lifted tensor constants are showing up as inputs to the emitted program. This shouldn't be the case as they're embedded inside the program as constants and the user will not be passing these in as inputs. Reviewed By: angelayi Differential Revision: D53584903
1 parent 7c80cd3 commit 8eaba6b

File tree

3 files changed

+61
-6
lines changed

3 files changed

+61
-6
lines changed

exir/emit/_emitter.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,18 +1308,33 @@ def placeholder(
13081308
if isinstance(target, str) and (
13091309
target in self.exported_program.graph_signature.inputs_to_parameters
13101310
or target in self.exported_program.graph_signature.inputs_to_buffers
1311+
or target
1312+
in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
13111313
):
1312-
1313-
fqn = (
1314-
self.exported_program.graph_signature.inputs_to_parameters[target]
1315-
if target in self.exported_program.graph_signature.inputs_to_parameters
1316-
else self.exported_program.graph_signature.inputs_to_buffers[target]
1317-
)
1314+
if (
1315+
target
1316+
in self.exported_program.graph_signature.inputs_to_lifted_tensor_constants
1317+
):
1318+
fqn = self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[
1319+
target
1320+
]
1321+
elif target in self.exported_program.graph_signature.inputs_to_buffers:
1322+
fqn = self.exported_program.graph_signature.inputs_to_buffers[target]
1323+
else:
1324+
fqn = self.exported_program.graph_signature.inputs_to_parameters[target]
13181325
if fqn in self.exported_program.state_dict:
13191326
spec = TensorSpec.from_tensor(
13201327
self.exported_program.state_dict[fqn], const=True
13211328
)
13221329
const_tensor = True
1330+
elif fqn in self.exported_program.constants:
1331+
print(
1332+
f"Found constant {self.exported_program.constants[fqn]} in constants dict"
1333+
)
1334+
spec = TensorSpec.from_tensor(
1335+
self.exported_program.constants[fqn], const=True
1336+
)
1337+
const_tensor = True
13231338
else:
13241339
buffers = self.exported_program.named_buffers()
13251340
buf = next((x[1] for x in buffers if x[0] == fqn), None)

exir/emit/test/test_emit.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,3 +1335,28 @@ def forward(self, x):
13351335
# confirm that the buffer was emitted
13361336
self.assertEqual(len(program.constant_buffer), 2)
13371337
self.assertEqual(len(program.constant_buffer[1].storage), 8)
1338+
1339+
def test_emit_lifted_tensor_constant(self):
1340+
class LiftedConstants(nn.Module):
1341+
def __init__(self):
1342+
super().__init__()
1343+
1344+
def forward(self, x):
1345+
x = x * torch.tensor([[4, 3], [1, 2], [5, 6]], dtype=torch.float)
1346+
return x
1347+
1348+
model = LiftedConstants()
1349+
1350+
program = to_edge(
1351+
export(
1352+
model,
1353+
(torch.ones(3, 2),),
1354+
)
1355+
).to_executorch()
1356+
1357+
program = program._emitter_output.program
1358+
exec_plan = program.execution_plan[0]
1359+
# There should only be 1 input to this model.
1360+
self.assertEqual(len(exec_plan.inputs), 1)
1361+
self.assertEqual(len(program.constant_buffer), 2)
1362+
self.assertEqual(len(program.constant_buffer[1].storage), 24)

exir/lowered_backend_module.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,21 @@ def _get_new_signature(
465465
]
466466
else:
467467
new_constants[buffer_name] = original_program.constants[buffer_name]
468+
elif node.name in old_signature.inputs_to_lifted_tensor_constants:
469+
constant_name = old_signature.inputs_to_lifted_tensor_constants[
470+
node.name
471+
]
472+
# add constant to graph signature
473+
input_specs.append(
474+
InputSpec(
475+
kind=InputKind.CONSTANT_TENSOR,
476+
arg=TensorArgument(name=node.name),
477+
target=constant_name,
478+
)
479+
)
480+
481+
# add constant to new_constants
482+
new_constants[constant_name] = original_program.constants[constant_name]
468483
else:
469484
# not param or buffer then user input
470485
input_specs.append(

0 commit comments

Comments
 (0)