Skip to content

Commit

Permalink
Merge pull request #105 from piccolo-orm/choices
Browse files Browse the repository at this point in the history
Choices
  • Loading branch information
dantownsend committed May 25, 2021
2 parents c85b202 + 94b2241 commit 255c5c1
Show file tree
Hide file tree
Showing 27 changed files with 465 additions and 38 deletions.
56 changes: 56 additions & 0 deletions docs/src/piccolo/schema/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,59 @@ use mixins to reduce the amount of repetition.
class Manager(FavouriteMixin, Table):
name = Varchar()
-------------------------------------------------------------------------------

Choices
-------

You can specify choices for a column, using Python's ``Enum`` support.

.. code-block:: python
from enum import Enum
from piccolo.columns import Varchar
from piccolo.table import Table
class Shirt(Table):
class Size(str, Enum):
small = 's'
medium = 'm'
large = 'l'
size = Varchar(length=1, choices=Size)
We can then use the ``Enum`` in our queries.

.. code-block:: python
>>> Shirt(size=Shirt.Size.large).save().run_sync()
>>> Shirt.select().run_sync()
[{'id': 1, 'size': 'l'}]
Note how the value stored in the database is the ``Enum`` value (in this case ``'l'``).

You can also use the ``Enum`` in ``where`` clauses, and in most other situations
where a query requires a value.

.. code-block:: python
>>> Shirt.insert(
>>> Shirt(size=Shirt.Size.small),
>>> Shirt(size=Shirt.Size.medium)
>>> ).run_sync()
>>> Shirt.select().where(Shirt.size == Shirt.Size.small).run_sync()
[{'id': 1, 'size': 's'}]
Advantages
~~~~~~~~~~

By using choices, you get the following benefits:

* Signalling to other programmers what values are acceptable for the column.
* Improved storage efficiency (we can store ``'l'`` instead of ``'large'``).
* Piccolo admin support (in progress)
37 changes: 33 additions & 4 deletions piccolo/apps/migrations/auto/serialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,22 @@ def __lt__(self, other):
return repr(self) < repr(other)


@dataclass
class SerialisedEnumType:
enum_type: t.Type[Enum]

def __hash__(self):
return hash(self.__repr__())

def __eq__(self, other):
return self.__hash__() == other.__hash__()

def __repr__(self):
class_name = self.enum_type.__name__
params = {i.name: i.value for i in self.enum_type}
return f"Enum('{class_name}', {params})"


@dataclass
class SerialisedCallable:
callable_: t.Callable
Expand Down Expand Up @@ -162,10 +178,7 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams:
for key, value in params.items():

# Builtins, such as str, list and dict.
if (
hasattr(value, "__module__")
and value.__module__ == builtins.__name__
):
if inspect.getmodule(value) == builtins:
params[key] = SerialisedBuiltin(builtin=value)
continue

Expand Down Expand Up @@ -238,6 +251,20 @@ def serialise_params(params: t.Dict[str, t.Any]) -> SerialisedParams:
)
continue

# Enum types
if inspect.isclass(value) and issubclass(value, Enum):
params[key] = SerialisedEnumType(enum_type=value)
extra_imports.append(Import(module="enum", target="Enum"))
for member in value:
type_ = type(member.value)
module = inspect.getmodule(type_)

if module and module != builtins:
module_name = module.__name__
extra_imports.append(
Import(module=module_name, target=type_.__name__)
)

# Functions
if inspect.isfunction(value):
if value.__name__ == "<lambda>":
Expand Down Expand Up @@ -300,5 +327,7 @@ def deserialise_params(params: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
params[key] = value.callable_
elif isinstance(value, SerialisedTableType):
params[key] = value.table_type
elif isinstance(value, SerialisedEnumType):
params[key] = value.enum_type

return params
5 changes: 3 additions & 2 deletions piccolo/apps/migrations/commands/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,11 @@ async def _create_new_migration(app_config: AppConfig, auto=False) -> None:
chain(*[i.statements for i in alter_statements])
)
extra_imports = sorted(
list(set(chain(*[i.extra_imports for i in alter_statements])))
list(set(chain(*[i.extra_imports for i in alter_statements]))),
key=lambda x: x.__repr__(),
)
extra_definitions = sorted(
list(set(chain(*[i.extra_definitions for i in alter_statements])))
list(set(chain(*[i.extra_definitions for i in alter_statements]))),
)

if sum([len(i.statements) for i in alter_statements]) == 0:
Expand Down
61 changes: 61 additions & 0 deletions piccolo/columns/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
NotLike,
)
from piccolo.columns.combination import Where
from piccolo.columns.choices import Choice
from piccolo.columns.defaults.base import Default
from piccolo.columns.reference import LazyTableReference
from piccolo.columns.indexes import IndexMethod
Expand Down Expand Up @@ -124,6 +125,7 @@ class ColumnMeta:
index_method: IndexMethod = IndexMethod.btree
required: bool = False
help_text: t.Optional[str] = None
choices: t.Optional[t.Type[Enum]] = None

# Used for representing the table in migrations and the playground.
params: t.Dict[str, t.Any] = field(default_factory=dict)
Expand Down Expand Up @@ -164,6 +166,30 @@ def engine_type(self) -> str:
else:
raise ValueError("The table has no engine defined.")

def get_choices_dict(self) -> t.Optional[t.Dict[str, t.Any]]:
"""
Return the choices Enum as a dict. It maps the attribute name to a
dict containing the display name, and value.
"""
if self.choices is None:
return None
else:
output = {}
for element in self.choices:
if isinstance(element.value, Choice):
display_name = element.value.display_name
value = element.value.value
else:
display_name = element.name.replace("_", " ").title()
value = element.value

output[element.name] = {
"display_name": display_name,
"value": value,
}

return output

def get_full_name(self, just_alias=False) -> str:
"""
Returns the full column name, taking into account joins.
Expand All @@ -183,6 +209,8 @@ def get_full_name(self, just_alias=False) -> str:
else:
return f'{alias} AS "{column_name}"'

###########################################################################

def copy(self) -> ColumnMeta:
kwargs = self.__dict__.copy()
kwargs.update(
Expand Down Expand Up @@ -266,11 +294,14 @@ def __init__(
index_method: IndexMethod = IndexMethod.btree,
required: bool = False,
help_text: t.Optional[str] = None,
choices: t.Optional[t.Type[Enum]] = None,
**kwargs,
) -> None:
# Used for migrations.
# We deliberately omit 'required', and 'help_text' as they don't effect
# the actual schema.
# 'choices' isn't used directly in the schema, but may be important
# for data migrations.
kwargs.update(
{
"null": null,
Expand All @@ -279,6 +310,7 @@ def __init__(
"unique": unique,
"index": index,
"index_method": index_method,
"choices": choices,
}
)

Expand All @@ -288,6 +320,9 @@ def __init__(
"not nullable."
)

if choices is not None:
self._validate_choices(choices, allowed_type=self.value_type)

self._meta = ColumnMeta(
null=null,
primary=primary,
Expand All @@ -298,6 +333,7 @@ def __init__(
params=kwargs,
required=required,
help_text=help_text,
choices=choices,
)

self.alias: t.Optional[str] = None
Expand All @@ -324,12 +360,37 @@ def _validate_default(
elif callable(default):
self._validated = True
return True
elif (
isinstance(default, Enum) and type(default.value) in allowed_types
):
self._validated = True
return True
else:
raise ValueError(
f"The default {default} isn't one of the permitted types - "
f"{allowed_types}"
)

def _validate_choices(
self, choices: t.Type[Enum], allowed_type: t.Type[t.Any]
) -> bool:
"""
Make sure the choices value has values of the allowed_type.
"""
for element in choices:
if isinstance(element.value, allowed_type):
continue
elif isinstance(element.value, Choice) and isinstance(
element.value.value, allowed_type
):
continue
else:
raise ValueError(
f"{element.name} doesn't have the correct type"
)

return True

def is_in(self, values: t.List[t.Any]) -> Where:
if len(values) == 0:
raise ValueError(
Expand Down
32 changes: 32 additions & 0 deletions piccolo/columns/choices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations
from dataclasses import dataclass
import typing as t


@dataclass
class Choice:
"""
When defining enums for ``Column`` choices, they can either be defined
like:
.. code-block:: python
class Title(Enum):
mr = 1
mrs = 2
If using Piccolo Admin, the values shown will be ``Mr`` and ``Mrs``. If you
want more control, you can use ``Choice`` for the value instead.
.. code-block:: python
class Title(Enum):
mr = Choice(value=1, display_name="Mr.")
mrs = Choice(value=1, display_name="Mrs.")
Now the values shown will be ``Mr.`` and ``Mrs.``.
"""

value: t.Any
display_name: str
22 changes: 14 additions & 8 deletions piccolo/columns/column_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
import decimal
from enum import Enum
import typing as t
import uuid
from datetime import date, datetime, time, timedelta
Expand Down Expand Up @@ -165,7 +166,7 @@ class Band(Table):
def __init__(
self,
length: int = 255,
default: t.Union[str, t.Callable[[], str], None] = "",
default: t.Union[str, Enum, t.Callable[[], str], None] = "",
**kwargs,
) -> None:
self._validate_default(default, (str, None))
Expand Down Expand Up @@ -251,7 +252,9 @@ class Band(Table):
concat_delegate: ConcatDelegate = ConcatDelegate()

def __init__(
self, default: t.Union[str, None, t.Callable[[], str]] = "", **kwargs
self,
default: t.Union[str, Enum, None, t.Callable[[], str]] = "",
**kwargs,
) -> None:
self._validate_default(default, (str, None))
self.default = default
Expand Down Expand Up @@ -333,7 +336,9 @@ class Band(Table):
math_delegate = MathDelegate()

def __init__(
self, default: t.Union[int, t.Callable[[], int], None] = 0, **kwargs
self,
default: t.Union[int, Enum, t.Callable[[], int], None] = 0,
**kwargs,
) -> None:
self._validate_default(default, (int, None))
self.default = default
Expand Down Expand Up @@ -771,7 +776,7 @@ class Band(Table):

def __init__(
self,
default: t.Union[bool, t.Callable[[], bool], None] = False,
default: t.Union[bool, Enum, t.Callable[[], bool], None] = False,
**kwargs,
) -> None:
self._validate_default(default, (bool, None))
Expand Down Expand Up @@ -841,7 +846,7 @@ def __init__(
self,
digits: t.Optional[t.Tuple[int, int]] = None,
default: t.Union[
decimal.Decimal, t.Callable[[], decimal.Decimal], None
decimal.Decimal, Enum, t.Callable[[], decimal.Decimal], None
] = decimal.Decimal(0.0),
**kwargs,
) -> None:
Expand Down Expand Up @@ -897,7 +902,7 @@ class Concert(Table):

def __init__(
self,
default: t.Union[float, t.Callable[[], float], None] = 0.0,
default: t.Union[float, Enum, t.Callable[[], float], None] = 0.0,
**kwargs,
) -> None:
self._validate_default(default, (float, None))
Expand Down Expand Up @@ -1087,7 +1092,7 @@ class Band(Table):
def __init__(
self,
references: t.Union[t.Type[Table], LazyTableReference, str],
default: t.Union[int, None] = None,
default: t.Union[int, Enum, None] = None,
null: bool = True,
on_delete: OnDelete = OnDelete.cascade,
on_update: OnUpdate = OnUpdate.cascade,
Expand Down Expand Up @@ -1324,6 +1329,7 @@ def __init__(
default: t.Union[
bytes,
bytearray,
Enum,
t.Callable[[], bytes],
t.Callable[[], bytearray],
None,
Expand Down Expand Up @@ -1376,7 +1382,7 @@ class Ticket(Table):
def __init__(
self,
base_column: Column,
default: t.Union[t.List, t.Callable[[], t.List], None] = list,
default: t.Union[t.List, Enum, t.Callable[[], t.List], None] = list,
**kwargs,
) -> None:
if isinstance(base_column, ForeignKey):
Expand Down
Loading

0 comments on commit 255c5c1

Please sign in to comment.