Skip to content

Commit

Permalink
folder_name arg is optional for S3MediaStorage (#179)
Browse files Browse the repository at this point in the history
  • Loading branch information
dantownsend committed Aug 25, 2022
1 parent e07a7c1 commit 336cf56
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 14 deletions.
41 changes: 27 additions & 14 deletions piccolo_api/media/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
self,
column: t.Union[Text, Varchar, Array],
bucket_name: str,
folder_name: str,
folder_name: t.Optional[str] = None,
connection_kwargs: t.Dict[str, t.Any] = None,
sign_urls: bool = True,
signed_url_expiry: int = 3600,
Expand Down Expand Up @@ -165,6 +165,13 @@ async def store_file(

return file_key

def _prepend_folder_name(self, file_key: str) -> str:
folder_name = self.folder_name
if folder_name:
return str(pathlib.Path(folder_name, file_key))
else:
return file_key

def store_file_sync(
self, file_name: str, file: t.IO, user: t.Optional[BaseUser] = None
) -> str:
Expand All @@ -182,7 +189,7 @@ def store_file_sync(
client.upload_fileobj(
file,
self.bucket_name,
str(pathlib.Path(self.folder_name, file_key)),
self._prepend_folder_name(file_key),
ExtraArgs=upload_metadata,
)

Expand Down Expand Up @@ -225,7 +232,7 @@ def generate_file_url_sync(
ClientMethod="get_object",
Params={
"Bucket": self.bucket_name,
"Key": str(pathlib.Path(self.folder_name, file_key)),
"Key": self._prepend_folder_name(file_key),
},
ExpiresIn=self.signed_url_expiry,
)
Expand All @@ -249,7 +256,7 @@ def get_file_sync(self, file_key: str) -> t.Optional[t.IO]:
s3_client = self.get_client()
response = s3_client.get_object(
Bucket=self.bucket_name,
Key=str(pathlib.Path(self.folder_name, file_key)),
Key=self._prepend_folder_name(file_key),
)
return response["Body"]

Expand All @@ -273,7 +280,7 @@ def delete_file_sync(self, file_key: str):
s3_client = self.get_client()
return s3_client.delete_object(
Bucket=self.bucket_name,
Key=str(pathlib.Path(self.folder_name, file_key)),
Key=self._prepend_folder_name(file_key),
)

async def bulk_delete_files(self, file_keys: t.List[str]):
Expand All @@ -289,7 +296,6 @@ def bulk_delete_files_sync(self, file_keys: t.List[str]):

batch_size = 100
iteration = 0
folder_name = self.folder_name

while True:
batch = file_keys[
Expand All @@ -305,7 +311,9 @@ def bulk_delete_files_sync(self, file_keys: t.List[str]):
Bucket=self.bucket_name,
Delete={
"Objects": [
{"Key": str(pathlib.Path(folder_name, file_key))}
{
"Key": self._prepend_folder_name(file_key),
}
for file_key in file_keys
],
},
Expand All @@ -323,13 +331,16 @@ def get_file_keys_sync(self) -> t.List[str]:
start_after = None

while True:
extra_kwargs: t.Dict[str, t.Any] = (
{"StartAfter": start_after} if start_after else {}
)
extra_kwargs: t.Dict[str, t.Any] = {}

if start_after:
extra_kwargs["StartAfter"] = start_after

if self.folder_name:
extra_kwargs["Prefix"] = f"{self.folder_name}/"

response = s3_client.list_objects_v2(
Bucket=self.bucket_name,
Prefix=self.folder_name,
**extra_kwargs,
)

Expand All @@ -344,9 +355,11 @@ def get_file_keys_sync(self) -> t.List[str]:
# https://github.com/nedbat/coveragepy/issues/772
break # pragma: no cover

prefix = f"{self.folder_name}/"

return [i.lstrip(prefix) for i in keys]
if self.folder_name:
prefix = f"{self.folder_name}/"
return [i.lstrip(prefix) for i in keys]
else:
return keys

async def get_file_keys(self) -> t.List[str]:
"""
Expand Down
133 changes: 133 additions & 0 deletions tests/media/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,136 @@ def test_unsigned(self, get_client: MagicMock, uuid_module: MagicMock):
url,
f"https://{bucket_name}.s3.amazonaws.com/{folder_name}/{file_key}", # noqa: E501
)

@patch("piccolo_api.media.base.uuid")
@patch("piccolo_api.media.s3.S3MediaStorage.get_client")
def test_no_folder(self, get_client: MagicMock, uuid_module: MagicMock):
"""
Make sure we can store files, and retrieve them when the
``folder_name`` is ``None``.
"""
uuid_module.uuid4.return_value = uuid.UUID(
"fd0125c7-8777-4976-83c1-81605d5ab155"
)
bucket_name = "bucket123"

with mock_s3():
s3 = boto3.resource("s3", region_name="us-east-1")
s3.create_bucket(Bucket=bucket_name)

connection_kwargs = {
"aws_access_key_id": "abc123",
"aws_secret_access_key": "xyz123",
"region_name": "us-east-1",
}

get_client.return_value = boto3.client("s3", **connection_kwargs)

storage = S3MediaStorage(
column=Movie.poster,
bucket_name=bucket_name,
folder_name=None,
connection_kwargs=connection_kwargs,
upload_metadata={
"ACL": "public-read",
"Metadata": {"visibility": "premium"},
"CacheControl": "max-age=86400",
},
)

with open(
os.path.join(os.path.dirname(__file__), "test_files/bulb.jpg"),
"rb",
) as test_file:
# Store the file
file_key = asyncio.run(
storage.store_file(file_name="bulb.jpg", file=test_file)
)

# Retrieve the URL for the file
url = asyncio.run(
storage.generate_file_url(file_key, root_url="")
)

path, params = url.split("?", 1)

self.assertEqual(
path,
f"https://{bucket_name}.s3.amazonaws.com/{file_key}", # noqa: E501
)

# We're parsing a string like this:
# AWSAccessKeyId=abc123&Signature=abc123&Expires=1659437428
params_list = [i.split("=") for i in params.split("&")]

params_dict = {i[0]: i[1] for i in params_list}

self.assertEqual(
params_dict["AWSAccessKeyId"],
connection_kwargs["aws_access_key_id"],
)
self.assertIn("Signature", params_dict)
self.assertIn("Expires", params_dict)

# Get the file
file = asyncio.run(storage.get_file(file_key=file_key))
assert file is not None
self.assertEqual(
file.read(),
# We need to reopen the test file, in case it's closed:
open(test_file.name, "rb").read(),
)

# List file keys
file_keys = asyncio.run(storage.get_file_keys())
self.assertListEqual(file_keys, [file_key])

# Delete the file
asyncio.run(storage.delete_file(file_key=file_key))
file_keys = asyncio.run(storage.get_file_keys())
self.assertListEqual(file_keys, [])

# Test bulk deletion
file_keys = []
for file_name in ("file_1.txt", "file_2.txt", "file_3.txt"):
file = io.BytesIO(b"test")
file_key = asyncio.run(
storage.store_file(file_name=file_name, file=file)
)
file_keys.append(file_key)

asyncio.run(storage.bulk_delete_files(file_keys=file_keys[:2]))

self.assertListEqual(
asyncio.run(storage.get_file_keys()), file_keys[2:]
)


class TestFolderName(TestCase):
"""
Make sure the folder name is correctly added to the file key.
"""

def test_with_folder_name(self):
storage = S3MediaStorage(
column=Movie.poster,
bucket_name="test_bucket",
folder_name="test_folder",
connection_kwargs={},
)
self.assertEqual(
storage._prepend_folder_name(file_key="abc123.jpeg"),
"test_folder/abc123.jpeg",
)

def test_without_folder_name(self):
storage = S3MediaStorage(
column=Movie.poster,
bucket_name="test_bucket",
folder_name=None,
connection_kwargs={},
)
self.assertEqual(
storage._prepend_folder_name(file_key="abc123.jpeg"),
"abc123.jpeg",
)

0 comments on commit 336cf56

Please sign in to comment.