Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
jtroussard committed Oct 1, 2019
1 parent c04c01c commit 2ad25f7
Showing 1 changed file with 47 additions and 161 deletions.
208 changes: 47 additions & 161 deletions requests_oauthlib/oauth2_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,9 @@
import logging

from oauthlib.common import generate_token, urldecode
from oauthlib.oauth2 import (
WebApplicationClient,
InsecureTransportError,
)
from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError
from oauthlib.oauth2 import LegacyApplicationClient
from oauthlib.oauth2 import (
TokenExpiredError,
is_secure_transport,
)
from oauthlib.oauth2 import TokenExpiredError, is_secure_transport
import requests

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -81,9 +75,7 @@ def __init__(
:param kwargs: Arguments to pass to the Session constructor.
"""
super(OAuth2Session, self).__init__(**kwargs)
self._client = client or WebApplicationClient(
client_id, token=token
)
self._client = client or WebApplicationClient(client_id, token=token)
self.token = token or {}
self.scope = scope
self.redirect_uri = redirect_uri
Expand All @@ -109,15 +101,10 @@ def new_state(self):
"""Generates a state string to be used in authorizations."""
try:
self._state = self.state()
log.debug(
"Generated new state %s.", self._state
)
log.debug("Generated new state %s.", self._state)
except TypeError:
self._state = self.state
log.debug(
"Re-using previously supplied state %s.",
self._state,
)
log.debug("Re-using previously supplied state %s.", self._state)
return self._state

@property
Expand Down Expand Up @@ -253,14 +240,11 @@ def fetch_token(
authorization_response, state=self._state
)
code = self._client.code
elif not code and isinstance(
self._client, WebApplicationClient
):
elif not code and isinstance(self._client, WebApplicationClient):
code = self._client.code
if not code:
raise ValueError(
"Please supply either code or "
"authorization_response parameters."
"Please supply either code or " "authorization_response parameters."
)

# Earlier versions of this library build an HTTPBasicAuth header out of
Expand All @@ -275,9 +259,7 @@ def fetch_token(
# 4.3.2 - Resource Owner Password Credentials Grant
# https://tools.ietf.org/html/rfc6749#section-4.3.2

if isinstance(
self._client, LegacyApplicationClient
):
if isinstance(self._client, LegacyApplicationClient):
if username is None:
raise ValueError(
"`LegacyApplicationClient` requires both the "
Expand Down Expand Up @@ -319,14 +301,8 @@ def fetch_token(
"as Basic auth credentials.",
client_id,
)
client_secret = (
client_secret
if client_secret is not None
else ""
)
auth = requests.auth.HTTPBasicAuth(
client_id, client_secret
)
client_secret = client_secret if client_secret is not None else ""
auth = requests.auth.HTTPBasicAuth(client_id, client_secret)

if include_client_id:
# this was pulled out of the params
Expand All @@ -349,15 +325,13 @@ def fetch_token(
self.token = {}
request_kwargs = {}
if method.upper() == "POST":
request_kwargs[
"params" if force_querystring else "data"
] = dict(urldecode(body))
request_kwargs["params" if force_querystring else "data"] = dict(
urldecode(body)
)
elif method.upper() == "GET":
request_kwargs["params"] = dict(urldecode(body))
else:
raise ValueError(
"The method kwarg must be POST or GET."
)
raise ValueError("The method kwarg must be POST or GET.")

r = self.request(
method=method,
Expand All @@ -370,37 +344,20 @@ def fetch_token(
**request_kwargs
)

log.debug(
"Request to fetch token completed with status %s.",
r.status_code,
)
log.debug("Request to fetch token completed with status %s.", r.status_code)
log.debug("Request url was %s", r.request.url)
log.debug(
"Request headers were %s", r.request.headers
)
log.debug("Request headers were %s", r.request.headers)
log.debug("Request body was %s", r.request.body)
log.debug(
"Response headers were %s and content %s.",
r.headers,
r.text,
)
log.debug("Response headers were %s and content %s.", r.headers, r.text)
log.debug(
"Invoking %d token response hooks.",
len(
self.compliance_hook[
"access_token_response"
]
),
len(self.compliance_hook["access_token_response"]),
)
for hook in self.compliance_hook[
"access_token_response"
]:
for hook in self.compliance_hook["access_token_response"]:
log.debug("Invoking hook %s.", hook)
r = hook(r)

self._client.parse_request_body_response(
r.text, scope=self.scope
)
self._client.parse_request_body_response(r.text, scope=self.scope)
self.token = self._client.token
log.debug("Obtained token %s.", self.token)
return self.token
Expand Down Expand Up @@ -444,38 +401,26 @@ def refresh_token(
:return: A token dict
"""
if not token_url:
raise ValueError(
"No token endpoint set for auto_refresh."
)
raise ValueError("No token endpoint set for auto_refresh.")

if not is_secure_transport(token_url):
raise InsecureTransportError()

refresh_token = refresh_token or self.token.get(
"refresh_token"
)
refresh_token = refresh_token or self.token.get("refresh_token")

log.debug(
"Adding auto refresh key word arguments %s.",
self.auto_refresh_kwargs,
"Adding auto refresh key word arguments %s.", self.auto_refresh_kwargs
)
kwargs.update(self.auto_refresh_kwargs)
body = self._client.prepare_refresh_body(
body=body,
refresh_token=refresh_token,
scope=self.scope,
**kwargs
)
log.debug(
"Prepared refresh token request body %s", body
body=body, refresh_token=refresh_token, scope=self.scope, **kwargs
)
log.debug("Prepared refresh token request body %s", body)

if headers is None:
headers = {
"Accept": "application/json",
"Content-Type": (
"application/x-www-form-urlencoded;charset=UTF-8"
),
"Content-Type": ("application/x-www-form-urlencoded;charset=UTF-8"),
}

r = self.post(
Expand All @@ -488,36 +433,19 @@ def refresh_token(
withhold_token=True,
proxies=proxies,
)
log.debug(
"Request to refresh token completed with status %s.",
r.status_code,
)
log.debug(
"Response headers were %s and content %s.",
r.headers,
r.text,
)
log.debug("Request to refresh token completed with status %s.", r.status_code)
log.debug("Response headers were %s and content %s.", r.headers, r.text)
log.debug(
"Invoking %d token response hooks.",
len(
self.compliance_hook[
"refresh_token_response"
]
),
len(self.compliance_hook["refresh_token_response"]),
)
for hook in self.compliance_hook[
"refresh_token_response"
]:
for hook in self.compliance_hook["refresh_token_response"]:
log.debug("Invoking hook %s.", hook)
r = hook(r)

self.token = self._client.parse_request_body_response(
r.text, scope=self.scope
)
self.token = self._client.parse_request_body_response(r.text, scope=self.scope)
if not "refresh_token" in self.token:
log.debug(
"No new refresh token given. Re-using old."
)
log.debug("No new refresh token given. Re-using old.")
self.token["refresh_token"] = refresh_token
return self.token

Expand All @@ -536,33 +464,20 @@ def request(
if not is_secure_transport(url):
raise InsecureTransportError()
if self.token and not withhold_token:
log.debug(
"Adding token %s to request.", self.token
)
log.debug("Adding token %s to request.", self.token)
try:
url, headers, data = self._client.add_token(
url,
http_method=method,
body=data,
headers=headers,
url, http_method=method, body=data, headers=headers
)
# Moving this compliance hook invocation until after the access_token is added and handling the
# token venacular within the compliance hook.
log.debug(
"Invoking %d protected resource request hooks.",
len(
self.compliance_hook[
"protected_request"
]
),
len(self.compliance_hook["protected_request"]),
)
for hook in self.compliance_hook[
"protected_request"
]:
for hook in self.compliance_hook["protected_request"]:
log.debug("Invoking hook %s.", hook)
url, headers, data = hook(
url, headers, data
)
url, headers, data = hook(url, headers, data)

# Attempt to retrieve and save new access token if expired
except TokenExpiredError:
Expand All @@ -574,60 +489,33 @@ def request(

# We mustn't pass auth twice.
auth = kwargs.pop("auth", None)
if (
client_id
and client_secret
and (auth is None)
):
if client_id and client_secret and (auth is None):
log.debug(
'Encoding client_id "%s" with client_secret as Basic auth credentials.',
client_id,
)
auth = requests.auth.HTTPBasicAuth(
client_id, client_secret
)
auth = requests.auth.HTTPBasicAuth(client_id, client_secret)
token = self.refresh_token(
self.auto_refresh_url,
auth=auth,
**kwargs
self.auto_refresh_url, auth=auth, **kwargs
)
if self.token_updater:
log.debug(
"Updating token to %s using %s.",
token,
self.token_updater,
"Updating token to %s using %s.", token, self.token_updater
)
self.token_updater(token)
url, headers, data = self._client.add_token(
url,
http_method=method,
body=data,
headers=headers,
url, http_method=method, body=data, headers=headers
)
else:
raise TokenUpdated(token)
else:
raise

log.debug(
"Requesting url %s using method %s.",
url,
method,
)
log.debug(
"Supplying headers %s and data %s",
headers,
data,
)
log.debug(
"Passing through key word arguments %s.", kwargs
)
log.debug("Requesting url %s using method %s.", url, method)
log.debug("Supplying headers %s and data %s", headers, data)
log.debug("Passing through key word arguments %s.", kwargs)
return super(OAuth2Session, self).request(
method,
url,
headers=headers,
data=data,
**kwargs
method, url, headers=headers, data=data, **kwargs
)

def register_compliance_hook(self, hook_type, hook):
Expand All @@ -643,8 +531,6 @@ def register_compliance_hook(self, hook_type, hook):
"""
if hook_type not in self.compliance_hook:
raise ValueError(
"Hook type %s is not in %s.",
hook_type,
self.compliance_hook,
"Hook type %s is not in %s.", hook_type, self.compliance_hook
)
self.compliance_hook[hook_type].add(hook)

0 comments on commit 2ad25f7

Please sign in to comment.