From 4d91c4deac4bba1d9db52c78a2f46d950b5c8171 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 21 May 2026 11:31:23 -0700 Subject: [PATCH] feat: identify the source of auth for new server connections --- pgdog/src/backend/pool/address.rs | 6 ++--- pgdog/src/backend/pool/password.rs | 37 +++++++++++++++++++++++------- pgdog/src/backend/server.rs | 15 ++++++------ 3 files changed, 40 insertions(+), 18 deletions(-) diff --git a/pgdog/src/backend/pool/address.rs b/pgdog/src/backend/pool/address.rs index f7b46eda0..590b5e97b 100644 --- a/pgdog/src/backend/pool/address.rs +++ b/pgdog/src/backend/pool/address.rs @@ -7,7 +7,7 @@ use pgdog_config::Role; use serde::{Deserialize, Serialize}; use url::Url; -use super::Password; +use super::{password::PasswordSource, Password}; use crate::backend::auth::{azure_workload_identity, rds_iam}; use crate::backend::pool::dns_cache::DnsCache; use crate::backend::pool::token_cache::TokenCache; @@ -108,14 +108,14 @@ impl Address { let token = TokenCache::global() .get_or_fetch(self, rds_iam::token) .await?; - vec![token.into()] + vec![Password::new(&token, PasswordSource::RdsIam)] } ServerAuth::AzureWorkloadIdentity => { let token = TokenCache::global() .get_or_fetch(self, azure_workload_identity::token) .await?; - vec![token.into()] + vec![Password::new(&token, PasswordSource::AzureIdentity)] } }; diff --git a/pgdog/src/backend/pool/password.rs b/pgdog/src/backend/pool/password.rs index 7b7de3507..bf7cb0390 100644 --- a/pgdog/src/backend/pool/password.rs +++ b/pgdog/src/backend/pool/password.rs @@ -1,4 +1,5 @@ use std::{ + fmt::Display, hash::Hash, ops::Deref, sync::{ @@ -9,27 +10,39 @@ use std::{ use serde::{Deserialize, Serialize}; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum PasswordSource { + Config, + RdsIam, + AzureIdentity, +} + +impl Display for PasswordSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Config => write!(f, "config"), + Self::RdsIam => write!(f, "rds iam"), + Self::AzureIdentity => write!(f, "azure workload identity"), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Password { pub(crate) password: String, pub(crate) valid: Arc, + pub(crate) source: PasswordSource, } impl From for Password { fn from(password: String) -> Self { - Self { - password, - valid: Arc::new(AtomicBool::new(true)), - } + Self::new(&password, PasswordSource::Config) } } impl From<&str> for Password { fn from(password: &str) -> Self { - Self { - password: password.to_string(), - valid: Arc::new(AtomicBool::new(true)), - } + Self::new(password, PasswordSource::Config) } } @@ -76,4 +89,12 @@ impl Password { pub(crate) fn valid(&self, valid: bool) { self.valid.store(valid, Ordering::Relaxed) } + + pub(crate) fn new(password: &str, source: PasswordSource) -> Self { + Self { + password: password.to_owned(), + source, + valid: Arc::new(AtomicBool::new(true)), + } + } } diff --git a/pgdog/src/backend/server.rs b/pgdog/src/backend/server.rs index 89b1b4828..306fcda7e 100644 --- a/pgdog/src/backend/server.rs +++ b/pgdog/src/backend/server.rs @@ -1,6 +1,6 @@ //! PostgreSQL server connection. -use std::time::Duration; +use std::{ops::Deref, time::Duration}; use bytes::{BufMut, BytesMut}; use rustls_pki_types::ServerName; @@ -142,7 +142,7 @@ impl Server { addr: &Address, options: ServerOptions, connect_reason: ConnectReason, - auth_secret: &str, + auth_secret: &super::pool::Password, ) -> Result { debug!("=> {}", addr); let stream = TcpStream::connect(addr.addr().await?).await?; @@ -242,7 +242,7 @@ impl Server { match auth { Authentication::Ok => break, Authentication::ClearTextPassword => { - let password = Password::new_password(auth_secret); + let password = Password::new_password(auth_secret.deref()); stream.send_flush(&password).await?; } Authentication::Sasl(_) => { @@ -318,10 +318,11 @@ impl Server { let params: Parameters = params.into(); info!( - "new server connection [{}, auth: {}, reason: {}] {}", - addr, + "new server connection: auth={}, source={}, reason={} [{}] {}", auth_type, + auth_secret.source, connect_reason, + addr, if stream.is_tls() { "🔒" } else { "" }, ); @@ -1157,10 +1158,10 @@ impl Drop for Server { self.stats().disconnect(); if let Some(mut stream) = self.stream.take() { info!( - "closing server connection [{}, state: {}, reason: {}]", - self.addr, + "closing server connection: state={}, reason={} [{}]", self.stats.get_state(), self.disconnect_reason.take().unwrap_or_default(), + self.addr, ); spawn(async move {