diff --git a/requests_oauthlib/oauth2_session.py b/requests_oauthlib/oauth2_session.py index 93cc4d7..3df4a9f 100644 --- a/requests_oauthlib/oauth2_session.py +++ b/requests_oauthlib/oauth2_session.py @@ -1,4 +1,7 @@ import logging +import time +import calendar +from datetime import datetime from oauthlib.common import generate_token, urldecode from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError @@ -171,6 +174,39 @@ def authorized(self): """ return bool(self.access_token) + def _add_expires_at(self, token, response_date=None): + """Add expires_at to token if expires_in is present. + + OAuth2 responses often include expires_in (seconds until token expires) but + oauthlib expects expires_at (timestamp when token expires) for expiration checks. + + Uses response Date header if provided. Falls back to current time if + Date header is not provided or malformed. + + :param token: OAuth2 token dict + :param response_date: Optional Date header value from response (e.g. "Thu, 14 Mar 2024 08:30:00 GMT") + :return: Token dict with expires_at if expires_in was present + """ + if token and 'expires_in' in token: + # RFC 6749 requires expires_in to be 1*DIGIT, but some providers send it as string + expires_in = int(token['expires_in']) + + # Try to use response Date header if provided + if response_date: + try: + # Parse HTTP date format (RFC 7231) + dt = datetime.strptime(response_date, "%a, %d %b %Y %H:%M:%S GMT") + # Convert UTC time tuple to Unix timestamp (returns integer) + token['expires_at'] = calendar.timegm(dt.utctimetuple()) + expires_in + return token + except (TypeError, ValueError): + # Skip if Date header is malformed or uses non-standard format + log.debug("Failed to parse Date header: %s", response_date) + + # Fall back to current time (truncate to second for conservative expiry) + token['expires_at'] = int(time.time()) + expires_in + return token + def authorization_url(self, url, state=None, **kwargs): """Form an authorization URL. @@ -404,7 +440,7 @@ def fetch_token( r = hook(r) self._client.parse_request_body_response(r.text, scope=self.scope) - self.token = self._client.token + self.token = self._add_expires_at(self._client.token, response_date=r.headers.get('Date')) log.debug("Obtained token %s.", self.token) return self.token @@ -494,6 +530,7 @@ def refresh_token( r = hook(r) self.token = self._client.parse_request_body_response(r.text, scope=self.scope) + self.token = self._add_expires_at(self.token, response_date=r.headers.get('Date')) if "refresh_token" not in self.token: log.debug("No new refresh token given. Re-using old.") self.token["refresh_token"] = refresh_token diff --git a/tests/test_compliance_fixes.py b/tests/test_compliance_fixes.py index 9ad6d09..aaec230 100644 --- a/tests/test_compliance_fixes.py +++ b/tests/test_compliance_fixes.py @@ -115,9 +115,9 @@ def test_fetch_access_token(self): authorization_response="https://i.b/?code=hello", ) # Times should be close - approx_expires_at = round(time.time()) + 3600 + approx_expires_at = int(time.time()) + 3600 actual_expires_at = token.pop("expires_at") - self.assertAlmostEqual(actual_expires_at, approx_expires_at, places=2) + self.assertEqual(actual_expires_at, approx_expires_at) # Other token values exact self.assertEqual(token, {"access_token": "mailchimp", "expires_in": 3600}) @@ -289,9 +289,9 @@ def test_fetch_access_token(self): authorization_response="https://i.b/?code=hello", ) - approx_expires_at = round(time.time()) + 86400 + approx_expires_at = int(time.time()) + 86400 actual_expires_at = token.pop("expires_at") - self.assertAlmostEqual(actual_expires_at, approx_expires_at, places=2) + self.assertEqual(actual_expires_at, approx_expires_at) self.assertEqual( token, diff --git a/tests/test_oauth2_session.py b/tests/test_oauth2_session.py index 7e3e63c..6f80b4b 100644 --- a/tests/test_oauth2_session.py +++ b/tests/test_oauth2_session.py @@ -40,7 +40,7 @@ def setUp(self): "access_token": "asdfoiw37850234lkjsdfsdf", "refresh_token": "sldvafkjw34509s8dfsdf", "expires_in": 3600, - "expires_at": fake_time + 3600, + "expires_at": int(fake_time) + 3600, } # use someclientid:someclientsecret to easily differentiate between client and user credentials # these are the values used in oauthlib tests @@ -401,10 +401,10 @@ def test_cleans_previous_token_before_fetching_new_one(self): """ new_token = deepcopy(self.token) - past = time.time() - 7200 now = time.time() - self.token["expires_at"] = past - new_token["expires_at"] = now + 3600 + past = now - 7200 + self.token["expires_at"] = int(past) + new_token["expires_at"] = int(now) + 3600 url = "https://example.com/token" with mock.patch("time.time", lambda: now): @@ -488,6 +488,35 @@ def test_token_proxy(self): with self.assertRaises(AttributeError): del sess.token + @mock.patch("time.time", new=lambda: fake_time) + def test_add_expires_at_from_expires_in(self): + """Test that expires_at is correctly calculated from expires_in""" + sess = OAuth2Session("someclientid") + now = int(fake_time) + + # Test with missing expires_in (should not modify token) + token = {"access_token": "foo"} + updated_token = sess._add_expires_at(token) + self.assertNotIn('expires_at', updated_token) + + # Test with Date header + date_str = "Thu, 14 Mar 2024 08:30:00 GMT" + token = {"access_token": "foo", "expires_in": 3600} + updated_token = sess._add_expires_at(token, response_date=date_str) + self.assertIn('expires_at', updated_token) + expected_timestamp = 1710405000 + 3600 # 2024-03-14 08:30:00 UTC + 1 hour + self.assertEqual(updated_token['expires_at'], expected_timestamp) + + # Test with malformed Date header + updated_token = sess._add_expires_at(token, response_date="invalid date format") + self.assertIn('expires_at', updated_token) + self.assertEqual(updated_token['expires_at'], now + 3600) + + # Test with missing Date header + updated_token = sess._add_expires_at(token, response_date=None) + self.assertIn('expires_at', updated_token) + self.assertEqual(updated_token['expires_at'], now + 3600) + def test_authorized_false(self): sess = OAuth2Session("someclientid") self.assertFalse(sess.authorized)