diff --git a/extension/flat_tensor/serialize/flat_tensor.fbs b/extension/flat_tensor/serialize/flat_tensor.fbs index 35491593c47..7d47a61d2b0 100644 --- a/extension/flat_tensor/serialize/flat_tensor.fbs +++ b/extension/flat_tensor/serialize/flat_tensor.fbs @@ -35,8 +35,8 @@ table TensorMetadata { // To retrieve a given tensor: // 1. segment_base_offset: from the file header. // 2. segment_offset: segments[segment_index].offset - // 3. tensor_offset: segments[segment_offset].tensor_metadata[j].offset - // Find the relevant index j by matching on tensor fqn. + // 3. tensor_offset: the offset within the segment. If there is only one item + // in the segment, offset=0. offset: uint64; } @@ -55,6 +55,15 @@ table DataSegment { size: uint64; } +// Attributes a name to data referenced by FlatTensor.segments. +table NamedData { + // The unique id of the data blob. + key: string; + + // Index of the segment in FlatTensor.segments. + segment_index: uint32; +} + // FlatTensor is a flatbuffer-based format for storing and loading tensors. table FlatTensor { // Schema version. @@ -70,6 +79,10 @@ table FlatTensor { // List of data segments that follow the FlatTensor data in this file, sorted by // offset. Elements in this schema can refer to these segments by index. segments: [DataSegment]; + + // List of blobs keyed by a unique name. Note that multiple 'NamedData' + // entries could point to the same segment index. + named_data: [NamedData]; } root_type FlatTensor; diff --git a/extension/flat_tensor/serialize/flat_tensor_schema.py b/extension/flat_tensor/serialize/flat_tensor_schema.py index 5ede6ced5bf..9581442c2d8 100644 --- a/extension/flat_tensor/serialize/flat_tensor_schema.py +++ b/extension/flat_tensor/serialize/flat_tensor_schema.py @@ -31,9 +31,16 @@ class DataSegment: size: int +@dataclass +class NamedData: + key: str + segment_index: int + + @dataclass class FlatTensor: version: int tensor_alignment: int tensors: List[TensorMetadata] segments: List[DataSegment] + named_data: List[NamedData] diff --git a/extension/flat_tensor/serialize/serialize.py b/extension/flat_tensor/serialize/serialize.py index abdccc17dec..683530adbfd 100644 --- a/extension/flat_tensor/serialize/serialize.py +++ b/extension/flat_tensor/serialize/serialize.py @@ -282,6 +282,7 @@ def serialize( tensor_alignment=self.config.tensor_alignment, tensors=flat_tensor_metadata, segments=[DataSegment(offset=0, size=len(flat_tensor_data))], + named_data=[], ) flatbuffer_payload = _serialize_to_flatbuffer(flat_tensor)