Skip to content

Commit

Permalink
Merge 3a60f9a into 2e021e4
Browse files Browse the repository at this point in the history
  • Loading branch information
Avantol13 committed Jun 10, 2021
2 parents 2e021e4 + 3a60f9a commit fd8910f
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ repos:
- id: no-commit-to-branch
args: [--branch, develop, --branch, master, --pattern, release/.*]
- repo: https://github.com/psf/black
rev: 21.5b1
rev: 21.5b2
hooks:
- id: black
18 changes: 15 additions & 3 deletions fence/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def build_redirect_url(hostname, path):
return redirect_base + path


def login_user(username, provider, fence_idp=None, shib_idp=None):
def login_user(username, provider, fence_idp=None, shib_idp=None, email=None):
"""
Login a user with the given username and provider. Set values in Flask
session to indicate the user being logged in. In addition, commit the user
Expand All @@ -66,7 +66,10 @@ def login_user(username, provider, fence_idp=None, shib_idp=None):
Args:
username (str): specific username of user to be logged in
provider (str): specfic idp of user to be logged in
fence_idp (str, optional): Downstreawm fence IdP
shib_idp (str, optional): Downstreawm shibboleth IdP
email (str, optional): email of user (may or may not match username depending
on the IdP)
"""

def set_flask_session_values(user):
Expand Down Expand Up @@ -96,7 +99,10 @@ def set_flask_session_values(user):
set_flask_session_values(user)
return
else:
user = User(username=username)
if email:
user = User(username=username, email=email)
else:
user = User(username=username)

idp = (
current_session.query(IdentityProvider)
Expand All @@ -106,6 +112,12 @@ def set_flask_session_values(user):
if not idp:
idp = IdentityProvider(name=provider)

if email and user.email != email:
logger.info(
f"Updating username {user.username}'s email from {user.email} to {email}"
)
user.email = email

user.identity_provider = idp
current_session.add(user)
current_session.commit()
Expand Down
8 changes: 4 additions & 4 deletions fence/blueprints/login/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def default_login():
# fall back on ENABLED_IDENTITY_PROVIDERS.default
default_idp = config["ENABLED_IDENTITY_PROVIDERS"]["default"]
else:
logger.warn("DEFAULT_LOGIN_IDP not configured")
logger.warning("DEFAULT_LOGIN_IDP not configured")
default_idp = None

# other login options
Expand All @@ -89,7 +89,7 @@ def default_login():
for idp, details in enabled_providers.items()
]
else:
logger.warn("LOGIN_OPTIONS not configured or empty")
logger.warning("LOGIN_OPTIONS not configured or empty")
login_options = []

def absolute_login_url(provider_id, fence_idp=None, shib_idp=None):
Expand Down Expand Up @@ -325,15 +325,15 @@ def get_all_shib_idps():
all_shib_idps = []
for shib_idp in res.json():
if "entityID" not in shib_idp:
logger.warn(
logger.warning(
f"get_all_shib_idps(): 'entityID' field not in IDP data: {shib_idp}. Skipping this IDP."
)
continue
idp = shib_idp["entityID"]
if len(shib_idp.get("DisplayNames", [])) > 0:
name = get_shib_idp_en_name(shib_idp["DisplayNames"])
else:
logger.warn(
logger.warning(
f"get_all_shib_idps(): 'DisplayNames' field not in IDP data: {shib_idp}. Using IDP ID '{idp}' as IDP name."
)
name = idp
Expand Down
12 changes: 8 additions & 4 deletions fence/blueprints/login/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get(self):


class DefaultOAuth2Callback(Resource):
def __init__(self, idp_name, client, username_field="email"):
def __init__(self, idp_name, client, username_field="email", email_field="email"):
"""
Construct a resource for a login callback endpoint
Expand All @@ -66,10 +66,13 @@ def __init__(self, idp_name, client, username_field="email"):
Some instaniation of this base client class or a child class
username_field (str, optional): default field from response to
retrieve the username
email_field (str, optional): default field from response to
retrieve the email (if available)
"""
self.idp_name = idp_name
self.client = client
self.username_field = username_field
self.email_field = email_field

def get(self):
# Check if user granted access
Expand Down Expand Up @@ -97,8 +100,9 @@ def get(self):
code = flask.request.args.get("code")
result = self.client.get_user_id(code)
username = result.get(self.username_field)
email = result.get(self.email_field)
if username:
resp = _login(username, self.idp_name)
resp = _login(username, self.idp_name, email=email)
self.post_login(flask.g.user, result)
return resp
raise UserError(result)
Expand All @@ -118,12 +122,12 @@ def create_login_log(idp_name):
)


def _login(username, idp_name):
def _login(username, idp_name, email=None):
"""
Login user with given username, then redirect if session has a saved
redirect.
"""
login_user(username, idp_name)
login_user(username, idp_name, email=email)
if flask.session.get("redirect"):
return flask.redirect(flask.session.get("redirect"))
return flask.jsonify({"username": username})
2 changes: 2 additions & 0 deletions fence/blueprints/login/fence_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ def get(self):
tokens["id_token"], scope="openid", purpose="id", attempt_refresh=True
)
username = id_token_claims["context"]["user"]["name"]
email = id_token_claims["context"]["user"].get("email")
login_user(
username,
IdentityProvider.fence,
fence_idp=flask.session.get("fence_idp"),
shib_idp=flask.session.get("shib_idp"),
email=email,
)
self.post_login()

Expand Down
2 changes: 1 addition & 1 deletion fence/resources/audit_service_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, service_url, logger):
logger.info("Enabling audit logs")
self.ping()
else:
logger.warn("NOT enabling audit logs")
logger.warning("NOT enabling audit logs")

def ping(self):
max_tries = 3
Expand Down
2 changes: 1 addition & 1 deletion fence/resources/openid/idp_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def get_auth_url(self):

def get_user_id(self, code):
"""
Must implement in inheriting class. Should return dictionary with "email" field
Must implement in inheriting class. Should return dictionary with necessary field(s)
for successfully logged in user OR "error" field with details of the error.
"""
raise NotImplementedError()
Expand Down
2 changes: 1 addition & 1 deletion fence/sync/sync_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,7 +1806,7 @@ def _add_dbgap_study_to_arborist(self, dbgap_study, dbgap_config):

def _is_arborist_healthy(self):
if not self.arborist_client:
self.logger.warn("no arborist client set; skipping arborist dbgap sync")
self.logger.warning("no arborist client set; skipping arborist dbgap sync")
return False
if not self.arborist_client.healthy():
# TODO (rudyardrichter, 2019-01-07): add backoff/retry here
Expand Down
2 changes: 1 addition & 1 deletion tests/oidc/discovery/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ def test_oidc_config_fields(app, client):

for field in recommended_fields:
if field not in response.json:
warnings.warn(
warnings.warning(
"OIDC configuration response missing recommended field: " + field
)

0 comments on commit fd8910f

Please sign in to comment.