Skip to content

Commit

Permalink
feat(frontend): redact sql option in log (#16871)
Browse files Browse the repository at this point in the history
  • Loading branch information
zwang28 committed May 28, 2024
1 parent ac93e24 commit 6c81beb
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 13 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions src/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,11 @@ pub struct BatchConfig {
/// This is the secs used to mask a worker unavailable temporarily.
#[serde(default = "default::batch::mask_worker_temporary_secs")]
pub mask_worker_temporary_secs: usize,

/// Keywords on which SQL option redaction is based in the query log.
/// A SQL option with a name containing any of these keywords will be redacted.
#[serde(default = "default::batch::redact_sql_option_keywords")]
pub redact_sql_option_keywords: Vec<String>,
}

/// The section `[streaming]` in `risingwave.toml`.
Expand Down Expand Up @@ -1749,6 +1754,20 @@ pub mod default {
pub fn mask_worker_temporary_secs() -> usize {
30
}

pub fn redact_sql_option_keywords() -> Vec<String> {
[
"credential",
"key",
"password",
"private",
"secret",
"token",
]
.into_iter()
.map(str::to_string)
.collect()
}
}

pub mod compaction_config {
Expand Down
1 change: 1 addition & 0 deletions src/config/docs.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ This page is automatically generated by `./risedev generate-example-config`
| frontend_compute_runtime_worker_threads | frontend compute runtime worker threads | 4 |
| mask_worker_temporary_secs | This is the secs used to mask a worker unavailable temporarily. | 30 |
| max_batch_queries_per_frontend_node | This is the max number of batch queries per frontend node. | |
| redact_sql_option_keywords | Keywords on which SQL option redaction is based in the query log. A SQL option with a name containing any of these keywords will be redacted. | ["credential", "key", "password", "private", "secret", "token"] |
| statement_timeout_in_sec | Timeout for a batch query in seconds. | 3600 |
| worker_threads_num | The thread number of the batch task runtime in the compute node. The default value is decided by `tokio`. | |

Expand Down
1 change: 1 addition & 0 deletions src/config/example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ enable_barrier_read = false
statement_timeout_in_sec = 3600
frontend_compute_runtime_worker_threads = 4
mask_worker_temporary_secs = 30
redact_sql_option_keywords = ["credential", "key", "password", "private", "secret", "token"]

[batch.developer]
batch_connector_message_buffer_size = 16
Expand Down
23 changes: 20 additions & 3 deletions src/frontend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ risingwave_expr_impl::enable!();

#[macro_use]
mod catalog;

use std::collections::HashSet;

pub use catalog::TableCatalog;
mod binder;
pub use binder::{bind_data_type, Binder};
Expand Down Expand Up @@ -168,8 +171,22 @@ pub fn start(opts: FrontendOpts) -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(async move {
let listen_addr = opts.listen_addr.clone();
let session_mgr = Arc::new(SessionManagerImpl::new(opts).await.unwrap());
pg_serve(&listen_addr, session_mgr, TlsConfig::new_default())
.await
.unwrap();
let redact_sql_option_keywords = Arc::new(
session_mgr
.env()
.batch_config()
.redact_sql_option_keywords
.iter()
.map(|s| s.to_lowercase())
.collect::<HashSet<_>>(),
);
pg_serve(
&listen_addr,
session_mgr,
TlsConfig::new_default(),
Some(redact_sql_option_keywords),
)
.await
.unwrap()
})
}
4 changes: 4 additions & 0 deletions src/frontend/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,10 @@ impl SessionManagerImpl {
})
}

pub fn env(&self) -> &FrontendEnv {
&self.env
}

fn insert_session(&self, session: Arc<SessionImpl>) {
let active_sessions = {
let mut write_guard = self.env.sessions_map.write();
Expand Down
1 change: 1 addition & 0 deletions src/sqlparser/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ normal = ["workspace-hack"]
itertools = { workspace = true }
serde = { version = "1.0", features = ["derive"], optional = true }
thiserror = "1.0.61"
tokio = { version = "0.2", package = "madsim-tokio" }
tracing = "0.1"
tracing-subscriber = "0.3"
winnow = { version = "0.6.8", git = "https://github.com/TennyZhuang/winnow.git", rev = "a6b1f04" }
Expand Down
26 changes: 25 additions & 1 deletion src/sqlparser/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ use alloc::{
};
use core::fmt;
use core::fmt::Display;
use std::collections::HashSet;
use std::sync::Arc;

use itertools::Itertools;
#[cfg(feature = "serde")]
Expand Down Expand Up @@ -59,6 +61,12 @@ pub use crate::ast::ddl::{
use crate::keywords::Keyword;
use crate::parser::{IncludeOption, IncludeOptionItem, Parser, ParserError};

pub type RedactSqlOptionKeywordsRef = Arc<HashSet<String>>;

tokio::task_local! {
pub static REDACT_SQL_OPTION_KEYWORDS: RedactSqlOptionKeywordsRef;
}

pub struct DisplaySeparated<'a, T>
where
T: fmt::Display,
Expand Down Expand Up @@ -2584,7 +2592,17 @@ pub struct SqlOption {

impl fmt::Display for SqlOption {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} = {}", self.name, self.value)
let should_redact = REDACT_SQL_OPTION_KEYWORDS
.try_with(|keywords| {
let sql_option_name = self.name.real_value().to_lowercase();
keywords.iter().any(|k| sql_option_name.contains(k))
})
.unwrap_or(false);
if should_redact {
write!(f, "{} = [REDACTED]", self.name)
} else {
write!(f, "{} = {}", self.name, self.value)
}
}
}

Expand Down Expand Up @@ -3142,6 +3160,12 @@ impl fmt::Display for DiscardType {
}
}

impl Statement {
pub fn to_redacted_string(&self, keywords: RedactSqlOptionKeywordsRef) -> String {
REDACT_SQL_OPTION_KEYWORDS.sync_scope(keywords, || self.to_string())
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
53 changes: 47 additions & 6 deletions src/utils/pgwire/src/pg_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use risingwave_common::types::DataType;
use risingwave_common::util::panic::FutureCatchUnwindExt;
use risingwave_common::util::query_log::*;
use risingwave_common::{PG_VERSION, SERVER_ENCODING, STANDARD_CONFORMING_STRINGS};
use risingwave_sqlparser::ast::Statement;
use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
use risingwave_sqlparser::parser::Parser;
use thiserror_ext::AsReport;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
Expand Down Expand Up @@ -101,6 +101,8 @@ where

// Client Address
peer_addr: AddressRef,

redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
}

/// Configures TLS encryption for connections.
Expand Down Expand Up @@ -152,16 +154,31 @@ pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> {
}

/// Record `sql` in the current tracing span.
fn record_sql_in_span(sql: &str) {
fn record_sql_in_span(sql: &str, redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>) {
let redacted_sql = if let Some(keywords) = redact_sql_option_keywords {
redact_sql(sql, keywords)
} else {
sql.to_owned()
};
tracing::Span::current().record(
"sql",
tracing::field::display(truncated_fmt::TruncatedFmt(
&sql,
&redacted_sql,
*RW_QUERY_LOG_TRUNCATE_LEN,
)),
);
}

fn redact_sql(sql: &str, keywords: RedactSqlOptionKeywordsRef) -> String {
match Parser::parse_sql(sql) {
Ok(sqls) => sqls
.into_iter()
.map(|sql| sql.to_redacted_string(keywords.clone()))
.join(";"),
Err(_) => sql.to_owned(),
}
}

impl<S, SM> PgProtocol<S, SM>
where
S: AsyncWrite + AsyncRead + Unpin,
Expand All @@ -172,6 +189,7 @@ where
session_mgr: Arc<SM>,
tls_config: Option<TlsConfig>,
peer_addr: AddressRef,
redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
) -> Self {
Self {
stream: Conn::Unencrypted(PgStream {
Expand All @@ -193,6 +211,7 @@ where
statement_portal_dependency: Default::default(),
ignore_util_sync: false,
peer_addr,
redact_sql_option_keywords,
}
}

Expand Down Expand Up @@ -555,7 +574,7 @@ where
async fn process_query_msg(&mut self, query_string: io::Result<&str>) -> PsqlResult<()> {
let sql: Arc<str> =
Arc::from(query_string.map_err(|err| PsqlError::SimpleQueryError(Box::new(err)))?);
record_sql_in_span(&sql);
record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());
let session = self.session.clone().unwrap();

session.check_idle_in_transaction_timeout()?;
Expand Down Expand Up @@ -664,7 +683,7 @@ where

fn process_parse_msg(&mut self, msg: FeParseMessage) -> PsqlResult<()> {
let sql = cstr_to_str(&msg.sql_bytes).unwrap();
record_sql_in_span(sql);
record_sql_in_span(sql, self.redact_sql_option_keywords.clone());
let session = self.session.clone().unwrap();
let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_string();

Expand Down Expand Up @@ -798,7 +817,7 @@ where
} else {
let portal = self.get_portal(&portal_name)?;
let sql: Arc<str> = Arc::from(format!("{}", portal));
record_sql_in_span(&sql);
record_sql_in_span(&sql, self.redact_sql_option_keywords.clone());

session.check_idle_in_transaction_timeout()?;
let _exec_context_guard = session.init_exec_context(sql.clone());
Expand Down Expand Up @@ -1205,3 +1224,25 @@ pub mod truncated_fmt {
}
}
}

#[cfg(test)]
mod tests {
use std::collections::HashSet;

use super::*;

#[test]
fn test_redact_parsable_sql() {
let keywords = Arc::new(HashSet::from(["v2".into(), "v4".into(), "b".into()]));
let sql = r"
create source temp (k bigint, v varchar) with (
connector = 'datagen',
v1 = 123,
v2 = 'with',
v3 = false,
v4 = '',
) FORMAT plain ENCODE json (a='1',b='2')
";
assert_eq!(redact_sql(sql, keywords), "CREATE SOURCE temp (k BIGINT, v CHARACTER VARYING) WITH (connector = 'datagen', v1 = 123, v2 = [REDACTED], v3 = false, v4 = [REDACTED]) FORMAT PLAIN ENCODE JSON (a = '1', b = [REDACTED])");
}
}
15 changes: 12 additions & 3 deletions src/utils/pgwire/src/pg_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use bytes::Bytes;
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use parking_lot::Mutex;
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::Statement;
use risingwave_sqlparser::ast::{RedactSqlOptionKeywordsRef, Statement};
use serde::Deserialize;
use thiserror_ext::AsReport;
use tokio::io::{AsyncRead, AsyncWrite};
Expand Down Expand Up @@ -251,6 +251,7 @@ pub async fn pg_serve(
addr: &str,
session_mgr: Arc<impl SessionManager>,
tls_config: Option<TlsConfig>,
redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
) -> io::Result<()> {
let listener = Listener::bind(addr).await?;
tracing::info!(addr, "server started");
Expand Down Expand Up @@ -281,6 +282,7 @@ pub async fn pg_serve(
session_mgr.clone(),
tls_config.clone(),
Arc::new(peer_addr),
redact_sql_option_keywords.clone(),
));
}

Expand All @@ -299,11 +301,18 @@ pub async fn handle_connection<S, SM>(
session_mgr: Arc<SM>,
tls_config: Option<TlsConfig>,
peer_addr: AddressRef,
redact_sql_option_keywords: Option<RedactSqlOptionKeywordsRef>,
) where
S: AsyncWrite + AsyncRead + Unpin,
SM: SessionManager,
{
let mut pg_proto = PgProtocol::new(stream, session_mgr, tls_config, peer_addr);
let mut pg_proto = PgProtocol::new(
stream,
session_mgr,
tls_config,
peer_addr,
redact_sql_option_keywords,
);
loop {
let msg = match pg_proto.read_message().await {
Ok(msg) => msg,
Expand Down Expand Up @@ -486,7 +495,7 @@ mod tests {
let pg_config = pg_config.into();

let session_mgr = Arc::new(MockSessionManager {});
tokio::spawn(async move { pg_serve(&bind_addr, session_mgr, None).await });
tokio::spawn(async move { pg_serve(&bind_addr, session_mgr, None, None).await });
// wait for server to start
tokio::time::sleep(std::time::Duration::from_millis(100)).await;

Expand Down

0 comments on commit 6c81beb

Please sign in to comment.