Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

woo code

  • Loading branch information...
commit eb3112935be309e8a1251b101664c496329bdb3e 1 parent 27b86fb
@ssadler authored
Showing with 293 additions and 0 deletions.
  1. +13 −0 LICENSE
  2. +11 −0 setup.py
  3. +166 −0 squirrel.py
  4. +103 −0 test_squirrel.py
View
13 LICENSE
@@ -0,0 +1,13 @@
+Copyright 2012 Scott Sadler
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
View
11 setup.py
@@ -0,0 +1,11 @@
+#!/usr/bin/env python
+
+from setuptools import setup, find_packages
+
+setup(
+ name='squirrel',
+ description='Psycopg2 wrapper for tornadp',
+ version='0.1',
+ author='scott sadler',
+ py_modules=['squirrel'],
+)
View
166 squirrel.py
@@ -0,0 +1,166 @@
+import logging as _logging
+import weakref
+import psycopg2
+import psycopg2.extensions
+from collections import deque
+from functools import partial
+from tornado import ioloop, stack_context, gen
+
+
+logger = _logging.getLogger('squirrel.pool')
+
+
+class ConnectionPool(object):
+ """
+ Manages connections and provisions cursors.
+ """
+ def __init__(self, io_loop, max_connections=10, **connect_kwargs):
+ self.io_loop = io_loop
+ self.max_connections = max_connections
+ connect_kwargs['async'] = 1
+ self.connect_kwargs = connect_kwargs
+ self.connections = deque()
+ self.queue = deque()
+ self.owed = 0
+ self._refs = set()
+ self._closed = False
+
+ def cursor(self, callback):
+ """ Get a cursor """
+ dispatch = partial(self._dispatch, callback)
+ dispatch = stack_context.wrap(dispatch)
+ self.queue.append(dispatch)
+ self._process_queue()
+
+ def execute(self, sql, args, callback):
+ """ Shortcut to execute a query """
+ self.cursor(lambda cursor: cursor.execute(sql, args, callback))
+
+ def close(self):
+ """
+ Closes the pool. The effect of this is not that the pool can no
+ longer be used, but that connections will be closed if not in use.
+ """
+ self._closed = True
+ while self.connections:
+ self.connections.pop().close()
+
+ def _process_queue(self):
+ with stack_context.NullContext():
+ while self.queue and self.owed < self.max_connections:
+ self.owed += 1
+ self.queue.popleft()()
+
+ def _checkin(self, connection, err=None):
+ """ Called automatically on dereference of CursorFairy """
+ self.owed -= 1
+ if err:
+ logger.error("Error: %s, closing connection" % err)
+ connection.close()
+ elif self._closed:
+ connection.close()
+ elif not connection.closed:
+ self.connections.append(connection)
+
+ @gen.engine
+ def _dispatch(self, callback):
+ try:
+ connection = self.connections.popleft()
+ except IndexError:
+ connection = psycopg2.connect(**self.connect_kwargs)
+ try:
+ yield gen.Task(poll, connection, self.io_loop)
+ except Exception as e:
+ self._checkin(connection, e)
+ raise
+ else:
+ fairy = self._make_fairy(connection.cursor())
+ callback(fairy)
+
+ def _make_fairy(self, cursor):
+ # We must be careful here not to make a reference to the fairy,
+ # or it will never be dereferenced. But, we must make a reference to
+ # the cursor, or it will be dereferenced with the fairy.
+ on_deref = partial(self._on_fairy_deref, cursor)
+ fairy = CursorFairy(self.io_loop, cursor)
+ ref = weakref.ref(fairy, on_deref)
+ self._refs.add(ref)
+ return fairy
+
+ def _on_fairy_deref(self, cursor, ref):
+ with stack_context.NullContext():
+ self._refs.remove(ref)
+ self._checkin(cursor.connection)
+ self.io_loop.add_callback(self._process_queue)
+
+
+class CursorFairy(object):
+ CONNECTION_WARN = False
+
+ def __init__(self, io_loop, cursor):
+ self._io_loop = io_loop
+ self._cursor = cursor
+
+ def __getattr__(self, name):
+ """ Proxy missing attribute lookups to the cursor """
+ return getattr(self._cursor, name)
+
+ @property
+ def connection(self):
+ if not self.CONNECTION_WARN:
+ self.CONNECTION_WARN = True
+ logger.warning("Using the connection directly may cause "
+ "inconsistent state of the poller!")
+ return self._cursor.connection
+
+ def execute(self, sql, args, callback):
+ self._cursor.execute(sql, args)
+ self.poll(callback)
+
+ def poll(self, callback):
+ # bind self as first argument of callback.
+ # this makes the cursor available to the
+ # callee and ensures we aren't dereferenced until
+ # the query has finished executing.
+ callback = partial(callback, self)
+ poll(self._cursor.connection, self._io_loop, callback)
+
+
+class poll(object):
+ """
+ A poller that polls the PostgreSQL connection and calls the callback
+ when the connection state is `POLL_OK`, or an error occurs.
+ """
+ def __init__(self, connection, io_loop, callback):
+ self.connection = connection
+ self.io_loop = io_loop
+ self.callback = callback
+ self.tick(connection.fileno(), 0)
+
+ def tick(self, fd, events):
+ mask = -1
+ try:
+ mask = STATE_MAP.get(self.connection.poll())
+ if mask > 0:
+ if events == 0:
+ self.io_loop.add_handler(fd, self.tick, mask)
+ elif events > 0:
+ self.io_loop.update_handler(fd, mask)
+ elif mask < 0:
+ raise psycopg2.OperationalError("Connection has unknown error state")
+ except:
+ self.callback = None
+ raise
+ finally:
+ if mask <= 0:
+ if events:
+ self.io_loop.remove_handler(fd)
+ if mask == 0:
+ self.callback()
+
+STATE_MAP = {
+ psycopg2.extensions.POLL_OK: 0,
+ psycopg2.extensions.POLL_READ: ioloop.IOLoop.ERROR | ioloop.IOLoop.READ,
+ psycopg2.extensions.POLL_WRITE: ioloop.IOLoop.ERROR | ioloop.IOLoop.WRITE,
+ psycopg2.extensions.POLL_ERROR: -1,
+}
View
103 test_squirrel.py
@@ -0,0 +1,103 @@
+import psycopg2
+import itertools
+from mock import patch
+from squirrel import ConnectionPool
+from tornado.testing import AsyncTestCase
+from tornado import gen
+
+
+class ConnectionPoolTestCase(AsyncTestCase):
+ def async(self, func, *args, **kwargs):
+ args = args + (self.stop,)
+ func(*args, **kwargs)
+ return self.wait()
+
+ def setUp(self):
+ super(ConnectionPoolTestCase, self).setUp()
+ dsn = "host=127.0.0.1 dbname=test port=5432"
+ self.provider = ConnectionPool(dsn=dsn,
+ io_loop=self.io_loop)
+
+ def test_connect_error_propagates_exception(self):
+ provider = ConnectionPool(dsn="host=127.0.0.1 dbname=test port=6432",
+ io_loop=self.io_loop)
+ provider.cursor(self.stop)
+ self.assertRaises(psycopg2.OperationalError, self.wait)
+
+ def test_query(self):
+ cursor = self.async(self.provider.cursor)
+ self.async(cursor.execute, 'select 1', ())
+ self.assertEqual((1,), cursor.fetchone())
+
+ def test_query_error_propagates_error(self):
+ cursor = self.async(self.provider.cursor)
+ cursor.execute("I AM BAD", (), self.stop)
+ self.assertRaises(psycopg2.ProgrammingError, self.wait)
+
+ def test_query_error_dereferences(self):
+ return
+ try:
+ self.async(self.provider.execute, "I AM BAD", ())
+ except:
+ import sys
+ sys.exc_clear()
+ self.assertEqual(1, len(self.provider.connections))
+
+ def test_query_error_propagates_error_2(self):
+ self.provider.execute("I AM BAD", (), self.stop)
+ self.assertRaises(psycopg2.ProgrammingError, self.wait)
+
+ @patch.object(ConnectionPool, '_checkin')
+ def test_connection_checkin_on_deref(self, checkin):
+ cursor = self.async(self.provider.cursor)
+ connection = cursor.connection
+ cursor = None
+ checkin.assert_called_once_with(connection)
+
+ @patch.object(ConnectionPool, '_checkin')
+ def test_shorthand_execute_doesnt_deref_fairy(self, checkin):
+ self.provider.cursor(self.stop)
+ self.wait().execute("select 1", (), self.stop)
+ self.assertEqual(0, checkin.call_count)
+ self.wait() # cursor returned here but we dont reference it
+ self.assertEqual(1, checkin.call_count)
+
+ def test_100_queries(self):
+ n = 100
+ c = itertools.count().next
+
+ @gen.engine
+ def query(i):
+ if i % 2:
+ cursor = yield gen.Task(self.provider.execute, 'select %s', (i,))
+ self.assertEqual((i,), cursor.fetchone())
+ else:
+ try:
+ cursor = yield gen.Task(self.provider.cursor)
+ yield gen.Task(cursor.execute, 'select _%s' % i, ())
+ except psycopg2.ProgrammingError as e:
+ self.assertIn('select _%s' % i, e.pgerror)
+
+ if c() == n - 1:
+ self.stop()
+
+ for i in range(n):
+ query(i)
+
+ self.wait()
+
+ def test_close_pool_eventually_closes_everything(self):
+ cursor1 = self.async(self.provider.cursor)
+ cursor2 = self.async(self.provider.cursor)
+ conn2 = cursor2.connection
+ cursor2 = None
+ self.assertEqual(1, len(self.provider.connections))
+ self.assertEqual(1, self.provider.owed)
+ self.provider.close()
+ self.assertTrue(conn2.closed)
+ self.assertEqual(0, len(self.provider.connections))
+ conn1 = cursor1.connection
+ cursor1 = None
+ self.assertEqual(0, self.provider.owed)
+ self.assertTrue(conn1.closed)
+
Please sign in to comment.
Something went wrong with that request. Please try again.