Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
384 changes: 82 additions & 302 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ rustyline = "10.0"
tera = "1.19.0"
log = "0.4"
env_logger = "0.10"
pgwire-lite = "0.1.0"
zip = "0.6"
reqwest = { version = "0.11", features = ["blocking", "json"] }
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
indicatif = "0.17"
unicode-width = "0.1.10"
once_cell = "1.17.0"
Expand Down
2 changes: 1 addition & 1 deletion src/commands/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use std::path::Path;
use std::process;

use log::{debug, error, info};
use pgwire_lite::PgwireLite;

use crate::core::config::{get_full_context, render_globals, render_string_value};
use crate::core::env::load_env_vars;
Expand All @@ -25,6 +24,7 @@ use crate::core::utils::{
use crate::resource::manifest::{Manifest, Resource};
use crate::resource::validation::validate_manifest;
use crate::template::engine::TemplateEngine;
use crate::utils::pgwire::PgwireLite;

/// Core state for all command operations, equivalent to Python's StackQLBase.
pub struct CommandRunner {
Expand Down
2 changes: 1 addition & 1 deletion src/core/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use std::thread;
use std::time::{Duration, Instant};

use log::{debug, error, info, warn};
use pgwire_lite::PgwireLite;

use crate::utils::pgwire::PgwireLite;
use crate::utils::query::{execute_query, QueryResult};

/// Exit with error message. Matches Python's `catch_error_and_exit`.
Expand Down
4 changes: 2 additions & 2 deletions src/utils/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
use std::process;

use colored::*;
use pgwire_lite::PgwireLite;

use crate::globals::{server_host, server_port};
use crate::utils::pgwire::PgwireLite;

/// Creates a new PgwireLite client connection
pub fn create_client() -> PgwireLite {
Expand All @@ -38,7 +38,7 @@ pub fn create_client() -> PgwireLite {
});

println!("Connected to stackql server at {}:{}", host, port);
println!("Using libpq version: {}", client.libpq_version());
println!("Using pgwire client: {}", client.libpq_version());

client
}
1 change: 1 addition & 0 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod connection;
pub mod display;
pub mod download;
pub mod logging;
pub mod pgwire;
pub mod platform;
pub mod query;
pub mod server;
Expand Down
322 changes: 322 additions & 0 deletions src/utils/pgwire.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
// utils/pgwire.rs

//! Pure-Rust PostgreSQL simple-query wire protocol client.
//!
//! Implements only what stackql-deploy needs: unencrypted TCP connections
//! to a local StackQL server using the PostgreSQL simple query protocol (v3).
//! No native dependencies (replaces pgwire-lite → libpq-sys).

use std::collections::HashMap;
use std::io::{Read, Write};
use std::net::TcpStream;

/// A single column value returned from a query.
pub enum Value {
String(String),
Null,
Bool(bool),
Integer(i64),
Float(f64),
Bytes(Vec<u8>),
}

/// A server notice (NOTICE, WARNING, etc.).
pub struct Notice {
pub fields: HashMap<String, String>,
}

/// The result of a [`PgwireLite::query`] call.
pub struct PgQueryResult {
pub column_names: Vec<String>,
pub rows: Vec<HashMap<String, Value>>,
pub notices: Vec<Notice>,
/// Row count reported by CommandComplete (INSERT/UPDATE/DELETE n).
pub row_count: usize,
}

/// Minimal PostgreSQL wire-protocol client.
pub struct PgwireLite {
stream: TcpStream,
}

impl PgwireLite {
/// Connect to a PostgreSQL-protocol server (e.g. StackQL) at `host:port`.
///
/// `_ssl` and `_verbosity` are accepted for API compatibility but ignored;
/// the connection is always unencrypted (StackQL default).
pub fn new(host: &str, port: u16, _ssl: bool, _verbosity: &str) -> Result<Self, String> {
let addr = format!("{}:{}", host, port);
let stream = TcpStream::connect(&addr)
.map_err(|e| format!("Connection to {} failed: {}", addr, e))?;

let mut client = PgwireLite { stream };
client.startup()?;
Ok(client)
}

/// Returns a version string (no libpq; just identifies the client).
pub fn libpq_version(&self) -> String {
"pure-rust-pgwire-client".to_string()
}

// ------------------------------------------------------------------
// Startup handshake
// ------------------------------------------------------------------

fn startup(&mut self) -> Result<(), String> {
// Protocol version 3.0 = 0x00_03_00_00
const PROTOCOL_V3: i32 = 196608;

// Startup message: user=stackql, database=stackql, then double-null
let params = b"user\0stackql\0database\0stackql\0\0";
let total_len = 4 + 4 + params.len(); // length field + protocol + params

let mut msg = Vec::with_capacity(total_len);
msg.extend_from_slice(&(total_len as i32).to_be_bytes());
msg.extend_from_slice(&PROTOCOL_V3.to_be_bytes());
msg.extend_from_slice(params);

self.stream
.write_all(&msg)
.map_err(|e| format!("Startup write error: {}", e))?;

// Process auth / parameter-status messages until ReadyForQuery
loop {
let msg_type = self.read_byte()?;
let payload_len = self.read_i32()? as usize;
// payload_len includes the 4 bytes of the length field itself
let data = self.read_bytes(payload_len.saturating_sub(4))?;

match msg_type {
b'R' => {
// AuthenticationRequest
let auth_type =
i32::from_be_bytes(data[..4].try_into().map_err(|_| "Bad auth")?);
if auth_type != 0 {
return Err(format!(
"Unsupported authentication type {} from server",
auth_type
));
}
// AuthenticationOk — nothing to do
}
b'K' => {} // BackendKeyData — ignore
b'S' => {} // ParameterStatus — ignore
b'Z' => break, // ReadyForQuery
b'E' => return Err(parse_error_fields(&data)),
b'N' => {} // NoticeResponse during startup — ignore
_ => {} // Unknown message type — skip
}
}

Ok(())
}

// ------------------------------------------------------------------
// Query
// ------------------------------------------------------------------

/// Execute a simple (non-prepared) SQL query and return structured results.
pub fn query(&mut self, sql: &str) -> Result<PgQueryResult, String> {
// Send Query message: 'Q' | int32(len) | sql\0
let sql_bytes = sql.as_bytes();
let payload_len = 4 + sql_bytes.len() + 1; // length field + sql + null

let mut msg = Vec::with_capacity(1 + payload_len);
msg.push(b'Q');
msg.extend_from_slice(&(payload_len as i32).to_be_bytes());
msg.extend_from_slice(sql_bytes);
msg.push(0u8);

self.stream
.write_all(&msg)
.map_err(|e| format!("Query write error: {}", e))?;

// Collect response messages
let mut column_names: Vec<String> = Vec::new();
let mut rows: Vec<HashMap<String, Value>> = Vec::new();
let mut notices: Vec<Notice> = Vec::new();
let mut row_count: usize = 0;

loop {
let msg_type = self.read_byte()?;
let payload_len = self.read_i32()? as usize;
let data = self.read_bytes(payload_len.saturating_sub(4))?;

match msg_type {
b'T' => {
// RowDescription
column_names = parse_row_description(&data);
}
b'D' => {
// DataRow
let row = parse_data_row(&data, &column_names);
rows.push(row);
}
b'C' => {
// CommandComplete — tag like "SELECT 5", "INSERT 0 1", "UPDATE 3"
let tag = std::str::from_utf8(data.strip_suffix(b"\0").unwrap_or(&data))
.unwrap_or("")
.to_string();
if let Some(n) = tag.split_whitespace().last().and_then(|s| s.parse().ok()) {
row_count = n;
}
}
b'N' => {
notices.push(parse_notice_fields(&data));
}
b'E' => {
return Err(parse_error_fields(&data));
}
b'I' => {} // EmptyQueryResponse
b'Z' => break, // ReadyForQuery — done
_ => {}
}
}

Ok(PgQueryResult {
column_names,
rows,
notices,
row_count,
})
}

// ------------------------------------------------------------------
// Low-level I/O helpers
// ------------------------------------------------------------------

fn read_byte(&mut self) -> Result<u8, String> {
let mut buf = [0u8; 1];
self.stream
.read_exact(&mut buf)
.map_err(|e| format!("Read error: {}", e))?;
Ok(buf[0])
}

fn read_i32(&mut self) -> Result<i32, String> {
let mut buf = [0u8; 4];
self.stream
.read_exact(&mut buf)
.map_err(|e| format!("Read error: {}", e))?;
Ok(i32::from_be_bytes(buf))
}

fn read_bytes(&mut self, n: usize) -> Result<Vec<u8>, String> {
let mut buf = vec![0u8; n];
self.stream
.read_exact(&mut buf)
.map_err(|e| format!("Read error: {}", e))?;
Ok(buf)
}
}

// ------------------------------------------------------------------
// Message parsers (free functions for readability)
// ------------------------------------------------------------------

fn parse_row_description(data: &[u8]) -> Vec<String> {
let mut names = Vec::new();
if data.len() < 2 {
return names;
}
let num_fields = u16::from_be_bytes([data[0], data[1]]) as usize;
let mut pos = 2;

for _ in 0..num_fields {
// Null-terminated field name
let Some(null_off) = data[pos..].iter().position(|&b| b == 0) else {
break;
};
let name = String::from_utf8_lossy(&data[pos..pos + null_off]).into_owned();
names.push(name);
// Skip: name + null(1) + tableOID(4) + attrNum(2) + typeOID(4) + typeSize(2)
// + typeMod(4) + formatCode(2) = 19 bytes after the null
pos += null_off + 1 + 18;
}
names
}

fn parse_data_row(data: &[u8], columns: &[String]) -> HashMap<String, Value> {
let mut row = HashMap::new();
if data.len() < 2 {
return row;
}
let num_cols = u16::from_be_bytes([data[0], data[1]]) as usize;
let mut pos = 2;

for col_name in columns.iter().take(num_cols) {
if pos + 4 > data.len() {
break;
}
let col_len = i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;

let value = if col_len < 0 {
Value::Null
} else {
let len = col_len as usize;
if pos + len > data.len() {
break;
}
let s = String::from_utf8_lossy(&data[pos..pos + len]).into_owned();
pos += len;
Value::String(s)
};

row.insert(col_name.clone(), value);
}
row
}

fn parse_notice_fields(data: &[u8]) -> Notice {
let mut fields = HashMap::new();
let mut pos = 0;

while pos < data.len() {
let field_code = data[pos];
pos += 1;
if field_code == 0 {
break;
}
let Some(null_off) = data[pos..].iter().position(|&b| b == 0) else {
break;
};
let value = String::from_utf8_lossy(&data[pos..pos + null_off]).into_owned();
pos += null_off + 1;

let key = match field_code {
b'S' => "severity",
b'M' => "message",
b'D' => "detail",
b'H' => "hint",
b'C' => "code",
b'P' => "position",
b'W' => "where",
_ => continue,
};
fields.insert(key.to_string(), value);
}

Notice { fields }
}

fn parse_error_fields(data: &[u8]) -> String {
let mut pos = 0;
while pos < data.len() {
let field_code = data[pos];
pos += 1;
if field_code == 0 {
break;
}
let Some(null_off) = data[pos..].iter().position(|&b| b == 0) else {
break;
};
let value = String::from_utf8_lossy(&data[pos..pos + null_off]).into_owned();
pos += null_off + 1;
if field_code == b'M' {
return value;
}
}
"Unknown server error".to_string()
}
2 changes: 1 addition & 1 deletion src/utils/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
//! }
//! ```

use pgwire_lite::{PgwireLite, Value};
use crate::utils::pgwire::{PgwireLite, Value};

/// Represents a column in a query result.
pub struct QueryResultColumn {
Expand Down