diff --git a/s3_management/manage.py b/s3_management/manage.py index ef7ae74fb..51fede761 100644 --- a/s3_management/manage.py +++ b/s3_management/manage.py @@ -1,12 +1,15 @@ #!/usr/bin/env python import argparse +import base64 +import dataclasses +import functools import time from os import path, makedirs from datetime import datetime from collections import defaultdict -from typing import Iterator, List, Type, Dict, Set, TypeVar, Optional +from typing import Iterable, List, Type, Dict, Set, TypeVar, Optional from re import sub, match, search from packaging.version import parse @@ -14,7 +17,6 @@ S3 = boto3.resource('s3') -CLIENT = boto3.client('s3') BUCKET = S3.Bucket('pytorch') ACCEPTED_FILE_EXTENSIONS = ("whl", "zip", "tar.gz") @@ -107,6 +109,23 @@ S3IndexType = TypeVar('S3IndexType', bound='S3Index') + +@dataclasses.dataclass(frozen=True) +@functools.total_ordering +class S3Object: + key: str + checksum: str | None + + def __str__(self): + return self.key + + def __eq__(self, other): + return self.key == other.key + + def __lt__(self, other): + return self.key < other.key + + def extract_package_build_time(full_package_name: str) -> datetime: result = search(PACKAGE_DATE_REGEX, full_package_name) if result is not None: @@ -124,7 +143,7 @@ def between_bad_dates(package_build_time: datetime): class S3Index: - def __init__(self: S3IndexType, objects: List[str], prefix: str) -> None: + def __init__(self: S3IndexType, objects: List[S3Object], prefix: str) -> None: self.objects = objects self.prefix = prefix.rstrip("/") self.html_name = PREFIXES_WITH_HTML[self.prefix] @@ -134,7 +153,7 @@ def __init__(self: S3IndexType, objects: List[str], prefix: str) -> None: path.dirname(obj) for obj in objects if path.dirname != prefix } - def nightly_packages_to_show(self: S3IndexType) -> Set[str]: + def nightly_packages_to_show(self: S3IndexType) -> Set[S3Object]: """Finding packages to show based on a threshold we specify Basically takes our S3 packages, normalizes the version for easier @@ -174,8 +193,8 @@ def nightly_packages_to_show(self: S3IndexType) -> Set[str]: if self.normalize_package_version(obj) in to_hide }) - def is_obj_at_root(self, obj:str) -> bool: - return path.dirname(obj) == self.prefix + def is_obj_at_root(self, obj: S3Object) -> bool: + return path.dirname(str(obj)) == self.prefix def _resolve_subdir(self, subdir: Optional[str] = None) -> str: if not subdir: @@ -187,7 +206,7 @@ def gen_file_list( self, subdir: Optional[str]=None, package_name: Optional[str] = None - ) -> Iterator[str]: + ) -> Iterable[S3Object]: objects = ( self.nightly_packages_to_show() if self.prefix == 'whl/nightly' else self.objects @@ -197,23 +216,23 @@ def gen_file_list( if package_name is not None: if self.obj_to_package_name(obj) != package_name: continue - if self.is_obj_at_root(obj) or obj.startswith(subdir): + if self.is_obj_at_root(obj) or str(obj).startswith(subdir): yield obj def get_package_names(self, subdir: Optional[str] = None) -> List[str]: return sorted(set(self.obj_to_package_name(obj) for obj in self.gen_file_list(subdir))) - def normalize_package_version(self: S3IndexType, obj: str) -> str: + def normalize_package_version(self: S3IndexType, obj: S3Object) -> str: # removes the GPU specifier from the package name as well as # unnecessary things like the file extension, architecture name, etc. return sub( r"%2B.*", "", - "-".join(path.basename(obj).split("-")[:2]) + "-".join(path.basename(str(obj)).split("-")[:2]) ) - def obj_to_package_name(self, obj: str) -> str: - return path.basename(obj).split('-', 1)[0] + def obj_to_package_name(self, obj: S3Object) -> str: + return path.basename(str(obj)).split('-', 1)[0] def to_legacy_html( self, @@ -258,7 +277,8 @@ def to_simple_package_html( out.append(' ') out.append('

Links for {}

'.format(package_name.lower().replace("_","-"))) for obj in sorted(self.gen_file_list(subdir, package_name)): - out.append(f' {path.basename(obj).replace("%2B","+")}
') + maybe_fragment = f"#sha256={obj.checksum}" if obj.checksum else "" + out.append(f' {path.basename(obj).replace("%2B","+")}
') # Adding html footer out.append(' ') out.append('') @@ -319,7 +339,6 @@ def upload_pep503_htmls(self) -> None: Body=self.to_simple_package_html(subdir=subdir, package_name=pkg_name) ) - def save_legacy_html(self) -> None: for subdir in self.subdirs: print(f"INFO Saving {subdir}/{self.html_name}") @@ -351,10 +370,18 @@ def from_S3(cls: Type[S3IndexType], prefix: str) -> S3IndexType: for pattern in ACCEPTED_SUBDIR_PATTERNS ]) and obj.key.endswith(ACCEPTED_FILE_EXTENSIONS) if is_acceptable: + # Add PEP 503-compatible hashes to URLs to allow clients to avoid spurious downloads, if possible. + response = obj.meta.client.head_object(Bucket=BUCKET.name, Key=obj.key, ChecksumMode="ENABLED") + sha256 = (_b64 := response.get("ChecksumSHA256")) and base64.b64decode(_b64).hex() sanitized_key = obj.key.replace("+", "%2B") - objects.append(sanitized_key) + s3_object = S3Object( + key=sanitized_key, + checksum=sha256, + ) + objects.append(s3_object) return cls(objects, prefix) + def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser("Manage S3 HTML indices for PyTorch") parser.add_argument( @@ -366,6 +393,7 @@ def create_parser() -> argparse.ArgumentParser: parser.add_argument("--generate-pep503", action="store_true") return parser + def main(): parser = create_parser() args = parser.parse_args() @@ -390,5 +418,6 @@ def main(): if args.generate_pep503: idx.upload_pep503_htmls() + if __name__ == "__main__": main()