Skip to content

Commit

Permalink
Minor bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
danwos committed Feb 2, 2017
1 parent f95de0c commit 39f698c
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions groundwork_database/patterns/gw_sql_pattern.py
Expand Up @@ -90,11 +90,18 @@ def register(self, database, database_url, description, plugin=None):
raise DatabaseExistException("Database %s already registered by %s" % (
database, self._databases[database].plugin.name))

new_database = Database(database, database_url, description, plugin)
if plugin is None:
new_database = Database(database, database_url, description, app=self.app)
else:
new_database = Database(database, database_url, description, plugin=plugin)

self._databases[database] = new_database
self.log.debug("Database registered: %s" % database)

self.app.signals.send("db_registered", database=new_database, plugin=plugin)
if plugin is not None:
plugin.signals.send("db_registered", database=new_database)
else:
self.app.signals.send("db_registered", plugin=self.app, database=new_database)
return new_database

def unregister(self, database):
Expand Down Expand Up @@ -142,11 +149,12 @@ def get(self, name=None, plugin=None):


class Database:
def __init__(self, name, url, description, plugin):
def __init__(self, name, url, description, plugin=None, app=None):
self.name = name
self.database_url = url
self.description = description
self.plugin = plugin
self.app = app

self.engine = create_engine(url)

Expand All @@ -161,7 +169,7 @@ def __init__(self, name, url, description, plugin):
# Fore more visit: http://stackoverflow.com/a/28025843
self.Base.query = self.session.query_property()

self.classes = DatabaseClass(self, self.plugin)
self.classes = DatabaseClass(self, self.plugin, self.app)

def create_all(self):
return self.Base.metadata.create_all(self.engine)
Expand All @@ -186,10 +194,11 @@ def close(self, *args, **kwargs):


class DatabaseClass:
def __init__(self, database, plugin):
def __init__(self, database, plugin=None, app=None):
self.database = database
self._Base = database.Base
self.plugin = plugin
self.app = app
self._classes = {}

def register(self, clazz, name=None):
Expand All @@ -211,18 +220,23 @@ def register(self, clazz, name=None):
# self._classes[name] = TempClass

self._classes[name] = clazz
self.plugin.signals.send("db_class_registered", database=self.database, db_class=clazz)
if self.plugin is not None:
self.plugin.signals.send("db_class_registered", database=self.database, db_class=clazz)
else:
self.app.signals.send("db_class_registered", database=self.database, db_class=clazz, plugin=self.app)

return self._classes[name]

def unregister(self, name):
return self._classes.pop(name, None)

def get(self, clazz_name=None):
if clazz_name is not None and clazz_name in self._classes.keys():
if clazz_name is None:
return self._classes
elif clazz_name in self._classes.keys():
return self._classes[clazz_name]
else:
return None

return None


class DatabaseExistException(BaseException):
Expand Down

0 comments on commit 39f698c

Please sign in to comment.