diff --git a/.circleci/config.yml b/.circleci/config.yml index 63bfee52..09fc893e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -12,13 +12,15 @@ jobs: - image: cimg/rust:1.58.1 environment: RUST_LOG: info - - image: cimg/postgres:14.0 - auth: - username: mydockerhub-user - password: $DOCKERHUB_PASSWORD + - image: postgres:14 + # auth: + # username: mydockerhub-user + # password: $DOCKERHUB_PASSWORD environment: POSTGRES_USER: postgres POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_HOST_AUTH_METHOD: scram-sha-256 # Add steps to the job # See: https://circleci.com/docs/2.0/configuration-reference/#steps steps: diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index 9dbca9d2..f325d166 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -12,7 +12,7 @@ function start_pgcat() { } # Setup the database with shards and user -psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/query_routing_setup.sql +PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/query_routing_setup.sql PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard0 -i PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard1 -i PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard2 -i @@ -72,7 +72,7 @@ psql -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev (! psql -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null) # Start PgCat in debug to demonstrate failover better -start_pgcat "debug" +start_pgcat "trace" # Add latency to the replica at port 5433 slightly above the healthcheck timeout toxiproxy-cli toxic add -t latency -a latency=300 postgres_replica diff --git a/Cargo.lock b/Cargo.lock index 19df07b0..51bbc6e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -45,6 +45,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" +[[package]] +name = "base64" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd" + [[package]] name = "bb8" version = "0.7.1" @@ -109,22 +115,23 @@ dependencies = [ [[package]] name = "crypto-common" -version = "0.1.1" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683d6b536309245c849479fba3da410962a43ed8e51c26b729208ec0ac2798d0" +checksum = "57952ca27b5e3606ff4dd79b0020231aaf9d6aa76dc05fd30137538c50bd3ce8" dependencies = [ "generic-array", + "typenum", ] [[package]] name = "digest" -version = "0.10.1" +version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b697d66081d42af4fba142d56918a3cb21dc8eb63372c6b85d14f44fb9c5979b" +checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506" dependencies = [ "block-buffer", "crypto-common", - "generic-array", + "subtle", ] [[package]] @@ -205,6 +212,15 @@ dependencies = [ "libc", ] +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "humantime" version = "2.1.0" @@ -356,10 +372,12 @@ version = "0.1.0-beta2" dependencies = [ "arc-swap", "async-trait", + "base64", "bb8", "bytes", "chrono", "env_logger", + "hmac", "log", "md-5", "num_cpus", @@ -370,7 +388,9 @@ dependencies = [ "serde", "serde_derive", "sha-1", + "sha2", "sqlparser", + "stringprep", "tokio", "toml", ] @@ -511,6 +531,17 @@ dependencies = [ "digest", ] +[[package]] +name = "sha2" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55deaec60f81eefe3cce0dc50bda92d6d8e88f2a27df7c5033b42afeb1ed2676" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "signal-hook-registry" version = "1.4.0" @@ -541,6 +572,22 @@ dependencies = [ "log", ] +[[package]] +name = "stringprep" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ee348cb74b87454fff4b551cbf727025810a004f88aeacae7f85b87f4e9a1c1" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "subtle" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" + [[package]] name = "syn" version = "1.0.86" @@ -572,6 +619,21 @@ dependencies = [ "winapi", ] +[[package]] +name = "tinyvec" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" + [[package]] name = "tokio" version = "1.16.1" @@ -617,6 +679,21 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987" +[[package]] +name = "unicode-bidi" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "099b7128301d285f79ddd55b9a83d5e6b9e97c92e0ea0daebee7263e932de992" + +[[package]] +name = "unicode-normalization" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d54590932941a9e9266f0832deed84ebe1bf2e4c9e4a3554d393d18f5e854bf9" +dependencies = [ + "tinyvec", +] + [[package]] name = "unicode-xid" version = "0.2.2" diff --git a/Cargo.toml b/Cargo.toml index d41657ad..ae18d7b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgcat" -version = "0.1.0-beta2" +version = "0.2.0-beta1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -25,3 +25,7 @@ log = "0.4" arc-swap = "1" env_logger = "0.9" parking_lot = "0.11" +hmac = "0.12" +sha2 = "0.10" +base64 = "0.13" +stringprep = "0.1" \ No newline at end of file diff --git a/src/constants.rs b/src/constants.rs index a0b63c09..0900d7cc 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -14,6 +14,13 @@ pub const CANCEL_REQUEST_CODE: i32 = 80877102; // AuthenticationMD5Password pub const MD5_ENCRYPTED_PASSWORD: i32 = 5; +// SASL +pub const SASL: i32 = 10; +pub const SASL_CONTINUE: i32 = 11; +pub const SASL_FINAL: i32 = 12; +pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; +pub const NONCE_LENGTH: usize = 24; + // AuthenticationOk pub const AUTHENTICATION_SUCCESSFUL: i32 = 0; diff --git a/src/main.rs b/src/main.rs index c7a82e17..c22391ed 100644 --- a/src/main.rs +++ b/src/main.rs @@ -54,6 +54,7 @@ mod errors; mod messages; mod pool; mod query_router; +mod scram; mod server; mod sharding; mod stats; diff --git a/src/scram.rs b/src/scram.rs new file mode 100644 index 00000000..58096fa8 --- /dev/null +++ b/src/scram.rs @@ -0,0 +1,311 @@ +// SCRAM authentication...largely copy/pasted from +// https://github.com/sfackler/rust-postgres/. + +use bytes::BytesMut; +use hmac::{Hmac, Mac}; +use rand::{self, Rng}; +use sha2::digest::FixedOutput; +use sha2::{Digest, Sha256}; + +use std::fmt::Write; + +use crate::constants::*; +use crate::errors::Error; + +fn normalize(pass: &[u8]) -> Vec { + let pass = match std::str::from_utf8(pass) { + Ok(pass) => pass, + Err(_) => return pass.to_vec(), + }; + + match stringprep::saslprep(pass) { + Ok(pass) => pass.into_owned().into_bytes(), + Err(_) => pass.as_bytes().to_vec(), + } +} + +pub struct ScramSha256 { + password: String, + salted_password: [u8; 32], + auth_message: String, + message: BytesMut, + nonce: String, +} + +impl ScramSha256 { + pub fn new(password: &str) -> ScramSha256 { + let mut rng = rand::thread_rng(); + let nonce = (0..NONCE_LENGTH) + .map(|_| { + let mut v = rng.gen_range(0x21u8..0x7e); + if v == 0x2c { + v = 0x7e + } + v as char + }) + .collect::(); + + Self::from_nonce(password, &nonce) + } + + pub fn from_nonce(password: &str, nonce: &str) -> ScramSha256 { + let message = BytesMut::from(&format!("{}n=,r={}", "n,,", nonce).as_bytes()[..]); + + ScramSha256 { + password: password.to_string(), + nonce: String::from(nonce), + message, + salted_password: [0u8; 32], + auth_message: String::new(), + } + } + + pub fn message(&mut self) -> BytesMut { + self.message.clone() + } + + pub fn update(&mut self, message: &BytesMut) -> Result { + let server_message = Message::parse(message)?; + + if !server_message.nonce.starts_with(&self.nonce) { + // trace!("Bad server nonce"); + return Err(Error::ProtocolSyncError); + } + + let salt = match base64::decode(&server_message.salt) { + Ok(salt) => salt, + Err(_) => return Err(Error::ProtocolSyncError), + }; + + let salted_password = Self::hi( + &normalize(&self.password.as_bytes()[..]), + &salt, + server_message.iterations, + ); + self.salted_password = salted_password; + + let mut hmac = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(b"Client Key"); + let client_key = hmac.finalize().into_bytes(); + + let mut hash = Sha256::default(); + hash.update(client_key.as_slice()); + let stored_key = hash.finalize_fixed(); + + let mut cbind_input = vec![]; + cbind_input.extend("n,,".as_bytes()); + let cbind_input = base64::encode(&cbind_input); + + self.message.clear(); + write!( + &mut self.message, + "c={},r={}", + cbind_input, server_message.nonce + ) + .unwrap(); + + let auth_message = format!( + "n=,r={},{},{}", + self.nonce, + String::from_utf8_lossy(&message[..]), + String::from_utf8_lossy(&self.message[..]) + ); + + let mut hmac = Hmac::::new_from_slice(&stored_key) + .expect("HMAC is able to accept all key sizes"); + hmac.update(auth_message.as_bytes()); + let client_signature = hmac.finalize().into_bytes(); + + let mut client_proof = client_key; + for (proof, signature) in client_proof.iter_mut().zip(client_signature) { + *proof ^= signature; + } + + write!(&mut self.message, ",p={}", base64::encode(&*client_proof)).unwrap(); + + self.auth_message = auth_message; + + Ok(self.message.clone()) + } + + pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> { + let final_message = FinalMessage::parse(message)?; + + let verifier = match base64::decode(&final_message.value) { + Ok(verifier) => verifier, + Err(_) => return Err(Error::ProtocolSyncError), + }; + + let mut hmac = Hmac::::new_from_slice(&self.salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(b"Server Key"); + let server_key = hmac.finalize().into_bytes(); + + let mut hmac = Hmac::::new_from_slice(&server_key) + .expect("HMAC is able to accept all key sizes"); + hmac.update(self.auth_message.as_bytes()); + + match hmac.verify_slice(&verifier) { + Ok(_) => Ok(()), + Err(_) => return Err(Error::ServerError), + } + } + + // https://github.com/sfackler/rust-postgres/blob/c3a029e60c1c0bd0be947049859b8fa5bd5ac220/postgres-protocol/src/authentication/sasl.rs#L35 + fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] { + let mut hmac = + Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); + hmac.update(salt); + hmac.update(&[0, 0, 0, 1]); + let mut prev = hmac.finalize().into_bytes(); + + let mut hi = prev; + + for _ in 1..i { + let mut hmac = Hmac::::new_from_slice(str).expect("already checked above"); + hmac.update(&prev); + prev = hmac.finalize().into_bytes(); + + for (hi, prev) in hi.iter_mut().zip(prev) { + *hi ^= prev; + } + } + + hi.into() + } +} + +#[derive(Default, Debug)] +struct Message { + nonce: String, + salt: String, + iterations: u32, +} + +impl Message { + fn parse(message: &BytesMut) -> Result { + if !message.starts_with(b"r=") { + return Err(Error::ProtocolSyncError); + } + + let mut i = 2; + + while message[i] != b',' && i < message.len() { + i += 1; + } + + let nonce = String::from_utf8_lossy(&message[2..i]).to_string(); + + // Skip the , + i += 1; + + if !&message[i..].starts_with(b"s=") { + return Err(Error::ProtocolSyncError); + } + + // Skip the s= + i += 2; + + let s = i; + while message[i] != b',' && i < message.len() { + i += 1; + } + + let salt = String::from_utf8_lossy(&message[s..i]).to_string(); + + // Skip the , + i += 1; + + if !&message[i..].starts_with(b"i=") { + return Err(Error::ProtocolSyncError); + } + + i += 2; + + let iterations = match String::from_utf8_lossy(&message[i..]).parse::() { + Ok(it) => it, + Err(_) => return Err(Error::ProtocolSyncError), + }; + + Ok(Message { + nonce, + salt, + iterations, + }) + } +} + +struct FinalMessage { + value: String, +} + +impl FinalMessage { + pub fn parse(message: &BytesMut) -> Result { + if !message.starts_with(b"v=") { + return Err(Error::ProtocolSyncError); + } + + Ok(FinalMessage { + value: String::from_utf8_lossy(&message[2..]).to_string(), + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn parse_server_first_message() { + let message = BytesMut::from( + &"r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096".as_bytes()[..], + ); + let message = Message::parse(&message).unwrap(); + assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j"); + assert_eq!(message.salt, "QSXCR+Q6sek8bf92"); + assert_eq!(message.iterations, 4096); + } + + #[test] + fn parse_server_last_message() { + let f = FinalMessage::parse(&BytesMut::from( + &"v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw".as_bytes()[..], + )) + .unwrap(); + assert_eq!( + f.value, + "U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw".to_string() + ); + } + + // recorded auth exchange from psql + #[test] + fn exchange() { + let password = "foobar"; + let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB"; + + let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB"; + let server_first = + "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\ + =4096"; + let client_final = + "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\ + 1NTlQYNs5BTeQjdHdk7lOflDo5re2an8="; + let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw="; + + let mut scram = ScramSha256::from_nonce(password, nonce); + + let message = scram.message(); + assert_eq!(std::str::from_utf8(&message).unwrap(), client_first); + + let result = scram + .update(&BytesMut::from(&server_first.as_bytes()[..])) + .unwrap(); + assert_eq!(std::str::from_utf8(&result).unwrap(), client_final); + + scram + .finish(&BytesMut::from(&server_final.as_bytes()[..])) + .unwrap(); + } +} diff --git a/src/server.rs b/src/server.rs index 0dd051e5..3670af9b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -12,6 +12,7 @@ use crate::config::{Address, User}; use crate::constants::*; use crate::errors::Error; use crate::messages::*; +use crate::scram::ScramSha256; use crate::stats::Reporter; use crate::ClientServerMap; @@ -89,6 +90,8 @@ impl Server { // We'll be handling multiple packets, but they will all be structured the same. // We'll loop here until this exchange is complete. + let mut scram = ScramSha256::new(&user.password); + loop { let code = match stream.read_u8().await { Ok(code) => code as char, @@ -130,6 +133,83 @@ impl Server { AUTHENTICATION_SUCCESSFUL => (), + SASL => { + debug!("Starting SASL authentication"); + let sasl_len = (len - 8) as usize; + let mut sasl_auth = vec![0u8; sasl_len]; + match stream.read_exact(&mut sasl_auth).await { + Ok(_) => (), + Err(_) => return Err(Error::SocketError), + }; + + let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]); + + if sasl_type == SCRAM_SHA_256 { + debug!("Using {}", SCRAM_SHA_256); + + // Send client message + let sasl_response = scram.message(); + let mut res = BytesMut::new(); + res.put_u8(b'p'); + res.put_i32( + 4 + SCRAM_SHA_256.len() as i32 + + 1 + + sasl_response.len() as i32 + + 4, + ); + res.put_slice(&format!("{}\0", SCRAM_SHA_256).as_bytes()[..]); + res.put_i32(sasl_response.len() as i32); + res.put(sasl_response); + + write_all(&mut stream, res).await?; + } else { + error!("Unsupported SCRAM version: {}", sasl_type); + return Err(Error::ServerError); + } + } + + SASL_CONTINUE => { + trace!("Continuing SASL"); + + let mut sasl_data = vec![0u8; (len - 8) as usize]; + + match stream.read_exact(&mut sasl_data).await { + Ok(_) => (), + Err(_) => return Err(Error::SocketError), + }; + + let msg = BytesMut::from(&sasl_data[..]); + let sasl_response = scram.update(&msg)?; + + let mut res = BytesMut::new(); + res.put_u8(b'p'); + res.put_i32(4 + sasl_response.len() as i32); + res.put(sasl_response); + + write_all(&mut stream, res).await?; + } + + SASL_FINAL => { + trace!("Final SASL"); + + let mut sasl_final = vec![0u8; len as usize - 8]; + match stream.read_exact(&mut sasl_final).await { + Ok(_) => (), + Err(_) => return Err(Error::SocketError), + }; + + match scram.finish(&BytesMut::from(&sasl_final[..])) { + Ok(_) => { + debug!("SASL authentication successful"); + } + + Err(err) => { + debug!("SASL authentication failed"); + return Err(err); + } + }; + } + _ => { error!("Unsupported authentication mechanism: {}", auth_code); return Err(Error::ServerError);