-
Notifications
You must be signed in to change notification settings - Fork 683
Dedup delegate blobs in emitter #14564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -147,8 +147,6 @@ class _EmitterState: | |
operators: List[Operator] | ||
delegates: List[BackendDelegate] | ||
operator_cache: Dict[Tuple[str, str], int] | ||
# delegate_cache: the key is hash(delegated_payload) and the value is the index in delegates | ||
delegate_cache: Dict[str, int] | ||
emit_stacktrace: bool | ||
emit_mutable_buffer_names: bool | ||
|
||
|
@@ -1092,7 +1090,7 @@ def _emit_delegate( | |
delegate's blob.""" | ||
processed_bytes = lowered_module.processed_bytes | ||
hashed = hashlib.sha256(processed_bytes).hexdigest() | ||
delegate_index = self.emitter_state.delegate_cache.get(hashed) | ||
delegate_index = self.program_state.backend_delegate_data_cache.get(hashed) | ||
delegate_ret = None | ||
|
||
if isinstance(self.node.meta["spec"], list): | ||
|
@@ -1130,28 +1128,20 @@ def _emit_delegate( | |
if delegate_index is None: | ||
# Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if | ||
# present. | ||
hashed = hashlib.sha256(processed_bytes).hexdigest() | ||
data_index: Optional[int] = ( | ||
self.program_state.backend_delegate_data_cache.get(hashed) | ||
delegate_index = len(self.program_state.backend_delegate_data_cache) | ||
self.program_state.backend_delegate_data_cache[hashed] = delegate_index | ||
self.program_state.backend_delegate_data.append( | ||
BackendDelegateInlineData(data=processed_bytes) | ||
) | ||
if data_index is None: | ||
data_index = len(self.program_state.backend_delegate_data) | ||
self.program_state.backend_delegate_data_cache[hashed] = data_index | ||
self.program_state.backend_delegate_data.append( | ||
BackendDelegateInlineData(data=processed_bytes) | ||
) | ||
|
||
backend_delegate = BackendDelegate( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not quite sure why the previous logic didn't work but the new one does. It seems the logic is the same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous logic refers to the BackendDelegate created for the deduplicated processed blob. We may have a different compile specs, but the same processed blob, in which case the compile specs are lost. You can take a look at the test case as well, and try it on the old code. |
||
id=lowered_module.backend_id, | ||
processed=BackendDelegateDataReference( | ||
location=DataLocation.INLINE, index=data_index | ||
), | ||
compile_specs=lowered_module.compile_specs, | ||
) | ||
delegate_index = len(self.emitter_state.delegate_cache) | ||
self.emitter_state.delegates.append(backend_delegate) | ||
self.emitter_state.delegate_cache[hashed] = delegate_index | ||
|
||
backend_delegate = BackendDelegate( | ||
id=lowered_module.backend_id, | ||
processed=BackendDelegateDataReference( | ||
location=DataLocation.INLINE, index=delegate_index | ||
), | ||
compile_specs=lowered_module.compile_specs, | ||
) | ||
self.emitter_state.delegates.append(backend_delegate) | ||
# TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the | ||
# function's spec and with default arguments. This requires us to store the function's spec | ||
# in to_backend() | ||
|
@@ -1164,7 +1154,12 @@ def _emit_delegate( | |
delegate_args.append(elem.id) | ||
|
||
self.chain.instructions.append( | ||
Instruction(DelegateCall(delegate_index=delegate_index, args=delegate_args)) | ||
Instruction( | ||
DelegateCall( | ||
delegate_index=len(self.emitter_state.delegates) - 1, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm why is the delegate index defined as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the new logic creates a separate BackendDelegate for each call_delegate, so it corresponds to the length self.emitter_state.delegates. |
||
args=delegate_args, | ||
) | ||
) | ||
) | ||
|
||
return delegate_ret | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1770,6 +1770,60 @@ def forward(self, x): | |
len(edge_program_manager.executorch_program.backend_delegate_data), 1 | ||
) | ||
|
||
def test_delegate_deduplicate_with_different_compile_specs(self) -> None: | ||
class LowerableSubModel(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x): | ||
return torch.sin(x) | ||
|
||
lowered = LowerableSubModel() | ||
example_input = (torch.ones(1),) | ||
|
||
lowered_edge = to_edge(export(lowered, example_input)) | ||
|
||
from executorch.exir.backend.compile_spec_schema import CompileSpec | ||
|
||
compile_specs1 = [CompileSpec("config", b"fast")] | ||
compile_specs2 = [CompileSpec("config", b"small")] | ||
lowered_module1 = to_backend( | ||
"BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs1 | ||
) | ||
lowered_module2 = to_backend( | ||
"BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs2 | ||
) | ||
|
||
class CompositeModel(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.lowerable1 = lowered_module1 | ||
self.lowerable2 = lowered_module2 | ||
|
||
def forward(self, x): | ||
a = self.lowerable1(x) | ||
b = self.lowerable2(a) | ||
return a, b | ||
|
||
composite_model = CompositeModel() | ||
model_inputs = (torch.ones(1),) | ||
edge_prog = to_edge(export(composite_model, model_inputs)).to_executorch() | ||
|
||
exported_program = edge_prog.exported_program() | ||
program = emit_program({"method1": exported_program}, False).program | ||
self.assertEqual(len(program.execution_plan), 1) | ||
|
||
plan = program.execution_plan[0] | ||
# Two delegates that point to the same blob. | ||
self.assertEqual(len(plan.delegates), 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it for checking the number of call_delegate instruction? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes |
||
self.assertEqual(plan.delegates[0].processed.index, 0) | ||
self.assertEqual(plan.delegates[1].processed.index, 0) | ||
# Compile specs are different. | ||
self.assertEqual(plan.delegates[0].compile_specs, compile_specs1) | ||
self.assertEqual(plan.delegates[1].compile_specs, compile_specs2) | ||
# Only one delegate blob in the backend_delegate_data. | ||
self.assertEqual(len(program.backend_delegate_data), 1) | ||
|
||
def test_constant_tagged_mutable_tensors(self) -> None: | ||
class Net(nn.Module): | ||
def __init__(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the difference between emitter state and program state? How is
backend_delegate_data_cache
different thandelegate_cache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
emitter state is per-method, program state covers all the methods in the program