Skip to content

Commit

Permalink
Refactored and added tests for backend choosing. Functionally still w…
Browse files Browse the repository at this point in the history
…orks the same, but mostly rewritten.
  • Loading branch information
rgalanakis committed May 31, 2014
1 parent 664bff0 commit f360492
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 17 deletions.
37 changes: 20 additions & 17 deletions goless/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,25 @@ def propogate_exc(self, errtype, *args):

return GeventBackend()

_backends = {
"stackless": _make_stackless,
"gevent": _make_gevent

_default_backends = {
'stackless': _make_stackless,
'gevent': _make_gevent
}

current = None

GOLESS_BACKEND = os.getenv("GOLESS_BACKEND", '')
if GOLESS_BACKEND:
if GOLESS_BACKEND not in _backends:
raise RuntimeError(
"Invalid backend %r specified. Valid backends are: %s"
% (GOLESS_BACKEND, _backends.keys()))
current = _backends[GOLESS_BACKEND]()
else:
try:
current = _make_stackless()
except ImportError:
current = _make_gevent()

def calculate_backend(name_from_env, backends=None):
if backends is None:
backends = _default_backends
if name_from_env:
if name_from_env not in backends:
raise RuntimeError(
'Invalid backend %r specified. Valid backends are: %s'
% (name_from_env, _default_backends.keys()))
return backends[name_from_env]()
for maker in backends.values():
return maker()
raise RuntimeError('No backend could be created.')


current = calculate_backend(os.getenv('GOLESS_BACKEND', ''))
28 changes: 28 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from . import BaseTests
from goless import backends

test_backends = dict(
a=lambda: 'be_A',
b=lambda: 'be_B',
)


class CalcBackendTests(BaseTests):
def calc(self, name, testbackends=test_backends):
return backends.calculate_backend(name, testbackends)

def test_valid_envvar_name(self):
be = self.calc('a')
self.assertEqual(be, 'be_A')

def test_invalid_envvar_name(self):
with self.assertRaises(RuntimeError):
self.calc('invalid')

def test_default(self):
be = self.calc('')
self.assertIn(be, [v() for v in test_backends.values()])

def test_all_invalid(self):
with self.assertRaises(RuntimeError):
self.calc('', {})

0 comments on commit f360492

Please sign in to comment.