Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ Cargo.lock
# Added by cargo

/target

# Dev TLS certs
.certs/
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ aws-smithy-http = "0.54"
aws-smithy-types = "0.54"
aws-types = "0.54"
axum = { version = "0.6", features = ["headers"] }
axum-server = { version = "0.4.7", features = ["tls-rustls"] }
clap = { version = "4.2.1", features = ["derive", "env"] }
expanduser = "1.2.2"
http = "*"
hyper = { version = "0.14", features = ["full"] }
maligned = "0.2.1"
Expand Down
86 changes: 79 additions & 7 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
//! * [ndarray] provides [NumPy](https://numpy.orgq)-like n-dimensional arrays used in numerical
//! computation.

use std::{net::SocketAddr, process::exit, str::FromStr, time::Duration};

use axum_server::{tls_rustls::RustlsConfig, Handle};
use clap::Parser;
use expanduser::expanduser;
use tokio::signal;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

Expand All @@ -44,6 +48,26 @@ struct CommandLineArgs {
/// The port to which the proxy should bind
#[arg(long, default_value_t = 8080, env = "S3_ACTIVE_STORAGE_PORT")]
port: u16,
/// Flag indicating whether HTTPS should be used
#[arg(long, default_value_t = false, env = "S3_ACTIVE_STORAGE_HTTPS")]
https: bool,
/// Path to the certificate file to be used for HTTPS encryption
#[arg(
long,
default_value = "~/.config/s3-active-storage/certs/cert.pem",
env = "S3_ACTIVE_STORAGE_CERT_FILE"
)]
cert_file: String,
/// Path to the key file to be used for HTTPS encryption
#[arg(
long,
default_value = "~/.config/s3-active-storage/certs/key.pem",
env = "S3_ACTIVE_STORAGE_KEY_FILE"
)]
key_file: String,
/// Maximum time in seconds to wait for operations to complete upon receiving `ctrl+c` signal.
#[arg(long, default_value_t = 60, env = "S3_ACTIVE_STORAGE_SHUTDOWN_TIMEOUT")]
graceful_shutdown_timeout: u64,
}

/// Application entry point
Expand All @@ -54,13 +78,59 @@ async fn main() {
init_tracing();

let router = app::router();
let addr = SocketAddr::from_str(&format!("{}:{}", args.host, args.port))
.expect("invalid host name, IP address or port number");

// Catch ctrl+c and try to shutdown gracefully
let handle = Handle::new();
tokio::spawn(shutdown_signal(
handle.clone(),
args.graceful_shutdown_timeout,
));

// run it with hyper
axum::Server::bind(&format!("{}:{}", args.host, args.port).parse().unwrap())
.serve(router.into_make_service())
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap();
if args.https {
// Expand files
let abs_cert_file = expanduser(args.cert_file)
.expect("Failed to expand ~ to user name. Please provide an absolute path instead.")
.canonicalize()
.expect("failed to determine absolute path to TLS cerficate file");
let abs_key_file = expanduser(args.key_file)
.expect("Failed to expand ~ to user name. Please provide an absolute path instead.")
.canonicalize()
.expect("failed to determine absolute path to TLS key file");
// Check files exist
if !abs_cert_file.exists() {
println!(
"TLS certificate file expected at '{}' but not found.",
abs_cert_file.display()
);
exit(1)
}
if !abs_key_file.exists() {
println!(
"TLS key file expected at '{}' but not found.",
abs_key_file.display()
);
exit(1)
}
// Set up TLS config
let tls_config = RustlsConfig::from_pem_file(abs_cert_file, abs_key_file)
.await
.expect("Failed to load TLS certificate files");
// run HTTPS server with hyper
axum_server::bind_rustls(addr, tls_config)
.handle(handle)
.serve(router.into_make_service())
.await
.unwrap();
} else {
// run HTTP server with hyper
axum_server::bind(addr)
.handle(handle)
.serve(router.into_make_service())
.await
.unwrap();
}
}

/// Initlialise tracing (logging)
Expand All @@ -80,7 +150,7 @@ fn init_tracing() {
/// Graceful shutdown handler
///
/// Installs signal handlers to catch Ctrl-C or SIGTERM and trigger a graceful shutdown.
async fn shutdown_signal() {
async fn shutdown_signal(handle: Handle, timeout: u64) {
let ctrl_c = async {
signal::ctrl_c()
.await
Expand All @@ -104,4 +174,6 @@ async fn shutdown_signal() {
}

println!("signal received, starting graceful shutdown");
// Force shutdown if graceful shutdown takes longer than 10s
handle.graceful_shutdown(Some(Duration::from_secs(timeout)));
}