Skip to content

Commit

Permalink
(PXP-6339): RAS visa validation (#933)
Browse files Browse the repository at this point in the history
* feat(visa-validate): Add GA4GH_VISA_ISSUER_WHITELIST config val

* feat(visa-validate): validate visas

* feat(visa-validate): Add try/except and logs around validation

* feat(visa-validate): Update deps, lockfile

* feat(visa-validate): Validate visas in visa update cronjob

* feat(visa-validate): Add comment to existing code about clearing user visas

* feat(visa-validate): Rm dead code -- visa validation happens in client update fn

* feat(visa-validate): Catch exception when visa is not JWT

* test(visa-validate): Add visa issuer whitelist to test config

* test(visa-validate): Update RAS tests with public key cache

* feat(visa-validate): Fix dangerous default empty dict argument

* feat(visa-validate): Rm unused imports

* feat(visa-validate): Bump project minor version

* feat(visa-validate): Add RAS visa issuers and self to default whitelist

* feat(visa-validate): Whitelist --> Allowlist

* test(visa-validate): Add test for client fetching pkeys

* feat(visa-validate): Bump project minor version again

* leapfrogging 5.1.0 (push audit logs to AWS SQS)

* fix(cirrus): Stay on cirrus 1.3.x
  • Loading branch information
vpsx committed Jun 16, 2021
1 parent 8fc5367 commit 7cbc74c
Show file tree
Hide file tree
Showing 8 changed files with 375 additions and 119 deletions.
57 changes: 50 additions & 7 deletions fence/blueprints/login/ras.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import flask
import jwt
import os
from authutils.errors import JWTError
from authutils.token.core import validate_jwt
from authutils.token.keys import get_public_key_for_token
from cdislogging import get_logger
from flask_sqlalchemy_session import current_session
import urllib.request, urllib.error
from urllib.parse import urlparse, parse_qs
Expand All @@ -13,6 +17,8 @@
from fence.scripting.fence_create import init_syncer
from fence.utils import get_valid_expiration

logger = get_logger(__name__)


class RASLogin(DefaultOAuth2Login):
def __init__(self):
Expand Down Expand Up @@ -45,13 +51,50 @@ def post_login(self, user=None, token_result=None):
encoded_visas = flask.g.userinfo.get("ga4gh_passport_v1", [])

for encoded_visa in encoded_visas:
# TODO: These visas must be validated!!!
# i.e. (Remove `verify=False` in jwt.decode call)
# But: need a routine for getting public keys per visa.
# And we probably want to cache them.
# Also needs any ga4gh-specific validation.
# For now just read them without validation:
decoded_visa = jwt.decode(encoded_visa, verify=False)
try:
# Do not move out of loop unless we can assume every visa has same issuer and kid
public_key = get_public_key_for_token(
encoded_visa, attempt_refresh=True
)
except Exception as e:
# (But don't log the visa contents!)
logger.error(
"Could not get public key to validate visa: {}. Discarding visa.".format(
e
)
)
continue

try:
# Validate the visa per GA4GH AAI "Embedded access token" format rules.
# pyjwt also validates signature and expiration.
decoded_visa = validate_jwt(
encoded_visa,
public_key,
# Embedded token must not contain aud claim
aud=None,
# Embedded token must contain scope claim, which must include openid
scope={"openid"},
issuers=config.get("GA4GH_VISA_ISSUER_ALLOWLIST", []),
# Embedded token must contain iss, sub, iat, exp claims
# options={"require": ["iss", "sub", "iat", "exp"]},
# ^ FIXME 2021-05-13: Above needs pyjwt>=v2.0.0, which requires cryptography>=3.
# Once we can unpin and upgrade cryptography and pyjwt, switch to above "options" arg.
# For now, pyjwt 1.7.1 is able to require iat and exp;
# authutils' validate_jwt (i.e. the function being called) checks issuers already (see above);
# and we will check separately for sub below.
options={
"require_iat": True,
"require_exp": True,
},
)

# Also require 'sub' claim (see note above about pyjwt and the options arg).
if "sub" not in decoded_visa:
raise JWTError("Visa is missing the 'sub' claim.")
except Exception as e:
logger.error("Visa failed validation: {}. Discarding visa.".format(e))
continue

visa = GA4GHVisaV1(
user=user,
Expand Down
5 changes: 5 additions & 0 deletions fence/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,11 @@ ASSUME_ROLE_CACHE_SECONDS: 1800

# RAS refresh_tokens expire in 15 days
RAS_REFRESH_EXPIRATION: 1296000
# List of JWT issuers from which Fence will accept GA4GH visas
GA4GH_VISA_ISSUER_ALLOWLIST:
- '{{BASE_URL}}'
- 'https://sts.nih.gov'
- 'https://stsstg.nih.gov'
# Number of projects that can be registered to a Google Service Accont
SERVICE_ACCOUNT_LIMIT: 6
# Settings for usersync with visas
Expand Down
20 changes: 11 additions & 9 deletions fence/job/visa_update_cronjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def __init__(
self.n_workers = self.thread_pool_size + self.concurrency
self.logger = logger

# This job runs without an application context, so it cannot use the
# current_app.jwt_public_keys cache.
# This is a simple dict with the same lifetime as the job.
# When there are many visas from many issuers it will make sense to
# implement a more persistent cache.
self.pkey_cache = {}

self.visa_types = config.get("USERSYNC", {}).get("visa_types", {})

# Initialize visa clients:
Expand Down Expand Up @@ -128,17 +135,17 @@ async def producer(self, db_session, queue, chunk_idx):

async def worker(self, name, queue, updater_queue):
"""
Create tasks to pass to updater to update visas AND pass updated visas to _verify_jwt_token for verification
Create tasks to pass to updater to update visas.
"""
while not queue.empty():
user = await queue.get()
await updater_queue.put(user)
self._verify_jwt_token(user.ga4gh_visas_v1)
queue.task_done()

async def updater(self, name, updater_queue, db_session):
"""
Update visas in the updater_queue
Update visas in the updater_queue.
Note that only visas which pass validation will be saved.
"""
while True:
user = await updater_queue.get()
Expand All @@ -150,7 +157,7 @@ async def updater(self, name, updater_queue, db_session):
name, user.username
)
)
client.update_user_visas(user, db_session)
client.update_user_visas(user, self.pkey_cache, db_session)
else:
# clear expired refresh tokens
if user.upstream_refresh_tokens:
Expand Down Expand Up @@ -179,8 +186,3 @@ def _pick_client(self, visa):
"Visa Client not set up or not available for type {}".format(visa.type)
)
return client

def _verify_jwt_token(self, visa):
# NOT IMPLEMENTED
# TODO: Once local jwt verification is ready use thread_pool_size to determine how many users we want to verify the token for
pass
153 changes: 136 additions & 17 deletions fence/resources/openid/ras_oauth2.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import base64
import flask
import httpx
import requests
import jwt
import backoff
from flask_sqlalchemy_session import current_session
from jose import jwt as jose_jwt

from authutils.errors import JWTError
from authutils.token.core import get_iss, get_keys_url, get_kid, validate_jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa

from fence.config import config
from fence.models import GA4GHVisaV1
from fence.utils import DEFAULT_BACKOFF_SETTINGS
from .idp_oauth2 import Oauth2ClientBase
Expand Down Expand Up @@ -116,13 +124,17 @@ def get_user_id(self, code):
return {"username": username, "email": userinfo.get("email")}

@backoff.on_exception(backoff.expo, Exception, **DEFAULT_BACKOFF_SETTINGS)
def update_user_visas(self, user, db_session=current_session):
def update_user_visas(self, user, pkey_cache, db_session=current_session):
"""
Updates user's RAS refresh token and uses the new access token to retrieve new visas from
RAS's /userinfo endpoint and update the db with the new visa.
- delete user's visas from db if we're not able to get a new access_token
- delete user's visas from db if we're not able to get a new visa
- delete user's visas from db if we're not able to get new visas
- only visas which pass validation are added to the database
"""
# Note: in the cronjob this is called per-user per-visa.
# So it should be noted that when there are more clients than just RAS,
# this code as it stands can remove visas that the user has from other clients.
user.ga4gh_visas_v1 = []
db_session.commit()

Expand All @@ -141,23 +153,130 @@ def update_user_visas(self, user, db_session=current_session):

for encoded_visa in encoded_visas:
try:
# TODO: These visas must be validated!!!
decoded_visa = jwt.decode(encoded_visa, verify=False)
visa = GA4GHVisaV1(
user=user,
source=decoded_visa["ga4gh_visa_v1"]["source"],
type=decoded_visa["ga4gh_visa_v1"]["type"],
asserted=int(decoded_visa["ga4gh_visa_v1"]["asserted"]),
expires=int(decoded_visa["exp"]),
ga4gh_visa=encoded_visa,
visa_issuer = get_iss(encoded_visa)
visa_kid = get_kid(encoded_visa)
except Exception as e:
self.logger.error(
"Could not get issuer or kid from visa: {}. Discarding visa.".format(
e
)
)
continue # Not raise: If visa malformed, does not make sense to retry

# See if pkey is in cronjob cache; if not, update cache.
public_key = pkey_cache.get(visa_issuer, {}).get(visa_kid)
if not public_key:
jwks_url = get_keys_url(visa_issuer)
try:
jwt_public_keys = httpx.get(jwks_url).json()["keys"]
except Exception as e:
raise JWTError(
"Could not get public key to validate visa: Could not fetch keys from JWKs url: {}".format(
e
)
)

issuer_public_keys = {}
try:
for key in jwt_public_keys:
if "kty" in key and key["kty"] == "RSA":
self.logger.debug(
"Serializing RSA public key (kid: {}) to PEM format.".format(
key["kid"]
)
)
# Decode public numbers https://tools.ietf.org/html/rfc7518#section-6.3.1
n_padded_bytes = base64.urlsafe_b64decode(
key["n"] + "=" * (4 - len(key["n"]) % 4)
)
e_padded_bytes = base64.urlsafe_b64decode(
key["e"] + "=" * (4 - len(key["e"]) % 4)
)
n = int.from_bytes(n_padded_bytes, "big", signed=False)
e = int.from_bytes(e_padded_bytes, "big", signed=False)
# Serialize and encode public key--PyJWT decode/validation requires PEM
rsa_public_key = rsa.RSAPublicNumbers(e, n).public_key(
default_backend()
)
public_bytes = rsa_public_key.public_bytes(
serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo,
)
# Cache the encoded key by issuer
issuer_public_keys[key["kid"]] = public_bytes
else:
self.logger.debug(
"Key type (kty) is not 'RSA'; assuming PEM format. "
"Skipping key serialization. (kid: {})".format(key[0])
)
issuer_public_keys[key[0]] = key[1]

pkey_cache.update({visa_issuer: issuer_public_keys})
self.logger.info(
"Refreshed cronjob pkey cache for visa issuer {}".format(
visa_issuer
)
)
except Exception as e:
self.logger.error(
"Could not refresh cronjob pkey cache for visa issuer {}: "
"Something went wrong during serialization: {}. Discarding visa.".format(
visa_issuer, e
)
)
continue # Not raise: If issuer publishing malformed keys, does not make sense to retry

public_key = pkey_cache.get(visa_issuer, {}).get(visa_kid)

if not public_key:
self.logger.error(
"Could not get public key to validate visa: Successfully fetched "
"issuer's keys but did not find the visa's key id among them. Discarding visa."
)
continue # Not raise: If issuer not publishing pkey, does not make sense to retry

current_db_session = db_session.object_session(visa)
try:
# Validate the visa per GA4GH AAI "Embedded access token" format rules.
# pyjwt also validates signature and expiration.
decoded_visa = validate_jwt(
encoded_visa,
public_key,
# Embedded token must not contain aud claim
aud=None,
# Embedded token must contain scope claim, which must include openid
scope={"openid"},
issuers=config.get("GA4GH_VISA_ISSUER_ALLOWLIST", []),
# Embedded token must contain iss, sub, iat, exp claims
# options={"require": ["iss", "sub", "iat", "exp"]},
# ^ FIXME 2021-05-13: Above needs pyjwt>=v2.0.0, which requires cryptography>=3.
# Once we can unpin and upgrade cryptography and pyjwt, switch to above "options" arg.
# For now, pyjwt 1.7.1 is able to require iat and exp;
# authutils' validate_jwt (i.e. the function being called) checks issuers already (see above);
# and we will check separately for sub below.
options={
"require_iat": True,
"require_exp": True,
},
)

current_db_session.add(visa)
# Also require 'sub' claim (see note above about pyjwt and the options arg).
if "sub" not in decoded_visa:
raise JWTError("Visa is missing the 'sub' claim.")
except Exception as e:
err_msg = (
f"Could not process visa '{encoded_visa}' - skipping this visa"
self.logger.error(
"Visa failed validation: {}. Discarding visa.".format(e)
)
self.logger.exception("{}: {}".format(err_msg, e), exc_info=True)
continue

visa = GA4GHVisaV1(
user=user,
source=decoded_visa["ga4gh_visa_v1"]["source"],
type=decoded_visa["ga4gh_visa_v1"]["type"],
asserted=int(decoded_visa["ga4gh_visa_v1"]["asserted"]),
expires=int(decoded_visa["exp"]),
ga4gh_visa=encoded_visa,
)

current_db_session = db_session.object_session(visa)
current_db_session.add(visa)
db_session.commit()
Loading

0 comments on commit 7cbc74c

Please sign in to comment.