Skip to content

Commit

Permalink
Merge 67ed2e5 into 28d169a
Browse files Browse the repository at this point in the history
  • Loading branch information
BinamB committed Jan 5, 2021
2 parents 28d169a + 67ed2e5 commit fdfc0f6
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 3 deletions.
47 changes: 45 additions & 2 deletions fence/resources/openid/ras_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import requests
import jwt
import backoff
import json
from flask_sqlalchemy_session import current_session
from jose import jwt as jose_jwt

Expand Down Expand Up @@ -48,6 +49,17 @@ def get_userinfo(self, token, userinfo_endpoint):
res = requests.get(userinfo_endpoint, headers=header)
return res.json()

def validate_passport(self, validation_endpoint, passport):
"""
Validate passport with RAS's validation endpoint.
TODO: Remove this once we can locally validate passports
NOTE: RAS has an option to query a single visa with the /passport/validate?visa= but not using these since we hit the limit for an http header pretty quick
"""
payload = json.dumps(passport)
headers = {"Content-Type": "application/json"}
res = requests.post(validation_endpoint, headers=headers, data=payload)
return res.text

def get_user_id(self, code):

err_msg = "Can't get user's info"
Expand All @@ -58,11 +70,16 @@ def get_user_id(self, code):
userinfo_endpoint = self.get_value_from_discovery_doc(
"userinfo_endpoint", ""
)
validation_endpoint = (
self.get_value_from_discovery_doc("issuer", "") + "/passport/validate"
)

token = self.get_token(token_endpoint, code)
keys = self.get_jwt_keys(jwks_endpoint)
userinfo = self.get_userinfo(token, userinfo_endpoint)

validation = self.validate_passport(validation_endpoint, userinfo)

claims = jose_jwt.decode(
token["id_token"],
keys,
Expand Down Expand Up @@ -93,7 +110,16 @@ def get_user_id(self, code):
self.logger.info("Using {} field as username.".format(field_name))

# Save userinfo and token in flask.g for later use in post_login
flask.g.userinfo = userinfo
flask.g.userinfo = {}
if validation == "Valid":
self.logger.info("Passport validated")
flask.g.userinfo = userinfo
else:
self.logger.error(
"Passport validation failed. Not storing passport in database. Error: {}".format(
validation
)
)
flask.g.tokens = token

except Exception as e:
Expand All @@ -120,12 +146,29 @@ def update_user_visas(self, user):
)
token = self.get_access_token(user, token_endpoint)
userinfo = self.get_userinfo(token, userinfo_endpoint)
encoded_visas = userinfo.get("ga4gh_passport_v1", [])
except Exception as e:
err_msg = "Could not retrieve visa"
self.logger.exception("{}: {}".format(err_msg, e))
raise

validation_endpoint = (
self.get_value_from_discovery_doc("issuer", "") + "/passport/validate"
)
validation = self.validate_passport(validation_endpoint, userinfo)
encoded_visas = (
userinfo.get("ga4gh_passport_v1", []) if validation == "Valid" else []
)

encoded_visas = []
if validation == "Valid":
encoded_visas = userinfo.get("ga4gh_passport_v1", [])
else:
self.logger.error(
"Passport validation failed. Not storing passport in database. Error: {}".format(
validation
)
)

for encoded_visa in encoded_visas:
try:
# TODO: These visas must be validated!!!
Expand Down
102 changes: 101 additions & 1 deletion tests/ras/test_ras.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def test_store_refresh_token(db_session):
assert final_query.expires == new_expire


@mock.patch("fence.resources.openid.ras_oauth2.RASOauth2Client.validate_passport")
@mock.patch("fence.resources.openid.ras_oauth2.RASOauth2Client.get_userinfo")
@mock.patch("fence.resources.openid.ras_oauth2.RASOauth2Client.get_access_token")
@mock.patch(
Expand All @@ -107,6 +108,7 @@ def test_update_visa_token(
mock_discovery,
mock_get_token,
mock_userinfo,
mock_validate_passport,
config,
db_session,
rsa_private_key,
Expand Down Expand Up @@ -176,13 +178,104 @@ def test_update_visa_token(
userinfo_response["ga4gh_passport_v1"] = [encoded_visa]
mock_userinfo.return_value = userinfo_response

mock_validate_passport.return_value = "Valid"

ras_client.update_user_visas(test_user)

query_visa = db_session.query(GA4GHVisaV1).first()
assert query_visa.ga4gh_visa
assert query_visa.ga4gh_visa == encoded_visa


@mock.patch("fence.resources.openid.ras_oauth2.RASOauth2Client.validate_passport")
@mock.patch("fence.resources.openid.ras_oauth2.RASOauth2Client.get_userinfo")
@mock.patch("fence.resources.openid.ras_oauth2.RASOauth2Client.get_access_token")
@mock.patch(
"fence.resources.openid.ras_oauth2.RASOauth2Client.get_value_from_discovery_doc"
)
def test_update_visa_token_validation_invalid(
mock_discovery,
mock_get_token,
mock_userinfo,
mock_validate_passport,
config,
db_session,
rsa_private_key,
kid,
kid_2,
):
"""
Test to check visa table is not updated when passport validation fails
"""

mock_discovery.return_value = "https://ras/token_endpoint"
new_token = "refresh12345abcdefg"
token_response = {
"access_token": "abcdef12345",
"id_token": "id12345abcdef",
"refresh_token": new_token,
}
mock_get_token.return_value = token_response

userinfo_response = {
"sub": "abcd-asdj-sajpiasj12iojd-asnoin",
"name": "",
"preferred_username": "someuser@era.com",
"UID": "",
"UserID": "admin_user",
"email": "",
}

test_user = add_test_user(db_session)
add_visa_manually(db_session, test_user, rsa_private_key, kid)
add_refresh_token(db_session, test_user)

visa_query = db_session.query(GA4GHVisaV1).filter_by(user=test_user).first()
initial_visa = visa_query.ga4gh_visa
assert initial_visa

oidc = config.get("OPENID_CONNECT", {})
ras_client = RASClient(
oidc["ras"],
HTTP_PROXY=config.get("HTTP_PROXY"),
logger=logger,
)

new_visa = {
"iss": "https://stsstg.nih.gov",
"sub": "abcde12345aspdij",
"iat": int(time.time()),
"exp": int(time.time()) + 1000,
"scope": "openid ga4gh_passport_v1 email profile",
"jti": "jtiajoidasndokmasdl",
"txn": "sapidjspa.asipidja",
"name": "",
"ga4gh_visa_v1": {
"type": "https://ras/visa/v1",
"asserted": int(time.time()),
"value": "https://nig/passport/dbgap",
"source": "https://ncbi/gap",
},
}

headers = {"kid": kid_2}

encoded_visa = jwt.encode(
new_visa, key=rsa_private_key, headers=headers, algorithm="RS256"
).decode("utf-8")

userinfo_response["ga4gh_passport_v1"] = [encoded_visa]
mock_userinfo.return_value = userinfo_response

mock_validate_passport.return_value = "Invalid"

ras_client.update_user_visas(test_user)

query_visa = db_session.query(GA4GHVisaV1).filter_by(user=test_user).all()
assert len(query_visa) == 0


@mock.patch("fence.resources.openid.ras_oauth2.RASOauth2Client.validate_passport")
@mock.patch("fence.resources.openid.ras_oauth2.RASOauth2Client.get_userinfo")
@mock.patch("fence.resources.openid.ras_oauth2.RASOauth2Client.get_access_token")
@mock.patch(
Expand All @@ -192,6 +285,7 @@ def test_update_visa_empty_visa_returned(
mock_discovery,
mock_get_token,
mock_userinfo,
mock_validate_passport,
config,
db_session,
rsa_private_key,
Expand Down Expand Up @@ -223,6 +317,8 @@ def test_update_visa_empty_visa_returned(

mock_userinfo.return_value = userinfo_response

mock_validate_passport.return_value = "Valid"

test_user = add_test_user(db_session)
add_visa_manually(db_session, test_user, rsa_private_key, kid)
add_refresh_token(db_session, test_user)
Expand All @@ -244,6 +340,7 @@ def test_update_visa_empty_visa_returned(
assert query_visa == None


@mock.patch("fence.resources.openid.ras_oauth2.RASOauth2Client.validate_passport")
@mock.patch("fence.resources.openid.ras_oauth2.RASOauth2Client.get_userinfo")
@mock.patch("fence.resources.openid.ras_oauth2.RASOauth2Client.get_access_token")
@mock.patch(
Expand All @@ -253,6 +350,7 @@ def test_update_visa_token_with_invalid_visa(
mock_discovery,
mock_get_token,
mock_userinfo,
mock_validate_passport,
config,
db_session,
rsa_private_key,
Expand Down Expand Up @@ -300,7 +398,7 @@ def test_update_visa_token_with_invalid_visa(

new_visa = {
"iss": "https://stsstg.nih.gov",
"sub": "abcde12345aspdij",
"sub": "abcde12345aspdijk",
"iat": int(time.time()),
"exp": int(time.time()) + 1000,
"scope": "openid ga4gh_passport_v1 email profile",
Expand All @@ -324,6 +422,8 @@ def test_update_visa_token_with_invalid_visa(
userinfo_response["ga4gh_passport_v1"] = [encoded_visa, [], encoded_visa]
mock_userinfo.return_value = userinfo_response

mock_validate_passport.return_value = "Valid"

ras_client.update_user_visas(test_user)

query_visas = db_session.query(GA4GHVisaV1).filter_by(user=test_user).all()
Expand Down

0 comments on commit fdfc0f6

Please sign in to comment.