Skip to content

Commit

Permalink
Merge 61817b4 into c360c34
Browse files Browse the repository at this point in the history
  • Loading branch information
martinhoefling committed Mar 29, 2015
2 parents c360c34 + 61817b4 commit eb4a32a
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 29 deletions.
84 changes: 81 additions & 3 deletions tornado/test/web_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tornado.testing import AsyncHTTPTestCase, ExpectLog, gen_test
from tornado.test.util import unittest
from tornado.util import u, ObjectDict, unicode_type, timedelta_to_seconds
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler
from tornado.web import RequestHandler, authenticated, Application, asynchronous, url, HTTPError, StaticFileHandler, _create_signature_v1, create_signed_value, decode_signed_value, ErrorHandler, UIModule, MissingArgumentError, stream_request_body, Finish, removeslash, addslash, RedirectHandler as WebRedirectHandler, get_signature_key_version

import binascii
import contextlib
Expand Down Expand Up @@ -71,10 +71,14 @@ def get(self):

class CookieTestRequestHandler(RequestHandler):
# stub out enough methods to make the secure_cookie functions work
def __init__(self):
def __init__(self, cookie_secret='0123456789', key_version=None):
# don't call super.__init__
self._cookies = {}
self.application = ObjectDict(settings=dict(cookie_secret='0123456789'))
if key_version is None:
self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret))
else:
self.application = ObjectDict(settings=dict(cookie_secret=cookie_secret,
key_version=key_version))

def get_cookie(self, name):
return self._cookies.get(name)
Expand Down Expand Up @@ -128,6 +132,44 @@ def test_arbitrary_bytes(self):
self.assertEqual(handler.get_secure_cookie('foo', min_version=1), b'\xe9')


# See SignedValueTest below for more.
class SecureCookieV2Test(unittest.TestCase):
KEY_VERSIONS = {
0: 'ajklasdf0ojaisdf',
1: 'aslkjasaolwkjsdf'
}
def test_round_trip(self):
handler = CookieTestRequestHandler()
handler.set_secure_cookie('foo', b'bar', version=2)
self.assertEqual(handler.get_secure_cookie('foo', min_version=2), b'bar')

def test_key_version_roundtrip(self):
handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
key_version=0)
handler.set_secure_cookie('foo', b'bar')
self.assertEqual(handler.get_secure_cookie('foo'), b'bar')

def test_key_version_increment_version(self):
handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
key_version=0)
handler.set_secure_cookie('foo', b'bar')
new_handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
key_version=1)
new_handler._cookies = handler._cookies
self.assertEqual(new_handler.get_secure_cookie('foo'), b'bar')

def test_key_version_invalidate_version(self):
handler = CookieTestRequestHandler(cookie_secret=self.KEY_VERSIONS,
key_version=1)
handler.set_secure_cookie('foo', b'bar')
new_key_versions = self.KEY_VERSIONS.copy()
new_key_versions.pop(1)
new_handler = CookieTestRequestHandler(cookie_secret=new_key_versions,
key_version=1)
new_handler._cookies = handler._cookies
self.assertEqual(new_handler.get_secure_cookie('foo'), None)


class CookieTest(WebTestCase):
def get_handlers(self):
class SetCookieHandler(RequestHandler):
Expand Down Expand Up @@ -2139,6 +2181,7 @@ def test_client_close(self):

class SignedValueTest(unittest.TestCase):
SECRET = "It's a secret to everybody"
SECRET_DICT = {0: "asdfbasdf", 1: "12312312", 2: "2342342"}

def past(self):
return self.present() - 86400 * 32
Expand Down Expand Up @@ -2245,6 +2288,41 @@ def test_non_ascii(self):
clock=self.present)
self.assertEqual(value, decoded)

def test_key_versioning_read_write_default_key(self):
value = b"\xe9"
signed = create_signed_value(SignedValueTest.SECRET_DICT,
"key", value, clock=self.present)
decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
"key", signed, clock=self.present)
self.assertEqual(value, decoded)

def test_key_versioning_read_write_non_default_key(self):
value = b"\xe9"
signed = create_signed_value(SignedValueTest.SECRET_DICT,
"key", value, clock=self.present,
key_version=1)
decoded = decode_signed_value(SignedValueTest.SECRET_DICT,
"key", signed, clock=self.present)
self.assertEqual(value, decoded)

def test_key_versioning_invalid_key(self):
value = b"\xe9"
signed = create_signed_value(SignedValueTest.SECRET_DICT,
"key", value, clock=self.present)
newkeys = SignedValueTest.SECRET_DICT.copy()
newkeys.pop(0)
decoded = decode_signed_value(newkeys,
"key", signed, clock=self.present)
self.assertEqual(None, decoded)

def test_key_version_retreival(self):
value = b"\xe9"
signed = create_signed_value(SignedValueTest.SECRET_DICT,
"key", value, clock=self.present,
key_version=1)
key_version = get_signature_key_version(signed)
self.assertEqual(1, key_version)


@wsgi_safe
class XSRFTest(SimpleHandlerTestCase):
Expand Down
119 changes: 93 additions & 26 deletions tornado/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,19 @@ def get(self):
DEFAULT_SIGNED_VALUE_MIN_VERSION = 1
"""The oldest signed value accepted by `.RequestHandler.get_secure_cookie`.
May be overrided by passing a ``min_version`` keyword argument.
May be overridden by passing a ``min_version`` keyword argument.
.. versionadded:: 3.2.1
"""

DEFAULT_SIGN_KEY_VERSION = 0
"""The current key index used by `.RequestHandler.set_secure_cookie`.
May be overridden by passing a ``key_version`` keyword argument.
.. versionadded:: x.x.x
"""


class RequestHandler(object):
"""Subclass this class and define `get()` or `post()` to make a handler.
Expand Down Expand Up @@ -613,8 +621,15 @@ def create_signed_value(self, name, value, version=None):
and made it the default.
"""
self.require_setting("cookie_secret", "secure cookies")
return create_signed_value(self.application.settings["cookie_secret"],
name, value, version=version)
secret = self.application.settings["cookie_secret"]
key_version = None
if isinstance(secret, dict):
if self.application.settings.get("key_version") is None:
raise Exception("key_version setting must be used for secret_key dicts")
key_version = self.application.settings["key_version"]

return create_signed_value(secret, name, value, version=version,
key_version=key_version)

def get_secure_cookie(self, name, value=None, max_age_days=31,
min_version=None):
Expand All @@ -635,6 +650,17 @@ def get_secure_cookie(self, name, value=None, max_age_days=31,
name, value, max_age_days=max_age_days,
min_version=min_version)

def get_secure_cookie_key_version(self, name, value=None):
"""Returns the signing key version of the secure cookie.
The version is returned as int.
"""
self.require_setting("cookie_secret", "secure cookies")
if value is None:
value = self.get_cookie(name)
return get_signature_key_version(value)


def redirect(self, url, permanent=False, status=None):
"""Sends a redirect to the given (optionally relative) URL.
Expand Down Expand Up @@ -2961,11 +2987,18 @@ def _time_independent_equals(a, b):
return result == 0


def create_signed_value(secret, name, value, version=None, clock=None):
def create_signed_value(secret, name, value, version=None, clock=None,
key_version=None):
if version is None:
version = DEFAULT_SIGNED_VALUE_VERSION
if clock is None:
clock = time.time

if key_version is None:
key_version = DEFAULT_SIGN_KEY_VERSION
else:
assert version >= 2, 'Version must be at least 2 for key version support'

timestamp = utf8(str(int(clock())))
value = base64.b64encode(utf8(value))
if version == 1:
Expand All @@ -2982,20 +3015,24 @@ def create_signed_value(secret, name, value, version=None, clock=None):
#
# The fields are:
# - format version (i.e. 2; no length prefix)
# - key version (currently 0; reserved for future
# key rotation features)
# - key version (integer, default is 0)
# - timestamp (integer seconds since epoch)
# - name (not encoded; assumed to be ~alphanumeric)
# - value (base64-encoded)
# - signature (hex-encoded; no length prefix)
def format_field(s):
return utf8("%d:" % len(s)) + utf8(s)
to_sign = b"|".join([
b"2|1:0",
b"2",
format_field(str(key_version)),
format_field(timestamp),
format_field(name),
format_field(value),
b''])

if isinstance(secret, dict):
secret = secret[key_version]

signature = _create_signature_v2(secret, to_sign)
return to_sign + signature
else:
Expand All @@ -3006,21 +3043,10 @@ def format_field(s):
_signed_value_version_re = re.compile(br"^([1-9][0-9]*)\|(.*)$")


def decode_signed_value(secret, name, value, max_age_days=31,
clock=None, min_version=None):
if clock is None:
clock = time.time
if min_version is None:
min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION
if min_version > 2:
raise ValueError("Unsupported min_version %d" % min_version)
if not value:
return None

# Figure out what version this is. Version 1 did not include an
def _get_version(value):
# Figures out what version value is. Version 1 did not include an
# explicit version field and started with arbitrary base64 data,
# which makes this tricky.
value = utf8(value)
m = _signed_value_version_re.match(value)
if m is None:
version = 1
Expand All @@ -3037,6 +3063,22 @@ def decode_signed_value(secret, name, value, max_age_days=31,
version = 1
except ValueError:
version = 1
return version


def decode_signed_value(secret, name, value, max_age_days=31,
clock=None, min_version=None):
if clock is None:
clock = time.time
if min_version is None:
min_version = DEFAULT_SIGNED_VALUE_MIN_VERSION
if min_version > 2:
raise ValueError("Unsupported min_version %d" % min_version)
if not value:
return None

value = utf8(value)
version = _get_version(value)

if version < min_version:
return None
Expand Down Expand Up @@ -3080,7 +3122,7 @@ def _decode_signed_value_v1(secret, name, value, max_age_days, clock):
return None


def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
def _decode_fields_v2(value):
def _consume_field(s):
length, _, rest = s.partition(b':')
n = int(length)
Expand All @@ -3091,16 +3133,28 @@ def _consume_field(s):
raise ValueError("malformed v2 signed value field")
rest = rest[n + 1:]
return field_value, rest

rest = value[2:] # remove version number
key_version, rest = _consume_field(rest)
timestamp, rest = _consume_field(rest)
name_field, rest = _consume_field(rest)
value_field, passed_sig = _consume_field(rest)
return int(key_version), timestamp, name_field, value_field, passed_sig


def _decode_signed_value_v2(secret, name, value, max_age_days, clock):
try:
key_version, rest = _consume_field(rest)
timestamp, rest = _consume_field(rest)
name_field, rest = _consume_field(rest)
value_field, rest = _consume_field(rest)
key_version, timestamp, name_field, value_field, passed_sig = _decode_fields_v2(value)
except ValueError:
return None
passed_sig = rest
signed_string = value[:-len(passed_sig)]

if isinstance(secret, dict):
try:
secret = secret[key_version]
except KeyError:
return None

expected_sig = _create_signature_v2(secret, signed_string)
if not _time_independent_equals(passed_sig, expected_sig):
return None
Expand All @@ -3116,6 +3170,19 @@ def _consume_field(s):
return None


def get_signature_key_version(value):
value = utf8(value)
version = _get_version(value)
if version < 2:
return None
try:
key_version, _, _, _, _ = _decode_fields_v2(value)
except ValueError:
return None

return key_version


def _create_signature_v1(secret, *parts):
hash = hmac.new(utf8(secret), digestmod=hashlib.sha1)
for part in parts:
Expand Down

0 comments on commit eb4a32a

Please sign in to comment.