From 714b74478313a69042bf61cc42a585fc6dddf2e4 Mon Sep 17 00:00:00 2001 From: CaselIT Date: Fri, 10 Feb 2023 22:24:11 +0100 Subject: [PATCH] Improve typing to accommodate sqlalchemy v2 Index name can be null. Fixes: #1168 Change-Id: Id7c944e19a9facd7d3862d43f84fd70aedace999 --- alembic/autogenerate/compare.py | 13 ++++++------- alembic/autogenerate/render.py | 13 +++++-------- alembic/ddl/postgresql.py | 7 ++++--- alembic/op.pyi | 2 +- alembic/operations/batch.py | 7 ++++--- alembic/operations/ops.py | 21 ++++++++------------- alembic/operations/schemaobj.py | 2 +- alembic/util/sqla_compat.py | 17 +++++++++++++++++ setup.cfg | 1 + tox.ini | 3 +-- 10 files changed, 48 insertions(+), 38 deletions(-) diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index 828a4cd5..c2181b8c 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -640,16 +640,15 @@ def _compare_indexes_and_uniques( or sqla_compat._constraint_is_named(c.const, autogen_context.dialect) } + conn_uniques_by_name: Dict[sqla_compat._ConstraintName, _uq_constraint_sig] + conn_indexes_by_name: Dict[sqla_compat._ConstraintName, _ix_constraint_sig] + conn_uniques_by_name = {c.name: c for c in conn_unique_constraints} - conn_indexes_by_name: Dict[Optional[str], _ix_constraint_sig] = { - c.name: c for c in conn_indexes_sig - } + conn_indexes_by_name = {c.name: c for c in conn_indexes_sig} conn_names = { c.name: c - for c in conn_unique_constraints.union( - conn_indexes_sig # type:ignore[arg-type] - ) - if c.name is not None + for c in conn_unique_constraints.union(conn_indexes_sig) + if sqla_compat.constraint_name_defined(c.name) } doubled_constraints = { diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index 4a144db7..dc841f83 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -17,6 +17,7 @@ from sqlalchemy import sql from sqlalchemy import types as sqltypes from sqlalchemy.sql.elements import conv +from sqlalchemy.sql.elements import quoted_name from .. import util from ..operations import ops @@ -26,12 +27,10 @@ from typing import Literal from sqlalchemy.sql.elements import ColumnElement - from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.schema import CheckConstraint from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Constraint - from sqlalchemy.sql.schema import DefaultClause from sqlalchemy.sql.schema import FetchedValue from sqlalchemy.sql.schema import ForeignKey from sqlalchemy.sql.schema import ForeignKeyConstraint @@ -55,12 +54,12 @@ def _render_gen_name( autogen_context: AutogenContext, - name: Optional[Union[quoted_name, str]], + name: sqla_compat._ConstraintName, ) -> Optional[Union[quoted_name, str, _f_name]]: if isinstance(name, conv): return _f_name(_alembic_autogenerate_prefix(autogen_context), name) else: - return name + return sqla_compat.constraint_name_or_none(name) def _indent(text: str) -> str: @@ -554,7 +553,7 @@ def _ident(name: Optional[Union[quoted_name, str]]) -> Optional[str]: """ if name is None: return name - elif isinstance(name, sql.elements.quoted_name): + elif isinstance(name, quoted_name): return str(name) elif isinstance(name, str): return name @@ -721,9 +720,7 @@ def _render_column(column: Column, autogen_context: AutogenContext) -> str: } -def _should_render_server_default_positionally( - server_default: Union[Computed, DefaultClause] -) -> bool: +def _should_render_server_default_positionally(server_default: Any) -> bool: return sqla_compat._server_default_is_computed( server_default ) or sqla_compat._server_default_is_identity(server_default) diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index 32674d2a..e7c85bdc 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -419,7 +419,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): def __init__( self, - constraint_name: Optional[str], + constraint_name: sqla_compat._ConstraintName, table_name: Union[str, quoted_name], elements: Union[ Sequence[Tuple[str, str]], @@ -443,7 +443,6 @@ def from_constraint( # type:ignore[override] cls, constraint: ExcludeConstraint ) -> CreateExcludeConstraintOp: constraint_table = sqla_compat._table_for_constraint(constraint) - return cls( constraint.name, constraint_table.name, @@ -451,7 +450,9 @@ def from_constraint( # type:ignore[override] (expr, op) for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa ], - where=constraint.where, + where=cast( + "Optional[Union[BinaryExpression, str]]", constraint.where + ), schema=constraint_table.schema, _orig_constraint=constraint, deferrable=constraint.deferrable, diff --git a/alembic/op.pyi b/alembic/op.pyi index 5c089e83..7a5710eb 100644 --- a/alembic/op.pyi +++ b/alembic/op.pyi @@ -576,7 +576,7 @@ def create_foreign_key( """ def create_index( - index_name: str, + index_name: Optional[str], table_name: str, columns: Sequence[Union[str, TextClause, Function]], schema: Optional[str] = None, diff --git a/alembic/operations/batch.py b/alembic/operations/batch.py index 0c773c68..fe32eec2 100644 --- a/alembic/operations/batch.py +++ b/alembic/operations/batch.py @@ -33,6 +33,7 @@ from ..util.sqla_compat import _remove_column_from_collection from ..util.sqla_compat import _resolve_for_variant from ..util.sqla_compat import _select +from ..util.sqla_compat import constraint_name_defined if TYPE_CHECKING: from typing import Literal @@ -268,7 +269,7 @@ def _grab_table_elements(self) -> None: # because # we have no way to determine _is_type_bound() for these. pass - elif const.name: + elif constraint_name_defined(const.name): self.named_constraints[const.name] = const else: self.unnamed_constraints.append(const) @@ -662,7 +663,7 @@ def drop_table_comment(self, table): """ def add_constraint(self, const: Constraint) -> None: - if not const.name: + if not constraint_name_defined(const.name): raise ValueError("Constraint must have a name") if isinstance(const, sql_schema.PrimaryKeyConstraint): if self.table.primary_key in self.unnamed_constraints: @@ -681,7 +682,7 @@ def drop_constraint(self, const: Constraint) -> None: if col_const.name == const.name: self.columns[col.name].constraints.remove(col_const) else: - assert const.name + assert constraint_name_defined(const.name) const = self.named_constraints.pop(const.name) except KeyError: if _is_type_bound(const): diff --git a/alembic/operations/ops.py b/alembic/operations/ops.py index 3cdd170d..8e1144bb 100644 --- a/alembic/operations/ops.py +++ b/alembic/operations/ops.py @@ -154,10 +154,7 @@ def to_diff_tuple( return ("remove_constraint", self.to_constraint()) @classmethod - def from_constraint( - cls, - constraint: Constraint, - ) -> DropConstraintOp: + def from_constraint(cls, constraint: Constraint) -> DropConstraintOp: types = { "unique_constraint": "unique", "foreign_key_constraint": "foreignkey", @@ -169,7 +166,7 @@ def from_constraint( constraint_table = sqla_compat._table_for_constraint(constraint) return cls( - constraint.name, + sqla_compat.constraint_name_or_none(constraint.name), constraint_table.name, schema=constraint_table.schema, type_=types[constraint.__visit_name__], @@ -274,9 +271,8 @@ def __init__( def from_constraint(cls, constraint: Constraint) -> CreatePrimaryKeyOp: constraint_table = sqla_compat._table_for_constraint(constraint) pk_constraint = cast("PrimaryKeyConstraint", constraint) - return cls( - pk_constraint.name, + sqla_compat.constraint_name_or_none(pk_constraint.name), constraint_table.name, pk_constraint.columns.keys(), schema=constraint_table.schema, @@ -411,7 +407,7 @@ def from_constraint( kw["initially"] = uq_constraint.initially kw.update(uq_constraint.dialect_kwargs) return cls( - uq_constraint.name, + sqla_compat.constraint_name_or_none(uq_constraint.name), constraint_table.name, [c.name for c in uq_constraint.columns], schema=constraint_table.schema, @@ -567,7 +563,7 @@ def from_constraint(cls, constraint: Constraint) -> CreateForeignKeyOp: kw["referent_schema"] = target_schema kw.update(fk_constraint.dialect_kwargs) return cls( - fk_constraint.name, + sqla_compat.constraint_name_or_none(fk_constraint.name), source_table, target_table, source_columns, @@ -753,9 +749,8 @@ def from_constraint( constraint_table = sqla_compat._table_for_constraint(constraint) ck_constraint = cast("CheckConstraint", constraint) - return cls( - ck_constraint.name, + sqla_compat.constraint_name_or_none(ck_constraint.name), constraint_table.name, cast("ColumnElement[Any]", ck_constraint.sqltext), schema=constraint_table.schema, @@ -863,7 +858,7 @@ class CreateIndexOp(MigrateOperation): def __init__( self, - index_name: str, + index_name: Optional[str], table_name: str, columns: Sequence[Union[str, TextClause, ColumnElement[Any]]], schema: Optional[str] = None, @@ -914,7 +909,7 @@ def to_index( def create_index( cls, operations: Operations, - index_name: str, + index_name: Optional[str], table_name: str, columns: Sequence[Union[str, TextClause, Function]], schema: Optional[str] = None, diff --git a/alembic/operations/schemaobj.py b/alembic/operations/schemaobj.py index dfda8bbe..ba09b3bb 100644 --- a/alembic/operations/schemaobj.py +++ b/alembic/operations/schemaobj.py @@ -235,7 +235,7 @@ def column(self, name: str, type_: TypeEngine, **kw) -> Column: def index( self, - name: str, + name: Optional[str], tablename: Optional[str], columns: Sequence[Union[str, TextClause, ColumnElement[Any]]], schema: Optional[str] = None, diff --git a/alembic/util/sqla_compat.py b/alembic/util/sqla_compat.py index 23255be3..9bdcfc3b 100644 --- a/alembic/util/sqla_compat.py +++ b/alembic/util/sqla_compat.py @@ -26,6 +26,7 @@ from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.visitors import traverse +from typing_extensions import TypeGuard if TYPE_CHECKING: from sqlalchemy import Index @@ -103,6 +104,22 @@ def _safe_int(value: str) -> Union[int, str]: _identity_attrs = _identity_options_attrs + ("on_null",) has_identity = True +if sqla_2: + from sqlalchemy.sql.base import _NoneName +else: + from sqlalchemy.util import symbol as _NoneName # type: ignore[assignment] + +_ConstraintName = Union[None, str, _NoneName] + + +def constraint_name_defined(name: _ConstraintName) -> TypeGuard[str]: + return isinstance(name, str) + + +def constraint_name_or_none(name: _ConstraintName) -> Optional[str]: + return name if constraint_name_defined(name) else None + + AUTOINCREMENT_DEFAULT = "auto" diff --git a/setup.cfg b/setup.cfg index 0d9ce1a7..4741eb72 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,7 @@ install_requires = Mako importlib-metadata;python_version<"3.9" importlib-resources;python_version<"3.9" + typing-extensions>=4 [options.extras_require] tz = diff --git a/tox.ini b/tox.ini index 8b744d7c..4cc54450 100644 --- a/tox.ini +++ b/tox.ini @@ -70,8 +70,7 @@ commands= basepython = python3 deps= mypy - sqlalchemy>=1.4.0 - sqlalchemy2-stubs + sqlalchemy>=2 mako types-pkg-resources types-python-dateutil