Skip to content

Commit

Permalink
Properly handles reconnections
Browse files Browse the repository at this point in the history
Fixes #8
  • Loading branch information
sdispater committed Jun 24, 2015
1 parent 079f19e commit 063924c
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 6 deletions.
2 changes: 1 addition & 1 deletion orator/connectors/connector.py
Expand Up @@ -4,7 +4,7 @@
class Connector(object):

RESERVED_KEYWORDS = [
'log_queries', 'driver', 'prefix'
'log_queries', 'driver', 'prefix', 'name'
]

def get_api(self):
Expand Down
3 changes: 2 additions & 1 deletion orator/connectors/mysql_connector.py
Expand Up @@ -22,7 +22,8 @@ class MySqlConnector(Connector):

RESERVED_KEYWORDS = [
'log_queries', 'driver', 'prefix',
'engine', 'charset', 'collation'
'engine', 'charset', 'collation',
'name'
]

def connect(self, config):
Expand Down
16 changes: 16 additions & 0 deletions orator/database_manager.py
@@ -1,10 +1,13 @@
# -*- coding: utf-8 -*-

import threading
import logging
from .connections.connection_resolver_interface import ConnectionResolverInterface
from .connectors.connection_factory import ConnectionFactory
from .exceptions import ArgumentError

logger = logging.getLogger('orator.database_manager')


class BaseDatabaseManager(ConnectionResolverInterface):

Expand Down Expand Up @@ -36,6 +39,7 @@ def connection(self, name=None):
name, type = self._parse_connection_name(name)

if name not in self._connections:
logger.debug('Initiating connection %s' % name)
connection = self._make_connection(name)

self._set_connection_for_type(connection, type)
Expand Down Expand Up @@ -80,13 +84,17 @@ def disconnect(self, name=None):
if name is None:
name = self.get_default_connection()

logger.debug('Disconnecting %s' % name)

if name in self._connections:
self._connections[name].disconnect()

def reconnect(self, name=None):
if name is None:
name = self.get_default_connection()

logger.debug('Reconnecting %s' % name)

self.disconnect(name)

if name not in self._connections:
Expand All @@ -95,14 +103,20 @@ def reconnect(self, name=None):
return self._refresh_api_connections(name)

def _refresh_api_connections(self, name):
logger.debug('Refreshing api connections for %s' % name)

fresh = self._make_connection(name)

return self._connections[name]\
.set_connection(fresh.get_connection())\
.set_read_connection(fresh.get_read_connection())

def _make_connection(self, name):
logger.debug('Making connection for %s' % name)

config = self._get_config(name)
if 'name' not in config:
config['name'] = name

if name in self._extensions:
return self._extensions[name](config, name)
Expand All @@ -115,6 +129,8 @@ def _make_connection(self, name):
return self._factory.make(config, name)

def _prepare(self, connection):
logger.debug('Preparing connection %s' % connection.get_name())

def reconnector(connection_):
self.reconnect(connection_.get_name())

Expand Down
17 changes: 15 additions & 2 deletions tests/test_database_manager.py
Expand Up @@ -2,7 +2,7 @@

from . import OratorTestCase
from . import mock
from .utils import MockConnection, MockManager, MockFactory
from .utils import MockConnection, MockManager

from orator.database_manager import DatabaseManager

Expand All @@ -23,15 +23,20 @@ def test_connection_method_create_a_new_connection_if_needed(self):

def test_manager_uses_factory_to_create_connections(self):
manager = self._get_real_manager()
original_make = manager._factory.make
manager._factory.make = mock.MagicMock()
manager.connection()

manager._factory.make.assert_called_with(
{
'name': 'sqlite',
'driver': 'sqlite',
'database': ':memory:'
}, 'sqlite'
)

manager._factory.make = original_make

def test_connection_can_select_connections(self):
manager = self._get_manager()
self.assertEqual(manager.connection(), manager.connection('sqlite'))
Expand All @@ -55,6 +60,14 @@ def test_default_database_with_one_database(self):

self.assertEqual('sqlite', manager.get_default_connection())

def test_reconnect(self):
manager = self._get_real_manager()

api_connection = manager.connection().get_connection()
manager.reconnect()
self.assertIsNot(manager.connection().get_connection(), api_connection)
self.assertIsNotNone(manager.connection().get_connection())

def _get_manager(self):
manager = MockManager({
'default': 'sqlite',
Expand All @@ -77,7 +90,7 @@ def _get_real_manager(self):
'driver': 'sqlite',
'database': ':memory:'
}
}, MockFactory().prepare_mock())
})

return manager

Expand Down
8 changes: 6 additions & 2 deletions tests/utils.py
Expand Up @@ -11,10 +11,14 @@

class MockConnection(ConnectionInterface):

def __init__(self, name=None):
if name:
self.get_name = lambda: name

def set_reconnector(self, reconnector):
return mock.MagicMock()

def prepare_mock(self):
def prepare_mock(self, name=None):
self.table = mock.MagicMock()
self.select = mock.MagicMock()
self.insert = mock.MagicMock()
Expand All @@ -38,7 +42,7 @@ class MockManager(DatabaseManager):

def prepare_mock(self):
self._make_connection = mock.MagicMock(
side_effect=lambda name: MockConnection().prepare_mock()
side_effect=lambda name: MockConnection(name).prepare_mock()
)

return self
Expand Down

0 comments on commit 063924c

Please sign in to comment.