Skip to content

Commit

Permalink
Merge pull request #10 from gmr/master
Browse files Browse the repository at this point in the history
Attempt to fix the reconnection race condition
  • Loading branch information
nvllsvm committed Apr 15, 2021
2 parents d074075 + 661a598 commit d9b2ed5
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 15 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.7.0
1.8.0
2 changes: 1 addition & 1 deletion bootstrap
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ report_done

cat > build/test-environment<<EOF
export ASYNC_TEST_TIMEOUT=5
export POSTGRES_URL=postgresql://postgres@${TEST_HOST}:$(get_exposed_port postgres 5432)/postgres
export POSTGRES_URL=postgresql://postgres@${TEST_HOST}:$(get_exposed_port postgres 5432)/postgres?application_name=sprockets_postgres
EOF

printf "\nBootstrap complete\n\nDon't forget to \"source build/test-environment\"\n"
64 changes: 51 additions & 13 deletions sprockets_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,18 +361,27 @@ async def postgres_connector(self,
as cursor:
yield PostgresConnector(
cursor, on_error, on_duration, timeout)
except (asyncio.TimeoutError, psycopg2.OperationalError) as err:
except (asyncio.TimeoutError,
psycopg2.OperationalError,
RuntimeError) as err:
if isinstance(err, psycopg2.OperationalError) and _attempt == 1:
LOGGER.critical('Disconnected from Postgres: %s', err)
if not self._postgres_reconnect.locked():
async with self._postgres_reconnect:
LOGGER.info('Reconnecting to Postgres with new Pool')
if await self._postgres_connect():
await self._postgres_connected.wait()
async with self.postgres_connector(
on_error, on_duration, timeout,
_attempt + 1) as connector:
yield connector
return
try:
await asyncio.wait_for(
self._postgres_connected.wait(),
self._postgres_settings['timeout'])
except asyncio.TimeoutError as error:
err = error
else:
async with self.postgres_connector(
on_error, on_duration, timeout,
_attempt + 1) as connector:
yield connector
return
if on_error is None:
raise ConnectionException(str(err))
exc = on_error(
Expand All @@ -382,6 +391,12 @@ async def postgres_connector(self,
else: # postgres_status.on_error does not return an exception
yield None

@property
def postgres_is_connected(self) -> bool:
"""Returns `True` if Postgres is currently connected"""
return self._postgres_connected is not None \
and self._postgres_connected.is_set()

async def postgres_status(self) -> dict:
"""Invoke from the ``/status`` RequestHandler to check that there is
a Postgres connection handler available and return info about the
Expand All @@ -408,8 +423,7 @@ async def postgres_status(self) -> dict:
}
"""
if not self._postgres_connected or \
not self._postgres_connected.is_set():
if not self.postgres_is_connected:
return {
'available': False,
'pool_size': 0,
Expand Down Expand Up @@ -497,12 +511,12 @@ async def _postgres_connect(self) -> bool:
else:
url = self._postgres_settings['url']

if self._postgres_pool:
self._postgres_pool.close()

safe_url = self._obscure_url_password(url)
LOGGER.debug('Connecting to %s', safe_url)

if self._postgres_pool and not self._postgres_pool.closed:
self._postgres_pool.close()

try:
self._postgres_pool = await pool.Pool.from_pool_fill(
url,
Expand All @@ -513,7 +527,7 @@ async def _postgres_connect(self) -> bool:
enable_json=self._postgres_settings['enable_json'],
enable_uuid=self._postgres_settings['enable_uuid'],
echo=False,
on_connect=None,
on_connect=self._on_postgres_connect,
pool_recycle=self._postgres_settings['connection_ttl'])
except psycopg2.Error as error: # pragma: nocover
LOGGER.warning(
Expand All @@ -535,6 +549,9 @@ def _obscure_url_password(url):
url = parse.urlunparse(parsed._replace(netloc=netloc))
return url

async def _on_postgres_connect(self, conn):
LOGGER.debug('New postgres connection %s', conn)

async def _postgres_on_start(self,
_app: web.Application,
loop: ioloop.IOLoop):
Expand Down Expand Up @@ -640,6 +657,7 @@ async def postgres_callproc(self,
:rtype: :class:`~sprockets_postgres.QueryResult`
"""
self._postgres_connection_check()
async with self.application.postgres_connector(
self.on_postgres_error,
self.on_postgres_timing,
Expand Down Expand Up @@ -679,6 +697,7 @@ async def postgres_execute(self,
:rtype: :class:`~sprockets_postgres.QueryResult`
"""
self._postgres_connection_check()
async with self.application.postgres_connector(
self.on_postgres_error,
self.on_postgres_timing,
Expand Down Expand Up @@ -726,6 +745,7 @@ async def post(self):
likely be more specific.
"""
self._postgres_connection_check()
async with self.application.postgres_connector(
self.on_postgres_error,
self.on_postgres_timing,
Expand Down Expand Up @@ -771,6 +791,11 @@ def on_postgres_error(self,
raise problemdetails.Problem(
status_code=409, title='Unique Violation')
raise web.HTTPError(409, reason='Unique Violation')
elif isinstance(exc, psycopg2.OperationalError):
if problemdetails:
raise problemdetails.Problem(
status_code=503, title='Database Error')
raise web.HTTPError(503, reason='Database Error')
elif isinstance(exc, psycopg2.Error):
if problemdetails:
raise problemdetails.Problem(
Expand Down Expand Up @@ -801,6 +826,19 @@ def on_postgres_timing(self,
LOGGER.debug('Postgres query %s duration: %s',
metric_name, duration)

def _postgres_connection_check(self):
"""Ensures Postgres is connected, exiting the request in error if not
:raises: problemdetails.Problem
:raises: web.HTTPError
"""
if not self.application.postgres_is_connected:
if problemdetails:
raise problemdetails.Problem(
status_code=503, title='Database Connection Error')
raise web.HTTPError(503, reason='Database Connection Error')


class StatusRequestHandler(web.RequestHandler):
"""A RequestHandler that can be used to expose API health or status"""
Expand Down
30 changes: 30 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,38 @@ async def test_postgres_status_before_first_connection(self):
'pool_free': 0})


class ReconnectionTestCast(TestCase):

@ttesting.gen_test
async def test_postgres_reconnect(self):
response = await self.http_client.fetch(self.get_url('/callproc'))
self.assertEqual(response.code, 200)
self.assertIsInstance(
uuid.UUID(json.loads(response.body)['value']), uuid.UUID)

# Force close all open connections for tests
conn = await aiopg.connect(os.environ['POSTGRES_URL'].split('?')[0])
cursor = await conn.cursor()
await cursor.execute(
'SELECT pg_terminate_backend(pid)'
' FROM pg_stat_activity'
" WHERE application_name = 'sprockets_postgres'")
await cursor.fetchall()
await asyncio.sleep(1)
response = await self.http_client.fetch(
self.get_url('/callproc'), raise_error=False)
self.assertEqual(response.code, 200)
conn.close()


class RequestHandlerMixinTestCase(TestCase):

def test_postgres_connected(self):
response = self.fetch('/status')
data = json.loads(response.body)
self.assertEqual(data['status'], 'ok')
self.assertTrue(self.app.postgres_is_connected)

def test_postgres_status(self):
response = self.fetch('/status')
data = json.loads(response.body)
Expand Down

0 comments on commit d9b2ed5

Please sign in to comment.