Skip to content

Commit

Permalink
code clean
Browse files Browse the repository at this point in the history
  • Loading branch information
saxix committed Dec 28, 2013
1 parent b7de13a commit d5e6d86
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 85 deletions.
6 changes: 6 additions & 0 deletions concurrency/db/backends/common.py
@@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-

class TriggerMixin(object):
def drop_triggers(self):
for trigger_name in self.list_triggers():
self.drop_trigger(trigger_name)
7 changes: 3 additions & 4 deletions concurrency/db/backends/mysql/base.py
@@ -1,8 +1,9 @@
from django.db.backends.mysql.base import DatabaseWrapper as MySQLDatabaseWrapper
from concurrency.db.backends.common import TriggerMixin
from concurrency.db.backends.mysql.creation import MySQLCreation


class DatabaseWrapper(MySQLDatabaseWrapper):
class DatabaseWrapper(TriggerMixin, MySQLDatabaseWrapper):
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.creation = MySQLCreation(self)
Expand All @@ -20,6 +21,4 @@ def drop_trigger(self, trigger_name):
result = cursor.execute("DROP TRIGGER IF EXISTS %s;" % trigger_name)
return result

def drop_triggers(self):
for trigger_name in self.list_triggers():
self.drop_trigger(trigger_name)

3 changes: 0 additions & 3 deletions concurrency/db/backends/mysql/creation.py
Expand Up @@ -16,9 +16,6 @@ class MySQLCreation(DatabaseCreation):
FOR EACH ROW SET NEW.{field.column} = OLD.{field.column}+1;
"""

def __init__(self, connection):
super(MySQLCreation, self).__init__(connection)
self.trigger_fields = []

def _create_trigger(self, field):
import MySQLdb as Database
Expand Down
10 changes: 2 additions & 8 deletions concurrency/db/backends/postgresql_psycopg2/base.py
@@ -1,19 +1,17 @@
import logging
import re
from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper as PgDatabaseWrapper
from concurrency.db.backends.common import TriggerMixin
from concurrency.db.backends.postgresql_psycopg2.creation import PgCreation

logger = logging.getLogger(__name__)


class DatabaseWrapper(PgDatabaseWrapper):
class DatabaseWrapper(TriggerMixin, PgDatabaseWrapper):
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.creation = PgCreation(self)

def _clone(self):
return self.__class__(self.settings_dict, self.alias)

def list_triggers(self):
cursor = self.cursor()
stm = "select * from pg_trigger where tgname LIKE 'concurrency_%%'; "
Expand All @@ -30,7 +28,3 @@ def drop_trigger(self, trigger_name):
logger.debug(stm)
result = cursor.execute(stm)
return result

def drop_triggers(self):
for trigger_name in self.list_triggers():
self.drop_trigger(trigger_name)
8 changes: 1 addition & 7 deletions concurrency/db/backends/postgresql_psycopg2/creation.py
@@ -1,4 +1,3 @@

from django.db.backends.postgresql_psycopg2.creation import DatabaseCreation
from concurrency.db.backends.utils import get_trigger_name

Expand Down Expand Up @@ -36,17 +35,12 @@ class PgCreation(DatabaseCreation):
EXECUTE PROCEDURE {trigger_name}_si();
"""

def __init__(self, connection):
super(PgCreation, self).__init__(connection)
self.trigger_fields = []

def _create_trigger(self, field):
from django.db.utils import DatabaseError

opts = field.model._meta
trigger_name = get_trigger_name(field, opts)

# cursor = self.connection._clone().cursor()

stm = self.sql.format(trigger_name=trigger_name,
opts=opts,
field=field)
Expand Down
10 changes: 2 additions & 8 deletions concurrency/db/backends/sqlite3/base.py
@@ -1,16 +1,14 @@
# from django.db.backends.sqlite3.base import *
from django.db.backends.sqlite3.base import DatabaseWrapper as Sqlite3DatabaseWrapper
from concurrency.db.backends.common import TriggerMixin
from concurrency.db.backends.sqlite3.creation import Sqlite3Creation


class DatabaseWrapper(Sqlite3DatabaseWrapper):
class DatabaseWrapper(TriggerMixin, Sqlite3DatabaseWrapper):
def __init__(self, *args, **kwargs):
super(DatabaseWrapper, self).__init__(*args, **kwargs)
self.creation = Sqlite3Creation(self)

def _clone(self):
return self.__class__(self.settings_dict, self.alias)

def list_triggers(self):
cursor = self.cursor()
result = cursor.execute("select name from sqlite_master where type = 'trigger';")
Expand All @@ -20,7 +18,3 @@ def drop_trigger(self, trigger_name):
cursor = self.cursor()
result = cursor.execute("DROP TRIGGER IF EXISTS %s;" % trigger_name)
return result

def drop_triggers(self):
for trigger_name in self.list_triggers():
self.drop_trigger(trigger_name)
4 changes: 0 additions & 4 deletions concurrency/db/backends/sqlite3/creation.py
Expand Up @@ -19,10 +19,6 @@ class Sqlite3Creation(DatabaseCreation):
END; ##
"""

def __init__(self, connection):
super(Sqlite3Creation, self).__init__(connection)
self.trigger_fields = []

def _create_trigger(self, field):
from django.db.utils import DatabaseError
cursor = self.connection.cursor()
Expand Down
51 changes: 0 additions & 51 deletions concurrency/fields.py
Expand Up @@ -173,35 +173,10 @@ class IntegerVersionField(VersionField):
"""
form_class = forms.VersionField

# def get_internal_type(self):
# return "BigIntegerField"

def _get_next_version(self, model_instance):
old_value = getattr(model_instance, self.attname, 0)
return max(int(old_value) + 1, (int(time.time() * 1000000) - OFFSET))

# def pre_save(self, model_instance, add):
# if conf.PROTOCOL >= 2:
# if add:
# value = self._get_next_version(model_instance)
# self._set_version_value(model_instance, value)
# return getattr(model_instance, self.attname)
#
# value = self._get_next_version(model_instance)
# self._set_version_value(model_instance, value)
# return value

# @staticmethod
# def _wrap_save(func):
# from concurrency.api import concurrency_check
#
# def inner(self, force_insert=False, force_update=False, using=None, **kwargs):
# if self._concurrencymeta.enabled:
# concurrency_check(self, force_insert, force_update, using, **kwargs)
# return func(self, force_insert, force_update, using, **kwargs)
#
# return update_wrapper(inner, func)


class AutoIncVersionField(VersionField):
"""
Expand All @@ -210,33 +185,9 @@ class AutoIncVersionField(VersionField):
"""
form_class = forms.VersionField

# def get_internal_type(self):
# return "BigIntegerField"

def _get_next_version(self, model_instance):
return int(getattr(model_instance, self.attname, 0)) + 1

# def pre_save(self, model_instance, add):
# if conf.PROTOCOL >= 2:
# if add:
# value = self._get_next_version(model_instance)
# self._set_version_value(model_instance, value)
# return getattr(model_instance, self.attname)
# value = self._get_next_version(model_instance)
# self._set_version_value(model_instance, value)
# return value

# @staticmethod
# def _wrap_save(func):
# from concurrency.api import concurrency_check
#
# def inner(self, force_insert=False, force_update=False, using=None, **kwargs):
# if self._concurrencymeta.enabled:
# concurrency_check(self, force_insert, force_update, using, **kwargs)
# return func(self, force_insert, force_update, using, **kwargs)
#
# return update_wrapper(inner, func)


class TriggerVersionField(VersionField):
"""
Expand All @@ -257,8 +208,6 @@ def pre_save(self, model_instance, add):
# always returns the same value
return int(getattr(model_instance, self.attname, 0))

# def _set_version_value(self, model_instance, value):
# pass # noop here

@staticmethod
def _increment_version_number(obj):
Expand Down

0 comments on commit d5e6d86

Please sign in to comment.