From e483b0440eb32c702eb36738afcf8dcc080d1fee Mon Sep 17 00:00:00 2001 From: Ryan Barrett Date: Thu, 10 Nov 2011 00:53:13 -0800 Subject: [PATCH] check access token if provided --- README | 4 +++- fql.py | 10 ++++++++++ fql_test.py | 45 +++++++++++++++++++++++++++++++-------------- graph.py | 8 ++++++++ graph_test.py | 18 ++++++++++++++++-- oauth.py | 7 ++++--- oauth_test.py | 39 +++++++++++++++++++++------------------ testutil.py | 44 ++++++++++++++++++++------------------------ 8 files changed, 113 insertions(+), 62 deletions(-) diff --git a/README b/README index c7818fd..b51286a 100644 --- a/README +++ b/README @@ -24,6 +24,7 @@ at the `/...` endpoint. It supports: * aliases as well as ids * read access to all connection types except `insights`, `mutualfriends`, `payments`, `subscriptions`, and `Comment/likes` * multiple selection via `?ids=...` + * checks access token if provided * most error codes and messages [https://developers.facebook.com/docs/reference/fql/ FQL] is served at the @@ -32,8 +33,9 @@ at the `/...` endpoint. It supports: * read access to all tables except `insights` and `permissions` * indexable columns. returns an error if a non-indexable column is used in a `WHERE` clause. * all functions: `me(), now(), strlen(), substr(), strpos()` - * most error codes and messages + * checks access token if provided * JSON and XML output formats + * most error codes and messages [http://developers.facebook.com/docs/authentication/ OAuth authentication] is served at the `/dialog/oauth` and `/oauth/access_token` endpoints. It supports: diff --git a/fql.py b/fql.py index 915a7d5..a830c62 100644 --- a/fql.py +++ b/fql.py @@ -16,6 +16,7 @@ from sqlparse import tokens import webapp2 +import oauth import schemautil @@ -64,6 +65,10 @@ class MissingParamError(FqlError): code = -1 msg = 'The parameter %s is required' +class InvalidAccessTokenError(FqlError): + code = 190 + msg = 'Invalid access token signature.' + class Fql(object): """A parsed FQL statement. Just a thin wrapper around sqlparse.sql.Statement. @@ -260,6 +265,11 @@ def get(self): query = self.request.get(query_arg) if not query: raise MissingParamError(query_arg) + + token = self.request.get('access_token') + if token and not oauth.AccessTokenHandler.is_valid_token(self.conn, token): + raise InvalidAccessTokenError() + logging.debug('Received FQL query: %s' % query) fql = Fql(self.schema, query, self.me) diff --git a/fql_test.py b/fql_test.py index 59d9ac0..5e058ab 100755 --- a/fql_test.py +++ b/fql_test.py @@ -62,33 +62,37 @@ def setUp(self): super(FqlHandlerTest, self).setUp(fql.FqlHandler) insert_test_data(self.conn) - def expect_fql(self, fql, expected, format='json'): + def expect_fql(self, fql, expected, args=None): """Runs an FQL query and checks the response. Args: fql: string expected: list or dict that the JSON response should match - format: passed as the format URL query parameter + args: dict, extra query parameters """ - self.expect('/method/fql.query', expected, {'format': format, 'query': fql}) + full_args = {'format': 'json', 'query': fql} + if args: + full_args.update(args) + self.expect('/method/fql.query', expected, full_args) - def expect_error(self, query, error): + def expect_error(self, query, error, args=None): """Runs a query and checks that it returns the given error code and message. Args: fql: string error: expected error + args: dict, extra query parameters """ + request_args = {'format': 'json', 'query': query, 'method': 'fql.query'} + if args: + request_args.update(args) + expected = { 'error_code': error.code, 'error_msg': error.msg, - 'request_args': [ - # order here matters, since the list is compared by value. - {'key': 'query', 'value': query}, - {'key': 'format', 'value': 'json'}, - {'key': 'method', 'value': 'fql.query'}, - ]} - self.expect_fql(query, expected) + 'request_args': [{'key': k, 'value': v} for k, v in request_args.items()], + } + self.expect_fql(query, expected, args=args) def test_example_data(self): dataset = testutil.maybe_read(schemautil.FqlDataset) @@ -199,7 +203,7 @@ def test_xml_format(self): %s """ % self.ME, - format='xml') + args={'format': 'xml'}) def test_format_defaults_to_xml(self): for format in ('foo', ''): @@ -211,7 +215,7 @@ def test_format_defaults_to_xml(self): alice """, - format=format) + args={'format': format}) def test_xml_format_error(self): self.expect_fql( @@ -235,7 +239,7 @@ def test_xml_format_error(self): """, - format='xml') + args={'format': 'xml'}) def test_no_select_error(self): self.expect_error('INSERT id FROM profile WHERE id = me()', @@ -265,6 +269,19 @@ def test_sqlite_error(self): def test_no_query_error(self): self.expect_error('', fql.MissingParamError('query')) + def test_access_token(self): + self.conn.execute( + 'INSERT INTO oauth_access_tokens(code, token) VALUES("asdf", "qwert")') + self.conn.commit() + self.expect_fql('SELECT username FROM profile WHERE id = me()', + [{'username': 'alice'}], + args={'access_token': 'qwert'}) + + def test_invalid_access_token(self): + self.expect_error('SELECT username FROM profile WHERE id = me()', + fql.InvalidAccessTokenError(), + args={'access_token': 'bad'}) + if __name__ == '__main__': unittest.main() diff --git a/graph.py b/graph.py index 8abbf3b..7d508d4 100644 --- a/graph.py +++ b/graph.py @@ -15,6 +15,7 @@ import webapp2 +import oauth import schemautil # the one connection that returns an HTTP 302 redirect instead of a normal @@ -85,6 +86,9 @@ class ObjectsNotFoundError(GraphError): class AccessTokenError(JsonError): message = 'An access token is required to request this resource.' +class ValidationError(JsonError): + message = 'Error validating application.' + class AliasNotFoundError(JsonError): status = 404 message = '(#803) Some of the aliases you requested do not exist: %s' @@ -175,6 +179,10 @@ def get(self, id, connection): id = None try: + token = self.request.get('access_token') + if token and not oauth.AccessTokenHandler.is_valid_token(self.conn, token): + raise ValidationError() + namedict = self.prepare_ids(id) if connection: diff --git a/graph_test.py b/graph_test.py index 84f0f4c..f467153 100755 --- a/graph_test.py +++ b/graph_test.py @@ -56,12 +56,12 @@ def expect_redirect(self, path, redirect_to): self.assertEquals(302, resp.status_int) self.assertEquals(redirect_to, resp.headers['Location']) - def expect_error(self, path, exception): + def expect_error(self, path, exception, args=None): """Args: path: string exception: expected instance of a GraphError subclass """ - self.expect(path, exception.message, expected_status=exception.status) + self.expect(path, exception.message, expected_status=exception.status, args=args) class ObjectTest(TestBase): @@ -120,6 +120,20 @@ def test_ids_always_prefers_alias(self): self.expect('/?ids=alice,1', {'alice': self.alice}) self.expect('/?ids=1,alice', {'alice': self.alice}) + def test_access_token(self): + self.conn.execute( + 'INSERT INTO oauth_access_tokens(code, token) VALUES("asdf", "qwert")') + self.conn.commit() + + token = {'access_token': 'qwert'} + self.expect('/alice', self.alice, args=token) + self.expect('/alice/albums', self.alice_albums, args=token) + + def test_invalid_access_token(self): + for path in '/alice', '/alice/albums': + self.expect_error(path, graph.ValidationError(), + args={'access_token': 'bad'}) + class ConnectionTest(TestBase): diff --git a/oauth.py b/oauth.py index f3327fd..1812d39 100644 --- a/oauth.py +++ b/oauth.py @@ -158,10 +158,11 @@ class AccessTokenHandler(BaseHandler): ROUTES = [(r'/oauth/access_token/?', 'oauth.AccessTokenHandler')] - def is_valid_token(self, access_token): + @staticmethod + def is_valid_token(conn, access_token): """Returns True if the given access token is valid, False otherwise.""" - cursor = self.conn.execute('SELECT token FROM oauth_access_tokens WHERE token = ?', - (access_token,)) + cursor = conn.execute('SELECT token FROM oauth_access_tokens WHERE token = ?', + (access_token,)) return cursor.fetchone() is not None def get(self): diff --git a/oauth_test.py b/oauth_test.py index a8b0030..986fab9 100755 --- a/oauth_test.py +++ b/oauth_test.py @@ -24,11 +24,6 @@ class OAuthHandlerTest(testutil.HandlerTest): def setUp(self): super(OAuthHandlerTest, self).setUp(oauth.AuthCodeHandler, oauth.AccessTokenHandler) - self.handler = oauth.AccessTokenHandler() - self.auth_code_args = { - 'client_id': '123', - 'redirect_uri': 'http://x/y', - } self.access_token_args = { 'client_id': '123', 'client_secret': '456', @@ -36,10 +31,18 @@ def setUp(self): 'code': None # filled in by individual tests } - def expect_oauth_redirect(self, redirect_re='http://x/y\?code=(.+)'): + def expect_oauth_redirect(self, redirect_re='http://x/y\?code=(.+)', + args=None): """Requests an access code, checks the redirect, and returns the code. """ - resp = self.get_response('/dialog/oauth', args=self.auth_code_args) + full_args = { + 'client_id': '123', + 'redirect_uri': 'http://x/y', + } + if args: + full_args.update(args) + + resp = self.get_response('/dialog/oauth', args=full_args) self.assertEquals('302 Moved Temporarily', resp.status) location = resp.headers['Location'] match = re.match(redirect_re, location) @@ -50,12 +53,12 @@ def test_auth_code(self): self.expect_oauth_redirect() def test_auth_code_with_redirect_uri_with_params(self): - self.auth_code_args['redirect_uri'] = 'http://x/y?foo=bar' - self.expect_oauth_redirect('http://x/y\?code=(.+)&foo=bar') + self.expect_oauth_redirect('http://x/y\?code=(.+)&foo=bar', + args={'redirect_uri': 'http://x/y?foo=bar'}) def test_auth_code_with_state(self): - self.auth_code_args['state'] = 'my_state' - self.expect_oauth_redirect('http://x/y\?state=my_state&code=(.+)') + self.expect_oauth_redirect('http://x/y\?state=my_state&code=(.+)', + args={'state': 'my_state'}) def test_auth_code_missing_args(self): for arg in ('client_id', 'redirect_uri'): @@ -71,7 +74,7 @@ def test_access_token(self): args = urlparse.parse_qs(resp.body) self.assertEquals(2, len(args), `args`) self.assertEquals('999999', args['expires'][0]) - assert self.handler.is_valid_token(args['access_token'][0]) + assert oauth.AccessTokenHandler.is_valid_token(self.conn, args['access_token'][0]) def test_access_token_nonexistent_auth_code(self): self.access_token_args['code'] = 'xyz' @@ -79,8 +82,8 @@ def test_access_token_nonexistent_auth_code(self): assert 'not found' in resp.body def test_nonexistent_access_token(self): - self.assertFalse(self.handler.is_valid_token('')) - self.assertFalse(self.handler.is_valid_token('xyz')) + self.assertFalse(oauth.AccessTokenHandler.is_valid_token(self.conn, '')) + self.assertFalse(oauth.AccessTokenHandler.is_valid_token(self.conn, 'xyz')) def test_access_token_missing_args(self): for arg in ('client_id', 'client_secret'): @@ -105,13 +108,13 @@ def test_app_login(self): self.access_token_args['grant_type'] = 'client_credentials' resp = self.get_response('/oauth/access_token', args=self.access_token_args) args = urlparse.parse_qs(resp.body) - assert self.handler.is_valid_token(args['access_token'][0]) + assert oauth.AccessTokenHandler.is_valid_token(self.conn, args['access_token'][0]) def test_client_side_flow(self): - self.auth_code_args['response_type'] = 'token' token = self.expect_oauth_redirect( - 'http://x/y#access_token=(.+)&expires_in=999999') - assert self.handler.is_valid_token(token) + 'http://x/y#access_token=(.+)&expires_in=999999', + args={'response_type': 'token'}) + assert oauth.AccessTokenHandler.is_valid_token(self.conn, token) if __name__ == '__main__': diff --git a/testutil.py b/testutil.py index ae4339d..71c5150 100644 --- a/testutil.py +++ b/testutil.py @@ -1,6 +1,4 @@ """Unit test utilities. - -TODO: put a utility in server that combines all routes, then use that in all tests """ __author__ = ['Ryan Barrett '] @@ -63,6 +61,7 @@ def expect(self, path, expected, args=None, expected_status=200): expected_status: integer, expected HTTP response status """ response = None + results = None try: response = self.get_response(path, args=args) self.assertEquals(expected_status, response.status_int) @@ -73,6 +72,7 @@ def expect(self, path, expected, args=None, expected_status=200): results = json.loads(response) if not isinstance(expected, list): expected = [expected] + if not isinstance(results, list): results = [results] expected.sort() results.sort() @@ -82,7 +82,7 @@ def expect(self, path, expected, args=None, expected_status=200): except: print >> sys.stderr, '\nquery: %s %s' % (path, args) print >> sys.stderr, 'expected: %r' % expected - print >> sys.stderr, 'received: %r' % response + print >> sys.stderr, 'received: %r' % results if results else response raise def get_response(self, path, args=None): @@ -90,27 +90,23 @@ def get_response(self, path, args=None): path = '%s?%s' % (path, urllib.urlencode(args)) return self.app.get_response(path) + # TODO: for the love of god, refactor, or even better, find a more supported + # utility somewhere else. def assert_dict_equals(self, expected, actual): msgs = [] - for key in set(expected.keys()) | set(actual.keys()): - e = expected.get(key, None) - a = actual.get(key, None) - if isinstance(e, re._pattern_type): - if not re.match(e, a): - msgs.append("%s: %r doesn't match %s" % (key, e, a)) - elif isinstance(e, dict) and isinstance(a, dict): - self.assert_dict_equals(e, a) - # this is only here because we don't exactly match FB in whether we return - # or omit some "empty" values, e.g. 0, null, ''. see the TODO in graph_on_fql.py. - elif not e and not a: - continue - else: - if isinstance(e, list) and isinstance(a, list): - e.sort() - a.sort() - if e != a: - msgs.append('%s: %r != %r' % (key, e, a)) - - if msgs: - self.fail('\n'.join(msgs)) + if isinstance(expected, re._pattern_type): + if not re.match(expected, actual): + self.fail("%r doesn't match %s" % (expected, actual)) + # this is only here because we don't exactly match FB in whether we return + # or omit some "empty" values, e.g. 0, null, ''. see the TODO in graph_on_fql.py. + elif not expected and not actual: + return True + elif isinstance(expected, dict) and isinstance(actual, dict): + for key in set(expected.keys()) | set(actual.keys()): + self.assert_dict_equals(expected.get(key), actual.get(key)) + else: + if isinstance(expected, list) and isinstance(actual, list): + expected.sort() + actual.sort() + self.assertEquals(expected, actual)