Skip to content

Commit

Permalink
Merge 2762484 into ddd7ecd
Browse files Browse the repository at this point in the history
  • Loading branch information
Avantol13 committed Jun 15, 2021
2 parents ddd7ecd + 2762484 commit da4ae00
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ repos:
- id: no-commit-to-branch
args: [--branch, develop, --branch, master, --pattern, release/.*]
- repo: https://github.com/psf/black
rev: 21.5b1
rev: 21.5b2
hooks:
- id: black
28 changes: 25 additions & 3 deletions fence/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def build_redirect_url(hostname, path):
return redirect_base + path


def login_user(username, provider, fence_idp=None, shib_idp=None):
def login_user(username, provider, fence_idp=None, shib_idp=None, email=None):
"""
Login a user with the given username and provider. Set values in Flask
session to indicate the user being logged in. In addition, commit the user
Expand All @@ -66,7 +66,10 @@ def login_user(username, provider, fence_idp=None, shib_idp=None):
Args:
username (str): specific username of user to be logged in
provider (str): specfic idp of user to be logged in
fence_idp (str, optional): Downstreawm fence IdP
shib_idp (str, optional): Downstreawm shibboleth IdP
email (str, optional): email of user (may or may not match username depending
on the IdP)
"""

def set_flask_session_values(user):
Expand All @@ -89,14 +92,19 @@ def set_flask_session_values(user):

user = query_for_user(session=current_session, username=username)
if user:
_update_users_email(user, email)

# This expression is relevant to those users who already have user and
# idp info persisted to the database. We return early to avoid
# unnecessarily re-saving that user and idp info.
if user.identity_provider and user.identity_provider.name == provider:
set_flask_session_values(user)
return
else:
user = User(username=username)
if email:
user = User(username=username, email=email)
else:
user = User(username=username)

idp = (
current_session.query(IdentityProvider)
Expand Down Expand Up @@ -249,3 +257,17 @@ def wrapper(*args, **kwargs):
def admin_login_required(function):
"""Compose the login required and admin required decorators."""
return login_required({"admin"})(admin_required(function))


def _update_users_email(user, email):
"""
Update email if provided and doesn't match db entry.
"""
if email and user.email != email:
logger.info(
f"Updating username {user.username}'s email from {user.email} to {email}"
)
user.email = email

current_session.add(user)
current_session.commit()
12 changes: 8 additions & 4 deletions fence/blueprints/login/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get(self):


class DefaultOAuth2Callback(Resource):
def __init__(self, idp_name, client, username_field="email"):
def __init__(self, idp_name, client, username_field="email", email_field="email"):
"""
Construct a resource for a login callback endpoint
Expand All @@ -66,10 +66,13 @@ def __init__(self, idp_name, client, username_field="email"):
Some instaniation of this base client class or a child class
username_field (str, optional): default field from response to
retrieve the username
email_field (str, optional): default field from response to
retrieve the email (if available)
"""
self.idp_name = idp_name
self.client = client
self.username_field = username_field
self.email_field = email_field

def get(self):
# Check if user granted access
Expand Down Expand Up @@ -97,8 +100,9 @@ def get(self):
code = flask.request.args.get("code")
result = self.client.get_user_id(code)
username = result.get(self.username_field)
email = result.get(self.email_field)
if username:
resp = _login(username, self.idp_name)
resp = _login(username, self.idp_name, email=email)
self.post_login(flask.g.user, result)
return resp
raise UserError(result)
Expand All @@ -118,12 +122,12 @@ def prepare_login_log(idp_name):
}


def _login(username, idp_name):
def _login(username, idp_name, email=None):
"""
Login user with given username, then redirect if session has a saved
redirect.
"""
login_user(username, idp_name)
login_user(username, idp_name, email=email)
if flask.session.get("redirect"):
return flask.redirect(flask.session.get("redirect"))
return flask.jsonify({"username": username})
2 changes: 2 additions & 0 deletions fence/blueprints/login/fence_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ def get(self):
tokens["id_token"], scope="openid", purpose="id", attempt_refresh=True
)
username = id_token_claims["context"]["user"]["name"]
email = id_token_claims["context"]["user"].get("email")
login_user(
username,
IdentityProvider.fence,
fence_idp=flask.session.get("fence_idp"),
shib_idp=flask.session.get("shib_idp"),
email=email,
)
self.post_login()

Expand Down
2 changes: 1 addition & 1 deletion fence/resources/openid/idp_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def get_auth_url(self):

def get_user_id(self, code):
"""
Must implement in inheriting class. Should return dictionary with "email" field
Must implement in inheriting class. Should return dictionary with necessary field(s)
for successfully logged in user OR "error" field with details of the error.
"""
raise NotImplementedError()
Expand Down
2 changes: 1 addition & 1 deletion fence/resources/openid/ras_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_user_id(self, code):
self.logger.exception("{}: {}".format(err_msg, e))
return {"error": err_msg}

return {"username": username}
return {"username": username, "email": userinfo.get("email")}

@backoff.on_exception(backoff.expo, Exception, **DEFAULT_BACKOFF_SETTINGS)
def update_user_visas(self, user, db_session=current_session):
Expand Down
2 changes: 1 addition & 1 deletion tests/oidc/discovery/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ def test_oidc_config_fields(app, client):

for field in recommended_fields:
if field not in response.json:
warnings.warn(
warnings.warning(
"OIDC configuration response missing recommended field: " + field
)

0 comments on commit da4ae00

Please sign in to comment.