Skip to content

Commit

Permalink
fix loader consumption missing upload issue
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed Oct 11, 2023
1 parent 3951f84 commit 58e6803
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 39 deletions.
76 changes: 37 additions & 39 deletions client/starwhale/api/_impl/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,43 +188,41 @@ def _get_processed_key_range(self) -> t.Optional[t.List[t.Tuple[t.Any, t.Any]]]:

# Current server side implementation only supports the original key range as the processedData parameter,
# so we need to wait for all the keys in the original key range to be processed.
while not self._key_processed_queue.empty():
key = self._key_processed_queue.get(block=True)
while True:
try:
key = self._key_processed_queue.get(block=False)
except queue.Empty:
break

# TODO: tune performance for find key range
with self._lock:
for rk in self._key_range_dict:
if (rk[0] is None or rk[0] <= key) and (
rk[1] is None or key < rk[1]
):
self._key_range_dict[rk]["processed_cnt"] += 1
break
else:
raise RuntimeError(
f"key({key}) not found in key range dict:{self._key_range_dict}"
)
for rk in self._key_range_dict:
if (rk[0] is None or rk[0] <= key) and (rk[1] is None or key < rk[1]):
self._key_range_dict[rk]["processed_cnt"] += 1
break
else:
raise RuntimeError(
f"key({key}) not found in key range dict:{self._key_range_dict}"
)

processed_range_keys = []
with self._lock:
for rk in list(self._key_range_dict.keys()):
if (
self._key_range_dict[rk]["processed_cnt"]
== self._key_range_dict[rk]["rows_cnt"]
):
processed_range_keys.append(rk)
del self._key_range_dict[rk]
for rk in list(self._key_range_dict.keys()):
if (
self._key_range_dict[rk]["processed_cnt"]
== self._key_range_dict[rk]["rows_cnt"]
):
processed_range_keys.append(rk)
del self._key_range_dict[rk]

return processed_range_keys

def _check_all_processed_done(self) -> bool:
with self._lock:
unfinished = self._expected_rows_cnt - self._processed_rows_cnt
if unfinished < 0:
raise ValueError(
f"unfinished rows cnt({unfinished}) < 0, processed rows cnt has been called more than expected"
)
else:
return unfinished == 0
unfinished = self._expected_rows_cnt - self._processed_rows_cnt
if unfinished < 0:
raise ValueError(
f"unfinished rows cnt({unfinished}) < 0, processed rows cnt has been called more than expected"
)
else:
return unfinished == 0

def _iter_meta(self) -> t.Generator[TabularDatasetRow, None, None]:
if not self.session_consumption:
Expand All @@ -233,14 +231,15 @@ def _iter_meta(self) -> t.Generator[TabularDatasetRow, None, None]:
yield row
else:
while True:
pk = self._get_processed_key_range()
rt = self.session_consumption.get_scan_range(pk)
if rt is None:
if self._check_all_processed_done():
with self._lock:
pk = self._get_processed_key_range()
rt = self.session_consumption.get_scan_range(pk)
if rt is None and self._check_all_processed_done():
break
else:
time.sleep(1)
continue

if rt is None:
time.sleep(1)
continue

rows_cnt = 0
if self.dataset_uri.instance.is_cloud:
Expand Down Expand Up @@ -379,11 +378,10 @@ def __iter__(
else:
yield row
with self._lock:
if self._key_processed_queue is not None:
self._key_processed_queue.put(row.index)
self._processed_rows_cnt += 1

if self._key_processed_queue is not None:
self._key_processed_queue.put(row.index)

console.debug(
"queue details:"
f"meta fetcher(qsize:{self._meta_fetched_queue.qsize()}, alive: {meta_fetcher.is_alive()}), "
Expand Down
79 changes: 79 additions & 0 deletions client/tests/sdk/test_loader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import os
import time
import queue
import random
import shutil
import typing as t
import tempfile
import threading
from itertools import chain
from collections import defaultdict
from unittest.mock import patch, MagicMock

from requests_mock import Mocker
Expand Down Expand Up @@ -733,6 +738,80 @@ def test_processed_key_range(self) -> None:
loader._key_processed_queue.put(30)
loader._get_processed_key_range()

@patch("starwhale.api._impl.dataset.loader.TabularDataset.scan")
def test_session_consumption(self, mock_scan: MagicMock) -> None:
mock_sc = MagicMock()
mock_sc.session_start = None
mock_sc.session_end = None
mock_sc.batch_size = 1

start_key, end_key = 0, 5002
chunk_size = 10
count = end_key - start_key

def _chunk() -> t.Iterator[t.Tuple[int, int]]:
r = range(start_key, end_key)
for i in range(0, len(r), chunk_size):
lst = r[i : i + chunk_size]
yield (lst[0], lst[-1] + 1)

allocated_keys = list(_chunk())
mock_sc.get_scan_range.side_effect = allocated_keys + [None] * 20

def _mock_scan(*args: t.Any, **kwargs: t.Any) -> t.Any:
_s, _e = args
if _s is None:
_s = start_key
if _e is None:
_e = end_key

for i in range(_s, _e):
# simulate data unpacking
time.sleep(random.randint(0, 2) / 1000)
yield TabularDatasetRow(id=i, features={"label": i})

mock_scan.side_effect = _mock_scan

exceptions = []

def _consume_loader(consumed_ids: t.Dict[str, list], name: str) -> None:
try:
loader = get_data_loader(
dataset_uri=self.dataset_uri, session_consumption=mock_sc
)
for item in loader:
consumed_ids[name].append(item.index)
# simulate data processing(predicting, etc.)
time.sleep(random.randint(1, 3) / 1000)
except Exception as e:
exceptions.append(e)
raise

consumed_ids = defaultdict(list)
loader_threads = []
for i in range(0, 10):
_n = f"loader-{i}"
_t = threading.Thread(
name=_n,
target=_consume_loader,
args=(consumed_ids, _n),
daemon=True,
)
_t.start()
loader_threads.append(_t)

for _t in loader_threads:
_t.join()

assert len(exceptions) == 0
assert len(list(consumed_ids.values())[0]) < count
assert len(list(chain(*consumed_ids.values()))) == count

submit_processed_keys = sorted(
chain(*[s[0][0] for s in mock_sc.get_scan_range.call_args_list if s[0][0]])
)
assert submit_processed_keys == allocated_keys

@patch("starwhale.core.dataset.model.StandaloneDataset.summary")
@patch("starwhale.api._impl.dataset.loader.TabularDataset.scan")
def test_loader_with_scan_exception(
Expand Down

0 comments on commit 58e6803

Please sign in to comment.