Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion exir/emit/_emit_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def emit_program(
operators=[],
delegates=[],
operator_cache={},
delegate_cache={},
emit_stacktrace=emit_stacktrace,
emit_mutable_buffer_names=emit_mutable_buffer_names,
)
Expand Down
43 changes: 19 additions & 24 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,6 @@ class _EmitterState:
operators: List[Operator]
delegates: List[BackendDelegate]
operator_cache: Dict[Tuple[str, str], int]
# delegate_cache: the key is hash(delegated_payload) and the value is the index in delegates
delegate_cache: Dict[str, int]
emit_stacktrace: bool
emit_mutable_buffer_names: bool

Expand Down Expand Up @@ -1091,7 +1089,7 @@ def _emit_delegate(
delegate's blob."""
processed_bytes = lowered_module.processed_bytes
hashed = hashlib.sha256(processed_bytes).hexdigest()
delegate_index = self.emitter_state.delegate_cache.get(hashed)
delegate_index = self.program_state.backend_delegate_data_cache.get(hashed)
delegate_ret = None

if isinstance(self.node.meta["spec"], list):
Expand Down Expand Up @@ -1129,28 +1127,20 @@ def _emit_delegate(
if delegate_index is None:
# Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
# present.
hashed = hashlib.sha256(processed_bytes).hexdigest()
data_index: Optional[int] = (
self.program_state.backend_delegate_data_cache.get(hashed)
delegate_index = len(self.program_state.backend_delegate_data_cache)
self.program_state.backend_delegate_data_cache[hashed] = delegate_index
self.program_state.backend_delegate_data.append(
BackendDelegateInlineData(data=processed_bytes)
)
if data_index is None:
data_index = len(self.program_state.backend_delegate_data)
self.program_state.backend_delegate_data_cache[hashed] = data_index
self.program_state.backend_delegate_data.append(
BackendDelegateInlineData(data=processed_bytes)
)

backend_delegate = BackendDelegate(
id=lowered_module.backend_id,
processed=BackendDelegateDataReference(
location=DataLocation.INLINE, index=data_index
),
compile_specs=lowered_module.compile_specs,
)
delegate_index = len(self.emitter_state.delegate_cache)
self.emitter_state.delegates.append(backend_delegate)
self.emitter_state.delegate_cache[hashed] = delegate_index

backend_delegate = BackendDelegate(
id=lowered_module.backend_id,
processed=BackendDelegateDataReference(
location=DataLocation.INLINE, index=delegate_index
),
compile_specs=lowered_module.compile_specs,
)
self.emitter_state.delegates.append(backend_delegate)
# TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the
# function's spec and with default arguments. This requires us to store the function's spec
# in to_backend()
Expand All @@ -1163,7 +1153,12 @@ def _emit_delegate(
delegate_args.append(elem.id)

self.chain.instructions.append(
Instruction(DelegateCall(delegate_index=delegate_index, args=delegate_args))
Instruction(
DelegateCall(
delegate_index=len(self.emitter_state.delegates) - 1,
args=delegate_args,
)
)
)

return delegate_ret
Expand Down
54 changes: 54 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,60 @@ def forward(self, x):
len(edge_program_manager.executorch_program.backend_delegate_data), 1
)

def test_delegate_deduplicate_with_different_compile_specs(self) -> None:
class LowerableSubModel(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.sin(x)

lowered = LowerableSubModel()
example_input = (torch.ones(1),)

lowered_edge = to_edge(export(lowered, example_input))

from executorch.exir.backend.compile_spec_schema import CompileSpec

compile_specs1 = [CompileSpec("config", b"fast")]
compile_specs2 = [CompileSpec("config", b"small")]
lowered_module1 = to_backend(
"BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs1
)
lowered_module2 = to_backend(
"BackendWithCompilerDemo", lowered_edge.exported_program(), compile_specs2
)

class CompositeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.lowerable1 = lowered_module1
self.lowerable2 = lowered_module2

def forward(self, x):
a = self.lowerable1(x)
b = self.lowerable2(a)
return a, b

composite_model = CompositeModel()
model_inputs = (torch.ones(1),)
edge_prog = to_edge(export(composite_model, model_inputs)).to_executorch()

exported_program = edge_prog.exported_program()
program = emit_program({"method1": exported_program}, False).program
self.assertEqual(len(program.execution_plan), 1)

plan = program.execution_plan[0]
# Two delegates that point to the same blob.
self.assertEqual(len(plan.delegates), 2)
self.assertEqual(plan.delegates[0].processed.index, 0)
self.assertEqual(plan.delegates[1].processed.index, 0)
# Compile specs are different.
self.assertEqual(plan.delegates[0].compile_specs, compile_specs1)
self.assertEqual(plan.delegates[1].compile_specs, compile_specs2)
# Only one delegate blob in the backend_delegate_data.
self.assertEqual(len(program.backend_delegate_data), 1)

def test_constant_tagged_mutable_tensors(self) -> None:
class Net(nn.Module):
def __init__(self):
Expand Down
Loading