Skip to content

Commit

Permalink
Merge d24070e into d0fe6fc
Browse files Browse the repository at this point in the history
  • Loading branch information
BinamB committed Feb 10, 2021
2 parents d0fe6fc + d24070e commit b0d7993
Show file tree
Hide file tree
Showing 7 changed files with 436 additions and 82 deletions.
31 changes: 31 additions & 0 deletions bin/fence_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
force_update_google_link,
migrate_database,
google_list_authz_groups,
update_user_visas,
)
from fence.settings import CONFIG_SEARCH_FOLDERS

Expand Down Expand Up @@ -339,6 +340,28 @@ def parse_arguments():
"Fence is providing access to. Includes Fence Project.auth_id and Google Bucket "
"Access Group",
)
update_visas = subparsers.add_parser(
"update-visas",
help="Update visas and refresh tokens for users with valid visas and refresh tokens.",
)
update_visas.add_argument(
"--chunk-size",
required=False,
help="size of chunk of users we want to take from each query to db. Default value: 10",
)
update_visas.add_argument(
"--concurrency",
required=False,
help="number of concurrent users going through the visa update flow. Default value: 5",
)
update_visas.add_argument(
"--thread-pool-size",
required=False,
help="number of Docker container CPU used for jwt verifcation. Default value: 3",
)
update_visas.add_argument(
"--buffer-size", required=False, help="max size of queue. Default value: 10"
)

return parser.parse_args()

Expand Down Expand Up @@ -543,6 +566,14 @@ def main():
)
elif args.action == "migrate":
migrate_database(DB)
elif args.action == "update-visas":
update_user_visas(
DB,
chunk_size=args.chunk_size,
concurrency=args.concurrency,
thread_pool_size=args.thread_pool_size,
buffer_size=args.buffer_size,
)


if __name__ == "__main__":
Expand Down
170 changes: 170 additions & 0 deletions fence/job/visa_update_cronjob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import asyncio
import datetime
import time

from cdislogging import get_logger
from userdatamodel.driver import SQLAlchemyDriver

from fence.config import config
from fence.models import (
GA4GHVisaV1,
User,
UpstreamRefreshToken,
query_for_user,
)
from fence.resources.openid.ras_oauth2 import RASOauth2Client as RASClient


logger = get_logger(__name__, log_level="debug")


class Visa_Token_Update(object):
def __init__(
self,
chunk_size=None,
concurrency=None,
thread_pool_size=None,
buffer_size=None,
logger=logger,
):
"""
args:
chunk_size: size of chunk of users we want to take from each iteration
concurrency: number of concurrent users going through the visa update flow
thread_pool_size: number of Docker container CPU used for jwt verifcation
buffer_size: max size of queue
"""
self.chunk_size = chunk_size or 10
self.concurrency = concurrency or 5
self.thread_pool_size = thread_pool_size or 3
self.buffer_size = buffer_size or 10
self.n_workers = self.thread_pool_size + self.concurrency
self.logger = logger

# Initialize visa clients:
oidc = config.get("OPENID_CONNECT", {})
if "ras" not in oidc:
self.logger.error("RAS client not configured")
else:
self.ras_client = RASClient(
oidc["ras"],
HTTP_PROXY=config.get("HTTP_PROXY"),
logger=logger,
)

async def update_tokens(self, db_session):
"""
Initialize a producer-consumer workflow.
Producer: Collects users from db and feeds it to the workers
Worker: Takes in the users from the Producer and passes it to the Updater to update the tokens and passes those updated tokens for JWT validation
Updater: Updates refresh_tokens and visas by calling the update_user_visas from the correct client
"""
start_time = time.time()
self.logger.info("Initializing Visa Update Cronjob . . .")
self.logger.info("Total concurrency size: {}".format(self.concurrency))
self.logger.info("Total thread pool size: {}".format(self.thread_pool_size))
self.logger.info("Total buffer size: {}".format(self.buffer_size))
self.logger.info("Total number of workers: {}".format(self.n_workers))

queue = asyncio.Queue(maxsize=self.buffer_size)
updater_queue = asyncio.Queue(maxsize=self.n_workers)
loop = asyncio.get_event_loop()

producers = [
loop.create_task(self.producer(db_session, queue, chunk_idx=0))
for _ in range(1)
]
workers = [
loop.create_task(self.worker(j, queue, updater_queue))
for j in range(self.n_workers)
]
updaters = [
loop.create_task(self.updater(i, updater_queue, db_session))
for i in range(self.concurrency)
]

await asyncio.gather(*producers)
self.logger.info("Producers done producing")
await queue.join()

await asyncio.gather(*workers)
await updater_queue.join() # blocks until everything in updater_queue is complete

for u in updaters:
u.cancel()

self.logger.info(
"Visa cron job completed in {}".format(
datetime.timedelta(seconds=time.time() - start_time)
)
)

async def get_user_from_db(self, db_session, chunk_idx):
"""
Window function to get chunks of data from the table
"""
start, stop = self.chunk_size * chunk_idx, self.chunk_size * (chunk_idx + 1)
users = db_session.query(User).slice(start, stop).all()
return users

async def producer(self, db_session, queue, chunk_idx):
"""
Produces users from db and puts them in a queue for processing
"""
chunk_size = self.chunk_size
while True:
users = await self.get_user_from_db(db_session, chunk_idx)

if users == None:
break
for user in users:
self.logger.info("Producer producing user {}".format(user.username))
await queue.put(user)
if len(users) < chunk_size:
break
chunk_idx += 1

async def worker(self, name, queue, updater_queue):
"""
Create tasks to pass to updater to update visas AND pass updated visas to _verify_jwt_token for verification
"""
while not queue.empty():
user = await queue.get()
await updater_queue.put(user)
queue.task_done()

async def updater(self, name, updater_queue, db_session):
"""
Update visas in the updater_queue
"""
while True:
user = await updater_queue.get()
if user.ga4gh_visas_v1:
for visa in user.ga4gh_visas_v1:
client = self._pick_client(visa)
self.logger.info(
"Updater {} updating visa for user {}".format(
name, user.username
)
)
client.update_user_visas(user, db_session)
else:
self.logger.info(
"User {} doesnt have visa. Skipping . . .".format(user.username)
)

updater_queue.task_done()

def _pick_client(self, visa):
"""
Pick oidc client according to the visa provider
"""
if "ras" in visa.type:
return self.ras_client

def _verify_jwt_token(self, visa):
# NOT IMPLEMENTED
# TODO: Once local jwt verification is ready use thread_pool_size to determine how many users we want to verify the token for
pass
17 changes: 12 additions & 5 deletions fence/resources/openid/idp_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_user_id(self, code):
"""
raise NotImplementedError()

def get_access_token(self, user, token_endpoint):
def get_access_token(self, user, token_endpoint, db_session=None):

"""
Get access_token using a refresh_token and store it in upstream_refresh_token table.
Expand All @@ -158,11 +158,18 @@ def get_access_token(self, user, token_endpoint):
)
new_refresh_token = token_response["refresh_token"]

self.store_refresh_token(user, refresh_token=new_refresh_token, expires=expires)
self.store_refresh_token(
user,
refresh_token=new_refresh_token,
expires=expires,
db_session=db_session,
)

return token_response

def store_refresh_token(self, user, refresh_token, expires):
def store_refresh_token(
self, user, refresh_token, expires, db_session=current_session
):
"""
Store refresh token in db.
"""
Expand All @@ -172,6 +179,6 @@ def store_refresh_token(self, user, refresh_token, expires):
refresh_token=refresh_token,
expires=expires,
)
current_db_session = current_session.object_session(upstream_refresh_token)
current_db_session = db_session.object_session(upstream_refresh_token)
current_db_session.add(upstream_refresh_token)
current_session.commit()
db_session.commit()
10 changes: 5 additions & 5 deletions fence/resources/openid/ras_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,22 @@ def get_user_id(self, code):
return {"username": username}

@backoff.on_exception(backoff.expo, Exception, **DEFAULT_BACKOFF_SETTINGS)
def update_user_visas(self, user):
def update_user_visas(self, user, db_session=current_session):
"""
Updates user's RAS refresh token and uses the new access token to retrieve new visas from
RAS's /userinfo endpoint and update the db with the new visa.
- delete user's visas from db if we're not able to get a new access_token
- delete user's visas from db if we're not able to get a new visa
"""
user.ga4gh_visas_v1 = []
current_session.commit()
db_session.commit()

try:
token_endpoint = self.get_value_from_discovery_doc("token_endpoint", "")
userinfo_endpoint = self.get_value_from_discovery_doc(
"userinfo_endpoint", ""
)
token = self.get_access_token(user, token_endpoint)
token = self.get_access_token(user, token_endpoint, db_session)
userinfo = self.get_userinfo(token, userinfo_endpoint)
encoded_visas = userinfo.get("ga4gh_passport_v1", [])
except Exception as e:
Expand All @@ -139,12 +139,12 @@ def update_user_visas(self, user):
ga4gh_visa=encoded_visa,
)

current_db_session = current_session.object_session(visa)
current_db_session = db_session.object_session(visa)

current_db_session.add(visa)
except Exception as e:
err_msg = (
f"Could not process visa '{encoded_visa}' - skipping this visa"
)
self.logger.exception("{}: {}".format(err_msg, e), exc_info=True)
current_session.commit()
db_session.commit()
26 changes: 26 additions & 0 deletions fence/scripting/fence_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from yaml import safe_load
import json
import pprint
import asyncio

from cirrus import GoogleCloudManager
from cirrus.google_cloud.errors import GoogleAuthError
Expand Down Expand Up @@ -34,6 +35,7 @@
generate_signed_refresh_token,
issued_and_expiration_times,
)
from fence.job.visa_update_cronjob import Visa_Token_Update
from fence.models import (
Client,
GoogleServiceAccount,
Expand Down Expand Up @@ -1501,3 +1503,27 @@ def google_list_authz_groups(db):
print(", ".join(item[:-1]))

return google_authz


def update_user_visas(
db, chunk_size=None, concurrency=None, thread_pool_size=None, buffer_size=None
):
"""
Update visas and refresh tokens for users with valid visas and refresh tokens
db (string): database instance
chunk_size (int): size of chunk of users we want to take from each iteration
concurrency (int): number of concurrent users going through the visa update flow
thread_pool_size (int): number of Docker container CPU used for jwt verifcation
buffer_size (int): max size of queue
"""
driver = SQLAlchemyDriver(db)
job = Visa_Token_Update(
chunk_size=int(chunk_size) if chunk_size else None,
concurrency=int(concurrency) if concurrency else None,
thread_pool_size=int(thread_pool_size) if thread_pool_size else None,
buffer_size=int(buffer_size) if buffer_size else None,
)
with driver.session as db_session:
loop = asyncio.get_event_loop()
loop.run_until_complete(job.update_tokens(db_session))
Loading

0 comments on commit b0d7993

Please sign in to comment.