Skip to content

Commit

Permalink
Merge branch 'master' into fix/service-account-limit
Browse files Browse the repository at this point in the history
  • Loading branch information
BinamB authored Aug 5, 2020
2 parents 52cc836 + d672b46 commit c025ab3
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 23 deletions.
12 changes: 12 additions & 0 deletions fence/blueprints/login/ras.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from fence.blueprints.login.base import DefaultOAuth2Login, DefaultOAuth2Callback

from fence.config import config


class RASLogin(DefaultOAuth2Login):
def __init__(self):
Expand Down Expand Up @@ -58,3 +60,13 @@ def post_login(self, user, token_result):

current_session.add(visa)
current_session.commit()

# Store refresh token in db
refresh_token = flask.g.tokens.get("refresh_token")
id_token = flask.g.tokens.get("id_token")
decoded_id = jwt.decode(id_token, verify=False)
# Add 15 days to iat to calculate refresh token expiration time
expires = int(decoded_id.get("iat")) + config["RAS_REFRESH_EXPIRATION"]
flask.current_app.ras_client.store_refresh_token(
user=user, refresh_token=refresh_token, expires=expires
)
2 changes: 2 additions & 0 deletions fence/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -765,3 +765,5 @@ SYNAPSE_URI: 'https://repo-prod.prod.sagebase.org/auth/v1'
SYNAPSE_JWKS_URI:
SYNAPSE_DISCOVERY_URL:
SYNAPSE_AUTHZ_TTL: 86400
# RAS refresh_tokens expire in 15 days
RAS_REFRESH_EXPIRATION: 1296000
25 changes: 20 additions & 5 deletions fence/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""
Define sqlalchemy models.
The models here inherit from the `Base` in userdatamodel, so when the fence app
is initialized, the resulting db session includes everything from userdatamodel
and this file.
The `migrate` function in this file is called every init and can be used for
database migrations.
"""
Expand Down Expand Up @@ -71,11 +69,9 @@ def query_for_user(session, username):
class ClientAuthType(Enum):
"""
List the possible types of OAuth client authentication, which are
- None (no authentication).
- Basic (using basic HTTP authorization header to include the client ID & secret).
- POST (the client ID & secret are included in the body of a POST request).
These all have a corresponding string which identifies them to authlib.
"""

Expand Down Expand Up @@ -555,6 +551,26 @@ class GA4GHVisaV1(Base):
expires = Column(BigInteger, nullable=False)


class UpstreamRefreshToken(Base):
# General table to store any refresh_token sent from any oidc client

__tablename__ = "upstream_refresh_token"

id = Column(BigInteger, primary_key=True)

user_id = Column(Integer, ForeignKey(User.id, ondelete="CASCADE"), nullable=False)
user = relationship(
"User",
backref=backref(
"upstream_refresh_tokens",
cascade="all, delete-orphan",
passive_deletes=True,
),
)
refresh_token = Column(Text, nullable=False)
expires = Column(BigInteger, nullable=False)


to_timestamp = (
"CREATE OR REPLACE FUNCTION pc_datetime_to_timestamp(datetoconvert timestamp) "
"RETURNS BIGINT AS "
Expand Down Expand Up @@ -1041,7 +1057,6 @@ def _remove_policy(driver, md):
def _add_google_project_id(driver, md):
"""
Add new unique not null field to GoogleServiceAccount.
In order to do this without errors, we have to:
- add the field and allow null (for all previous rows)
- update all null entries to be unique
Expand Down
43 changes: 43 additions & 0 deletions fence/resources/openid/idp_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from cached_property import cached_property
from jose import jwt
import requests
import time
from fence.errors import AuthError
from fence.models import UpstreamRefreshToken
from flask_sqlalchemy_session import current_session


class Oauth2ClientBase(object):
Expand Down Expand Up @@ -128,3 +132,42 @@ def get_user_id(self, code):
for successfully logged in user OR "error" field with details of the error.
"""
raise NotImplementedError()

def get_access_token(self, user, token_endpoint):

"""
Get access_token using a refresh_token
"""
refresh_token = None
expires = None

# get refresh_token and expiration from db
for row in user.upstream_refresh_tokens:
refresh_token = row.refresh_token
expires = row.expires

if not refresh_token:
raise AuthError("User doesnt have a refresh token")
if time.time() > expires:
raise AuthError("Refresh token expired. Please login again.")

token_response = self.session.refresh_token(
url=token_endpoint, proxies=self.get_proxies(), refresh_token=refresh_token,
)
new_refresh_token = token_response["refresh_token"]

self.store_refresh_token(user, refresh_token=new_refresh_token, expires=expires)

return token_response

def store_refresh_token(self, user, refresh_token, expires):
"""
Store refresh token in db.
"""
user.upstream_refresh_tokens = []
upstream_refresh_token = UpstreamRefreshToken(
user=user, refresh_token=refresh_token, expires=expires,
)
current_db_session = current_session.object_session(upstream_refresh_token)
current_db_session.add(upstream_refresh_token)
current_session.commit()
40 changes: 22 additions & 18 deletions fence/resources/openid/ras_oauth2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# need new RAS
import flask
from .idp_oauth2 import Oauth2ClientBase
from jose import jwt
Expand All @@ -9,7 +8,6 @@ class RASOauth2Client(Oauth2ClientBase):
"""
client for interacting with RAS oauth 2,
as openid connect is supported under oauth2
"""

RAS_DISCOVERY_URL = "https://stsstg.nih.gov/.well-known/openid-configuration"
Expand Down Expand Up @@ -57,34 +55,40 @@ def get_user_id(self, code):

token = self.get_token(token_endpoint, code)
keys = self.get_jwt_keys(jwks_endpoint)
userinfo = self.get_userinfo(token, userinfo_endpoint)

claims = jwt.decode(
token["id_token"],
keys,
options={"verify_aud": False, "verify_at_hash": False},
)

userinfo = self.get_userinfo(token, userinfo_endpoint)
username = None
if userinfo.get("UserID"):
username = userinfo["UserID"]
field_name = "UserID"
elif userinfo.get("preferred_username"):
username = userinfo["preferred_username"]
field_name = "preferred_username"
elif claims.get("sub"):
username = claims["sub"]
field_name = "sub"
if not username:
self.logger.error(
"{}, received claims: {} and userinfo: {}".format(
err_msg, claims, userinfo
)
)
return {"error": err_msg}

# Save userinfo in flask.g.user for later use in post_login
self.logger.info("Using {} field as username.".format(field_name))

# Save userinfo and token in flask.g for later use in post_login
flask.g.userinfo = userinfo
flask.g.tokens = token

except Exception as e:
self.logger.exception("{}: {}".format(err_msg, e))
return {"error": err_msg}

username = None
if userinfo.get("UserID"):
username = userinfo["UserID"]
elif userinfo.get("preferred_username"):
username = userinfo["preferred_username"]
elif claims.get("sub"):
username = claims["sub"]
if not username:
logger.error(
"{}, received claims: {} and userinfo: {}".format(
err_msg, claims, userinfo
)
)
return {"error": err_msg}
return {"username": username}

0 comments on commit c025ab3

Please sign in to comment.