diff --git a/requests_oauthlib/oauth2_session.py b/requests_oauthlib/oauth2_session.py index 999183e2..7b1ebc8f 100644 --- a/requests_oauthlib/oauth2_session.py +++ b/requests_oauthlib/oauth2_session.py @@ -312,7 +312,8 @@ def refresh_token(self, token_url, refresh_token=None, body='', auth=None, self.token['refresh_token'] = refresh_token return self.token - def request(self, method, url, data=None, headers=None, withhold_token=False, **kwargs): + def request(self, method, url, data=None, headers=None, withhold_token=False, + client_id=None, client_secret=None, **kwargs): """Intercept all requests and add the OAuth 2 token if present.""" if not is_secure_transport(url): raise InsecureTransportError() @@ -332,7 +333,13 @@ def request(self, method, url, data=None, headers=None, withhold_token=False, ** if self.auto_refresh_url: log.debug('Auto refresh is set, attempting to refresh at %s.', self.auto_refresh_url) - token = self.refresh_token(self.auto_refresh_url, **kwargs) + auth = None + if client_id and client_secret: + log.debug('Encoding client_id "%s" with client_secret as Basic auth credentials.', client_id) + auth = requests.auth.HTTPBasicAuth(client_id, client_secret) + token = self.refresh_token( + self.auto_refresh_url, auth=auth, **kwargs + ) if self.token_updater: log.debug('Updating token to %s using %s.', token, self.token_updater) diff --git a/tests/test_oauth2_session.py b/tests/test_oauth2_session.py index bb5201e9..a222df00 100644 --- a/tests/test_oauth2_session.py +++ b/tests/test_oauth2_session.py @@ -2,6 +2,7 @@ import json import mock import time +from base64 import b64encode from copy import deepcopy try: from unittest2 import TestCase @@ -91,6 +92,8 @@ def test_refresh_token_request(self): del self.expired_token['expires_at'] def fake_refresh(r, **kwargs): + if "/refresh" in r.url: + self.assertNotIn("Authorization", r.headers) resp = mock.MagicMock() resp.text = json.dumps(self.token) return resp @@ -118,6 +121,23 @@ def token_updater(token): auth.send = fake_refresh auth.get('https://i.b') + def fake_refresh_with_auth(r, **kwargs): + if "/refresh" in r.url: + self.assertIn("Authorization", r.headers) + encoded = b64encode(b"foo:bar") + content = (b"Basic " + encoded).decode('latin1') + self.assertEqual(r.headers["Authorization"], content) + resp = mock.MagicMock() + resp.text = json.dumps(self.token) + return resp + + for client in self.clients: + auth = OAuth2Session(client=client, token=self.expired_token, + auto_refresh_url='https://i.b/refresh', + token_updater=token_updater) + auth.send = fake_refresh_with_auth + auth.get('https://i.b', client_id='foo', client_secret='bar') + @mock.patch("time.time", new=lambda: fake_time) def test_token_from_fragment(self): mobile = MobileApplicationClient(self.client_id)