From a679f5a5243445733deb88d20dd04b3e21362973 Mon Sep 17 00:00:00 2001 From: Aryan Yadav <149820333+Mintsolester@users.noreply.github.com> Date: Wed, 20 Aug 2025 13:25:06 +0000 Subject: [PATCH] feat(typings): add typings support for list_models --- psqlpy_sqlalchemy/dialect.py | 194 ++++++++++++++++++++--------------- 1 file changed, 110 insertions(+), 84 deletions(-) diff --git a/psqlpy_sqlalchemy/dialect.py b/psqlpy_sqlalchemy/dialect.py index 110f82f..a38e426 100644 --- a/psqlpy_sqlalchemy/dialect.py +++ b/psqlpy_sqlalchemy/dialect.py @@ -1,9 +1,7 @@ +from __future__ import annotations from types import ModuleType import typing as t from collections import deque -from collections.abc import MutableMapping, Sequence -from typing import Any, Optional, Tuple - import psqlpy from psqlpy import row_factories from sqlalchemy import URL, util @@ -20,21 +18,24 @@ if t.TYPE_CHECKING: from sqlalchemy.engine.interfaces import DBAPICursor, _DBAPICursorDescription +# ---------- Type Aliases ---------- +Row = t.Tuple[t.Any, ...] +Rows = t.Tuple[Row, ...] + +# ---------- PG-specific colspecs ---------- class _PGString(sqltypes.String): - render_bind_cast = True + render_bind_cast: t.ClassVar[bool] = True class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): __visit_name__ = "json_int_index" - - render_bind_cast = True + render_bind_cast: t.ClassVar[bool] = True class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): __visit_name__ = "json_str_index" - - render_bind_cast = True + render_bind_cast: t.ClassVar[bool] = True class _PGJSONPathType(JSONPathType): @@ -42,70 +43,79 @@ class _PGJSONPathType(JSONPathType): class _PGInterval(INTERVAL): - render_bind_cast = True + render_bind_cast: t.ClassVar[bool] = True class _PGTimeStamp(sqltypes.DateTime): - render_bind_cast = True + render_bind_cast: t.ClassVar[bool] = True class _PGDate(sqltypes.Date): - render_bind_cast = True + render_bind_cast: t.ClassVar[bool] = True class _PGTime(sqltypes.Time): - render_bind_cast = True + render_bind_cast: t.ClassVar[bool] = True class _PGInteger(sqltypes.Integer): - render_bind_cast = True + render_bind_cast: t.ClassVar[bool] = True class _PGSmallInteger(sqltypes.SmallInteger): - render_bind_cast = True + render_bind_cast: t.ClassVar[bool] = True class _PGNullType(sqltypes.NullType): - render_bind_cast = True + render_bind_cast: t.ClassVar[bool] = True class _PGBigInteger(sqltypes.BigInteger): - render_bind_cast = True + render_bind_cast: t.ClassVar[bool] = True class _PGBoolean(sqltypes.Boolean): - render_bind_cast = True + render_bind_cast: t.ClassVar[bool] = True +# ---------- Execution context ---------- class PGExecutionContext_psqlpy(PGExecutionContext): def create_server_side_cursor(self) -> "DBAPICursor": return self._dbapi_connection.cursor(server_side=True) +# ---------- Async cursors ---------- class AsyncAdapt_psqlpy_cursor(AsyncAdapt_dbapi_cursor): __slots__ = ( "_arraysize", "_description", "_invalidate_schema_cache_asof", "_rowcount", + "_rows", + "_cursor", + "_adapt_connection", + "_connection", ) _adapt_connection: "AsyncAdapt_psqlpy_connection" _connection: psqlpy.Connection + _rows: t.Deque[Row] + _cursor: t.Optional[psqlpy.Cursor] - def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection): - self._adapt_connection = adapt_connection - self._connection = adapt_connection._connection + def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection) -> None: + self._adapt_connection = t.cast("AsyncAdapt_psqlpy_connection", adapt_connection) + self._connection = t.cast(psqlpy.Connection, adapt_connection._connection) self._rows = deque() - self._description: t.Optional[t.List[t.Tuple[t.Any, ...]]] = None - self._arraysize = 1 - self._rowcount = -1 - self._invalidate_schema_cache_asof = 0 + self._description: t.Optional["_DBAPICursorDescription"] = None + self._arraysize: int = 1 + self._rowcount: int = -1 + self._invalidate_schema_cache_asof: int = 0 + self._cursor = None async def _prepare_execute( self, querystring: str, - parameters: t.Union[t.Sequence[t.Any], t.Mapping[str, Any], None] = None, + parameters: t.Union[t.Sequence[t.Any], t.Mapping[str, t.Any], None] = None, ) -> None: if self._adapt_connection._transaction: await self._adapt_connection._start_transaction() @@ -114,29 +124,29 @@ async def _prepare_execute( querystring=querystring, parameters=parameters, ) - self._description = [ - (column.name, column.table_oid, None, None, None, None, None) - for column in prepared_stmt.columns() - ] + self._description = t.cast( + "_DBAPICursorDescription", + [ + (column.name, column.table_oid, None, None, None, None, None) + for column in prepared_stmt.columns() + ], + ) if self.server_side: - self._cursor = self._connection.cursor( - querystring, - parameters, - ) + self._cursor = self._connection.cursor(querystring, parameters) await self._cursor.start() self._rowcount = -1 return results = await prepared_stmt.execute() - rows: Tuple[Tuple[Any, ...], ...] = tuple( + rows: Rows = tuple( tuple(value for _, value in row) for row in results.row_factory(row_factories.tuple_row) ) self._rows = deque(rows) @property - def description(self) -> "Optional[_DBAPICursorDescription]": + def description(self) -> "t.Optional[_DBAPICursorDescription]": return self._description @property @@ -157,13 +167,12 @@ async def _executemany( seq_of_parameters: t.Sequence[t.Sequence[t.Any]], ) -> None: adapt_connection = self._adapt_connection - self._description = None if not adapt_connection._started: await adapt_connection._start_transaction() - return await self._connection.execute_many( + await self._connection.execute_many( operation, seq_of_parameters, True, @@ -171,15 +180,17 @@ async def _executemany( def execute( self, - operation: t.Any, - parameters: t.Union[t.Sequence[t.Any], t.Mapping[str, Any], None] = None, + operation: str, + parameters: t.Union[t.Sequence[t.Any], t.Mapping[str, t.Any], None] = None, ) -> None: await_only(self._prepare_execute(operation, parameters)) - def executemany(self, operation, seq_of_parameters) -> None: - return await_only(self._executemany(operation, seq_of_parameters)) + def executemany( + self, operation: str, seq_of_parameters: t.Sequence[t.Sequence[t.Any]] + ) -> None: + await_only(self._executemany(operation, seq_of_parameters)) - def setinputsizes(self, *inputsizes): + def setinputsizes(self, *inputsizes: t.Any) -> t.NoReturn: raise NotImplementedError @@ -189,53 +200,50 @@ class AsyncAdapt_psqlpy_ss_cursor( ): _cursor: psqlpy.Cursor - def __init__(self, adapt_connection): + def __init__(self, adapt_connection: AsyncAdapt_psqlpy_connection) -> None: self._adapt_connection = adapt_connection self._connection = adapt_connection._connection self.await_ = adapt_connection.await_ - self._cursor = self._connection.cursor() - def _convert_result( - self, - result: psqlpy.QueryResult, - ) -> Tuple[Tuple[Any, ...], ...]: + def _convert_result(self, result: psqlpy.QueryResult) -> Rows: return tuple( tuple(value for _, value in row) for row in result.row_factory(row_factories.tuple_row) ) - def close(self): + def close(self) -> None: if self._cursor is not None: self._cursor.close() - self._cursor = None + self._cursor = None # type: ignore[assignment] - def fetchone(self): + def fetchone(self) -> Rows: result = self.await_(self._cursor.fetchone()) - return self._convert_result(result=result) + return self._convert_result(result) - def fetchmany(self, size=None): + def fetchmany(self, size: t.Optional[int] = None) -> Rows: result = self.await_(self._cursor.fetchmany(size=size)) - return self._convert_result(result=result) + return self._convert_result(result) - def fetchall(self): + def fetchall(self) -> Rows: result = self.await_(self._cursor.fetchall()) - return self._convert_result(result=result) + return self._convert_result(result) - def __iter__(self): + def __iter__(self) -> t.Iterator[Rows]: iterator = self._cursor.__aiter__() while True: try: result = self.await_(iterator.__anext__()) - rows = self._convert_result(result=result) + rows = self._convert_result(result) yield rows except StopAsyncIteration: break +# ---------- Async connection ---------- class AsyncAdapt_psqlpy_connection(AsyncAdapt_dbapi_connection): - _cursor_cls = AsyncAdapt_psqlpy_cursor - _ss_cursor_cls = AsyncAdapt_psqlpy_ss_cursor + _cursor_cls: t.Type[AsyncAdapt_psqlpy_cursor] = AsyncAdapt_psqlpy_cursor + _ss_cursor_cls: t.Type[AsyncAdapt_psqlpy_ss_cursor] = AsyncAdapt_psqlpy_ss_cursor _connection: psqlpy.Connection @@ -251,7 +259,18 @@ class AsyncAdapt_psqlpy_connection(AsyncAdapt_dbapi_connection): "readonly", ) - def __init__(self, dbapi, connection): + # attributes managed here or by base: + _invalidate_schema_cache_asof: int + _isolation_setting: t.Optional[psqlpy.IsolationLevel] + _prepared_statement_cache: t.Any + _prepared_statement_name_func: t.Any + _started: bool + _transaction: t.Optional[psqlpy.Transaction] + deferrable: bool + isolation_level: t.Optional[psqlpy.IsolationLevel] + readonly: t.Union[bool, psqlpy.ReadVariant] + + def __init__(self, dbapi: t.Any, connection: psqlpy.Connection) -> None: super().__init__(dbapi, connection) self.isolation_level = self._isolation_setting = None self.readonly = False @@ -264,7 +283,7 @@ async def _start_transaction(self) -> None: await transaction.begin() self._transaction = transaction - def set_isolation_level(self, level): + def set_isolation_level(self, level: t.Optional[psqlpy.IsolationLevel]) -> None: self.isolation_level = self._isolation_setting = level def rollback(self) -> None: @@ -275,41 +294,48 @@ def commit(self) -> None: await_only(self._connection.commit()) self._transaction = None - def close(self): + def close(self) -> None: self.rollback() self._connection.close() - def cursor(self, server_side=False): + def cursor(self, server_side: bool = False) -> t.Union[ + AsyncAdapt_psqlpy_ss_cursor, AsyncAdapt_psqlpy_cursor + ]: if server_side: return self._ss_cursor_cls(self) return self._cursor_cls(self) +# ---------- DBAPI facade ---------- class PSQLPyAdaptDBAPI: - def __init__(self, psqlpy) -> None: - self.psqlpy = psqlpy + paramstyle: t.ClassVar[str] = "numeric_dollar" + + def __init__(self, psqlpy_mod: ModuleType) -> None: + self.psqlpy = psqlpy_mod self.paramstyle = "numeric_dollar" for k, v in self.psqlpy.__dict__.items(): if k != "connect": self.__dict__[k] = v - def connect(self, *arg, **kw): + def connect(self, *arg: t.Any, **kw: t.Any) -> AsyncAdapt_psqlpy_connection: creator_fn = kw.pop("async_creator_fn", self.psqlpy.connect) - return AsyncAdapt_psqlpy_connection(self, await_only(creator_fn(*arg, **kw))) + conn = t.cast(psqlpy.Connection, await_only(creator_fn(*arg, **kw))) + return AsyncAdapt_psqlpy_connection(self, conn) +# ---------- Dialect ---------- class PSQLPyAsyncDialect(PGDialect): - driver = "psqlpy" - is_async = True + driver: t.ClassVar[str] = "psqlpy" + is_async: t.ClassVar[bool] = True - execution_ctx_cls = PGExecutionContext_psqlpy - supports_statement_cache = True - supports_server_side_cursors = True - default_paramstyle = "numeric_dollar" - supports_sane_multi_rowcount = True + execution_ctx_cls: t.ClassVar[t.Type[PGExecutionContext]] = PGExecutionContext_psqlpy + supports_statement_cache: t.ClassVar[bool] = True + supports_server_side_cursors: t.ClassVar[bool] = True + default_paramstyle: t.ClassVar[str] = "numeric_dollar" + supports_sane_multi_rowcount: t.ClassVar[bool] = True - colspecs = util.update_copy( + colspecs: t.ClassVar[t.Dict[t.Any, t.Any]] = util.update_copy( PGDialect.colspecs, { sqltypes.String: _PGString, @@ -342,29 +368,29 @@ def _isolation_lookup(self) -> t.Dict[str, psqlpy.IsolationLevel]: def set_isolation_level( self, dbapi_connection: AsyncAdapt_psqlpy_connection, - level, - ): + level: str, + ) -> None: dbapi_connection.set_isolation_level(self._isolation_lookup[level]) - def set_readonly(self, connection, value): + def set_readonly(self, connection: AsyncAdapt_psqlpy_connection, value: bool) -> None: if value is True: connection.readonly = psqlpy.ReadVariant.ReadOnly else: connection.readonly = psqlpy.ReadVariant.ReadWrite - def get_readonly(self, connection): + def get_readonly(self, connection: AsyncAdapt_psqlpy_connection) -> t.Union[bool, psqlpy.ReadVariant]: return connection.readonly - def set_deferrable(self, connection, value): + def set_deferrable(self, connection: AsyncAdapt_psqlpy_connection, value: bool) -> None: connection.deferrable = value - def get_deferrable(self, connection): + def get_deferrable(self, connection: AsyncAdapt_psqlpy_connection) -> bool: return connection.deferrable def create_connect_args( self, url: URL, - ) -> Tuple[Sequence[str], MutableMapping[str, Any]]: + ) -> t.Tuple[t.Sequence[str], t.MutableMapping[str, t.Any]]: opts = url.translate_connect_args() return ( [], @@ -378,4 +404,4 @@ def create_connect_args( ) -dialect = PSQLPyAsyncDialect +dialect: PSQLPyAsyncDialect = PSQLPyAsyncDialect()