Skip to content

Commit

Permalink
Improve typing to accommodate sqlalchemy v2
Browse files Browse the repository at this point in the history
Index name can be null.

Fixes: #1168
Change-Id: Id7c944e19a9facd7d3862d43f84fd70aedace999
  • Loading branch information
CaselIT authored and zzzeek committed Feb 26, 2023
1 parent cbc1330 commit 714b744
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 38 deletions.
13 changes: 6 additions & 7 deletions alembic/autogenerate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,16 +640,15 @@ def _compare_indexes_and_uniques(
or sqla_compat._constraint_is_named(c.const, autogen_context.dialect)
}

conn_uniques_by_name: Dict[sqla_compat._ConstraintName, _uq_constraint_sig]
conn_indexes_by_name: Dict[sqla_compat._ConstraintName, _ix_constraint_sig]

conn_uniques_by_name = {c.name: c for c in conn_unique_constraints}
conn_indexes_by_name: Dict[Optional[str], _ix_constraint_sig] = {
c.name: c for c in conn_indexes_sig
}
conn_indexes_by_name = {c.name: c for c in conn_indexes_sig}
conn_names = {
c.name: c
for c in conn_unique_constraints.union(
conn_indexes_sig # type:ignore[arg-type]
)
if c.name is not None
for c in conn_unique_constraints.union(conn_indexes_sig)
if sqla_compat.constraint_name_defined(c.name)
}

doubled_constraints = {
Expand Down
13 changes: 5 additions & 8 deletions alembic/autogenerate/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sqlalchemy import sql
from sqlalchemy import types as sqltypes
from sqlalchemy.sql.elements import conv
from sqlalchemy.sql.elements import quoted_name

from .. import util
from ..operations import ops
Expand All @@ -26,12 +27,10 @@
from typing import Literal

from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import CheckConstraint
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import DefaultClause
from sqlalchemy.sql.schema import FetchedValue
from sqlalchemy.sql.schema import ForeignKey
from sqlalchemy.sql.schema import ForeignKeyConstraint
Expand All @@ -55,12 +54,12 @@

def _render_gen_name(
autogen_context: AutogenContext,
name: Optional[Union[quoted_name, str]],
name: sqla_compat._ConstraintName,
) -> Optional[Union[quoted_name, str, _f_name]]:
if isinstance(name, conv):
return _f_name(_alembic_autogenerate_prefix(autogen_context), name)
else:
return name
return sqla_compat.constraint_name_or_none(name)


def _indent(text: str) -> str:
Expand Down Expand Up @@ -554,7 +553,7 @@ def _ident(name: Optional[Union[quoted_name, str]]) -> Optional[str]:
"""
if name is None:
return name
elif isinstance(name, sql.elements.quoted_name):
elif isinstance(name, quoted_name):
return str(name)
elif isinstance(name, str):
return name
Expand Down Expand Up @@ -721,9 +720,7 @@ def _render_column(column: Column, autogen_context: AutogenContext) -> str:
}


def _should_render_server_default_positionally(
server_default: Union[Computed, DefaultClause]
) -> bool:
def _should_render_server_default_positionally(server_default: Any) -> bool:
return sqla_compat._server_default_is_computed(
server_default
) or sqla_compat._server_default_is_identity(server_default)
Expand Down
7 changes: 4 additions & 3 deletions alembic/ddl/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp):

def __init__(
self,
constraint_name: Optional[str],
constraint_name: sqla_compat._ConstraintName,
table_name: Union[str, quoted_name],
elements: Union[
Sequence[Tuple[str, str]],
Expand All @@ -443,15 +443,16 @@ def from_constraint( # type:ignore[override]
cls, constraint: ExcludeConstraint
) -> CreateExcludeConstraintOp:
constraint_table = sqla_compat._table_for_constraint(constraint)

return cls(
constraint.name,
constraint_table.name,
[
(expr, op)
for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa
],
where=constraint.where,
where=cast(
"Optional[Union[BinaryExpression, str]]", constraint.where
),
schema=constraint_table.schema,
_orig_constraint=constraint,
deferrable=constraint.deferrable,
Expand Down
2 changes: 1 addition & 1 deletion alembic/op.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def create_foreign_key(
"""

def create_index(
index_name: str,
index_name: Optional[str],
table_name: str,
columns: Sequence[Union[str, TextClause, Function]],
schema: Optional[str] = None,
Expand Down
7 changes: 4 additions & 3 deletions alembic/operations/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ..util.sqla_compat import _remove_column_from_collection
from ..util.sqla_compat import _resolve_for_variant
from ..util.sqla_compat import _select
from ..util.sqla_compat import constraint_name_defined

if TYPE_CHECKING:
from typing import Literal
Expand Down Expand Up @@ -268,7 +269,7 @@ def _grab_table_elements(self) -> None:
# because
# we have no way to determine _is_type_bound() for these.
pass
elif const.name:
elif constraint_name_defined(const.name):
self.named_constraints[const.name] = const
else:
self.unnamed_constraints.append(const)
Expand Down Expand Up @@ -662,7 +663,7 @@ def drop_table_comment(self, table):
"""

def add_constraint(self, const: Constraint) -> None:
if not const.name:
if not constraint_name_defined(const.name):
raise ValueError("Constraint must have a name")
if isinstance(const, sql_schema.PrimaryKeyConstraint):
if self.table.primary_key in self.unnamed_constraints:
Expand All @@ -681,7 +682,7 @@ def drop_constraint(self, const: Constraint) -> None:
if col_const.name == const.name:
self.columns[col.name].constraints.remove(col_const)
else:
assert const.name
assert constraint_name_defined(const.name)
const = self.named_constraints.pop(const.name)
except KeyError:
if _is_type_bound(const):
Expand Down
21 changes: 8 additions & 13 deletions alembic/operations/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,7 @@ def to_diff_tuple(
return ("remove_constraint", self.to_constraint())

@classmethod
def from_constraint(
cls,
constraint: Constraint,
) -> DropConstraintOp:
def from_constraint(cls, constraint: Constraint) -> DropConstraintOp:
types = {
"unique_constraint": "unique",
"foreign_key_constraint": "foreignkey",
Expand All @@ -169,7 +166,7 @@ def from_constraint(

constraint_table = sqla_compat._table_for_constraint(constraint)
return cls(
constraint.name,
sqla_compat.constraint_name_or_none(constraint.name),
constraint_table.name,
schema=constraint_table.schema,
type_=types[constraint.__visit_name__],
Expand Down Expand Up @@ -274,9 +271,8 @@ def __init__(
def from_constraint(cls, constraint: Constraint) -> CreatePrimaryKeyOp:
constraint_table = sqla_compat._table_for_constraint(constraint)
pk_constraint = cast("PrimaryKeyConstraint", constraint)

return cls(
pk_constraint.name,
sqla_compat.constraint_name_or_none(pk_constraint.name),
constraint_table.name,
pk_constraint.columns.keys(),
schema=constraint_table.schema,
Expand Down Expand Up @@ -411,7 +407,7 @@ def from_constraint(
kw["initially"] = uq_constraint.initially
kw.update(uq_constraint.dialect_kwargs)
return cls(
uq_constraint.name,
sqla_compat.constraint_name_or_none(uq_constraint.name),
constraint_table.name,
[c.name for c in uq_constraint.columns],
schema=constraint_table.schema,
Expand Down Expand Up @@ -567,7 +563,7 @@ def from_constraint(cls, constraint: Constraint) -> CreateForeignKeyOp:
kw["referent_schema"] = target_schema
kw.update(fk_constraint.dialect_kwargs)
return cls(
fk_constraint.name,
sqla_compat.constraint_name_or_none(fk_constraint.name),
source_table,
target_table,
source_columns,
Expand Down Expand Up @@ -753,9 +749,8 @@ def from_constraint(
constraint_table = sqla_compat._table_for_constraint(constraint)

ck_constraint = cast("CheckConstraint", constraint)

return cls(
ck_constraint.name,
sqla_compat.constraint_name_or_none(ck_constraint.name),
constraint_table.name,
cast("ColumnElement[Any]", ck_constraint.sqltext),
schema=constraint_table.schema,
Expand Down Expand Up @@ -863,7 +858,7 @@ class CreateIndexOp(MigrateOperation):

def __init__(
self,
index_name: str,
index_name: Optional[str],
table_name: str,
columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
schema: Optional[str] = None,
Expand Down Expand Up @@ -914,7 +909,7 @@ def to_index(
def create_index(
cls,
operations: Operations,
index_name: str,
index_name: Optional[str],
table_name: str,
columns: Sequence[Union[str, TextClause, Function]],
schema: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion alembic/operations/schemaobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def column(self, name: str, type_: TypeEngine, **kw) -> Column:

def index(
self,
name: str,
name: Optional[str],
tablename: Optional[str],
columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
schema: Optional[str] = None,
Expand Down
17 changes: 17 additions & 0 deletions alembic/util/sqla_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.visitors import traverse
from typing_extensions import TypeGuard

if TYPE_CHECKING:
from sqlalchemy import Index
Expand Down Expand Up @@ -103,6 +104,22 @@ def _safe_int(value: str) -> Union[int, str]:
_identity_attrs = _identity_options_attrs + ("on_null",)
has_identity = True

if sqla_2:
from sqlalchemy.sql.base import _NoneName
else:
from sqlalchemy.util import symbol as _NoneName # type: ignore[assignment]

_ConstraintName = Union[None, str, _NoneName]


def constraint_name_defined(name: _ConstraintName) -> TypeGuard[str]:
return isinstance(name, str)


def constraint_name_or_none(name: _ConstraintName) -> Optional[str]:
return name if constraint_name_defined(name) else None


AUTOINCREMENT_DEFAULT = "auto"


Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ install_requires =
Mako
importlib-metadata;python_version<"3.9"
importlib-resources;python_version<"3.9"
typing-extensions>=4

[options.extras_require]
tz =
Expand Down
3 changes: 1 addition & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ commands=
basepython = python3
deps=
mypy
sqlalchemy>=1.4.0
sqlalchemy2-stubs
sqlalchemy>=2
mako
types-pkg-resources
types-python-dateutil
Expand Down

0 comments on commit 714b744

Please sign in to comment.