Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ bytes = "1.7"
chrono = "0.4.38"
clap = "4.5"
cql2 = "0.3.0"
duckdb = "1.1.1"
duckdb = "=1.1.1"
fluent-uri = "0.3.2"
futures = "0.3.31"
geo = "0.29.3"
Expand Down
1 change: 1 addition & 0 deletions crates/duckdb/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Added

- Configure s3 credential chain ([#621](https://github.com/stac-utils/stac-rs/pull/621))
- Read hive partitioned datasets, `Config` structure ([#624](https://github.com/stac-utils/stac-rs/pull/624))

## [0.1.1] - 2025-01-31

Expand Down
92 changes: 76 additions & 16 deletions crates/duckdb/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
pub struct Client {
connection: Connection,
config: Config,
}

/// Configuration for a client.
#[derive(Debug)]
pub struct Config {
/// Whether to enable the s3 credential chain, which allows s3:// url access.
///
/// True by default.
pub use_s3_credential_chain: bool,

/// Whether to enable hive partitioning.
///
/// False by default.
pub use_hive_partitioning: bool,
}

/// A SQL query.
Expand All @@ -109,43 +124,68 @@ impl Client {
/// let client = Client::new().unwrap();
/// ```
pub fn new() -> Result<Client> {
Client::with_config(Config::default())
}

/// Creates a new client with the provided configuration.
///
/// # Examples
///
/// ```
/// use stac_duckdb::{Client, Config};
///
/// let config = Config {
/// use_s3_credential_chain: true,
/// use_hive_partitioning: true,
/// };
/// let client = Client::with_config(config);
/// ```
pub fn with_config(config: Config) -> Result<Client> {
let connection = Connection::open_in_memory()?;
connection.execute("INSTALL spatial", [])?;
connection.execute("LOAD spatial", [])?;
connection.execute("INSTALL icu", [])?;
connection.execute("LOAD icu", [])?;
connection.execute("CREATE SECRET (TYPE S3, PROVIDER CREDENTIAL_CHAIN)", [])?;
Ok(Client { connection })
if config.use_s3_credential_chain {
connection.execute("CREATE SECRET (TYPE S3, PROVIDER CREDENTIAL_CHAIN)", [])?;
}
Ok(Client { connection, config })
}

/// Returns one or more [stac::Collection] from the items in the stac-geoparquet file.
pub fn collections(&self, href: &str) -> Result<Vec<Collection>> {
let start_datetime= if self.connection.prepare(&format!(
"SELECT column_name FROM (DESCRIBE SELECT * from read_parquet('{}')) where column_name = 'start_datetime'",
href
"SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'start_datetime'",
self.read_parquet_str(href)
))?.query([])?.next()?.is_some() {
"strftime(min(coalesce(start_datetime, datetime)), '%xT%X%z')"
} else {
"strftime(min(datetime), '%xT%X%z')"
};
let end_datetime= if self.connection.prepare(&format!(
"SELECT column_name FROM (DESCRIBE SELECT * from read_parquet('{}')) where column_name = 'end_datetime'",
href
))?.query([])?.next()?.is_some() {
let end_datetime = if self
.connection
.prepare(&format!(
"SELECT column_name FROM (DESCRIBE SELECT * from {}) where column_name = 'end_datetime'",
self.read_parquet_str(href)
))?
.query([])?
.next()?
.is_some()
{
"strftime(max(coalesce(end_datetime, datetime)), '%xT%X%z')"
} else {
"strftime(max(datetime), '%xT%X%z')"
};
let mut statement = self.connection.prepare(&format!(
"SELECT DISTINCT collection FROM read_parquet('{}')",
href
"SELECT DISTINCT collection FROM {}",
self.read_parquet_str(href)
))?;
let mut collections = Vec::new();
for row in statement.query_map([], |row| row.get::<_, String>(0))? {
let collection_id = row?;
let mut statement = self.connection.prepare(&
format!("SELECT ST_AsGeoJSON(ST_Extent_Agg(geometry)), {}, {} FROM read_parquet('{}') WHERE collection = $1", start_datetime, end_datetime,
href
format!("SELECT ST_AsGeoJSON(ST_Extent_Agg(geometry)), {}, {} FROM {} WHERE collection = $1", start_datetime, end_datetime,
self.read_parquet_str(href)
))?;
let row = statement.query_row([&collection_id], |row| {
Ok((
Expand Down Expand Up @@ -235,8 +275,8 @@ impl Client {
let fields = std::mem::take(&mut search.items.fields);

let mut statement = self.connection.prepare(&format!(
"SELECT column_name FROM (DESCRIBE SELECT * from read_parquet('{}'))",
href
"SELECT column_name FROM (DESCRIBE SELECT * from {})",
self.read_parquet_str(href)
))?;
let mut columns = Vec::new();
// Can we use SQL magic to make our query not depend on which columns are present?
Expand Down Expand Up @@ -354,14 +394,25 @@ impl Client {
}
Ok(Query {
sql: format!(
"SELECT {} FROM read_parquet('{}'){}",
"SELECT {} FROM {}{}",
columns.join(","),
href,
self.read_parquet_str(href),
suffix,
),
params,
})
}

fn read_parquet_str(&self, href: &str) -> String {
if self.config.use_hive_partitioning {
format!(
"read_parquet('{}', filename=true, hive_partitioning=1)",
href
)
} else {
format!("read_parquet('{}', filename=true)", href)
}
}
}

/// Return this crate's version.
Expand Down Expand Up @@ -396,6 +447,15 @@ fn to_geoarrow_record_batch(mut record_batch: RecordBatch) -> Result<RecordBatch
Ok(record_batch)
}

impl Default for Config {
fn default() -> Self {
Config {
use_hive_partitioning: false,
use_s3_credential_chain: true,
}
}
}

#[cfg(test)]
mod tests {
use super::Client;
Expand Down
Loading