Skip to content

Commit

Permalink
Merge 3fce612 into 644d427
Browse files Browse the repository at this point in the history
  • Loading branch information
dantownsend committed Feb 13, 2021
2 parents 644d427 + 3fce612 commit dcfe27f
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 25 deletions.
2 changes: 2 additions & 0 deletions piccolo/apps/migrations/auto/diffable_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def __sub__(self, value: DiffableTable) -> TableDelta:
column_name=column._meta.name,
params=deserialise_params(delta),
old_params=old_params,
column_class=column.__class__,
old_column_class=existing_column.__class__,
)
)

Expand Down
56 changes: 52 additions & 4 deletions piccolo/apps/migrations/auto/migration_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ def alter_column(
column_name: str,
params: t.Dict[str, t.Any],
old_params: t.Dict[str, t.Any],
column_class: t.Optional[t.Type[Column]] = None,
old_column_class: t.Optional[t.Type[Column]] = None,
):
"""
All possible alterations aren't currently supported.
Expand All @@ -263,6 +265,8 @@ def alter_column(
column_name=column_name,
params=params,
old_params=old_params,
column_class=column_class,
old_column_class=old_column_class,
)
)

Expand Down Expand Up @@ -328,9 +332,53 @@ async def _run_alter_columns(self, backwards=False):
_Table: t.Type[Table] = type(table_class_name, (Table,), {})
_Table._meta.tablename = alter_columns[0].tablename

for column in alter_columns:
params = column.old_params if backwards else column.params
column_name = column.column_name
for alter_column in alter_columns:

params = (
alter_column.old_params
if backwards
else alter_column.params
)

old_params = (
alter_column.params
if backwards
else alter_column.old_params
)

###############################################################

# Change the column type if possible
column_class = (
alter_column.old_column_class
if backwards
else alter_column.column_class
)
old_column_class = (
alter_column.column_class
if backwards
else alter_column.old_column_class
)

if (old_column_class is not None) and (
column_class is not None
):
if old_column_class != column_class:
old_column = old_column_class(**old_params)
old_column._meta._table = _Table
old_column._meta._name = alter_column.column_name

new_column = column_class(**params)
new_column._meta._table = _Table
new_column._meta._name = alter_column.column_name

await _Table.alter().set_column_type(
old_column=old_column, new_column=new_column
)

###############################################################

column_name = alter_column.column_name

null = params.get("null")
if null is not None:
Expand Down Expand Up @@ -383,7 +431,7 @@ async def _run_alter_columns(self, backwards=False):
digits = params.get("digits", ...)
if digits is not ...:
await _Table.alter().set_digits(
column=column.column_name, digits=digits,
column=alter_column.column_name, digits=digits,
).run()

async def _run_drop_tables(self, backwards=False):
Expand Down
2 changes: 2 additions & 0 deletions piccolo/apps/migrations/auto/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class AlterColumn:
tablename: str
params: t.Dict[str, t.Any]
old_params: t.Dict[str, t.Any]
column_class: t.Optional[t.Type[Column]] = None
old_column_class: t.Optional[t.Type[Column]] = None


@dataclass
Expand Down
36 changes: 32 additions & 4 deletions piccolo/apps/migrations/auto/schema_differ.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,17 +308,45 @@ def alter_columns(self) -> AlterStatements:
else:
continue

for i in delta.alter_columns:
new_params = serialise_params(i.params)
for alter_column in delta.alter_columns:
new_params = serialise_params(alter_column.params)
extra_imports.extend(new_params.extra_imports)
extra_definitions.extend(new_params.extra_definitions)

old_params = serialise_params(i.old_params)
old_params = serialise_params(alter_column.old_params)
extra_imports.extend(old_params.extra_imports)
extra_definitions.extend(old_params.extra_definitions)

column_class = (
alter_column.column_class.__name__
if alter_column.column_class
else "None"
)

old_column_class = (
alter_column.old_column_class.__name__
if alter_column.old_column_class
else "None"
)

if alter_column.column_class is not None:
extra_imports.append(
Import(
module=alter_column.column_class.__module__,
target=alter_column.column_class.__name__,
)
)

if alter_column.old_column_class is not None:
extra_imports.append(
Import(
module=alter_column.old_column_class.__module__,
target=alter_column.old_column_class.__name__,
)
)

response.append(
f"manager.alter_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{i.column_name}', params={new_params.params}, old_params={old_params.params})" # noqa: E501
f"manager.alter_column(table_class_name='{table.class_name}', tablename='{table.tablename}', column_name='{alter_column.column_name}', params={new_params.params}, old_params={old_params.params}, column_class={column_class}, old_column_class={old_column_class})" # noqa: E501
)

return AlterStatements(
Expand Down
2 changes: 2 additions & 0 deletions piccolo/columns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .column_types import ( # noqa: F401
BigInt,
Boolean,
Bytea,
Date,
Expand All @@ -14,6 +15,7 @@
Real,
Secret,
Serial,
SmallInt,
Text,
Timestamp,
Timestamptz,
Expand Down
9 changes: 5 additions & 4 deletions piccolo/columns/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,15 +462,16 @@ def get_sql_value(self, value: t.Any) -> t.Any:

return output

@property
def column_type(self):
return self.__class__.__name__.upper()

@property
def querystring(self) -> QueryString:
"""
Used when creating tables.
"""
column_type = getattr(
self, "column_type", self.__class__.__name__.upper()
)
query = f'"{self._meta.name}" {column_type}'
query = f'"{self._meta.name}" {self.column_type}'
if self._meta.primary:
query += " PRIMARY"
if self._meta.key:
Expand Down
50 changes: 40 additions & 10 deletions piccolo/query/methods/alter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,24 @@ def querystring(self) -> QueryString:
return QueryString(f"ALTER COLUMN {self.column_name} DROP DEFAULT")


@dataclass
class SetColumnType(AlterStatement):
__slots__ = ("old_column", "new_column")

old_column: Column
new_column: Column

@property
def querystring(self) -> QueryString:
if self.new_column._meta._table is None:
self.new_column._meta._table = self.old_column._meta.table

column_name = self.old_column._meta.name
return QueryString(
f"ALTER COLUMN {column_name} TYPE {self.new_column.column_type}"
)


@dataclass
class SetDefault(AlterColumnStatement):
__slots__ = ("column", "value")
Expand All @@ -107,7 +125,7 @@ def querystring(self) -> QueryString:


@dataclass
class Unique(AlterColumnStatement):
class SetUnique(AlterColumnStatement):
__slots__ = ("boolean",)

boolean: bool
Expand All @@ -129,7 +147,7 @@ def querystring(self) -> QueryString:


@dataclass
class Null(AlterColumnStatement):
class SetNull(AlterColumnStatement):
__slots__ = ("boolean",)

boolean: bool
Expand Down Expand Up @@ -255,13 +273,14 @@ class Alter(Query):
"_drop_default",
"_drop_table",
"_drop",
"_null",
"_rename_columns",
"_rename_table",
"_set_column_type",
"_set_default",
"_set_digits",
"_set_length",
"_unique",
"_set_null",
"_set_unique",
)

def __init__(self, table: t.Type[Table]):
Expand All @@ -272,13 +291,14 @@ def __init__(self, table: t.Type[Table]):
self._drop_default: t.List[DropDefault] = []
self._drop_table: t.Optional[DropTable] = None
self._drop: t.List[DropColumn] = []
self._null: t.List[Null] = []
self._rename_columns: t.List[RenameColumn] = []
self._rename_table: t.List[RenameTable] = []
self._set_column_type: t.List[SetColumnType] = []
self._set_default: t.List[SetDefault] = []
self._set_digits: t.List[SetDigits] = []
self._set_length: t.List[SetLength] = []
self._unique: t.List[Unique] = []
self._set_null: t.List[SetNull] = []
self._set_unique: t.List[SetUnique] = []

def add_column(self, name: str, column: Column) -> Alter:
"""
Expand Down Expand Up @@ -333,6 +353,15 @@ def rename_column(
self._rename_columns.append(RenameColumn(column, new_name))
return self

def set_column_type(self, old_column: Column, new_column: Column) -> Alter:
"""
Change the type of a column.
"""
self._set_column_type.append(
SetColumnType(old_column=old_column, new_column=new_column)
)
return self

def set_default(self, column: Column, value: t.Any) -> Alter:
"""
Set the default for a column.
Expand All @@ -349,7 +378,7 @@ def set_null(
Band.alter().set_null(Band.name, True)
Band.alter().set_null('name', True)
"""
self._null.append(Null(column, boolean))
self._set_null.append(SetNull(column, boolean))
return self

def set_unique(
Expand All @@ -359,7 +388,7 @@ def set_unique(
Band.alter().set_unique(Band.name, True)
Band.alter().set_unique('name', True)
"""
self._unique.append(Unique(column, boolean))
self._set_unique.append(SetUnique(column, boolean))
return self

def set_length(self, column: t.Union[str, Varchar], length: int) -> Alter:
Expand Down Expand Up @@ -472,8 +501,9 @@ def querystrings(self) -> t.Sequence[QueryString]:
self._rename_table,
self._drop,
self._drop_default,
self._unique,
self._null,
self._set_column_type,
self._set_unique,
self._set_null,
self._set_length,
self._set_default,
self._set_digits,
Expand Down
2 changes: 1 addition & 1 deletion tests/apps/migrations/auto/test_schema_differ.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_alter_column_precision(self):
self.assertTrue(len(schema_differ.alter_columns.statements) == 1)
self.assertEqual(
schema_differ.alter_columns.statements[0],
"manager.alter_column(table_class_name='Ticket', tablename='ticket', column_name='price', params={'digits': (4, 2)}, old_params={'digits': (5, 2)})", # noqa
"manager.alter_column(table_class_name='Ticket', tablename='ticket', column_name='price', params={'digits': (4, 2)}, old_params={'digits': (5, 2)}, column_class=Numeric, old_column_class=Numeric)", # noqa
)

def test_alter_default(self):
Expand Down
60 changes: 58 additions & 2 deletions tests/table/test_alter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from piccolo.columns.column_types import Varchar
from unittest import TestCase

from piccolo.columns import Integer, Numeric
from piccolo.columns import BigInt, Integer, Numeric
from piccolo.table import Table

from ..base import DBTestCase, postgres_only
Expand Down Expand Up @@ -113,7 +114,6 @@ def test_unique(self):
self.assertTrue(len(response), 2)


# TODO - make it work for SQLite. Should work.
@postgres_only
class TestMultiple(DBTestCase):
"""
Expand All @@ -137,6 +137,62 @@ def test_multiple(self):
self.assertTrue("column_b" in column_names)


# TODO - test more conversions.
@postgres_only
class TestSetColumnType(DBTestCase):
def test_integer_to_bigint(self):
"""
Test converting an Integer column to BigInt.
"""
self.insert_row()

alter_query = Band.alter().set_column_type(
old_column=Band.popularity, new_column=BigInt()
)
alter_query.run_sync()

query = """
SELECT data_type FROM information_schema.columns
WHERE table_name = 'band'
AND table_catalog = 'piccolo'
AND column_name = 'popularity'
"""

response = Band.raw(query).run_sync()
self.assertEqual(response[0]["data_type"].upper(), "BIGINT")

popularity = (
Band.select(Band.popularity).first().run_sync()["popularity"]
)
self.assertEqual(popularity, 1000)

def test_integer_to_varchar(self):
"""
Test converting an Integer column to Varchar.
"""
self.insert_row()

alter_query = Band.alter().set_column_type(
old_column=Band.popularity, new_column=Varchar()
)
alter_query.run_sync()

query = """
SELECT data_type FROM information_schema.columns
WHERE table_name = 'band'
AND table_catalog = 'piccolo'
AND column_name = 'popularity'
"""

response = Band.raw(query).run_sync()
self.assertEqual(response[0]["data_type"].upper(), "CHARACTER VARYING")

popularity = (
Band.select(Band.popularity).first().run_sync()["popularity"]
)
self.assertEqual(popularity, "1000")


@postgres_only
class TestSetNull(DBTestCase):
def test_set_null(self):
Expand Down

0 comments on commit dcfe27f

Please sign in to comment.