Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ s3 = ["aioboto3>=15.4.0"]

[dependency-groups]
dev = [
"aiosqlite>=0.21.0",
"pytest>=8.4.2",
"pytest-aioboto3>=0.6.0",
"pytest-asyncio>=1.2.0",
"sqlalchemy>=2.0.44",
]

[tool.uv.build-backend]
Expand Down
3 changes: 2 additions & 1 deletion src/cloud_storage/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base import AsyncStorageFile
from .s3 import AsyncS3Storage

__version__ = "0.1.0"
__all__ = ["AsyncS3Storage"]
__all__ = ["AsyncStorageFile", "AsyncS3Storage"]
32 changes: 27 additions & 5 deletions src/cloud_storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,39 @@


class AsyncBaseStorage:
def get_secure_key(self, key: str) -> str:
def get_name(self, name: str) -> str:
raise NotImplementedError()

async def get_size(self, key: str) -> int:
async def get_size(self, name: str) -> int:
raise NotImplementedError()

async def get_url(self, key: str) -> str:
async def get_url(self, name: str) -> str:
raise NotImplementedError()

async def upload(self, file: BinaryIO, key: str) -> str:
async def upload(self, file: BinaryIO, name: str) -> str:
raise NotImplementedError()

async def delete(self, key: str) -> None:
async def delete(self, name: str) -> None:
raise NotImplementedError()


class AsyncStorageFile:
def __init__(self, name: str, storage: AsyncBaseStorage):
self._name: str = name
self._storage: AsyncBaseStorage = storage

@property
def name(self) -> str:
return self._name

async def get_size(self) -> int:
return await self._storage.get_size(self._name)

async def get_url(self) -> str:
return await self._storage.get_url(self._name)

async def upload(self, file: BinaryIO) -> str:
return await self._storage.upload(file=file, name=self._name)

async def delete(self) -> None:
await self._storage.delete(self._name)
Empty file.
34 changes: 34 additions & 0 deletions src/cloud_storage/integrations/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any, override
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.types import TypeDecorator, TypeEngine, Unicode

from cloud_storage.base import AsyncBaseStorage, AsyncStorageFile


class AsyncFileType(TypeDecorator[Any]):
impl: TypeEngine[Any] | type[TypeEngine[Any]] = Unicode
cache_ok: bool | None = True

def __init__(self, storage: AsyncBaseStorage, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.storage: AsyncBaseStorage = storage

@override
def process_bind_param(self, value: Any, dialect: Dialect) -> str:
if value is None:
return value
if isinstance(value, str):
return value

name = getattr(value, "name", None)
if name:
return name
return str(value)

@override
def process_result_value(
self, value: Any | None, dialect: Dialect
) -> AsyncStorageFile | None:
if value is None:
return None
return AsyncStorageFile(name=value, storage=self.storage)
36 changes: 18 additions & 18 deletions src/cloud_storage/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def _get_s3_client(self) -> Any:
)

@override
def get_secure_key(self, key: str) -> str:
parts = Path(key).parts
def get_name(self, name: str) -> str:
parts = Path(name).parts
safe_parts: list[str] = []

for part in parts:
Expand All @@ -69,13 +69,13 @@ def get_secure_key(self, key: str) -> str:
return str(safe_path)

@override
async def get_size(self, key: str) -> int:
key = self.get_secure_key(key)
async def get_size(self, name: str) -> int:
name = self.get_name(name)

async with self._get_s3_client() as s3_client:
try:
response = await s3_client.head_object(Bucket=self.bucket_name, Key=key)
return int(response.get("ContentLength", 0))
res = await s3_client.head_object(Bucket=self.bucket_name, Key=name)
return int(res.get("ContentLength", 0))
except ClientError as e:
code = e.response.get("Error", {}).get("Code")
status = e.response.get("ResponseMetadata", {}).get("HTTPStatusCode")
Expand All @@ -85,39 +85,39 @@ async def get_size(self, key: str) -> int:
raise

@override
async def get_url(self, key: str, expires_in: int = 3600) -> str:
async def get_url(self, name: str) -> str:
if self.custom_domain:
return f"{self._http_scheme}://{self.custom_domain}/{key}"
return f"{self._http_scheme}://{self.custom_domain}/{name}"
elif self.querystring_auth:
async with self._get_s3_client() as s3_client:
params = {"Bucket": self.bucket_name, "Key": key}
params = {"Bucket": self.bucket_name, "Key": name}
return await s3_client.generate_presigned_url(
"get_object", Params=params, ExpiresIn=expires_in
"get_object", Params=params
)
else:
url = f"{self._http_scheme}://{self.endpoint_url}/{self.bucket_name}/{key}"
url = f"{self._http_scheme}://{self.endpoint_url}/{self.bucket_name}/{name}"
return url

@override
async def upload(self, file: BinaryIO, key: str) -> str:
key = self.get_secure_key(key)
content_type, _ = mimetypes.guess_type(key)
async def upload(self, file: BinaryIO, name: str) -> str:
name = self.get_name(name)
content_type, _ = mimetypes.guess_type(name)
extra_args = {"ContentType": content_type or "application/octet-stream"}
if self.default_acl:
extra_args["ACL"] = self.default_acl

async with self._get_s3_client() as s3_client:
file.seek(0)
await s3_client.put_object(
Bucket=self.bucket_name, Key=key, Body=file, **extra_args
Bucket=self.bucket_name, Key=name, Body=file, **extra_args
)
return key
return name

@override
async def delete(self, key: str) -> None:
async def delete(self, name: str) -> None:
async with self._get_s3_client() as s3_client:
try:
await s3_client.delete_object(Bucket=self.bucket_name, Key=key)
await s3_client.delete_object(Bucket=self.bucket_name, Key=name)
except ClientError as e:
if e.response.get("Error", {}).get("Code") != "NoSuchKey":
raise
17 changes: 17 additions & 0 deletions tests/test_integrations/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Any
import pytest

from cloud_storage import AsyncS3Storage


@pytest.fixture
async def s3_test_storage(s3_test_env: Any) -> AsyncS3Storage:
bucket_name, endpoint_without_scheme = s3_test_env

return AsyncS3Storage(
bucket_name=bucket_name,
endpoint_url=endpoint_without_scheme,
aws_access_key_id="fake-access-key",
aws_secret_access_key="fake-secret-key",
use_ssl=False,
)
75 changes: 75 additions & 0 deletions tests/test_integrations/test_sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from io import BytesIO
from typing import Any
import pytest
from sqlalchemy import Column, Integer
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.ext.asyncio.session import async_sessionmaker
from sqlalchemy.orm import declarative_base

from cloud_storage import AsyncStorageFile
from cloud_storage.integrations.sqlalchemy import AsyncFileType

Base = declarative_base()


class Document(Base):
__tablename__: str = "documents"
id: Column[int] = Column(Integer, primary_key=True)
file: Column[str] = Column(AsyncFileType(storage=None)) # pyright: ignore[reportArgumentType]


@pytest.mark.asyncio
async def test_sqlalchemy_filetype_with_s3(s3_test_storage: Any):
storage = s3_test_storage
# assign s3_storage to file column
Document.__table__.columns.file.type.storage = storage

# create async engine and session
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async_session = async_sessionmaker(
engine, expire_on_commit=False, class_=AsyncSession
)

# create db tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

# create demo file object
file_name = "uploads/test-file.txt"
file_content = b"SQLAlchemy + S3 integration test"
file_obj = BytesIO(file_content)

# upload to s3 storage to fetch from db and test methods
await storage.upload(file_obj, file_name)

# insert record into db
async with async_session() as session:
doc = Document(file=file_name)
session.add(doc)
await session.commit()
doc_id = doc.id

# fetch record back and run tests
async with async_session() as session:
doc = await session.get(Document, doc_id)
if doc is None:
return

# check instance type
assert isinstance(doc.file, AsyncStorageFile)
assert doc.file.name == f"{file_name}"

# methods should work
url = await doc.file.get_url()
assert file_name in url

size = await doc.file.get_size()
assert size == len(file_content)

# deleting should not raise
await doc.file.delete()
size_after_delete = await storage.get_size(file_name)
assert size_after_delete == 0

# close all connections
await engine.dispose()
36 changes: 17 additions & 19 deletions tests/test_s3_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
@pytest.mark.asyncio
async def test_s3_storage_methods(s3_test_env: Any):
bucket_name, endpoint_without_scheme = s3_test_env

storage = AsyncS3Storage(
bucket_name=bucket_name,
endpoint_url=endpoint_without_scheme,
Expand All @@ -17,28 +16,27 @@ async def test_s3_storage_methods(s3_test_env: Any):
use_ssl=False,
)

file_name = "test/file.txt"
file_content = b"hello moto"
file_obj = BytesIO(file_content)

key = "test/file.txt"

# upload test
returned_key = await storage.upload(file_obj, key)
assert returned_key == storage.get_secure_key(key)
returned_name = await storage.upload(file_obj, file_name)
assert returned_name == storage.get_name(file_name)

# get url test without custom domain or querystring_auth
url = await storage.get_url(key)
assert key in url
url = await storage.get_url(file_name)
assert file_name in url

# get size test
size = await storage.get_size(key)
size = await storage.get_size(file_name)
assert size == len(file_content)

# delete test (should suceed silently)
await storage.delete(key)
await storage.delete(file_name)

# get size test after delete (should return 0)
size_after_delete = await storage.get_size(key)
size_after_delete = await storage.get_size(file_name)
assert size_after_delete == 0


Expand All @@ -55,8 +53,8 @@ async def test_s3_storage_querystring_auth(s3_test_env: Any):
querystring_auth=True,
)

key = "test/file.txt"
url = await storage.get_url(key)
name = "test/file.txt"
url = await storage.get_url(name)

assert url.count("AWSAccessKeyId=") == 1
assert url.count("Signature=") == 1
Expand All @@ -76,11 +74,11 @@ async def test_s3_storage_custom_domain(s3_test_env: Any):
custom_domain="cdn.example.com",
)

key = "test/file.txt"
url = await storage.get_url(key)
name = "test/file.txt"
url = await storage.get_url(name)

assert url.startswith("http://cdn.example.com/")
assert key in await storage.get_url(key)
assert name in await storage.get_url(name)


@pytest.mark.asyncio
Expand All @@ -93,8 +91,8 @@ async def test_get_secure_key_normalization():
use_ssl=False,
)

raw_key = "../../weird ../file name.txt"
normalized_key = storage.get_secure_key(raw_key)
raw_name = "../../weird ../file name.txt"
normalized_name = storage.get_name(raw_name)

assert ".." not in normalized_key
assert ".txt" in normalized_key
assert ".." not in normalized_name
assert ".txt" in normalized_name
Loading