Permalink
Browse files

- support for running with 0.8 version of Cassandra (0.7 still suppor…

…ted too)

- support for compound primary key generation
- support for query set update operations
- support for cascading deletes
- cleaned up code for reconnecting to Cassandra when connection is disrupted
  • Loading branch information...
1 parent a0202fa commit 9d46d30e918d2e462b0a44fe3c5432d5673e49a1 @vaterlaus committed Oct 20, 2011
Showing with 202 additions and 99 deletions.
  1. +39 −15 django_cassandra/db/base.py
  2. +111 −81 django_cassandra/db/compiler.py
  3. +1 −1 django_cassandra/db/creation.py
  4. +29 −0 django_cassandra/db/utils.py
  5. +1 −1 settings.py
  6. +21 −1 tests/tests.py
@@ -23,13 +23,18 @@
from thrift.protocol import TBinaryProtocol
from cassandra import Cassandra
from cassandra.ttypes import *
+import re
import time
from .creation import DatabaseCreation
from .introspection import DatabaseIntrospection
class DatabaseFeatures(NonrelDatabaseFeatures):
string_based_auto_field = True
-
+
+ def __init__(self, connection):
+ super(DatabaseFeatures, self).__init__(connection)
+ self.supports_deleting_related_objects = connection.settings_dict.get('CASSANDRA_ENABLE_CASCADING_DELETES', False)
+
class DatabaseOperations(NonrelDatabaseOperations):
compiler_module = __name__.rsplit('.', 1)[0] + '.compiler'
@@ -38,7 +43,7 @@ def pk_default_value(self):
Use None as the value to indicate to the insert compiler that it needs
to auto-generate a guid to use for the id. The case where this gets hit
is when you create a model instance with no arguments. We override from
- the default implementation (which returns 'DEFAULT') beacuse it's possible
+ the default implementation (which returns 'DEFAULT') because it's possible
that someone would explicitly initialize the id field to be that value and
we wouldn't want to override that. But None would never be a valid value
for the id.
@@ -185,28 +190,47 @@ def get_db_connection(self, set_keyspace=False, login=False):
if not self._db_connection.is_connected():
self._db_connection.open(False, False)
-
- #version = client.describe_version()
+
+ try:
+ version_string = self._db_connection.get_client().describe_version()
# FIXME: Should do some version check here to make sure that we're
# talking to a cassandra daemon that supports the operations we require
-
+ m = re.match('^([0-9]+)\.([0-9]+)\.([0-9]+)$', version_string)
+ major_version = int(m.group(1))
+ minor_version = int(m.group(2))
+ patch_version = int(m.group(3))
+ except Exception, e:
+ raise DatabaseError('Invalid Thrift version string', e)
+
+ # Determine supported features based on the API version
+ self.supports_replication_factor_as_strategy_option = major_version >= 19 and minor_version >= 10
+
if login:
self._db_connection.login()
if set_keyspace:
try:
self._db_connection.set_keyspace()
except Exception, e:
- replication_factor = self.settings_dict.get('CASSANDRA_REPLICATION_FACTOR')
- if not replication_factor:
- replication_factor = 1
- replication_strategy_class = self.settings_dict.get('CASSANDRA_REPLICATION_STRATEGY')
- if not replication_strategy_class:
- replication_strategy_class = 'org.apache.cassandra.locator.SimpleStrategy'
- keyspace_def = KsDef(name=self._db_connection.keyspace,
- strategy_class=replication_strategy_class,
- replication_factor=replication_factor,
- cf_defs=[])
+ replication_factor = self.settings_dict.get('CASSANDRA_REPLICATION_FACTOR', 1)
+ strategy_class = self.settings_dict.get('CASSANDRA_REPLICATION_STRATEGY', 'org.apache.cassandra.locator.SimpleStrategy')
+ strategy_options = self.settings_dict.get('CASSANDRA_REPLICATION_STRATEGY_OPTIONS', {})
+ if type(strategy_options) != dict:
+ raise DatabaseError('CASSANDRA_REPLICATION_STRATEGY_OPTIONS must be a dictionary')
+
+ keyspace_def_args = {
+ 'name': self._db_connection.keyspace,
+ 'strategy_class': strategy_class,
+ 'strategy_options': strategy_options,
+ 'cf_defs': []}
+
+ if self.supports_replication_factor_as_strategy_option:
+ if 'replication_factor' not in strategy_options:
+ strategy_options['replication_factor'] = str(replication_factor)
+ else:
+ keyspace_def_args['replication_factor'] = replication_factor
+
+ keyspace_def = KsDef(**keyspace_def_args)
self._db_connection.get_client().system_add_keyspace(keyspace_def)
self._db_connection.set_keyspace()
@@ -20,7 +20,8 @@
from django.db.models import ForeignKey
from django.db.models.sql.where import AND, OR, WhereNode
-from django.db.utils import DatabaseError, IntegrityError
+from django.db.models.sql.constants import MULTI
+from django.db.utils import DatabaseError
from functools import wraps
@@ -97,43 +98,37 @@ def _get_rows_by_pk(self, range_predicate):
slice_predicate = SlicePredicate(slice_range=SliceRange(start='',
finish='', count=self.connection.max_column_count))
- for attempt in (1,2):
- try:
- if range_predicate._is_exact():
- column_list = db_connection.get_client().get_slice(range_predicate.start,
- column_parent, slice_predicate, self.connection.read_consistency_level)
- if column_list:
- row = self._convert_column_list_to_row(column_list, self.pk_column, range_predicate.start)
- rows = [row]
- else:
- rows = []
- else:
- if range_predicate.start != None:
- key_start = range_predicate.start
- if not range_predicate.start_inclusive:
- key_start = key_start + chr(1)
- else:
- key_start = ''
-
- if range_predicate.end != None:
- key_end = range_predicate.end
- if not range_predicate.end_inclusive:
- key_end = key_end[:-1] + chr(ord(key_end[-1])-1) + (chr(126) * 16)
- else:
- key_end = ''
-
- key_range = KeyRange(start_key=key_start, end_key=key_end,
- count=self.connection.max_key_count)
- key_slice = db_connection.get_client().get_range_slices(column_parent,
- slice_predicate, key_range, self.connection.read_consistency_level)
-
- rows = self._convert_key_slice_to_rows(key_slice)
- break
- except TTransportException, e:
- # Only retry once, so if it's the second time through, propagate the exception
- if attempt == 2:
- raise e
- db_connection.reopen()
+ if range_predicate._is_exact():
+ column_list = call_cassandra_with_reconnect(db_connection,
+ db_connection.get_client().get_slice, range_predicate.start,
+ column_parent, slice_predicate, self.connection.read_consistency_level)
+ if column_list:
+ row = self._convert_column_list_to_row(column_list, self.pk_column, range_predicate.start)
+ rows = [row]
+ else:
+ rows = []
+ else:
+ if range_predicate.start != None:
+ key_start = range_predicate.start
+ if not range_predicate.start_inclusive:
+ key_start = key_start + chr(1)
+ else:
+ key_start = ''
+
+ if range_predicate.end != None:
+ key_end = range_predicate.end
+ if not range_predicate.end_inclusive:
+ key_end = key_end[:-1] + chr(ord(key_end[-1])-1) + (chr(126) * 16)
+ else:
+ key_end = ''
+
+ key_range = KeyRange(start_key=key_start, end_key=key_end,
+ count=self.connection.max_key_count)
+ key_slice = call_cassandra_with_reconnect(db_connection,
+ db_connection.get_client().get_range_slices, column_parent,
+ slice_predicate, key_range, self.connection.read_consistency_level)
+
+ rows = self._convert_key_slice_to_rows(key_slice)
return rows
@@ -170,17 +165,12 @@ def _get_rows_by_indexed_column(self, range_predicate):
index_clause = IndexClause(index_expressions, '', self.connection.max_key_count)
slice_predicate = SlicePredicate(slice_range=SliceRange(start='', finish='', count=self.connection.max_column_count))
- for attempt in (1,2):
- try:
- key_slice = db_connection.get_client().get_indexed_slices(column_parent, index_clause, slice_predicate, self.connection.read_consistency_level)
- rows = self._convert_key_slice_to_rows(key_slice)
- break
- except TTransportException, e:
- # Only retry once, so if it's the second time through, propagate the exception
- if attempt == 2:
- raise e
- db_connection.reopen()
-
+ key_slice = call_cassandra_with_reconnect(db_connection,
+ db_connection.get_client().get_indexed_slices,
+ column_parent, index_clause, slice_predicate,
+ self.connection.read_consistency_level)
+ rows = self._convert_key_slice_to_rows(key_slice)
+
return rows
def get_row_range(self, range_predicate):
@@ -200,17 +190,12 @@ def get_all_rows(self):
key_range = KeyRange(start_token = '0', end_token = '0', count=self.connection.max_key_count)
#end_key = u'\U0010ffff'.encode('utf-8')
#key_range = KeyRange(start_key='\x01', end_key=end_key, count=self.connection.max_key_count)
- for attempt in (1,2):
- try:
- key_slice = db_connection.get_client().get_range_slices(column_parent, slice_predicate, key_range, self.connection.read_consistency_level)
- rows = self._convert_key_slice_to_rows(key_slice)
- break
- except Exception, e:
- # Only retry once, so if it's the second time through, propagate the exception
- if attempt == 2:
- raise e
- db_connection.reopen()
-
+
+ key_slice = call_cassandra_with_reconnect(db_connection,
+ db_connection.get_client().get_range_slices, column_parent,
+ slice_predicate, key_range, self.connection.read_consistency_level)
+ rows = self._convert_key_slice_to_rows(key_slice)
+
return rows
def _get_query_results(self):
@@ -237,7 +222,7 @@ def fetch(self, low_mark, high_mark):
except Exception, e:
# FIXME: Can get rid of this exception handling code eventually,
# but it's useful for debugging for now.
- traceback.print_exc()
+ #traceback.print_exc()
raise e
for entity in results:
@@ -260,15 +245,10 @@ def delete(self):
for item in results:
mutation_map[item[self.pk_column]] = {column_family: [Mutation(deletion=Deletion(timestamp=timestamp))]}
db_connection = self.connection.db_connection
- for attempt in (1,2):
- try:
- db_connection.get_client().batch_mutate(mutation_map, self.connection.write_consistency_level)
- break
- except TTransportException, e:
- # Only retry once, so if it's the second time through, propagate the exception
- if attempt == 2:
- raise e
- db_connection.reopen()
+ call_cassandra_with_reconnect(db_connection,
+ db_connection.get_client().batch_mutate, mutation_map,
+ self.connection.write_consistency_level)
+
@safe_call
def order_by(self, ordering):
@@ -499,20 +479,70 @@ def insert(self, data, return_id=False):
db_connection = self.connection.db_connection
column_family = self.query.get_meta().db_table
- for attempt in (1,2):
- try:
- db_connection.get_client().batch_mutate({key: {column_family: mutation_list}}, self.connection.write_consistency_level)
- break
- except TTransportException, e:
- # Only retry once, so if it's the second time through, propagate the exception
- if attempt == 2:
- raise e
- db_connection.reopen()
+ call_cassandra_with_reconnect(db_connection,
+ db_connection.get_client().batch_mutate, {key: {column_family: mutation_list}},
+ self.connection.write_consistency_level)
+
if return_id:
return key
class SQLUpdateCompiler(NonrelUpdateCompiler, SQLCompiler):
- pass
-
+ def __init__(self, *args, **kwargs):
+ super(SQLUpdateCompiler, self).__init__(*args, **kwargs)
+
+ def execute_sql(self, result_type=MULTI):
+ data = {}
+ for field, model, value in self.query.values:
+ assert field is not None
+ if not field.null and value is None:
+ raise DatabaseError("You can't set %s (a non-nullable "
+ "field) to None!" % field.name)
+ db_type = field.db_type(connection=self.connection)
+ value = self.convert_value_for_db(db_type, value)
+ data[field.column] = value
+
+ # TODO: Add compound key check here -- ensure that we're not updating
+ # any of the fields that are components in the compound key.
+
+ # TODO: This isn't super efficient because executing the query will
+ # fetch all of the columns for each row even though all we really need
+ # is the key for the row. Should be pretty straightforward to change
+ # the CassandraQuery class to support custom slice predicates.
+
+ #model = self.query.model
+ pk_column = self.query.get_meta().pk.column
+
+ pk_index = -1
+ fields = self.get_fields()
+ for index in range(len(fields)):
+ if fields[index].column == pk_column:
+ pk_index = index;
+ break
+ if pk_index == -1:
+ raise DatabaseError('Invalid primary key column')
+
+ row_count = 0
+ column_family = self.query.get_meta().db_table
+ timestamp = get_next_timestamp()
+ batch_mutate_data = {}
+ for result in self.results_iter():
+ row_count += 1
+ mutation_list = []
+ key = result[pk_index]
+ for name, value in data.items():
+ # FIXME: Do we need this check here? Or is the name always already a str instead of unicode.
+ if type(name) is unicode:
+ name = name.decode('utf-8')
+ mutation = Mutation(column_or_supercolumn=ColumnOrSuperColumn(column=Column(name=name, value=value, timestamp=timestamp)))
+ mutation_list.append(mutation)
+ batch_mutate_data[key] = {column_family: mutation_list}
+
+ db_connection = self.connection.db_connection
+ call_cassandra_with_reconnect(db_connection,
+ db_connection.get_client().batch_mutate,
+ batch_mutate_data, self.connection.write_consistency_level)
+
+ return row_count
+
class SQLDeleteCompiler(NonrelDeleteCompiler, SQLCompiler):
pass
@@ -67,7 +67,7 @@ def sql_create_model(self, model, style, known_models=set()):
if field.db_index:
column_name = str(field.db_column if field.db_column else field.column)
column_def = ColumnDef(name=column_name, validation_class='BytesType',
- index_type=IndexType.KEYS, index_name=column_name)
+ index_type=IndexType.KEYS)
column_metadata.append(column_def)
cfdef_settings = self.connection.column_family_def_defaults.copy()
@@ -13,6 +13,8 @@
# limitations under the License.
import time
+from thrift.transport import TTransport
+from django.db.utils import DatabaseError
def _cmp_to_key(comparison_function):
"""
@@ -167,3 +169,30 @@ def convert_string_to_list(s):
def convert_list_to_string(l):
return unicode(l)
+
+
+class CassandraConnectionError(DatabaseError):
+ def __init__(self):
+ super(CassandraConnectionError,self).__init__('Error connecting to Cassandra database')
+
+
+class CassandraAccessError(DatabaseError):
+ def __init__(self):
+ super(CassandraAccessException,self).__init__('Error accessing Cassandra database')
+
+
+def call_cassandra_with_reconnect(connection, fn, *args, **kwargs):
+ try:
+ try:
+ results = fn(*args, **kwargs)
+ except TTransport.TTransportException:
+ connection.reopen()
+ results = fn(*args, **kwargs)
+ except TTransport.TTransportException, e:
+ raise CassandraConnectionError()
+ except Exception, e:
+ raise CassandraAccessError()
+
+ return results
+
+
View
@@ -19,7 +19,7 @@
'PORT': '9160', # Set to empty string for default. Not used with sqlite3.
'SUPPORTS_TRANSACTIONS': False,
'CASSANDRA_REPLICATION_FACTOR': 1,
- 'CASSANDRA_REPLICATION_STRATEGY': 'org.apache.cassandra.locator.SimpleStrategy'
+ 'CASSANDRA_ENABLE_CASCADING_DELETES': True
}
}
Oops, something went wrong.

0 comments on commit 9d46d30

Please sign in to comment.