Skip to content

Commit

Permalink
Merge 90ee044 into 8157fcf
Browse files Browse the repository at this point in the history
  • Loading branch information
Avantol13 committed Jun 11, 2021
2 parents 8157fcf + 90ee044 commit 12ef55d
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ repos:
- id: no-commit-to-branch
args: [--branch, develop, --branch, master, --pattern, release/.*]
- repo: https://github.com/psf/black
rev: 21.5b1
rev: 21.5b2
hooks:
- id: black
29 changes: 26 additions & 3 deletions fence/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def build_redirect_url(hostname, path):
return redirect_base + path


def login_user(username, provider, fence_idp=None, shib_idp=None):
def login_user(username, provider, fence_idp=None, shib_idp=None, email=None):
"""
Login a user with the given username and provider. Set values in Flask
session to indicate the user being logged in. In addition, commit the user
Expand All @@ -66,7 +66,10 @@ def login_user(username, provider, fence_idp=None, shib_idp=None):
Args:
username (str): specific username of user to be logged in
provider (str): specfic idp of user to be logged in
fence_idp (str, optional): Downstreawm fence IdP
shib_idp (str, optional): Downstreawm shibboleth IdP
email (str, optional): email of user (may or may not match username depending
on the IdP)
"""

def set_flask_session_values(user):
Expand All @@ -89,14 +92,19 @@ def set_flask_session_values(user):

user = query_for_user(session=current_session, username=username)
if user:
_update_users_email(user, email)

# This expression is relevant to those users who already have user and
# idp info persisted to the database. We return early to avoid
# unnecessarily re-saving that user and idp info.
if user.identity_provider and user.identity_provider.name == provider:
set_flask_session_values(user)
return
else:
user = User(username=username)
if email:
user = User(username=username, email=email)
else:
user = User(username=username)

idp = (
current_session.query(IdentityProvider)
Expand All @@ -106,6 +114,8 @@ def set_flask_session_values(user):
if not idp:
idp = IdentityProvider(name=provider)

_update_users_email(user, email)

user.identity_provider = idp
current_session.add(user)
current_session.commit()
Expand Down Expand Up @@ -249,3 +259,16 @@ def wrapper(*args, **kwargs):
def admin_login_required(function):
"""Compose the login required and admin required decorators."""
return login_required({"admin"})(admin_required(function))


def _update_users_email(user, email):
"""
Update email if provided and doesn't match db entry.
NOTE: This does NOT commit to the db, do so outside this function
"""
if email and user.email != email:
logger.info(
f"Updating username {user.username}'s email from {user.email} to {email}"
)
user.email = email
12 changes: 8 additions & 4 deletions fence/blueprints/login/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get(self):


class DefaultOAuth2Callback(Resource):
def __init__(self, idp_name, client, username_field="email"):
def __init__(self, idp_name, client, username_field="email", email_field="email"):
"""
Construct a resource for a login callback endpoint
Expand All @@ -66,10 +66,13 @@ def __init__(self, idp_name, client, username_field="email"):
Some instaniation of this base client class or a child class
username_field (str, optional): default field from response to
retrieve the username
email_field (str, optional): default field from response to
retrieve the email (if available)
"""
self.idp_name = idp_name
self.client = client
self.username_field = username_field
self.email_field = email_field

def get(self):
# Check if user granted access
Expand Down Expand Up @@ -97,8 +100,9 @@ def get(self):
code = flask.request.args.get("code")
result = self.client.get_user_id(code)
username = result.get(self.username_field)
email = result.get(self.email_field)
if username:
resp = _login(username, self.idp_name)
resp = _login(username, self.idp_name, email=email)
self.post_login(flask.g.user, result)
return resp
raise UserError(result)
Expand All @@ -118,12 +122,12 @@ def prepare_login_log(idp_name):
}


def _login(username, idp_name):
def _login(username, idp_name, email=None):
"""
Login user with given username, then redirect if session has a saved
redirect.
"""
login_user(username, idp_name)
login_user(username, idp_name, email=email)
if flask.session.get("redirect"):
return flask.redirect(flask.session.get("redirect"))
return flask.jsonify({"username": username})
2 changes: 2 additions & 0 deletions fence/blueprints/login/fence_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ def get(self):
tokens["id_token"], scope="openid", purpose="id", attempt_refresh=True
)
username = id_token_claims["context"]["user"]["name"]
email = id_token_claims["context"]["user"].get("email")
login_user(
username,
IdentityProvider.fence,
fence_idp=flask.session.get("fence_idp"),
shib_idp=flask.session.get("shib_idp"),
email=email,
)
self.post_login()

Expand Down
120 changes: 120 additions & 0 deletions fence/resources/audit_service_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import flask
import requests
import time

from fence.config import config
from fence.errors import InternalError


def get_request_url():
request_url = flask.request.url
base_url = config.get("BASE_URL", "")
if request_url.startswith(base_url):
request_url = request_url[len(base_url) :]
return request_url


def is_audit_enabled(category=None):
enable_audit_logs = config.get("ENABLE_AUDIT_LOGS") or {}
if category:
return enable_audit_logs and enable_audit_logs.get(category, False)
return enable_audit_logs and any(v for v in enable_audit_logs.values())


class AuditServiceClient:
def __init__(self, service_url, logger):
self.service_url = service_url.rstrip("/")
self.logger = logger

# audit logs should not be enabled if the audit-service is unavailable
if is_audit_enabled():
logger.info("Enabling audit logs")
self.ping()
else:
logger.warning("NOT enabling audit logs")

def ping(self):
max_tries = 3
status_url = f"{self.service_url}/_status"
self.logger.debug(f"Checking audit-service availability at {status_url}")
wait_time = 1
for t in range(max_tries):
r = requests.get(status_url)
if r.status_code == 200:
return # all good!
if t + 1 < max_tries:
self.logger.debug(f"Retrying... (got status code {r.status_code})")
time.sleep(wait_time)
wait_time *= 2
raise Exception(
f"Audit logs are enabled but audit-service is unreachable at {status_url}: {r.text}"
)

def check_response(self, resp, body):
# The audit-service returns 201 before inserting the log in the DB.
# This request should only error if the input is incorrect (status
# code 422) or if the service is unreachable.
if resp.status_code != 201:
try:
err = resp.json()
except Exception:
err = resp.text
self.logger.error(f"Unable to POST audit log `{body}`. Details:\n{err}")
raise InternalError("Unable to create audit log")

def create_presigned_url_log(
self,
username,
sub,
guid,
resource_paths,
action,
protocol,
):
if not is_audit_enabled("presigned_url"):
return

url = f"{self.service_url}/log/presigned_url"
body = {
"request_url": get_request_url(),
"status_code": 200, # only record successful requests for now
"username": username,
"sub": sub,
"guid": guid,
"resource_paths": resource_paths,
"action": action,
"protocol": protocol,
}
resp = requests.post(url, json=body)
self.check_response(resp, body)

def create_login_log(
self,
username,
sub,
idp,
fence_idp=None,
shib_idp=None,
client_id=None,
):
if not is_audit_enabled("login"):
return

# special case for idp=fence when falling back on
# fence_idp=shibboleth and shib_idp=NIH
if shib_idp == "None":
shib_idp = None

url = f"{self.service_url}/log/login"
body = {
"request_url": get_request_url(),
"status_code": 200, # only record successful requests for now
"username": username,
"sub": sub,
"idp": idp,
"fence_idp": fence_idp,
"shib_idp": shib_idp,
"client_id": client_id,
}
resp = requests.post(url, json=body)
self.check_response(resp, body)
2 changes: 1 addition & 1 deletion fence/resources/openid/idp_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def get_auth_url(self):

def get_user_id(self, code):
"""
Must implement in inheriting class. Should return dictionary with "email" field
Must implement in inheriting class. Should return dictionary with necessary field(s)
for successfully logged in user OR "error" field with details of the error.
"""
raise NotImplementedError()
Expand Down
2 changes: 1 addition & 1 deletion tests/oidc/discovery/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ def test_oidc_config_fields(app, client):

for field in recommended_fields:
if field not in response.json:
warnings.warn(
warnings.warning(
"OIDC configuration response missing recommended field: " + field
)

0 comments on commit 12ef55d

Please sign in to comment.