Permalink
Browse files

add batch insert to ColumnFamilyMap, along with accompanying tests

  • Loading branch information...
1 parent 50e668b commit c25681d291ff8de9e11acc56cffd294bc4f66bcb @kylemcc kylemcc committed May 11, 2012
Showing with 65 additions and 9 deletions.
  1. +27 −9 pycassa/columnfamilymap.py
  2. +38 −0 tests/test_columnfamilymap.py
View
@@ -205,6 +205,17 @@ def get_indexed_slices(self, *args, **kwargs):
combined = self.combine_columns(columns)
yield create_instance(self.cls, key=key, **combined)
+ def _get_instance_as_dict(self, instance, columns=None):
+ fields = columns or self.fields
+ instance_dict = {}
+ for field in fields:
+ val = getattr(instance, field, None)
+ if val is not None and not isinstance(val, CassandraType):
+ instance_dict[field] = val
+ if self.super:
+ instance_dict = {instance.super_column: instance_dict}
+ return instance_dict
+
def insert(self, instance, columns=None, timestamp=None, ttl=None,
write_consistency_level=None):
"""
@@ -222,19 +233,26 @@ def insert(self, instance, columns=None, timestamp=None, ttl=None,
else:
fields = columns
- insert_dict = {}
- for field in fields:
- val = getattr(instance, field, None)
- if val is not None and not isinstance(val, CassandraType):
- insert_dict[field] = val
-
- if self.super:
- insert_dict = {instance.super_column: insert_dict}
-
+ insert_dict = self._get_instance_as_dict(instance, columns=fields)
return ColumnFamily.insert(self, instance.key, insert_dict,
timestamp=timestamp, ttl=ttl,
write_consistency_level=write_consistency_level)
+ def batch_insert(self, instances, timestamp=None, ttl=None,
+ write_consistency_level=None):
+ """
+ Insert or update stored instances.
+
+ `instances` should be a list containing instances of `cls` to store.
+ """
+ insert_dict = dict(
+ [(instance.key, self._get_instance_as_dict(instance))
+ for instance in instances]
+ )
+ return ColumnFamily.batch_insert(self, insert_dict,
+ timestamp=timestamp, ttl=ttl,
+ write_consistency_level=write_consistency_level)
+
def remove(self, instance, columns=None, write_consistency_level=None):
"""
Removes a stored instance.
@@ -196,6 +196,24 @@ def test_has_defaults(self):
assert_equal(instance.floatcol, TestUTF8.floatcol.default)
assert_equal(instance.datetimecol, TestUTF8.datetimecol.default)
+ def test_batch_insert(self):
+ instances = []
+ for i in range(3):
+ instance = TestUTF8()
+ instance.key = uuid.uuid4()
+ instance.strcol = 'instance%s' % (i + 1)
+ instances.append(instance)
+
+ for i in instances:
+ assert_raises(NotFoundException, self.map.get, i.key)
+
+ self.map.batch_insert(instances)
+
+ for i in instances:
+ get_instance = self.map.get(i.key)
+ assert_equal(get_instance.key, i.key)
+ assert_equal(get_instance.strcol, i.strcol)
+
class TestSuperColumnFamilyMap(unittest.TestCase):
def setUp(self):
@@ -237,3 +255,23 @@ def test_super_remove(self):
self.map.remove(instance2)
assert_equal(len(self.map.get(instance1.key)), 1)
assert_equal(self.map.get(instance1.key)[instance1.super_column], instance1)
+
+ def test_batch_insert_super(self):
+ instances = []
+ for i in range(3):
+ instance = self.instance('super_batch%s' % (i + 1))
+ instances.append(instance)
+
+ for i in instances:
+ assert_raises(NotFoundException, self.map.get, i.key)
+
+ self.map.batch_insert(instances)
+
+ for i in instances:
+ result = self.map.get(i.key)
+ get_instance = result[i.super_column]
+ assert_equal(len(result), 1)
+ assert_equal(get_instance.key, i.key)
+ assert_equal(get_instance.super_column, i.super_column)
+ assert_equal(get_instance.strcol, i.strcol)
+

0 comments on commit c25681d

Please sign in to comment.