Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions apps/skit/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,96 @@ impl Default for CorsConfig {
}
}

fn default_label_fallback() -> String {
"other".to_string()
}

fn default_request_labels() -> Vec<RequestLabelConfig> {
// Ships the `service` dimension oneshot dashboards build against; operators
// can extend or replace it without a recompile.
vec![RequestLabelConfig {
name: "service".to_string(),
header: "X-StreamKit-Service".to_string(),
allowed: vec!["tts".to_string(), "stt".to_string()],
fallback: default_label_fallback(),
}]
}

/// A bounded metric label sourced from a trusted request header.
///
/// The header value is trimmed and lowercased, then matched against `allowed`;
/// anything not in the allowlist (or a missing header) collapses to `fallback`,
/// so client-supplied headers can never inflate metric cardinality.
#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)]
pub struct RequestLabelConfig {
/// Metric label key (e.g. `service`).
pub name: String,
/// Trusted request header to read the value from (e.g. `X-StreamKit-Service`).
pub header: String,
/// Permitted values, matched case-insensitively after trimming.
#[serde(default)]
pub allowed: Vec<String>,
/// Value emitted when the header is absent or its value is not in `allowed`.
#[serde(default = "default_label_fallback")]
pub fallback: String,
}

/// Configuration for request-scoped metric labeling.
#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)]
pub struct MetricsConfig {
/// Bounded labels attached to request metrics, each sourced from a trusted
/// request header. Applied to all HTTP request metrics and to oneshot
/// pipeline metrics.
#[serde(default = "default_request_labels")]
pub request_labels: Vec<RequestLabelConfig>,
}

impl Default for MetricsConfig {
fn default() -> Self {
Self { request_labels: default_request_labels() }
}
}

impl MetricsConfig {
/// Label keys emitted by built-in request instruments; configured labels
/// must not collide with these (a duplicate key makes Prometheus reject the
/// whole series on scrape).
const RESERVED_LABEL_NAMES: [&'static str; 4] =
["status", "http.method", "http.route", "http.status_code"];

/// Lowercase and trim every allowlist entry so the per-request hot path only
/// has to normalize the incoming header value.
fn normalize(&mut self) {
for label in &mut self.request_labels {
for allowed in &mut label.allowed {
*allowed = allowed.trim().to_ascii_lowercase();
}
}
}

/// Reject label names that collide with built-in metric keys or each other.
///
/// # Errors
///
/// Returns an error if a configured label name is reserved by a built-in
/// metric or duplicates another configured label name.
pub fn validate(&self) -> Result<(), String> {
let mut seen = std::collections::HashSet::new();
for label in &self.request_labels {
if Self::RESERVED_LABEL_NAMES.contains(&label.name.as_str()) {
return Err(format!(
"metrics request_label name '{}' is reserved by built-in metrics",
label.name
));
}
if !seen.insert(label.name.as_str()) {
return Err(format!("duplicate metrics request_label name '{}'", label.name));
}
}
Ok(())
}
}

/// Telemetry and observability configuration (OpenTelemetry, tokio-console).
#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)]
pub struct TelemetryConfig {
Expand Down Expand Up @@ -322,6 +412,9 @@ pub struct ServerConfig {
/// CORS configuration for cross-origin requests
#[serde(default)]
pub cors: CorsConfig,
/// Bounded request-metric labeling configuration.
#[serde(default)]
pub metrics: MetricsConfig,
#[cfg(feature = "moq")]
pub moq_address: Option<String>,
/// TLS certificate for the MoQ WebTransport listener.
Expand Down Expand Up @@ -350,6 +443,7 @@ impl Default for ServerConfig {
max_body_size: default_max_body_size(),
base_path: None,
cors: CorsConfig::default(),
metrics: MetricsConfig::default(),
#[cfg(feature = "moq")]
moq_address: None,
#[cfg(feature = "moq")]
Expand Down Expand Up @@ -1031,10 +1125,14 @@ pub fn load(config_path: &str) -> Result<ConfigLoadResult, Box<figment::Error>>
figment.merge(Env::prefixed("SK_").split("__")).extract().map_err(Box::new)?;

normalize_permissions_config(&mut config);
config.server.metrics.normalize();

if let Err(e) = config.mcp.validate() {
return Err(Box::new(figment::Error::from(e)));
}
if let Err(e) = config.server.metrics.validate() {
return Err(Box::new(figment::Error::from(e)));
}

Ok(ConfigLoadResult { config, file_missing })
}
Expand Down Expand Up @@ -1445,4 +1543,62 @@ allowed_plugins = []
Ok(())
});
}

fn request_label(name: &str) -> RequestLabelConfig {
RequestLabelConfig {
name: name.to_string(),
header: "X-Test".to_string(),
allowed: vec![],
fallback: "other".to_string(),
}
}

#[test]
fn metrics_validate_rejects_reserved_label_name() {
let metrics = MetricsConfig { request_labels: vec![request_label("status")] };
assert!(metrics.validate().is_err());
}

#[test]
fn metrics_validate_rejects_duplicate_label_name() {
let metrics = MetricsConfig {
request_labels: vec![request_label("service"), request_label("service")],
};
assert!(metrics.validate().is_err());
}

#[test]
fn metrics_validate_accepts_default() {
assert!(MetricsConfig::default().validate().is_ok());
}

#[test]
fn metrics_normalize_lowercases_allowlist() {
let mut metrics = MetricsConfig {
request_labels: vec![RequestLabelConfig {
name: "service".to_string(),
header: "X-StreamKit-Service".to_string(),
allowed: vec![" TTS ".to_string(), "Stt".to_string()],
fallback: "other".to_string(),
}],
};
metrics.normalize();
assert_eq!(metrics.request_labels[0].allowed, vec!["tts".to_string(), "stt".to_string()]);
}

#[test]
fn load_rejects_reserved_metrics_label_name() {
figment::Jail::expect_with(|jail| {
jail.create_file(
"skit.toml",
r#"[[server.metrics.request_labels]]
name = "http.route"
header = "X-StreamKit-Service"
allowed = ["tts"]
"#,
)?;
assert!(load("skit.toml").is_err(), "reserved label name must fail load");
Ok(())
});
}
}
1 change: 1 addition & 0 deletions apps/skit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub mod marketplace_installer;
pub mod marketplace_security;
#[cfg(feature = "mcp")]
pub mod mcp;
pub mod metrics_labels;
#[cfg(feature = "moq")]
pub mod moq_gateway;
pub mod mse_gateway;
Expand Down
1 change: 1 addition & 0 deletions apps/skit/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ mod marketplace_installer;
mod marketplace_security;
#[cfg(feature = "mcp")]
mod mcp;
mod metrics_labels;
#[cfg(feature = "moq")]
mod moq_gateway;
mod mse_gateway;
Expand Down
103 changes: 103 additions & 0 deletions apps/skit/src/metrics_labels.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// SPDX-FileCopyrightText: © 2025 StreamKit Contributors
//
// SPDX-License-Identifier: MPL-2.0

//! Resolve bounded metric labels from trusted request headers.
//!
//! The values are constrained to operator-configured allowlists so
//! client-supplied headers can never inflate metric cardinality.

use axum::http::HeaderMap;
use opentelemetry::KeyValue;

use crate::config::RequestLabelConfig;

/// Bounded request labels resolved once per request and stashed in request
/// extensions so downstream handlers can reuse them without re-parsing headers.
#[derive(Clone)]
pub struct ResolvedRequestLabels(pub Vec<KeyValue>);

fn normalize(value: &str) -> String {
value.trim().to_ascii_lowercase()
}

/// Constrain a header value to an allowlist, falling back when it is absent or
/// unrecognized. The incoming value is normalized (trim + lowercase); `allowed`
/// entries are expected to be pre-normalized at config load.
fn classify(value: Option<&str>, allowed: &[String], fallback: &str) -> String {
match value.map(normalize) {
Some(v) if allowed.contains(&v) => v,
_ => fallback.to_string(),
}
}

/// Resolve configured request labels into bounded metric key-values.
pub fn resolve_request_labels(labels: &[RequestLabelConfig], headers: &HeaderMap) -> Vec<KeyValue> {
labels
.iter()
.map(|label| {
let value = headers.get(label.header.as_str()).and_then(|v| v.to_str().ok());
KeyValue::new(label.name.clone(), classify(value, &label.allowed, &label.fallback))
})
.collect()
}

#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderValue;

fn label(name: &str, header: &str, allowed: &[&str]) -> RequestLabelConfig {
RequestLabelConfig {
name: name.to_string(),
header: header.to_string(),
allowed: allowed.iter().map(|s| (*s).to_string()).collect(),
fallback: "other".to_string(),
}
}

#[test]
fn classify_allows_listed_values() {
let allowed = vec!["tts".to_string(), "stt".to_string()];
assert_eq!(classify(Some("tts"), &allowed, "other"), "tts");
assert_eq!(classify(Some("stt"), &allowed, "other"), "stt");
}

#[test]
fn classify_normalizes_case_and_whitespace() {
let allowed = vec!["tts".to_string()];
assert_eq!(classify(Some(" TTS "), &allowed, "other"), "tts");
}

#[test]
fn classify_unknown_empty_and_absent_fall_back() {
let allowed = vec!["tts".to_string()];
assert_eq!(classify(Some("kokoro"), &allowed, "other"), "other");
assert_eq!(classify(Some(""), &allowed, "other"), "other");
assert_eq!(classify(None, &allowed, "other"), "other");
}

#[test]
fn classify_empty_allowlist_always_falls_back() {
assert_eq!(classify(Some("tts"), &[], "other"), "other");
}

#[test]
fn resolve_emits_one_keyvalue_per_label() {
let labels = vec![label("service", "X-StreamKit-Service", &["tts", "stt"])];
let mut headers = HeaderMap::new();
headers.insert("X-StreamKit-Service", HeaderValue::from_static("STT"));

let resolved = resolve_request_labels(&labels, &headers);
assert_eq!(resolved.len(), 1);
assert_eq!(resolved[0].key.as_str(), "service");
assert_eq!(resolved[0].value.as_str(), "stt");
}

#[test]
fn resolve_falls_back_when_header_missing() {
let labels = vec![label("service", "X-StreamKit-Service", &["tts", "stt"])];
let resolved = resolve_request_labels(&labels, &HeaderMap::new());
assert_eq!(resolved[0].value.as_str(), "other");
}
}
18 changes: 15 additions & 3 deletions apps/skit/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1562,13 +1562,24 @@ async fn static_handler(
}
}

async fn metrics_middleware(req: axum::http::Request<Body>, next: Next) -> Response {
async fn metrics_middleware(
State(app_state): State<Arc<AppState>>,
mut req: axum::http::Request<Body>,
next: Next,
) -> Response {
let start = Instant::now();
let method = req.method().clone();
let path = req.extensions().get::<MatchedPath>().map_or_else(
|| req.uri().path().to_owned(),
|matched_path| matched_path.as_str().to_owned(),
);
let configured_labels = crate::metrics_labels::resolve_request_labels(
&app_state.config.server.metrics.request_labels,
req.headers(),
);
// Let downstream handlers reuse the resolved labels instead of re-parsing headers.
req.extensions_mut()
.insert(crate::metrics_labels::ResolvedRequestLabels(configured_labels.clone()));

let response = next.run(req).await;

Expand All @@ -1590,11 +1601,12 @@ async fn metrics_middleware(req: axum::http::Request<Body>, next: Next) -> Respo
})
.clone();

let labels = [
let mut labels = vec![
KeyValue::new("http.method", method.to_string()),
KeyValue::new("http.route", path),
KeyValue::new("http.status_code", status),
];
labels.extend(configured_labels);

counter.add(1, &labels);
histogram.record(latency, &labels);
Expand Down Expand Up @@ -1958,7 +1970,7 @@ pub fn create_app(
.on_response(DefaultOnResponse::new().level(tracing::Level::DEBUG))
.on_failure(DefaultOnFailure::new().level(tracing::Level::WARN)),
))
.layer(middleware::from_fn(metrics_middleware))
.layer(middleware::from_fn_with_state(Arc::clone(&app_state), metrics_middleware))
.layer(SetResponseHeaderLayer::if_not_present(
header::X_CONTENT_TYPE_OPTIONS,
header::HeaderValue::from_static("nosniff"),
Expand Down
Loading
Loading