diff --git a/client/starwhale/api/_impl/dataset/builder.py b/client/starwhale/api/_impl/dataset/builder.py index dfcc0c2743..1268e74959 100644 --- a/client/starwhale/api/_impl/dataset/builder.py +++ b/client/starwhale/api/_impl/dataset/builder.py @@ -202,19 +202,29 @@ class SWDSBinBuildExecutor(BaseBuildExecutor): _DATA_FMT = SWDS_DATA_FNAME_FMT - def _write(self, writer: t.Any, data: bytes) -> t.Tuple[int, int]: + class _BinSection(t.NamedTuple): + offset: int + size: int + raw_data_offset: int + raw_data_size: int + + def _write(self, writer: t.Any, data: bytes) -> _BinSection: size = len(data) crc = crc32(data) # TODO: crc is right? start = writer.tell() padding_size = self._get_padding_size(size + _header_size) - # TODO: remove idx field _header = _header_struct.pack( _header_magic, crc, 0, size, padding_size, _header_version, _data_magic ) _padding = b"\0" * padding_size writer.write(_header + data + _padding) - return start, _header_size + size + padding_size + return self._BinSection( + offset=start, + size=_header_size + size + padding_size, + raw_data_offset=start + _header_size, + raw_data_size=size, + ) def _get_padding_size(self, size: int) -> int: remain = (size + _header_size) % self.alignment_bytes_size @@ -248,7 +258,7 @@ def make_swds(self) -> DatasetSummary: if not isinstance(_data_content, bytes): raise FormatError("data content must be bytes type") - data_offset, data_size = self._write(dwriter, _data_content) + _bin_section = self._write(dwriter, _data_content) self.tabular_dataset.put( TabularDatasetRow( id=idx, @@ -256,17 +266,19 @@ def make_swds(self) -> DatasetSummary: label=label, data_format=self.data_format_type, object_store_type=ObjectStoreType.LOCAL, - data_offset=data_offset, - data_size=data_size, + data_offset=_bin_section.raw_data_offset, + data_size=_bin_section.raw_data_size, + _swds_bin_offset=_bin_section.offset, + _swds_bin_size=_bin_section.size, data_origin=DataOriginType.NEW, data_mime_type=_data_mime_type or self.default_data_mime_type, ) ) - total_data_size += data_size + total_data_size += _bin_section.size total_label_size += sys.getsizeof(label) - wrote_size += data_size + wrote_size += _bin_section.size if wrote_size > self.volume_bytes_size: wrote_size = 0 fno += 1 diff --git a/client/starwhale/api/_impl/dataset/loader.py b/client/starwhale/api/_impl/dataset/loader.py index 7fee463a46..27771b59b0 100644 --- a/client/starwhale/api/_impl/dataset/loader.py +++ b/client/starwhale/api/_impl/dataset/loader.py @@ -69,9 +69,15 @@ def _get_key_compose(self, row: TabularDatasetRow, store: ObjectStore) -> str: if store.key_prefix: data_uri = os.path.join(store.key_prefix, data_uri.lstrip("/")) - _key_compose = ( - f"{data_uri}:{row.data_offset}:{row.data_offset + row.data_size - 1}" - ) + if self.kind == DataFormatType.SWDS_BIN: + offset, size = ( + int(row.extra_kw["_swds_bin_offset"]), + int(row.extra_kw["_swds_bin_size"]), + ) + else: + offset, size = row.data_offset, row.data_size + + _key_compose = f"{data_uri}:{offset}:{offset + size - 1}" return _key_compose def __iter__(self) -> t.Generator[t.Tuple[DataField, DataField], None, None]: diff --git a/client/starwhale/core/dataset/tabular.py b/client/starwhale/core/dataset/tabular.py index c9c0a0afbf..8ec7e52d1e 100644 --- a/client/starwhale/core/dataset/tabular.py +++ b/client/starwhale/core/dataset/tabular.py @@ -50,18 +50,19 @@ def __init__( data_origin: DataOriginType = DataOriginType.NEW, data_mime_type: MIMEType = MIMEType.UNDEFINED, auth_name: str = "", - **kw: t.Any, + **kw: t.Union[str, int, float], ) -> None: self.id = id self.data_uri = data_uri.strip() - self.data_format = data_format + self.data_format = DataFormatType(data_format) self.data_offset = data_offset self.data_size = data_size - self.data_origin = data_origin - self.object_store_type = object_store_type - self.data_mime_type = data_mime_type + self.data_origin = DataOriginType(data_origin) + self.object_store_type = ObjectStoreType(object_store_type) + self.data_mime_type = MIMEType(data_mime_type) self.auth_name = auth_name self.label = self._parse_label(label) + self.extra_kw = kw # TODO: add non-starwhale object store related fields, such as address, authority # TODO: add data uri crc for versioning @@ -94,8 +95,7 @@ def _do_validate(self) -> None: if self.data_origin not in DataOriginType: raise NoSupportError(f"data origin: {self.data_origin}") - # TODO: support non-starwhale remote object store, for index-only feature - if self.object_store_type != ObjectStoreType.LOCAL: + if self.object_store_type not in ObjectStoreType: raise NoSupportError(f"object store {self.object_store_type}") def __str__(self) -> str: @@ -110,6 +110,8 @@ def __repr__(self) -> str: def asdict(self) -> t.Dict[str, t.Union[str, bytes, int]]: d = deepcopy(self.__dict__) + d.pop("extra_kw", None) + d.update(self.extra_kw) for k, v in d.items(): if isinstance(v, Enum): d[k] = v.value diff --git a/client/tests/sdk/test_dataset.py b/client/tests/sdk/test_dataset.py index aa8f83af59..d261f69d2a 100644 --- a/client/tests/sdk/test_dataset.py +++ b/client/tests/sdk/test_dataset.py @@ -117,6 +117,7 @@ def test_swds_bin_workflow(self) -> None: label_filter="mnist-data-*", alignment_bytes_size=64, volume_bytes_size=100, + data_mime_type=MIMEType.GRAYSCALE, ) as e: assert e.data_tmpdir.exists() summary = e.make_swds() @@ -163,5 +164,7 @@ def test_swds_bin_workflow(self) -> None: tdb = TabularDataset(name="mnist", version="112233", project="self") meta = list(tdb.scan(0, 1))[0] assert meta.id == 0 - assert meta.data_offset == 0 + assert meta.data_offset == 32 + assert meta.extra_kw["_swds_bin_offset"] == 0 assert meta.data_uri in data_files_sign + assert meta.data_mime_type == MIMEType.GRAYSCALE diff --git a/client/tests/sdk/test_loader.py b/client/tests/sdk/test_loader.py index 6eddffb739..013420aec8 100644 --- a/client/tests/sdk/test_loader.py +++ b/client/tests/sdk/test_loader.py @@ -16,7 +16,11 @@ UserRawDataLoader, ) from starwhale.core.dataset.type import DatasetSummary -from starwhale.core.dataset.store import DatasetStorage +from starwhale.core.dataset.store import ( + DatasetStorage, + S3StorageBackend, + LocalFSStorageBackend, +) from starwhale.core.dataset.tabular import TabularDatasetRow from .. import ROOT_DIR @@ -234,8 +238,10 @@ def test_swds_bin_s3( id=0, object_store_type=ObjectStoreType.LOCAL, data_uri=fname, - data_offset=0, - data_size=8160, + data_offset=32, + data_size=784, + _swds_bin_offset=0, + _swds_bin_size=8160, label=b"0", data_origin=DataOriginType.NEW, data_format=DataFormatType.SWDS_BIN, @@ -282,6 +288,7 @@ def test_swds_bin_s3( assert list(loader._stores.keys()) == ["local."] backend = loader._stores["local."].backend + assert isinstance(backend, S3StorageBackend) assert backend.kind == SWDSBackendType.S3 assert backend.s3.Object.call_args[0] == ( "starwhale", @@ -314,8 +321,10 @@ def test_swds_bin_local_fs(self, m_scan: MagicMock, m_summary: MagicMock) -> Non id=0, object_store_type=ObjectStoreType.LOCAL, data_uri=fname, - data_offset=0, - data_size=8160, + data_offset=32, + data_size=784, + _swds_bin_offset=0, + _swds_bin_size=8160, label=b"0", data_origin=DataOriginType.NEW, data_format=DataFormatType.SWDS_BIN, @@ -326,8 +335,10 @@ def test_swds_bin_local_fs(self, m_scan: MagicMock, m_summary: MagicMock) -> Non id=1, object_store_type=ObjectStoreType.LOCAL, data_uri=fname, - data_offset=0, - data_size=8160, + data_offset=32, + data_size=784, + _swds_bin_offset=0, + _swds_bin_size=8160, label=b"1", data_origin=DataOriginType.NEW, data_format=DataFormatType.SWDS_BIN, @@ -353,6 +364,8 @@ def test_swds_bin_local_fs(self, m_scan: MagicMock, m_summary: MagicMock) -> Non assert _data.ext_attr == {"ds_name": "mnist", "ds_version": "1122334455667788"} assert list(loader._stores.keys()) == ["local."] - assert loader._stores["local."].backend.kind == SWDSBackendType.LocalFS + backend = loader._stores["local."].backend + assert isinstance(backend, LocalFSStorageBackend) + assert backend.kind == SWDSBackendType.LocalFS assert loader._stores["local."].bucket == str(data_dir) assert not loader._stores["local."].key_prefix