diff --git a/exir/emit/_emit_program.py b/exir/emit/_emit_program.py index 0618871bd40..d25ee3c538b 100644 --- a/exir/emit/_emit_program.py +++ b/exir/emit/_emit_program.py @@ -164,7 +164,6 @@ def emit_program( operators=[], delegates=[], operator_cache={}, - delegate_cache={}, emit_stacktrace=emit_stacktrace, emit_mutable_buffer_names=emit_mutable_buffer_names, ) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 7701ca7b8ff..15e0b23d36f 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -147,8 +147,6 @@ class _EmitterState: operators: List[Operator] delegates: List[BackendDelegate] operator_cache: Dict[Tuple[str, str], int] - # delegate_cache: the key is hash(delegated_payload) and the value is the index in delegates - delegate_cache: Dict[str, int] emit_stacktrace: bool emit_mutable_buffer_names: bool @@ -1092,7 +1090,7 @@ def _emit_delegate( delegate's blob.""" processed_bytes = lowered_module.processed_bytes hashed = hashlib.sha256(processed_bytes).hexdigest() - delegate_index = self.emitter_state.delegate_cache.get(hashed) + delegate_index = self.program_state.backend_delegate_data_cache.get(hashed) delegate_ret = None if isinstance(self.node.meta["spec"], list): @@ -1130,28 +1128,20 @@ def _emit_delegate( if delegate_index is None: # Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if # present. - hashed = hashlib.sha256(processed_bytes).hexdigest() - data_index: Optional[int] = ( - self.program_state.backend_delegate_data_cache.get(hashed) + delegate_index = len(self.program_state.backend_delegate_data_cache) + self.program_state.backend_delegate_data_cache[hashed] = delegate_index + self.program_state.backend_delegate_data.append( + BackendDelegateInlineData(data=processed_bytes) ) - if data_index is None: - data_index = len(self.program_state.backend_delegate_data) - self.program_state.backend_delegate_data_cache[hashed] = data_index - self.program_state.backend_delegate_data.append( - BackendDelegateInlineData(data=processed_bytes) - ) - - backend_delegate = BackendDelegate( - id=lowered_module.backend_id, - processed=BackendDelegateDataReference( - location=DataLocation.INLINE, index=data_index - ), - compile_specs=lowered_module.compile_specs, - ) - delegate_index = len(self.emitter_state.delegate_cache) - self.emitter_state.delegates.append(backend_delegate) - self.emitter_state.delegate_cache[hashed] = delegate_index + backend_delegate = BackendDelegate( + id=lowered_module.backend_id, + processed=BackendDelegateDataReference( + location=DataLocation.INLINE, index=delegate_index + ), + compile_specs=lowered_module.compile_specs, + ) + self.emitter_state.delegates.append(backend_delegate) # TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the # function's spec and with default arguments. This requires us to store the function's spec # in to_backend() @@ -1164,7 +1154,12 @@ def _emit_delegate( delegate_args.append(elem.id) self.chain.instructions.append( - Instruction(DelegateCall(delegate_index=delegate_index, args=delegate_args)) + Instruction( + DelegateCall( + delegate_index=len(self.emitter_state.delegates) - 1, + args=delegate_args, + ) + ) ) return delegate_ret diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 649b795ad8f..dcc3544875a 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1770,6 +1770,60 @@ def forward(self, x): len(edge_program_manager.executorch_program.backend_delegate_data), 1 ) + def test_delegate_deduplicate_with_different_compile_specs(self) -> None: + class LowerableSubModel(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + lowered = LowerableSubModel() + example_input = (torch.ones(1),) + + lowered_edge = to_edge(export(lowered, example_input)) + + from executorch.exir.backend.compile_spec_schema import CompileSpec + + compile_specs1 = [CompileSpec("config", b"fast")] + compile_specs2 = [CompileSpec("config", b"small")] + lowered_module1 = to_backend( + "BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs1 + ) + lowered_module2 = to_backend( + "BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs2 + ) + + class CompositeModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.lowerable1 = lowered_module1 + self.lowerable2 = lowered_module2 + + def forward(self, x): + a = self.lowerable1(x) + b = self.lowerable2(a) + return a, b + + composite_model = CompositeModel() + model_inputs = (torch.ones(1),) + edge_prog = to_edge(export(composite_model, model_inputs)).to_executorch() + + exported_program = edge_prog.exported_program() + program = emit_program({"method1": exported_program}, False).program + self.assertEqual(len(program.execution_plan), 1) + + plan = program.execution_plan[0] + # Two delegates that point to the same blob. + self.assertEqual(len(plan.delegates), 2) + self.assertEqual(plan.delegates[0].processed.index, 0) + self.assertEqual(plan.delegates[1].processed.index, 0) + # Compile specs are different. + self.assertEqual(plan.delegates[0].compile_specs, compile_specs1) + self.assertEqual(plan.delegates[1].compile_specs, compile_specs2) + # Only one delegate blob in the backend_delegate_data. + self.assertEqual(len(program.backend_delegate_data), 1) + def test_constant_tagged_mutable_tensors(self) -> None: class Net(nn.Module): def __init__(self):