Skip to content

Commit

Permalink
Merge branch 'master' into fix/usersync-update-role
Browse files Browse the repository at this point in the history
  • Loading branch information
Avantol13 committed Feb 19, 2021
2 parents 193c1b8 + 6e55fcd commit d74020b
Show file tree
Hide file tree
Showing 20 changed files with 1,423 additions and 357 deletions.
36 changes: 36 additions & 0 deletions bin/fence_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
force_update_google_link,
migrate_database,
google_list_authz_groups,
update_user_visas,
)
from fence.settings import CONFIG_SEARCH_FOLDERS

Expand Down Expand Up @@ -339,6 +340,28 @@ def parse_arguments():
"Fence is providing access to. Includes Fence Project.auth_id and Google Bucket "
"Access Group",
)
update_visas = subparsers.add_parser(
"update-visas",
help="Update visas and refresh tokens for users with valid visas and refresh tokens.",
)
update_visas.add_argument(
"--chunk-size",
required=False,
help="size of chunk of users we want to take from each query to db. Default value: 10",
)
update_visas.add_argument(
"--concurrency",
required=False,
help="number of concurrent users going through the visa update flow. Default value: 5",
)
update_visas.add_argument(
"--thread-pool-size",
required=False,
help="number of Docker container CPU used for jwt verifcation. Default value: 3",
)
update_visas.add_argument(
"--buffer-size", required=False, help="max size of queue. Default value: 10"
)

return parser.parse_args()

Expand Down Expand Up @@ -382,6 +405,9 @@ def main():
STORAGE_CREDENTIALS = os.environ.get("STORAGE_CREDENTIALS") or config.get(
"STORAGE_CREDENTIALS"
)
usersync = config.get("USERSYNC", {})
sync_from_visas = usersync.get("sync_from_visas", False)
fallback_to_dbgap_sftp = usersync.get("fallback_to_dbgap_sftp", False)

arborist = None
if args.arborist:
Expand Down Expand Up @@ -444,6 +470,8 @@ def main():
sync_from_local_yaml_file=args.yaml,
folder=args.folder,
arborist=arborist,
sync_from_visas=sync_from_visas,
fallback_to_dbgap_sftp=fallback_to_dbgap_sftp,
)
elif args.action == "dbgap-download-access-files":
download_dbgap_files(
Expand Down Expand Up @@ -543,6 +571,14 @@ def main():
)
elif args.action == "migrate":
migrate_database(DB)
elif args.action == "update-visas":
update_user_visas(
DB,
chunk_size=args.chunk_size,
concurrency=args.concurrency,
thread_pool_size=args.thread_pool_size,
buffer_size=args.buffer_size,
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion fence/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def set_csrf(response):
"""
if not flask.request.cookies.get("csrftoken"):
secure = config.get("SESSION_COOKIE_SECURE", True)
response.set_cookie("csrftoken", random_str(40), secure=secure)
response.set_cookie("csrftoken", random_str(40), secure=secure, httponly=True)

if flask.request.method in ["POST", "PUT", "DELETE"]:
current_session.commit()
Expand Down
3 changes: 1 addition & 2 deletions fence/blueprints/login/ras.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jwt
from flask_sqlalchemy_session import current_session

from fence.models import GA4GHVisaV1, IdentityProvider, User
from fence.models import GA4GHVisaV1, IdentityProvider

from fence.blueprints.login.base import DefaultOAuth2Login, DefaultOAuth2Callback

Expand Down Expand Up @@ -57,7 +57,6 @@ def post_login(self, user, token_result):
expires=int(decoded_visa["exp"]),
ga4gh_visa=encoded_visa,
)

current_session.add(visa)
current_session.commit()

Expand Down
8 changes: 7 additions & 1 deletion fence/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,6 @@ dbGaP:
# 'studyX': ['/orgA/', '/orgB/']
# 'studyX.c2': ['/orgB/', '/orgC/']
# 'studyZ': ['/orgD/']

# Regex to match an assession number that has consent information in forms like:
# phs00301123.c999
# phs000123.v3.p1.c3
Expand Down Expand Up @@ -770,3 +769,10 @@ SYNAPSE_AUTHZ_TTL: 86400
RAS_REFRESH_EXPIRATION: 1296000
# Number of projects that can be registered to a Google Service Accont
SERVICE_ACCOUNT_LIMIT: 6
# Settings for usersync with visas
USERSYNC:
sync_from_visas: false
# fallback to dbgap sftp when there are no valid visas for a user i.e. if they're expired or if they're malformed
fallback_to_dbgap_sftp: false
visa_types:
ras: [https://ras.nih.gov/visas/v1, https://ras.nih.gov/visas/v1.1]
181 changes: 181 additions & 0 deletions fence/job/visa_update_cronjob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import asyncio
import datetime
import time

from cdislogging import get_logger
from userdatamodel.driver import SQLAlchemyDriver

from fence.config import config
from fence.models import (
GA4GHVisaV1,
User,
UpstreamRefreshToken,
query_for_user,
)
from fence.resources.openid.ras_oauth2 import RASOauth2Client as RASClient


logger = get_logger(__name__, log_level="debug")


class Visa_Token_Update(object):
def __init__(
self,
chunk_size=None,
concurrency=None,
thread_pool_size=None,
buffer_size=None,
logger=logger,
):
"""
args:
chunk_size: size of chunk of users we want to take from each iteration
concurrency: number of concurrent users going through the visa update flow
thread_pool_size: number of Docker container CPU used for jwt verifcation
buffer_size: max size of queue
"""
self.chunk_size = chunk_size or 10
self.concurrency = concurrency or 5
self.thread_pool_size = thread_pool_size or 3
self.buffer_size = buffer_size or 10
self.n_workers = self.thread_pool_size + self.concurrency
self.logger = logger

self.visa_types = config.get("USERSYNC", {}).get("visa_types", {})

# Initialize visa clients:
oidc = config.get("OPENID_CONNECT", {})
if "ras" not in oidc:
self.logger.error("RAS client not configured")
self.ras_client = None
else:
self.ras_client = RASClient(
oidc["ras"],
HTTP_PROXY=config.get("HTTP_PROXY"),
logger=logger,
)

async def update_tokens(self, db_session):
"""
Initialize a producer-consumer workflow.
Producer: Collects users from db and feeds it to the workers
Worker: Takes in the users from the Producer and passes it to the Updater to update the tokens and passes those updated tokens for JWT validation
Updater: Updates refresh_tokens and visas by calling the update_user_visas from the correct client
"""
start_time = time.time()
self.logger.info("Initializing Visa Update Cronjob . . .")
self.logger.info("Total concurrency size: {}".format(self.concurrency))
self.logger.info("Total thread pool size: {}".format(self.thread_pool_size))
self.logger.info("Total buffer size: {}".format(self.buffer_size))
self.logger.info("Total number of workers: {}".format(self.n_workers))

queue = asyncio.Queue(maxsize=self.buffer_size)
updater_queue = asyncio.Queue(maxsize=self.n_workers)
loop = asyncio.get_event_loop()

producers = loop.create_task(self.producer(db_session, queue, chunk_idx=0))
workers = [
loop.create_task(self.worker(j, queue, updater_queue))
for j in range(self.n_workers)
]
updaters = [
loop.create_task(self.updater(i, updater_queue, db_session))
for i in range(self.concurrency)
]

await asyncio.gather(producers)
self.logger.info("Producers done producing")
await queue.join()

await asyncio.gather(*workers)
await updater_queue.join() # blocks until everything in updater_queue is complete

for u in updaters:
u.cancel()

self.logger.info(
"Visa cron job completed in {}".format(
datetime.timedelta(seconds=time.time() - start_time)
)
)

async def get_user_from_db(self, db_session, chunk_idx):
"""
Window function to get chunks of data from the table
"""
start, stop = self.chunk_size * chunk_idx, self.chunk_size * (chunk_idx + 1)
users = db_session.query(User).slice(start, stop).all()
return users

async def producer(self, db_session, queue, chunk_idx):
"""
Produces users from db and puts them in a queue for processing
"""
chunk_size = self.chunk_size
while True:
users = await self.get_user_from_db(db_session, chunk_idx)

if users == None:
break
for user in users:
self.logger.info("Producer producing user {}".format(user.username))
await queue.put(user)
if len(users) < chunk_size:
break
chunk_idx += 1

async def worker(self, name, queue, updater_queue):
"""
Create tasks to pass to updater to update visas AND pass updated visas to _verify_jwt_token for verification
"""
while not queue.empty():
user = await queue.get()
await updater_queue.put(user)
self._verify_jwt_token(user.ga4gh_visas_v1)
queue.task_done()

async def updater(self, name, updater_queue, db_session):
"""
Update visas in the updater_queue
"""
while True:
user = await updater_queue.get()
if user.ga4gh_visas_v1:
for visa in user.ga4gh_visas_v1:
client = self._pick_client(visa)
self.logger.info(
"Updater {} updating visa for user {}".format(
name, user.username
)
)
client.update_user_visas(user, db_session)
else:
self.logger.info(
"User {} doesnt have visa. Skipping . . .".format(user.username)
)

updater_queue.task_done()

def _pick_client(self, visa):
"""
Pick oidc client according to the visa provider
"""
client = None
if visa.type in self.visa_types["ras"]:
client = self.ras_client
else:
raise Exception(
"Visa type {} not configured in fence-config".format(visa.type)
)
if not client:
raise Exception(
"Visa Client not set up or not available for type {}".format(visa.type)
)
return client

def _verify_jwt_token(self, visa):
# NOT IMPLEMENTED
# TODO: Once local jwt verification is ready use thread_pool_size to determine how many users we want to verify the token for
pass
17 changes: 12 additions & 5 deletions fence/resources/openid/idp_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_user_id(self, code):
"""
raise NotImplementedError()

def get_access_token(self, user, token_endpoint):
def get_access_token(self, user, token_endpoint, db_session=None):

"""
Get access_token using a refresh_token and store it in upstream_refresh_token table.
Expand All @@ -158,11 +158,18 @@ def get_access_token(self, user, token_endpoint):
)
new_refresh_token = token_response["refresh_token"]

self.store_refresh_token(user, refresh_token=new_refresh_token, expires=expires)
self.store_refresh_token(
user,
refresh_token=new_refresh_token,
expires=expires,
db_session=db_session,
)

return token_response

def store_refresh_token(self, user, refresh_token, expires):
def store_refresh_token(
self, user, refresh_token, expires, db_session=current_session
):
"""
Store refresh token in db.
"""
Expand All @@ -172,6 +179,6 @@ def store_refresh_token(self, user, refresh_token, expires):
refresh_token=refresh_token,
expires=expires,
)
current_db_session = current_session.object_session(upstream_refresh_token)
current_db_session = db_session.object_session(upstream_refresh_token)
current_db_session.add(upstream_refresh_token)
current_session.commit()
db_session.commit()
10 changes: 5 additions & 5 deletions fence/resources/openid/ras_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,22 @@ def get_user_id(self, code):
return {"username": username}

@backoff.on_exception(backoff.expo, Exception, **DEFAULT_BACKOFF_SETTINGS)
def update_user_visas(self, user):
def update_user_visas(self, user, db_session=current_session):
"""
Updates user's RAS refresh token and uses the new access token to retrieve new visas from
RAS's /userinfo endpoint and update the db with the new visa.
- delete user's visas from db if we're not able to get a new access_token
- delete user's visas from db if we're not able to get a new visa
"""
user.ga4gh_visas_v1 = []
current_session.commit()
db_session.commit()

try:
token_endpoint = self.get_value_from_discovery_doc("token_endpoint", "")
userinfo_endpoint = self.get_value_from_discovery_doc(
"userinfo_endpoint", ""
)
token = self.get_access_token(user, token_endpoint)
token = self.get_access_token(user, token_endpoint, db_session)
userinfo = self.get_userinfo(token, userinfo_endpoint)
encoded_visas = userinfo.get("ga4gh_passport_v1", [])
except Exception as e:
Expand All @@ -139,12 +139,12 @@ def update_user_visas(self, user):
ga4gh_visa=encoded_visa,
)

current_db_session = current_session.object_session(visa)
current_db_session = db_session.object_session(visa)

current_db_session.add(visa)
except Exception as e:
err_msg = (
f"Could not process visa '{encoded_visa}' - skipping this visa"
)
self.logger.exception("{}: {}".format(err_msg, e), exc_info=True)
current_session.commit()
db_session.commit()

0 comments on commit d74020b

Please sign in to comment.