Skip to content

Commit

Permalink
Merge pull request pycassa#141 from kylemcc/master
Browse files Browse the repository at this point in the history
Added batch_insert to ColumnFamilyMap + minor convenience change
  • Loading branch information
thobbs committed May 16, 2012
2 parents 50e668b + 459c7ff commit 66cc8ac
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 10 deletions.
38 changes: 28 additions & 10 deletions pycassa/columnfamilymap.py
Expand Up @@ -19,7 +19,7 @@

def create_instance(cls, **kwargs):
instance = cls()
instance.__dict__.update(kwargs)
map(lambda (k,v): setattr(instance, k, v), kwargs.iteritems())
return instance

class ColumnFamilyMap(ColumnFamily):
Expand Down Expand Up @@ -205,6 +205,17 @@ def get_indexed_slices(self, *args, **kwargs):
combined = self.combine_columns(columns)
yield create_instance(self.cls, key=key, **combined)

def _get_instance_as_dict(self, instance, columns=None):
fields = columns or self.fields
instance_dict = {}
for field in fields:
val = getattr(instance, field, None)
if val is not None and not isinstance(val, CassandraType):
instance_dict[field] = val
if self.super:
instance_dict = {instance.super_column: instance_dict}
return instance_dict

def insert(self, instance, columns=None, timestamp=None, ttl=None,
write_consistency_level=None):
"""
Expand All @@ -222,19 +233,26 @@ def insert(self, instance, columns=None, timestamp=None, ttl=None,
else:
fields = columns

insert_dict = {}
for field in fields:
val = getattr(instance, field, None)
if val is not None and not isinstance(val, CassandraType):
insert_dict[field] = val

if self.super:
insert_dict = {instance.super_column: insert_dict}

insert_dict = self._get_instance_as_dict(instance, columns=fields)
return ColumnFamily.insert(self, instance.key, insert_dict,
timestamp=timestamp, ttl=ttl,
write_consistency_level=write_consistency_level)

def batch_insert(self, instances, timestamp=None, ttl=None,
write_consistency_level=None):
"""
Insert or update stored instances.
`instances` should be a list containing instances of `cls` to store.
"""
insert_dict = dict(
[(instance.key, self._get_instance_as_dict(instance))
for instance in instances]
)
return ColumnFamily.batch_insert(self, insert_dict,
timestamp=timestamp, ttl=ttl,
write_consistency_level=write_consistency_level)

def remove(self, instance, columns=None, write_consistency_level=None):
"""
Removes a stored instance.
Expand Down
38 changes: 38 additions & 0 deletions tests/test_columnfamilymap.py
Expand Up @@ -196,6 +196,24 @@ def test_has_defaults(self):
assert_equal(instance.floatcol, TestUTF8.floatcol.default)
assert_equal(instance.datetimecol, TestUTF8.datetimecol.default)

def test_batch_insert(self):
instances = []
for i in range(3):
instance = TestUTF8()
instance.key = uuid.uuid4()
instance.strcol = 'instance%s' % (i + 1)
instances.append(instance)

for i in instances:
assert_raises(NotFoundException, self.map.get, i.key)

self.map.batch_insert(instances)

for i in instances:
get_instance = self.map.get(i.key)
assert_equal(get_instance.key, i.key)
assert_equal(get_instance.strcol, i.strcol)

class TestSuperColumnFamilyMap(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -237,3 +255,23 @@ def test_super_remove(self):
self.map.remove(instance2)
assert_equal(len(self.map.get(instance1.key)), 1)
assert_equal(self.map.get(instance1.key)[instance1.super_column], instance1)

def test_batch_insert_super(self):
instances = []
for i in range(3):
instance = self.instance('super_batch%s' % (i + 1))
instances.append(instance)

for i in instances:
assert_raises(NotFoundException, self.map.get, i.key)

self.map.batch_insert(instances)

for i in instances:
result = self.map.get(i.key)
get_instance = result[i.super_column]
assert_equal(len(result), 1)
assert_equal(get_instance.key, i.key)
assert_equal(get_instance.super_column, i.super_column)
assert_equal(get_instance.strcol, i.strcol)

0 comments on commit 66cc8ac

Please sign in to comment.