Skip to content

Commit

Permalink
refactor dataset build
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed Sep 8, 2022
1 parent 8a99e2c commit c5e5dd1
Show file tree
Hide file tree
Showing 43 changed files with 949 additions and 975 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ example/PennFudanPed/data/
example/PennFudanPed/models/
example/nmt/models/
example/nmt/data/
example/speech_command/data/
example/speech_command/models/

# UI
node_modules
Expand Down
20 changes: 16 additions & 4 deletions client/starwhale/api/_impl/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from starwhale.core.dataset.type import (
Link,
Text,
Audio,
Image,
Binary,
LinkType,
MIMEType,
DataField,
ClassLabel,
S3LinkAuth,
BoundingBox,
GrayscaleImage,
LocalFSLinkAuth,
DefaultS3LinkAuth,
COCOObjectAnnotation,
)

from .mnist import MNISTBuildExecutor
from .loader import get_data_loader, SWDSBinDataLoader, UserRawDataLoader
from .builder import BuildExecutor, SWDSBinBuildExecutor, UserRawBuildExecutor

Expand All @@ -20,11 +26,17 @@
"S3LinkAuth",
"MIMEType",
"LinkType",
"DataField",
"BuildExecutor", # SWDSBinBuildExecutor alias
"UserRawBuildExecutor",
"SWDSBinBuildExecutor",
"MNISTBuildExecutor",
"SWDSBinDataLoader",
"UserRawDataLoader",
"Binary",
"Text",
"Audio",
"Image",
"ClassLabel",
"BoundingBox",
"GrayscaleImage",
"COCOObjectAnnotation",
]
172 changes: 60 additions & 112 deletions client/starwhale/api/_impl/dataset/builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from __future__ import annotations

import sys
import struct
import typing as t
import tempfile
Expand All @@ -19,8 +16,10 @@
from starwhale.core.dataset import model
from starwhale.core.dataset.type import (
Link,
Binary,
LinkAuth,
MIMEType,
BaseArtifact,
DatasetSummary,
D_ALIGNMENT_SIZE,
D_FILE_VOLUME_SIZE,
Expand All @@ -45,10 +44,7 @@ def __init__(
dataset_name: str,
dataset_version: str,
project_name: str,
data_dir: Path = Path("."),
workdir: Path = Path("./sw_output"),
data_filter: str = "*",
label_filter: str = "*",
alignment_bytes_size: int = D_ALIGNMENT_SIZE,
volume_bytes_size: int = D_FILE_VOLUME_SIZE,
append: bool = False,
Expand All @@ -58,10 +54,6 @@ def __init__(
) -> None:
# TODO: add more docstring for args
# TODO: validate group upper and lower?
self.data_dir = data_dir
self.data_filter = data_filter
self.label_filter = label_filter

self.workdir = workdir
self.data_output_dir = workdir / "data"
ensure_dir(self.data_output_dir)
Expand Down Expand Up @@ -118,6 +110,10 @@ def __exit__(

print("cleanup done.")

@abstractmethod
def iter_item(self) -> t.Generator[t.Tuple[t.Any, t.Any], None, None]:
raise NotImplementedError

@abstractmethod
def make_swds(self) -> DatasetSummary:
raise NotImplementedError
Expand All @@ -128,59 +124,16 @@ def _merge_forked_summary(self, s: DatasetSummary) -> DatasetSummary:
s.rows += _fs.rows
s.unchanged_rows += _fs.rows
s.data_byte_size += _fs.data_byte_size
s.label_byte_size += _fs.label_byte_size
s.annotations = list(set(s.annotations) | set(_fs.annotations))
s.include_link |= _fs.include_link
s.include_user_raw |= _fs.include_user_raw

return s

def _iter_files(
self, filter: str, sort_key: t.Optional[t.Any] = None
) -> t.Generator[Path, None, None]:
_key = sort_key
if _key is not None and not callable(_key):
raise Exception(f"data_sort_func({_key}) is not callable.")

_files = sorted(self.data_dir.rglob(filter), key=_key)
for p in _files:
if not p.is_file():
continue
yield p

def iter_data_files(self) -> t.Generator[Path, None, None]:
return self._iter_files(self.data_filter, self.data_sort_func())

def iter_label_files(self) -> t.Generator[Path, None, None]:
return self._iter_files(self.label_filter, self.label_sort_func())

def iter_all_dataset_slice(self) -> t.Generator[t.Any, None, None]:
for p in self.iter_data_files():
for d in self.iter_data_slice(str(p.absolute())):
yield p, d

def iter_all_label_slice(self) -> t.Generator[t.Any, None, None]:
for p in self.iter_label_files():
for d in self.iter_label_slice(str(p.absolute())):
yield p, d

@abstractmethod
def iter_data_slice(self, path: str) -> t.Generator[t.Any, None, None]:
raise NotImplementedError

@abstractmethod
def iter_label_slice(self, path: str) -> t.Generator[t.Any, None, None]:
raise NotImplementedError

@property
def data_format_type(self) -> DataFormatType:
raise NotImplementedError

def data_sort_func(self) -> t.Any:
return None

def label_sort_func(self) -> t.Any:
return None


class SWDSBinBuildExecutor(BaseBuildExecutor):
"""
Expand Down Expand Up @@ -244,39 +197,45 @@ def make_swds(self) -> DatasetSummary:
ds_copy_candidates[fno] = dwriter_path

increased_rows = 0
total_label_size, total_data_size = 0, 0
total_data_size = 0
dataset_annotations: t.Dict[str, t.Any] = {}

for idx, ((_, data), (_, label)) in enumerate(
zip(self.iter_all_dataset_slice(), self.iter_all_label_slice()),
start=self._forked_last_idx + 1,
for idx, (row_data, row_annotations) in enumerate(
self.iter_item(), start=self._forked_last_idx + 1
):
if isinstance(data, (tuple, list)):
_data_content, _data_mime_type = data
if not isinstance(row_annotations, dict):
raise FormatError(f"annotations({row_annotations}) must be dict type")

_artifact: BaseArtifact
if isinstance(row_data, bytes):
_artifact = Binary(row_data, self.default_data_mime_type)
elif isinstance(row_data, BaseArtifact):
_artifact = row_data
else:
_data_content, _data_mime_type = data, self.default_data_mime_type
raise NoSupportError(f"data type {type(row_data)}")

if not isinstance(_data_content, bytes):
raise FormatError("data content must be bytes type")
if not dataset_annotations:
# TODO: check annotations type and name
dataset_annotations = row_annotations

_bin_section = self._write(dwriter, _data_content)
_bin_section = self._write(dwriter, _artifact.to_bytes())
self.tabular_dataset.put(
TabularDatasetRow(
id=idx,
data_uri=str(fno),
label=label,
data_format=self.data_format_type,
object_store_type=ObjectStoreType.LOCAL,
data_offset=_bin_section.raw_data_offset,
data_size=_bin_section.raw_data_size,
_swds_bin_offset=_bin_section.offset,
_swds_bin_size=_bin_section.size,
data_origin=DataOriginType.NEW,
data_mime_type=_data_mime_type or self.default_data_mime_type,
data_type=_artifact.astype(),
annotations=row_annotations,
)
)

total_data_size += _bin_section.size
total_label_size += sys.getsizeof(label)

wrote_size += _bin_section.size
if wrote_size > self.volume_bytes_size:
Expand All @@ -303,10 +262,10 @@ def make_swds(self) -> DatasetSummary:
summary = DatasetSummary(
rows=increased_rows,
increased_rows=increased_rows,
label_byte_size=total_label_size,
data_byte_size=total_data_size,
include_user_raw=False,
include_link=False,
annotations=list(dataset_annotations.keys()),
)
return self._merge_forked_summary(summary)

Expand Down Expand Up @@ -335,39 +294,53 @@ def _copy_files(

_dest_path.symlink_to(_obj_path)

# TODO: tune performance scan after put in a second
for row in self.tabular_dataset.scan(*row_pos):
self.tabular_dataset.update(
row_id=row.id, data_uri=map_fno_sign[int(row.data_uri)]
)

def iter_data_slice(self, path: str) -> t.Generator[t.Any, None, None]:
with Path(path).open() as f:
yield f.read()

def iter_label_slice(self, path: str) -> t.Generator[t.Any, None, None]:
yield Path(path).name


BuildExecutor = SWDSBinBuildExecutor


class UserRawBuildExecutor(BaseBuildExecutor):
def make_swds(self) -> DatasetSummary:
increased_rows = 0
total_label_size, total_data_size = 0, 0
total_data_size = 0
auth_candidates = {}
include_link = False

map_path_sign: t.Dict[str, t.Tuple[str, Path]] = {}
dataset_annotations: t.Dict[str, t.Any] = {}

for idx, (data, (_, label)) in enumerate(
zip(self.iter_all_dataset_slice(), self.iter_all_label_slice()),
for idx, (row_data, row_annotations) in enumerate(
self.iter_item(),
start=self._forked_last_idx + 1,
):
if isinstance(data, Link):
_remote_link = data
if not isinstance(row_annotations, dict):
raise FormatError(f"annotations({row_annotations}) must be dict type")

if not dataset_annotations:
# TODO: check annotations type and name
dataset_annotations = row_annotations

if not isinstance(row_data, Link):
raise FormatError(f"data({row_data}) must be Link type")

if row_data.with_local_fs_data:
_local_link = row_data
_data_fpath = _local_link.uri
if _data_fpath not in map_path_sign:
map_path_sign[_data_fpath] = DatasetStorage.save_data_file(
_data_fpath
)
data_uri, _ = map_path_sign[_data_fpath]
auth = ""
object_store_type = ObjectStoreType.LOCAL
else:
_remote_link = row_data
data_uri = _remote_link.uri
data_offset, data_size = _remote_link.offset, _remote_link.size
if _remote_link.auth:
auth = _remote_link.auth.name
auth_candidates[
Expand All @@ -377,42 +350,23 @@ def make_swds(self) -> DatasetSummary:
auth = ""
object_store_type = ObjectStoreType.REMOTE
include_link = True
data_mime_type = _remote_link.mime_type
elif isinstance(data, (tuple, list)):
_data_fpath, _local_link = data
if _data_fpath not in map_path_sign:
map_path_sign[_data_fpath] = DatasetStorage.save_data_file(
_data_fpath
)

if not isinstance(_local_link, Link):
raise NoSupportError("data only support Link type")

data_mime_type = _local_link.mime_type
data_offset, data_size = _local_link.offset, _local_link.size
data_uri, _ = map_path_sign[_data_fpath]
auth = ""
object_store_type = ObjectStoreType.LOCAL
else:
raise FormatError(f"data({data}) type error, no list, tuple or Link")

self.tabular_dataset.put(
TabularDatasetRow(
id=idx,
data_uri=data_uri,
label=label,
data_format=self.data_format_type,
object_store_type=object_store_type,
data_offset=data_offset,
data_size=data_size,
data_offset=row_data.offset,
data_size=row_data.size,
data_origin=DataOriginType.NEW,
auth_name=auth,
data_mime_type=data_mime_type,
data_type=row_data.astype(),
annotations=row_annotations,
)
)

total_data_size += data_size
total_label_size += sys.getsizeof(label)
total_data_size += row_data.size
increased_rows += 1

self._copy_files(map_path_sign)
Expand All @@ -422,10 +376,10 @@ def make_swds(self) -> DatasetSummary:
summary = DatasetSummary(
rows=increased_rows,
increased_rows=increased_rows,
label_byte_size=total_label_size,
data_byte_size=total_data_size,
include_link=include_link,
include_user_raw=True,
annotations=list(dataset_annotations.keys()),
)
return self._merge_forked_summary(summary)

Expand All @@ -444,12 +398,6 @@ def _copy_auth(self, auth_candidates: t.Dict[str, LinkAuth]) -> None:
for auth in auth_candidates.values():
f.write("\n".join(auth.dump_env()))

def iter_data_slice(self, path: str) -> t.Generator[t.Any, None, None]:
yield 0, Path(path).stat().st_size

def iter_label_slice(self, path: str) -> t.Generator[t.Any, None, None]:
yield Path(path).name

@property
def data_format_type(self) -> DataFormatType:
return DataFormatType.USER_RAW
Loading

0 comments on commit c5e5dd1

Please sign in to comment.