Permalink
Browse files

Add PooledColumnFamily class

  • Loading branch information...
1 parent 2a6d99c commit 6b8d679320a78316ce1a661adbcf9e22139d6e4e @thobbs thobbs committed Nov 11, 2010
Showing with 471 additions and 3 deletions.
  1. +24 −0 pycassa/batch.py
  2. +97 −3 pycassa/columnfamily.py
  3. +350 −0 tests/test_pooledcolumnfamily.py
View
24 pycassa/batch.py
@@ -145,3 +145,27 @@ def remove(self, key, columns=None, super_column=None, timestamp=None):
super_column=super_column,
timestamp=timestamp)
+class PooledCfMutator(CfMutator):
+
+ def __init__(self, *args, **kwargs):
+ super(PooledCfMutator, self).__init__(*args, **kwargs)
+
+ def send(self, write_consistency_level=None):
+ if write_consistency_level is None:
+ write_consistency_level = self.write_consistency_level
+ mutations = {}
+ conn = None
+ self._lock.acquire()
+ try:
+ for key, column_family, cols in self._buffer:
+ mutations.setdefault(key, {}).setdefault(column_family, []).extend(cols)
+ if mutations:
+ conn = self._column_family.pool.get()
+ conn.batch_mutate(mutations, write_consistency_level)
+ self._buffer = []
+ finally:
+ if conn:
+ conn.return_to_pool()
+ self._lock.release()
+
+
View
100 pycassa/columnfamily.py
@@ -16,7 +16,7 @@
import uuid
import struct
-from batch import CfMutator
+from batch import CfMutator, PooledCfMutator
if hasattr(struct, 'Struct'): # new in Python 2.5
_have_struct = True
@@ -26,7 +26,7 @@
else:
_have_struct = False
-__all__ = ['gm_timestamp', 'ColumnFamily']
+__all__ = ['gm_timestamp', 'ColumnFamily', 'PooledColumnFamily']
_TYPES = ['BytesType', 'LongType', 'IntegerType', 'UTF8Type', 'AsciiType',
'LexicalUUIDType', 'TimeUUIDType']
@@ -134,7 +134,7 @@ def __init__(self, client, column_family, buffer_size=1024,
col_fam = None
try:
- col_fam = client.get_keyspace_description(use_dict_for_col_metadata=True)[self.column_family]
+ col_fam = self.client.get_keyspace_description(use_dict_for_col_metadata=True)[self.column_family]
except KeyError:
raise NotFoundException('Column family %s not found.' % self.column_family)
@@ -820,3 +820,97 @@ def truncate(self):
"""
self.client.truncate(self.column_family)
+
+class PooledColumnFamily(ColumnFamily):
+ """
+ A ColumnFamily that uses a :class:`.Pool` object instead of a
+ :class:`.Connection` object to perform its operations.
+
+ """
+
+ def __init__(self, pool, keyspace, **kwargs):
+ """
+ A ColumnFamily that uses a :class:`.Pool` object instead of a
+ :class:`.Connection` object to perform its operations. Connections
+ are automatically retrieved before every operation and returned to the
+ pool when the operation completes.
+
+ """
+
+ self.pool = pool
+ conn = self.pool.get()
+ super(PooledColumnFamily, self).__init__(conn, keyspace, **kwargs)
+ conn.return_to_pool()
+
+ def get(self, *args, **kwargs):
+ self.client = self.pool.get()
+ try:
+ return super(PooledColumnFamily, self).get(*args, **kwargs)
+ finally:
+ self.client.return_to_pool()
+
+ def multiget(self, *args, **kwargs):
+ self.client = self.pool.get()
+ try:
+ return super(PooledColumnFamily, self).multiget(*args, **kwargs)
+ finally:
+ self.client.return_to_pool()
+
+ def get_indexed_slices(self, *args, **kwargs):
+ self.client = self.pool.get()
+ try:
+ return super(PooledColumnFamily, self).get_indexed_slices(*args, **kwargs)
+ finally:
+ self.client.return_to_pool()
+
+ def get_count(self, *args, **kwargs):
+ self.client = self.pool.get()
+ try:
+ return super(PooledColumnFamily, self).get_count(*args, **kwargs)
+ finally:
+ self.client.return_to_pool()
+
+ def multiget_count(self, *args, **kwargs):
+ self.client = self.pool.get()
+ try:
+ return super(PooledColumnFamily, self).multiget_count(*args, **kwargs)
+ finally:
+ self.client.return_to_pool()
+
+ def get_range(self, *args, **kwargs):
+ self.client = self.pool.get()
+ try:
+ return super(PooledColumnFamily, self).get_range(*args, **kwargs)
+ finally:
+ self.client.return_to_pool()
+
+ def insert(self, *args, **kwargs):
+ self.client = self.pool.get()
+ try:
+ return super(PooledColumnFamily, self).insert(*args, **kwargs)
+ finally:
+ self.client.return_to_pool()
+
+ def batch_insert(self, *args, **kwargs):
+ self.client = self.pool.get()
+ try:
+ return super(PooledColumnFamily, self).batch_insert(*args, **kwargs)
+ finally:
+ self.client.return_to_pool()
+
+ def remove(self, *args, **kwargs):
+ self.client = self.pool.get()
+ try:
+ return super(PooledColumnFamily, self).remove(*args, **kwargs)
+ finally:
+ self.client.return_to_pool()
+
+ def truncate(self, *args, **kwargs):
+ self.client = self.pool.get()
+ try:
+ return super(PooledColumnFamily, self).truncate(*args, **kwargs)
+ finally:
+ self.client.return_to_pool()
+
+ def batch(self, queue_size=100, write_consistency_level=None):
+ return PooledCfMutator(self, queue_size, self._wcl(write_consistency_level))
View
350 tests/test_pooledcolumnfamily.py
@@ -0,0 +1,350 @@
+from pycassa import connect, connect_thread_local, index, PooledColumnFamily,\
+ QueuePool, ConsistencyLevel, NotFoundException
+
+from nose.tools import assert_raises, assert_equal, assert_true
+
+import struct
+
+class TestDict(dict):
+ pass
+
+class TestColumnFamily:
+ def setUp(self):
+ credentials = {'username': 'jsmith', 'password': 'havebadpass'}
+ self.pool = QueuePool(pool_size=5, keyspace='Keyspace1', credentials=credentials)
+ self.cf = PooledColumnFamily(self.pool, 'Standard2',
+ write_consistency_level=ConsistencyLevel.ONE,
+ buffer_size=2, timestamp=self.timestamp,
+ dict_class=TestDict)
+ try:
+ self.timestamp_n = int(self.cf.get('meta')['timestamp'])
+ except NotFoundException:
+ self.timestamp_n = 0
+ self.clear()
+
+ def tearDown(self):
+ self.cf.insert('meta', {'timestamp': str(self.timestamp_n)})
+
+ # Since the timestamp passed to Cassandra will be in the same second
+ # with the default timestamp function, causing problems with removing
+ # and inserting (Cassandra doesn't know which is later), we supply our own
+ def timestamp(self):
+ self.timestamp_n += 1
+ return self.timestamp_n
+
+ def clear(self):
+ for key, columns in self.cf.get_range(include_timestamp=True):
+ for value, timestamp in columns.itervalues():
+ self.timestamp_n = max(self.timestamp_n, timestamp)
+ self.cf.remove(key)
+
+ def test_empty(self):
+ key = 'TestColumnFamily.test_empty'
+ assert_raises(NotFoundException, self.cf.get, key)
+ assert len(self.cf.multiget([key])) == 0
+ for key, columns in self.cf.get_range():
+ assert len(columns) == 0
+
+ def test_insert_get(self):
+ key = 'TestColumnFamily.test_insert_get'
+ columns = {'1': 'val1', '2': 'val2'}
+ assert_raises(NotFoundException, self.cf.get, key)
+ self.cf.insert(key, columns)
+ assert self.cf.get(key) == columns
+
+ def test_insert_multiget(self):
+ key1 = 'TestColumnFamily.test_insert_multiget1'
+ columns1 = {'1': 'val1', '2': 'val2'}
+ key2 = 'test_insert_multiget1'
+ columns2 = {'3': 'val1', '4': 'val2'}
+ missing_key = 'key3'
+
+ self.cf.insert(key1, columns1)
+ self.cf.insert(key2, columns2)
+ rows = self.cf.multiget([key1, key2, missing_key])
+ assert len(rows) == 2
+ assert rows[key1] == columns1
+ assert rows[key2] == columns2
+ assert missing_key not in rows
+
+ def test_insert_get_count(self):
+ key = 'TestColumnFamily.test_insert_get_count'
+ columns = {'1': 'val1', '2': 'val2'}
+ self.cf.insert(key, columns)
+ assert self.cf.get_count(key) == 2
+
+ assert_equal(self.cf.get_count(key, column_start='1'), 2)
+ assert_equal(self.cf.get_count(key, column_finish='2'), 2)
+ assert_equal(self.cf.get_count(key, column_start='1', column_finish='2'), 2)
+ assert_equal(self.cf.get_count(key, column_start='1', column_finish='1'), 1)
+ assert_equal(self.cf.get_count(key, columns=['1','2']), 2)
+ assert_equal(self.cf.get_count(key, columns=['1']), 1)
+
+ def test_insert_multiget_count(self):
+ keys = ['TestColumnFamily.test_insert_multiget_count1',
+ 'TestColumnFamily.test_insert_multiget_count2',
+ 'TestColumnFamily.test_insert_multiget_count3']
+ columns = {'1': 'val1', '2': 'val2'}
+ for key in keys:
+ self.cf.insert(key, columns)
+ result = self.cf.multiget_count(keys)
+ assert_equal(result[keys[0]], 2)
+ assert_equal(result[keys[1]], 2)
+ assert_equal(result[keys[2]], 2)
+
+ result = self.cf.multiget_count(keys, column_start='1')
+ assert_equal(len(result), 3)
+ assert_equal(result[keys[0]], 2)
+
+ result = self.cf.multiget_count(keys, column_finish='2')
+ assert_equal(len(result), 3)
+ assert_equal(result[keys[0]], 2)
+
+ result = self.cf.multiget_count(keys, column_start='1', column_finish='2')
+ assert_equal(len(result), 3)
+ assert_equal(result[keys[0]], 2)
+
+ result = self.cf.multiget_count(keys, column_start='1', column_finish='1')
+ assert_equal(len(result), 3)
+ assert_equal(result[keys[0]], 1)
+
+ result = self.cf.multiget_count(keys, columns=['1','2'])
+ assert_equal(len(result), 3)
+ assert_equal(result[keys[0]], 2)
+
+ result = self.cf.multiget_count(keys, columns=['1'])
+ assert_equal(len(result), 3)
+ assert_equal(result[keys[0]], 1)
+
+ def test_insert_get_range(self):
+ keys = ['TestColumnFamily.test_insert_get_range%s' % i for i in xrange(5)]
+ columns = {'1': 'val1', '2': 'val2'}
+ for key in keys:
+ self.cf.insert(key, columns)
+
+ rows = list(self.cf.get_range(start=keys[0], finish=keys[-1]))
+ assert len(rows) == len(keys)
+ for i, (k, c) in enumerate(rows):
+ assert k == keys[i]
+ assert c == columns
+
+ def test_get_range_batching(self):
+ self.cf.truncate()
+
+ keys = []
+ columns = {'c': 'v'}
+ for i in range(100, 201):
+ keys.append('key%d' % i)
+ self.cf.insert('key%d' % i, columns)
+
+ for i in range(201, 301):
+ self.cf.insert('key%d' % i, columns)
+
+ count = 0
+ for (k,v) in self.cf.get_range(row_count=100, buffer_size=10):
+ assert_true(k in keys, 'key "%s" should be in keys' % k)
+ count += 1
+ assert_equal(count, 100)
+
+ count = 0
+ for (k,v) in self.cf.get_range(row_count=100, buffer_size=1000):
+ assert_true(k in keys, 'key "%s" should be in keys' % k)
+ count += 1
+ assert_equal(count, 100)
+
+ count = 0
+ for (k,v) in self.cf.get_range(row_count=100, buffer_size=150):
+ assert_true(k in keys, 'key "%s" should be in keys' % k)
+ count += 1
+ assert_equal(count, 100)
+
+ count = 0
+ for (k,v) in self.cf.get_range(row_count=100, buffer_size=7):
+ assert_true(k in keys, 'key "%s" should be in keys' % k)
+ count += 1
+ assert_equal(count, 100)
+
+ count = 0
+ for (k,v) in self.cf.get_range(row_count=100, buffer_size=2):
+ assert_true(k in keys, 'key "%s" should be in keys' % k)
+ count += 1
+ assert_equal(count, 100)
+
+ # Put the remaining keys in our list
+ for i in range(201, 301):
+ keys.append('key%d' % i)
+
+ count = 0
+ for (k,v) in self.cf.get_range(row_count=10000, buffer_size=2):
+ assert_true(k in keys, 'key "%s" should be in keys' % k)
+ count += 1
+ assert_equal(count, 201)
+
+ count = 0
+ for (k,v) in self.cf.get_range(row_count=10000, buffer_size=7):
+ assert_true(k in keys, 'key "%s" should be in keys' % k)
+ count += 1
+ assert_equal(count, 201)
+
+ count = 0
+ for (k,v) in self.cf.get_range(row_count=10000, buffer_size=200):
+ assert_true(k in keys, 'key "%s" should be in keys' % k)
+ count += 1
+ assert_equal(count, 201)
+
+ count = 0
+ for (k,v) in self.cf.get_range(row_count=10000, buffer_size=10000):
+ assert_true(k in keys, 'key "%s" should be in keys' % k)
+ count += 1
+ assert_equal(count, 201)
+
+ # Don't give a row count
+ count = 0
+ for (k,v) in self.cf.get_range(buffer_size=2):
+ assert_true(k in keys, 'key "%s" should be in keys' % k)
+ count += 1
+ assert_equal(count, 201)
+
+ count = 0
+ for (k,v) in self.cf.get_range(buffer_size=77):
+ assert_true(k in keys, 'key "%s" should be in keys' % k)
+ count += 1
+ assert_equal(count, 201)
+
+ count = 0
+ for (k,v) in self.cf.get_range(buffer_size=200):
+ assert_true(k in keys, 'key "%s" should be in keys' % k)
+ count += 1
+ assert_equal(count, 201)
+
+ count = 0
+ for (k,v) in self.cf.get_range(buffer_size=10000):
+ assert_true(k in keys, 'key "%s" should be in keys' % k)
+ count += 1
+ assert_equal(count, 201)
+
+ self.cf.truncate()
+
+ def insert_insert_get_indexed_slices(self):
+ indexed_cf = ColumnFamily(self.client, 'Indexed1')
+
+ columns = {'birthdate': 1L}
+
+ keys = []
+ for i in range(1,4):
+ indexed_cf.insert('key%d' % i, columns)
+ keys.append('key%d')
+
+ expr = index.create_index_expression(column_name='birthdate', value=1L)
+ clause = index.create_index_clause([expr])
+
+ count = 0
+ for key,cols in indexed_cf.get_indexed_slices(clause):
+ assert cols == columns
+ assert key in keys
+ count += 1
+ assert_equal(count, 3)
+
+ def test_get_indexed_slices_batching(self):
+ indexed_cf = PooledColumnFamily(self.pool, 'Indexed1')
+
+ columns = {'birthdate': 1L}
+
+ for i in range(200):
+ indexed_cf.insert('key%d' % i, columns)
+
+ expr = index.create_index_expression(column_name='birthdate', value=1L)
+ clause = index.create_index_clause([expr], count=10)
+
+ result = list(indexed_cf.get_indexed_slices(clause, buffer_size=2))
+ assert_equal(len(result), 10)
+ result = list(indexed_cf.get_indexed_slices(clause, buffer_size=10))
+ assert_equal(len(result), 10)
+ result = list(indexed_cf.get_indexed_slices(clause, buffer_size=77))
+ assert_equal(len(result), 10)
+ result = list(indexed_cf.get_indexed_slices(clause, buffer_size=200))
+ assert_equal(len(result), 10)
+ result = list(indexed_cf.get_indexed_slices(clause, buffer_size=1000))
+ assert_equal(len(result), 10)
+
+ clause = index.create_index_clause([expr], count=250)
+
+ result = list(indexed_cf.get_indexed_slices(clause, buffer_size=2))
+ assert_equal(len(result), 200)
+ result = list(indexed_cf.get_indexed_slices(clause, buffer_size=10))
+ assert_equal(len(result), 200)
+ result = list(indexed_cf.get_indexed_slices(clause, buffer_size=77))
+ assert_equal(len(result), 200)
+ result = list(indexed_cf.get_indexed_slices(clause, buffer_size=200))
+ assert_equal(len(result), 200)
+ result = list(indexed_cf.get_indexed_slices(clause, buffer_size=1000))
+ assert_equal(len(result), 200)
+
+ def test_remove(self):
+ key = 'TestColumnFamily.test_remove'
+ columns = {'1': 'val1', '2': 'val2'}
+ self.cf.insert(key, columns)
+
+ self.cf.remove(key, columns=['2'])
+ del columns['2']
+ assert self.cf.get(key) == {'1': 'val1'}
+
+ self.cf.remove(key)
+ assert_raises(NotFoundException, self.cf.get, key)
+
+ def test_dict_class(self):
+ key = 'TestColumnFamily.test_dict_class'
+ self.cf.insert(key, {'1': 'val1'})
+ assert isinstance(self.cf.get(key), TestDict)
+
+class TestSuperColumnFamily:
+ def setUp(self):
+ credentials = {'username': 'jsmith', 'password': 'havebadpass'}
+ self.pool = QueuePool(pool_size=5, keyspace='Keyspace1', credentials=credentials)
+ self.cf = PooledColumnFamily(self.pool, 'Super2',
+ write_consistency_level=ConsistencyLevel.ONE,
+ buffer_size=2, timestamp=self.timestamp,
+ super=True)
+
+ try:
+ self.timestamp_n = int(self.cf.get('meta')['meta']['timestamp'])
+ except NotFoundException:
+ self.timestamp_n = 0
+ self.clear()
+
+ def tearDown(self):
+ self.cf.insert('meta', {'meta': {'timestamp': str(self.timestamp_n)}})
+
+ # Since the timestamp passed to Cassandra will be in the same second
+ # with the default timestamp function, causing problems with removing
+ # and inserting (Cassandra doesn't know which is later), we supply our own
+ def timestamp(self):
+ self.timestamp_n += 1
+ return self.timestamp_n
+
+ def clear(self):
+ for key, columns in self.cf.get_range(include_timestamp=True):
+ for subcolumns in columns.itervalues():
+ for value, timestamp in subcolumns.itervalues():
+ self.timestamp_n = max(self.timestamp_n, timestamp)
+ self.cf.remove(key)
+
+ def test_super(self):
+ key = 'TestSuperColumnFamily.test_super'
+ columns = {'1': {'sub1': 'val1', 'sub2': 'val2'}, '2': {'sub3': 'val3', 'sub4': 'val4'}}
+ assert_raises(NotFoundException, self.cf.get, key)
+ self.cf.insert(key, columns)
+ assert self.cf.get(key) == columns
+ assert self.cf.multiget([key]) == {key: columns}
+ assert list(self.cf.get_range(start=key, finish=key)) == [(key, columns)]
+
+ def test_super_column_argument(self):
+ key = 'TestSuperColumnFamily.test_super_columns_argument'
+ sub12 = {'sub1': 'val1', 'sub2': 'val2'}
+ sub34 = {'sub3': 'val3', 'sub4': 'val4'}
+ columns = {'1': sub12, '2': sub34}
+ self.cf.insert(key, columns)
+ assert self.cf.get(key, super_column='1') == sub12
+ assert_raises(NotFoundException, self.cf.get, key, super_column='3')
+ assert self.cf.multiget([key], super_column='1') == {key: sub12}
+ assert list(self.cf.get_range(start=key, finish=key, super_column='1')) == [(key, sub12)]

0 comments on commit 6b8d679

Please sign in to comment.