diff --git a/exir/serde/schema.py b/exir/serde/schema.py index 16ac724917c..6d250ee7923 100644 --- a/exir/serde/schema.py +++ b/exir/serde/schema.py @@ -402,3 +402,4 @@ class LoweredBackendModule: original_module: export_schema.ExportedProgram original_state_dict: str original_constants: str + named_data_store: Optional[bytes] = None diff --git a/exir/serde/serialize.py b/exir/serde/serialize.py index c25b9b3a771..c9605018c4a 100644 --- a/exir/serde/serialize.py +++ b/exir/serde/serialize.py @@ -22,6 +22,7 @@ import torch import torch.export.exported_program as ep from executorch.exir import delegate +from executorch.exir._serialize._named_data_store import NamedDataStoreOutput from executorch.exir.backend.compile_spec_schema import ( CompileSpec as delegate_CompileSpec, ) @@ -268,6 +269,7 @@ def serialize_bytes(b: bytes) -> str: assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram) serialized_processed_bytes = serialize_bytes(lowered_module.processed_bytes) + named_data_store = json.dumps(export_serialize._dataclass_to_dict(lowered_module.named_data_store_output),cls=export_serialize.EnumEncoder) if lowered_module.named_data_store_output else None serialized_lowered_module = SerdeLoweredBackendModule( original_module=serialized_artifact.exported_program, @@ -276,6 +278,7 @@ def serialize_bytes(b: bytes) -> str: processed_bytes=serialized_processed_bytes, compile_specs=serialized_compile_spec, backend_id=lowered_module.backend_id, + named_data_store=named_data_store, ) json_lowered_module = json.dumps( @@ -556,11 +559,19 @@ def deserialize_lowered_module( None, ) + if serialized_lowered_module.named_data_store is None: + named_data_store = None + else: + named_data_store = export_serialize._dict_to_dataclass(NamedDataStoreOutput, json.loads(serialized_lowered_module.named_data_store)) + for buffer in named_data_store.buffers: + buffer.buffer = base64.b64decode(buffer.buffer.encode("ascii")) + lowered_module = ExirLoweredBackendModule( original_module, backend_id, processed_bytes, compile_specs, + named_data_store ) self.module.register_module(serialized_lowered_module_arg.name, lowered_module) return self.graph.get_attr(serialized_lowered_module_arg.name) diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 650a77c6ef6..63f76656a03 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -98,6 +98,7 @@ python_unittest( "//executorch/exir/backend/test:backend_with_compiler_demo", "//executorch/exir/backend/test:op_partitioner_demo", "//executorch/exir/serde:serialize", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", ], ) diff --git a/exir/tests/test_serde.py b/exir/tests/test_serde.py index 5b09ddf07c1..67821d0bffb 100644 --- a/exir/tests/test_serde.py +++ b/exir/tests/test_serde.py @@ -7,12 +7,16 @@ # pyre-strict import io +import tempfile import unittest from typing import Tuple import executorch.exir as exir import torch +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackFloatingPointPartitioner, +) from executorch.exir import to_edge from executorch.exir.backend.backend_api import CompileSpec, to_backend from executorch.exir.backend.test.backend_with_compiler_demo import ( @@ -20,6 +24,10 @@ ) from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo +from executorch.exir.program._program import ( + EdgeProgramManager, + to_edge_transform_and_lower, +) from executorch.exir.serde.serialize import deserialize, serialize from torch import nn from torch.export import export @@ -202,6 +210,33 @@ def forward(self, a, x, b): edge_new = deserialize(serialize(edge.exported_program())) self.check_ep(edge.exported_program(), edge_new, inputs) + def test_delegate_xnnpack(self) -> None: + class SimpleConv1DModel(nn.Module): + def __init__(self): + super(SimpleConv1DModel, self).__init__() + self.conv1 = nn.Conv1d( + in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = self.conv1(x) + return x + + x = torch.randn(64, 1, 100) + model = SimpleConv1DModel() + ep = torch.export.export(model, (x,)) + edge_orig = to_edge_transform_and_lower( + ep, partitioner=[XnnpackFloatingPointPartitioner()] + ) + + with tempfile.NamedTemporaryFile() as f: + exir.save(edge_orig.exported_program(), f) + edge_deserialized = EdgeProgramManager(exir.load(f)) + self.assertTrue( + edge_orig.to_executorch().buffer + == edge_deserialized.to_executorch().buffer + ) + def test_meta_stack_trace_module_hierarchy(self) -> None: class Model(nn.Module): def __init__(self):