diff --git a/exir/TARGETS b/exir/TARGETS index 853d5e199ba..402e9a21bd1 100644 --- a/exir/TARGETS +++ b/exir/TARGETS @@ -79,6 +79,16 @@ runtime.python_library( ], ) +runtime.python_library( + name = "tensor_layout", + srcs = [ + "tensor_layout.py", + ], + deps = [ + ":scalar_type", + ] +) + runtime.python_library( name = "memory", srcs = [ diff --git a/exir/_serialize/TARGETS b/exir/_serialize/TARGETS index 1b8b76b7835..51bad73ab5c 100644 --- a/exir/_serialize/TARGETS +++ b/exir/_serialize/TARGETS @@ -64,5 +64,6 @@ runtime.python_library( deps = [ "//executorch/exir:schema", "//executorch/exir:tensor", + "//executorch/exir:tensor_layout", ], ) diff --git a/exir/_serialize/_serialize.py b/exir/_serialize/_serialize.py index e2147458545..06e81997654 100644 --- a/exir/_serialize/_serialize.py +++ b/exir/_serialize/_serialize.py @@ -16,12 +16,12 @@ DataEntry, DataPayload, DataSerializer, - TensorLayout, ) from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.emit import EmitterOutput from executorch.exir.schema import Tensor, TensorDataLocation +from executorch.exir.tensor_layout import TensorLayout def serialize_for_executorch( diff --git a/exir/_serialize/data_serializer.py b/exir/_serialize/data_serializer.py index e828b4d0ae3..cee34506b66 100644 --- a/exir/_serialize/data_serializer.py +++ b/exir/_serialize/data_serializer.py @@ -3,7 +3,7 @@ from typing import Dict, Optional, Sequence from executorch.exir._serialize._cord import Cord -from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorLayout +from executorch.exir.tensor_layout import TensorLayout @dataclass diff --git a/exir/tensor_layout.py b/exir/tensor_layout.py new file mode 100644 index 00000000000..f8f77ebeea3 --- /dev/null +++ b/exir/tensor_layout.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from dataclasses import dataclass +from typing import List + +from executorch.exir.scalar_type import ScalarType + + +# Note: keep this in sync with the TensorLayout definition in +# executorch/extension/flat_tensor/serialize/flat_tensor.fbs +@dataclass +class TensorLayout: + scalar_type: ScalarType + sizes: List[int] + dim_order: List[int] diff --git a/extension/flat_tensor/serialize/TARGETS b/extension/flat_tensor/serialize/TARGETS index 229f6930f4e..b9ccadf9f23 100644 --- a/extension/flat_tensor/serialize/TARGETS +++ b/extension/flat_tensor/serialize/TARGETS @@ -13,6 +13,9 @@ runtime.python_library( visibility = [ "//executorch/...", ], + deps = [ + "//executorch/exir:tensor_layout", + ] ) runtime.python_library( diff --git a/extension/flat_tensor/serialize/flat_tensor.fbs b/extension/flat_tensor/serialize/flat_tensor.fbs index abf331697d6..4b71e13e2c4 100644 --- a/extension/flat_tensor/serialize/flat_tensor.fbs +++ b/extension/flat_tensor/serialize/flat_tensor.fbs @@ -7,6 +7,8 @@ namespace flat_tensor_flatbuffer; file_identifier "FT01"; file_extension "ptd"; +// Note: keep this in sync with the python definition in +// executorch/exir/tensor_layout.py table TensorLayout { scalar_type: executorch_flatbuffer.ScalarType; diff --git a/extension/flat_tensor/serialize/flat_tensor_schema.py b/extension/flat_tensor/serialize/flat_tensor_schema.py index 53b0fe98ea9..2fcf2c6eb81 100644 --- a/extension/flat_tensor/serialize/flat_tensor_schema.py +++ b/extension/flat_tensor/serialize/flat_tensor_schema.py @@ -9,18 +9,11 @@ from dataclasses import dataclass from typing import List, Optional -from executorch.exir.scalar_type import ScalarType +from executorch.exir.tensor_layout import TensorLayout # Note: check executorch/extension/data_format/flat_tensor.fbs for explanations of these fields. -@dataclass -class TensorLayout: - scalar_type: ScalarType - sizes: List[int] - dim_order: List[int] - - @dataclass class DataSegment: offset: int diff --git a/extension/flat_tensor/test/test_serialize.py b/extension/flat_tensor/test/test_serialize.py index 13402e60a65..726a8845c2e 100644 --- a/extension/flat_tensor/test/test_serialize.py +++ b/extension/flat_tensor/test/test_serialize.py @@ -22,7 +22,7 @@ from executorch.exir._serialize.padding import aligned_size from executorch.exir.schema import ScalarType -from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorLayout +from executorch.exir.tensor_layout import TensorLayout from executorch.extension.flat_tensor.serialize.serialize import ( _deserialize_to_flat_tensor,