Skip to content

Commit

Permalink
tune dataset meta schema fields with swds_bin_offset and swds_bin_size
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed Aug 28, 2022
1 parent ede5aec commit 50b5438
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 27 deletions.
28 changes: 20 additions & 8 deletions client/starwhale/api/_impl/dataset/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,19 +202,29 @@ class SWDSBinBuildExecutor(BaseBuildExecutor):

_DATA_FMT = SWDS_DATA_FNAME_FMT

def _write(self, writer: t.Any, data: bytes) -> t.Tuple[int, int]:
class _BinSection(t.NamedTuple):
offset: int
size: int
raw_data_offset: int
raw_data_size: int

def _write(self, writer: t.Any, data: bytes) -> _BinSection:
size = len(data)
crc = crc32(data) # TODO: crc is right?
start = writer.tell()
padding_size = self._get_padding_size(size + _header_size)

# TODO: remove idx field
_header = _header_struct.pack(
_header_magic, crc, 0, size, padding_size, _header_version, _data_magic
)
_padding = b"\0" * padding_size
writer.write(_header + data + _padding)
return start, _header_size + size + padding_size
return self._BinSection(
offset=start,
size=_header_size + size + padding_size,
raw_data_offset=start + _header_size,
raw_data_size=size,
)

def _get_padding_size(self, size: int) -> int:
remain = (size + _header_size) % self.alignment_bytes_size
Expand Down Expand Up @@ -248,25 +258,27 @@ def make_swds(self) -> DatasetSummary:
if not isinstance(_data_content, bytes):
raise FormatError("data content must be bytes type")

data_offset, data_size = self._write(dwriter, _data_content)
_bin_section = self._write(dwriter, _data_content)
self.tabular_dataset.put(
TabularDatasetRow(
id=idx,
data_uri=str(fno),
label=label,
data_format=self.data_format_type,
object_store_type=ObjectStoreType.LOCAL,
data_offset=data_offset,
data_size=data_size,
data_offset=_bin_section.raw_data_offset,
data_size=_bin_section.raw_data_size,
_swds_bin_offset=_bin_section.offset,
_swds_bin_size=_bin_section.size,
data_origin=DataOriginType.NEW,
data_mime_type=_data_mime_type or self.default_data_mime_type,
)
)

total_data_size += data_size
total_data_size += _bin_section.size
total_label_size += sys.getsizeof(label)

wrote_size += data_size
wrote_size += _bin_section.size
if wrote_size > self.volume_bytes_size:
wrote_size = 0
fno += 1
Expand Down
12 changes: 9 additions & 3 deletions client/starwhale/api/_impl/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,15 @@ def _get_key_compose(self, row: TabularDatasetRow, store: ObjectStore) -> str:
if store.key_prefix:
data_uri = os.path.join(store.key_prefix, data_uri.lstrip("/"))

_key_compose = (
f"{data_uri}:{row.data_offset}:{row.data_offset + row.data_size - 1}"
)
if self.kind == DataFormatType.SWDS_BIN:
offset, size = (
int(row.extra_kw["_swds_bin_offset"]),
int(row.extra_kw["_swds_bin_size"]),
)
else:
offset, size = row.data_offset, row.data_size

_key_compose = f"{data_uri}:{offset}:{offset + size - 1}"
return _key_compose

def __iter__(self) -> t.Generator[t.Tuple[DataField, DataField], None, None]:
Expand Down
16 changes: 9 additions & 7 deletions client/starwhale/core/dataset/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,19 @@ def __init__(
data_origin: DataOriginType = DataOriginType.NEW,
data_mime_type: MIMEType = MIMEType.UNDEFINED,
auth_name: str = "",
**kw: t.Any,
**kw: t.Union[str, int, float],
) -> None:
self.id = id
self.data_uri = data_uri.strip()
self.data_format = data_format
self.data_format = DataFormatType(data_format)
self.data_offset = data_offset
self.data_size = data_size
self.data_origin = data_origin
self.object_store_type = object_store_type
self.data_mime_type = data_mime_type
self.data_origin = DataOriginType(data_origin)
self.object_store_type = ObjectStoreType(object_store_type)
self.data_mime_type = MIMEType(data_mime_type)
self.auth_name = auth_name
self.label = self._parse_label(label)
self.extra_kw = kw

# TODO: add non-starwhale object store related fields, such as address, authority
# TODO: add data uri crc for versioning
Expand Down Expand Up @@ -94,8 +95,7 @@ def _do_validate(self) -> None:
if self.data_origin not in DataOriginType:
raise NoSupportError(f"data origin: {self.data_origin}")

# TODO: support non-starwhale remote object store, for index-only feature
if self.object_store_type != ObjectStoreType.LOCAL:
if self.object_store_type not in ObjectStoreType:
raise NoSupportError(f"object store {self.object_store_type}")

def __str__(self) -> str:
Expand All @@ -110,6 +110,8 @@ def __repr__(self) -> str:

def asdict(self) -> t.Dict[str, t.Union[str, bytes, int]]:
d = deepcopy(self.__dict__)
d.pop("extra_kw", None)
d.update(self.extra_kw)
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
Expand Down
5 changes: 4 additions & 1 deletion client/tests/sdk/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def test_swds_bin_workflow(self) -> None:
label_filter="mnist-data-*",
alignment_bytes_size=64,
volume_bytes_size=100,
data_mime_type=MIMEType.GRAYSCALE,
) as e:
assert e.data_tmpdir.exists()
summary = e.make_swds()
Expand Down Expand Up @@ -163,5 +164,7 @@ def test_swds_bin_workflow(self) -> None:
tdb = TabularDataset(name="mnist", version="112233", project="self")
meta = list(tdb.scan(0, 1))[0]
assert meta.id == 0
assert meta.data_offset == 0
assert meta.data_offset == 32
assert meta.extra_kw["_swds_bin_offset"] == 0
assert meta.data_uri in data_files_sign
assert meta.data_mime_type == MIMEType.GRAYSCALE
29 changes: 21 additions & 8 deletions client/tests/sdk/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
UserRawDataLoader,
)
from starwhale.core.dataset.type import DatasetSummary
from starwhale.core.dataset.store import DatasetStorage
from starwhale.core.dataset.store import (
DatasetStorage,
S3StorageBackend,
LocalFSStorageBackend,
)
from starwhale.core.dataset.tabular import TabularDatasetRow

from .. import ROOT_DIR
Expand Down Expand Up @@ -234,8 +238,10 @@ def test_swds_bin_s3(
id=0,
object_store_type=ObjectStoreType.LOCAL,
data_uri=fname,
data_offset=0,
data_size=8160,
data_offset=32,
data_size=784,
_swds_bin_offset=0,
_swds_bin_size=8160,
label=b"0",
data_origin=DataOriginType.NEW,
data_format=DataFormatType.SWDS_BIN,
Expand Down Expand Up @@ -282,6 +288,7 @@ def test_swds_bin_s3(

assert list(loader._stores.keys()) == ["local."]
backend = loader._stores["local."].backend
assert isinstance(backend, S3StorageBackend)
assert backend.kind == SWDSBackendType.S3
assert backend.s3.Object.call_args[0] == (
"starwhale",
Expand Down Expand Up @@ -314,8 +321,10 @@ def test_swds_bin_local_fs(self, m_scan: MagicMock, m_summary: MagicMock) -> Non
id=0,
object_store_type=ObjectStoreType.LOCAL,
data_uri=fname,
data_offset=0,
data_size=8160,
data_offset=32,
data_size=784,
_swds_bin_offset=0,
_swds_bin_size=8160,
label=b"0",
data_origin=DataOriginType.NEW,
data_format=DataFormatType.SWDS_BIN,
Expand All @@ -326,8 +335,10 @@ def test_swds_bin_local_fs(self, m_scan: MagicMock, m_summary: MagicMock) -> Non
id=1,
object_store_type=ObjectStoreType.LOCAL,
data_uri=fname,
data_offset=0,
data_size=8160,
data_offset=32,
data_size=784,
_swds_bin_offset=0,
_swds_bin_size=8160,
label=b"1",
data_origin=DataOriginType.NEW,
data_format=DataFormatType.SWDS_BIN,
Expand All @@ -353,6 +364,8 @@ def test_swds_bin_local_fs(self, m_scan: MagicMock, m_summary: MagicMock) -> Non
assert _data.ext_attr == {"ds_name": "mnist", "ds_version": "1122334455667788"}

assert list(loader._stores.keys()) == ["local."]
assert loader._stores["local."].backend.kind == SWDSBackendType.LocalFS
backend = loader._stores["local."].backend
assert isinstance(backend, LocalFSStorageBackend)
assert backend.kind == SWDSBackendType.LocalFS
assert loader._stores["local."].bucket == str(data_dir)
assert not loader._stores["local."].key_prefix

0 comments on commit 50b5438

Please sign in to comment.