diff --git a/src/cli.rs b/src/cli.rs index 4f4d4d219..0c8c84995 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -455,9 +455,9 @@ pub struct Options { long = "oidc-scope", name = "oidc-scope", env = "P_OIDC_SCOPE", - default_value = "openid profile email", + default_value = "openid profile email offline_access", required = false, - help = "OIDC scope to request (default: openid profile email)" + help = "OIDC scope to request (default: openid profile email offline_access)" )] pub scope: String, diff --git a/src/handlers/http/middleware.rs b/src/handlers/http/middleware.rs index dee3933e8..7b7d6652a 100644 --- a/src/handlers/http/middleware.rs +++ b/src/handlers/http/middleware.rs @@ -24,16 +24,25 @@ use actix_web::{ dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready}, error::{ErrorBadRequest, ErrorForbidden, ErrorUnauthorized}, http::header::{self, HeaderName}, + web::Data, }; +use chrono::{Duration, Utc}; use futures_util::future::LocalBoxFuture; use crate::{ handlers::{ AUTHORIZATION_KEY, KINESIS_COMMON_ATTRIBUTES_KEY, LOG_SOURCE_KEY, LOG_SOURCE_KINESIS, - STREAM_NAME_HEADER_KEY, + STREAM_NAME_HEADER_KEY, http::rbac::RBACError, }, + oidc::DiscoveredClient, option::Mode, parseable::PARSEABLE, + rbac::{ + EXPIRY_DURATION, + map::{SessionKey, mut_sessions, mut_users, sessions, users}, + roles_to_permission, user, + }, + utils::get_user_from_request, }; use crate::{ rbac::Users, @@ -160,8 +169,97 @@ where let auth_result: Result<_, Error> = (self.auth_method)(&mut req, self.action); + let http_req = req.request().clone(); + let key: Result = extract_session_key(&mut req); + let userid: Result = get_user_from_request(&http_req); + let fut = self.service.call(req); Box::pin(async move { + let Ok(key) = key else { + return Err(ErrorUnauthorized( + "Your session has expired or is no longer valid. Please re-authenticate to access this resource.", + )); + }; + + // if session is expired, refresh token + if sessions().is_session_expired(&key) { + let oidc_client = match http_req.app_data::>>() { + Some(client) => { + let c = client.clone().into_inner(); + c.as_ref().clone() + } + None => None, + }; + + if let Some(client) = oidc_client + && let Ok(userid) = userid + { + let bearer_to_refresh = { + if let Some(user) = users().get(&userid) { + match &user.ty { + user::UserType::OAuth(oauth) if oauth.bearer.is_some() => { + Some(oauth.clone()) + } + _ => None, + } + } else { + None + } + }; + + if let Some(oauth_data) = bearer_to_refresh { + let Ok(refreshed_token) = client + .refresh_token(&oauth_data, Some(PARSEABLE.options.scope.as_str())) + .await + else { + return Err(ErrorUnauthorized( + "Your session has expired or is no longer valid. Please re-authenticate to access this resource.", + )); + }; + + let expires_in = + if let Some(expires_in) = refreshed_token.expires_in.as_ref() { + if *expires_in > u32::MAX.into() { + EXPIRY_DURATION + } else { + let v = i64::from(*expires_in as u32); + Duration::seconds(v) + } + } else { + EXPIRY_DURATION + }; + + let user_roles = { + let mut users_guard = mut_users(); + if let Some(user) = users_guard.get_mut(&userid) { + if let user::UserType::OAuth(oauth) = &mut user.ty { + oauth.bearer = Some(refreshed_token); + } + user.roles().to_vec() + } else { + return Err(ErrorUnauthorized( + "Your session has expired or is no longer valid. Please re-authenticate to access this resource.", + )); + } + }; + + mut_sessions().track_new( + userid.clone(), + key.clone(), + Utc::now() + expires_in, + roles_to_permission(user_roles), + ); + } else if let Some(user) = users().get(&userid) { + mut_sessions().track_new( + userid.clone(), + key.clone(), + Utc::now() + EXPIRY_DURATION, + roles_to_permission(user.roles()), + ); + } + } + } + match auth_result? { rbac::Response::UnAuthorized => { return Err(ErrorForbidden( diff --git a/src/handlers/http/modal/ingest_server.rs b/src/handlers/http/modal/ingest_server.rs index 628bd9f0f..0440e857c 100644 --- a/src/handlers/http/modal/ingest_server.rs +++ b/src/handlers/http/modal/ingest_server.rs @@ -51,7 +51,7 @@ use crate::{ use super::IngestorMetadata; use super::{ - OpenIdClient, ParseableServer, + ParseableServer, ingest::{ingestor_logstream, ingestor_rbac, ingestor_role}, }; @@ -62,7 +62,7 @@ pub struct IngestServer; #[async_trait] impl ParseableServer for IngestServer { // configure the api routes - fn configure_routes(config: &mut web::ServiceConfig, _oidc_client: Option) { + fn configure_routes(config: &mut web::ServiceConfig) { config .service( // Base path "{url}/api/v1" diff --git a/src/handlers/http/modal/mod.rs b/src/handlers/http/modal/mod.rs index 844975e5f..c8be6c89a 100644 --- a/src/handlers/http/modal/mod.rs +++ b/src/handlers/http/modal/mod.rs @@ -18,7 +18,11 @@ use std::{fmt, path::Path, sync::Arc}; -use actix_web::{App, HttpServer, middleware::from_fn, web::ServiceConfig}; +use actix_web::{ + App, HttpServer, + middleware::from_fn, + web::{self, ServiceConfig}, +}; use actix_web_prometheus::PrometheusMetrics; use anyhow::Context; use async_trait::async_trait; @@ -67,7 +71,7 @@ include!(concat!(env!("OUT_DIR"), "/generated.rs")); #[async_trait] pub trait ParseableServer { /// configure the router - fn configure_routes(config: &mut ServiceConfig, oidc_client: Option) + fn configure_routes(config: &mut ServiceConfig) where Self: Sized; @@ -96,7 +100,7 @@ pub trait ParseableServer { let client = config .connect(&format!("{API_BASE_PATH}/{API_VERSION}/o/code")) .await?; - Some(Arc::new(client)) + Some(client) } None => None, @@ -116,8 +120,9 @@ pub trait ParseableServer { // fn that creates the app let create_app_fn = move || { App::new() + .app_data(web::Data::new(oidc_client.clone())) .wrap(prometheus.clone()) - .configure(|config| Self::configure_routes(config, oidc_client.clone())) + .configure(|config| Self::configure_routes(config)) .wrap(from_fn(health_check::check_shutdown_middleware)) .wrap(actix_web::middleware::Logger::default()) .wrap(actix_web::middleware::Compress::default()) diff --git a/src/handlers/http/modal/query_server.rs b/src/handlers/http/modal/query_server.rs index f1a4249c7..c345d3112 100644 --- a/src/handlers/http/modal/query_server.rs +++ b/src/handlers/http/modal/query_server.rs @@ -42,14 +42,14 @@ use crate::Server; use crate::parseable::PARSEABLE; use super::query::{querier_ingest, querier_logstream, querier_rbac, querier_role}; -use super::{NodeType, OpenIdClient, ParseableServer, QuerierMetadata, load_on_init}; +use super::{NodeType, ParseableServer, QuerierMetadata, load_on_init}; pub struct QueryServer; pub static QUERIER_META: OnceCell> = OnceCell::const_new(); #[async_trait] impl ParseableServer for QueryServer { // configure the api routes - fn configure_routes(config: &mut ServiceConfig, oidc_client: Option) { + fn configure_routes(config: &mut ServiceConfig) { config .service( web::scope(&base_path()) @@ -66,7 +66,7 @@ impl ParseableServer for QueryServer { .service(Server::get_dashboards_webscope()) .service(Server::get_filters_webscope()) .service(Server::get_llm_webscope()) - .service(Server::get_oauth_webscope(oidc_client)) + .service(Server::get_oauth_webscope()) .service(Self::get_user_role_webscope()) .service(Server::get_roles_webscope()) .service(Server::get_counts_webscope().wrap(from_fn( diff --git a/src/handlers/http/modal/server.rs b/src/handlers/http/modal/server.rs index 6e3ba9ea7..7b145ebb1 100644 --- a/src/handlers/http/modal/server.rs +++ b/src/handlers/http/modal/server.rs @@ -61,7 +61,6 @@ use crate::{ }; // use super::generate; -use super::OpenIdClient; use super::ParseableServer; use super::generate; use super::load_on_init; @@ -70,7 +69,7 @@ pub struct Server; #[async_trait] impl ParseableServer for Server { - fn configure_routes(config: &mut web::ServiceConfig, oidc_client: Option) { + fn configure_routes(config: &mut web::ServiceConfig) { // there might be a bug in the configure routes method config .service( @@ -91,7 +90,7 @@ impl ParseableServer for Server { .service(Self::get_dashboards_webscope()) .service(Self::get_filters_webscope()) .service(Self::get_llm_webscope()) - .service(Self::get_oauth_webscope(oidc_client)) + .service(Self::get_oauth_webscope()) .service(Self::get_user_role_webscope()) .service(Self::get_roles_webscope()) .service(Self::get_counts_webscope().wrap(from_fn( @@ -570,17 +569,11 @@ impl Server { } // get the oauth webscope - pub fn get_oauth_webscope(oidc_client: Option) -> Scope { - let oauth = web::scope("/o") + pub fn get_oauth_webscope() -> Scope { + web::scope("/o") .service(resource("/login").route(web::get().to(oidc::login))) .service(resource("/logout").route(web::get().to(oidc::logout))) - .service(resource("/code").route(web::get().to(oidc::reply_login))); - - if let Some(client) = oidc_client { - oauth.app_data(web::Data::from(client)) - } else { - oauth - } + .service(resource("/code").route(web::get().to(oidc::reply_login))) } // get list of roles diff --git a/src/handlers/http/oidc.rs b/src/handlers/http/oidc.rs index 5f3506d42..ad8523467 100644 --- a/src/handlers/http/oidc.rs +++ b/src/handlers/http/oidc.rs @@ -16,7 +16,7 @@ * */ -use std::{collections::HashSet, sync::Arc}; +use std::collections::HashSet; use actix_web::{ HttpRequest, HttpResponse, @@ -24,10 +24,12 @@ use actix_web::{ http::header::{self, ContentType}, web::{self, Data}, }; +use chrono::{Duration, TimeDelta}; use http::StatusCode; -use openid::{Options, Token, Userinfo}; +use openid::{Bearer, Options, Token, Userinfo}; use regex::Regex; use serde::Deserialize; +use tracing::error; use ulid::Ulid; use url::Url; @@ -36,7 +38,7 @@ use crate::{ oidc::{Claims, DiscoveredClient}, parseable::PARSEABLE, rbac::{ - self, Users, + self, EXPIRY_DURATION, Users, map::{DEFAULT_ROLE, SessionKey}, user::{self, GroupUser, User, UserType}, }, @@ -72,14 +74,20 @@ pub async fn login( )); } - let oidc_client = req.app_data::>(); + let oidc_client = match req.app_data::>>() { + Some(client) => { + let c = client.clone().into_inner(); + c.as_ref().clone() + } + None => None, + }; let session_key = extract_session_key_from_req(&req).ok(); let (session_key, oidc_client) = match (session_key, oidc_client) { (None, None) => return Ok(redirect_no_oauth_setup(query.redirect.clone())), (None, Some(client)) => { return Ok(redirect_to_oidc( query, - client, + &client, PARSEABLE.options.scope.to_string().as_str(), )); } @@ -103,8 +111,11 @@ pub async fn login( ) if basic.verify_password(&password) => { let user_cookie = cookie_username(&username); let user_id_cookie = cookie_userid(&username); - let session_cookie = - exchange_basic_for_cookie(user, SessionKey::BasicAuth { username, password }); + let session_cookie = exchange_basic_for_cookie( + user, + SessionKey::BasicAuth { username, password }, + EXPIRY_DURATION, + ); Ok(redirect_to_client( query.redirect.as_str(), [user_cookie, user_id_cookie, session_cookie], @@ -121,7 +132,7 @@ pub async fn login( if let Some(oidc_client) = oidc_client { redirect_to_oidc( query, - oidc_client, + &oidc_client, PARSEABLE.options.scope.to_string().as_str(), ) } else { @@ -134,7 +145,13 @@ pub async fn login( } pub async fn logout(req: HttpRequest, query: web::Query) -> HttpResponse { - let oidc_client = req.app_data::>(); + let oidc_client = match req.app_data::>>() { + Some(client) => { + let c = client.clone().into_inner(); + c.as_ref().clone() + } + None => None, + }; let Some(session) = extract_session_key_from_req(&req).ok() else { return redirect_to_client(query.redirect.as_str(), None); }; @@ -155,11 +172,12 @@ pub async fn logout(req: HttpRequest, query: web::Query) -> /// Handler for code callback /// User should be redirected to page they were trying to access with cookie pub async fn reply_login( - oidc_client: Data, + req: HttpRequest, login_query: web::Query, ) -> Result { - let oidc_client = Data::into_inner(oidc_client); - let Ok((mut claims, user_info)): Result<(Claims, Userinfo), anyhow::Error> = + let oidc_client = req.app_data::>>().unwrap(); + let oidc_client = oidc_client.clone().into_inner().as_ref().clone().unwrap(); + let Ok((mut claims, user_info, bearer)): Result<(Claims, Userinfo, Bearer), anyhow::Error> = request_token(oidc_client, &login_query).await else { return Ok(HttpResponse::Unauthorized().finish()); @@ -178,6 +196,10 @@ pub async fn reply_login( } }; let user_info: user::UserInfo = user_info.into(); + + // if provider has group A, and parseable as has role A + // then user will automatically get assigned role A + // else, the default oidc role (inside parseable) will get assigned let group: HashSet = claims .other .remove("groups") @@ -223,12 +245,25 @@ pub async fn reply_login( final_roles.clone_from(&default_role); } + let expires_in = if let Some(expires_in) = bearer.expires_in.as_ref() { + // need an i64 somehow + if *expires_in > u32::MAX.into() { + EXPIRY_DURATION + } else { + let v = i64::from(*expires_in as u32); + Duration::seconds(v) + } + } else { + EXPIRY_DURATION + }; + let user = match (existing_user, final_roles) { - (Some(user), roles) => update_user_if_changed(user, roles, user_info).await?, - (None, roles) => put_user(&user_id, roles, user_info).await?, + (Some(user), roles) => update_user_if_changed(user, roles, user_info, bearer).await?, + (None, roles) => put_user(&user_id, roles, user_info, bearer).await?, }; let id = Ulid::new(); - Users.new_session(&user, SessionKey::SessionId(id)); + + Users.new_session(&user, SessionKey::SessionId(id), expires_in); let redirect_url = login_query .state @@ -270,10 +305,14 @@ fn find_existing_user(user_info: &user::UserInfo) -> Option { None } -fn exchange_basic_for_cookie(user: &User, key: SessionKey) -> Cookie<'static> { +fn exchange_basic_for_cookie( + user: &User, + key: SessionKey, + expires_in: TimeDelta, +) -> Cookie<'static> { let id = Ulid::new(); Users.remove_session(&key); - Users.new_session(user, SessionKey::SessionId(id)); + Users.new_session(user, SessionKey::SessionId(id), expires_in); cookie_session(id) } @@ -288,7 +327,8 @@ fn redirect_to_oidc( state: Some(redirect), ..Default::default() }); - let url: String = auth_url.into(); + let mut url: String = auth_url.into(); + url.push_str("&access_type=offline&prompt=consent"); HttpResponse::TemporaryRedirect() .insert_header((header::LOCATION, url)) .finish() @@ -348,9 +388,9 @@ pub fn cookie_userid(user_id: &str) -> Cookie<'static> { } pub async fn request_token( - oidc_client: Arc, + oidc_client: DiscoveredClient, login_query: &Login, -) -> anyhow::Result<(Claims, Userinfo)> { +) -> anyhow::Result<(Claims, Userinfo, Bearer)> { let mut token: Token = oidc_client.request_token(&login_query.code).await?.into(); let Some(id_token) = token.id_token.as_mut() else { return Err(anyhow::anyhow!("No id_token provided")); @@ -361,7 +401,8 @@ pub async fn request_token( let claims = id_token.payload().expect("payload is decoded").clone(); let userinfo = oidc_client.request_userinfo(&token).await?; - Ok((claims, userinfo)) + let bearer = token.bearer; + Ok((claims, userinfo, bearer)) } // put new user in metadata if does not exits @@ -370,21 +411,27 @@ pub async fn put_user( userid: &str, group: HashSet, user_info: user::UserInfo, + bearer: Bearer, ) -> Result { let mut metadata = get_metadata().await?; - let user = metadata + let mut user = metadata .users .iter() .find(|user| user.userid() == userid) .cloned() .unwrap_or_else(|| { - let user = User::new_oauth(userid.to_owned(), group, user_info); + let user = User::new_oauth(userid.to_owned(), group, user_info, None); metadata.users.push(user.clone()); user }); put_metadata(&metadata).await?; + + // modify before storing + if let user::UserType::OAuth(oauth) = &mut user.ty { + oauth.bearer = Some(bearer); + } Users.put_user(user.clone()); Ok(user) } @@ -393,6 +440,7 @@ pub async fn update_user_if_changed( mut user: User, group: HashSet, user_info: user::UserInfo, + bearer: Bearer, ) -> Result { // Store the old username before modifying the user object let old_username = user.userid().to_string(); @@ -408,8 +456,12 @@ pub async fn update_user_if_changed( false }; - // update user only if roles, userinfo has changed, or userid needs migration - if roles == &group && oauth_user.user_info == user_info && !needs_userid_migration { + // update user only if roles, userinfo has changed, or userid needs migration, or bearer is updated + if roles == &group + && oauth_user.user_info == user_info + && !needs_userid_migration + && oauth_user.bearer.as_ref() == Some(&bearer) + { return Ok(user); } @@ -438,6 +490,10 @@ pub async fn update_user_if_changed( } put_metadata(&metadata).await?; Users.delete_user(&old_username); + // update oauth bearer + if let user::UserType::OAuth(oauth) = &mut user.ty { + oauth.bearer = Some(bearer); + } Users.put_user(user.clone()); Ok(user) } diff --git a/src/rbac/map.rs b/src/rbac/map.rs index 5377d10d7..e8836c824 100644 --- a/src/rbac/map.rs +++ b/src/rbac/map.rs @@ -29,6 +29,7 @@ use super::{ }; use chrono::{DateTime, Utc}; use once_cell::sync::{Lazy, OnceCell}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; pub type Roles = HashMap>; @@ -167,6 +168,26 @@ pub struct Sessions { } impl Sessions { + // only checks if the session is expired or not + pub fn is_session_expired(&self, key: &SessionKey) -> bool { + // fetch userid from session key + let userid = if let Some((user, _)) = self.active_sessions.get(key) { + user + } else { + return false; + }; + + // check against user sessions if this session is still valid + let Some(session) = self.user_sessions.get(userid) else { + return false; + }; + + session + .par_iter() + .find_first(|(sessionid, expiry)| sessionid.eq(key) && expiry < &Utc::now()) + .is_some() + } + // track new session key pub fn track_new( &mut self, diff --git a/src/rbac/mod.rs b/src/rbac/mod.rs index 4eb115778..64cea51af 100644 --- a/src/rbac/mod.rs +++ b/src/rbac/mod.rs @@ -23,7 +23,7 @@ pub mod utils; use std::collections::{HashMap, HashSet}; -use chrono::{DateTime, Days, Utc}; +use chrono::{DateTime, Duration, TimeDelta, Utc}; use itertools::Itertools; use role::model::DefaultPrivilege; use serde::Serialize; @@ -37,6 +37,8 @@ use self::map::SessionKey; use self::role::{Permission, RoleBuilder}; use self::user::UserType; +pub const EXPIRY_DURATION: Duration = Duration::hours(1); + #[derive(PartialEq)] pub enum Response { Authorized, @@ -147,11 +149,11 @@ impl Users { mut_sessions().remove_session(session) } - pub fn new_session(&self, user: &User, session: SessionKey) { + pub fn new_session(&self, user: &User, session: SessionKey, expires_in: TimeDelta) { mut_sessions().track_new( user.userid().to_owned(), session, - Utc::now() + Days::new(7), + Utc::now() + expires_in, roles_to_permission(user.roles()), ) } @@ -228,7 +230,7 @@ pub struct UsersPrism { pub user_groups: HashSet, } -fn roles_to_permission(roles: Vec) -> Vec { +pub fn roles_to_permission(roles: Vec) -> Vec { let mut perms = HashSet::new(); for role in &roles { let role_map = &map::roles(); diff --git a/src/rbac/user.rs b/src/rbac/user.rs index 300bf90d9..8e8b62ab8 100644 --- a/src/rbac/user.rs +++ b/src/rbac/user.rs @@ -23,6 +23,7 @@ use argon2::{ password_hash::{PasswordHasher, SaltString, rand_core::OsRng}, }; +use openid::Bearer; use rand::distributions::{Alphanumeric, DistString}; use crate::{ @@ -38,7 +39,7 @@ use crate::{ #[serde(untagged)] pub enum UserType { Native(Basic), - OAuth(OAuth), + OAuth(Box), } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -66,12 +67,18 @@ impl User { ) } - pub fn new_oauth(userid: String, roles: HashSet, user_info: UserInfo) -> Self { + pub fn new_oauth( + userid: String, + roles: HashSet, + user_info: UserInfo, + bearer: Option, + ) -> Self { Self { - ty: UserType::OAuth(OAuth { + ty: UserType::OAuth(Box::new(OAuth { userid: user_info.sub.clone().unwrap_or(userid), user_info, - }), + bearer, + })), roles, user_groups: HashSet::new(), } @@ -80,7 +87,7 @@ impl User { pub fn userid(&self) -> &str { match self.ty { UserType::Native(Basic { ref username, .. }) => username, - UserType::OAuth(OAuth { ref userid, .. }) => userid, + UserType::OAuth(ref oauth) => &oauth.userid, } } @@ -175,6 +182,19 @@ pub fn get_admin_user() -> User { pub struct OAuth { pub userid: String, pub user_info: UserInfo, + pub bearer: Option, +} + +impl AsRef for Box { + /// Returns a reference to the bearer token. + /// + /// # Panics + /// Panics if bearer is None. This should never happen in practice as + /// bearer is always set to Some when OIDC is configured and this trait + /// is only called by refresh_token after verifying bearer.is_some(). + fn as_ref(&self) -> &Bearer { + self.bearer.as_ref().unwrap() + } } #[derive(Debug, Default, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -255,8 +275,10 @@ impl GroupUser { username: username.clone(), method: "native".to_string(), }, - UserType::OAuth(OAuth { userid, user_info }) => { + UserType::OAuth(oauth) => { // For OAuth users, derive the display username from user_info + let user_info = &oauth.user_info; + let userid = &oauth.userid; let display_username = user_info .name .clone() diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 1d84558ab..4ef8063a6 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -59,7 +59,7 @@ pub fn extract_datetime(path: &str) -> Option { } pub fn get_user_from_request(req: &HttpRequest) -> Result { - let session_key = extract_session_key_from_req(req).unwrap(); + let session_key = extract_session_key_from_req(req).map_err(|_| RBACError::UserDoesNotExist)?; let user_id = Users.get_userid_from_session(&session_key); if user_id.is_none() { return Err(RBACError::UserDoesNotExist);