Skip to content
This repository has been archived by the owner on Jul 4, 2021. It is now read-only.

Commit

Permalink
check access token if provided
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan Barrett committed Nov 10, 2011
1 parent dd5aa19 commit e483b04
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 62 deletions.
4 changes: 3 additions & 1 deletion README
Expand Up @@ -24,6 +24,7 @@ at the `/...` endpoint. It supports:
* aliases as well as ids
* read access to all connection types except `insights`, `mutualfriends`, `payments`, `subscriptions`, and `Comment/likes`
* multiple selection via `?ids=...`
* checks access token if provided
* most error codes and messages

[https://developers.facebook.com/docs/reference/fql/ FQL] is served at the
Expand All @@ -32,8 +33,9 @@ at the `/...` endpoint. It supports:
* read access to all tables except `insights` and `permissions`
* indexable columns. returns an error if a non-indexable column is used in a `WHERE` clause.
* all functions: `me(), now(), strlen(), substr(), strpos()`
* most error codes and messages
* checks access token if provided
* JSON and XML output formats
* most error codes and messages

[http://developers.facebook.com/docs/authentication/ OAuth authentication] is
served at the `/dialog/oauth` and `/oauth/access_token` endpoints. It supports:
Expand Down
10 changes: 10 additions & 0 deletions fql.py
Expand Up @@ -16,6 +16,7 @@
from sqlparse import tokens
import webapp2

import oauth
import schemautil


Expand Down Expand Up @@ -64,6 +65,10 @@ class MissingParamError(FqlError):
code = -1
msg = 'The parameter %s is required'

class InvalidAccessTokenError(FqlError):
code = 190
msg = 'Invalid access token signature.'


class Fql(object):
"""A parsed FQL statement. Just a thin wrapper around sqlparse.sql.Statement.
Expand Down Expand Up @@ -260,6 +265,11 @@ def get(self):
query = self.request.get(query_arg)
if not query:
raise MissingParamError(query_arg)

token = self.request.get('access_token')
if token and not oauth.AccessTokenHandler.is_valid_token(self.conn, token):
raise InvalidAccessTokenError()

logging.debug('Received FQL query: %s' % query)

fql = Fql(self.schema, query, self.me)
Expand Down
45 changes: 31 additions & 14 deletions fql_test.py
Expand Up @@ -62,33 +62,37 @@ def setUp(self):
super(FqlHandlerTest, self).setUp(fql.FqlHandler)
insert_test_data(self.conn)

def expect_fql(self, fql, expected, format='json'):
def expect_fql(self, fql, expected, args=None):
"""Runs an FQL query and checks the response.
Args:
fql: string
expected: list or dict that the JSON response should match
format: passed as the format URL query parameter
args: dict, extra query parameters
"""
self.expect('/method/fql.query', expected, {'format': format, 'query': fql})
full_args = {'format': 'json', 'query': fql}
if args:
full_args.update(args)
self.expect('/method/fql.query', expected, full_args)

def expect_error(self, query, error):
def expect_error(self, query, error, args=None):
"""Runs a query and checks that it returns the given error code and message.
Args:
fql: string
error: expected error
args: dict, extra query parameters
"""
request_args = {'format': 'json', 'query': query, 'method': 'fql.query'}
if args:
request_args.update(args)

expected = {
'error_code': error.code,
'error_msg': error.msg,
'request_args': [
# order here matters, since the list is compared by value.
{'key': 'query', 'value': query},
{'key': 'format', 'value': 'json'},
{'key': 'method', 'value': 'fql.query'},
]}
self.expect_fql(query, expected)
'request_args': [{'key': k, 'value': v} for k, v in request_args.items()],
}
self.expect_fql(query, expected, args=args)

def test_example_data(self):
dataset = testutil.maybe_read(schemautil.FqlDataset)
Expand Down Expand Up @@ -199,7 +203,7 @@ def test_xml_format(self):
<id>%s</id>
</profile>
</fql_query_response>""" % self.ME,
format='xml')
args={'format': 'xml'})

def test_format_defaults_to_xml(self):
for format in ('foo', ''):
Expand All @@ -211,7 +215,7 @@ def test_format_defaults_to_xml(self):
<username>alice</username>
</profile>
</fql_query_response>""",
format=format)
args={'format': format})

def test_xml_format_error(self):
self.expect_fql(
Expand All @@ -235,7 +239,7 @@ def test_xml_format_error(self):
</arg>
</request_args>
</error_response>""",
format='xml')
args={'format': 'xml'})

def test_no_select_error(self):
self.expect_error('INSERT id FROM profile WHERE id = me()',
Expand Down Expand Up @@ -265,6 +269,19 @@ def test_sqlite_error(self):
def test_no_query_error(self):
self.expect_error('', fql.MissingParamError('query'))

def test_access_token(self):
self.conn.execute(
'INSERT INTO oauth_access_tokens(code, token) VALUES("asdf", "qwert")')
self.conn.commit()
self.expect_fql('SELECT username FROM profile WHERE id = me()',
[{'username': 'alice'}],
args={'access_token': 'qwert'})

def test_invalid_access_token(self):
self.expect_error('SELECT username FROM profile WHERE id = me()',
fql.InvalidAccessTokenError(),
args={'access_token': 'bad'})


if __name__ == '__main__':
unittest.main()
8 changes: 8 additions & 0 deletions graph.py
Expand Up @@ -15,6 +15,7 @@

import webapp2

import oauth
import schemautil

# the one connection that returns an HTTP 302 redirect instead of a normal
Expand Down Expand Up @@ -85,6 +86,9 @@ class ObjectsNotFoundError(GraphError):
class AccessTokenError(JsonError):
message = 'An access token is required to request this resource.'

class ValidationError(JsonError):
message = 'Error validating application.'

class AliasNotFoundError(JsonError):
status = 404
message = '(#803) Some of the aliases you requested do not exist: %s'
Expand Down Expand Up @@ -175,6 +179,10 @@ def get(self, id, connection):
id = None

try:
token = self.request.get('access_token')
if token and not oauth.AccessTokenHandler.is_valid_token(self.conn, token):
raise ValidationError()

namedict = self.prepare_ids(id)

if connection:
Expand Down
18 changes: 16 additions & 2 deletions graph_test.py
Expand Up @@ -56,12 +56,12 @@ def expect_redirect(self, path, redirect_to):
self.assertEquals(302, resp.status_int)
self.assertEquals(redirect_to, resp.headers['Location'])

def expect_error(self, path, exception):
def expect_error(self, path, exception, args=None):
"""Args:
path: string
exception: expected instance of a GraphError subclass
"""
self.expect(path, exception.message, expected_status=exception.status)
self.expect(path, exception.message, expected_status=exception.status, args=args)


class ObjectTest(TestBase):
Expand Down Expand Up @@ -120,6 +120,20 @@ def test_ids_always_prefers_alias(self):
self.expect('/?ids=alice,1', {'alice': self.alice})
self.expect('/?ids=1,alice', {'alice': self.alice})

def test_access_token(self):
self.conn.execute(
'INSERT INTO oauth_access_tokens(code, token) VALUES("asdf", "qwert")')
self.conn.commit()

token = {'access_token': 'qwert'}
self.expect('/alice', self.alice, args=token)
self.expect('/alice/albums', self.alice_albums, args=token)

def test_invalid_access_token(self):
for path in '/alice', '/alice/albums':
self.expect_error(path, graph.ValidationError(),
args={'access_token': 'bad'})


class ConnectionTest(TestBase):

Expand Down
7 changes: 4 additions & 3 deletions oauth.py
Expand Up @@ -158,10 +158,11 @@ class AccessTokenHandler(BaseHandler):

ROUTES = [(r'/oauth/access_token/?', 'oauth.AccessTokenHandler')]

def is_valid_token(self, access_token):
@staticmethod
def is_valid_token(conn, access_token):
"""Returns True if the given access token is valid, False otherwise."""
cursor = self.conn.execute('SELECT token FROM oauth_access_tokens WHERE token = ?',
(access_token,))
cursor = conn.execute('SELECT token FROM oauth_access_tokens WHERE token = ?',
(access_token,))
return cursor.fetchone() is not None

def get(self):
Expand Down
39 changes: 21 additions & 18 deletions oauth_test.py
Expand Up @@ -24,22 +24,25 @@ class OAuthHandlerTest(testutil.HandlerTest):
def setUp(self):
super(OAuthHandlerTest, self).setUp(oauth.AuthCodeHandler,
oauth.AccessTokenHandler)
self.handler = oauth.AccessTokenHandler()
self.auth_code_args = {
'client_id': '123',
'redirect_uri': 'http://x/y',
}
self.access_token_args = {
'client_id': '123',
'client_secret': '456',
'redirect_uri': 'http://x/y',
'code': None # filled in by individual tests
}

def expect_oauth_redirect(self, redirect_re='http://x/y\?code=(.+)'):
def expect_oauth_redirect(self, redirect_re='http://x/y\?code=(.+)',
args=None):
"""Requests an access code, checks the redirect, and returns the code.
"""
resp = self.get_response('/dialog/oauth', args=self.auth_code_args)
full_args = {
'client_id': '123',
'redirect_uri': 'http://x/y',
}
if args:
full_args.update(args)

resp = self.get_response('/dialog/oauth', args=full_args)
self.assertEquals('302 Moved Temporarily', resp.status)
location = resp.headers['Location']
match = re.match(redirect_re, location)
Expand All @@ -50,12 +53,12 @@ def test_auth_code(self):
self.expect_oauth_redirect()

def test_auth_code_with_redirect_uri_with_params(self):
self.auth_code_args['redirect_uri'] = 'http://x/y?foo=bar'
self.expect_oauth_redirect('http://x/y\?code=(.+)&foo=bar')
self.expect_oauth_redirect('http://x/y\?code=(.+)&foo=bar',
args={'redirect_uri': 'http://x/y?foo=bar'})

def test_auth_code_with_state(self):
self.auth_code_args['state'] = 'my_state'
self.expect_oauth_redirect('http://x/y\?state=my_state&code=(.+)')
self.expect_oauth_redirect('http://x/y\?state=my_state&code=(.+)',
args={'state': 'my_state'})

def test_auth_code_missing_args(self):
for arg in ('client_id', 'redirect_uri'):
Expand All @@ -71,16 +74,16 @@ def test_access_token(self):
args = urlparse.parse_qs(resp.body)
self.assertEquals(2, len(args), `args`)
self.assertEquals('999999', args['expires'][0])
assert self.handler.is_valid_token(args['access_token'][0])
assert oauth.AccessTokenHandler.is_valid_token(self.conn, args['access_token'][0])

def test_access_token_nonexistent_auth_code(self):
self.access_token_args['code'] = 'xyz'
resp = self.get_response('/oauth/access_token/', args=self.access_token_args)
assert 'not found' in resp.body

def test_nonexistent_access_token(self):
self.assertFalse(self.handler.is_valid_token(''))
self.assertFalse(self.handler.is_valid_token('xyz'))
self.assertFalse(oauth.AccessTokenHandler.is_valid_token(self.conn, ''))
self.assertFalse(oauth.AccessTokenHandler.is_valid_token(self.conn, 'xyz'))

def test_access_token_missing_args(self):
for arg in ('client_id', 'client_secret'):
Expand All @@ -105,13 +108,13 @@ def test_app_login(self):
self.access_token_args['grant_type'] = 'client_credentials'
resp = self.get_response('/oauth/access_token', args=self.access_token_args)
args = urlparse.parse_qs(resp.body)
assert self.handler.is_valid_token(args['access_token'][0])
assert oauth.AccessTokenHandler.is_valid_token(self.conn, args['access_token'][0])

def test_client_side_flow(self):
self.auth_code_args['response_type'] = 'token'
token = self.expect_oauth_redirect(
'http://x/y#access_token=(.+)&expires_in=999999')
assert self.handler.is_valid_token(token)
'http://x/y#access_token=(.+)&expires_in=999999',
args={'response_type': 'token'})
assert oauth.AccessTokenHandler.is_valid_token(self.conn, token)


if __name__ == '__main__':
Expand Down
44 changes: 20 additions & 24 deletions testutil.py
@@ -1,6 +1,4 @@
"""Unit test utilities.
TODO: put a utility in server that combines all routes, then use that in all tests
"""

__author__ = ['Ryan Barrett <mockfacebook@ryanb.org>']
Expand Down Expand Up @@ -63,6 +61,7 @@ def expect(self, path, expected, args=None, expected_status=200):
expected_status: integer, expected HTTP response status
"""
response = None
results = None
try:
response = self.get_response(path, args=args)
self.assertEquals(expected_status, response.status_int)
Expand All @@ -73,6 +72,7 @@ def expect(self, path, expected, args=None, expected_status=200):
results = json.loads(response)
if not isinstance(expected, list):
expected = [expected]
if not isinstance(results, list):
results = [results]
expected.sort()
results.sort()
Expand All @@ -82,35 +82,31 @@ def expect(self, path, expected, args=None, expected_status=200):
except:
print >> sys.stderr, '\nquery: %s %s' % (path, args)
print >> sys.stderr, 'expected: %r' % expected
print >> sys.stderr, 'received: %r' % response
print >> sys.stderr, 'received: %r' % results if results else response
raise

def get_response(self, path, args=None):
if args:
path = '%s?%s' % (path, urllib.urlencode(args))
return self.app.get_response(path)

# TODO: for the love of god, refactor, or even better, find a more supported
# utility somewhere else.
def assert_dict_equals(self, expected, actual):
msgs = []

for key in set(expected.keys()) | set(actual.keys()):
e = expected.get(key, None)
a = actual.get(key, None)
if isinstance(e, re._pattern_type):
if not re.match(e, a):
msgs.append("%s: %r doesn't match %s" % (key, e, a))
elif isinstance(e, dict) and isinstance(a, dict):
self.assert_dict_equals(e, a)
# this is only here because we don't exactly match FB in whether we return
# or omit some "empty" values, e.g. 0, null, ''. see the TODO in graph_on_fql.py.
elif not e and not a:
continue
else:
if isinstance(e, list) and isinstance(a, list):
e.sort()
a.sort()
if e != a:
msgs.append('%s: %r != %r' % (key, e, a))

if msgs:
self.fail('\n'.join(msgs))
if isinstance(expected, re._pattern_type):
if not re.match(expected, actual):
self.fail("%r doesn't match %s" % (expected, actual))
# this is only here because we don't exactly match FB in whether we return
# or omit some "empty" values, e.g. 0, null, ''. see the TODO in graph_on_fql.py.
elif not expected and not actual:
return True
elif isinstance(expected, dict) and isinstance(actual, dict):
for key in set(expected.keys()) | set(actual.keys()):
self.assert_dict_equals(expected.get(key), actual.get(key))
else:
if isinstance(expected, list) and isinstance(actual, list):
expected.sort()
actual.sort()
self.assertEquals(expected, actual)

0 comments on commit e483b04

Please sign in to comment.