Skip to content
This repository has been archived by the owner on Sep 12, 2023. It is now read-only.

Commit

Permalink
feat(health): Custom health checks (#237)
Browse files Browse the repository at this point in the history
* ✨ feat(health): custom healtchecks

* ♻️ refactor(health): rename sqlachemy engine logging name

* ♻️ refactor(health): always return all health checks

* 🐛 fix: linters

* ♻️ refactor(health): add AppHealthCheck

* ♻️ refactor: remove commented code

* ✅ test(health check): fix integration

* 🐛 fix: pyright
  • Loading branch information
gazorby committed Jan 13, 2023
1 parent b76a503 commit bfe64f4
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 22 deletions.
4 changes: 4 additions & 0 deletions src/starlite_saqlalchemy/exceptions.py
Expand Up @@ -51,6 +51,10 @@ class AuthorizationError(StarliteSaqlalchemyError):
"""A user tried to do something they shouldn't have."""


class HealthCheckConfigurationError(StarliteSaqlalchemyError):
"""An error occurred while registering an health check."""


class _HTTPConflictException(HTTPException):
"""Request conflict with the current state of the target resource."""

Expand Down
100 changes: 89 additions & 11 deletions src/starlite_saqlalchemy/health.py
Expand Up @@ -4,25 +4,103 @@
"""
from __future__ import annotations

import contextlib
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from sqlalchemy.ext.asyncio import AsyncSession
from starlite import get
from pydantic import BaseModel
from starlite import Controller, Response, get
from starlite.exceptions import ServiceUnavailableException

from starlite_saqlalchemy import settings
from starlite_saqlalchemy.repository.sqlalchemy import SQLAlchemyRepository
from starlite_saqlalchemy.settings import AppSettings

if TYPE_CHECKING:
from typing import Any

from starlite import Request


class HealthCheckFailure(ServiceUnavailableException):
"""Raise for health check failure."""

def __init__(
self,
health: dict[str, bool],
*args: Any,
detail: str = "",
status_code: int | None = None,
headers: dict[str, str] | None = None,
extra: dict[str, Any] | list[Any] | None = None,
) -> None:
"""Initialize HealthCheckFailure with an additional health arg."""
super().__init__(*args, detail, status_code, headers, extra)
self.health = health


class AbstractHealthCheck(ABC):
"""Base protocol for implementing health checks."""

name: str = ""

async def live(self) -> bool:
"""Run a liveness check.
Returns:
True if the service is running, False otherwise
"""
return await self.ready() # pragma: no cover

@abstractmethod
async def ready(self) -> bool:
"""Run readiness check.
Returns:
True if the service is ready to serve requests, False otherwise
"""


class AppHealthCheck(AbstractHealthCheck):
"""Simple health check that does not require any dependencies."""

name = "app"

async def ready(self) -> bool:
"""Readiness check used when no other health check is available."""
return True


class HealthResource(BaseModel):
"""Health data returned by the health endpoint."""

app: AppSettings
health: dict[str, bool]


def health_failure_exception_handler(
_: Request, exc: HealthCheckFailure
) -> Response[HealthResource]:
"""Return all health checks data on `HealthCheckFailure`."""
return Response(
status_code=HealthCheckFailure.status_code,
content=HealthResource(app=settings.app, health=exc.health),
)


class HealthController(Controller):
"""Holds health endpoints."""

exception_handlers = {HealthCheckFailure: health_failure_exception_handler}
health_checks: list[AbstractHealthCheck] = []

@get(path=settings.api.HEALTH_PATH, tags=["Misc"])
async def health_check(db_session: AsyncSession) -> AppSettings:
"""Check database available and returns app config info."""
with contextlib.suppress(Exception):
if await SQLAlchemyRepository.check_health(db_session):
return settings.app
raise HealthCheckFailure("DB not ready.")
@get(path=settings.api.HEALTH_PATH, tags=["Misc"], raises=[HealthCheckFailure])
async def health_check(self) -> HealthResource:
"""Run registered health checks."""
health: dict[str, bool] = {}
for health_check in self.health_checks:
try:
health[health_check.name] = await health_check.ready()
except Exception: # pylint: disable=broad-except
health[health_check.name] = False
if not all(health.values()):
raise HealthCheckFailure(health=health)
return HealthResource(app=settings.app, health=health)
19 changes: 17 additions & 2 deletions src/starlite_saqlalchemy/init_plugin.py
Expand Up @@ -51,15 +51,22 @@ def example_handler() -> dict:
settings,
sqlalchemy_plugin,
)
from starlite_saqlalchemy.health import health_check
from starlite_saqlalchemy.exceptions import HealthCheckConfigurationError
from starlite_saqlalchemy.health import (
AbstractHealthCheck,
AppHealthCheck,
HealthController,
)
from starlite_saqlalchemy.service import make_service_callback
from starlite_saqlalchemy.sqlalchemy_plugin import SQLAlchemyHealthCheck
from starlite_saqlalchemy.type_encoders import type_encoders_map
from starlite_saqlalchemy.worker import create_worker_instance

if TYPE_CHECKING:
from starlite.config.app import AppConfig
from starlite.types import TypeEncodersMap


T = TypeVar("T")


Expand Down Expand Up @@ -160,6 +167,7 @@ class PluginConfig(BaseModel):
"""Chain of structlog log processors."""
type_encoders: TypeEncodersMap = type_encoders_map
"""Map of type to serializer callable."""
health_checks: Sequence[type[AbstractHealthCheck]] = [AppHealthCheck, SQLAlchemyHealthCheck]


class ConfigureApp:
Expand Down Expand Up @@ -283,7 +291,14 @@ def configure_health_check(self, app_config: AppConfig) -> None:
app_config: The Starlite application config object.
"""
if self.config.do_health_check:
app_config.route_handlers.append(health_check)
healt_checks: list[AbstractHealthCheck] = []
for health_check in self.config.health_checks:
health_check_instance = health_check()
if not health_check_instance.name:
raise HealthCheckConfigurationError(f"{health_check}.name must be set.")
healt_checks.append(health_check_instance)
HealthController.health_checks = healt_checks
app_config.route_handlers.append(HealthController)

def configure_logging(self, app_config: AppConfig) -> None:
"""Configure application logging.
Expand Down
2 changes: 1 addition & 1 deletion src/starlite_saqlalchemy/settings.py
Expand Up @@ -135,7 +135,7 @@ class Config:
SAQ_LEVEL: int = 30
"""Level to log SAQ logs."""
SQLALCHEMY_LEVEL: int = 30
"""Level to log SAQ logs."""
"""Level to log SQLAlchemy logs."""
UVICORN_ACCESS_LEVEL: int = 30
"""Level to log uvicorn access logs."""
UVICORN_ERROR_LEVEL: int = 20
Expand Down
26 changes: 26 additions & 0 deletions src/starlite_saqlalchemy/sqlalchemy_plugin.py
Expand Up @@ -3,13 +3,16 @@

from typing import TYPE_CHECKING, cast

from sqlalchemy import text
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from starlite.plugins.sql_alchemy import SQLAlchemyConfig, SQLAlchemyPlugin
from starlite.plugins.sql_alchemy.config import (
SESSION_SCOPE_KEY,
SESSION_TERMINUS_ASGI_EVENTS,
)

from starlite_saqlalchemy import db, settings
from starlite_saqlalchemy.health import AbstractHealthCheck

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
Expand Down Expand Up @@ -40,6 +43,29 @@ async def before_send_handler(message: "Message", _: "State", scope: "Scope") ->
del scope[SESSION_SCOPE_KEY] # type:ignore[misc]


class SQLAlchemyHealthCheck(AbstractHealthCheck):
"""SQLAlchemy health check."""

name: str = "db"

def __init__(self) -> None:
self.engine = create_async_engine(
settings.db.URL, logging_name="starlite_saqlalchemy.health"
)
self.session_maker = async_sessionmaker(bind=self.engine)

async def ready(self) -> bool:
"""Perform a health check on the database.
Returns:
`True` if healthy.
"""
async with self.session_maker() as session:
return ( # type:ignore[no-any-return] # pragma: no cover
await session.execute(text("SELECT 1"))
).scalar_one() == 1


config = SQLAlchemyConfig(
before_send_handler=before_send_handler,
dependency_key=settings.api.DB_SESSION_DEPENDENCY_KEY,
Expand Down
8 changes: 7 additions & 1 deletion tests/integration/conftest.py
Expand Up @@ -19,6 +19,8 @@
from starlite import Provide, Router

from starlite_saqlalchemy import db, sqlalchemy_plugin, worker
from starlite_saqlalchemy.health import AppHealthCheck, HealthController
from starlite_saqlalchemy.sqlalchemy_plugin import SQLAlchemyHealthCheck
from tests.utils import controllers

if TYPE_CHECKING:
Expand Down Expand Up @@ -187,11 +189,15 @@ async def _seed_db(engine: AsyncEngine, authors: list[Author]) -> abc.AsyncItera

@pytest.fixture(autouse=True)
def _patch_db(app: Starlite, engine: AsyncEngine, monkeypatch: pytest.MonkeyPatch) -> None:
session_maker = async_sessionmaker(bind=engine)
monkeypatch.setitem(app.state, sqlalchemy_plugin.config.engine_app_state_key, engine)
sqla_health_check = SQLAlchemyHealthCheck()
monkeypatch.setattr(sqla_health_check, "session_maker", session_maker)
monkeypatch.setattr(HealthController, "health_checks", [AppHealthCheck(), sqla_health_check])
monkeypatch.setitem(
app.state,
sqlalchemy_plugin.config.session_maker_app_state_key,
async_sessionmaker(bind=engine),
session_maker,
)


Expand Down
81 changes: 74 additions & 7 deletions tests/unit/test_health.py
Expand Up @@ -2,10 +2,19 @@
from typing import TYPE_CHECKING
from unittest.mock import AsyncMock

import pytest
from starlite import Starlite
from starlite.status_codes import HTTP_200_OK, HTTP_503_SERVICE_UNAVAILABLE

from starlite_saqlalchemy import settings
from starlite_saqlalchemy.repository.sqlalchemy import SQLAlchemyRepository
from starlite_saqlalchemy import init_plugin, settings
from starlite_saqlalchemy.exceptions import HealthCheckConfigurationError
from starlite_saqlalchemy.health import (
AbstractHealthCheck,
AppHealthCheck,
HealthController,
HealthResource,
)
from starlite_saqlalchemy.sqlalchemy_plugin import SQLAlchemyHealthCheck

if TYPE_CHECKING:
from pytest import MonkeyPatch
Expand All @@ -17,25 +26,83 @@ def test_health_check(client: "TestClient", monkeypatch: "MonkeyPatch") -> None:
Checks that we call the repository method and the response content.
"""
repo_health_mock = AsyncMock()
monkeypatch.setattr(SQLAlchemyRepository, "check_health", repo_health_mock)
repo_health_mock = AsyncMock(return_value=True)
monkeypatch.setattr(SQLAlchemyHealthCheck, "ready", repo_health_mock)
resp = client.get(settings.api.HEALTH_PATH)
assert resp.status_code == HTTP_200_OK
assert resp.json() == settings.app.dict()
health = HealthResource(
app=settings.app,
health={SQLAlchemyHealthCheck.name: True, AppHealthCheck.name: True},
)
assert resp.json() == health.dict()
repo_health_mock.assert_called_once()


def test_health_check_false_response(client: "TestClient", monkeypatch: "MonkeyPatch") -> None:
"""Test health check response if check method returns `False`"""
repo_health_mock = AsyncMock(return_value=False)
monkeypatch.setattr(SQLAlchemyRepository, "check_health", repo_health_mock)
monkeypatch.setattr(SQLAlchemyHealthCheck, "ready", repo_health_mock)
resp = client.get(settings.api.HEALTH_PATH)
assert resp.status_code == HTTP_503_SERVICE_UNAVAILABLE
health = HealthResource(
app=settings.app,
health={SQLAlchemyHealthCheck.name: False, AppHealthCheck.name: True},
)
assert resp.json() == health.dict()


def test_health_check_exception_raised(client: "TestClient", monkeypatch: "MonkeyPatch") -> None:
"""Test expected response from check if exception raised in handler."""
repo_health_mock = AsyncMock(side_effect=ConnectionError)
monkeypatch.setattr(SQLAlchemyRepository, "check_health", repo_health_mock)
monkeypatch.setattr(SQLAlchemyHealthCheck, "ready", repo_health_mock)
resp = client.get(settings.api.HEALTH_PATH)
assert resp.status_code == HTTP_503_SERVICE_UNAVAILABLE
health = HealthResource(
app=settings.app,
health={SQLAlchemyHealthCheck.name: False, AppHealthCheck.name: True},
)
assert resp.json() == health.dict()


def test_health_custom_health_check(client: "TestClient", monkeypatch: "MonkeyPatch") -> None:
"""Test registering custom health checks."""

class MyHealthCheck(AbstractHealthCheck):
"""Custom health check."""

name = "MyHealthCheck"

async def ready(self) -> bool:
"""Readiness check."""
return False

HealthController.health_checks.append(MyHealthCheck())
repo_health_mock = AsyncMock(return_value=True)
monkeypatch.setattr(SQLAlchemyHealthCheck, "ready", repo_health_mock)
resp = client.get(settings.api.HEALTH_PATH)
assert resp.status_code == HTTP_503_SERVICE_UNAVAILABLE
health = HealthResource(
app=settings.app,
health={
AppHealthCheck.name: True,
SQLAlchemyHealthCheck.name: True,
MyHealthCheck.name: False,
},
)
assert resp.json() == health.dict()


def test_health_check_no_name_error() -> None:
"""Test registering an health check without specifying its name raise an
error."""

class MyHealthCheck(AbstractHealthCheck):
"""Custom health check."""

async def ready(self) -> bool:
"""Readiness check."""
return False

config = init_plugin.PluginConfig(health_checks=[MyHealthCheck])
with pytest.raises(HealthCheckConfigurationError):
Starlite(route_handlers=[], on_app_init=[init_plugin.ConfigureApp(config=config)])

0 comments on commit bfe64f4

Please sign in to comment.