Skip to content

Commit

Permalink
Add timezone and use_utc parameters into RegisterTortoise class. (
Browse files Browse the repository at this point in the history
#1649)

* add use_tz and timezone in RegisterTortoise

* add tests for east-8 timezone

* Update main.py

* Update _tests.py

* add code to adjust the timezone
  • Loading branch information
Abeautifulsnow committed Jun 18, 2024
1 parent 222aaf5 commit bc18384
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 6 deletions.
38 changes: 37 additions & 1 deletion examples/fastapi/_tests.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# mypy: no-disallow-untyped-decorators
# pylint: disable=E0611,E0401
import datetime
from typing import AsyncGenerator

import pytest
import pytz
from asgi_lifespan import LifespanManager
from httpx import ASGITransport, AsyncClient
from main import app
from main import app, app_east
from models import Users


Expand Down Expand Up @@ -33,3 +35,37 @@ async def test_create_user(client: AsyncClient) -> None: # nosec

user_obj = await Users.get(id=user_id)
assert user_obj.id == user_id


@pytest.fixture(scope="module")
async def client_east() -> AsyncGenerator[AsyncClient, None]:
async with LifespanManager(app_east):
transport = ASGITransport(app=app_east)
async with AsyncClient(transport=transport, base_url="http://test") as c:
yield c


@pytest.mark.anyio
async def test_create_user_east(client_east: AsyncClient) -> None: # nosec
response = await client_east.post("/users_east", json={"username": "admin"})
assert response.status_code == 200, response.text
data = response.json()
assert data["username"] == "admin"
assert "id" in data
user_id = data["id"]

user_obj = await Users.get(id=user_id)
assert user_obj.id == user_id

# Verify that the time zone is East 8.
created_at = user_obj.created_at

# Asia/Shanghai timezone
asia_tz = pytz.timezone("Asia/Shanghai")
asia_now = datetime.datetime.now(pytz.utc).astimezone(asia_tz)
assert created_at.hour - asia_now.hour == 0

# UTC timezone
utc_tz = pytz.timezone("UTC")
utc_now = datetime.datetime.now(pytz.utc).astimezone(utc_tz)
assert created_at.hour - utc_now.hour == 8
50 changes: 50 additions & 0 deletions examples/fastapi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,26 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# db connections closed


@asynccontextmanager
async def lifespan_east(app: FastAPI) -> AsyncGenerator[None, None]:
# app startup
async with RegisterTortoise(
app,
db_url="sqlite://:memory:",
modules={"models": ["models"]},
generate_schemas=True,
add_exception_handlers=True,
use_tz=False,
timezone="Asia/Shanghai",
):
# db connected
yield
# app teardown
# db connections closed


app = FastAPI(title="Tortoise ORM FastAPI example", lifespan=lifespan)
app_east = FastAPI(title="Tortoise ORM FastAPI example", lifespan=lifespan_east)


class Status(BaseModel):
Expand Down Expand Up @@ -72,3 +91,34 @@ async def delete_user(user_id: int):
if not deleted_count:
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
return Status(message=f"Deleted user {user_id}")


############################ East 8 ############################
@app_east.get("/users_east", response_model=List[User_Pydantic])
async def get_users_east():
return await User_Pydantic.from_queryset(Users.all())


@app_east.post("/users_east", response_model=User_Pydantic)
async def create_user_east(user: UserIn_Pydantic):
user_obj = await Users.create(**user.model_dump(exclude_unset=True))
return await User_Pydantic.from_tortoise_orm(user_obj)


@app_east.get("/user_east/{user_id}", response_model=User_Pydantic)
async def get_user_east(user_id: int):
return await User_Pydantic.from_queryset_single(Users.get(id=user_id))


@app_east.put("/user_east/{user_id}", response_model=User_Pydantic)
async def update_user_east(user_id: int, user: UserIn_Pydantic):
await Users.filter(id=user_id).update(**user.model_dump(exclude_unset=True))
return await User_Pydantic.from_queryset_single(Users.get(id=user_id))


@app_east.delete("/user_east/{user_id}", response_model=Status)
async def delete_user_east(user_id: int):
deleted_count = await Users.filter(id=user_id).delete()
if not deleted_count:
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
return Status(message=f"Deleted user {user_id}")
32 changes: 27 additions & 5 deletions tortoise/contrib/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,31 +96,53 @@ def __init__(
modules: Optional[Dict[str, Iterable[Union[str, ModuleType]]]] = None,
generate_schemas: bool = False,
add_exception_handlers: bool = False,
use_tz: bool = False,
timezone: str = "UTC",
) -> None:
self.app = app
self.config = config
self.config_file = config_file
self.db_url = db_url
self.modules = modules
self.generate_schemas = generate_schemas
self.use_tz = use_tz
self.timezone = timezone

if add_exception_handlers:

@app.exception_handler(DoesNotExist)
async def doesnotexist_exception_handler(request: "Request", exc: DoesNotExist):
async def doesnotexist_exception_handler(
request: "Request", exc: DoesNotExist
):
return JSONResponse(status_code=404, content={"detail": str(exc)})

@app.exception_handler(IntegrityError)
async def integrityerror_exception_handler(request: "Request", exc: IntegrityError):
async def integrityerror_exception_handler(
request: "Request", exc: IntegrityError
):
return JSONResponse(
status_code=422,
content={"detail": [{"loc": [], "msg": str(exc), "type": "IntegrityError"}]},
content={
"detail": [
{"loc": [], "msg": str(exc), "type": "IntegrityError"}
]
},
)

async def init_orm(self) -> None: # pylint: disable=W0612
config, config_file = self.config, self.config_file
db_url, modules = self.db_url, self.modules
await Tortoise.init(config=config, config_file=config_file, db_url=db_url, modules=modules)
logger.info("Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps)
await Tortoise.init(
config=config,
config_file=config_file,
db_url=db_url,
modules=modules,
use_tz=self.use_tz,
timezone=self.timezone,
)
logger.info(
"Tortoise-ORM started, %s, %s", connections._get_storage(), Tortoise.apps
)
if self.generate_schemas:
logger.info("Tortoise-ORM generating schema")
await Tortoise.generate_schemas()
Expand Down

0 comments on commit bc18384

Please sign in to comment.