Skip to content

Commit

Permalink
Begin collecting completed upload parts immediately for more responsi…
Browse files Browse the repository at this point in the history
…ve callbacks
  • Loading branch information
bhperry committed Feb 27, 2024
1 parent 497a3ad commit a397947
Showing 1 changed file with 41 additions and 28 deletions.
69 changes: 41 additions & 28 deletions saturnfs/client/file_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def __init__(
self.num_workers = num_workers
self.exit_on_timeout = exit_on_timeout
self.upload_queue: Queue[Optional[UploadChunk]] = Queue(2 * self.num_workers)
self.completed_queue: Queue[Optional[ObjectStorageCompletePart]] = Queue()
self.completed_queue: Queue[Union[ObjectStorageCompletePart, UploadStop]] = Queue()
self.stop = Event()

for _ in range(self.num_workers):
Expand All @@ -461,32 +461,42 @@ def __init__(
def upload_chunks(
self, chunks: Iterable[UploadChunk], callback: Optional[Callback] = None
) -> Tuple[List[ObjectStorageCompletePart], bool]:
num_parts: int = 0
first_part: int = -1
all_chunks_read: bool = False
for chunk in chunks:
if first_part == -1:
first_part = chunk.part.part_number
if not self._put_chunk(chunk):
break
num_parts += 1
else:
all_chunks_read = True

if first_part == -1:
first_part_number = self._producer(chunks)
if first_part_number == -1:
# No chunks given
return [], True

self._wait()
completed_parts, uploads_finished = self._collect(first_part, num_parts, callback=callback)
completed_parts, uploads_finished = self._collect(first_part_number, callback=callback)
self.stop.clear()
return completed_parts, uploads_finished and all_chunks_read
return completed_parts, uploads_finished

def close(self):
# Signal shutdown to upload workers
for _ in range(self.num_workers):
self.upload_queue.put(None)

def _producer(self, chunks: Iterable[UploadChunk]) -> int:
first_chunk = next(iter(chunks), None)
if first_chunk is None:
return -1
self.upload_queue.put(first_chunk)

def _producer_thread():
all_chunks_read: bool = False
for chunk in chunks:
if not self._put_chunk(chunk):
break
else:
all_chunks_read = True

uploads_finished = self._wait()

# Signal end of upload to the collector
self.completed_queue.put(UploadStop(error=not (uploads_finished and all_chunks_read)))

Thread(target=_producer_thread, daemon=True).start()
return first_chunk.part.part_number

def _put_chunk(self, chunk: UploadChunk, poll_interval: int = 5) -> bool:
while True:
try:
Expand Down Expand Up @@ -524,7 +534,7 @@ def _worker(self):
self.upload_queue.task_done()
self.completed_queue.put(completed_part)

def _wait(self):
def _wait(self) -> bool:
# Wait for workers to finish processing all chunks, or exit due to expired signatures
uploads_finished = False

Expand Down Expand Up @@ -552,24 +562,22 @@ def _uploads_finished():
# Join the thread instead of queue to ensure there is no race-condition when
# worker threads are signaled to shutdown (otherwise uploads_finished_thread could leak)
uploads_finished_thread.join()

# Signal error to the collector
self.completed_queue.put(None)
return uploads_finished

def _collect(
self,
first_part: int,
num_parts: int,
callback: Optional[Callback] = None,
):
# Collect completed parts
completed_parts: List[ObjectStorageCompletePart] = []
uploads_finished: bool = False
while True:
completed_part = self.completed_queue.get()
if completed_part is None:
# Error detected in one or more workers
if isinstance(completed_part, UploadStop):
# End of upload detected
# Producer only puts None on the queue when all workers are done
uploads_finished = not completed_part.error
self.completed_queue.task_done()
break

Expand All @@ -579,10 +587,6 @@ def _collect(
if callback is not None:
callback.relative_update(completed_part.size)

if len(completed_parts) == num_parts:
uploads_finished = True
break

completed_parts.sort(key=lambda p: p.part_number)
completed_len = len(completed_parts)
if not uploads_finished and completed_len > 0:
Expand Down Expand Up @@ -663,6 +667,15 @@ class UploadChunk:
data: Any


@dataclass
class UploadStop:
"""
Sentinel type to mark end of upload
"""

error: bool = False


def set_last_modified(local_path: str, last_modified: datetime):
timestamp = last_modified.timestamp()
os.utime(local_path, (timestamp, timestamp))

0 comments on commit a397947

Please sign in to comment.