From 69aa18b179c496dd9dd214ea4d12a2250144a62a Mon Sep 17 00:00:00 2001 From: overtrue Date: Sat, 16 May 2026 13:46:38 +0800 Subject: [PATCH] feat(requests): support custom amz headers --- crates/cli/src/commands/mod.rs | 39 ++++++++ crates/cli/tests/help_contract.rs | 1 + crates/core/src/alias.rs | 75 ++++++++++++++++ crates/core/src/lib.rs | 5 +- crates/s3/src/client.rs | 144 ++++++++++++++++++++++++++++-- docs/reference/rc/README.md | 1 + 6 files changed, 255 insertions(+), 10 deletions(-) diff --git a/crates/cli/src/commands/mod.rs b/crates/cli/src/commands/mod.rs index 7f1e125..cc7381e 100644 --- a/crates/cli/src/commands/mod.rs +++ b/crates/cli/src/commands/mod.rs @@ -7,6 +7,7 @@ use std::io::{IsTerminal, stderr, stdout}; use clap::{Parser, Subcommand, ValueEnum}; +use rc_core::{RequestHeader, set_global_request_headers}; use crate::exit_code::ExitCode; use crate::output::OutputConfig; @@ -74,10 +75,18 @@ pub struct Cli { #[arg(long, global = true, default_value = "false")] pub debug: bool, + /// Add an x-amz-* request header to signed S3 requests + #[arg(short = 'H', long = "header", global = true, value_parser = parse_request_header)] + pub request_headers: Vec, + #[command(subcommand)] pub command: Commands, } +fn parse_request_header(value: &str) -> Result { + RequestHeader::parse(value).map_err(|error| error.to_string()) +} + #[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)] pub enum OutputFormat { Auto, @@ -247,6 +256,7 @@ pub enum Commands { /// Execute the CLI command and return an exit code pub async fn execute(cli: Cli) -> ExitCode { + set_global_request_headers(cli.request_headers.clone()); let output_options = GlobalOutputOptions::from_cli(&cli); match cli.command { @@ -443,6 +453,35 @@ mod tests { assert_eq!(resolved.json, !std::io::stdout().is_terminal()); } + #[test] + fn cli_accepts_global_custom_amz_header() { + let cli = Cli::try_parse_from([ + "rc", + "-H", + "x-amz-bucket-encrypt-enabled:1", + "bucket", + "list", + "local/", + ]) + .expect("parse custom header"); + + assert_eq!(cli.request_headers.len(), 1); + assert_eq!(cli.request_headers[0].name, "x-amz-bucket-encrypt-enabled"); + assert_eq!(cli.request_headers[0].value, "1"); + } + + #[test] + fn cli_rejects_non_amz_custom_header() { + let error = Cli::try_parse_from(["rc", "-H", "authorization:secret", "ls", "local/"]) + .expect_err("non amz header should fail"); + + assert!( + error + .to_string() + .contains("Only x-amz-* custom request headers are supported") + ); + } + #[test] fn cli_accepts_bucket_cors_subcommand() { let cli = Cli::try_parse_from(["rc", "bucket", "cors", "list", "local/my-bucket"]) diff --git a/crates/cli/tests/help_contract.rs b/crates/cli/tests/help_contract.rs index 8e5aad9..b51fa3a 100644 --- a/crates/cli/tests/help_contract.rs +++ b/crates/cli/tests/help_contract.rs @@ -14,6 +14,7 @@ const GLOBAL_OPTIONS: &[&str] = &[ "--no-progress", "--quiet", "--debug", + "--header", "--help", "--version", ]; diff --git a/crates/core/src/alias.rs b/crates/core/src/alias.rs index 9b96ab8..7ee194b 100644 --- a/crates/core/src/alias.rs +++ b/crates/core/src/alias.rs @@ -4,6 +4,7 @@ //! including connection details and credentials. use std::env; +use std::sync::{OnceLock, RwLock}; use serde::{Deserialize, Serialize}; use url::Url; @@ -12,6 +13,80 @@ use crate::config::ConfigManager; use crate::error::{Error, Result}; const RC_HOST_PREFIX: &str = "RC_HOST_"; +const CUSTOM_HEADER_PREFIX: &str = "x-amz-"; + +static GLOBAL_REQUEST_HEADERS: OnceLock>> = OnceLock::new(); + +/// Custom S3 request header applied to remote operations. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RequestHeader { + pub name: String, + pub value: String, +} + +impl RequestHeader { + pub fn parse(value: &str) -> Result { + let (name, header_value) = value.split_once(':').ok_or_else(|| { + Error::Config( + "Header must use NAME:VALUE format, for example x-amz-meta-key:value".into(), + ) + })?; + + let name = name.trim().to_ascii_lowercase(); + let header_value = header_value.trim().to_string(); + + if name.is_empty() { + return Err(Error::Config("Header name must not be empty".into())); + } + + if header_value.is_empty() { + return Err(Error::Config("Header value must not be empty".into())); + } + + if !name.starts_with(CUSTOM_HEADER_PREFIX) { + return Err(Error::Config( + "Only x-amz-* custom request headers are supported".into(), + )); + } + + if !name + .bytes() + .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'-' | b'_')) + { + return Err(Error::Config(format!("Invalid header name '{name}'"))); + } + + if !header_value.is_ascii() || header_value.bytes().any(|b| matches!(b, b'\r' | b'\n')) { + return Err(Error::Config(format!("Invalid value for header '{name}'"))); + } + + Ok(Self { + name, + value: header_value, + }) + } +} + +/// Set process-wide custom request headers for this CLI invocation. +pub fn set_global_request_headers(headers: Vec) { + let storage = GLOBAL_REQUEST_HEADERS.get_or_init(|| RwLock::new(Vec::new())); + let mut guard = storage + .write() + .expect("global request header lock should not be poisoned"); + *guard = headers; +} + +/// Get process-wide custom request headers for this CLI invocation. +pub fn global_request_headers() -> Vec { + let Some(storage) = GLOBAL_REQUEST_HEADERS.get() else { + return Vec::new(); + }; + + storage + .read() + .expect("global request header lock should not be poisoned") + .clone() +} /// Retry configuration for an alias #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index c4f41f9..3fc526e 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -21,7 +21,10 @@ pub mod retry; pub mod select; pub mod traits; -pub use alias::{Alias, AliasManager, validate_alias_endpoint}; +pub use alias::{ + Alias, AliasManager, RequestHeader, global_request_headers, set_global_request_headers, + validate_alias_endpoint, +}; pub use config::{Config, ConfigManager}; pub use cors::{CorsConfiguration, CorsRule}; pub use error::{Error, Result}; diff --git a/crates/s3/src/client.rs b/crates/s3/src/client.rs index 8aa5b70..dbc7c40 100644 --- a/crates/s3/src/client.rs +++ b/crates/s3/src/client.rs @@ -9,21 +9,26 @@ use aws_sigv4::http_request::{ SignableBody, SignableRequest, SignatureLocation, SigningSettings, sign, }; use aws_sigv4::sign::v4; +use aws_smithy_runtime_api::box_error::BoxError; use aws_smithy_runtime_api::client::http::{ HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector, }; +use aws_smithy_runtime_api::client::interceptors::Intercept; +use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut; use aws_smithy_runtime_api::client::orchestrator::HttpRequest; use aws_smithy_runtime_api::client::result::ConnectorError; use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; use aws_smithy_runtime_api::http::{Response, StatusCode}; use aws_smithy_types::body::SdkBody; +use aws_smithy_types::config_bag::ConfigBag; use bytes::Bytes; use jiff::Timestamp; use quick_xml::de::from_str as from_xml_str; use rc_core::{ Alias, BucketNotification, Capabilities, CorsRule, Error, LifecycleRule, ListOptions, ListResult, NotificationTarget, ObjectInfo, ObjectStore, ObjectVersion, - ObjectVersionListResult, RemotePath, ReplicationConfiguration, Result, SelectOptions, + ObjectVersionListResult, RemotePath, ReplicationConfiguration, RequestHeader, Result, + SelectOptions, global_request_headers, }; use reqwest::Method; use reqwest::header::{CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue}; @@ -671,8 +676,10 @@ impl HttpClient for ReqwestConnector { /// S3 client wrapper pub struct S3Client { inner: aws_sdk_s3::Client, + presign_inner: aws_sdk_s3::Client, xml_http_client: reqwest::Client, alias: Alias, + request_headers: Vec, } /// Request-level options for delete operations. @@ -682,6 +689,33 @@ pub struct DeleteRequestOptions { pub force_delete: bool, } +#[derive(Debug, Clone)] +struct CustomHeaderInterceptor { + headers: Vec, +} + +impl Intercept for CustomHeaderInterceptor { + fn name(&self) -> &'static str { + "CustomHeaderInterceptor" + } + + fn modify_before_signing( + &self, + context: &mut BeforeTransmitInterceptorContextMut<'_>, + _runtime_components: &RuntimeComponents, + _cfg: &mut ConfigBag, + ) -> std::result::Result<(), BoxError> { + let request = context.request_mut(); + for header in &self.headers { + request + .headers_mut() + .try_insert(header.name.clone(), header.value.clone()) + .map_err(|error| Box::new(error) as BoxError)?; + } + Ok(()) + } +} + impl S3Client { /// Create a new S3 client from an alias configuration pub async fn new(alias: Alias) -> Result { @@ -718,7 +752,8 @@ impl S3Client { let config = config_loader.load().await; // Build S3 client with path-style addressing for compatibility - let s3_config = aws_sdk_s3::config::Builder::from(&config) + let request_headers = global_request_headers(); + let mut s3_config_builder = aws_sdk_s3::config::Builder::from(&config) .force_path_style(force_path_style_for_alias(&alias)) // Improve compatibility with S3-compatible backends by only sending request // checksums when the operation explicitly requires them. @@ -727,15 +762,27 @@ impl S3Client { ) .response_checksum_validation( aws_sdk_s3::config::ResponseChecksumValidation::WhenRequired, - ) - .build(); + ); + + let presign_s3_config = s3_config_builder.clone().build(); + + if !request_headers.is_empty() { + s3_config_builder = s3_config_builder.interceptor(CustomHeaderInterceptor { + headers: request_headers.clone(), + }); + } + + let s3_config = s3_config_builder.build(); let client = aws_sdk_s3::Client::from_conf(s3_config); + let presign_client = aws_sdk_s3::Client::from_conf(presign_s3_config); Ok(Self { inner: client, + presign_inner: presign_client, xml_http_client, alias, + request_headers, }) } @@ -1148,6 +1195,14 @@ impl S3Client { ); } + for header in &self.request_headers { + let name = HeaderName::from_bytes(header.name.as_bytes()) + .map_err(|e| Error::Auth(format!("Invalid custom header name: {e}")))?; + let value = HeaderValue::from_str(&header.value) + .map_err(|e| Error::Auth(format!("Invalid custom header value: {e}")))?; + headers.insert(name, value); + } + let signed_headers = self .sign_xml_request(&method, url.as_str(), &headers, &body) .await?; @@ -1924,7 +1979,7 @@ impl ObjectStore for S3Client { .map_err(|e| Error::General(format!("presign_get config: {e}")))?; let request = self - .inner + .presign_inner .get_object() .bucket(&path.bucket) .key(&path.key) @@ -1946,7 +2001,11 @@ impl ObjectStore for S3Client { .build() .map_err(|e| Error::General(format!("presign_put config: {e}")))?; - let mut builder = self.inner.put_object().bucket(&path.bucket).key(&path.key); + let mut builder = self + .presign_inner + .put_object() + .bucket(&path.bucket) + .key(&path.key); if let Some(ct) = content_type { builder = builder.content_type(ct); @@ -2767,6 +2826,14 @@ mod tests { fn test_s3_client_with_endpoint( endpoint: &str, response: Option>, + ) -> (S3Client, CaptureRequestReceiver) { + test_s3_client_with_endpoint_and_headers(endpoint, response, Vec::new()) + } + + fn test_s3_client_with_endpoint_and_headers( + endpoint: &str, + response: Option>, + request_headers: Vec, ) -> (S3Client, CaptureRequestReceiver) { let (http_client, request_receiver) = capture_request(response); let credentials = Credentials::new( @@ -2776,20 +2843,31 @@ mod tests { None, "rc-test-credentials", ); - let config = aws_sdk_s3::config::Builder::new() + let mut config_builder = aws_sdk_s3::config::Builder::new() .credentials_provider(credentials) .endpoint_url(endpoint) .region(aws_sdk_s3::config::Region::new("us-east-1")) .force_path_style(true) .behavior_version_latest() - .http_client(http_client) - .build(); + .http_client(http_client); + + let presign_config = config_builder.clone().build(); + + if !request_headers.is_empty() { + config_builder = config_builder.interceptor(CustomHeaderInterceptor { + headers: request_headers.clone(), + }); + } + + let config = config_builder.build(); let alias = Alias::new("test", endpoint, "access-key", "secret-key"); let client = S3Client { inner: aws_sdk_s3::Client::from_conf(config), + presign_inner: aws_sdk_s3::Client::from_conf(presign_config), xml_http_client: reqwest::Client::new(), alias, + request_headers, }; (client, request_receiver) @@ -3469,6 +3547,54 @@ mod tests { assert_eq!(request.headers().get("x-rustfs-force-delete"), Some("true")); } + #[tokio::test] + async fn custom_headers_are_added_before_sending_sdk_requests() { + let (client, request_receiver) = test_s3_client_with_endpoint_and_headers( + "https://example.com", + None, + vec![RequestHeader { + name: "x-amz-bucket-encrypt-enabled".to_string(), + value: "1".to_string(), + }], + ); + let path = RemotePath::new("test", "bucket", "key.txt"); + + let _ = client.delete_object(&path).await; + + let request = request_receiver.expect_request(); + assert_eq!( + request.headers().get("x-amz-bucket-encrypt-enabled"), + Some("1") + ); + assert!( + request + .headers() + .get("authorization") + .expect("authorization header") + .contains("x-amz-bucket-encrypt-enabled") + ); + } + + #[tokio::test] + async fn custom_headers_are_not_required_by_presigned_urls() { + let (client, _request_receiver) = test_s3_client_with_endpoint_and_headers( + "https://example.com", + None, + vec![RequestHeader { + name: "x-amz-bucket-encrypt-enabled".to_string(), + value: "1".to_string(), + }], + ); + let path = RemotePath::new("test", "bucket", "key.txt"); + + let url = client + .presign_get(&path, 3600) + .await + .expect("presign get should succeed"); + + assert!(!url.contains("x-amz-bucket-encrypt-enabled")); + } + #[tokio::test] async fn delete_object_without_force_delete_omits_rustfs_header() { let (client, request_receiver) = test_s3_client(None); diff --git a/docs/reference/rc/README.md b/docs/reference/rc/README.md index 402aa1d..d3551cb 100644 --- a/docs/reference/rc/README.md +++ b/docs/reference/rc/README.md @@ -55,6 +55,7 @@ Global options shown in command syntax use the same meaning everywhere: | `--no-progress` | Disable progress bars. | | `-q, --quiet` | Suppress non-error output. | | `--debug` | Enable debug logging. | +| `-H, --header NAME:VALUE` | Add an `x-amz-*` header to signed S3 requests. | ## Credentials