Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions requests_oauthlib/oauth1_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(self, client_key,
verifier=None,
client_class=None,
force_include_body=False,
base_url=None,
**kwargs):
"""Construct the OAuth 1 session.

Expand Down Expand Up @@ -138,6 +139,8 @@ def __init__(self, client_key,
`requests_oauthlib.OAuth1` instead of the default
:param force_include_body: Always include the request body in the
signature creation.
:param base_url: An optional string to use as a prefix for all requests
from this session.
:param **kwargs: Additional keyword arguments passed to `OAuth1`
"""
super(OAuth1Session, self).__init__()
Expand All @@ -154,6 +157,7 @@ def __init__(self, client_key,
force_include_body=force_include_body,
**kwargs)
self.auth = self._client
self.base_url = base_url

def authorization_url(self, url, request_token=None, **kwargs):
"""Create an authorization URL by appending request_token and optional
Expand Down Expand Up @@ -334,3 +338,11 @@ def rebuild_auth(self, prepared_request, response):
prepared_request.headers.pop('Authorization', True)
prepared_request.prepare_auth(self.auth)
return

def prepare_request(self, request):
"""
If we have a `base_url`, prepend it to the URL.
"""
if self.base_url:
request.url = self.base_url + request.url
return super(OAuth1Session, self).prepare_request(request)
13 changes: 12 additions & 1 deletion requests_oauthlib/oauth2_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class OAuth2Session(requests.Session):

def __init__(self, client_id=None, client=None, auto_refresh_url=None,
auto_refresh_kwargs=None, scope=None, redirect_uri=None, token=None,
state=None, token_updater=None, **kwargs):
state=None, token_updater=None, base_url=None, **kwargs):
"""Construct a new OAuth 2 client session.

:param client_id: Client id obtained during registration
Expand All @@ -61,6 +61,8 @@ def __init__(self, client_id=None, client=None, auto_refresh_url=None,
set a TokenUpdated warning will be raised when a token
has been refreshed. This warning will carry the token
in its token argument.
:param base_url: An optional string to use as a prefix for all requests
from this session.
:param kwargs: Arguments to pass to the Session constructor.
"""
super(OAuth2Session, self).__init__(**kwargs)
Expand All @@ -75,6 +77,7 @@ def __init__(self, client_id=None, client=None, auto_refresh_url=None,
self.token_updater = token_updater
self._client = client or WebApplicationClient(client_id, token=token)
self._client._populate_attributes(token or {})
self.base_url = base_url

# Allow customizations for non compliant providers through various
# hooks to adjust requests and responses.
Expand Down Expand Up @@ -298,3 +301,11 @@ def register_compliance_hook(self, hook_type, hook):
raise ValueError('Hook type %s is not in %s.',
hook_type, self.compliance_hook)
self.compliance_hook[hook_type].add(hook)

def prepare_request(self, request):
"""
If we have a `base_url`, prepend it to the URL.
"""
if self.base_url:
request.url = self.base_url + request.url
return super(OAuth2Session, self).prepare_request(request)