Skip to content

Commit

Permalink
add session package
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei Zhuo committed Mar 9, 2016
1 parent 5dc775d commit ddf4503
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 0 deletions.
1 change: 1 addition & 0 deletions cocopot/session/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .globals import session, session_config
59 changes: 59 additions & 0 deletions cocopot/session/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@

class SessionConfig(object):
def __init__(self, **kwargs):
pass

class ModelDict(dict):
def __init__(self, *a, **k):
super(ModelDict, self).__init__(*a, **k)
self.dirty = False

def __setitem__(self, *a, **k):
self.mark_dirty()
super(ModelDict, self).__setitem__(*a, **k)

def __delitem__(self, *a, **k):
self.mark_dirty()
super(ModelDict, self).__delitem__(*a, **k)

def mark_clean(self):
self.dirty = False

def mark_dirty(self):
self.dirty = True

def clear(self, *a, **k):
self.mark_dirty()
super(ModelDict, self).clear(*a, **k)

def pop(self, *a, **k):
self.mark_dirty()
super(ModelDict, self).pop(*a, **k)

def popitem(self, *a, **k):
self.mark_dirty()
super(ModelDict, self).popitem(*a, **k)

def setdefault(self, *a, **k):
self.mark_dirty()
super(ModelDict, self).setdefault(*a, **k)

def update(self, *a, **k):
self.mark_dirty()
super(ModelDict, self).update(*a, **k)

def __repr__(self):
return '<%s %s>' % (
self.__class__.__name__,
dict.__repr__(self)
)

class BaseSession(ModelDict):
def __init__(self):
pass

def open(self, request):
raise NotImplementedError()

def save(self, response):
raise NotImplementedError()
32 changes: 32 additions & 0 deletions cocopot/session/cookie_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

from .base import BaseSession
import hashlib
from .utils import sign_payload, validate_payload, load_payload, dump_payload
from .globals import session_config


class SecureCookieSession(BaseSession):
salt = 'cookie-session'

def __init__(self):
pass

def decode_session(self, data):
validated, ret = validate_payload(data)
if validated:
return load_payload(ret)
return None

def encode_session(self):
s = dump_payload(data)
return sign_payload(s, session_config.get('secret_key'), salt)

def open(self, request):
value = request.get_cookie('session') or ''
data = self.decode_session(value) or {}
self.update(data)


def save(self, response):
data = self.encode_session(dict(self))
response.set_cookie('session', data)
19 changes: 19 additions & 0 deletions cocopot/session/globals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

from .base import SessionConfig

session_config = SessionConfig()

from cocopot.local import LocalProxy

def create_session(*args):
return None

def _lookup_session_object():
from cocopot import request
session = getattr(request, 'session')
if not session:
setattr(request, 'session', create_session(request))
return getattr(request, 'session')


session = LocalProxy(_lookup_session_object)
114 changes: 114 additions & 0 deletions cocopot/session/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
try:
import simplejson as json
except:
import json
import sys
import hmac
import zlib
import time
import base64
from datetime import datetime

# 2011/01/01 in UTC
EPOCH = 1293840000

class BadPayload(Exception):
def __init__(self, message, original_error=None):
Exception.__init__(self, message)
#: If available, the error that indicates why the payload
#: was not valid. This might be `None`.
self.original_error = original_error

class SignatureExpired(Exception):
def __init__(self, message, payload=None):
Exception.__init__(self, message)
self.payload = payload

def to_bytes(s, encoding='utf-8', errors='strict'):
if isinstance(s, text_type):
s = s.encode(encoding, errors)
return s

def base64_encode(string):
"""base64 encodes a single bytestring (and is tolerant to getting
called with a unicode string).
The resulting bytestring is safe for putting into URLs.
"""
string = to_bytes(string)
return base64.urlsafe_b64encode(string).strip(b'=')


def base64_decode(string):
"""base64 decodes a single bytestring (and is tolerant to getting
called with a unicode string).
The result is also a bytestring.
"""
string = to_bytes(string, encoding='ascii', errors='ignore')
return base64.urlsafe_b64decode(string + b'=' * (-len(string) % 4))

def load_payload(payload):
decompress = False
if payload.startswith(b'.'):
payload = payload[1:]
decompress = True
try:
jsondata = base64_decode(payload)
except Exception as e:
raise BadPayload('Could not base64 decode the payload because of '
'an exception', original_error=e)
if decompress:
try:
jsondata = zlib.decompress(jsondata)
except Exception as e:
raise BadPayload('Could not zlib decompress the payload before '
'decoding the payload', original_error=e)
return json.loads(jsondata)

def dump_payload(data):
json = json.dumps(data, separators=(',', ':'))
is_compressed = False
compressed = zlib.compress(json)
if len(compressed) < (len(json) - 1):
json = compressed
is_compressed = True
base64d = base64_encode(json)
if is_compressed:
base64d = b'.' + base64d
return base64d

def gen_signature(value, key, salt):
value = to_bytes(value)
mac = hmac.new(key, msg=salt, digestmod=hashlib.sha1)
key = mac.digest()
mac = hmac.new(key, msg=value, digestmod=hashlib.sha1)
sig = mac.digest()
return base64_encode(sig)

def sign_payload(value, key, salt):
sep = "."
ts = int(time.time() - EPOCH)
value = '%s%s%s'%(value, sep, ts)
return value + sep + gen_signature(value)

def format_time(ts):
return time.strftime("%Y-%m-%d %H:%M:%S UTC", ts)

def validate_payload(value, key, salt, max_age=None):
parts = value.rsplit(".", 1)
if len(parts) != 2:
return False, None
signature = parts[-1]
value = parts[0]
if gen_signature(value, key, salt) != signature:
try:
ts = int(value.rsplit(".", 1)[-1])
except:
return False, "timestamp not valid"
if max_age is not None:
age = int(time.time() - EPOCH) - ts
if age > max_age:
raise SignatureExpired(
'Signature age %s > %s seconds, Expired at %s' % (age, max_age, format_time(int(time.time())+EPOCH)),
payload=value)
return False, "Signature %s wrong!"%(signature)
return True, value

0 comments on commit ddf4503

Please sign in to comment.