Skip to content

Commit

Permalink
added Numeric and Real column types
Browse files Browse the repository at this point in the history
  • Loading branch information
dantownsend committed May 25, 2020
1 parent b3ad082 commit eebf552
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 0 deletions.
73 changes: 73 additions & 0 deletions piccolo/columns/column_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
import copy
from datetime import datetime
import decimal
import typing as t
import uuid

Expand Down Expand Up @@ -231,6 +232,78 @@ def __init__(self, default: bool = False, **kwargs) -> None:
super().__init__(**kwargs)


###############################################################################


class Numeric(Column):
"""
Used to represent values precisely. The value is returned as a Decimal.
"""

value_type = decimal.Decimal

@property
def column_type(self):
if self.precision and self.scale:
return f"NUMERIC({self.precision}, {self.scale})"
else:
return "NUMERIC"

def __init__(
self,
precision: t.Optional[int] = None,
scale: t.Optional[int] = None,
default: decimal.Decimal = decimal.Decimal(0.0),
**kwargs,
) -> None:
if (precision, scale).count(None) == 1:
raise ValueError(
"The precision and scale args should either both be None, or "
"neither be None."
)

self.default = default
self.precision = precision
self.scale = scale
kwargs.update(
{"default": default, "precision": precision, "scale": scale}
)
super().__init__(**kwargs)


class Decimal(Numeric):
"""
An alias for Numeric.
"""

pass


class Real(Column):
"""
Can be used instead of Numeric when precision isn't as important. The value
is returned as a float.
"""

value_type = float

def __init__(self, default: float = 0.0, **kwargs) -> None:
self.default = default
kwargs.update({"default": default})
super().__init__(**kwargs)


class Float(Real):
"""
An alias for Real.
"""

pass


###############################################################################


class ForeignKey(Integer):
"""
Returns an integer, representing the referenced row's ID.
Expand Down
26 changes: 26 additions & 0 deletions piccolo/engine/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
import contextvars
from dataclasses import dataclass
from decimal import Decimal
import os
import sqlite3
import typing as t
Expand All @@ -14,6 +15,31 @@
from piccolo.querystring import QueryString
from piccolo.utils.sync import run_sync

###############################################################################

# We need to register some adapters so sqlite returns types which are more
# consistent with the Postgres engine.


def convert_numeric_out(value: bytes):
"""
Converts the value coming from sqlite.
"""
return Decimal(value.decode("ascii"))


def convert_numeric_in(value):
"""
Converts the value being passed into sqlite.
"""
return value if isinstance(value, float) else float(value)


sqlite3.register_converter("Numeric", convert_numeric_out)
sqlite3.register_adapter(Decimal, convert_numeric_in)

###############################################################################


@dataclass
class AsyncBatch(Batch):
Expand Down
30 changes: 30 additions & 0 deletions tests/columns/test_numeric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from decimal import Decimal
from unittest import TestCase

from piccolo.table import Table
from piccolo.columns.column_types import Numeric


class MyTable(Table):
column_a = Numeric()
column_b = Numeric(precision=3, scale=2)


class TestNumeric(TestCase):
def setUp(self):
MyTable.create_table().run_sync()

def tearDown(self):
MyTable.alter().drop_table().run_sync()

def test_creation(self):
row = MyTable(column_a=Decimal(1.23), column_b=Decimal(1.23))
row.save().run_sync()

_row = MyTable.objects().first().run_sync()

self.assertTrue(type(_row.column_a) == Decimal)
self.assertTrue(type(_row.column_b) == Decimal)

self.assertAlmostEqual(_row.column_a, Decimal(1.23))
self.assertEqual(_row.column_b, Decimal("1.23"))
24 changes: 24 additions & 0 deletions tests/columns/test_real.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from unittest import TestCase

from piccolo.table import Table
from piccolo.columns.column_types import Real


class MyTable(Table):
column_a = Real()


class TestReal(TestCase):
def setUp(self):
MyTable.create_table().run_sync()

def tearDown(self):
MyTable.alter().drop_table().run_sync()

def test_creation(self):
row = MyTable(column_a=1.23)
row.save().run_sync()

_row = MyTable.objects().first().run_sync()
self.assertTrue(type(_row.column_a) == float)
self.assertAlmostEqual(_row.column_a, 1.23)

0 comments on commit eebf552

Please sign in to comment.