Skip to content

Commit

Permalink
Merge pull request #80 from ziadsawalha/context
Browse files Browse the repository at this point in the history
feat(middleware): add context middleware
  • Loading branch information
Paul Nelson authored and Paul Nelson committed Aug 28, 2015
2 parents 5519d3b + 9dc15f4 commit e7fa493
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Includes sample middleware for use with WSGI apps including bottle.

Middleware included:
- CORS: handles CORS requests
- Context: handles setting a threadlocal context and adds a transaction ID.


## <a name="rest"></a>REST API Tooling
Expand Down
92 changes: 92 additions & 0 deletions simpl/middleware/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2011-2015 Rackspace US, Inc.
# All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

"""Context WSGI Middleware.
Creates a context for the WSGI call and adds the following to the context:
- transaction_id: a UUID to identify the call (this can be passed in the
context to remote workers)
- base_url: the URL of the incoming call (overrideable)
The transaction id is returned in responses as an X-Transaction-Id header.
Example usage:
# Disable all CORS checks:
import bottle
from simpl.middleware import context
app = bottle.default_app()
chain = context.ContextMiddleware(app, override_url="https://my_app.io")
bottle.run(app=chain)
"""

import logging
import uuid

from simpl import threadlocal

LOG = logging.getLogger(__name__)


class ContextMiddleware(object): # pylint: disable=R0903

"""Adds a call context to the call environ which holds call data."""

def __init__(self, app, override_url=None):
"""Add a call context to the call environ which holds authn+z data."""
self.app = app
self.override_url = override_url

def __call__(self, environ, start_response):
"""Handle WSGI Request."""
if self.override_url:
url = self.override_url
else:
# PEP333: wsgi.url_scheme, HTTP_HOST, SERVER_NAME, and SERVER_PORT
# can be used to reconstruct a request's complete URL

# Much of the following is copied from bottle.py
http = environ.get('HTTP_X_FORWARDED_PROTO') \
or environ.get('wsgi.url_scheme', 'http')
host = environ.get('HTTP_X_FORWARDED_HOST') \
or environ.get('HTTP_HOST')
if not host:
# HTTP 1.1 requires a Host-header. This is for HTTP/1.0
# clients.
host = environ.get('SERVER_NAME', '127.0.0.1')
port = environ.get('SERVER_PORT')
if port and port != ('80' if http == 'http' else '443'):
host += ':' + port
url = "%s://%s" % (http, host)

# Use a default empty context
transaction_id = uuid.uuid4().hex
context = threadlocal.default()
context['base_url'] = url
context['transaction_id'] = transaction_id
environ['context'] = context
LOG.debug("Context created: base_url=%s, tid=%s", url, transaction_id)
return self.app(environ, self.start_response_callback(start_response,
transaction_id))

@staticmethod
def start_response_callback(start_response, transaction_id):
"""Intercept upstream start_response and adds our headers."""
def callback(status, headers, exc_info=None):
"""Add our headers to response using a closure."""
headers.append(('X-Transaction-Id', transaction_id))
# Call upstream start_response
start_response(status, headers, exc_info)
return callback
98 changes: 98 additions & 0 deletions tests/test_middleware_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# pylint: disable=C0103,R0904,R0903

# Copyright (c) 2011-2015 Rackspace US, Inc.
# All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

"""Tests for Context middleware."""

import unittest

import mock
from webtest.debugapp import debug_app

from simpl.middleware import context


class TestContextMiddleware(unittest.TestCase):

def setUp(self):
self.filter = context.ContextMiddleware(debug_app)
self.headers = []

def start_response(self, status, headers, exc_info=None):
"""Emulate WSGI start_response."""
self.headers += headers

def test_url_override(self):
env = {'REQUEST_METHOD': 'GET',
'PATH_INFO': '/'}
self.filter.override_url = "http://OVERRIDDEN"
self.filter(env, self.start_response)
self.assertEqual('http://OVERRIDDEN', env['context']['base_url'])

def test_no_url_scheme(self):
with self.assertRaises(KeyError):
self.filter({}, self.start_response)

def test_http_host(self):
env = {'REQUEST_METHOD': 'GET',
'PATH_INFO': '/',
'wsgi.url_scheme': 'http',
'HTTP_HOST': 'MOCK'}
self.filter(env, self.start_response)
self.assertEqual('http://MOCK', env['context']['base_url'])

def test_server_name(self):
env = {'REQUEST_METHOD': 'GET',
'PATH_INFO': '/',
'wsgi.url_scheme': 'http',
'SERVER_NAME': 'MOCK',
'SERVER_PORT': '80'}
self.filter(env, self.start_response)
self.assertEqual('http://MOCK', env['context']['base_url'])

def test_https_weird_port(self):
env = {'REQUEST_METHOD': 'GET',
'PATH_INFO': '/',
'wsgi.url_scheme': 'https',
'SERVER_NAME': 'MOCK',
'SERVER_PORT': '444'}
self.filter(env, self.start_response)
self.assertEqual('https://MOCK:444', env['context']['base_url'])

def test_http_weird_port(self):
env = {'REQUEST_METHOD': 'GET',
'PATH_INFO': '/',
'wsgi.url_scheme': 'http',
'SERVER_NAME': 'MOCK',
'SERVER_PORT': '81'}
self.filter(env, self.start_response)
self.assertEqual('http://MOCK:81', env['context']['base_url'])

@mock.patch.object(context.uuid, 'uuid4')
def test_transaction_id(self, mock_uuid):
mock_uuid.return_value = mock.Mock(hex="12345abc")
env = {'REQUEST_METHOD': 'GET',
'PATH_INFO': '/',
'wsgi.url_scheme': 'http',
'SERVER_NAME': 'MOCK',
'SERVER_PORT': '80'}
self.filter(env, self.start_response)
self.assertIn(('X-Transaction-Id', '12345abc'), self.headers)
self.assertIn('transaction_id', env['context'])
self.assertEqual('12345abc', env['context']['transaction_id'])


if __name__ == '__main__':
unittest.main()
1 change: 0 additions & 1 deletion tests/test_middleware_cors.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,5 @@ def test_conditional_import(self):
with self.assertRaises(RuntimeError):
app({}, self.start_response)


if __name__ == '__main__':
unittest.main()

0 comments on commit e7fa493

Please sign in to comment.