Skip to content

Commit

Permalink
Override default get_file for faster download
Browse files Browse the repository at this point in the history
  • Loading branch information
bhperry committed Jun 12, 2023
1 parent c46bf32 commit bc1e67f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
27 changes: 20 additions & 7 deletions saturnfs/client/file_transfer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from io import BytesIO
import os
from datetime import datetime
from typing import Any, BinaryIO, List, Optional, Tuple

from fsspec import Callback
from saturnfs import settings
from saturnfs.client.aws import AWSPresignedClient
from saturnfs.errors import ExpiredSignature
from saturnfs.schemas import (
Expand Down Expand Up @@ -70,27 +72,38 @@ def download(
presigned_download: ObjectStoragePresignedDownload,
local_path: str,
callback: Optional[Callback] = None,
block_size: int = settings.S3_MIN_PART_SIZE,
):
dirname = os.path.dirname(local_path)
if dirname:
os.makedirs(dirname, exist_ok=True)

with open(local_path, "wb") as f:
self.download_outfile(
presigned_download, f, callback=callback, block_size=block_size
)
set_last_modified(local_path, presigned_download.updated_at)

def download_outfile(
self,
presigned_download: ObjectStoragePresignedDownload,
outfile: BytesIO,
callback: Optional[Callback] = None,
block_size: int = settings.S3_MIN_PART_SIZE,
):
response = self.aws.get(presigned_download.url, stream=True)
if callback is not None:
content_length = response.headers.get("Content-Length")
callback.set_size(int(content_length) if content_length else None)

with open(local_path, "wb") as f:
for chunk in response.iter_content(None):
chunk_size = f.write(chunk)
if callback is not None:
callback.relative_update(chunk_size)
for chunk in response.iter_content(block_size):
bytes_written = outfile.write(chunk)
if callback is not None:
callback.relative_update(bytes_written)

if callback is not None and callback.size == 0:
callback.relative_update(0)

set_last_modified(local_path, presigned_download.updated_at)

def close(self):
self.aws.close()

Expand Down
7 changes: 6 additions & 1 deletion saturnfs/client/saturnfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,12 @@ def get_file(
outfile: Optional[BytesIO] = None,
**kwargs,
):
super().get_file(rpath, lpath, callback=callback, outfile=outfile, **kwargs)
remote = ObjectStorage.parse(rpath)
download = self.object_storage_client.download_file(remote)
if outfile is not None:
self.file_transfer.download_outfile(download, outfile, callback=callback, **kwargs)
else:
self.file_transfer.download(download, lpath, callback=callback, **kwargs)

def get_bulk(self, rpaths: List[str], lpaths: List[str], callback: Callback = DEFAULT_CALLBACK):
callback.set_size(len(lpaths))
Expand Down

0 comments on commit bc1e67f

Please sign in to comment.