Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion exir/emit/_emit_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def emit_program(
operators=[],
delegates=[],
operator_cache={},
delegate_cache={},
emit_stacktrace=emit_stacktrace,
emit_mutable_buffer_names=emit_mutable_buffer_names,
)
Expand Down
43 changes: 19 additions & 24 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,6 @@ class _EmitterState:
operators: List[Operator]
delegates: List[BackendDelegate]
operator_cache: Dict[Tuple[str, str], int]
# delegate_cache: the key is hash(delegated_payload) and the value is the index in delegates
delegate_cache: Dict[str, int]
emit_stacktrace: bool
emit_mutable_buffer_names: bool

Expand Down Expand Up @@ -1092,7 +1090,7 @@ def _emit_delegate(
delegate's blob."""
processed_bytes = lowered_module.processed_bytes
hashed = hashlib.sha256(processed_bytes).hexdigest()
delegate_index = self.emitter_state.delegate_cache.get(hashed)
delegate_index = self.program_state.backend_delegate_data_cache.get(hashed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between emitter state and program state? How is backend_delegate_data_cache different than delegate_cache

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emitter state is per-method, program state covers all the methods in the program

delegate_ret = None

if isinstance(self.node.meta["spec"], list):
Expand Down Expand Up @@ -1130,28 +1128,20 @@ def _emit_delegate(
if delegate_index is None:
# Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
# present.
hashed = hashlib.sha256(processed_bytes).hexdigest()
data_index: Optional[int] = (
self.program_state.backend_delegate_data_cache.get(hashed)
delegate_index = len(self.program_state.backend_delegate_data_cache)
self.program_state.backend_delegate_data_cache[hashed] = delegate_index
self.program_state.backend_delegate_data.append(
BackendDelegateInlineData(data=processed_bytes)
)
if data_index is None:
data_index = len(self.program_state.backend_delegate_data)
self.program_state.backend_delegate_data_cache[hashed] = data_index
self.program_state.backend_delegate_data.append(
BackendDelegateInlineData(data=processed_bytes)
)

backend_delegate = BackendDelegate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure why the previous logic didn't work but the new one does. It seems the logic is the same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous logic refers to the BackendDelegate created for the deduplicated processed blob. We may have a different compile specs, but the same processed blob, in which case the compile specs are lost. You can take a look at the test case as well, and try it on the old code.

id=lowered_module.backend_id,
processed=BackendDelegateDataReference(
location=DataLocation.INLINE, index=data_index
),
compile_specs=lowered_module.compile_specs,
)
delegate_index = len(self.emitter_state.delegate_cache)
self.emitter_state.delegates.append(backend_delegate)
self.emitter_state.delegate_cache[hashed] = delegate_index

backend_delegate = BackendDelegate(
id=lowered_module.backend_id,
processed=BackendDelegateDataReference(
location=DataLocation.INLINE, index=delegate_index
),
compile_specs=lowered_module.compile_specs,
)
self.emitter_state.delegates.append(backend_delegate)
# TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the
# function's spec and with default arguments. This requires us to store the function's spec
# in to_backend()
Expand All @@ -1164,7 +1154,12 @@ def _emit_delegate(
delegate_args.append(elem.id)

self.chain.instructions.append(
Instruction(DelegateCall(delegate_index=delegate_index, args=delegate_args))
Instruction(
DelegateCall(
delegate_index=len(self.emitter_state.delegates) - 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm why is the delegate index defined as len(self.emitter_state.delegates) - 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the new logic creates a separate BackendDelegate for each call_delegate, so it corresponds to the length self.emitter_state.delegates.

args=delegate_args,
)
)
)

return delegate_ret
Expand Down
54 changes: 54 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,60 @@ def forward(self, x):
len(edge_program_manager.executorch_program.backend_delegate_data), 1
)

def test_delegate_deduplicate_with_different_compile_specs(self) -> None:
class LowerableSubModel(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.sin(x)

lowered = LowerableSubModel()
example_input = (torch.ones(1),)

lowered_edge = to_edge(export(lowered, example_input))

from executorch.exir.backend.compile_spec_schema import CompileSpec

compile_specs1 = [CompileSpec("config", b"fast")]
compile_specs2 = [CompileSpec("config", b"small")]
lowered_module1 = to_backend(
"BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs1
)
lowered_module2 = to_backend(
"BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs2
)

class CompositeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.lowerable1 = lowered_module1
self.lowerable2 = lowered_module2

def forward(self, x):
a = self.lowerable1(x)
b = self.lowerable2(a)
return a, b

composite_model = CompositeModel()
model_inputs = (torch.ones(1),)
edge_prog = to_edge(export(composite_model, model_inputs)).to_executorch()

exported_program = edge_prog.exported_program()
program = emit_program({"method1": exported_program}, False).program
self.assertEqual(len(program.execution_plan), 1)

plan = program.execution_plan[0]
# Two delegates that point to the same blob.
self.assertEqual(len(plan.delegates), 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it for checking the number of call_delegate instruction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

self.assertEqual(plan.delegates[0].processed.index, 0)
self.assertEqual(plan.delegates[1].processed.index, 0)
# Compile specs are different.
self.assertEqual(plan.delegates[0].compile_specs, compile_specs1)
self.assertEqual(plan.delegates[1].compile_specs, compile_specs2)
# Only one delegate blob in the backend_delegate_data.
self.assertEqual(len(program.backend_delegate_data), 1)

def test_constant_tagged_mutable_tensors(self) -> None:
class Net(nn.Module):
def __init__(self):
Expand Down
Loading