Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions requests_oauthlib/oauth2_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ def refresh_token(self, token_url, refresh_token=None, body='', auth=None,
self.token['refresh_token'] = refresh_token
return self.token

def request(self, method, url, data=None, headers=None, withhold_token=False, **kwargs):
def request(self, method, url, data=None, headers=None, withhold_token=False,
client_id=None, client_secret=None, **kwargs):
"""Intercept all requests and add the OAuth 2 token if present."""
if not is_secure_transport(url):
raise InsecureTransportError()
Expand All @@ -332,7 +333,13 @@ def request(self, method, url, data=None, headers=None, withhold_token=False, **
if self.auto_refresh_url:
log.debug('Auto refresh is set, attempting to refresh at %s.',
self.auto_refresh_url)
token = self.refresh_token(self.auto_refresh_url, **kwargs)
auth = None
if client_id and client_secret:
log.debug('Encoding client_id "%s" with client_secret as Basic auth credentials.', client_id)
auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
token = self.refresh_token(
self.auto_refresh_url, auth=auth, **kwargs
)
if self.token_updater:
log.debug('Updating token to %s using %s.',
token, self.token_updater)
Expand Down
20 changes: 20 additions & 0 deletions tests/test_oauth2_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import mock
import time
from base64 import b64encode
from copy import deepcopy
try:
from unittest2 import TestCase
Expand Down Expand Up @@ -91,6 +92,8 @@ def test_refresh_token_request(self):
del self.expired_token['expires_at']

def fake_refresh(r, **kwargs):
if "/refresh" in r.url:
self.assertNotIn("Authorization", r.headers)
resp = mock.MagicMock()
resp.text = json.dumps(self.token)
return resp
Expand Down Expand Up @@ -118,6 +121,23 @@ def token_updater(token):
auth.send = fake_refresh
auth.get('https://i.b')

def fake_refresh_with_auth(r, **kwargs):
if "/refresh" in r.url:
self.assertIn("Authorization", r.headers)
encoded = b64encode(b"foo:bar")
content = (b"Basic " + encoded).decode('latin1')
self.assertEqual(r.headers["Authorization"], content)
resp = mock.MagicMock()
resp.text = json.dumps(self.token)
return resp

for client in self.clients:
auth = OAuth2Session(client=client, token=self.expired_token,
auto_refresh_url='https://i.b/refresh',
token_updater=token_updater)
auth.send = fake_refresh_with_auth
auth.get('https://i.b', client_id='foo', client_secret='bar')

@mock.patch("time.time", new=lambda: fake_time)
def test_token_from_fragment(self):
mobile = MobileApplicationClient(self.client_id)
Expand Down