From ca4c4a2a9781e1aa19e2fbbb965838f689a1eed4 Mon Sep 17 00:00:00 2001 From: Keita Ichihashi Date: Mon, 8 Apr 2024 23:43:51 +0900 Subject: [PATCH] Add support for composite unique constraint --- .../apps/migrations/auto/diffable_table.py | 73 ++++- .../apps/migrations/auto/migration_manager.py | 213 +++++++++++++- piccolo/apps/migrations/auto/operations.py | 19 ++ piccolo/apps/migrations/auto/schema_differ.py | 108 +++++++ .../apps/migrations/auto/schema_snapshot.py | 19 ++ piccolo/apps/migrations/commands/new.py | 1 + piccolo/constraint.py | 55 ++++ piccolo/query/methods/alter.py | 20 ++ piccolo/query/methods/create.py | 6 + piccolo/table.py | 17 ++ .../migrations/auto/test_migration_manager.py | 277 ++++++++++++++++++ .../migrations/auto/test_schema_differ.py | 87 ++++++ 12 files changed, 888 insertions(+), 7 deletions(-) create mode 100644 piccolo/constraint.py diff --git a/piccolo/apps/migrations/auto/diffable_table.py b/piccolo/apps/migrations/auto/diffable_table.py index aa609f041..d6172fce5 100644 --- a/piccolo/apps/migrations/auto/diffable_table.py +++ b/piccolo/apps/migrations/auto/diffable_table.py @@ -5,14 +5,17 @@ from piccolo.apps.migrations.auto.operations import ( AddColumn, + AddConstraint, AlterColumn, DropColumn, + DropConstraint, ) from piccolo.apps.migrations.auto.serialisation import ( deserialise_params, serialise_params, ) from piccolo.columns.base import Column +from piccolo.constraint import Constraint from piccolo.table import Table, create_table_class @@ -55,6 +58,8 @@ class TableDelta: add_columns: t.List[AddColumn] = field(default_factory=list) drop_columns: t.List[DropColumn] = field(default_factory=list) alter_columns: t.List[AlterColumn] = field(default_factory=list) + add_constraints: t.List[AddConstraint] = field(default_factory=list) + drop_constraints: t.List[DropConstraint] = field(default_factory=list) def __eq__(self, value: TableDelta) -> bool: # type: ignore """ @@ -85,6 +90,19 @@ def __eq__(self, value) -> bool: return False +@dataclass +class ConstraintComparison: + constraint: Constraint + + def __hash__(self) -> int: + return self.constraint.__hash__() + + def __eq__(self, value) -> bool: + if isinstance(value, ConstraintComparison): + return self.constraint._meta.name == value.constraint._meta.name + return False + + @dataclass class DiffableTable: """ @@ -96,6 +114,7 @@ class DiffableTable: tablename: str schema: t.Optional[str] = None columns: t.List[Column] = field(default_factory=list) + constraints: t.List[Constraint] = field(default_factory=list) previous_class_name: t.Optional[str] = None def __post_init__(self) -> None: @@ -189,10 +208,54 @@ def __sub__(self, value: DiffableTable) -> TableDelta: ) ) + add_constraints = [ + AddConstraint( + table_class_name=self.class_name, + constraint_name=i.constraint._meta.name, + constraint_class_name=i.constraint.__class__.__name__, + constraint_class=i.constraint.__class__, + params=i.constraint._meta.params, + schema=self.schema, + ) + for i in sorted( + { + ConstraintComparison(constraint=constraint) + for constraint in self.constraints + } + - { + ConstraintComparison(constraint=constraint) + for constraint in value.constraints + }, + key=lambda x: x.constraint._meta.name, + ) + ] + + drop_constraints = [ + DropConstraint( + table_class_name=self.class_name, + constraint_name=i.constraint._meta.name, + tablename=value.tablename, + schema=self.schema, + ) + for i in sorted( + { + ConstraintComparison(constraint=constraint) + for constraint in value.constraints + } + - { + ConstraintComparison(constraint=constraint) + for constraint in self.constraints + }, + key=lambda x: x.constraint._meta.name, + ) + ] + return TableDelta( add_columns=add_columns, drop_columns=drop_columns, alter_columns=alter_columns, + add_constraints=add_constraints, + drop_constraints=drop_constraints, ) def __hash__(self) -> int: @@ -218,10 +281,14 @@ def to_table_class(self) -> t.Type[Table]: """ Converts the DiffableTable into a Table subclass. """ + class_members: t.Dict[str, t.Any] = {} + for column in self.columns: + class_members[column._meta.name] = column + for constraint in self.constraints: + class_members[constraint._meta.name] = constraint + return create_table_class( class_name=self.class_name, class_kwargs={"tablename": self.tablename, "schema": self.schema}, - class_members={ - column._meta.name: column for column in self.columns - }, + class_members=class_members, ) diff --git a/piccolo/apps/migrations/auto/migration_manager.py b/piccolo/apps/migrations/auto/migration_manager.py index fca36e8e7..b1744274a 100644 --- a/piccolo/apps/migrations/auto/migration_manager.py +++ b/piccolo/apps/migrations/auto/migration_manager.py @@ -10,12 +10,14 @@ AlterColumn, ChangeTableSchema, DropColumn, + DropConstraint, RenameColumn, RenameTable, ) from piccolo.apps.migrations.auto.serialisation import deserialise_params from piccolo.columns import Column, column_types from piccolo.columns.column_types import ForeignKey, Serial +from piccolo.constraint import Constraint, UniqueConstraint from piccolo.engine import engine_finder from piccolo.query import Query from piccolo.query.base import DDL @@ -127,6 +129,65 @@ def table_class_names(self) -> t.List[str]: return list({i.table_class_name for i in self.alter_columns}) +@dataclass +class AddConstraintClass: + constraint: Constraint + table_class_name: str + tablename: str + schema: t.Optional[str] + + +@dataclass +class AddConstraintCollection: + add_constraints: t.List[AddConstraintClass] = field(default_factory=list) + + def append(self, add_constraint: AddConstraintClass): + self.add_constraints.append(add_constraint) + + def for_table_class_name( + self, table_class_name: str + ) -> t.List[AddConstraintClass]: + return [ + i + for i in self.add_constraints + if i.table_class_name == table_class_name + ] + + def constraints_for_table_class_name( + self, table_class_name: str + ) -> t.List[Constraint]: + return [ + i.constraint + for i in self.add_constraints + if i.table_class_name == table_class_name + ] + + @property + def table_class_names(self) -> t.List[str]: + return list({i.table_class_name for i in self.add_constraints}) + + +@dataclass +class DropConstraintCollection: + drop_constraints: t.List[DropConstraint] = field(default_factory=list) + + def append(self, drop_constraint: DropConstraint): + self.drop_constraints.append(drop_constraint) + + def for_table_class_name( + self, table_class_name: str + ) -> t.List[DropConstraint]: + return [ + i + for i in self.drop_constraints + if i.table_class_name == table_class_name + ] + + @property + def table_class_names(self) -> t.List[str]: + return list({i.table_class_name for i in self.drop_constraints}) + + AsyncFunction = t.Callable[[], t.Coroutine] @@ -159,6 +220,12 @@ class MigrationManager: alter_columns: AlterColumnCollection = field( default_factory=AlterColumnCollection ) + add_constraints: AddConstraintCollection = field( + default_factory=AddConstraintCollection + ) + drop_constraints: DropConstraintCollection = field( + default_factory=DropConstraintCollection + ) raw: t.List[t.Union[t.Callable, AsyncFunction]] = field( default_factory=list ) @@ -345,6 +412,47 @@ def alter_column( ) ) + def add_constraint( + self, + table_class_name: str, + tablename: str, + constraint_name: str, + constraint_class: t.Type[Constraint], + params: t.Dict[str, t.Any], + schema: t.Optional[str] = None, + ): + if constraint_class is UniqueConstraint: + constraint = UniqueConstraint(**params) + else: + raise ValueError("Unrecognised constraint type") + + constraint._meta.name = constraint_name + + self.add_constraints.append( + AddConstraintClass( + constraint=constraint, + table_class_name=table_class_name, + tablename=tablename, + schema=schema, + ) + ) + + def drop_constraint( + self, + table_class_name: str, + tablename: str, + constraint_name: str, + schema: t.Optional[str] = None, + ): + self.drop_constraints.append( + DropConstraint( + table_class_name=table_class_name, + constraint_name=constraint_name, + tablename=tablename, + schema=schema, + ) + ) + def add_raw(self, raw: t.Union[t.Callable, AsyncFunction]): """ A migration manager can execute arbitrary functions or coroutines when @@ -740,16 +848,24 @@ async def _run_add_tables(self, backwards: bool = False): add_columns: t.List[AddColumnClass] = ( self.add_columns.for_table_class_name(add_table.class_name) ) + add_constraints: t.List[AddConstraintClass] = ( + self.add_constraints.for_table_class_name(add_table.class_name) + ) + class_members: t.Dict[str, t.Any] = {} + for add_column in add_columns: + class_members[add_column.column._meta.name] = add_column.column + for add_constraint in add_constraints: + class_members[add_constraint.constraint._meta.name] = ( + add_constraint.constraint + ) + _Table: t.Type[Table] = create_table_class( class_name=add_table.class_name, class_kwargs={ "tablename": add_table.tablename, "schema": add_table.schema, }, - class_members={ - add_column.column._meta.name: add_column.column - for add_column in add_columns - }, + class_members=class_members, ) table_classes.append(_Table) @@ -922,6 +1038,91 @@ async def _run_change_table_schema(self, backwards: bool = False): ) ) + async def _run_add_constraints(self, backwards: bool = False): + if backwards: + for add_constraint in self.add_constraints.add_constraints: + if add_constraint.table_class_name in [ + i.class_name for i in self.add_tables + ]: + # Don't reverse the add constraint as the table is going to + # be deleted. + continue + + _Table = create_table_class( + class_name=add_constraint.table_class_name, + class_kwargs={ + "tablename": add_constraint.tablename, + "schema": add_constraint.schema, + }, + ) + + await self._run_query( + _Table.alter().drop_constraint( + add_constraint.constraint._meta.name + ) + ) + else: + for table_class_name in self.add_constraints.table_class_names: + if table_class_name in [i.class_name for i in self.add_tables]: + continue # No need to add constraints to new tables + + add_constraints: t.List[AddConstraintClass] = ( + self.add_constraints.for_table_class_name(table_class_name) + ) + + _Table = create_table_class( + class_name=add_constraints[0].table_class_name, + class_kwargs={ + "tablename": add_constraints[0].tablename, + "schema": add_constraints[0].schema, + }, + ) + + for add_constraint in add_constraints: + await self._run_query( + _Table.alter().add_constraint( + add_constraint.constraint + ) + ) + + async def _run_drop_constraints(self, backwards: bool = False): + if backwards: + for drop_constraint in self.drop_constraints.drop_constraints: + _Table = await self.get_table_from_snapshot( + table_class_name=drop_constraint.table_class_name, + app_name=self.app_name, + offset=-1, + ) + constraint_to_restore = _Table._meta.get_constraint_by_name( + drop_constraint.constraint_name + ) + await self._run_query( + _Table.alter().add_constraint(constraint_to_restore) + ) + else: + for table_class_name in self.drop_constraints.table_class_names: + constraints = self.drop_constraints.for_table_class_name( + table_class_name + ) + + if not constraints: + continue + + _Table = create_table_class( + class_name=table_class_name, + class_kwargs={ + "tablename": constraints[0].tablename, + "schema": constraints[0].schema, + }, + ) + + for constraint in constraints: + await self._run_query( + _Table.alter().drop_constraint( + constraint_name=constraint.constraint_name + ) + ) + async def run(self, backwards: bool = False): direction = "backwards" if backwards else "forwards" if self.preview: @@ -958,6 +1159,10 @@ async def run(self, backwards: bool = False): # "ALTER COLUMN TYPE is not supported inside a transaction" if engine.engine_type != "cockroach": await self._run_alter_columns(backwards=backwards) + await self._run_add_constraints(backwards=backwards) + await self._run_drop_constraints(backwards=backwards) if engine.engine_type == "cockroach": await self._run_alter_columns(backwards=backwards) + await self._run_add_constraints(backwards=backwards) + await self._run_drop_constraints(backwards=backwards) diff --git a/piccolo/apps/migrations/auto/operations.py b/piccolo/apps/migrations/auto/operations.py index 0676bdbd4..8fc741e96 100644 --- a/piccolo/apps/migrations/auto/operations.py +++ b/piccolo/apps/migrations/auto/operations.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from piccolo.columns.base import Column +from piccolo.constraint import Constraint @dataclass @@ -63,3 +64,21 @@ class AddColumn: column_class: t.Type[Column] params: t.Dict[str, t.Any] schema: t.Optional[str] = None + + +@dataclass +class AddConstraint: + table_class_name: str + constraint_name: str + constraint_class_name: str + constraint_class: t.Type[Constraint] + params: t.Dict[str, t.Any] + schema: t.Optional[str] = None + + +@dataclass +class DropConstraint: + table_class_name: str + constraint_name: str + tablename: str + schema: t.Optional[str] = None diff --git a/piccolo/apps/migrations/auto/schema_differ.py b/piccolo/apps/migrations/auto/schema_differ.py index 1d095b938..8fe1028d0 100644 --- a/piccolo/apps/migrations/auto/schema_differ.py +++ b/piccolo/apps/migrations/auto/schema_differ.py @@ -613,6 +613,69 @@ def add_columns(self) -> AlterStatements: extra_definitions=extra_definitions, ) + @property + def add_constraints(self) -> AlterStatements: + response: t.List[str] = [] + extra_imports: t.List[Import] = [] + extra_definitions: t.List[Definition] = [] + for table in self.schema: + snapshot_table = self._get_snapshot_table(table.class_name) + if snapshot_table: + delta: TableDelta = table - snapshot_table + else: + continue + + for add_constraint in delta.add_constraints: + constraint_class = add_constraint.constraint_class + extra_imports.append( + Import( + module=constraint_class.__module__, + target=constraint_class.__name__, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{constraint_class.__name__.upper()}", + None, + ), + ) + ) + + schema_str = ( + "None" + if add_constraint.schema is None + else f'"{add_constraint.schema}"' + ) + + response.append( + f"manager.add_constraint(table_class_name='{table.class_name}', tablename='{table.tablename}', constraint_name='{add_constraint.constraint_name}', constraint_class={constraint_class.__name__}, params={add_constraint.params}, schema={schema_str})" # noqa: E501 + ) + return AlterStatements( + statements=response, + extra_imports=extra_imports, + extra_definitions=extra_definitions, + ) + + @property + def drop_constraints(self) -> AlterStatements: + response = [] + for table in self.schema: + snapshot_table = self._get_snapshot_table(table.class_name) + if snapshot_table: + delta: TableDelta = table - snapshot_table + else: + continue + + for constraint in delta.drop_constraints: + schema_str = ( + "None" + if constraint.schema is None + else f'"{constraint.schema}"' + ) + + response.append( + f"manager.drop_constraint(table_class_name='{table.class_name}', tablename='{table.tablename}', constraint_name='{constraint.constraint_name}', schema={schema_str})" # noqa: E501 + ) + return AlterStatements(statements=response) + @property def rename_columns(self) -> AlterStatements: alter_statements = AlterStatements() @@ -679,6 +742,48 @@ def new_table_columns(self) -> AlterStatements: extra_definitions=extra_definitions, ) + @property + def new_table_constraints(self) -> AlterStatements: + new_tables: t.List[DiffableTable] = list( + set(self.schema) - set(self.schema_snapshot) + ) + + response: t.List[str] = [] + extra_imports: t.List[Import] = [] + extra_definitions: t.List[Definition] = [] + for table in new_tables: + if ( + table.class_name + in self.rename_tables_collection.new_class_names + ): + continue + + for constraint in table.constraints: + extra_imports.append( + Import( + module=constraint.__class__.__module__, + target=constraint.__class__.__name__, + expect_conflict_with_global_name=getattr( + UniqueGlobalNames, + f"COLUMN_{constraint.__class__.__name__.upper()}", + None, + ), + ) + ) + + schema_str = ( + "None" if table.schema is None else f'"{table.schema}"' + ) + + response.append( + f"manager.add_constraint(table_class_name='{table.class_name}', tablename='{table.tablename}', constraint_name='{constraint._meta.name}', constraint_class={constraint.__class__.__name__}, params={constraint._meta.params}, schema={schema_str})" # noqa: E501 + ) + return AlterStatements( + statements=response, + extra_imports=extra_imports, + extra_definitions=extra_definitions, + ) + ########################################################################### def get_alter_statements(self) -> t.List[AlterStatements]: @@ -691,10 +796,13 @@ def get_alter_statements(self) -> t.List[AlterStatements]: "Renamed tables": self.rename_tables, "Tables which changed schema": self.change_table_schemas, "Created table columns": self.new_table_columns, + "Created table constraints": self.new_table_constraints, "Dropped columns": self.drop_columns, "Columns added to existing tables": self.add_columns, "Renamed columns": self.rename_columns, "Altered columns": self.alter_columns, + "Dropped constraints": self.drop_constraints, + "Constraints added to existing tables": self.add_constraints, } for message, statements in alter_statements.items(): diff --git a/piccolo/apps/migrations/auto/schema_snapshot.py b/piccolo/apps/migrations/auto/schema_snapshot.py index 45963b717..50d8128d7 100644 --- a/piccolo/apps/migrations/auto/schema_snapshot.py +++ b/piccolo/apps/migrations/auto/schema_snapshot.py @@ -112,4 +112,23 @@ def get_snapshot(self) -> t.List[DiffableTable]: rename_column.new_db_column_name ) + add_constraints = ( + manager.add_constraints.constraints_for_table_class_name( + table.class_name + ) + ) + table.constraints.extend(add_constraints) + + drop_constraints = ( + manager.drop_constraints.for_table_class_name( + table.class_name + ) + ) + for drop_constraint in drop_constraints: + table.constraints = [ + i + for i in table.constraints + if i._meta.name != drop_constraint.constraint_name + ] + return tables diff --git a/piccolo/apps/migrations/commands/new.py b/piccolo/apps/migrations/commands/new.py index ff123aaa2..c18c36d31 100644 --- a/piccolo/apps/migrations/commands/new.py +++ b/piccolo/apps/migrations/commands/new.py @@ -191,6 +191,7 @@ async def get_alter_statements( class_name=i.__name__, tablename=i._meta.tablename, columns=i._meta.non_default_columns, + constraints=i._meta.constraints, schema=i._meta.schema, ) for i in app_config.table_classes diff --git a/piccolo/constraint.py b/piccolo/constraint.py new file mode 100644 index 000000000..6bb9e1c95 --- /dev/null +++ b/piccolo/constraint.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import typing as t +from abc import abstractmethod +from dataclasses import dataclass, field + + +class Constraint: + def __init__(self, **kwargs) -> None: + self._meta = ConstraintMeta(params=kwargs) + + def __hash__(self): + return hash(self._meta.name) + + @property + @abstractmethod + def ddl(self) -> str: + raise NotImplementedError + + +@dataclass +class ConstraintMeta: + # Used for representing the table in migrations. + params: t.Dict[str, t.Any] = field(default_factory=dict) + + # Set by the Table Metaclass: + _name: t.Optional[str] = None + + @property + def name(self) -> str: + if not self._name: + raise ValueError( + "`_name` isn't defined - the Table Metaclass should set it." + ) + return self._name + + @name.setter + def name(self, value: str): + self._name = value + + +class UniqueConstraint(Constraint): + def __init__( + self, + unique_columns: t.List[str], + **kwargs, + ) -> None: + self.unique_columns = unique_columns + kwargs.update({"unique_columns": unique_columns}) + super().__init__(**kwargs) + + @property + def ddl(self) -> str: + unique_columns_string = ", ".join(self.unique_columns) + return f"UNIQUE ({unique_columns_string})" diff --git a/piccolo/query/methods/alter.py b/piccolo/query/methods/alter.py index 040b2f883..10a2427b8 100644 --- a/piccolo/query/methods/alter.py +++ b/piccolo/query/methods/alter.py @@ -6,6 +6,7 @@ from piccolo.columns.base import Column from piccolo.columns.column_types import ForeignKey, Numeric, Varchar +from piccolo.constraint import Constraint from piccolo.query.base import DDL from piccolo.utils.warnings import Level, colored_warning @@ -177,6 +178,17 @@ def ddl(self) -> str: return f'ALTER COLUMN "{self.column_name}" TYPE VARCHAR({self.length})' +@dataclass +class AddConstraint(AlterStatement): + __slots__ = ("constraint",) + + constraint: Constraint + + @property + def ddl(self) -> str: + return f"ADD CONSTRAINT {self.constraint._meta.name} {self.constraint.ddl}" # noqa: E501 + + @dataclass class DropConstraint(AlterStatement): __slots__ = ("constraint_name",) @@ -275,6 +287,7 @@ class Alter(DDL): __slots__ = ( "_add_foreign_key_constraint", "_add", + "_add_constraint", "_drop_constraint", "_drop_default", "_drop_table", @@ -294,6 +307,7 @@ def __init__(self, table: t.Type[Table], **kwargs): super().__init__(table, **kwargs) self._add_foreign_key_constraint: t.List[AddForeignKeyConstraint] = [] self._add: t.List[AddColumn] = [] + self._add_constraint: t.List[AddConstraint] = [] self._drop_constraint: t.List[DropConstraint] = [] self._drop_default: t.List[DropDefault] = [] self._drop_table: t.Optional[DropTable] = None @@ -490,6 +504,10 @@ def _get_constraint_name(self, column: t.Union[str, ForeignKey]) -> str: tablename = self.table._meta.tablename return f"{tablename}_{column_name}_fk" + def add_constraint(self, constraint: Constraint) -> Alter: + self._add_constraint.append(AddConstraint(constraint=constraint)) + return self + def drop_constraint(self, constraint_name: str) -> Alter: self._drop_constraint.append( DropConstraint(constraint_name=constraint_name) @@ -590,6 +608,8 @@ def default_ddl(self) -> t.Sequence[str]: self._set_default, self._set_digits, self._set_schema, + self._add_constraint, + self._drop_constraint, ) ] diff --git a/piccolo/query/methods/create.py b/piccolo/query/methods/create.py index 68cccf6b2..418c45e5b 100644 --- a/piccolo/query/methods/create.py +++ b/piccolo/query/methods/create.py @@ -3,6 +3,7 @@ import typing as t from piccolo.query.base import DDL +from piccolo.query.methods.alter import AddConstraint from piccolo.query.methods.create_index import CreateIndex if t.TYPE_CHECKING: # pragma: no cover @@ -87,4 +88,9 @@ def default_ddl(self) -> t.Sequence[str]: ).ddl ) + for constraint in self.table._meta.constraints: + ddl.append( + f"ALTER TABLE {self.table._meta.get_formatted_tablename()} {AddConstraint(constraint=constraint).ddl}" # noqa: E501 + ) + return ddl diff --git a/piccolo/table.py b/piccolo/table.py index b4fcbf942..6a2e5926a 100644 --- a/piccolo/table.py +++ b/piccolo/table.py @@ -28,6 +28,7 @@ ) from piccolo.columns.readable import Readable from piccolo.columns.reference import LAZY_COLUMN_REFERENCES +from piccolo.constraint import Constraint from piccolo.custom_types import TableInstance from piccolo.engine import Engine, engine_finder from piccolo.query import ( @@ -84,6 +85,7 @@ class TableMeta: primary_key: Column = field(default_factory=Column) json_columns: t.List[t.Union[JSON, JSONB]] = field(default_factory=list) secret_columns: t.List[Secret] = field(default_factory=list) + constraints: t.List[Constraint] = field(default_factory=list) auto_update_columns: t.List[Column] = field(default_factory=list) tags: t.List[str] = field(default_factory=list) help_text: t.Optional[str] = None @@ -170,6 +172,15 @@ def get_column_by_name(self, name: str) -> Column: return column_object + def get_constraint_by_name(self, name: str) -> Constraint: + """ + Returns a constraint which matches the given name. + """ + for constraint in self.constraints: + if constraint._meta.name == name: + return constraint + raise ValueError(f"No matching constraint found with name == {name}") + def get_auto_update_values(self) -> t.Dict[Column, t.Any]: """ If columns have ``auto_update`` defined, then we retrieve these values. @@ -276,6 +287,7 @@ def __init_subclass__( auto_update_columns: t.List[Column] = [] primary_key: t.Optional[Column] = None m2m_relationships: t.List[M2M] = [] + constraints: t.List[Constraint] = [] attribute_names = itertools.chain( *[i.__dict__.keys() for i in reversed(cls.__mro__)] @@ -328,6 +340,10 @@ def __init_subclass__( attribute._meta._table = cls m2m_relationships.append(attribute) + if isinstance(attribute, Constraint): + attribute._meta._name = attribute_name + constraints.append(attribute) + if not primary_key: primary_key = cls._create_serial_primary_key() setattr(cls, "id", primary_key) @@ -351,6 +367,7 @@ def __init_subclass__( help_text=help_text, _db=db, m2m_relationships=m2m_relationships, + constraints=constraints, schema=schema, ) diff --git a/tests/apps/migrations/auto/test_migration_manager.py b/tests/apps/migrations/auto/test_migration_manager.py index a1988a029..b5d14ccb3 100644 --- a/tests/apps/migrations/auto/test_migration_manager.py +++ b/tests/apps/migrations/auto/test_migration_manager.py @@ -11,6 +11,7 @@ from piccolo.columns.base import OnDelete, OnUpdate from piccolo.columns.column_types import ForeignKey from piccolo.conf.apps import AppConfig +from piccolo.constraint import UniqueConstraint from piccolo.table import Table, sort_table_classes from piccolo.utils.lazy_loader import LazyLoader from tests.base import AsyncMock, DBTestCase, engine_is, engines_only @@ -336,6 +337,282 @@ def test_add_column(self) -> None: if engine_is("cockroach"): self.assertEqual(response, [{"id": row_id, "name": "Dave"}]) + @engines_only("postgres", "cockroach") + def test_add_table_with_unique_constraint(self): + """ + Test adding a table with a unique constraint to a MigrationManager. + """ + # Create table with unique constraint + manager = MigrationManager() + manager.add_table(class_name="Musician", tablename="musician") + manager.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + manager.add_constraint( + table_class_name="Musician", + tablename="musician", + constraint_name="unique_name_label", + constraint_class=UniqueConstraint, + params={ + "unique_columns": ["name", "label"], + }, + ) + asyncio.run(manager.run()) + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + with self.assertRaises(asyncpg.exceptions.UniqueViolationError): + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + # Reverse + asyncio.run(manager.run(backwards=True)) + self.assertTrue(not self.table_exists("musician")) + + @engines_only("postgres", "cockroach") + @patch.object( + BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock + ) + @patch.object(BaseMigrationManager, "get_app_config") + def test_drop_table_with_unique_constraint( + self, get_app_config: MagicMock, get_migration_managers: MagicMock + ): + """ + Test dropping a table with a unique constraint to a MigrationManager. + """ + # Create table + manager_1 = MigrationManager() + manager_1.add_table(class_name="Musician", tablename="musician") + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + manager_1.add_constraint( + table_class_name="Musician", + tablename="musician", + constraint_name="unique_name_label", + constraint_class=UniqueConstraint, + params={ + "unique_columns": ["name", "label"], + }, + ) + asyncio.run(manager_1.run()) + + # Drop table + manager_2 = MigrationManager() + manager_2.drop_table( + class_name="Musician", + tablename="musician", + ) + asyncio.run(manager_2.run()) + self.assertTrue(not self.table_exists("musician")) + + # Reverse + get_migration_managers.return_value = [manager_1] + app_config = AppConfig(app_name="music", migrations_folder_path="") + get_app_config.return_value = app_config + asyncio.run(manager_2.run(backwards=True)) + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + with self.assertRaises(asyncpg.exceptions.UniqueViolationError): + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + # Reverse + asyncio.run(manager_1.run(backwards=True)) + self.assertTrue(not self.table_exists("musician")) + + @engines_only("postgres", "cockroach") + @patch.object( + BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock + ) + @patch.object(BaseMigrationManager, "get_app_config") + def test_rename_table_with_unique_constraint( + self, get_app_config: MagicMock, get_migration_managers: MagicMock + ): + """ + Test renaming a table with a unique constraint to a MigrationManager. + """ + # Create table + manager_1 = MigrationManager() + manager_1.add_table(class_name="Musician", tablename="musician") + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + manager_1.add_constraint( + table_class_name="Musician", + tablename="musician", + constraint_name="unique_name_label", + constraint_class=UniqueConstraint, + params={ + "unique_columns": ["name", "label"], + }, + ) + asyncio.run(manager_1.run()) + + # Rename table + manager_2 = MigrationManager() + manager_2.rename_table( + old_class_name="Musician", + old_tablename="musician", + new_class_name="Musician2", + new_tablename="musician2", + ) + asyncio.run(manager_2.run()) + self.assertTrue(not self.table_exists("musician")) + self.run_sync("INSERT INTO musician2 VALUES (default, 'a', 'a');") + with self.assertRaises(asyncpg.exceptions.UniqueViolationError): + self.run_sync("INSERT INTO musician2 VALUES (default, 'a', 'a');") + + # Reverse + get_migration_managers.return_value = [manager_1] + app_config = AppConfig(app_name="music", migrations_folder_path="") + get_app_config.return_value = app_config + asyncio.run(manager_2.run(backwards=True)) + self.assertTrue(not self.table_exists("musician2")) + with self.assertRaises(asyncpg.exceptions.UniqueViolationError): + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + # Reverse + asyncio.run(manager_1.run(backwards=True)) + self.assertTrue(not self.table_exists("musician")) + self.assertTrue(not self.table_exists("musician2")) + + @engines_only("postgres") + @patch.object( + BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock + ) + @patch.object(BaseMigrationManager, "get_app_config") + def test_add_unique_constraint( + self, get_app_config: MagicMock, get_migration_managers: MagicMock + ): + """ + Test adding a unique constraint to a MigrationManager. + Cockroach DB doesn't support dropping unique constraints with ALTER TABLE DROP CONSTRAINT. + https://github.com/cockroachdb/cockroach/issues/42840 + """ # noqa: E501 + manager_1 = MigrationManager() + manager_1.add_table(class_name="Musician", tablename="musician") + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + asyncio.run(manager_1.run()) + + manager_2 = MigrationManager() + manager_2.add_constraint( + table_class_name="Musician", + tablename="musician", + constraint_name="musician_unique", + constraint_class=UniqueConstraint, + params={ + "unique_columns": ["name", "label"], + }, + ) + asyncio.run(manager_2.run()) + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + with self.assertRaises(asyncpg.exceptions.UniqueViolationError): + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + # Reverse + get_migration_managers.return_value = [manager_1] + app_config = AppConfig(app_name="music", migrations_folder_path="") + get_app_config.return_value = app_config + asyncio.run(manager_2.run(backwards=True)) + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + # Reverse + asyncio.run(manager_1.run(backwards=True)) + self.assertTrue(not self.table_exists("musician")) + + @engines_only("postgres") + @patch.object( + BaseMigrationManager, "get_migration_managers", new_callable=AsyncMock + ) + @patch.object(BaseMigrationManager, "get_app_config") + def test_drop_unique_constraint( + self, get_app_config: MagicMock, get_migration_managers: MagicMock + ): + """ + Test dropping a unique constraint with a MigrationManager. + Cockroach DB doesn't support dropping unique constraints with ALTER TABLE DROP CONSTRAINT. + https://github.com/cockroachdb/cockroach/issues/42840 + """ # noqa: E501 + manager_1 = MigrationManager() + manager_1.add_table(class_name="Musician", tablename="musician") + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="name", + column_class_name="Varchar", + ) + manager_1.add_column( + table_class_name="Musician", + tablename="musician", + column_name="label", + column_class_name="Varchar", + ) + manager_1.add_constraint( + table_class_name="Musician", + tablename="musician", + constraint_name="musician_unique", + constraint_class=UniqueConstraint, + params={ + "unique_columns": ["name", "label"], + }, + ) + asyncio.run(manager_1.run()) + + manager_2 = MigrationManager() + manager_2.drop_constraint( + table_class_name="Musician", + tablename="musician", + constraint_name="musician_unique", + ) + asyncio.run(manager_2.run()) + + # Reverse + get_migration_managers.return_value = [manager_1] + app_config = AppConfig(app_name="music", migrations_folder_path="") + get_app_config.return_value = app_config + asyncio.run(manager_2.run(backwards=True)) + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + with self.assertRaises(asyncpg.exceptions.UniqueViolationError): + self.run_sync("INSERT INTO musician VALUES (default, 'a', 'a');") + + # Reverse + asyncio.run(manager_1.run(backwards=True)) + self.assertTrue(not self.table_exists("musician")) + @engines_only("postgres", "cockroach") def test_add_column_with_index(self): """ diff --git a/tests/apps/migrations/auto/test_schema_differ.py b/tests/apps/migrations/auto/test_schema_differ.py index 9cf6d26f2..b9fe5252c 100644 --- a/tests/apps/migrations/auto/test_schema_differ.py +++ b/tests/apps/migrations/auto/test_schema_differ.py @@ -13,6 +13,7 @@ SchemaDiffer, ) from piccolo.columns.column_types import Numeric, Varchar +from piccolo.constraint import UniqueConstraint class TestSchemaDiffer(TestCase): @@ -488,6 +489,92 @@ def test_db_column_name(self) -> None: "manager.alter_column(table_class_name='Ticket', tablename='ticket', column_name='price', db_column_name='custom', params={'digits': (4, 2)}, old_params={'digits': (5, 2)}, column_class=Numeric, old_column_class=Numeric, schema=None)", # noqa ) + def test_add_constraint(self) -> None: + """ + Test adding a constraint to an existing table. + """ + name_column = Varchar() + name_column._meta.name = "name" + + genre_column = Varchar() + genre_column._meta.name = "genre" + + name_unique_constraint = UniqueConstraint(unique_columns=["name"]) + name_unique_constraint._meta.name = "unique_name" + + name_genre_unique_constraint = UniqueConstraint(unique_columns=["name", "genre"]) + name_genre_unique_constraint._meta.name = "unique_name_genre" + + schema: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[name_column, genre_column], + constraints=[name_unique_constraint, name_genre_unique_constraint], + ) + ] + schema_snapshot: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[name_column, genre_column], + constraints=[name_unique_constraint], + ) + ] + + schema_differ = SchemaDiffer( + schema=schema, schema_snapshot=schema_snapshot, auto_input="y" + ) + + self.assertTrue(len(schema_differ.add_constraints.statements) == 1) + self.assertEqual( + schema_differ.add_constraints.statements[0], + "manager.add_constraint(table_class_name='Band', tablename='band', constraint_name='unique_name_genre', constraint_class=UniqueConstraint, params={'unique_columns': ['name', 'genre']}, schema=None)" # noqa: E501 + ) + + def test_drop_constraint(self) -> None: + """ + Test dropping a constraint from an existing table. + """ + name_column = Varchar() + name_column._meta.name = "name" + + genre_column = Varchar() + genre_column._meta.name = "genre" + + name_unique_constraint = UniqueConstraint(unique_columns=["name"]) + name_unique_constraint._meta.name = "unique_name" + + name_genre_unique_constraint = UniqueConstraint(unique_columns=["name", "genre"]) + name_genre_unique_constraint._meta.name = "unique_name_genre" + + schema: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[name_column, genre_column], + constraints=[name_unique_constraint], + ) + ] + schema_snapshot: t.List[DiffableTable] = [ + DiffableTable( + class_name="Band", + tablename="band", + columns=[name_column, genre_column], + constraints=[name_unique_constraint, name_genre_unique_constraint], + ) + ] + + schema_differ = SchemaDiffer( + schema=schema, schema_snapshot=schema_snapshot, auto_input="y" + ) + + self.assertTrue(len(schema_differ.drop_constraints.statements) == 1) + self.assertEqual( + schema_differ.drop_constraints.statements[0], + "manager.drop_constraint(table_class_name='Band', tablename='band', constraint_name='unique_name_genre', schema=None)" # noqa: E501 + ) + def test_alter_default(self): pass