Skip to content

Commit

Permalink
feat: max parallelization (#45)
Browse files Browse the repository at this point in the history
* feat: add max_parallelization

* chore: bump version

* refactor: formatting
  • Loading branch information
imathews committed Jun 13, 2024
1 parent 4b3ec3b commit 14437b1
Show file tree
Hide file tree
Showing 16 changed files with 967 additions and 460 deletions.
2 changes: 1 addition & 1 deletion src/redivis/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.15.2"
__version__ = "0.15.3"
19 changes: 13 additions & 6 deletions src/redivis/classes/Base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@


class Base:
def __getitem__(self, key):
return (
self.properties[key] if hasattr(self, 'properties') and self.properties is not None and key in self.properties else None
self.properties[key]
if hasattr(self, "properties")
and self.properties is not None
and key in self.properties
else None
)

def __repr__(self) -> str:
field_strings = []
for key, field in vars(self).items():
if key != 'properties' and key != 'scoped_reference' and field is not None and not isinstance(field, Base):
field_strings.append(f'{key}:{field!r}')
if (
key != "properties"
and key != "scoped_reference"
and field is not None
and not isinstance(field, Base)
):
field_strings.append(f"{key}:{field!r}")

return f"<{self.__class__.__name__} {' '.join(field_strings)}>"
return f"<{self.__class__.__name__} {' '.join(field_strings)}>"
21 changes: 17 additions & 4 deletions src/redivis/classes/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ..common.api_request import make_request, make_paginated_request


class Dataset(Base):
def __init__(
self,
Expand All @@ -19,13 +20,24 @@ def __init__(
self.user = user
self.organization = organization

if version and version != "current" and version != "next" and not version.lower().startswith("v"):
if (
version
and version != "current"
and version != "next"
and not version.lower().startswith("v")
):
version = f"v{version}"

self.qualified_reference = properties["qualifiedReference"] if "qualifiedReference" in (properties or {}) else (
f"{(self.organization or self.user).name}.{self.name}:{version}"
self.qualified_reference = (
properties["qualifiedReference"]
if "qualifiedReference" in (properties or {})
else (f"{(self.organization or self.user).name}.{self.name}:{version}")
)
self.scoped_reference = (
properties["scopedReference"]
if "scopedReference" in (properties or {})
else f"{self.name}:{version}"
)
self.scoped_reference = properties["scopedReference"] if "scopedReference" in (properties or {}) else f"{self.name}:{version}"
self.uri = f"/datasets/{quote_uri(self.qualified_reference, '')}"
self.properties = properties

Expand Down Expand Up @@ -128,6 +140,7 @@ def update(self, *, name=None, public_access_level=None, description=None):
update_properties(self, res)
return self


def update_properties(instance, properties):
instance.properties = properties
instance.qualified_reference = properties["qualifiedReference"]
Expand Down
96 changes: 70 additions & 26 deletions src/redivis/classes/Export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
id,
*,
table=None,
properties = {},
properties={},
):
self.table = table
self.properties = properties
Expand All @@ -27,7 +27,7 @@ def get(self):
self.properties = make_request(method="GET", path=self.uri)
return self

def download_files(self, *, path, overwrite, progress=True):
def download_files(self, *, path, overwrite, progress=True, max_parallelization):
# TODO: if overwriting file, first check if file is the same size, and if so check md5 hash, and skip if identical (need md5 in header)
self.wait_for_finish()
file_count = self.properties["fileCount"]
Expand All @@ -37,18 +37,27 @@ def download_files(self, *, path, overwrite, progress=True):
if path is None:
path = os.getcwd()
if file_count > 1:
if not hasattr(self.table.properties, 'name'):
if not hasattr(self.table.properties, "name"):
self.table.get()
escaped_table_name = re.sub(r'\W+', '_', self.table.properties['name']).lower()
escaped_table_name = re.sub(
r"\W+", "_", self.table.properties["name"]
).lower()
path = os.path.join(path, escaped_table_name)
elif path.endswith(os.sep) or (not os.path.exists(path) and '.' not in path):
elif path.endswith(os.sep) or (not os.path.exists(path) and "." not in path):
is_dir = True
elif file_count > 1:
raise Exception(f"Path '{path}' is a file, but the export consists of multiple files. Please specify the path to a directory")

if overwrite is False and os.path.exists(path):
raise Exception(
f"File already exists at '{path}'. Set parameter overwrite=True to overwrite existing files.")
f"Path '{path}' is a file, but the export consists of multiple files. Please specify the path to a directory"
)

if (
overwrite is False
and os.path.exists(path)
and (not is_dir or file_count > 1)
):
raise Exception(
f"File already exists at '{path}'. Set parameter overwrite=True to overwrite existing files."
)

# Make sure output directory exists
if is_dir:
Expand All @@ -58,25 +67,42 @@ def download_files(self, *, path, overwrite, progress=True):

pbar = None
if progress:
pbar = tqdm(total=self.properties['size'], leave=False, unit='iB', unit_scale=True)

pbar = tqdm(
total=self.properties["size"], leave=False, unit="iB", unit_scale=True
)

cancel_event = Event()
# Code to use multiprocess... would simplify exiting on stop, but progress doesn't currently work
# with concurrent.futures.ProcessPoolExecutor(max_workers=len(read_session["streams"]), mp_context=mp.get_context('fork')) as executor:

output_file_paths = []
# TODO: this should use async, not threads
# See https://github.com/googleapis/python-bigquery/blob/main/google/cloud/bigquery/_pandas_helpers.py#L920
with concurrent.futures.ThreadPoolExecutor(max_workers=min(8, file_count)) as executor:
futures = [executor.submit(download_file, uri=f'{self.uri}/download', download_path=path, file_number=file_number, is_dir=is_dir, overwrite=overwrite, pbar=pbar, cancel_event=cancel_event)
for file_number in range(file_count)]
with concurrent.futures.ThreadPoolExecutor(
max_workers=min(max_parallelization, file_count)
) as executor:
futures = [
executor.submit(
download_file,
uri=f"{self.uri}/download",
download_path=path,
file_number=file_number,
is_dir=is_dir,
overwrite=overwrite,
pbar=pbar,
cancel_event=cancel_event,
)
for file_number in range(file_count)
]

not_done = futures

try:
while not_done and not cancel_event.is_set():
# next line 'sleeps' this main thread, letting the thread pool run
freshly_done, not_done = concurrent.futures.wait(not_done, timeout=0.2)
freshly_done, not_done = concurrent.futures.wait(
not_done, timeout=0.2
)
for future in freshly_done:
# Call result() on any finished threads to raise any exceptions encountered.
output_file_paths.append(future.result())
Expand All @@ -93,7 +119,7 @@ def download_files(self, *, path, overwrite, progress=True):
def wait_for_finish(self, *, progress=True):
iter_count = 0
if progress:
pbar = tqdm(total=100, leave=False, unit='%', unit_scale=True)
pbar = tqdm(total=100, leave=False, unit="%", unit_scale=True)
pbar.set_description(f"Preparing download...")

while True:
Expand All @@ -105,43 +131,62 @@ def wait_for_finish(self, *, progress=True):
if progress:
pbar.close()
raise Exception(
f"Query job failed with message: {self.properties['message']}"
f"Export job failed with message: {self.properties['message']}"
)
elif self.properties["status"] == "cancelled":
if progress:
pbar.close()
raise Exception(f"Query job was cancelled")
raise Exception(f"Export job was cancelled")
else:
iter_count += 1
if progress:
pbar.update(self.properties["percentCompleted"] - pbar.n)
time.sleep(min(iter_count * 0.5, 2))
self.get()


def get_filename(s):
fname = re.findall("filename\*=([^;]+)", s, flags=re.IGNORECASE)
if not fname:
fname = re.findall("filename=([^;]+)", s, flags=re.IGNORECASE)
if "utf-8''" in fname[0].lower():
fname = re.sub("utf-8''", '', fname[0], flags=re.IGNORECASE)
fname = urllib.unquote(fname).decode('utf8')
fname = re.sub("utf-8''", "", fname[0], flags=re.IGNORECASE)
fname = urllib.unquote(fname).decode("utf8")
else:
fname = fname[0]
# clean space and double quotes
return fname.strip().strip('"')


def download_file(*, uri, file_number, download_path, is_dir=False, overwrite=False, pbar=None, cancel_event):
with closing(make_request(method="GET", path=uri, query={"filePart": file_number}, stream=True, parse_response=False)) as r:
def download_file(
*,
uri,
file_number,
download_path,
is_dir=False,
overwrite=False,
pbar=None,
cancel_event,
):
with closing(
make_request(
method="GET",
path=uri,
query={"filePart": file_number},
stream=True,
parse_response=False,
)
) as r:
if is_dir:
file_name = get_filename(r.headers.get('Content-Disposition'))
file_name = get_filename(r.headers.get("Content-Disposition"))
download_path = os.path.join(download_path, file_name)

if overwrite is False and os.path.exists(download_path):
raise Exception(
f"File already exists at '{download_path}'. Set parameter overwrite=True to overwrite existing files.")
f"File already exists at '{download_path}'. Set parameter overwrite=True to overwrite existing files."
)

with open(download_path, 'wb') as f:
with open(download_path, "wb") as f:
for chunk in r.iter_content(chunk_size=1024 * 1024):
if cancel_event.is_set():
break
Expand All @@ -150,4 +195,3 @@ def download_file(*, uri, file_number, download_path, is_dir=False, overwrite=Fa
f.write(chunk)

return download_path

56 changes: 41 additions & 15 deletions src/redivis/classes/File.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ..common.api_request import make_request
from urllib.parse import quote as quote_uri


class File(Base):
def __init__(
self,
Expand All @@ -21,15 +22,23 @@ def __init__(
self.uri = f"/rawFiles/{quote_uri(id, '')}"
self.properties = {
**{"kind": "rawFile", "id": id, "uri": self.uri},
**properties
**properties,
}

def get(self):
res = make_request(method="HEAD", path=self.uri, parse_response=False)
parse_headers(self, res)
return self

def download(self, path=None, *, overwrite=False, progress=True, on_progress=None, cancel_event=None):
def download(
self,
path=None,
*,
overwrite=False,
progress=True,
on_progress=None,
cancel_event=None,
):
is_dir = False
if path is None:
path = os.getcwd()
Expand All @@ -39,22 +48,35 @@ def download(self, path=None, *, overwrite=False, progress=True, on_progress=Non
elif os.path.exists(path) and os.path.isdir(path):
is_dir = True

with make_request(method="GET", path=f'{self.uri}', query={"allowRedirect": "true"}, stream=True, parse_response=False) as r:
with make_request(
method="GET",
path=f"{self.uri}",
query={"allowRedirect": "true"},
stream=True,
parse_response=False,
) as r:
parse_headers(self, r)
name = self.properties["name"]

file_name = os.path.join(path, name) if is_dir else path

if overwrite is False and os.path.exists(file_name):
raise Exception(f"File already exists at '{file_name}'. Set parameter overwrite=True to overwrite existing files.")
raise Exception(
f"File already exists at '{file_name}'. Set parameter overwrite=True to overwrite existing files."
)

# Make sure output directory exists
pathlib.Path(file_name).parent.mkdir(exist_ok=True, parents=True)

with open(file_name, 'wb') as f:
with open(file_name, "wb") as f:
if progress:
pbar = tqdm(total=self.properties['size'], leave=False, unit='iB', unit_scale=True)
for chunk in r.iter_content(chunk_size=1024*1024):
pbar = tqdm(
total=self.properties["size"],
leave=False,
unit="iB",
unit_scale=True,
)
for chunk in r.iter_content(chunk_size=1024 * 1024):
if cancel_event and cancel_event.is_set():
os.remove(file_name)
return None
Expand All @@ -70,34 +92,38 @@ def download(self, path=None, *, overwrite=False, progress=True, on_progress=Non
return file_name

def read(self, *, as_text=False):
r = make_request(method="GET", path=f'{self.uri}', parse_response=False)
r = make_request(method="GET", path=f"{self.uri}", parse_response=False)
parse_headers(self, r)
if as_text:
return r.text
else:
return r.content

def stream(self):
r = make_request(method="GET", path=f'{self.uri}', parse_response=False, stream=True)
r = make_request(
method="GET", path=f"{self.uri}", parse_response=False, stream=True
)
parse_headers(self, r)
return BytesIO(r.content)


def parse_headers(file, res):
file.properties["name"] = get_filename(res.headers['content-disposition'])
file.properties["name"] = get_filename(res.headers["content-disposition"])
# TODO: md5Hash from x-goog-hash
file.properties["contentType"] = res.headers['content-type']
file.properties["size"] = int(res.headers['content-length'] or res.headers['x-goog-stored-content-length'])
file.properties["contentType"] = res.headers["content-type"]
file.properties["size"] = int(
res.headers["content-length"] or res.headers["x-goog-stored-content-length"]
)


def get_filename(s):
fname = re.findall("filename\*=([^;]+)", s, flags=re.IGNORECASE)
if not fname:
fname = re.findall("filename=([^;]+)", s, flags=re.IGNORECASE)
if "utf-8''" in fname[0].lower():
fname = re.sub("utf-8''", '', fname[0], flags=re.IGNORECASE)
fname = urllib.unquote(fname).decode('utf8')
fname = re.sub("utf-8''", "", fname[0], flags=re.IGNORECASE)
fname = urllib.unquote(fname).decode("utf8")
else:
fname = fname[0]
# clean space and double quotes
return fname.strip().strip('"')
return fname.strip().strip('"')
Loading

0 comments on commit 14437b1

Please sign in to comment.