Skip to content

Commit

Permalink
Merge pull request #69 from piccolo-orm/custom_columns
Browse files Browse the repository at this point in the history
allow custom column types in migrations
  • Loading branch information
dantownsend committed Feb 10, 2021
2 parents fe40806 + 490d35c commit 644d427
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 17 deletions.
1 change: 1 addition & 0 deletions piccolo/apps/migrations/auto/diffable_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __sub__(self, value: DiffableTable) -> TableDelta:
table_class_name=self.class_name,
column_name=i._meta.name,
column_class_name=i.__class__.__name__,
column_class=i.__class__,
params=i._meta.params,
)
for i in (set(self.columns) - set(value.columns))
Expand Down
21 changes: 19 additions & 2 deletions piccolo/apps/migrations/auto/migration_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,27 @@ def add_column(
table_class_name: str,
tablename: str,
column_name: str,
column_class_name: str,
column_class_name: str = "",
column_class: t.Optional[t.Type[Column]] = None,
params: t.Dict[str, t.Any] = {},
):
column_class = getattr(column_types, column_class_name)
"""
Add a new column to the table.
:param column_class_name:
The column type was traditionally specified as a string, using this
variable. This didn't allow users to define custom column types
though, which is why newer migrations directly reference a
``Column`` subclass using ``column_class``.
:param column_class:
A direct reference to a ``Column`` subclass.
"""
column_class = column_class or getattr(column_types, column_class_name)

if column_class is None:
raise ValueError("Unrecognised column type")

cleaned_params = deserialise_params(params=params)
column = column_class(**cleaned_params)
column._meta.name = column_name
Expand Down
2 changes: 2 additions & 0 deletions piccolo/apps/migrations/auto/operations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from piccolo.columns.base import Column
import typing as t


Expand Down Expand Up @@ -39,4 +40,5 @@ class AddColumn:
table_class_name: str
column_name: str
column_class_name: str
column_class: t.Type[Column]
params: t.Dict[str, t.Any]
43 changes: 29 additions & 14 deletions piccolo/apps/migrations/auto/schema_differ.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,9 @@ def _get_snapshot_table(

@property
def alter_columns(self) -> AlterStatements:
response = []
extra_imports = []
extra_definitions = []
response: t.List[str] = []
extra_imports: t.List[Import] = []
extra_definitions: t.List[str] = []
for table in self.schema:
snapshot_table = self._get_snapshot_table(table.class_name)
if snapshot_table:
Expand Down Expand Up @@ -351,30 +351,38 @@ def drop_columns(self) -> AlterStatements:

@property
def add_columns(self) -> AlterStatements:
response = []
extra_imports = []
extra_definitions = []
response: t.List[str] = []
extra_imports: t.List[Import] = []
extra_definitions: t.List[str] = []
for table in self.schema:
snapshot_table = self._get_snapshot_table(table.class_name)
if snapshot_table:
delta: TableDelta = table - snapshot_table
else:
continue

for column in delta.add_columns:
for add_column in delta.add_columns:
if (
column.column_name
add_column.column_name
in self.rename_columns_collection.new_column_names
):
continue

params = serialise_params(column.params)
params = serialise_params(add_column.params)
cleaned_params = params.params
extra_imports.extend(params.extra_imports)
extra_definitions.extend(params.extra_definitions)

column_class = add_column.column_class
extra_imports.append(
Import(
module=column_class.__module__,
target=column_class.__name__,
)
)

response.append(
f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column.column_name}', column_class_name='{column.column_class_name}', params={str(cleaned_params)})" # noqa: E501
f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{add_column.column_name}', column_class_name='{add_column.column_class_name}', column_class={column_class.__name__}, params={str(cleaned_params)})" # noqa: E501
)
return AlterStatements(
statements=response,
Expand All @@ -399,9 +407,9 @@ def new_table_columns(self) -> AlterStatements:
set(self.schema) - set(self.schema_snapshot)
)

response = []
extra_imports = []
extra_definitions = []
response: t.List[str] = []
extra_imports: t.List[Import] = []
extra_definitions: t.List[str] = []
for table in new_tables:
if (
table.class_name
Expand All @@ -417,8 +425,15 @@ def new_table_columns(self) -> AlterStatements:
extra_imports.extend(_params.extra_imports)
extra_definitions.extend(_params.extra_definitions)

extra_imports.append(
Import(
module=column.__class__.__module__,
target=column.__class__.__name__,
)
)

response.append(
f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column._meta.name}', column_class_name='{column.__class__.__name__}', params={str(cleaned_params)})" # noqa: E501
f"manager.add_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{column._meta.name}', column_class_name='{column.__class__.__name__}', column_class={column.__class__.__name__}, params={str(cleaned_params)})" # noqa: E501
)
return AlterStatements(
statements=response,
Expand Down
2 changes: 1 addition & 1 deletion tests/apps/migrations/auto/test_schema_differ.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_add_column(self):
self.assertTrue(len(schema_differ.add_columns.statements) == 1)
self.assertEqual(
schema_differ.add_columns.statements[0],
"manager.add_column(table_class_name='Band', tablename='band', column_name='genre', column_class_name='Varchar', params={'length': 255, 'default': '', 'null': False, 'primary': False, 'key': False, 'unique': False, 'index': False})", # noqa
"manager.add_column(table_class_name='Band', tablename='band', column_name='genre', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary': False, 'key': False, 'unique': False, 'index': False})", # noqa
)

def test_drop_column(self):
Expand Down

0 comments on commit 644d427

Please sign in to comment.