Skip to content

Commit

Permalink
refactor: move all of basic authentication logic into basic_auth modu…
Browse files Browse the repository at this point in the history
…le (#346)

* refactor: Move all of Basic authentication logic into basic_auth module
* Do not import server_info macro explicitly

---------

Co-authored-by: Jose Quintana <1700322+joseluisq@users.noreply.github.com>
  • Loading branch information
palant and joseluisq committed Apr 18, 2024
1 parent cc6784a commit 76531e6
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 46 deletions.
198 changes: 196 additions & 2 deletions src/basic_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,64 @@

use bcrypt::verify as bcrypt_verify;
use headers::{authorization::Basic, Authorization, HeaderMap, HeaderMapExt};
use hyper::StatusCode;
use hyper::{header::WWW_AUTHENTICATE, Body, Request, Response, StatusCode};

use crate::{error_page, handler::RequestHandlerOpts, http_ext::MethodExt, Error};

/// Initializes `Basic` HTTP Authorization handling
pub(crate) fn init(credentials: &str, handler_opts: &mut RequestHandlerOpts) {
credentials.trim().clone_into(&mut handler_opts.basic_auth);
server_info!(
"basic authentication: enabled={}",
!handler_opts.basic_auth.is_empty()
);
}

/// Handles `Basic` HTTP Authorization Schema
pub(crate) fn pre_process(
opts: &RequestHandlerOpts,
req: &Request<Body>,
) -> Option<Result<Response<Body>, Error>> {
if opts.basic_auth.is_empty() {
return None;
}

let method = req.method();
if method.is_options() {
return None;
}

let uri = req.uri();
if let Some((user_id, password)) = opts.basic_auth.split_once(':') {
let err = check_request(req.headers(), user_id, password).err()?;
tracing::warn!("basic authentication failed {:?}", err);
let mut result = error_page::error_response(
uri,
method,
&StatusCode::UNAUTHORIZED,
&opts.page404,
&opts.page50x,
);
if let Ok(ref mut resp) = result {
resp.headers_mut().insert(
WWW_AUTHENTICATE,
"Basic realm=\"Static Web Server\", charset=\"UTF-8\""
.parse()
.unwrap(),
);
}
Some(result)
} else {
tracing::error!("invalid basic authentication `user_id:password` pairs");
Some(error_page::error_response(
uri,
method,
&StatusCode::INTERNAL_SERVER_ERROR,
&opts.page404,
&opts.page50x,
))
}
}

/// Check for a `Basic` HTTP Authorization Schema of an incoming request
/// and uses `bcrypt` for password hashing verification.
Expand All @@ -33,8 +90,71 @@ pub fn check_request(headers: &HeaderMap, userid: &str, password: &str) -> Resul

#[cfg(test)]
mod tests {
use super::check_request;
use super::{check_request, pre_process};
use crate::{handler::RequestHandlerOpts, Error};
use headers::HeaderMap;
use hyper::{header::WWW_AUTHENTICATE, Body, Request, Response, StatusCode};

fn make_request(method: &str, auth_header: &str) -> Request<Body> {
let mut builder = Request::builder();
if !auth_header.is_empty() {
builder = builder.header("Authorization", auth_header);
}
builder.method(method).uri("/").body(Body::empty()).unwrap()
}

fn is_401(result: Option<Result<Response<Body>, Error>>) -> bool {
if let Some(Ok(response)) = result {
response.status() == StatusCode::UNAUTHORIZED
&& response.headers().get(WWW_AUTHENTICATE).is_some()
} else {
false
}
}

fn is_500(result: Option<Result<Response<Body>, Error>>) -> bool {
if let Some(Ok(response)) = result {
response.status() == StatusCode::INTERNAL_SERVER_ERROR
} else {
false
}
}

#[test]
fn test_auth_disabled() {
assert!(pre_process(
&RequestHandlerOpts {
basic_auth: "".into(),
..Default::default()
},
&make_request("GET", "Basic anE6anE=")
)
.is_none());
}

#[test]
fn test_invalid_auth_configuration() {
assert!(is_500(pre_process(
&RequestHandlerOpts {
basic_auth: "xyz".into(),
..Default::default()
},
&make_request("GET", "Basic anE6anE=")
)));
}

#[test]
fn test_options_request() {
assert!(pre_process(
&RequestHandlerOpts {
basic_auth: "jq:$2y$05$32zazJ1yzhlDHnt26L3MFOgY0HVqPmDUvG0KUx6cjf9RDiUGp/M9q"
.into(),
..Default::default()
},
&make_request("OPTIONS", "")
)
.is_none());
}

#[test]
fn test_valid_auth() {
Expand All @@ -46,19 +166,45 @@ mod tests {
"$2y$05$32zazJ1yzhlDHnt26L3MFOgY0HVqPmDUvG0KUx6cjf9RDiUGp/M9q"
)
.is_ok());

assert!(pre_process(
&RequestHandlerOpts {
basic_auth: "jq:$2y$05$32zazJ1yzhlDHnt26L3MFOgY0HVqPmDUvG0KUx6cjf9RDiUGp/M9q"
.into(),
..Default::default()
},
&make_request("GET", "Basic anE6anE=")
)
.is_none());
}

#[test]
fn test_invalid_auth_header() {
let headers = HeaderMap::new();
assert!(check_request(&headers, "jq", "").is_err());

assert!(is_401(pre_process(
&RequestHandlerOpts {
basic_auth: "jq:".into(),
..Default::default()
},
&make_request("GET", "")
)));
}

#[test]
fn test_invalid_auth_pairs() {
let mut headers = HeaderMap::new();
headers.insert("Authorization", "Basic anE6anE=".parse().unwrap());
assert!(check_request(&headers, "xyz", "").is_err());

assert!(is_401(pre_process(
&RequestHandlerOpts {
basic_auth: "xyz:".into(),
..Default::default()
},
&make_request("GET", "Basic anE6anE=")
)));
}

#[test]
Expand All @@ -74,6 +220,36 @@ mod tests {
assert!(check_request(&headers, "jq", "password").is_err());
assert!(check_request(&headers, "", "password").is_err());
assert!(check_request(&headers, "jq", "").is_err());

assert!(is_401(pre_process(
&RequestHandlerOpts {
basic_auth: "abc:$2y$05$32zazJ1yzhlDHnt26L3MFOgY0HVqPmDUvG0KUx6cjf9RDiUGp/M9q"
.into(),
..Default::default()
},
&make_request("GET", "Basic anE6anE=")
)));
assert!(is_401(pre_process(
&RequestHandlerOpts {
basic_auth: "jq:password".into(),
..Default::default()
},
&make_request("GET", "Basic anE6anE=")
)));
assert!(is_401(pre_process(
&RequestHandlerOpts {
basic_auth: ":password".into(),
..Default::default()
},
&make_request("GET", "Basic anE6anE=")
)));
assert!(is_401(pre_process(
&RequestHandlerOpts {
basic_auth: "jq:".into(),
..Default::default()
},
&make_request("GET", "Basic anE6anE=")
)));
}

#[test]
Expand All @@ -86,6 +262,15 @@ mod tests {
"$2y$05$32zazJ1yzhlDHnt26L3MFOgY0HVqPmDUvG0KUx6cjf9RDiUGp/M9q"
)
.is_err());

assert!(is_401(pre_process(
&RequestHandlerOpts {
basic_auth: "jq:$2y$05$32zazJ1yzhlDHnt26L3MFOgY0HVqPmDUvG0KUx6cjf9RDiUGp/M9q"
.into(),
..Default::default()
},
&make_request("GET", "Basic xyz")
)));
}

#[test]
Expand All @@ -98,5 +283,14 @@ mod tests {
"$2y$05$32zazJ1yzhlDHnt26L3MFOgY0HVqPmDUvG0KUx6cjf9RDiUGp/M9q"
)
.is_err());

assert!(is_401(pre_process(
&RequestHandlerOpts {
basic_auth: "jq:$2y$05$32zazJ1yzhlDHnt26L3MFOgY0HVqPmDUvG0KUx6cjf9RDiUGp/M9q"
.into(),
..Default::default()
},
&make_request("GET", "abcd")
)));
}
}
2 changes: 0 additions & 2 deletions src/cors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ use headers::{
use http::header;
use std::collections::HashSet;

use crate::server_info;

/// It defines CORS instance.
#[derive(Clone, Debug)]
pub struct Cors {
Expand Down
35 changes: 4 additions & 31 deletions src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::{future::Future, net::IpAddr, net::SocketAddr, path::PathBuf, sync::Arc
use crate::compression;

#[cfg(feature = "basic-auth")]
use {crate::basic_auth, hyper::header::WWW_AUTHENTICATE};
use crate::basic_auth;

#[cfg(feature = "fallback-page")]
use crate::fallback_page;
Expand Down Expand Up @@ -242,37 +242,10 @@ impl RequestHandler {
};
}

#[cfg(feature = "basic-auth")]
// `Basic` HTTP Authorization Schema
if !self.opts.basic_auth.is_empty() && !method.is_options() {
if let Some((user_id, password)) = self.opts.basic_auth.split_once(':') {
if let Err(err) = basic_auth::check_request(headers, user_id, password) {
tracing::warn!("basic authentication failed {:?}", err);
let mut resp = error_page::error_response(
uri,
method,
&StatusCode::UNAUTHORIZED,
&self.opts.page404,
&self.opts.page50x,
)?;
resp.headers_mut().insert(
WWW_AUTHENTICATE,
"Basic realm=\"Static Web Server\", charset=\"UTF-8\""
.parse()
.unwrap(),
);
return Ok(resp);
}
} else {
tracing::error!("invalid basic authentication `user_id:password` pairs");
return error_page::error_response(
uri,
method,
&StatusCode::INTERNAL_SERVER_ERROR,
&self.opts.page404,
&self.opts.page50x,
);
}
#[cfg(feature = "basic-auth")]
if let Some(response) = basic_auth::pre_process(&self.opts, req) {
return response;
}

// Maintenance Mode
Expand Down
18 changes: 7 additions & 11 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ use {
hyper::service::{make_service_fn, service_fn},
};

#[cfg(feature = "basic-auth")]
use crate::basic_auth;

use crate::{cors, health, helpers, Settings};
use crate::{service::RouterService, Context, Result};

Expand Down Expand Up @@ -308,15 +311,6 @@ impl Server {
general.cors_expose_headers.trim(),
);

#[cfg(feature = "basic-auth")]
// `Basic` HTTP Authentication Schema option
let basic_auth = general.basic_auth.trim().to_owned();
#[cfg(feature = "basic-auth")]
server_info!(
"basic authentication: enabled={}",
!general.basic_auth.is_empty()
);

// Log remote address option
let log_remote_address = general.log_remote_address;
server_info!("log remote address: enabled={}", log_remote_address);
Expand Down Expand Up @@ -365,8 +359,6 @@ impl Server {
page50x: page50x.clone(),
#[cfg(feature = "fallback-page")]
page_fallback,
#[cfg(feature = "basic-auth")]
basic_auth,
log_remote_address,
redirect_trailing_slash,
ignore_hidden_files,
Expand All @@ -382,6 +374,10 @@ impl Server {
#[cfg(all(unix, feature = "experimental"))]
metrics::init(general.experimental_metrics, &mut handler_opts);

// `Basic` HTTP Authentication Schema option
#[cfg(feature = "basic-auth")]
basic_auth::init(&general.basic_auth, &mut handler_opts);

// Maintenance mode option
handler_opts.maintenance_mode = general.maintenance_mode;
handler_opts.maintenance_mode_status = general.maintenance_mode_status;
Expand Down

0 comments on commit 76531e6

Please sign in to comment.