Skip to content

Commit

Permalink
feature-2924: Add an option to suppress server identification headers (
Browse files Browse the repository at this point in the history
…#3770)

Co-authored-by: Gerard Guillemas Martos <gguillemas@users.noreply.github.com>
Co-authored-by: Micha de Vries <micha@devrie.sh>
  • Loading branch information
3 people committed Jun 10, 2024
1 parent 2913917 commit a11f1bc
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 9 deletions.
1 change: 1 addition & 0 deletions src/cli/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ pub struct Config {
pub key: Option<PathBuf>,
pub tick_interval: Duration,
pub engine: Option<EngineOptions>,
pub no_identification_headers: bool,
}
7 changes: 6 additions & 1 deletion src/cli/start.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ pub struct StartCommandArguments {
#[arg(env = "SURREAL_BIND", short = 'b', long = "bind")]
#[arg(default_value = "127.0.0.1:8000")]
listen_addresses: Vec<SocketAddr>,

#[arg(help = "Whether to suppress the server name and version headers")]
#[arg(env = "SURREAL_NO_IDENTIFICATION_HEADERS", long)]
#[arg(default_value_t = false)]
no_identification_headers: bool,
//
// Database options
//
Expand Down Expand Up @@ -142,6 +145,7 @@ pub async fn init(
log,
tick_interval,
no_banner,
no_identification_headers,
..
}: StartCommandArguments,
) -> Result<(), Error> {
Expand Down Expand Up @@ -171,6 +175,7 @@ pub async fn init(
user,
pass,
tick_interval,
no_identification_headers,
crt: web.as_ref().and_then(|x| x.web_crt.clone()),
key: web.as_ref().and_then(|x| x.web_key.clone()),
engine: None,
Expand Down
22 changes: 17 additions & 5 deletions src/net/headers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,25 @@ pub use db::SurrealDatabase;
pub use id::SurrealId;
pub use ns::SurrealNamespace;

pub fn add_version_header() -> SetResponseHeaderLayer<HeaderValue> {
let val = format!("{PKG_NAME}-{}", *PKG_VERSION);
SetResponseHeaderLayer::if_not_present(VERSION.to_owned(), HeaderValue::try_from(val).unwrap())
pub fn add_version_header(enabled: bool) -> SetResponseHeaderLayer<Option<HeaderValue>> {
let header_value = if enabled {
let val = format!("{PKG_NAME}-{}", *PKG_VERSION);
Some(HeaderValue::try_from(val).unwrap())
} else {
None
};

SetResponseHeaderLayer::if_not_present(VERSION.to_owned(), header_value)
}

pub fn add_server_header() -> SetResponseHeaderLayer<HeaderValue> {
SetResponseHeaderLayer::if_not_present(SERVER, HeaderValue::try_from(SERVER_NAME).unwrap())
pub fn add_server_header(enabled: bool) -> SetResponseHeaderLayer<Option<HeaderValue>> {
let header_value = if enabled {
Some(HeaderValue::try_from(SERVER_NAME).unwrap())
} else {
None
};

SetResponseHeaderLayer::if_not_present(SERVER, header_value)
}

// Parse a TypedHeader, returning None if the header is missing and an error if the header is invalid.
Expand Down
4 changes: 2 additions & 2 deletions src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ pub async fn init(ct: CancellationToken) -> Result<(), Error> {
.layer(HttpMetricsLayer)
.layer(SetSensitiveResponseHeadersLayer::from_shared(headers))
.layer(AsyncRequireAuthorizationLayer::new(auth::SurrealAuth))
.layer(headers::add_server_header())
.layer(headers::add_version_header())
.layer(headers::add_server_header(!opt.no_identification_headers))
.layer(headers::add_version_header(!opt.no_identification_headers))
.layer(
CorsLayer::new()
.allow_methods([
Expand Down
29 changes: 28 additions & 1 deletion tests/http_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod http_integration {
use test_log::test;
use ulid::Ulid;

use super::common::{self, PASS, USER};
use super::common::{self, StartServerArguments, PASS, USER};

#[test(tokio::test)]
async fn basic_auth() -> Result<(), Box<dyn std::error::Error>> {
Expand Down Expand Up @@ -352,6 +352,33 @@ mod http_integration {
Ok(())
}

#[test(tokio::test)]
async fn no_server_id_headers() -> Result<(), Box<dyn std::error::Error>> {
// default server has the id headers
{
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
let url = &format!("http://{addr}/health");

let res = Client::default().get(url).send().await?;
assert!(res.headers().contains_key("server"));
assert!(res.headers().contains_key("surreal-version"));
}

// turn on the no-identification-headers option to suppress headers
{
let mut start_server_arguments = StartServerArguments::default();
start_server_arguments.args.push_str(" --no-identification-headers");
let (addr, _server) = common::start_server(start_server_arguments).await.unwrap();
let url = &format!("http://{addr}/health");

let res = Client::default().get(url).send().await?;
assert!(!res.headers().contains_key("server"));
assert!(!res.headers().contains_key("surreal-version"));
}

Ok(())
}

#[test(tokio::test)]
async fn import_endpoint() -> Result<(), Box<dyn std::error::Error>> {
let (addr, _server) = common::start_server_with_defaults().await.unwrap();
Expand Down

0 comments on commit a11f1bc

Please sign in to comment.