Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pgdog/src/backend/pool/address.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use pgdog_config::Role;
use serde::{Deserialize, Serialize};
use url::Url;

use super::Password;
use super::{password::PasswordSource, Password};
use crate::backend::auth::{azure_workload_identity, rds_iam};
use crate::backend::pool::dns_cache::DnsCache;
use crate::backend::pool::token_cache::TokenCache;
Expand Down Expand Up @@ -108,14 +108,14 @@ impl Address {
let token = TokenCache::global()
.get_or_fetch(self, rds_iam::token)
.await?;
vec![token.into()]
vec![Password::new(&token, PasswordSource::RdsIam)]
}

ServerAuth::AzureWorkloadIdentity => {
let token = TokenCache::global()
.get_or_fetch(self, azure_workload_identity::token)
.await?;
vec![token.into()]
vec![Password::new(&token, PasswordSource::AzureIdentity)]
}
};

Expand Down
37 changes: 29 additions & 8 deletions pgdog/src/backend/pool/password.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
fmt::Display,
hash::Hash,
ops::Deref,
sync::{
Expand All @@ -9,27 +10,39 @@ use std::{

use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PasswordSource {
Config,
RdsIam,
AzureIdentity,
}

impl Display for PasswordSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Config => write!(f, "config"),
Self::RdsIam => write!(f, "rds iam"),
Self::AzureIdentity => write!(f, "azure workload identity"),
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Password {
pub(crate) password: String,
pub(crate) valid: Arc<AtomicBool>,
pub(crate) source: PasswordSource,
}

impl From<String> for Password {
fn from(password: String) -> Self {
Self {
password,
valid: Arc::new(AtomicBool::new(true)),
}
Self::new(&password, PasswordSource::Config)
}
}

impl From<&str> for Password {
fn from(password: &str) -> Self {
Self {
password: password.to_string(),
valid: Arc::new(AtomicBool::new(true)),
}
Self::new(password, PasswordSource::Config)
}
}

Expand Down Expand Up @@ -76,4 +89,12 @@ impl Password {
pub(crate) fn valid(&self, valid: bool) {
self.valid.store(valid, Ordering::Relaxed)
}

pub(crate) fn new(password: &str, source: PasswordSource) -> Self {
Self {
password: password.to_owned(),
source,
valid: Arc::new(AtomicBool::new(true)),
}
}
}
15 changes: 8 additions & 7 deletions pgdog/src/backend/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! PostgreSQL server connection.

use std::time::Duration;
use std::{ops::Deref, time::Duration};

use bytes::{BufMut, BytesMut};
use rustls_pki_types::ServerName;
Expand Down Expand Up @@ -142,7 +142,7 @@ impl Server {
addr: &Address,
options: ServerOptions,
connect_reason: ConnectReason,
auth_secret: &str,
auth_secret: &super::pool::Password,
) -> Result<Self, Error> {
debug!("=> {}", addr);
let stream = TcpStream::connect(addr.addr().await?).await?;
Expand Down Expand Up @@ -242,7 +242,7 @@ impl Server {
match auth {
Authentication::Ok => break,
Authentication::ClearTextPassword => {
let password = Password::new_password(auth_secret);
let password = Password::new_password(auth_secret.deref());
stream.send_flush(&password).await?;
}
Authentication::Sasl(_) => {
Expand Down Expand Up @@ -318,10 +318,11 @@ impl Server {
let params: Parameters = params.into();

info!(
"new server connection [{}, auth: {}, reason: {}] {}",
addr,
"new server connection: auth={}, source={}, reason={} [{}] {}",
auth_type,
auth_secret.source,
connect_reason,
addr,
if stream.is_tls() { "🔒" } else { "" },
);

Expand Down Expand Up @@ -1157,10 +1158,10 @@ impl Drop for Server {
self.stats().disconnect();
if let Some(mut stream) = self.stream.take() {
info!(
"closing server connection [{}, state: {}, reason: {}]",
self.addr,
"closing server connection: state={}, reason={} [{}]",
self.stats.get_state(),
self.disconnect_reason.take().unwrap_or_default(),
self.addr,
);

spawn(async move {
Expand Down
Loading