Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Add DecimalType support

Fixes #102
  • Loading branch information...
commit 4771d122a46fee30a25ea8b00c8639b2c103bc7b 1 parent 92de56f
@thobbs thobbs authored
View
40 pycassa/marshal.py
@@ -7,6 +7,7 @@
import struct
import calendar
from datetime import datetime
+from decimal import Decimal
import pycassa.util as util
@@ -26,17 +27,18 @@ def unpack(self, v):
return Struct()
-_bool_packer = make_packer('>B')
-_float_packer = make_packer('>f')
+_bool_packer = make_packer('>B')
+_float_packer = make_packer('>f')
_double_packer = make_packer('>d')
_long_packer = make_packer('>q')
_int_packer = make_packer('>i')
_short_packer = make_packer('>H')
-_BASIC_TYPES = ['BytesType', 'LongType', 'IntegerType', 'UTF8Type',
+_BASIC_TYPES = ('BytesType', 'LongType', 'IntegerType', 'UTF8Type',
'AsciiType', 'LexicalUUIDType', 'TimeUUIDType',
'CounterColumnType', 'FloatType', 'DoubleType',
- 'DateType', 'BooleanType', 'UUIDType', 'Int32Type']
+ 'DateType', 'BooleanType', 'UUIDType', 'Int32Type',
+ 'DecimalType')
def extract_type_name(typestr):
if typestr is None:
@@ -50,7 +52,7 @@ def extract_type_name(typestr):
index = typestr.rfind('.')
if index != -1:
- typestr = typestr[index + 1: ]
+ typestr = typestr[index + 1:]
if typestr not in _BASIC_TYPES:
typestr = 'BytesType'
return typestr
@@ -59,7 +61,7 @@ def _get_inner_type(typestr):
""" Given a str like 'org.apache...ReversedType(LongType)',
return just 'LongType' """
first_paren = typestr.find('(')
- return typestr[first_paren + 1 : -1]
+ return typestr[first_paren + 1:-1]
def _get_inner_types(typestr):
""" Given a str like 'org.apache...CompositeType(LongType, DoubleType)',
@@ -75,7 +77,7 @@ def _to_timestamp(v):
# Expects Value to be either date or datetime
try:
converted = calendar.timegm(v.utctimetuple())
- converted = converted * 1e3 + getattr(v, 'microsecond', 0)/1e3
+ converted = converted * 1e3 + getattr(v, 'microsecond', 0) / 1e3
except AttributeError:
# Ints and floats are valid timestamps too
if type(v) not in _number_types:
@@ -90,7 +92,7 @@ def get_composite_packer(typestr=None, composite_type=None):
if typestr:
packers = map(packer_for, _get_inner_types(typestr))
elif composite_type:
- packers = [c.pack for c in composite_type.components]
+ packers = [c.pack for c in composite_type.components]
len_packer = _short_packer.pack
@@ -143,8 +145,8 @@ def unpack_composite(bytestr):
while bytestr:
unpacker = i.next()
length = len_unpacker(bytestr[:2])
- components.append(unpacker(bytestr[2:2+length]))
- bytestr = bytestr[3+length:]
+ components.append(unpacker(bytestr[2:2 + length]))
+ bytestr = bytestr[3 + length:]
return tuple(components)
return unpack_composite
@@ -181,6 +183,17 @@ def pack_float(v, _=None):
return _float_packer.pack(v)
return pack_float
+ elif data_type == 'DecimalType':
+ def pack_decimal(dec, _=None):
+ sign, digits, exponent = dec.as_tuple()
+ unscaled = int(''.join(map(str, digits)))
+ if sign:
+ unscaled *= -1
+ scale = _int_packer.pack(-exponent)
+ unscaled = encode_int(unscaled)
+ return scale + unscaled
+ return pack_decimal
+
elif data_type == 'LongType':
def pack_long(v, _=None):
return _long_packer.pack(v)
@@ -260,6 +273,13 @@ def unpacker_for(typestr):
elif data_type == 'FloatType':
return lambda v: _float_packer.unpack(v)[0]
+ elif data_type == 'DecimalType':
+ def unpack_decimal(v):
+ scale = _int_packer.unpack(v[:4])[0]
+ unscaled = decode_int(v[4:])
+ return Decimal('%de%d' % (unscaled, -scale))
+ return unpack_decimal
+
elif data_type == 'LongType':
return lambda v: _long_packer.unpack(v)[0]
View
1  pycassa/system_manager.py
@@ -34,6 +34,7 @@
COUNTER_COLUMN_TYPE = types.CounterColumnType()
DOUBLE_TYPE = types.DoubleType()
FLOAT_TYPE = types.FloatType()
+DECIMAL_TYPE = types.DecimalType()
BOOLEAN_TYPE = types.BooleanType()
DATE_TYPE = types.DateType()
View
11 pycassa/types.py
@@ -28,7 +28,7 @@
__all__ = ('CassandraType', 'BytesType', 'LongType', 'IntegerType',
'AsciiType', 'UTF8Type', 'TimeUUIDType', 'LexicalUUIDType',
- 'CounterColumnType', 'DoubleType', 'FloatType',
+ 'CounterColumnType', 'DoubleType', 'FloatType', 'DecimalType',
'BooleanType', 'DateType', 'OldPycassaDateType',
'IntermediateDateType', 'CompositeType')
@@ -109,7 +109,14 @@ class DoubleType(CassandraType):
pass
class FloatType(CassandraType):
- """ Stores data as an 4 byte float """
+ """ Stores data as a 4 byte float """
+ pass
+
+class DecimalType(CassandraType):
+ """
+ Stores an unlimited precision decimal number. `decimal.Decimal`
+ objects are used by pycassa to represent these objects.
+ """
pass
class BooleanType(CassandraType):
View
10 tests/test_autopacking.py
@@ -6,7 +6,7 @@
from pycassa.types import (LongType, IntegerType, TimeUUIDType, LexicalUUIDType,
AsciiType, UTF8Type, BytesType, CompositeType,
OldPycassaDateType, IntermediateDateType, DateType,
- BooleanType, CassandraType)
+ BooleanType, CassandraType, DecimalType)
from pycassa.index import create_index_expression, create_index_clause
import pycassa.marshal as marshal
@@ -16,6 +16,7 @@
from datetime import date, datetime
from uuid import uuid1
+from decimal import Decimal
import uuid
import unittest
import time
@@ -46,6 +47,7 @@ def setup_class(cls):
sys.create_column_family(TEST_KS, 'StdLong', comparator_type=LongType())
sys.create_column_family(TEST_KS, 'StdInteger', comparator_type=IntegerType())
sys.create_column_family(TEST_KS, 'StdBigInteger', comparator_type=IntegerType())
+ sys.create_column_family(TEST_KS, 'StdDecimal', comparator_type=DecimalType())
sys.create_column_family(TEST_KS, 'StdTimeUUID', comparator_type=TimeUUIDType())
sys.create_column_family(TEST_KS, 'StdLexicalUUID', comparator_type=LexicalUUIDType())
sys.create_column_family(TEST_KS, 'StdAscii', comparator_type=AsciiType())
@@ -58,6 +60,7 @@ def setup_class(cls):
cls.cf_long = ColumnFamily(pool, 'StdLong')
cls.cf_int = ColumnFamily(pool, 'StdInteger')
cls.cf_big_int = ColumnFamily(pool, 'StdBigInteger')
+ cls.cf_decimal = ColumnFamily(pool, 'StdDecimal')
cls.cf_time = ColumnFamily(pool, 'StdTimeUUID')
cls.cf_lex = ColumnFamily(pool, 'StdLexicalUUID')
cls.cf_ascii = ColumnFamily(pool, 'StdAscii')
@@ -100,6 +103,11 @@ def test_standard_column_family(self):
3 + int(time.time() * 10 ** 6)]
type_groups.append(self.make_group(TestCFs.cf_big_int, big_int_cols))
+ decimal_cols = [Decimal('1.123456789123456789'),
+ Decimal('2.123456789123456789'),
+ Decimal('3.123456789123456789')]
+ type_groups.append(self.make_group(TestCFs.cf_decimal, decimal_cols))
+
time_cols = [TIME1, TIME2, TIME3]
type_groups.append(self.make_group(TestCFs.cf_time, time_cols))
Please sign in to comment.
Something went wrong with that request. Please try again.