Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

updates to piston

  • Loading branch information...
commit b06e9e8d120e37c671fec6d43299adc6168b35a3 1 parent 20e656a
@poelzi authored
View
15 external/piston/__init__.py
@@ -0,0 +1,15 @@
+try:
+ import pkg_resources
+ pkg_resources.declare_namespace(__name__)
+except ImportError:
+ # don't prevent use of paste if pkg_resources isn't installed
+ from pkgutil import extend_path
+ __path__ = extend_path(__path__, __name__)
+
+try:
+ import modulefinder
+except ImportError:
+ pass
+else:
+ for p in __path__:
+ modulefinder.AddPackagePath(__name__, p)
View
53 external/piston/authentication.py
@@ -1,3 +1,5 @@
+import binascii
+
import oauth
from django.http import HttpResponse, HttpResponseRedirect
from django.contrib.auth.models import User, AnonymousUser
@@ -45,17 +47,20 @@ def is_authenticated(self, request):
if not auth_string:
return False
- (authmeth, auth) = auth_string.split(" ", 1)
-
- if not authmeth.lower() == 'basic':
+ try:
+ (authmeth, auth) = auth_string.split(" ", 1)
+
+ if not authmeth.lower() == 'basic':
+ return False
+
+ auth = auth.strip().decode('base64')
+ (username, password) = auth.split(':', 1)
+ except (ValueError, binascii.Error):
return False
-
- auth = auth.strip().decode('base64')
- (username, password) = auth.split(':', 1)
request.user = self.auth_func(username=username, password=password) \
or AnonymousUser()
-
+
return not request.user in (False, None, AnonymousUser())
def challenge(self):
@@ -64,6 +69,20 @@ def challenge(self):
resp.status_code = 401
return resp
+ def __repr__(self):
+ return u'<HTTPBasic: realm=%s>' % self.realm
+
+class HttpBasicSimple(HttpBasicAuthentication):
+ def __init__(self, realm, username, password):
+ self.user = User.objects.get(username=username)
+ self.password = password
+
+ super(HttpBasicSimple, self).__init__(auth_func=self.hash, realm=realm)
+
+ def hash(self, username, password):
+ if username == self.user.username and password == self.password:
+ return self.user
+
def load_data_store():
'''Load data store for OAuth Consumers, Tokens, Nonces and Resources
'''
@@ -92,9 +111,19 @@ def initialize_server_request(request):
"""
Shortcut for initialization.
"""
+ if request.method == "POST": #and \
+# request.META['CONTENT_TYPE'] == "application/x-www-form-urlencoded":
+ params = dict(request.REQUEST.items())
+ else:
+ params = { }
+
+ # Seems that we want to put HTTP_AUTHORIZATION into 'Authorization'
+ # for oauth.py to understand. Lovely.
+ request.META['Authorization'] = request.META.get('HTTP_AUTHORIZATION', '')
+
oauth_request = oauth.OAuthRequest.from_request(
request.method, request.build_absolute_uri(),
- headers=request.META, parameters=dict(request.REQUEST.items()),
+ headers=request.META, parameters=params,
query_string=request.environ.get('QUERY_STRING', ''))
if oauth_request:
@@ -138,8 +167,8 @@ def oauth_request_token(request):
def oauth_auth_view(request, token, callback, params):
form = forms.OAuthAuthenticationForm(initial={
'oauth_token': token.key,
- 'oauth_callback': callback,
- })
+ 'oauth_callback': token.get_callback_url() or callback,
+ })
return render_to_response('piston/authorize_token.html',
{ 'form': form }, RequestContext(request))
@@ -160,7 +189,7 @@ def oauth_user_auth(request):
callback = oauth_server.get_callback(oauth_request)
except:
callback = None
-
+
if request.method == "GET":
params = oauth_request.get_normalized_parameters()
@@ -177,6 +206,7 @@ def oauth_user_auth(request):
args = '?'+token.to_string(only_key=True)
else:
args = '?error=%s' % 'Access not granted by user.'
+ print "FORM ERROR", form.errors
if not callback:
callback = getattr(settings, 'OAUTH_CALLBACK_VIEW')
@@ -230,6 +260,7 @@ def is_authenticated(self, request):
if consumer and token:
request.user = token.user
+ request.consumer = consumer
request.throttle_extra = token.consumer.id
return True
View
65 external/piston/doc.py
@@ -1,6 +1,7 @@
import inspect, handler
from piston.handler import typemapper
+from piston.handler import handler_tracker
from django.core.urlresolvers import get_resolver, get_callable, get_script_prefix
from django.shortcuts import render_to_response
@@ -12,7 +13,7 @@ def generate_doc(handler_cls):
for the given handler. Use this to generate
documentation for your API.
"""
- if not type(handler_cls) is handler.HandlerMetaClass:
+ if isinstance(type(handler_cls), handler.HandlerMetaClass):
raise ValueError("Give me handler, not %s" % type(handler_cls))
return HandlerDocumentation(handler_cls)
@@ -36,7 +37,8 @@ def iter_args(self):
else:
yield (arg, None)
- def get_signature(self, parse_optional=True):
+ @property
+ def signature(self, parse_optional=True):
spec = ""
for argn, argdef in self.iter_args():
@@ -53,18 +55,25 @@ def get_signature(self, parse_optional=True):
return spec.replace("=None", "=<optional>")
return spec
-
- signature = property(get_signature)
- def get_doc(self):
+ @property
+ def doc(self):
return inspect.getdoc(self.method)
- doc = property(get_doc)
-
- def get_name(self):
+ @property
+ def name(self):
return self.method.__name__
-
- name = property(get_name)
+
+ @property
+ def http_name(self):
+ if self.name == 'read':
+ return 'GET'
+ elif self.name == 'create':
+ return 'POST'
+ elif self.name == 'delete':
+ return 'DELETE'
+ elif self.name == 'update':
+ return 'PUT'
def __repr__(self):
return "<Method: %s>" % self.name
@@ -75,8 +84,12 @@ def __init__(self, handler):
def get_methods(self, include_default=False):
for method in "read create update delete".split():
- met = getattr(self.handler, method)
- stale = inspect.getmodule(met) is handler
+ met = getattr(self.handler, method, None)
+
+ if not met:
+ continue
+
+ stale = inspect.getmodule(met.im_func) is not inspect.getmodule(self.handler)
if not self.handler.is_anonymous:
if met and (not stale or include_default):
@@ -92,16 +105,24 @@ def get_all_methods(self):
@property
def is_anonymous(self):
- return handler.is_anonymous
+ return self.handler.is_anonymous
def get_model(self):
return getattr(self, 'model', None)
- def get_doc(self):
+ @property
+ def has_anonymous(self):
+ return self.handler.anonymous
+
+ @property
+ def anonymous(self):
+ if self.has_anonymous:
+ return HandlerDocumentation(self.handler.anonymous)
+
+ @property
+ def doc(self):
return self.handler.__doc__
- doc = property(get_doc)
-
@property
def name(self):
return self.handler.__name__
@@ -159,8 +180,16 @@ def documentation_view(request):
"""
docs = [ ]
- for handler, (model, anonymous) in typemapper.iteritems():
+ for handler in handler_tracker:
docs.append(generate_doc(handler))
-
+
+ def _compare(doc1, doc2):
+ #handlers and their anonymous counterparts are put next to each other.
+ name1 = doc1.name.replace("Anonymous", "")
+ name2 = doc2.name.replace("Anonymous", "")
+ return cmp(name1, name2)
+
+ docs.sort(_compare)
+
return render_to_response('documentation.html',
{ 'docs': docs }, RequestContext(request))
View
171 external/piston/emitters.py
@@ -1,6 +1,7 @@
from __future__ import generators
import decimal, re, inspect
+import copy
try:
# yaml isn't standard with python. It shouldn't be required if it
@@ -24,11 +25,13 @@ def any(iterable):
from django.utils import simplejson
from django.utils.xmlutils import SimplerXMLGenerator
from django.utils.encoding import smart_unicode
+from django.core.urlresolvers import reverse, NoReverseMatch
from django.core.serializers.json import DateTimeAwareJSONEncoder
from django.http import HttpResponse
from django.core import serializers
from utils import HttpStatusCode, Mimer
+from validate_jsonp import is_valid_jsonp_callback_value
try:
import cStringIO as StringIO
@@ -40,6 +43,9 @@ def any(iterable):
except ImportError:
import pickle
+# Allow people to change the reverser (default `permalink`).
+reverser = permalink
+
class Emitter(object):
"""
Super emitter. All other emitters should subclass
@@ -47,8 +53,15 @@ class Emitter(object):
conveniently returns a serialized `dict`. This is
usually the only method you want to use in your
emitter. See below for examples.
+
+ `RESERVED_FIELDS` was introduced when better resource
+ method detection came, and we accidentially caught these
+ as the methods on the handler. Issue58 says that's no good.
"""
EMITTERS = { }
+ RESERVED_FIELDS = set([ 'read', 'update', 'create',
+ 'delete', 'model', 'anonymous',
+ 'allowed_methods', 'fields', 'exclude' ])
def __init__(self, payload, typemapper, handler, fields=(), anonymous=True):
self.typemapper = typemapper
@@ -56,29 +69,30 @@ def __init__(self, payload, typemapper, handler, fields=(), anonymous=True):
self.handler = handler
self.fields = fields
self.anonymous = anonymous
-
+
if isinstance(self.data, Exception):
raise
-
- def method_fields(self, data, fields):
- if not data:
+
+ def method_fields(self, handler, fields):
+ if not handler:
return { }
- has = dir(data)
ret = dict()
-
- for field in fields:
- if field in has:
- ret[field] = getattr(data, field)
-
+
+ for field in fields - Emitter.RESERVED_FIELDS:
+ t = getattr(handler, str(field), None)
+
+ if t and callable(t):
+ ret[field] = t
+
return ret
-
+
def construct(self):
"""
Recursively serialize a lot of types, and
in cases where it doesn't recognize the type,
it will fall back to Django's `smart_unicode`.
-
+
Returns `dict`.
"""
def _any(thing, fields=()):
@@ -86,17 +100,17 @@ def _any(thing, fields=()):
Dispatch, all types are routed through here.
"""
ret = None
-
+
if isinstance(thing, QuerySet):
ret = _qs(thing, fields=fields)
- elif isinstance(thing, (tuple, list)):
- ret = _list(thing)
+ elif isinstance(thing, (tuple, list, set)):
+ ret = _list(thing, fields=fields)
elif isinstance(thing, dict):
- ret = _dict(thing)
+ ret = _dict(thing, fields)
elif isinstance(thing, decimal.Decimal):
ret = str(thing)
elif isinstance(thing, Model):
- ret = _model(thing, fields=fields)
+ ret = _model(thing, fields)
elif isinstance(thing, HttpResponse):
raise HttpStatusCode(thing)
elif inspect.isfunction(thing):
@@ -106,6 +120,8 @@ def _any(thing, fields=()):
f = thing.__emittable__
if inspect.ismethod(f) and len(inspect.getargspec(f)[0]) == 1:
ret = _any(f())
+ elif repr(thing).startswith("<django.db.models.fields.related.RelatedManager"):
+ ret = _any(thing.all())
else:
ret = smart_unicode(thing, strings_only=True)
@@ -116,19 +132,19 @@ def _fk(data, field):
Foreign keys.
"""
return _any(getattr(data, field.name))
-
+
def _related(data, fields=()):
"""
Foreign keys.
"""
return [ _model(m, fields) for m in data.iterator() ]
-
+
def _m2m(data, field, fields=()):
"""
Many to many (re-route to `_model`.)
"""
return [ _model(m, fields) for m in getattr(data, field.name).iterator() ]
-
+
def _model(data, fields=()):
"""
Models. Will respect the `fields` and/or
@@ -137,7 +153,7 @@ def _model(data, fields=()):
ret = { }
handler = self.in_typemapper(type(data), self.anonymous)
get_absolute_uri = False
-
+
if handler or fields:
v = lambda f: getattr(data, f.attname)
@@ -152,27 +168,30 @@ def _model(data, fields=()):
if 'absolute_uri' in get_fields:
get_absolute_uri = True
-
+
if not get_fields:
get_fields = set([ f.attname.replace("_id", "", 1)
- for f in data._meta.fields ])
-
+ for f in data._meta.fields + data._meta.virtual_fields])
+
+ if hasattr(mapped, 'extra_fields'):
+ get_fields.update(mapped.extra_fields)
+
# sets can be negated.
for exclude in exclude_fields:
if isinstance(exclude, basestring):
get_fields.discard(exclude)
-
+
elif isinstance(exclude, re._pattern_type):
for field in get_fields.copy():
if exclude.match(field):
get_fields.discard(field)
-
+
else:
get_fields = set(fields)
met_fields = self.method_fields(handler, get_fields)
- for f in data._meta.local_fields:
+ for f in data._meta.local_fields + data._meta.virtual_fields:
if f.serialize and not any([ p in met_fields for p in [ f.attname, f.name ]]):
if not f.rel:
if f.attname in get_fields:
@@ -182,16 +201,15 @@ def _model(data, fields=()):
if f.attname[:-3] in get_fields:
ret[f.name] = _fk(data, f)
get_fields.remove(f.name)
-
+
for mf in data._meta.many_to_many:
if mf.serialize and mf.attname not in met_fields:
if mf.attname in get_fields:
ret[mf.name] = _m2m(data, mf)
get_fields.remove(mf.name)
-
+
# try to get the remainder of fields
for maybe_field in get_fields:
-
if isinstance(maybe_field, (list, tuple)):
model, fields = maybe_field
inst = getattr(data, model, None)
@@ -211,11 +229,11 @@ def _model(data, fields=()):
# using different names.
ret[maybe_field] = _any(met_fields[maybe_field](data))
- else:
+ else:
maybe = getattr(data, maybe_field, None)
if maybe:
if callable(maybe):
- if len(inspect.getargspec(maybe)[0]) == 1:
+ if len(inspect.getargspec(maybe)[0]) <= 1:
ret[maybe_field] = _any(maybe())
else:
ret[maybe_field] = _any(maybe)
@@ -228,65 +246,68 @@ def _model(data, fields=()):
else:
for f in data._meta.fields:
ret[f.attname] = _any(getattr(data, f.attname))
-
+
fields = dir(data.__class__) + ret.keys()
add_ons = [k for k in dir(data) if k not in fields]
-
+
for k in add_ons:
ret[k] = _any(getattr(data, k))
-
+
# resouce uri
if self.in_typemapper(type(data), self.anonymous):
handler = self.in_typemapper(type(data), self.anonymous)
if hasattr(handler, 'resource_uri'):
- url_id, fields = handler.resource_uri()
- ret['resource_uri'] = permalink( lambda: (url_id,
- (getattr(data, f) for f in fields) ) )()
-
+ url_id, fields = handler.resource_uri(data)
+
+ try:
+ ret['resource_uri'] = reverser( lambda: (url_id, fields) )()
+ except NoReverseMatch, e:
+ pass
+
if hasattr(data, 'get_api_url') and 'resource_uri' not in ret:
try: ret['resource_uri'] = data.get_api_url()
except: pass
-
+
# absolute uri
if hasattr(data, 'get_absolute_url') and get_absolute_uri:
try: ret['absolute_uri'] = data.get_absolute_url()
except: pass
-
+
return ret
-
+
def _qs(data, fields=()):
"""
Querysets.
"""
return [ _any(v, fields) for v in data ]
-
- def _list(data):
+
+ def _list(data, fields=()):
"""
Lists.
"""
- return [ _any(v) for v in data ]
-
- def _dict(data):
+ return [ _any(v, fields) for v in data ]
+
+ def _dict(data, fields=()):
"""
Dictionaries.
"""
- return dict([ (k, _any(v)) for k, v in data.iteritems() ])
-
+ return dict([ (k, _any(v, fields)) for k, v in data.iteritems() ])
+
# Kickstart the seralizin'.
return _any(self.data, self.fields)
-
+
def in_typemapper(self, model, anonymous):
for klass, (km, is_anon) in self.typemapper.iteritems():
if model is km and is_anon is anonymous:
return klass
-
+
def render(self):
"""
This super emitter does not implement `render`,
this is a job for the specific emitter below.
"""
raise NotImplementedError("Please implement render.")
-
+
def stream_render(self, request, stream=True):
"""
Tells our patched middleware not to look
@@ -295,7 +316,7 @@ def stream_render(self, request, stream=True):
more memory friendly for large datasets.
"""
yield self.render(request)
-
+
@classmethod
def get(cls, format):
"""
@@ -305,19 +326,19 @@ def get(cls, format):
return cls.EMITTERS.get(format)
raise ValueError("No emitters found for type %s" % format)
-
+
@classmethod
def register(cls, name, klass, content_type='text/plain'):
"""
Register an emitter.
-
+
Parameters::
- `name`: The name of the emitter ('json', 'xml', 'yaml', ...)
- `klass`: The emitter class.
- `content_type`: The content type to serve response as.
"""
cls.EMITTERS[name] = (klass, content_type)
-
+
@classmethod
def unregister(cls, name):
"""
@@ -325,7 +346,7 @@ def unregister(cls, name):
want to provide output in one of the built-in emitters.
"""
return cls.EMITTERS.pop(name, None)
-
+
class XMLEmitter(Emitter):
def _to_xml(self, xml, data):
if isinstance(data, (list, tuple)):
@@ -343,16 +364,16 @@ def _to_xml(self, xml, data):
def render(self, request):
stream = StringIO.StringIO()
-
+
xml = SimplerXMLGenerator(stream, "utf-8")
xml.startDocument()
xml.startElement("response", {})
-
+
self._to_xml(xml, self.construct())
-
+
xml.endElement("response")
xml.endDocument()
-
+
return stream.getvalue()
Emitter.register('xml', XMLEmitter, 'text/xml; charset=utf-8')
@@ -363,18 +384,18 @@ class JSONEmitter(Emitter):
JSON emitter, understands timestamps.
"""
def render(self, request):
- cb = request.GET.get('callback')
+ cb = request.GET.get('callback', None)
seria = simplejson.dumps(self.construct(), cls=DateTimeAwareJSONEncoder, ensure_ascii=False, indent=4)
# Callback
- if cb:
+ if cb and is_valid_jsonp_callback_value(cb):
return '%s(%s)' % (cb, seria)
return seria
-
+
Emitter.register('json', JSONEmitter, 'application/json; charset=utf-8')
-Mimer.register(simplejson.loads, ('application/json','application/json; charset=utf-8', 'application/json; charset=UTF-8'))
-
+Mimer.register(simplejson.loads, ('application/json',))
+
class YAMLEmitter(Emitter):
"""
YAML emitter, uses `safe_dump` to omit the
@@ -385,7 +406,7 @@ def render(self, request):
if yaml: # Only register yaml if it was import successfully.
Emitter.register('yaml', YAMLEmitter, 'application/x-yaml; charset=utf-8')
- Mimer.register(yaml.load, ('application/x-yaml',))
+ Mimer.register(lambda s: dict(yaml.load(s)), ('application/x-yaml',))
class PickleEmitter(Emitter):
"""
@@ -393,9 +414,19 @@ class PickleEmitter(Emitter):
"""
def render(self, request):
return pickle.dumps(self.construct())
-
+
Emitter.register('pickle', PickleEmitter, 'application/python-pickle')
-Mimer.register(pickle.loads, ('application/python-pickle',))
+
+"""
+WARNING: Accepting arbitrary pickled data is a huge security concern.
+The unpickler has been disabled by default now, and if you want to use
+it, please be aware of what implications it will have.
+
+Read more: http://nadiana.com/python-pickle-insecure
+
+Uncomment the line below to enable it. You're doing so at your own risk.
+"""
+# Mimer.register(pickle.loads, ('application/python-pickle',))
class DjangoEmitter(Emitter):
"""
@@ -410,5 +441,5 @@ def render(self, request, format='xml'):
response = serializers.serialize(format, self.data, indent=True)
return response
-
+
Emitter.register('django', DjangoEmitter, 'text/xml; charset=utf-8')
View
46 external/piston/fixtures/models.json
@@ -0,0 +1,46 @@
+[
+ {
+ "pk": 2,
+ "model": "auth.user",
+ "fields": {
+ "username": "pistontestuser",
+ "first_name": "Piston",
+ "last_name": "User",
+ "is_active": true,
+ "is_superuser": false,
+ "is_staff": false,
+ "last_login": "2009-08-03 13:11:53",
+ "groups": [],
+ "user_permissions": [],
+ "password": "sha1$b6c1f$83d5879f3854f6e9d27f393e3bcb4b8db05cf671",
+ "email": "pistontestuser@example.com",
+ "date_joined": "2009-08-03 13:11:53"
+ }
+ },
+ {
+ "pk": 3,
+ "model": "auth.user",
+ "fields": {
+ "username": "pistontestconsumer",
+ "first_name": "Piston",
+ "last_name": "Consumer",
+ "is_active": true,
+ "is_superuser": false,
+ "is_staff": false,
+ "last_login": "2009-08-03 13:11:53",
+ "groups": [],
+ "user_permissions": [],
+ "password": "sha1$b6c1f$83d5879f3854f6e9d27f393e3bcb4b8db05cf671",
+ "email": "pistontestconsumer@example.com",
+ "date_joined": "2009-08-03 13:11:53"
+ }
+ },
+ {
+ "pk": 1,
+ "model": "sites.site",
+ "fields": {
+ "domain": "example.com",
+ "name": "example.com"
+ }
+ }
+]
View
27 external/piston/fixtures/oauth.json
@@ -0,0 +1,27 @@
+[
+ {
+ "pk": 1,
+ "model": "piston.consumer",
+ "fields": {
+ "status": "accepted",
+ "name": "Piston Test Consumer",
+ "secret": "T5XkNMkcjffDpC9mNQJbyQnJXGsenYbz",
+ "user": 2,
+ "key": "8aZSFj3W54h8J8sCpx",
+ "description": "A test consumer record for Piston unit tests."
+ }
+ },
+ {
+ "pk": 1,
+ "model": "piston.token",
+ "fields": {
+ "is_approved": true,
+ "timestamp": 1249347414,
+ "token_type": 2,
+ "secret": "qSWZq36t7yvkBquetYBkd8JxnuCu9jKk",
+ "user": 2,
+ "key": "Y7358vL5hDBbeP3HHL",
+ "consumer": 1
+ }
+ }
+]
View
2  external/piston/forms.py
@@ -23,7 +23,7 @@ def merge_from_initial(self):
class OAuthAuthenticationForm(forms.Form):
oauth_token = forms.CharField(widget=forms.HiddenInput)
- oauth_callback = forms.URLField(widget=forms.HiddenInput)
+ oauth_callback = forms.CharField(widget=forms.HiddenInput, required=False)
authorize_access = forms.BooleanField(required=True)
csrf_signature = forms.CharField(widget=forms.HiddenInput)
View
90 external/piston/handler.py
@@ -1,7 +1,11 @@
+import warnings
+
from utils import rc
from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned
+from django.conf import settings
typemapper = { }
+handler_tracker = [ ]
class HandlerMetaClass(type):
"""
@@ -10,10 +14,25 @@ class HandlerMetaClass(type):
"""
def __new__(cls, name, bases, attrs):
new_cls = type.__new__(cls, name, bases, attrs)
-
+
+ def already_registered(model, anon):
+ for k, (m, a) in typemapper.iteritems():
+ if model == m and anon == a:
+ return k
+
if hasattr(new_cls, 'model'):
+ if already_registered(new_cls.model, new_cls.is_anonymous):
+ if not getattr(settings, 'PISTON_IGNORE_DUPE_MODELS', False):
+ warnings.warn("Handler already registered for model %s, "
+ "you may experience inconsistent results." % new_cls.model.__name__)
+
typemapper[new_cls] = (new_cls.model, new_cls.is_anonymous)
-
+ else:
+ typemapper[new_cls] = (None, new_cls.is_anonymous)
+
+ if name not in ('BaseHandler', 'AnonymousBaseHandler'):
+ handler_tracker.append(new_cls)
+
return new_cls
class BaseHandler(object):
@@ -21,40 +40,43 @@ class BaseHandler(object):
Basehandler that gives you CRUD for free.
You are supposed to subclass this for specific
functionality.
-
+
All CRUD methods (`read`/`update`/`create`/`delete`)
receive a request as the first argument from the
resource. Use this for checking `request.user`, etc.
"""
__metaclass__ = HandlerMetaClass
-
+
allowed_methods = ('GET', 'POST', 'PUT', 'DELETE')
anonymous = is_anonymous = False
exclude = ( 'id', )
fields = ( )
-
+
def flatten_dict(self, dct):
return dict([ (str(k), dct.get(k)) for k in dct.keys() ])
-
+
def has_model(self):
- return hasattr(self, 'model')
-
+ return hasattr(self, 'model') or hasattr(self, 'queryset')
+
+ def queryset(self, request):
+ return self.model.objects.all()
+
def value_from_tuple(tu, name):
for int_, n in tu:
if n == name:
return int_
return None
-
+
def exists(self, **kwargs):
if not self.has_model():
raise NotImplementedError
-
+
try:
self.model.objects.get(**kwargs)
return True
except self.model.DoesNotExist:
return False
-
+
def read(self, request, *args, **kwargs):
if not self.has_model():
return rc.NOT_IMPLEMENTED
@@ -63,22 +85,22 @@ def read(self, request, *args, **kwargs):
if pkfield in kwargs:
try:
- return self.model.objects.get(pk=kwargs.get(pkfield))
+ return self.queryset(request).get(pk=kwargs.get(pkfield))
except ObjectDoesNotExist:
return rc.NOT_FOUND
except MultipleObjectsReturned: # should never happen, since we're using a PK
return rc.BAD_REQUEST
else:
- return self.model.objects.filter(*args, **kwargs)
-
+ return self.queryset(request).filter(*args, **kwargs)
+
def create(self, request, *args, **kwargs):
if not self.has_model():
return rc.NOT_IMPLEMENTED
-
- attrs = self.flatten_dict(request.POST)
-
+
+ attrs = self.flatten_dict(request.data)
+
try:
- inst = self.model.objects.get(**attrs)
+ inst = self.queryset(request).get(**attrs)
return rc.DUPLICATE_ENTRY
except self.model.DoesNotExist:
inst = self.model(**attrs)
@@ -86,17 +108,37 @@ def create(self, request, *args, **kwargs):
return inst
except self.model.MultipleObjectsReturned:
return rc.DUPLICATE_ENTRY
-
+
def update(self, request, *args, **kwargs):
- # TODO: This doesn't work automatically yet.
- return rc.NOT_IMPLEMENTED
-
+ if not self.has_model():
+ return rc.NOT_IMPLEMENTED
+
+ pkfield = self.model._meta.pk.name
+
+ if pkfield not in kwargs:
+ # No pk was specified
+ return rc.BAD_REQUEST
+
+ try:
+ inst = self.queryset(request).get(pk=kwargs.get(pkfield))
+ except ObjectDoesNotExist:
+ return rc.NOT_FOUND
+ except MultipleObjectsReturned: # should never happen, since we're using a PK
+ return rc.BAD_REQUEST
+
+ attrs = self.flatten_dict(request.data)
+ for k,v in attrs.iteritems():
+ setattr( inst, k, v )
+
+ inst.save()
+ return rc.ALL_OK
+
def delete(self, request, *args, **kwargs):
if not self.has_model():
raise NotImplementedError
try:
- inst = self.model.objects.get(*args, **kwargs)
+ inst = self.queryset(request).get(*args, **kwargs)
inst.delete()
@@ -105,7 +147,7 @@ def delete(self, request, *args, **kwargs):
return rc.DUPLICATE_ENTRY
except self.model.DoesNotExist:
return rc.NOT_HERE
-
+
class AnonymousBaseHandler(BaseHandler):
"""
Anonymous handler.
View
37 external/piston/handlers_doc.py
@@ -0,0 +1,37 @@
+from piston.doc import generate_doc
+from piston.handler import handler_tracker
+import re
+
+def generate_piston_documentation(app, docname, source):
+ e = re.compile(r"^\.\. piston_handlers:: ([\w\.]+)$")
+ old_source = source[0].split("\n")
+ new_source = old_source[:]
+ for line_nr, line in enumerate(old_source):
+ m = e.match(line)
+ if m:
+ module = m.groups()[0]
+ try:
+ __import__(module)
+ except ImportError:
+ pass
+ else:
+ new_lines = []
+ for handler in handler_tracker:
+ doc = generate_doc(handler)
+ new_lines.append(doc.name)
+ new_lines.append("-" * len(doc.name))
+ new_lines.append('::\n')
+ new_lines.append('\t' + doc.get_resource_uri_template() + '\n')
+ new_lines.append('Accepted methods:')
+ for method in doc.allowed_methods:
+ new_lines.append('\t* ' + method)
+ new_lines.append('')
+ if doc.doc:
+ new_lines.append(doc.doc)
+ new_source[line_nr:line_nr+1] = new_lines
+
+ source[0] = "\n".join(new_source)
+ return source
+
+def setup(app):
+ app.connect('source-read', generate_piston_documentation)
View
32 external/piston/managers.py
@@ -1,10 +1,23 @@
from django.db import models
from django.contrib.auth.models import User
-KEY_SIZE = 16
-SECRET_SIZE = 16
+KEY_SIZE = 18
+SECRET_SIZE = 32
-class ConsumerManager(models.Manager):
+class KeyManager(models.Manager):
+ '''Add support for random key/secret generation
+ '''
+ def generate_random_codes(self):
+ key = User.objects.make_random_password(length=KEY_SIZE)
+ secret = User.objects.make_random_password(length=SECRET_SIZE)
+
+ while self.filter(key__exact=key, secret__exact=secret).count():
+ secret = User.objects.make_random_password(length=SECRET_SIZE)
+
+ return key, secret
+
+
+class ConsumerManager(KeyManager):
def create_consumer(self, name, description=None, user=None):
"""
Shortcut to create a consumer with random key/secret.
@@ -18,10 +31,11 @@ def create_consumer(self, name, description=None, user=None):
consumer.description = description
if created:
- consumer.generate_random_codes()
+ consumer.key, consumer.secret = self.generate_random_codes()
+ consumer.save()
return consumer
-
+
_default_consumer = None
class ResourceManager(models.Manager):
@@ -36,7 +50,7 @@ def get_default_resource(self, name):
return self._default_resource
-class TokenManager(models.Manager):
+class TokenManager(KeyManager):
def create_token(self, consumer, token_type, timestamp, user=None):
"""
Shortcut to create a token with random key/secret.
@@ -47,6 +61,8 @@ def create_token(self, consumer, token_type, timestamp, user=None):
user=user)
if created:
- token.generate_random_codes()
+ token.key, token.secret = self.generate_random_codes()
+ token.save()
- return token
+ return token
+
View
126 external/piston/models.py
@@ -1,22 +1,29 @@
-import urllib
+import urllib, time, urlparse
+
+# Django imports
+from django.db.models.signals import post_save, post_delete
from django.db import models
from django.contrib.auth.models import User
-from django.contrib import admin
-from django.conf import settings
from django.core.mail import send_mail, mail_admins
-from django.template import loader
+# Piston imports
from managers import TokenManager, ConsumerManager, ResourceManager
+from signals import consumer_post_save, consumer_post_delete
KEY_SIZE = 18
SECRET_SIZE = 32
+VERIFIER_SIZE = 10
CONSUMER_STATES = (
- ('pending', 'Pending approval'),
+ ('pending', 'Pending'),
('accepted', 'Accepted'),
('canceled', 'Canceled'),
+ ('rejected', 'Rejected')
)
+def generate_random(length=SECRET_SIZE):
+ return User.objects.make_random_password(length=length)
+
class Nonce(models.Model):
token_key = models.CharField(max_length=KEY_SIZE)
consumer_key = models.CharField(max_length=KEY_SIZE)
@@ -25,19 +32,6 @@ class Nonce(models.Model):
def __unicode__(self):
return u"Nonce %s for %s" % (self.key, self.consumer_key)
-admin.site.register(Nonce)
-
-class Resource(models.Model):
- name = models.CharField(max_length=255)
- url = models.TextField(max_length=2047)
- is_readonly = models.BooleanField(default=True)
-
- objects = ResourceManager()
-
- def __unicode__(self):
- return u"Resource %s with url %s" % (self.name, self.url)
-
-admin.site.register(Resource)
class Consumer(models.Model):
name = models.CharField(max_length=255)
@@ -55,54 +49,26 @@ def __unicode__(self):
return u"Consumer %s with key %s" % (self.name, self.key)
def generate_random_codes(self):
+ """
+ Used to generate random key/secret pairings. Use this after you've
+ added the other data in place of save().
+
+ c = Consumer()
+ c.name = "My consumer"
+ c.description = "An app that makes ponies from the API."
+ c.user = some_user_object
+ c.generate_random_codes()
+ """
key = User.objects.make_random_password(length=KEY_SIZE)
-
- secret = User.objects.make_random_password(length=SECRET_SIZE)
+ secret = generate_random(SECRET_SIZE)
while Consumer.objects.filter(key__exact=key, secret__exact=secret).count():
- secret = User.objects.make_random_password(length=SECRET_SIZE)
+ secret = generate_random(SECRET_SIZE)
self.key = key
self.secret = secret
self.save()
- # --
-
- def save(self, **kwargs):
- super(Consumer, self).save(**kwargs)
-
- if self.id and self.user:
- subject = "API Consumer"
- rcpt = [ self.user.email, ]
-
- if self.status == "accepted":
- template = "api/mails/consumer_accepted.txt"
- subject += " was accepted!"
- elif self.status == "canceled":
- template = "api/mails/consumer_canceled.txt"
- subject += " has been canceled"
- else:
- template = "api/mails/consumer_pending.txt"
- subject += " application received"
-
- for admin in settings.ADMINS:
- bcc.append(admin[1])
-
- body = loader.render_to_string(template,
- { 'consumer': self, 'user': self.user })
-
- send_mail(subject, body, settings.DEFAULT_FROM_EMAIL,
- rcpt, fail_silently=True)
-
- if self.status == 'pending':
- mail_admins(subject, body, fail_silently=True)
-
- if settings.DEBUG:
- print "Mail being sent, to=%s" % rcpt
- print "Subject: %s" % subject
- print body
-
-admin.site.register(Consumer)
class Token(models.Model):
REQUEST = 1
@@ -111,13 +77,17 @@ class Token(models.Model):
key = models.CharField(max_length=KEY_SIZE)
secret = models.CharField(max_length=SECRET_SIZE)
+ verifier = models.CharField(max_length=VERIFIER_SIZE)
token_type = models.IntegerField(choices=TOKEN_TYPES)
- timestamp = models.IntegerField()
+ timestamp = models.IntegerField(default=long(time.time()))
is_approved = models.BooleanField(default=False)
user = models.ForeignKey(User, null=True, blank=True, related_name='tokens')
consumer = models.ForeignKey(Consumer)
+ callback = models.CharField(max_length=255, null=True, blank=True)
+ callback_confirmed = models.BooleanField(default=False)
+
objects = TokenManager()
def __unicode__(self):
@@ -126,21 +96,51 @@ def __unicode__(self):
def to_string(self, only_key=False):
token_dict = {
'oauth_token': self.key,
- 'oauth_token_secret': self.secret
+ 'oauth_token_secret': self.secret,
+ 'oauth_callback_confirmed': 'true',
}
+
+ if self.verifier:
+ token_dict.update({ 'oauth_verifier': self.verifier })
+
if only_key:
del token_dict['oauth_token_secret']
+
return urllib.urlencode(token_dict)
def generate_random_codes(self):
key = User.objects.make_random_password(length=KEY_SIZE)
- secret = User.objects.make_random_password(length=SECRET_SIZE)
+ secret = generate_random(SECRET_SIZE)
while Token.objects.filter(key__exact=key, secret__exact=secret).count():
- secret = User.objects.make_random_password(length=SECRET_SIZE)
+ secret = generate_random(SECRET_SIZE)
self.key = key
self.secret = secret
self.save()
-admin.site.register(Token)
+ # -- OAuth 1.0a stuff
+
+ def get_callback_url(self):
+ if self.callback and self.verifier:
+ # Append the oauth_verifier.
+ parts = urlparse.urlparse(self.callback)
+ scheme, netloc, path, params, query, fragment = parts[:6]
+ if query:
+ query = '%s&oauth_verifier=%s' % (query, self.verifier)
+ else:
+ query = 'oauth_verifier=%s' % self.verifier
+ return urlparse.urlunparse((scheme, netloc, path, params,
+ query, fragment))
+ return self.callback
+
+ def set_callback(self, callback):
+ if callback != "oob": # out of band, says "we can't do this!"
+ self.callback = callback
+ self.callback_confirmed = True
+ self.save()
+
+
+# Attach our signals
+post_save.connect(consumer_post_save, sender=Consumer)
+post_delete.connect(consumer_post_delete, sender=Consumer)
View
427 external/piston/oauth.py
@@ -1,50 +1,81 @@
+"""
+The MIT License
+
+Copyright (c) 2007 Leah Culver
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
import cgi
import urllib
import time
import random
import urlparse
import hmac
-import base64
+import binascii
+
VERSION = '1.0' # Hi Blaine!
HTTP_METHOD = 'GET'
SIGNATURE_METHOD = 'PLAINTEXT'
-# Generic exception class
-class OAuthError(RuntimeError):
- def get_message(self):
- return self._message
-
- def set_message(self, message):
- self._message = message
-
- message = property(get_message, set_message)
+class OAuthError(RuntimeError):
+ """Generic exception class."""
def __init__(self, message='OAuth error occured.'):
self.message = message
-# optional WWW-Authenticate header (401 error)
def build_authenticate_header(realm=''):
- return { 'WWW-Authenticate': 'OAuth realm="%s"' % realm }
+ """Optional WWW-Authenticate header (401 error)"""
+ return {'WWW-Authenticate': 'OAuth realm="%s"' % realm}
-# url escape
def escape(s):
- # escape '/' too
+ """Escape a URL including any /."""
return urllib.quote(s, safe='~')
-# util function: current timestamp
-# seconds since epoch (UTC)
+def _utf8_str(s):
+ """Convert unicode to utf-8."""
+ if isinstance(s, unicode):
+ return s.encode("utf-8")
+ else:
+ return str(s)
+
def generate_timestamp():
+ """Get seconds since epoch (UTC)."""
return int(time.time())
-# util function: nonce
-# pseudorandom number
def generate_nonce(length=8):
- return ''.join(str(random.randint(0, 9)) for i in range(length))
+ """Generate pseudorandom number."""
+ return ''.join([str(random.randint(0, 9)) for i in range(length)])
+
+def generate_verifier(length=8):
+ """Generate pseudorandom number."""
+ return ''.join([str(random.randint(0, 9)) for i in range(length)])
+
-# OAuthConsumer is a data type that represents the identity of the Consumer
-# via its shared secret with the Service Provider.
class OAuthConsumer(object):
+ """Consumer of OAuth authentication.
+
+ OAuthConsumer is a data type that represents the identity of the Consumer
+ via its shared secret with the Service Provider.
+
+ """
key = None
secret = None
@@ -52,39 +83,79 @@ def __init__(self, key, secret):
self.key = key
self.secret = secret
-# OAuthToken is a data type that represents an End User via either an access
-# or request token.
+
class OAuthToken(object):
- # access tokens and request tokens
+ """OAuthToken is a data type that represents an End User via either an access
+ or request token.
+
+ key -- the token
+ secret -- the token secret
+
+ """
key = None
secret = None
+ callback = None
+ callback_confirmed = None
+ verifier = None
- '''
- key = the token
- secret = the token secret
- '''
def __init__(self, key, secret):
self.key = key
self.secret = secret
- def to_string(self):
- return urllib.urlencode({'oauth_token': self.key, 'oauth_token_secret': self.secret})
+ def set_callback(self, callback):
+ self.callback = callback
+ self.callback_confirmed = 'true'
+
+ def set_verifier(self, verifier=None):
+ if verifier is not None:
+ self.verifier = verifier
+ else:
+ self.verifier = generate_verifier()
+
+ def get_callback_url(self):
+ if self.callback and self.verifier:
+ # Append the oauth_verifier.
+ parts = urlparse.urlparse(self.callback)
+ scheme, netloc, path, params, query, fragment = parts[:6]
+ if query:
+ query = '%s&oauth_verifier=%s' % (query, self.verifier)
+ else:
+ query = 'oauth_verifier=%s' % self.verifier
+ return urlparse.urlunparse((scheme, netloc, path, params,
+ query, fragment))
+ return self.callback
- # return a token from something like:
- # oauth_token_secret=digg&oauth_token=digg
- @staticmethod
+ def to_string(self):
+ data = {
+ 'oauth_token': self.key,
+ 'oauth_token_secret': self.secret,
+ }
+ if self.callback_confirmed is not None:
+ data['oauth_callback_confirmed'] = self.callback_confirmed
+ return urllib.urlencode(data)
+
def from_string(s):
+ """ Returns a token from something like:
+ oauth_token_secret=xxx&oauth_token=xxx
+ """
params = cgi.parse_qs(s, keep_blank_values=False)
key = params['oauth_token'][0]
secret = params['oauth_token_secret'][0]
- return OAuthToken(key, secret)
+ token = OAuthToken(key, secret)
+ try:
+ token.callback_confirmed = params['oauth_callback_confirmed'][0]
+ except KeyError:
+ pass # 1.0, no callback confirmed.
+ return token
+ from_string = staticmethod(from_string)
def __str__(self):
return self.to_string()
-# OAuthRequest represents the request and can be serialized
+
class OAuthRequest(object):
- '''
+ """OAuthRequest represents the request and can be serialized.
+
OAuth parameters:
- oauth_consumer_key
- oauth_token
@@ -93,9 +164,10 @@ class OAuthRequest(object):
- oauth_timestamp
- oauth_nonce
- oauth_version
+ - oauth_verifier
... any additional parameters, as defined by the Service Provider.
- '''
- parameters = None # oauth parameters
+ """
+ parameters = None # OAuth parameters.
http_method = HTTP_METHOD
http_url = None
version = VERSION
@@ -115,93 +187,107 @@ def get_parameter(self, parameter):
raise OAuthError('Parameter not found: %s' % parameter)
def _get_timestamp_nonce(self):
- return self.get_parameter('oauth_timestamp'), self.get_parameter('oauth_nonce')
+ return self.get_parameter('oauth_timestamp'), self.get_parameter(
+ 'oauth_nonce')
- # get any non-oauth parameters
def get_nonoauth_parameters(self):
+ """Get any non-OAuth parameters."""
parameters = {}
for k, v in self.parameters.iteritems():
- # ignore oauth parameters
+ # Ignore oauth parameters.
if k.find('oauth_') < 0:
parameters[k] = v
return parameters
- # serialize as a header for an HTTPAuth request
def to_header(self, realm=''):
+ """Serialize as a header for an HTTPAuth request."""
auth_header = 'OAuth realm="%s"' % realm
- # add the oauth parameters
+ # Add the oauth parameters.
if self.parameters:
for k, v in self.parameters.iteritems():
- auth_header += ', %s="%s"' % (k, escape(str(v)))
+ if k[:6] == 'oauth_':
+ auth_header += ', %s="%s"' % (k, escape(str(v)))
return {'Authorization': auth_header}
- # serialize as post data for a POST request
def to_postdata(self):
- return '&'.join('%s=%s' % (escape(str(k)), escape(str(v))) for k, v in self.parameters.iteritems())
+ """Serialize as post data for a POST request."""
+ return '&'.join(['%s=%s' % (escape(str(k)), escape(str(v))) \
+ for k, v in self.parameters.iteritems()])
- # serialize as a url for a GET request
def to_url(self):
+ """Serialize as a URL for a GET request."""
return '%s?%s' % (self.get_normalized_http_url(), self.to_postdata())
- # return a string that consists of all the parameters that need to be signed
def get_normalized_parameters(self):
+ """Return a string that contains the parameters that must be signed."""
params = self.parameters
try:
- # exclude the signature if it exists
+ # Exclude the signature if it exists.
del params['oauth_signature']
except:
pass
- key_values = params.items()
- # sort lexicographically, first after key, then after value
+ # Escape key values before sorting.
+ key_values = [(escape(_utf8_str(k)), escape(_utf8_str(v))) \
+ for k,v in params.items()]
+ # Sort lexicographically, first after key, then after value.
key_values.sort()
- # combine key value pairs in string and escape
- return '&'.join('%s=%s' % (escape(str(k)), escape(str(v))) for k, v in key_values)
+ # Combine key value pairs into a string.
+ return '&'.join(['%s=%s' % (k, v) for k, v in key_values])
- # just uppercases the http method
def get_normalized_http_method(self):
+ """Uppercases the http method."""
return self.http_method.upper()
- # parses the url and rebuilds it to be scheme://host/path
def get_normalized_http_url(self):
+ """Parses the URL and rebuilds it to be scheme://host/path."""
parts = urlparse.urlparse(self.http_url)
- url_string = '%s://%s%s' % (parts[0], parts[1], parts[2]) # scheme, netloc, path
- return url_string
-
- # set the signature parameter to the result of build_signature
+ scheme, netloc, path = parts[:3]
+ # Exclude default port numbers.
+ if scheme == 'http' and netloc[-3:] == ':80':
+ netloc = netloc[:-3]
+ elif scheme == 'https' and netloc[-4:] == ':443':
+ netloc = netloc[:-4]
+ return '%s://%s%s' % (scheme, netloc, path)
+
def sign_request(self, signature_method, consumer, token):
- # set the signature method
- self.set_parameter('oauth_signature_method', signature_method.get_name())
- # set the signature
- self.set_parameter('oauth_signature', self.build_signature(signature_method, consumer, token))
+ """Set the signature parameter to the result of build_signature."""
+ # Set the signature method.
+ self.set_parameter('oauth_signature_method',
+ signature_method.get_name())
+ # Set the signature.
+ self.set_parameter('oauth_signature',
+ self.build_signature(signature_method, consumer, token))
def build_signature(self, signature_method, consumer, token):
- # call the build signature method within the signature method
+ """Calls the build signature method within the signature method."""
return signature_method.build_signature(self, consumer, token)
- @staticmethod
- def from_request(http_method, http_url, headers=None, parameters=None, query_string=None):
- # combine multiple parameter sources
+ def from_request(http_method, http_url, headers=None, parameters=None,
+ query_string=None):
+ """Combines multiple parameter sources."""
if parameters is None:
parameters = {}
- # headers
- if headers and 'HTTP_AUTHORIZATION' in headers:
- auth_header = headers['HTTP_AUTHORIZATION']
- # check that the authorization header is OAuth
- if auth_header.index('OAuth') > -1:
+ # Headers
+ if headers and 'Authorization' in headers:
+ auth_header = headers['Authorization']
+ # Check that the authorization header is OAuth.
+ if auth_header[:6] == 'OAuth ':
+ auth_header = auth_header[6:]
try:
- # get the parameters from the header
+ # Get the parameters from the header.
header_params = OAuthRequest._split_header(auth_header)
parameters.update(header_params)
except:
- raise OAuthError('Unable to parse OAuth parameters from Authorization header.')
+ raise OAuthError('Unable to parse OAuth parameters from '
+ 'Authorization header.')
- # GET or POST query string
+ # GET or POST query string.
if query_string:
query_params = OAuthRequest._split_url_string(query_string)
parameters.update(query_params)
- # URL parameters
+ # URL parameters.
param_str = urlparse.urlparse(http_url)[4] # query
url_params = OAuthRequest._split_url_string(param_str)
parameters.update(url_params)
@@ -210,9 +296,11 @@ def from_request(http_method, http_url, headers=None, parameters=None, query_str
return OAuthRequest(http_method, http_url, parameters)
return None
+ from_request = staticmethod(from_request)
- @staticmethod
- def from_consumer_and_token(oauth_consumer, token=None, http_method=HTTP_METHOD, http_url=None, parameters=None):
+ def from_consumer_and_token(oauth_consumer, token=None,
+ callback=None, verifier=None, http_method=HTTP_METHOD,
+ http_url=None, parameters=None):
if not parameters:
parameters = {}
@@ -228,49 +316,57 @@ def from_consumer_and_token(oauth_consumer, token=None, http_method=HTTP_METHOD,
if token:
parameters['oauth_token'] = token.key
+ parameters['oauth_callback'] = token.callback
+ # 1.0a support for verifier.
+ parameters['oauth_verifier'] = verifier
+ elif callback:
+ # 1.0a support for callback in the request token request.
+ parameters['oauth_callback'] = callback
return OAuthRequest(http_method, http_url, parameters)
+ from_consumer_and_token = staticmethod(from_consumer_and_token)
- @staticmethod
- def from_token_and_callback(token, callback=None, http_method=HTTP_METHOD, http_url=None, parameters=None):
+ def from_token_and_callback(token, callback=None, http_method=HTTP_METHOD,
+ http_url=None, parameters=None):
if not parameters:
parameters = {}
parameters['oauth_token'] = token.key
if callback:
- parameters['oauth_callback'] = escape(callback)
+ parameters['oauth_callback'] = callback
return OAuthRequest(http_method, http_url, parameters)
+ from_token_and_callback = staticmethod(from_token_and_callback)
- # util function: turn Authorization: header into parameters, has to do some unescaping
- @staticmethod
def _split_header(header):
+ """Turn Authorization: header into parameters."""
params = {}
parts = header.split(',')
for param in parts:
- # ignore realm parameter
- if param.find('OAuth realm') > -1:
+ # Ignore realm parameter.
+ if param.find('realm') > -1:
continue
- # remove whitespace
+ # Remove whitespace.
param = param.strip()
- # split key-value
+ # Split key-value.
param_parts = param.split('=', 1)
- # remove quotes and unescape the value
+ # Remove quotes and unescape the value.
params[param_parts[0]] = urllib.unquote(param_parts[1].strip('\"'))
return params
-
- # util function: turn url string into parameters, has to do some unescaping
- @staticmethod
+ _split_header = staticmethod(_split_header)
+
def _split_url_string(param_str):
+ """Turn URL string into parameters."""
parameters = cgi.parse_qs(param_str, keep_blank_values=False)
for k, v in parameters.iteritems():
parameters[k] = urllib.unquote(v[0])
return parameters
+ _split_url_string = staticmethod(_split_url_string)
-# OAuthServer is a worker to check a requests validity against a data store
class OAuthServer(object):
- timestamp_threshold = 300 # in seconds, five minutes
+ """A worker to check the validity of a request against a data store."""
+ timestamp_threshold = 300 # In seconds, five minutes.
version = VERSION
signature_methods = None
data_store = None
@@ -279,7 +375,7 @@ def __init__(self, data_store=None, signature_methods=None):
self.data_store = data_store
self.signature_methods = signature_methods or {}
- def set_data_store(self, oauth_data_store):
+ def set_data_store(self, data_store):
self.data_store = data_store
def get_data_store(self):
@@ -289,57 +385,64 @@ def add_signature_method(self, signature_method):
self.signature_methods[signature_method.get_name()] = signature_method
return self.signature_methods
- # process a request_token request
- # returns the request token on success
def fetch_request_token(self, oauth_request):
+ """Processes a request_token request and returns the
+ request token on success.
+ """
try:
- # get the request token for authorization
+ # Get the request token for authorization.
token = self._get_token(oauth_request, 'request')
except OAuthError:
- # no token required for the initial token request
+ # No token required for the initial token request.
version = self._get_version(oauth_request)
consumer = self._get_consumer(oauth_request)
+ try:
+ callback = self.get_callback(oauth_request)
+ except OAuthError:
+ callback = None # 1.0, no callback specified.
self._check_signature(oauth_request, consumer, None)
- # fetch a new token
- token = self.data_store.fetch_request_token(consumer)
+ # Fetch a new token.
+ token = self.data_store.fetch_request_token(consumer, callback)
return token
- # process an access_token request
- # returns the access token on success
def fetch_access_token(self, oauth_request):
+ """Processes an access_token request and returns the
+ access token on success.
+ """
version = self._get_version(oauth_request)
consumer = self._get_consumer(oauth_request)
- # get the request token
+ verifier = self._get_verifier(oauth_request)
+ # Get the request token.
token = self._get_token(oauth_request, 'request')
self._check_signature(oauth_request, consumer, token)
- new_token = self.data_store.fetch_access_token(consumer, token)
+ new_token = self.data_store.fetch_access_token(consumer, token, verifier)
return new_token
- # verify an api call, checks all the parameters
def verify_request(self, oauth_request):
+ """Verifies an api call and checks all the parameters."""
# -> consumer and token
version = self._get_version(oauth_request)
consumer = self._get_consumer(oauth_request)
- # get the access token
+ # Get the access token.
token = self._get_token(oauth_request, 'access')
self._check_signature(oauth_request, consumer, token)
parameters = oauth_request.get_nonoauth_parameters()
return consumer, token, parameters
- # authorize a request token
def authorize_token(self, token, user):
+ """Authorize a request token."""
return self.data_store.authorize_request_token(token, user)
-
- # get the callback url
+
def get_callback(self, oauth_request):
+ """Get the callback URL."""
return oauth_request.get_parameter('oauth_callback')
-
- # optional support for the authenticate header
+
def build_authenticate_header(self, realm=''):
+ """Optional support for the authenticate header."""
return {'WWW-Authenticate': 'OAuth realm="%s"' % realm}
- # verify the correct version request for this server
def _get_version(self, oauth_request):
+ """Verify the correct version request for this server."""
try:
version = oauth_request.get_parameter('oauth_version')
except:
@@ -348,37 +451,40 @@ def _get_version(self, oauth_request):
raise OAuthError('OAuth version %s not supported.' % str(version))
return version
- # figure out the signature with some defaults
def _get_signature_method(self, oauth_request):
+ """Figure out the signature with some defaults."""
try:
- signature_method = oauth_request.get_parameter('oauth_signature_method')
+ signature_method = oauth_request.get_parameter(
+ 'oauth_signature_method')
except:
signature_method = SIGNATURE_METHOD
try:
- # get the signature method object
+ # Get the signature method object.
signature_method = self.signature_methods[signature_method]
except:
signature_method_names = ', '.join(self.signature_methods.keys())
- raise OAuthError('Signature method %s not supported try one of the following: %s' % (signature_method, signature_method_names))
+ raise OAuthError('Signature method %s not supported try one of the '
+ 'following: %s' % (signature_method, signature_method_names))
return signature_method
def _get_consumer(self, oauth_request):
consumer_key = oauth_request.get_parameter('oauth_consumer_key')
- if not consumer_key:
- raise OAuthError('Invalid consumer key.')
consumer = self.data_store.lookup_consumer(consumer_key)
if not consumer:
raise OAuthError('Invalid consumer.')
return consumer
- # try to find the token for the provided request token key
def _get_token(self, oauth_request, token_type='access'):
+ """Try to find the token for the provided request token key."""
token_field = oauth_request.get_parameter('oauth_token')
token = self.data_store.lookup_token(token_type, token_field)
if not token:
raise OAuthError('Invalid %s token: %s' % (token_type, token_field))
return token
+
+ def _get_verifier(self, oauth_request):
+ return oauth_request.get_parameter('oauth_verifier')
def _check_signature(self, oauth_request, consumer, token):
timestamp, nonce = oauth_request._get_timestamp_nonce()
@@ -389,29 +495,35 @@ def _check_signature(self, oauth_request, consumer, token):
signature = oauth_request.get_parameter('oauth_signature')
except:
raise OAuthError('Missing signature.')
- # validate the signature
- valid_sig = signature_method.check_signature(oauth_request, consumer, token, signature)
+ # Validate the signature.
+ valid_sig = signature_method.check_signature(oauth_request, consumer,
+ token, signature)
if not valid_sig:
- key, base = signature_method.build_signature_base_string(oauth_request, consumer, token)
- raise OAuthError('Invalid signature. Expected signature base string: %s' % base)
+ key, base = signature_method.build_signature_base_string(
+ oauth_request, consumer, token)
+ raise OAuthError('Invalid signature. Expected signature base '
+ 'string: %s' % base)
built = signature_method.build_signature(oauth_request, consumer, token)
def _check_timestamp(self, timestamp):
- # verify that timestamp is recentish
+ """Verify that timestamp is recentish."""
timestamp = int(timestamp)
now = int(time.time())
lapsed = now - timestamp
if lapsed > self.timestamp_threshold:
- raise OAuthError('Expired timestamp: given %d and now %s has a greater difference than threshold %d' % (timestamp, now, self.timestamp_threshold))
+ raise OAuthError('Expired timestamp: given %d and now %s has a '
+ 'greater difference than threshold %d' %
+ (timestamp, now, self.timestamp_threshold))
def _check_nonce(self, consumer, token, nonce):
- # verify that the nonce is uniqueish
+ """Verify that the nonce is uniqueish."""
nonce = self.data_store.lookup_nonce(consumer, token, nonce)
if nonce:
raise OAuthError('Nonce already used: %s' % str(nonce))
-# OAuthClient is a worker to attempt to execute a request
+
class OAuthClient(object):
+ """OAuthClient is a worker to attempt to execute a request."""
consumer = None
token = None
@@ -426,62 +538,65 @@ def get_token(self):
return self.token
def fetch_request_token(self, oauth_request):
- # -> OAuthToken
+ """-> OAuthToken."""
raise NotImplementedError
def fetch_access_token(self, oauth_request):
- # -> OAuthToken
+ """-> OAuthToken."""
raise NotImplementedError
def access_resource(self, oauth_request):
- # -> some protected resource
+ """-> Some protected resource."""
raise NotImplementedError
-# OAuthDataStore is a database abstraction used to lookup consumers and tokens
+
class OAuthDataStore(object):
+ """A database abstraction used to lookup consumers and tokens."""
def lookup_consumer(self, key):
- # -> OAuthConsumer
+ """-> OAuthConsumer."""
raise NotImplementedError
def lookup_token(self, oauth_consumer, token_type, token_token):
- # -> OAuthToken
+ """-> OAuthToken."""
raise NotImplementedError
- def lookup_nonce(self, oauth_consumer, oauth_token, nonce, timestamp):
- # -> OAuthToken
+ def lookup_nonce(self, oauth_consumer, oauth_token, nonce):
+ """-> OAuthToken."""
raise NotImplementedError
- def fetch_request_token(self, oauth_consumer):
- # -> OAuthToken
+ def fetch_request_token(self, oauth_consumer, oauth_callback):
+ """-> OAuthToken."""
raise NotImplementedError
- def fetch_access_token(self, oauth_consumer, oauth_token):
- # -> OAuthToken
+ def fetch_access_token(self, oauth_consumer, oauth_token, oauth_verifier):
+ """-> OAuthToken."""
raise NotImplementedError
def authorize_request_token(self, oauth_token, user):
- # -> OAuthToken
+ """-> OAuthToken."""
raise NotImplementedError
-# OAuthSignatureMethod is a strategy class that implements a signature method
+
class OAuthSignatureMethod(object):
+ """A strategy class that implements a signature method."""
def get_name(self):
- # -> str
+ """-> str."""
raise NotImplementedError
def build_signature_base_string(self, oauth_request, oauth_consumer, oauth_token):
- # -> str key, str raw
+ """-> str key, str raw."""
raise NotImplementedError
def build_signature(self, oauth_request, oauth_consumer, oauth_token):
- # -> str
+ """-> str."""
raise NotImplementedError
def check_signature(self, oauth_request, consumer, token, signature):
built = self.build_signature(oauth_request, consumer, token)
return built == signature
+
class OAuthSignatureMethod_HMAC_SHA1(OAuthSignatureMethod):
def get_name(self):
@@ -501,19 +616,21 @@ def build_signature_base_string(self, oauth_request, consumer, token):
return key, raw
def build_signature(self, oauth_request, consumer, token):
- # build the base signature string
- key, raw = self.build_signature_base_string(oauth_request, consumer, token)
+ """Builds the base signature string."""
+ key, raw = self.build_signature_base_string(oauth_request, consumer,
+ token)
- # hmac object
+ # HMAC object.
try:
import hashlib # 2.5
hashed = hmac.new(key, raw, hashlib.sha1)
except:
- import sha # deprecated
+ import sha # Deprecated
hashed = hmac.new(key, raw, sha)
- # calculate the digest base 64
- return base64.b64encode(hashed.digest())
+ # Calculate the digest base 64.
+ return binascii.b2a_base64(hashed.digest())[:-1]
+
class OAuthSignatureMethod_PLAINTEXT(OAuthSignatureMethod):
@@ -521,11 +638,13 @@ def get_name(self):
return 'PLAINTEXT'
def build_signature_base_string(self, oauth_request, consumer, token):
- # concatenate the consumer key and secret
- sig = escape(consumer.secret) + '&'
+ """Concatenates the consumer key and secret."""
+ sig = '%s&' % escape(consumer.secret)
if token:
sig = sig + escape(token.secret)
- return sig
+ return sig, sig
def build_signature(self, oauth_request, consumer, token):
- return self.build_signature_base_string(oauth_request, consumer, token)
+ key, raw = self.build_signature_base_string(oauth_request, consumer,
+ token)
+ return key
View
0  external/piston/piston/utils.py.orig
No changes.
View
260 external/piston/resource.py
@@ -6,6 +6,9 @@
from django.views.decorators.vary import vary_on_headers
from django.conf import settings
from django.core.mail import send_mail, EmailMessage
+from django.core.exceptions import ValidationError
+from django.db.models.query import QuerySet
+from django.http import Http404
from emitters import Emitter
from handler import typemapper
@@ -14,6 +17,8 @@
from utils import coerce_put_post, FormValidationError, HttpStatusCode
from utils import rc, format_error, translate_mime, MimerDataException
+CHALLENGE = object()
+
class Resource(object):
"""
Resource. Create one for your URL mappings, just
@@ -22,20 +27,23 @@ class Resource(object):
is an authentication handler. If not specified,
`NoAuthentication` will be used by default.
"""
- callmap = { 'GET': 'read', 'POST': 'create',
+ callmap = { 'GET': 'read', 'POST': 'create',
'PUT': 'update', 'DELETE': 'delete' }
-
+
def __init__(self, handler, authentication=None):
if not callable(handler):
raise AttributeError, "Handler not callable."
-
+
self.handler = handler()
-
+ self.csrf_exempt = getattr(self.handler, 'csrf_exempt', True)
+
if not authentication:
- self.authentication = NoAuthentication()
- else:
+ self.authentication = (NoAuthentication(),)
+ elif isinstance(authentication, (list, tuple)):
self.authentication = authentication
-
+ else:
+ self.authentication = (authentication,)
+
# Erroring
self.email_errors = getattr(settings, 'PISTON_EMAIL_ERRORS', True)
self.display_errors = getattr(settings, 'PISTON_DISPLAY_ERRORS', True)
@@ -52,12 +60,68 @@ def determine_emitter(self, request, *args, **kwargs):
that as well.
"""
em = kwargs.pop('emitter_format', None)
-
+
if not em:
em = request.GET.get('format', 'json')
return em
-
+
+ def form_validation_response(self, e):
+ """
+ Method to return form validation error information.
+ You will probably want to override this in your own
+ `Resource` subclass.
+ """
+ resp = rc.BAD_REQUEST
+ resp.write(' '+str(e.form.errors))
+ return resp
+