diff --git a/tests/__init__.py b/tests/__init__.py index 77f8b78..4b94f09 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -116,11 +116,11 @@ async def ensure_bucket(self): location = { "LocationConstraint": self.region_name, } - async with s3client.get_client() as client: - try: - await client.create_bucket( - Bucket=self.bucket_name, - CreateBucketConfiguration=location, - ) - except client.exceptions.BucketAlreadyOwnedByYou: - pass + client = await s3client.get_client() + try: + await client.create_bucket( + Bucket=self.bucket_name, + CreateBucketConfiguration=location, + ) + except client.exceptions.BucketAlreadyOwnedByYou: + pass diff --git a/thumbor_aws/s3_client.py b/thumbor_aws/s3_client.py index 4011582..0ea27b0 100755 --- a/thumbor_aws/s3_client.py +++ b/thumbor_aws/s3_client.py @@ -21,7 +21,7 @@ class S3Client: - __session: AioSession = None + __client: AioBaseClient = None context: Context = None configuration: Dict[str, object] = None @@ -83,20 +83,21 @@ def file_acl(self) -> str: @property def session(self) -> AioSession: - """Singleton Session used for connecting with AWS""" - if self.__session is None: - self.__session = get_session() - return self.__session - - def get_client(self) -> AioBaseClient: - """Gets a connected client to use for S3""" - return self.session.create_client( - "s3", - region_name=self.region_name, - aws_secret_access_key=self.secret_access_key, - aws_access_key_id=self.access_key_id, - endpoint_url=self.endpoint_url, - ) + """Session used for connecting with AWS""" + return get_session() + + async def get_client(self) -> AioBaseClient: + """Singleton client for S3""" + if self.__client is None: + client = self.session.create_client( + "s3", + region_name=self.region_name, + aws_secret_access_key=self.secret_access_key, + aws_access_key_id=self.access_key_id, + endpoint_url=self.endpoint_url, + ) + self.__client = await client.__aenter__() + return self.__client async def upload( self, @@ -107,88 +108,88 @@ async def upload( ) -> str: """Uploads a File to S3""" - async with self.get_client() as client: - response = None - try: - settings = dict( - Bucket=self.bucket_name, - Key=path, - Body=data, - ContentType=content_type, - ) - if self.file_acl is not None: - settings["ACL"] = self.file_acl - - response = await client.put_object(**settings) - except Exception as error: - msg = f"Unable to upload image to {path}: {error} ({type(error)})" - logger.error(msg) - raise RuntimeError(msg) # pylint: disable=raise-missing-from - status_code = self.get_status_code(response) - if status_code != 200: - msg = f"Unable to upload image to {path}: Status Code {status_code}" - logger.error(msg) - raise RuntimeError(msg) - - location = self.get_location(response) - if location is None: - msg = ( - f"Unable to process response from AWS to {path}: " - "Location Headers was not found in response" - ) - logger.warning(msg) - location = default_location.format( - bucket_name=self.bucket_name - ) - - return f"{location.rstrip('/')}/{path.lstrip('/')}" + client = await self.get_client() + response = None + try: + settings = dict( + Bucket=self.bucket_name, + Key=path, + Body=data, + ContentType=content_type, + ) + if self.file_acl is not None: + settings["ACL"] = self.file_acl + + response = await client.put_object(**settings) + except Exception as error: + msg = f"Unable to upload image to {path}: {error} ({type(error)})" + logger.error(msg) + raise RuntimeError(msg) # pylint: disable=raise-missing-from + status_code = self.get_status_code(response) + if status_code != 200: + msg = f"Unable to upload image to {path}: Status Code {status_code}" + logger.error(msg) + raise RuntimeError(msg) + + location = self.get_location(response) + if location is None: + msg = ( + f"Unable to process response from AWS to {path}: " + "Location Headers was not found in response" + ) + logger.warning(msg) + location = default_location.format( + bucket_name=self.bucket_name + ) + + return f"{location.rstrip('/')}/{path.lstrip('/')}" async def get_data( self, path: str, expiration: int = _default ) -> (int, bytes, Optional[datetime.datetime]): """Gets an object's data from S3""" - async with self.get_client() as client: - try: - response = await client.get_object( - Bucket=self.bucket_name, Key=path - ) - except client.exceptions.NoSuchKey: - return 404, b"", None + client = await self.get_client() + try: + response = await client.get_object( + Bucket=self.bucket_name, Key=path + ) + except client.exceptions.NoSuchKey: + return 404, b"", None - status_code = self.get_status_code(response) - if status_code != 200: - msg = f"Unable to upload image to {path}: Status Code {status_code}" - logger.error(msg) - return status_code, msg, None + status_code = self.get_status_code(response) + if status_code != 200: + msg = f"Unable to upload image to {path}: Status Code {status_code}" + logger.error(msg) + return status_code, msg, None - last_modified = response["LastModified"] - if self._is_expired(last_modified, expiration): - return 410, b"", last_modified + last_modified = response["LastModified"] + if self._is_expired(last_modified, expiration): + return 410, b"", last_modified - body = await self.get_body(response) + body = await self.get_body(response) - return status_code, body, last_modified + return status_code, body, last_modified async def object_exists(self, filepath: str): """Detects whether an object exists in S3""" - async with self.get_client() as client: - try: - await client.get_object_acl( - Bucket=self.bucket_name, Key=filepath - ) - return True - except client.exceptions.NoSuchKey: - return False + client = await self.get_client() + try: + await client.get_object_acl( + Bucket=self.bucket_name, Key=filepath + ) + return True + except client.exceptions.NoSuchKey: + return False async def get_object_acl(self, filepath: str): """Gets an object's metadata""" - async with self.get_client() as client: - return await client.get_object_acl( - Bucket=self.bucket_name, Key=filepath - ) + client = await self.get_client() + return await client.get_object_acl( + Bucket=self.bucket_name, Key=filepath + ) def get_status_code(self, response: Mapping[str, Any]) -> int: """Gets the status code from an AWS response object""" diff --git a/thumbor_aws/storage.py b/thumbor_aws/storage.py index 9bb2e5a..ca74349 100644 --- a/thumbor_aws/storage.py +++ b/thumbor_aws/storage.py @@ -173,17 +173,17 @@ async def remove(self, path: str): if not exists: return - async with self.get_client() as client: - normalized_path = self.normalize_path(path) - response = await client.delete_object( - Bucket=self.bucket_name, - Key=normalized_path, + client = await self.get_client() + normalized_path = self.normalize_path(path) + response = await client.delete_object( + Bucket=self.bucket_name, + Key=normalized_path, + ) + status = self.get_status_code(response) + if status >= 300: + raise RuntimeError( + f"Failed to remove {normalized_path}: Status {status}" ) - status = self.get_status_code(response) - if status >= 300: - raise RuntimeError( - f"Failed to remove {normalized_path}: Status {status}" - ) def normalize_path(self, path: str) -> str: """Returns the path used for storage"""