Skip to content

Commit

Permalink
Fixed issue with MySQLdb and connection pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
nullism committed Apr 9, 2013
1 parent 81a71b9 commit b3cbe26
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/pyormish/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
"""
from model import Model
import session
__version__ = "0.9.2"
__version__ = "0.9.3"
6 changes: 3 additions & 3 deletions src/pyormish/examples/example-users-messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def _delete(self):

if __name__ == "__main__":

Model.session = session.SQLite(':memory:')
Model.db_config = dict(DB_TYPE='sqlite', DB_PATH=':memory:')

# Create the users table for this example
Model.session.execute('''
Model().connection.execute('''
CREATE TABLE users (
id INTEGER PRIMARY KEY ASC,
username VARCHAR(255),
Expand All @@ -63,7 +63,7 @@ def _delete(self):
''')

# Create the messages table for this example
Model.session.execute('''
Model().connection.execute('''
CREATE TABLE messages (
id INTEGER PRIMARY KEY ASC,
to_user_id INTEGER,
Expand Down
7 changes: 4 additions & 3 deletions src/pyormish/examples/example-users.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/python
from pyormish import Model, session
from pyormish import Model
import hashlib


Expand All @@ -24,17 +24,18 @@ def _set_password(self, value):

if __name__ == "__main__":

Model.session = session.SQLite(':memory:')
Model.db_config = dict(DB_PATH=':memory:',DB_TYPE='sqlite')

# Create the users table for this example
Model.session.execute('''
Model().connection.execute('''
CREATE TABLE users (
id INTEGER PRIMARY KEY ASC,
username VARCHAR(255) UNIQUE,
fullname VARCHAR(255),
password VARCHAR(256)
)
''')


# Let's create some users
user_list = [
Expand Down
60 changes: 45 additions & 15 deletions src/pyormish/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
limitations under the License.
"""
import logging
import session
connection = None


class Model(object):
Expand All @@ -32,20 +34,49 @@ class Model(object):
_JOINS = None
_ORDER_FIELDS = None

session = None
db = None
db_config = None
connection = None
d = None

def __init__(self, _id=None):
"""Build SQL queries for object and
if _id is present, attempt to load an
object where _PRIMARY_KEY = _id
"""
if not self.session:
global connection # shared connections

# You may be wondering what's up with recreating
# a mysql connection EACH time a Model is
# instantiated. The reason deals with concurrency.
# In very rare situations (1 out of about 10K)
# a connection will dissappear. This is the best
# solution I could think of. If you have a better
# suggestion, PLEASE let me know.
db_type = self.db_config.get('DB_TYPE','mysql')

if db_type == 'mysql':
self.connection = session.MySQL(
self.db_config['DB_HOST'],
self.db_config['DB_USER'],
self.db_config['DB_PASS'],
self.db_config['DB_NAME']
)

elif db_type == 'postgres':
if not connection:
connection = session.Postgres(self.db_config['DB_CONN_STRING'])
self.connection = connection

elif db_type == 'sqlite':
if not connection:
connection = session.SQLite(self.db_config['DB_PATH'])
self.connection = connection

if not self.connection:
raise StandardError("No database connection specified")
self.db = self.session

self.make_sql()
if(_id):
if _id:
olist = self.get_many([_id])
if not olist:
return None
Expand Down Expand Up @@ -142,9 +173,9 @@ def create(self, **kwargs):
sql = 'INSERT INTO `%s` (%s) VALUES (%s)'%(
self._TABLE_NAME, ','.join(c_fs), ','.join(v_fs))

if not self.db.execute(sql, kwargs):
if not self.connection.execute(sql, kwargs):
return None
_id = self.db._cursor.lastrowid
_id = self.connection._cursor.lastrowid
obj = self.get_many([_id])[0]
obj._create()
return obj
Expand All @@ -157,7 +188,7 @@ def commit(self):
for k in self.__dict__.keys():
if getattr(self, '_set_%s'%(k), None):
self.__dict__[k] = self.__dict__['_'+k]
if not self.db.execute(sql, self.__dict__):
if not self.connection.execute(sql, self.__dict__):
sql = sql % self.__dict__
raise StandardError("Unable to commit ```%s```"%(sql))
self._commit()
Expand All @@ -168,7 +199,7 @@ def delete(self):
if not self._DELETE_SQL:
raise StandardError("_DELETE_SQL is not defined")
for sql in self._DELETE_SQL:
self.db.execute(sql, self.__dict__)
self.connection.execute(sql, self.__dict__)
self._delete()
del(self)

Expand All @@ -191,7 +222,7 @@ def get_by_fields(self, **kwargs):
wheres.append("`%s`=%%(%s)s"%(k, k))

sql = self._GET_ID_SQL + " %s"%(" AND ".join(wheres))
rows = self.db.select(sql, kwargs)
rows = self.connection.select(sql, kwargs)
if not rows:
return None
key, _id = rows[0].popitem()
Expand All @@ -206,7 +237,7 @@ def get_by_where(self, where, **kwargs):
if "WHERE" not in self._GET_ID_SQL.upper()+where.upper():
where = "WHERE " + where
sql = self._GET_ID_SQL + " " + where
rows = self.db.select(sql, kwargs)
rows = self.connection.select(sql, kwargs)
if not rows:
return None
key, _id = rows[0].popitem()
Expand All @@ -218,7 +249,6 @@ def get_many(self, ids, order_fields=None):
"""
if not self._GET_MANY_SQL:
raise StandardError("_GET_MANY_SQL is not defined")

ids = [str(int(i)) for i in ids]
sql = self._GET_MANY_SQL % ','.join(ids)
if order_fields:
Expand All @@ -230,7 +260,7 @@ def get_many(self, ids, order_fields=None):
o_fs.append('`%s`.`%s` %s'%(self._TABLE_NAME, o[0], o[1]))
sql = sql + ' ORDER BY %s'%(','.join(o_fs))

dl = self.db.select(sql)
dl = self.connection.select(sql)
if not dl:
return []
return self._build_objects(dl)
Expand All @@ -239,7 +269,7 @@ def get_many_by_query(self, sql, **kwargs):
"""Return multiple objects from the database
based on match from query sql (str).
"""
dl = self.db.select(sql, kwargs)
dl = self.connection.select(sql, kwargs)
return self._build_objects(dl)

def get_many_by_fields(self, **kwargs):
Expand Down Expand Up @@ -269,7 +299,7 @@ def get_many_by_where(self, where, **kwargs):
sql = self._GET_ID_SQL + " " + where
sql = sql + " LIMIT %s,%s"%(int(kwargs.get('_start',0)),
int(kwargs.get('_limit',self._GET_LIMIT)))
rows = self.db.select(sql, kwargs)
rows = self.connection.select(sql, kwargs)
ids = [r.popitem()[1] for r in rows]
if not ids:
return []
Expand Down

0 comments on commit b3cbe26

Please sign in to comment.