diff --git a/s3_management/manage.py b/s3_management/manage.py index ef7ae74fb..51fede761 100644 --- a/s3_management/manage.py +++ b/s3_management/manage.py @@ -1,12 +1,15 @@ #!/usr/bin/env python import argparse +import base64 +import dataclasses +import functools import time from os import path, makedirs from datetime import datetime from collections import defaultdict -from typing import Iterator, List, Type, Dict, Set, TypeVar, Optional +from typing import Iterable, List, Type, Dict, Set, TypeVar, Optional from re import sub, match, search from packaging.version import parse @@ -14,7 +17,6 @@ S3 = boto3.resource('s3') -CLIENT = boto3.client('s3') BUCKET = S3.Bucket('pytorch') ACCEPTED_FILE_EXTENSIONS = ("whl", "zip", "tar.gz") @@ -107,6 +109,23 @@ S3IndexType = TypeVar('S3IndexType', bound='S3Index') + +@dataclasses.dataclass(frozen=True) +@functools.total_ordering +class S3Object: + key: str + checksum: str | None + + def __str__(self): + return self.key + + def __eq__(self, other): + return self.key == other.key + + def __lt__(self, other): + return self.key < other.key + + def extract_package_build_time(full_package_name: str) -> datetime: result = search(PACKAGE_DATE_REGEX, full_package_name) if result is not None: @@ -124,7 +143,7 @@ def between_bad_dates(package_build_time: datetime): class S3Index: - def __init__(self: S3IndexType, objects: List[str], prefix: str) -> None: + def __init__(self: S3IndexType, objects: List[S3Object], prefix: str) -> None: self.objects = objects self.prefix = prefix.rstrip("/") self.html_name = PREFIXES_WITH_HTML[self.prefix] @@ -134,7 +153,7 @@ def __init__(self: S3IndexType, objects: List[str], prefix: str) -> None: path.dirname(obj) for obj in objects if path.dirname != prefix } - def nightly_packages_to_show(self: S3IndexType) -> Set[str]: + def nightly_packages_to_show(self: S3IndexType) -> Set[S3Object]: """Finding packages to show based on a threshold we specify Basically takes our S3 packages, normalizes the version for easier @@ -174,8 +193,8 @@ def nightly_packages_to_show(self: S3IndexType) -> Set[str]: if self.normalize_package_version(obj) in to_hide }) - def is_obj_at_root(self, obj:str) -> bool: - return path.dirname(obj) == self.prefix + def is_obj_at_root(self, obj: S3Object) -> bool: + return path.dirname(str(obj)) == self.prefix def _resolve_subdir(self, subdir: Optional[str] = None) -> str: if not subdir: @@ -187,7 +206,7 @@ def gen_file_list( self, subdir: Optional[str]=None, package_name: Optional[str] = None - ) -> Iterator[str]: + ) -> Iterable[S3Object]: objects = ( self.nightly_packages_to_show() if self.prefix == 'whl/nightly' else self.objects @@ -197,23 +216,23 @@ def gen_file_list( if package_name is not None: if self.obj_to_package_name(obj) != package_name: continue - if self.is_obj_at_root(obj) or obj.startswith(subdir): + if self.is_obj_at_root(obj) or str(obj).startswith(subdir): yield obj def get_package_names(self, subdir: Optional[str] = None) -> List[str]: return sorted(set(self.obj_to_package_name(obj) for obj in self.gen_file_list(subdir))) - def normalize_package_version(self: S3IndexType, obj: str) -> str: + def normalize_package_version(self: S3IndexType, obj: S3Object) -> str: # removes the GPU specifier from the package name as well as # unnecessary things like the file extension, architecture name, etc. return sub( r"%2B.*", "", - "-".join(path.basename(obj).split("-")[:2]) + "-".join(path.basename(str(obj)).split("-")[:2]) ) - def obj_to_package_name(self, obj: str) -> str: - return path.basename(obj).split('-', 1)[0] + def obj_to_package_name(self, obj: S3Object) -> str: + return path.basename(str(obj)).split('-', 1)[0] def to_legacy_html( self, @@ -258,7 +277,8 @@ def to_simple_package_html( out.append('
') out.append('