-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Wei Zhuo
committed
Mar 9, 2016
1 parent
5dc775d
commit ddf4503
Showing
5 changed files
with
225 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .globals import session, session_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
|
||
class SessionConfig(object): | ||
def __init__(self, **kwargs): | ||
pass | ||
|
||
class ModelDict(dict): | ||
def __init__(self, *a, **k): | ||
super(ModelDict, self).__init__(*a, **k) | ||
self.dirty = False | ||
|
||
def __setitem__(self, *a, **k): | ||
self.mark_dirty() | ||
super(ModelDict, self).__setitem__(*a, **k) | ||
|
||
def __delitem__(self, *a, **k): | ||
self.mark_dirty() | ||
super(ModelDict, self).__delitem__(*a, **k) | ||
|
||
def mark_clean(self): | ||
self.dirty = False | ||
|
||
def mark_dirty(self): | ||
self.dirty = True | ||
|
||
def clear(self, *a, **k): | ||
self.mark_dirty() | ||
super(ModelDict, self).clear(*a, **k) | ||
|
||
def pop(self, *a, **k): | ||
self.mark_dirty() | ||
super(ModelDict, self).pop(*a, **k) | ||
|
||
def popitem(self, *a, **k): | ||
self.mark_dirty() | ||
super(ModelDict, self).popitem(*a, **k) | ||
|
||
def setdefault(self, *a, **k): | ||
self.mark_dirty() | ||
super(ModelDict, self).setdefault(*a, **k) | ||
|
||
def update(self, *a, **k): | ||
self.mark_dirty() | ||
super(ModelDict, self).update(*a, **k) | ||
|
||
def __repr__(self): | ||
return '<%s %s>' % ( | ||
self.__class__.__name__, | ||
dict.__repr__(self) | ||
) | ||
|
||
class BaseSession(ModelDict): | ||
def __init__(self): | ||
pass | ||
|
||
def open(self, request): | ||
raise NotImplementedError() | ||
|
||
def save(self, response): | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
|
||
from .base import BaseSession | ||
import hashlib | ||
from .utils import sign_payload, validate_payload, load_payload, dump_payload | ||
from .globals import session_config | ||
|
||
|
||
class SecureCookieSession(BaseSession): | ||
salt = 'cookie-session' | ||
|
||
def __init__(self): | ||
pass | ||
|
||
def decode_session(self, data): | ||
validated, ret = validate_payload(data) | ||
if validated: | ||
return load_payload(ret) | ||
return None | ||
|
||
def encode_session(self): | ||
s = dump_payload(data) | ||
return sign_payload(s, session_config.get('secret_key'), salt) | ||
|
||
def open(self, request): | ||
value = request.get_cookie('session') or '' | ||
data = self.decode_session(value) or {} | ||
self.update(data) | ||
|
||
|
||
def save(self, response): | ||
data = self.encode_session(dict(self)) | ||
response.set_cookie('session', data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
|
||
from .base import SessionConfig | ||
|
||
session_config = SessionConfig() | ||
|
||
from cocopot.local import LocalProxy | ||
|
||
def create_session(*args): | ||
return None | ||
|
||
def _lookup_session_object(): | ||
from cocopot import request | ||
session = getattr(request, 'session') | ||
if not session: | ||
setattr(request, 'session', create_session(request)) | ||
return getattr(request, 'session') | ||
|
||
|
||
session = LocalProxy(_lookup_session_object) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
try: | ||
import simplejson as json | ||
except: | ||
import json | ||
import sys | ||
import hmac | ||
import zlib | ||
import time | ||
import base64 | ||
from datetime import datetime | ||
|
||
# 2011/01/01 in UTC | ||
EPOCH = 1293840000 | ||
|
||
class BadPayload(Exception): | ||
def __init__(self, message, original_error=None): | ||
Exception.__init__(self, message) | ||
#: If available, the error that indicates why the payload | ||
#: was not valid. This might be `None`. | ||
self.original_error = original_error | ||
|
||
class SignatureExpired(Exception): | ||
def __init__(self, message, payload=None): | ||
Exception.__init__(self, message) | ||
self.payload = payload | ||
|
||
def to_bytes(s, encoding='utf-8', errors='strict'): | ||
if isinstance(s, text_type): | ||
s = s.encode(encoding, errors) | ||
return s | ||
|
||
def base64_encode(string): | ||
"""base64 encodes a single bytestring (and is tolerant to getting | ||
called with a unicode string). | ||
The resulting bytestring is safe for putting into URLs. | ||
""" | ||
string = to_bytes(string) | ||
return base64.urlsafe_b64encode(string).strip(b'=') | ||
|
||
|
||
def base64_decode(string): | ||
"""base64 decodes a single bytestring (and is tolerant to getting | ||
called with a unicode string). | ||
The result is also a bytestring. | ||
""" | ||
string = to_bytes(string, encoding='ascii', errors='ignore') | ||
return base64.urlsafe_b64decode(string + b'=' * (-len(string) % 4)) | ||
|
||
def load_payload(payload): | ||
decompress = False | ||
if payload.startswith(b'.'): | ||
payload = payload[1:] | ||
decompress = True | ||
try: | ||
jsondata = base64_decode(payload) | ||
except Exception as e: | ||
raise BadPayload('Could not base64 decode the payload because of ' | ||
'an exception', original_error=e) | ||
if decompress: | ||
try: | ||
jsondata = zlib.decompress(jsondata) | ||
except Exception as e: | ||
raise BadPayload('Could not zlib decompress the payload before ' | ||
'decoding the payload', original_error=e) | ||
return json.loads(jsondata) | ||
|
||
def dump_payload(data): | ||
json = json.dumps(data, separators=(',', ':')) | ||
is_compressed = False | ||
compressed = zlib.compress(json) | ||
if len(compressed) < (len(json) - 1): | ||
json = compressed | ||
is_compressed = True | ||
base64d = base64_encode(json) | ||
if is_compressed: | ||
base64d = b'.' + base64d | ||
return base64d | ||
|
||
def gen_signature(value, key, salt): | ||
value = to_bytes(value) | ||
mac = hmac.new(key, msg=salt, digestmod=hashlib.sha1) | ||
key = mac.digest() | ||
mac = hmac.new(key, msg=value, digestmod=hashlib.sha1) | ||
sig = mac.digest() | ||
return base64_encode(sig) | ||
|
||
def sign_payload(value, key, salt): | ||
sep = "." | ||
ts = int(time.time() - EPOCH) | ||
value = '%s%s%s'%(value, sep, ts) | ||
return value + sep + gen_signature(value) | ||
|
||
def format_time(ts): | ||
return time.strftime("%Y-%m-%d %H:%M:%S UTC", ts) | ||
|
||
def validate_payload(value, key, salt, max_age=None): | ||
parts = value.rsplit(".", 1) | ||
if len(parts) != 2: | ||
return False, None | ||
signature = parts[-1] | ||
value = parts[0] | ||
if gen_signature(value, key, salt) != signature: | ||
try: | ||
ts = int(value.rsplit(".", 1)[-1]) | ||
except: | ||
return False, "timestamp not valid" | ||
if max_age is not None: | ||
age = int(time.time() - EPOCH) - ts | ||
if age > max_age: | ||
raise SignatureExpired( | ||
'Signature age %s > %s seconds, Expired at %s' % (age, max_age, format_time(int(time.time())+EPOCH)), | ||
payload=value) | ||
return False, "Signature %s wrong!"%(signature) | ||
return True, value |