Skip to content

Commit

Permalink
chore: raise domain specific errors for sql module
Browse files Browse the repository at this point in the history
Raise domain specific errors on the SQL module. This includes raising
`TransientError`s whenever an operation is retriable.

Also remove the unnecessary factory functions for the sql draw streams
and use appropriate factory class methods.

Included also are other minor fixes and improvements.
  • Loading branch information
kennedykori committed Jul 15, 2023
1 parent fde7885 commit 470582e
Show file tree
Hide file tree
Showing 12 changed files with 171 additions and 78 deletions.
20 changes: 10 additions & 10 deletions app/mods/common/src/sghi/idr/client/common/domain/etl_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,21 +134,21 @@ def _proto_definition_to_proto_instance(
dmf_def,
)
return SimpleETLProtocol(
id=protocol_definition["id"],
name=protocol_definition["name"],
description=protocol_definition.get("description"),
data_sink_factory=protocol_definition["data_sink_factory"],
data_source_factory=protocol_definition["data_source_factory"],
data_processor_factory=protocol_definition[
id=protocol_definition["id"], # pyright: ignore
name=protocol_definition["name"], # pyright: ignore
description=protocol_definition.get("description"), # pyright: ignore # noqa: E501
data_sink_factory=protocol_definition["data_sink_factory"], # pyright: ignore # noqa: E501
data_source_factory=protocol_definition["data_source_factory"], # pyright: ignore # noqa: E501
data_processor_factory=protocol_definition[ # pyright: ignore
"data_processor_factory"
],
metadata_consumer=protocol_definition[
metadata_consumer=protocol_definition[ # pyright: ignore
"metadata_consumer_factory"
](),
metadata_supplier=protocol_definition[
metadata_supplier=protocol_definition[ # pyright: ignore
"metadata_supplier_factory"
](),
drain_metadata_factory=dmf,
drain_metadata_factory=dmf, # pyright: ignore
)

@staticmethod
Expand All @@ -161,7 +161,7 @@ def _get_data_meta_factory_instance(
case Callable():
return dmf_def()
case _:
assert_never(dmf_def)
assert_never(dmf_def) # pyright: ignore


class FromFactoriesETLProtocolSupplier(ETLProtocolSupplier):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ class ProtocolDefinition(_ProtocolDefinitionOptional, total=True):
@cache
def _get_required_proto_definition_fields() -> set[str]:
all_fields: set[str] = set(
typed_dict_keys(_RawProtocolDefinition).keys(),
typed_dict_keys(_RawProtocolDefinition).keys(), # type: ignore
)
optional_fields: set[str] = set(
typed_dict_keys(_ProtocolDefinitionOptional).keys(),
typed_dict_keys(_ProtocolDefinitionOptional).keys(), # type: ignore
)
return all_fields.difference(optional_fields)

Expand Down
4 changes: 2 additions & 2 deletions app/mods/http/src/sghi/idr/client/http/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@


class HTTPTransportError(IDRClientError):
"""Unknown error occurred while performing a HTTP Transport operation."""
"""Unknown error occurred while performing an HTTP Transport operation."""

...


class HTTPTransportTransientError(HTTPTransportError, TransientError):
"""
Recoverable error occurred while performing a HTTP Transport operation.
Recoverable error occurred while performing an HTTP Transport operation.
"""

...
Expand Down
35 changes: 28 additions & 7 deletions app/mods/http/src/sghi/idr/client/http/lib/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Sequence

from requests.exceptions import (
ConnectionError,
RequestException,
Expand All @@ -9,6 +11,22 @@
from ..exceptions import HTTPTransportError, HTTPTransportTransientError
from ..typings import ResponsePredicate

# =============================================================================
# CONSTANTS
# =============================================================================

KNOWN_REQUESTS_TRANSIENT_BASE_ERRORS: Sequence[type[BaseException]] = (
ConnectionError,
RequestException,
Timeout,
TooManyRedirects,
)


# =============================================================================
# UTILITIES
# =============================================================================


def if_response_has_status_factory(
*http_status: int,
Expand Down Expand Up @@ -63,23 +81,26 @@ def if_request_accepted(response: Response) -> bool:


def to_appropriate_domain_error(
exp: RequestException,
exp: BaseException,
message: str | None = None,
known_transient_errors: tuple[type[BaseException]] = KNOWN_REQUESTS_TRANSIENT_BASE_ERRORS, # noqa: E501
) -> HTTPTransportError:
"""Map a :exp:`RequestException` to the appropriate domain error.
"""Map an exception to the appropriate domain error.
Given a `RequestException`, either map the error to an
Given an exception, either map the error to an
:exp:`HTTPTransportTransientError` if the error is retryable, or else
map it to a :exp:`HTTPTransportError`.
:param exp: A `RequestException` to be mapped to the appropriate domain
specific error.
:param exp: An exception to be mapped to the appropriate domain specific
error.
:param message: An optional error message to pass to the returned
exception.
:param known_transient_errors: A tuple of exception types that are safe to
retry.
:return: An appropriate domain specific error based on the given
`RequestException`.
exception.
"""
if isinstance(exp, ConnectionError | Timeout | TooManyRedirects):
if isinstance(exp, known_transient_errors):
return HTTPTransportTransientError(message=message)
return HTTPTransportError(message=message)
18 changes: 9 additions & 9 deletions app/mods/idr_server/src/sghi/idr/client/idr_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@

def fyj_cbs_etl_protocol_factory() -> FYJCBSETLProtocol:
return SimpleETLProtocol(
id="fyj-cbs",
name="FyJ CBS ETL Protocol",
description="Fahari ya Jamii, CBS ETL Protocol",
data_sink_factory=HTTPDataSink.from_data_sink_meta,
data_source_factory=SimpleSQLDatabase.from_data_source_meta,
data_processor_factory=IDRServerDataProcessor,
drain_metadata_factory=fyj_cbs_drain_meta_factory(),
metadata_consumer=fyj_cbs_metadata_consumer_factory(),
metadata_supplier=fyj_cbs_metadata_supplier_factory(),
id="fyj-cbs", # pyright: ignore
name="FyJ CBS ETL Protocol", # pyright: ignore
description="Fahari ya Jamii, CBS ETL Protocol", # pyright: ignore
data_sink_factory=HTTPDataSink.from_data_sink_meta, # pyright: ignore
data_source_factory=SimpleSQLDatabase.from_data_source_meta, # pyright: ignore # noqa: E501
data_processor_factory=IDRServerDataProcessor, # pyright: ignore
drain_metadata_factory=fyj_cbs_drain_meta_factory(), # pyright: ignore
metadata_consumer=fyj_cbs_metadata_consumer_factory(), # pyright: ignore # noqa: E501
metadata_supplier=fyj_cbs_metadata_supplier_factory(), # pyright: ignore # noqa: E501
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
)
from sghi.idr.client.http.lib.http_transport import HTTPTransport
from sghi.idr.client.sql.domain import (
PDDataFrameDataSourceStream,
SimpleSQLDatabaseDescriptor,
SimpleSQLQuery,
pd_data_frame_data_source_stream_factory,
)
from toolz import pipe
from toolz.curried import map
Expand Down Expand Up @@ -328,7 +328,7 @@ def handle_get_data_source_meta_response(
_result["database_name"],
),
"isolation_level": "REPEATABLE READ",
"data_source_stream_factory": pd_data_frame_data_source_stream_factory, # noqa: E501
"data_source_stream_factory": PDDataFrameDataSourceStream.of, # noqa: E501
},
),
map(lambda _kwargs: SimpleSQLDatabaseDescriptor(**_kwargs)),
Expand Down
4 changes: 0 additions & 4 deletions app/mods/sql/src/sghi/idr/client/sql/domain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
SimpleSQLDatabase,
SimpleSQLDataSourceStream,
SQLRawData,
pd_data_frame_data_source_stream_factory,
simple_data_source_stream_factory,
)

__all__ = [
Expand All @@ -28,6 +26,4 @@
"SimpleSQLDatabaseDescriptor",
"SimpleSQLDataSourceStream",
"SimpleSQLQuery",
"pd_data_frame_data_source_stream_factory",
"simple_data_source_stream_factory",
]
7 changes: 2 additions & 5 deletions app/mods/sql/src/sghi/idr/client/sql/domain/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,9 @@ class SimpleSQLDatabaseDescriptor(BaseSQLDataSourceMetadata):

@property
def data_source_stream_factory(self) -> DataSourceStreamFactory:
from .operations import pd_data_frame_data_source_stream_factory
from .operations import SimpleSQLDataSourceStream

return (
self._data_source_stream_factory
or pd_data_frame_data_source_stream_factory
)
return self._data_source_stream_factory or SimpleSQLDataSourceStream.of

@property
def database_url(self) -> str | URL:
Expand Down
84 changes: 47 additions & 37 deletions app/mods/sql/src/sghi/idr/client/sql/domain/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
)
from sghi.idr.client.core.lib import type_fqn
from sqlalchemy import Connection, CursorResult, Engine, Row, create_engine
from sqlalchemy.exc import SQLAlchemyError

from ..lib import to_appropriate_domain_error
from ..typings import ReadIsolationLevels
from .metadata import (
BaseSQLDataSourceMetadata,
Expand Down Expand Up @@ -48,37 +50,6 @@
"""The maximum number of rows to extract from a database at any one time."""


# =============================================================================
# HELPERS
# =============================================================================


def pd_data_frame_data_source_stream_factory(
sql_data_source: "BaseSQLDataSource[Any, SimpleSQLQuery, PDDataFrame]",
extract_metadata: SimpleSQLQuery,
) -> "PDDataFrameDataSourceStream":
# TODO: Add a check to ensure that the `sql_data_source` given is not
# disposed.
return PDDataFrameDataSourceStream(
sql_data_source,
extract_metadata,
sql_data_source.engine.connect(),
)


def simple_data_source_stream_factory(
sql_data_source: "BaseSQLDataSource[Any, SimpleSQLQuery, SQLRawData]",
extract_metadata: SimpleSQLQuery,
) -> "SimpleSQLDataSourceStream":
# TODO: Add a check to ensure that the `sql_data_source` given is not
# disposed.
return SimpleSQLDataSourceStream(
sql_data_source,
extract_metadata,
sql_data_source.engine.connect(),
)


# =============================================================================
# BASE OPERATIONS CLASSES
# =============================================================================
Expand All @@ -92,14 +63,11 @@ class BaseSQLDataSource(
):
"""An SQL Database."""

_engine: Engine = field()
_data_source_stream_factory: Callable[
["BaseSQLDataSource[Any, _EM, _RD]", _EM],
DataSourceStream[_EM, _RD],
] = field(
default=simple_data_source_stream_factory,
kw_only=True,
)
_engine: Engine = field()
] = field()

def __attrs_post_init__(self) -> None:
self._logger: Logger = logging.getLogger(type_fqn(self.__class__))
Expand Down Expand Up @@ -209,7 +177,7 @@ def of_sqlite_in_memory(
data_source_stream_factory: Callable[
["SimpleSQLDatabase[Any]", SimpleSQLQuery],
DataSourceStream[SimpleSQLQuery, Any],
] = simple_data_source_stream_factory,
] | None = None,
) -> Self:
return cls(
name=name, # pyright: ignore
Expand Down Expand Up @@ -275,12 +243,34 @@ def draw(self) -> tuple[PDDataFrame, float]:
self.draw_metadata.name,
)
raise DataSourceStream.StopDraw from None
except SQLAlchemyError as exp:
_err_msg: str = "Error while drawing from sql source."
self._logger.exception(_err_msg)
raise to_appropriate_domain_error(exp, message=_err_msg) from exp

def dispose(self) -> None:
self._is_disposed = True
self._connection.close()
self._logger.debug("Disposal complete.")

@classmethod
def of(
cls,
sql_data_source: BaseSQLDataSource[Any, SimpleSQLQuery, PDDataFrame],
draw_metadata: SimpleSQLQuery,
) -> Self:
try:
return cls(
data_source=sql_data_source, # pyright: ignore
draw_metadata=draw_metadata, # pyright: ignore
connection=sql_data_source.engine.connect(), # pyright: ignore
)
except SQLAlchemyError as exp:
_err_msg: str = (
"Unable to initialize 'PDDataFrameDataSourceStream'."
)
raise to_appropriate_domain_error(exp, message=_err_msg) from exp


@define(order=False, slots=True)
class SimpleSQLDataSourceStream(
Expand Down Expand Up @@ -334,9 +324,29 @@ def draw(self) -> tuple[SQLRawData, float]:
self.draw_metadata.name,
)
raise DataSourceStream.StopDraw from None
except SQLAlchemyError as exp:
_err_msg: str = "Error while drawing from sql source."
self._logger.exception(_err_msg)
raise to_appropriate_domain_error(exp, message=_err_msg) from exp

def dispose(self) -> None:
self._is_disposed = True
self._extraction_result.close()
self._connection.close()
self._logger.debug("Disposal complete.")

@classmethod
def of(
cls,
sql_data_source: BaseSQLDataSource[Any, SimpleSQLQuery, PDDataFrame],
draw_metadata: SimpleSQLQuery,
) -> Self:
try:
return cls(
data_source=sql_data_source, # pyright: ignore
draw_metadata=draw_metadata, # pyright: ignore
connection=sql_data_source.engine.connect(), # pyright: ignore
)
except SQLAlchemyError as exp:
_err_msg: str = "Unable to initialize 'SimpleSQLDataSourceStream'."
raise to_appropriate_domain_error(exp, message=_err_msg) from exp
11 changes: 11 additions & 0 deletions app/mods/sql/src/sghi/idr/client/sql/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from sghi.idr.client.core.exceptions import IDRClientError, TransientError


class SQLError(IDRClientError):
"""An ambiguous error occurred while operating on SQL resources."""
...


class SQLTransientError(SQLError, TransientError):
"""A recoverable error occurred while operating on an SQL resource."""
...
8 changes: 8 additions & 0 deletions app/mods/sql/src/sghi/idr/client/sql/lib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .common import to_appropriate_domain_error
from .config import DBInstanceConfig, DBInstancesInitializer

__all__ = [
"DBInstanceConfig",
"DBInstancesInitializer",
"to_appropriate_domain_error",
]

0 comments on commit 470582e

Please sign in to comment.