Skip to content

Commit

Permalink
Merge pull request #18 from saturncloud/bhperry/rsync-callback
Browse files Browse the repository at this point in the history
Add callback on rsync CLI command
  • Loading branch information
bhperry committed Feb 27, 2024
2 parents 70f7468 + ab13a51 commit 143e126
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 38 deletions.
17 changes: 15 additions & 2 deletions saturnfs/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,19 +177,32 @@ def delete(path: str, recursive: bool):
@cli.command("rsync")
@click.argument("source_path", type=str)
@click.argument("destination_path", type=str)
@click.option("--quiet", "-q", is_flag=True, default=False, help="Do not print file operations")
@click.option(
"-d",
"--delete-missing",
is_flag=True,
default=False,
help="Delete paths from the destination that are missing in the source",
)
def rsync(source_path: str, destination_path: str, delete_missing: bool):
def rsync(source_path: str, destination_path: str, delete_missing: bool, quiet: bool):
"""
Recursively sync files between two directory trees
"""
sfs = SaturnFS()
sfs.rsync(source_path, destination_path, delete_missing=delete_missing)

src_is_local = not source_path.startswith(settings.SATURNFS_FILE_PREFIX)
dst_is_local = not destination_path.startswith(settings.SATURNFS_FILE_PREFIX)
if src_is_local and dst_is_local:
raise SaturnError(PathErrors.AT_LEAST_ONE_REMOTE_PATH)

if quiet:
callback = NoOpCallback()
else:
operation = file_op(src_is_local, dst_is_local)
callback = FileOpCallback(operation=operation)

sfs.rsync(source_path, destination_path, delete_missing=delete_missing, callback=callback)


@cli.command("ls")
Expand Down
74 changes: 45 additions & 29 deletions saturnfs/client/file_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def __init__(
self.num_workers = num_workers
self.exit_on_timeout = exit_on_timeout
self.upload_queue: Queue[Optional[UploadChunk]] = Queue(2 * self.num_workers)
self.completed_queue: Queue[Optional[ObjectStorageCompletePart]] = Queue()
self.completed_queue: Queue[Union[ObjectStorageCompletePart, UploadStop]] = Queue()
self.stop = Event()

for _ in range(self.num_workers):
Expand All @@ -461,32 +461,46 @@ def __init__(
def upload_chunks(
self, chunks: Iterable[UploadChunk], callback: Optional[Callback] = None
) -> Tuple[List[ObjectStorageCompletePart], bool]:
num_parts: int = 0
first_part: int = -1
all_chunks_read: bool = False
for chunk in chunks:
if first_part == -1:
first_part = chunk.part.part_number
if not self._put_chunk(chunk):
break
num_parts += 1
else:
all_chunks_read = True

if first_part == -1:
first_part_number = self._producer_init(chunks)
if first_part_number == -1:
# No chunks given
return [], True

self._wait()
completed_parts, uploads_finished = self._collect(first_part, num_parts, callback=callback)
completed_parts, uploads_finished = self._collect(first_part_number, callback=callback)
self.stop.clear()
return completed_parts, uploads_finished and all_chunks_read
return completed_parts, uploads_finished

def close(self):
# Signal shutdown to upload workers
for _ in range(self.num_workers):
self.upload_queue.put(None)

def _producer_init(self, chunks: Iterable[UploadChunk]) -> int:
# Grab first chunk from iterable to determine the starting part_number
first_chunk = next(iter(chunks), None)
if first_chunk is None:
return -1
self.upload_queue.put(first_chunk)

# Start producer thread
Thread(target=self._producer, kwargs={"chunks": chunks}, daemon=True).start()
return first_chunk.part.part_number

def _producer(self, chunks: Iterable[UploadChunk]):
# Iterate chunks onto the upload_queue until completed or error detected
all_chunks_read: bool = False
for chunk in chunks:
if not self._put_chunk(chunk):
break
else:
all_chunks_read = True

# Wait for workers to finish processing the queue
uploads_finished = self._wait()

# Signal end of upload to the collector
self.completed_queue.put(UploadStop(error=not (uploads_finished and all_chunks_read)))

def _put_chunk(self, chunk: UploadChunk, poll_interval: int = 5) -> bool:
while True:
try:
Expand Down Expand Up @@ -524,7 +538,7 @@ def _worker(self):
self.upload_queue.task_done()
self.completed_queue.put(completed_part)

def _wait(self):
def _wait(self) -> bool:
# Wait for workers to finish processing all chunks, or exit due to expired signatures
uploads_finished = False

Expand Down Expand Up @@ -552,24 +566,21 @@ def _uploads_finished():
# Join the thread instead of queue to ensure there is no race-condition when
# worker threads are signaled to shutdown (otherwise uploads_finished_thread could leak)
uploads_finished_thread.join()

# Signal error to the collector
self.completed_queue.put(None)
return uploads_finished

def _collect(
self,
first_part: int,
num_parts: int,
callback: Optional[Callback] = None,
):
# Collect completed parts
completed_parts: List[ObjectStorageCompletePart] = []
uploads_finished: bool = False
while True:
completed_part = self.completed_queue.get()
if completed_part is None:
# Error detected in one or more workers
# Producer only puts None on the queue when all workers are done
if isinstance(completed_part, UploadStop):
# End of upload detected
uploads_finished = not completed_part.error
self.completed_queue.task_done()
break

Expand All @@ -579,10 +590,6 @@ def _collect(
if callback is not None:
callback.relative_update(completed_part.size)

if len(completed_parts) == num_parts:
uploads_finished = True
break

completed_parts.sort(key=lambda p: p.part_number)
completed_len = len(completed_parts)
if not uploads_finished and completed_len > 0:
Expand Down Expand Up @@ -663,6 +670,15 @@ class UploadChunk:
data: Any


@dataclass
class UploadStop:
"""
Sentinel type to mark end of upload
"""

error: bool = False


def set_last_modified(local_path: str, last_modified: datetime):
timestamp = last_modified.timestamp()
os.utime(local_path, (timestamp, timestamp))
16 changes: 9 additions & 7 deletions saturnfs/client/saturnfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from fsspec.spec import AbstractBufferedFile, AbstractFileSystem, _Cached
from fsspec.utils import other_paths
from saturnfs import settings
from saturnfs.cli.callback import FileOpCallback
from saturnfs.client.file_transfer import (
DownloadPart,
FileTransferClient,
Expand Down Expand Up @@ -1248,18 +1249,19 @@ async def _cp_file(
# by put/get instead of opening as a buffered file
proto1, path1 = split_protocol(url)
proto2, path2 = split_protocol(url2)
if isinstance(callback, FileOpCallback) and not callback.inner:
callback.branch(path1, path2, kwargs)
else:
kwargs["callback"] = callback

if self._is_local(proto1) and self._is_saturnfs(proto2):
if blocksize < settings.S3_MIN_PART_SIZE:
blocksize = settings.S3_MIN_PART_SIZE
return self.sfs.put_file(
path1, path2, callback=callback, block_size=blocksize, **kwargs
)
return self.sfs.put_file(path1, path2, block_size=blocksize, **kwargs)
elif self._is_saturnfs(proto1) and self._is_local(proto2):
return self.sfs.get_file(
path1, path2, callback=callback, block_size=blocksize, **kwargs
)
return self.sfs.get_file(path1, path2, block_size=blocksize, **kwargs)

return await super()._cp_file(url, url2, blocksize, callback, **kwargs)
return await super()._cp_file(url, url2, blocksize, **kwargs)

def _is_local(self, protocol: str) -> bool:
if isinstance(LocalFileSystem.protocol, tuple):
Expand Down

0 comments on commit 143e126

Please sign in to comment.