Browse files

_get_host_handlers() now returns all host matches.

This approach has more clearly defined precedence rules than the
previous insertion-time strategy implemented in add_handlers().

It also correctly leaves pattern matching in the hands of the regular
expression evaluator as opposed to directly comparing pattern strings.
  • Loading branch information...
1 parent a5fffe3 commit 364321653935686a71097a3995acd7cf20c74d7a @jparise jparise committed Nov 18, 2012
Showing with 52 additions and 27 deletions.
  1. +35 −0 tornado/test/web_test.py
  2. +17 −27 tornado/web.py
View
35 tornado/test/web_test.py
@@ -847,6 +847,41 @@ def test_static_url(self):
wsgi_safe.append(CustomStaticFileTest)
+class HostMatchingTest(WebTestCase):
+ class Handler(RequestHandler):
+ def initialize(self, reply):
+ self.reply = reply
+
+ def get(self):
+ self.write(self.reply)
+
+ def get_handlers(self):
+ return [("/foo", HostMatchingTest.Handler, {"reply": "wildcard"})]
+
+ def test_host_matching(self):
+ self.app.add_handlers("www.example.com",
+ [("/foo", HostMatchingTest.Handler, {"reply": "[0]"})])
+ self.app.add_handlers(r"www\.example\.com",
+ [("/bar", HostMatchingTest.Handler, {"reply": "[1]"})])
+ self.app.add_handlers("www.example.com",
+ [("/baz", HostMatchingTest.Handler, {"reply": "[2]"})])
+
+ response = self.fetch("/foo")
+ self.assertEqual(response.body, b("wildcard"))
+ response = self.fetch("/bar")
+ self.assertEqual(response.code, 404)
+ response = self.fetch("/baz")
+ self.assertEqual(response.code, 404)
+
+ response = self.fetch("/foo", headers={'Host': 'www.example.com'})
+ self.assertEqual(response.body, b("[0]"))
+ response = self.fetch("/bar", headers={'Host': 'www.example.com'})
+ self.assertEqual(response.body, b("[1]"))
+ response = self.fetch("/baz", headers={'Host': 'www.example.com'})
+ self.assertEqual(response.body, b("[2]"))
+wsgi_safe.append(HostMatchingTest)
+
+
class NamedURLSpecGroupsTest(WebTestCase):
def get_handlers(self):
class EchoHandler(RequestHandler):
View
44 tornado/web.py
@@ -1314,32 +1314,21 @@ def listen(self, port, address="", **kwargs):
def add_handlers(self, host_pattern, host_handlers):
"""Appends the given handlers to our handler list.
- Note that host patterns are processed sequentially in the
- order they were added, and only the first matching pattern is
- used.
+ Host patterns are processed sequentially in the order they were
+ added. All matching patterns will be considered.
"""
if not host_pattern.endswith("$"):
host_pattern += "$"
-
- # Search for an existing handlers entry for this host pattern.
- handlers = None
- for entry in self.handlers:
- if entry[0].pattern == host_pattern:
- handlers = entry[1]
- break
-
- # Otherwise, add a new handlers entry for this host pattern.
- if handlers is None:
- handlers = []
- # The handlers with the wildcard host_pattern are a special
- # case - they're added in the constructor but should have lower
- # precedence than the more-precise handlers added later.
- # If a wildcard handler group exists, it should always be last
- # in the list, so insert new groups just before it.
- if self.handlers and self.handlers[-1][0].pattern == '.*$':
- self.handlers.insert(-1, (re.compile(host_pattern), handlers))
- else:
- self.handlers.append((re.compile(host_pattern), handlers))
+ handlers = []
+ # The handlers with the wildcard host_pattern are a special
+ # case - they're added in the constructor but should have lower
+ # precedence than the more-precise handlers added later.
+ # If a wildcard handler group exists, it should always be last
+ # in the list, so insert new groups just before it.
+ if self.handlers and self.handlers[-1][0].pattern == '.*$':
+ self.handlers.insert(-1, (re.compile(host_pattern), handlers))
+ else:
+ self.handlers.append((re.compile(host_pattern), handlers))
for spec in host_handlers:
if type(spec) is type(()):
@@ -1371,15 +1360,16 @@ def add_transform(self, transform_class):
def _get_host_handlers(self, request):
host = request.host.lower().split(':')[0]
+ matches = []
for pattern, handlers in self.handlers:
if pattern.match(host):
- return handlers
+ matches.extend(handlers)
# Look for default host if not behind load balancer (for debugging)
- if "X-Real-Ip" not in request.headers:
+ if not matches and "X-Real-Ip" not in request.headers:
for pattern, handlers in self.handlers:
if pattern.match(self.default_host):
- return handlers
- return None
+ matches.extend(handlers)
+ return matches or None
def _load_ui_methods(self, methods):
if type(methods) is types.ModuleType:

0 comments on commit 3643216

Please sign in to comment.