Skip to content

Commit

Permalink
- sanitize variable naming
Browse files Browse the repository at this point in the history
  • Loading branch information
dataflake committed Jun 11, 2018
1 parent 02f88ce commit f53c6c4
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 117 deletions.
82 changes: 43 additions & 39 deletions Products/ZMySQLDA/DA.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,58 +112,65 @@ def _pool_key(self):
"""
return self.getPhysicalPath()

def _getConnection(self):
""" Helper method to retrieve an existing or create a new connection
"""
try:
return self._v_database_connection
except AttributeError:
self.connect(self.connection_string)
return self._v_database_connection

security.declareProtected(use_database_methods, 'connect')

def connect(self, s):
def connect(self, conn_string):
""" Base API. Opens connection to mysql. Raises if problems.
:string: s -- The database connection string
:string: conn_string -- The database connection string
"""
pool_key = self._pool_key()
connection = database_connection_pool.get(pool_key)
conn = database_connection_pool.get(pool_key)

if connection is not None and connection.connection == s:
self._v_database_connection = connection
self._v_connected = connection.connected_timestamp
if conn is not None and conn.connection == conn_string:
self._v_database_connection = conn
self._v_connected = conn.connected_timestamp
else:
if connection is not None:
connection.closeConnection()
DB = self.factory()
DB = DBPool(DB, create_db=self.auto_create_db,
use_unicode=self.use_unicode)
if conn is not None:
conn.closeConnection()

conn_pool = DBPool(self.factory(), create_db=self.auto_create_db,
use_unicode=self.use_unicode)
database_connection_pool_lock.acquire()
try:
database_connection_pool[pool_key] = connection = DB(s)
conn = conn_pool(conn_string)
database_connection_pool[pool_key] = conn
finally:
database_connection_pool_lock.release()
self._v_database_connection = connection

self._v_database_connection = conn
# XXX If date is used as such, it can be wrong because an
# existing connection may be reused. But this is suposedly
# only used as a marker to know if connection was successfull.
self._v_connected = connection.connected_timestamp
self._v_connected = conn.connected_timestamp

return self # ??? why doesn't this return the connection ???

security.declareProtected(use_database_methods, 'sql_quote__')

def sql_quote__(self, v, escapes={}):
def sql_quote__(self, sql_str, escapes={}):
""" Base API. Used to massage SQL strings for use in queries.
:string: v -- The raw SQ string to transform.
:string: sql_str -- The raw SQL string to transform.
:dict: escapes -- Additional escape transformations.
Default: empty ``dict``.
"""
try:
connection = self._v_database_connection
except AttributeError:
self.connect(self.connection_string)
connection = self._v_database_connection
connection = self._getConnection()

if self.use_unicode and isinstance(v, six.text_type):
return connection.unicode_literal(v)
if self.use_unicode and isinstance(sql_str, six.text_type):
return connection.unicode_literal(sql_str)
else:
return connection.string_literal(v)
return connection.string_literal(sql_str)

security.declareProtected(change_database_methods, 'manage_edit')

Expand Down Expand Up @@ -201,24 +208,21 @@ def tpValues(self):
Used in the Zope ZMI ``Browse`` tab
"""
r = []
try:
c = self._v_database_connection
except AttributeError:
self.connect(self.connection_string)
c = self._v_database_connection
for d in c.tables(rdb=0):
t_list = []
connection = self._getConnection()

for t_info in connection.tables(rdb=0):
try:
name = d["table_name"]
b = TableBrowser()
b.__name__ = name
b._d = d
b._c = c
b.icon = table_icons.get(d["table_type"], "text")
r.append(b)
t_browser = TableBrowser()
t_browser.__name__ = t_info["table_name"]
t_browser._d = t_info
t_browser._c = connection
t_browser.icon = table_icons.get(t_info["table_type"], "text")
t_list.append(t_browser)
except Exception:
pass
return r

return t_list


InitializeClass(Connection)
Expand Down
151 changes: 76 additions & 75 deletions Products/ZMySQLDA/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,31 +368,35 @@ def _parse_connection_string(cls, connection, use_unicode=False):
def tables(self, rdb=0, _care=("TABLE", "VIEW")):
""" Returns list of tables.
"""
r = []
a = r.append
result = self._query("SHOW TABLES")
row = result.fetch_row(1)
t_list = []
db_result = self._query("SHOW TABLES")
row = db_result.fetch_row(1)
while row:
table_name = row[0][0]
a({"table_name": table_name, "table_type": "table"})
row = result.fetch_row(1)
return r
t_list.append({"table_name": row[0][0], "table_type": "table"})
row = db_result.fetch_row(1)
return t_list

def columns(self, table_name):
""" Returns list of column descriptions for ``table_name``.
"""
c_list = []
try:
# Field, Type, Null, Key, Default, Extra
c = self._query("SHOW COLUMNS FROM %s" % table_name)
db_result = self._query("SHOW COLUMNS FROM %s" % table_name)
except Exception:
return ()
r = []
for Field, Type, Null, Key, Default, Extra in c.fetch_row(0):
info = {}
field_default = ""

for Field, Type, Null, Key, Default, Extra in db_result.fetch_row(0):
info = {'name': Field,
'extra': (Extra,),
'nullable': (Null == "YES") and 1 or 0}

if Default is not None:
info["default"] = Default
field_default = "DEFAULT '%s'" % Default
else:
field_default = ''

if "(" in Type:
end = Type.rfind(")")
short_type, size = Type[:end].split("(", 1)
Expand All @@ -404,23 +408,18 @@ def columns(self, table_name):
info["scale"] = int(size)
else:
short_type = Type

if short_type in field_icons:
info["icon"] = short_type
else:
info["icon"] = icon_xlate.get(short_type, "what")
info["name"] = Field

info["type"] = short_type
info["extra"] = (Extra,)
info["description"] = " ".join(
[
Type,
field_default,
Extra or "",
key_types.get(Key, Key or ""),
Null == "NO" and "NOT NULL" or "",
]
)
info["nullable"] = (Null == "YES") and 1 or 0
info["description"] = " ".join([Type,
field_default,
Extra or "",
key_types.get(Key, Key or ""),
Null == "NO" and "NOT NULL" or ""])
if Key:
info["index"] = True
info["key"] = Key
Expand All @@ -429,8 +428,10 @@ def columns(self, table_name):
info["unique"] = True
elif Key == "UNI":
info["unique"] = True
r.append(info)
return r

c_list.append(info)

return c_list

def variables(self):
""" Return dictionary of current mysql variable/values.
Expand All @@ -452,50 +453,56 @@ def _query(self, query, force_reconnect=False):
"""
try:
self.db.query(query)
except OperationalError as m:
if m.args[0] in query_syntax_error:
raise OperationalError(m.args[0],
"%s: %s" % (m.args[1], query))
except OperationalError as exc:
if exc.args[0] in query_syntax_error:
raise OperationalError(exc.args[0],
"%s: %s" % (exc.args[1], query))

if not force_reconnect and \
(self._mysql_lock or self._transactions) or \
m.args[0] not in hosed_connection:
exc.args[0] not in hosed_connection:
LOG.warning("query failed: %s" % (query,))
raise

# Hm. maybe the db is hosed. Let's restart it.
if m.args[0] in hosed_connection:
msg = "%s Forcing a reconnect." % hosed_connection[m.args[0]]
if exc.args[0] in hosed_connection:
msg = "%s Forcing a reconnect." % hosed_connection[exc.args[0]]
LOG.error(msg)
self._forceReconnection()
self.db.query(query)
except ProgrammingError as m:
if m.args[0] in hosed_connection:
except ProgrammingError as exc:
if exc.args[0] in hosed_connection:
self._forceReconnection()
msg = "%s Forcing a reconnect." % hosed_connection[m.args[0]]
msg = "%s Forcing a reconnect." % hosed_connection[exc.args[0]]
LOG.error(msg)
else:
LOG.warning("query failed: %s" % (query,))
raise

return self.db.store_result()

def query(self, query_string, max_rows=1000):
""" Execute ``query_string`` and return at most ``max_rows``.
def query(self, sql_string, max_rows=1000):
""" Execute ``sql_string`` and return at most ``max_rows``.
"""
self._use_TM and self._register()
desc = None
result = ()
for qs in filter(None, [q.strip() for q in query_string.split("\0")]):
rows = ()

for qs in filter(None, [q.strip() for q in sql_string.split("\0")]):
qtype = qs.split(None, 1)[0].upper()
if qtype == "SELECT" and max_rows:
qs = "%s LIMIT %d" % (qs, max_rows)
c = self._query(qs)
if desc is not None:
if c and (c.describe() != desc):
msg = "Multiple select schema are not allowed."
raise ProgrammingError(msg)
if c:
desc = c.describe()
result = c.fetch_row(max_rows)
db_results = self._query(qs)

if desc is not None and \
db_results and \
db_results.describe() != desc:
msg = "Multiple select schema are not allowed."
raise ProgrammingError(msg)

if db_results:
desc = db_results.describe()
rows = db_results.fetch_row(max_rows)
else:
desc = None

Expand All @@ -507,28 +514,24 @@ def query(self, query_string, max_rows=1000):
return (), ()

items = []
func = items.append
defs = self.defs
for d in desc:
item = {
"name": d[0],
"type": defs.get(d[1], "t"),
"width": d[2],
"null": d[6],
}
func(item)
return items, result

def string_literal(self, s):
for info in desc:
items.append({"name": info[0],
"type": self.defs.get(info[1], "t"),
"width": info[2],
"null": info[6]})

return items, rows

def string_literal(self, sql_str):
""" Called from zope to quote/escape strings for inclusion
in a query.
"""
return self.db.string_literal(s)
return self.db.string_literal(sql_str)

def unicode_literal(self, s):
def unicode_literal(self, sql_str):
""" Similar to string_literal but encodes it first.
"""
return self.db.unicode_literal(s)
return self.db.unicode_literal(sql_str)

# Zope 2-phase transaction handling methods

Expand Down Expand Up @@ -607,12 +610,10 @@ def _abort(self, *ignored):

def _mysql_version(self):
""" Return mysql server version.
Note instances of this class are not persistent.
"""
_version = getattr(self, "_version", None)
if not _version:
self._version = _version = self.variables().get("version")
return _version
if getattr(self, "_version", None) is None:
self._version = self.variables().get("version")
return self._version

def savepoint(self):
""" Basic savepoint support.
Expand All @@ -636,10 +637,10 @@ class _SavePoint(object):
""" Simple savepoint object
"""

def __init__(self, dm):
self.dm = dm
self.ident = ident = str(time.time()).replace(".", "sp")
dm._query("SAVEPOINT %s" % ident)
def __init__(self, db_conn):
self.db_conn = db_conn
self.ident = str(time.time()).replace(".", "sp")
db_conn._query("SAVEPOINT %s" % self.ident)

def rollback(self):
self.dm._query("ROLLBACK TO %s" % self.ident)
self.db_conn._query("ROLLBACK TO %s" % self.ident)
6 changes: 3 additions & 3 deletions Products/ZMySQLDA/tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,13 @@ def _makeOne(self):

def test_initialization(self):
sp = self._makeOne()
self.assertIsInstance(sp.dm, FakeConnection)
self.assertEqual(sp.dm.last_query, 'SAVEPOINT %s' % sp.ident)
self.assertIsInstance(sp.db_conn, FakeConnection)
self.assertEqual(sp.db_conn.last_query, 'SAVEPOINT %s' % sp.ident)

def test_rollback(self):
sp = self._makeOne()
sp.rollback()
self.assertEqual(sp.dm.last_query, 'ROLLBACK TO %s' % sp.ident)
self.assertEqual(sp.db_conn.last_query, 'ROLLBACK TO %s' % sp.ident)


def test_suite():
Expand Down

0 comments on commit f53c6c4

Please sign in to comment.