Skip to content

Commit

Permalink
Removed the slow _rowfactory function and replaced it with a cx_Oracl…
Browse files Browse the repository at this point in the history
…e 5 outputtypehandler function that operates only on each column, instead of each returned value.
  • Loading branch information
unknown authored and shaib committed Sep 25, 2012
1 parent e696437 commit 49b1e59
Showing 1 changed file with 39 additions and 55 deletions.
94 changes: 39 additions & 55 deletions django/db/backends/oracle/base.py
Expand Up @@ -671,7 +671,7 @@ class FormatStylePlaceholderCursor(object):
def __init__(self, connection):
self.cursor = connection.cursor()
# Necessary to retrieve decimal values without rounding error.
self.cursor.numbersAsStrings = True
self.cursor.outputtypehandler = self._outputtypehandler
# Default arraysize of 1 is highly sub-optimal.
self.cursor.arraysize = 100

Expand Down Expand Up @@ -743,20 +743,15 @@ def executemany(self, query, params=None):
six.reraise(utils.DatabaseError, utils.DatabaseError(*tuple(e.args)), sys.exc_info()[2])

def fetchone(self):
row = self.cursor.fetchone()
if row is None:
return row
return _rowfactory(row, self.cursor)
return self.cursor.fetchone()

def fetchmany(self, size=None):
if size is None:
size = self.arraysize
return tuple([_rowfactory(r, self.cursor)
for r in self.cursor.fetchmany(size)])
return tuple(self.cursor.fetchmany(size))

def fetchall(self):
return tuple([_rowfactory(r, self.cursor)
for r in self.cursor.fetchall()])
return tuple(self.cursor.fetchall())

def var(self, *args):
return VariableWrapper(self.cursor.var(*args))
Expand All @@ -771,71 +766,46 @@ def __getattr__(self, attr):
return getattr(self.cursor, attr)

def __iter__(self):
return CursorIterator(self.cursor)
return iter(self.cursor)


class CursorIterator(object):

"""Cursor iterator wrapper that invokes our custom row factory."""

def __init__(self, cursor):
self.cursor = cursor
self.iter = iter(cursor)

def __iter__(self):
return self

def __next__(self):
return _rowfactory(next(self.iter), self.cursor)

next = __next__ # Python 2 compatibility


def _rowfactory(row, cursor):
# Cast numeric values as the appropriate Python type based upon the
# cursor description, and convert strings to unicode.
casted = []
for value, desc in zip(row, cursor.description):
if value is not None and desc[1] is Database.NUMBER:
precision, scale = desc[4:6]
def _outputtypehandler(self, cursor, name, default_type, length, precision, scale):
if default_type is Database.NUMBER:
if scale == -127:
if precision == 0:
# NUMBER column: decimal-precision floating point
# This will normally be an integer from a sequence,
# but it could be a decimal value.
if '.' in value:
value = decimal.Decimal(value)
else:
value = int(value)
return cursor.var(str, 100, cursor.arraysize,
outconverter=_decimal_or_int)
else:
# FLOAT column: binary-precision floating point.
# This comes from FloatField columns.
value = float(value)
return cursor.var(default_type, arraysize=cursor.arraysize,
outconverter=float)
elif precision > 0:
# NUMBER(p,s) column: decimal-precision fixed point.
# This comes from IntField and DecimalField columns.
if scale == 0:
value = int(value)
return cursor.var(default_type, arraysize=cursor.arraysize,
outconverter=int)
else:
value = decimal.Decimal(value)
elif '.' in value:
return cursor.var(str, 100, cursor.arraysize,
outconverter=decimal.Decimal)
else:
# No type information. This normally comes from a
# mathematical expression in the SELECT list. Guess int
# or Decimal based on whether it has a decimal point.
value = decimal.Decimal(value)
else:
value = int(value)
return cursor.var(str, 100, cursor.arraysize,
outconverter=_decimal_or_int)
# datetimes are returned as TIMESTAMP, except the results
# of "dates" queries, which are returned as DATETIME.
elif desc[1] in (Database.TIMESTAMP, Database.DATETIME):
# Confirm that dt is naive before overwriting its tzinfo.
if settings.USE_TZ and value is not None and timezone.is_naive(value):
value = value.replace(tzinfo=timezone.utc)
elif desc[1] in (Database.STRING, Database.FIXED_CHAR,
Database.LONG_STRING):
value = to_unicode(value)
casted.append(value)
return tuple(casted)
elif default_type in (Database.TIMESTAMP, Database.DATETIME) and settings.USE_TZ:
return cursor.var(default_type, arraysize=cursor.arraysize,
outconverter=_add_tzinfo)
elif default_type in (Database.STRING, Database.FIXED_CHAR,
Database.LONG_STRING):
return cursor.var(default_type, length, cursor.arraysize,
outconverter=to_unicode)


def to_unicode(s):
Expand All @@ -848,6 +818,20 @@ def to_unicode(s):
return s


def _decimal_or_int(value):
if '.' in value:
return decimal.Decimal(value)
else:
return int(value)


def _add_tzinfo(value):
# Confirm that dt is naive before overwriting its tzinfo.
if value is not None and timezone.is_naive(value):
value = value.replace(tzinfo=timezone.utc)
return value


def _get_sequence_reset_sql():
# TODO: colorize this SQL code with style.SQL_KEYWORD(), etc.
return """
Expand Down

0 comments on commit 49b1e59

Please sign in to comment.