diff --git a/Cargo.lock b/Cargo.lock index 4ddf5c1..253a725 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1745,7 +1745,6 @@ version = "0.15.0" dependencies = [ "anyhow", "async-trait", - "flate2", "postgresql_archive", "postgresql_commands", "postgresql_embedded", @@ -1757,7 +1756,6 @@ dependencies = [ "semver", "serde", "serde_json", - "tar", "target-triple", "tempfile", "test-log", @@ -1765,7 +1763,6 @@ dependencies = [ "tokio", "tracing", "url", - "zip", ] [[package]] diff --git a/examples/zonky/src/main.rs b/examples/zonky/src/main.rs index b9ef79f..d98e4f7 100644 --- a/examples/zonky/src/main.rs +++ b/examples/zonky/src/main.rs @@ -9,7 +9,7 @@ use postgresql_embedded::{PostgreSQL, Result, Settings}; async fn main() -> Result<()> { let settings = Settings { releases_url: zonky::URL.to_string(), - version: VersionReq::parse("=16.2.0")?, + version: VersionReq::parse("=16.3.0")?, ..Default::default() }; let mut postgresql = PostgreSQL::new(settings); diff --git a/postgresql_archive/Cargo.toml b/postgresql_archive/Cargo.toml index c28e559..622c86c 100644 --- a/postgresql_archive/Cargo.toml +++ b/postgresql_archive/Cargo.toml @@ -12,7 +12,7 @@ version.workspace = true [dependencies] anyhow = { workspace = true } async-trait = { workspace = true } -flate2 = { workspace = true, optional = true } +flate2 = { workspace = true } hex = { workspace = true } http = { workspace = true } human_bytes = { workspace = true, default-features = false } @@ -29,15 +29,15 @@ serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, optional = true } sha1 = { workspace = true, optional = true } sha2 = { workspace = true, optional = true } -tar = { workspace = true, optional = true } +tar = { workspace = true } target-triple = { workspace = true, optional = true } tempfile = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["full"], optional = true } tracing = { workspace = true, features = ["log"] } url = { workspace = true } -xz2 = { workspace = true, optional = true } -zip = { workspace = true, optional = true } +xz2 = { workspace = true } +zip = { workspace = true } [dev-dependencies] criterion = { workspace = true } @@ -66,17 +66,11 @@ rustls-tls = ["reqwest/rustls-tls-native-roots"] sha1 = ["dep:sha1"] sha2 = ["dep:sha2"] theseus = [ - "dep:flate2", - "dep:tar", "dep:target-triple", "github", "sha2", ] zonky = [ - "dep:flate2", - "dep:tar", - "dep:xz2", - "dep:zip", "maven", ] diff --git a/postgresql_archive/src/archive.rs b/postgresql_archive/src/archive.rs index a7d7809..62c024a 100644 --- a/postgresql_archive/src/archive.rs +++ b/postgresql_archive/src/archive.rs @@ -3,6 +3,7 @@ use crate::error::Result; use crate::{extractor, repository}; +use regex::Regex; use semver::{Version, VersionReq}; use std::path::{Path, PathBuf}; use tracing::instrument; @@ -43,7 +44,9 @@ pub async fn get_archive(url: &str, version_req: &VersionReq) -> Result<(Version #[instrument(skip(bytes))] pub async fn extract(url: &str, bytes: &Vec, out_dir: &Path) -> Result> { let extractor_fn = extractor::registry::get(url)?; - extractor_fn(bytes, out_dir) + let mut extract_directories = extractor::ExtractDirectories::default(); + extract_directories.add_mapping(Regex::new(".*")?, out_dir.to_path_buf()); + extractor_fn(bytes, extract_directories) } #[cfg(test)] diff --git a/postgresql_archive/src/configuration/theseus/extractor.rs b/postgresql_archive/src/configuration/theseus/extractor.rs index 2279dec..da23ef5 100644 --- a/postgresql_archive/src/configuration/theseus/extractor.rs +++ b/postgresql_archive/src/configuration/theseus/extractor.rs @@ -1,14 +1,11 @@ +use crate::extractor::{tar_gz_extract, ExtractDirectories}; use crate::Error::Unexpected; use crate::Result; -use flate2::bufread::GzDecoder; -use human_bytes::human_bytes; -use num_format::{Locale, ToFormattedString}; -use std::fs::{create_dir_all, remove_dir_all, remove_file, rename, File}; -use std::io::{copy, BufReader, Cursor}; +use regex::Regex; +use std::fs::{create_dir_all, remove_dir_all, remove_file, rename}; use std::path::{Path, PathBuf}; use std::thread::sleep; use std::time::Duration; -use tar::Archive; use tracing::{debug, instrument, warn}; /// Extracts the compressed tar `bytes` to the [out_dir](Path). @@ -17,18 +14,14 @@ use tracing::{debug, instrument, warn}; /// Returns an error if the extraction fails. #[allow(clippy::cast_precision_loss)] #[instrument(skip(bytes))] -pub fn extract(bytes: &Vec, out_dir: &Path) -> Result> { - let mut files = Vec::new(); - let input = BufReader::new(Cursor::new(bytes)); - let decoder = GzDecoder::new(input); - let mut archive = Archive::new(decoder); - let mut extracted_bytes = 0; +pub fn extract(bytes: &Vec, extract_directories: ExtractDirectories) -> Result> { + let out_dir = extract_directories.get_path(".")?; let parent_dir = if let Some(parent) = out_dir.parent() { parent } else { debug!("No parent directory for {}", out_dir.to_string_lossy()); - out_dir + out_dir.as_path() }; create_dir_all(parent_dir)?; @@ -42,55 +35,14 @@ pub fn extract(bytes: &Vec, out_dir: &Path) -> Result> { out_dir.to_string_lossy() ); remove_file(&lock_file)?; - return Ok(files); + return Ok(Vec::new()); } let extract_dir = tempfile::tempdir_in(parent_dir)?.into_path(); debug!("Extracting archive to {}", extract_dir.to_string_lossy()); - - for archive_entry in archive.entries()? { - let mut entry = archive_entry?; - let entry_header = entry.header(); - let entry_type = entry_header.entry_type(); - let entry_size = entry_header.size()?; - #[cfg(unix)] - let file_mode = entry_header.mode()?; - - let entry_header_path = entry_header.path()?.to_path_buf(); - let prefix = match entry_header_path.components().next() { - Some(component) => component.as_os_str().to_str().unwrap_or_default(), - None => { - return Err(Unexpected( - "Failed to get file header path prefix".to_string(), - )); - } - }; - let stripped_entry_header_path = entry_header_path.strip_prefix(prefix)?.to_path_buf(); - let mut entry_name = extract_dir.clone(); - entry_name.push(stripped_entry_header_path); - - if entry_type.is_dir() || entry_name.is_dir() { - create_dir_all(&entry_name)?; - } else if entry_type.is_file() { - let mut output_file = File::create(&entry_name)?; - copy(&mut entry, &mut output_file)?; - extracted_bytes += entry_size; - - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - output_file.set_permissions(std::fs::Permissions::from_mode(file_mode))?; - } - files.push(entry_name); - } else if entry_type.is_symlink() { - #[cfg(unix)] - if let Some(symlink_target) = entry.link_name()? { - let symlink_path = entry_name.clone(); - std::os::unix::fs::symlink(symlink_target.as_ref(), symlink_path)?; - files.push(entry_name); - } - } - } + let mut archive_extract_directories = ExtractDirectories::default(); + archive_extract_directories.add_mapping(Regex::new(".*")?, extract_dir.clone()); + let files = tar_gz_extract(bytes, archive_extract_directories)?; if out_dir.exists() { debug!( @@ -113,13 +65,6 @@ pub fn extract(bytes: &Vec, out_dir: &Path) -> Result> { remove_file(lock_file)?; } - let number_of_files = files.len(); - debug!( - "Extracting {} files totalling {}", - number_of_files.to_formatted_string(&Locale::en), - human_bytes(extracted_bytes as f64) - ); - Ok(files) } diff --git a/postgresql_archive/src/configuration/zonky/extractor.rs b/postgresql_archive/src/configuration/zonky/extractor.rs index a27c7d7..bd24fe0 100644 --- a/postgresql_archive/src/configuration/zonky/extractor.rs +++ b/postgresql_archive/src/configuration/zonky/extractor.rs @@ -1,15 +1,13 @@ +use crate::extractor::{tar_xz_extract, ExtractDirectories}; use crate::Error::Unexpected; use crate::Result; -use human_bytes::human_bytes; -use num_format::{Locale, ToFormattedString}; -use std::fs::{create_dir_all, remove_dir_all, remove_file, rename, File}; -use std::io::{copy, BufReader, Cursor}; +use regex::Regex; +use std::fs::{create_dir_all, remove_dir_all, remove_file, rename}; +use std::io::Cursor; use std::path::{Path, PathBuf}; use std::thread::sleep; use std::time::Duration; -use tar::Archive; use tracing::{debug, instrument, warn}; -use xz2::bufread::XzDecoder; use zip::ZipArchive; /// Extracts the compressed tar `bytes` to the [out_dir](Path). @@ -19,13 +17,13 @@ use zip::ZipArchive; #[allow(clippy::case_sensitive_file_extension_comparisons)] #[allow(clippy::cast_precision_loss)] #[instrument(skip(bytes))] -pub fn extract(bytes: &Vec, out_dir: &Path) -> Result> { - let mut files = Vec::new(); +pub fn extract(bytes: &Vec, extract_directories: ExtractDirectories) -> Result> { + let out_dir = extract_directories.get_path(".")?; let parent_dir = if let Some(parent) = out_dir.parent() { parent } else { debug!("No parent directory for {}", out_dir.to_string_lossy()); - out_dir + out_dir.as_path() }; create_dir_all(parent_dir)?; @@ -39,7 +37,7 @@ pub fn extract(bytes: &Vec, out_dir: &Path) -> Result> { out_dir.to_string_lossy() ); remove_file(&lock_file)?; - return Ok(files); + return Ok(Vec::new()); } let extract_dir = tempfile::tempdir_in(parent_dir)?.into_path(); @@ -64,51 +62,9 @@ pub fn extract(bytes: &Vec, out_dir: &Path) -> Result> { return Err(Unexpected("Failed to find archive file".to_string())); } - let input = BufReader::new(Cursor::new(archive_bytes)); - let decoder = XzDecoder::new(input); - let mut archive = Archive::new(decoder); - let mut extracted_bytes = 0; - - for archive_entry in archive.entries()? { - let mut entry = archive_entry?; - let entry_header = entry.header(); - let entry_type = entry_header.entry_type(); - let entry_size = entry_header.size()?; - #[cfg(unix)] - let file_mode = entry_header.mode()?; - - let entry_header_path = entry_header.path()?.to_path_buf(); - let mut entry_name = extract_dir.clone(); - entry_name.push(entry_header_path); - - if let Some(parent) = entry_name.parent() { - if !parent.exists() { - create_dir_all(parent)?; - } - } - - if entry_type.is_dir() || entry_name.is_dir() { - create_dir_all(&entry_name)?; - } else if entry_type.is_file() { - let mut output_file = File::create(&entry_name)?; - copy(&mut entry, &mut output_file)?; - extracted_bytes += entry_size; - - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - output_file.set_permissions(std::fs::Permissions::from_mode(file_mode))?; - } - files.push(entry_name); - } else if entry_type.is_symlink() { - #[cfg(unix)] - if let Some(symlink_target) = entry.link_name()? { - let symlink_path = entry_name.clone(); - std::os::unix::fs::symlink(symlink_target.as_ref(), symlink_path)?; - files.push(entry_name); - } - } - } + let mut archive_extract_directories = ExtractDirectories::default(); + archive_extract_directories.add_mapping(Regex::new(".*")?, extract_dir.clone()); + let files = tar_xz_extract(&archive_bytes, archive_extract_directories)?; if out_dir.exists() { debug!( @@ -131,13 +87,6 @@ pub fn extract(bytes: &Vec, out_dir: &Path) -> Result> { remove_file(lock_file)?; } - let number_of_files = files.len(); - debug!( - "Extracting {} files totalling {}", - number_of_files.to_formatted_string(&Locale::en), - human_bytes(extracted_bytes as f64) - ); - Ok(files) } diff --git a/postgresql_archive/src/extractor/mod.rs b/postgresql_archive/src/extractor/mod.rs index d108990..56b41b9 100644 --- a/postgresql_archive/src/extractor/mod.rs +++ b/postgresql_archive/src/extractor/mod.rs @@ -1 +1,10 @@ +mod model; pub mod registry; +mod tar_gz_extractor; +mod tar_xz_extractor; +mod zip_extractor; + +pub use model::ExtractDirectories; +pub use tar_gz_extractor::extract as tar_gz_extract; +pub use tar_xz_extractor::extract as tar_xz_extract; +pub use zip_extractor::extract as zip_extract; diff --git a/postgresql_archive/src/extractor/model.rs b/postgresql_archive/src/extractor/model.rs new file mode 100644 index 0000000..258c266 --- /dev/null +++ b/postgresql_archive/src/extractor/model.rs @@ -0,0 +1,113 @@ +use crate::{Error, Result}; +use regex::Regex; +use std::fmt::Display; +use std::path::PathBuf; + +/// Extract directories manage the directories to extract a file in an archive to based upon the +/// associated regex matching the file path. +#[derive(Debug)] +pub struct ExtractDirectories { + mappings: Vec<(Regex, PathBuf)>, +} + +impl ExtractDirectories { + /// Creates a new ExtractDirectories instance. + #[must_use] + pub fn new(mappings: Vec<(Regex, PathBuf)>) -> Self { + Self { mappings } + } + + /// Adds a new mapping to the ExtractDirectories instance. + pub fn add_mapping(&mut self, regex: Regex, path: PathBuf) { + self.mappings.push((regex, path)); + } + + /// Returns the path associated with the first regex that matches the file path. + /// If no regex matches, then the file path is returned. + /// + /// # Errors + /// Returns an error if the file path cannot be converted to a string. + pub fn get_path(&self, file_path: &str) -> Result { + for (regex, path) in &self.mappings { + if regex.is_match(file_path) { + return Ok(path.clone()); + } + } + Err(Error::Unexpected(format!( + "No regex matched the file path: {file_path}" + ))) + } +} + +/// Default implementation for ExtractDirectories. +impl Default for ExtractDirectories { + /// Creates a new ExtractDirectories instance with an empty mappings vector. + fn default() -> Self { + ExtractDirectories::new(Vec::new()) + } +} + +/// Display implementation for ExtractDirectories. +impl Display for ExtractDirectories { + /// Formats the ExtractDirectories instance. + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for (regex, path) in &self.mappings { + writeln!(f, "{} -> {}", regex, path.display())?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new() -> Result<()> { + let mappings = vec![(Regex::new(".*")?, PathBuf::from("test"))]; + let extract_directories = ExtractDirectories::new(mappings); + let path = extract_directories.get_path("foo")?; + assert_eq!("test", path.to_string_lossy()); + Ok(()) + } + + #[test] + fn test_default() { + let extract_directories = ExtractDirectories::default(); + let result = extract_directories.get_path("foo"); + assert!(result.is_err()); + } + + #[test] + fn test_add_mapping() -> Result<()> { + let mut extract_directories = ExtractDirectories::default(); + extract_directories.add_mapping(Regex::new(".*")?, PathBuf::from("test")); + let path = extract_directories.get_path("foo")?; + assert_eq!("test", path.to_string_lossy()); + Ok(()) + } + + #[test] + fn test_get_path() -> Result<()> { + let mappings = vec![ + (Regex::new("test")?, PathBuf::from("test")), + (Regex::new("foo")?, PathBuf::from("bar")), + ]; + let extract_directories = ExtractDirectories::new(mappings); + let path = extract_directories.get_path("foo")?; + assert_eq!("bar", path.to_string_lossy()); + Ok(()) + } + + #[test] + fn test_display() -> Result<()> { + let mappings = vec![ + (Regex::new("test")?, PathBuf::from("test")), + (Regex::new("foo")?, PathBuf::from("bar")), + ]; + let extract_directories = ExtractDirectories::new(mappings); + let display = extract_directories.to_string(); + assert_eq!("test -> test\nfoo -> bar\n", display); + Ok(()) + } +} diff --git a/postgresql_archive/src/extractor/registry.rs b/postgresql_archive/src/extractor/registry.rs index 9ccd0d4..06ca8c8 100644 --- a/postgresql_archive/src/extractor/registry.rs +++ b/postgresql_archive/src/extractor/registry.rs @@ -2,16 +2,17 @@ use crate::configuration::theseus; #[cfg(feature = "zonky")] use crate::configuration::zonky; +use crate::extractor::ExtractDirectories; use crate::Error::{PoisonedLock, UnsupportedExtractor}; use crate::Result; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; use std::sync::{Arc, LazyLock, Mutex, RwLock}; static REGISTRY: LazyLock>> = LazyLock::new(|| Arc::new(Mutex::new(RepositoryRegistry::default()))); type SupportsFn = fn(&str) -> Result; -type ExtractFn = fn(&Vec, &Path) -> Result>; +type ExtractFn = fn(&Vec, ExtractDirectories) -> Result>; /// Singleton struct to store extractors #[allow(clippy::type_complexity)] @@ -99,13 +100,16 @@ pub fn get(url: &str) -> Result { #[cfg(test)] mod tests { use super::*; + use regex::Regex; #[test] fn test_register() -> Result<()> { register(|url| Ok(url == "https://foo.com"), |_, _| Ok(Vec::new()))?; let url = "https://foo.com"; let extractor = get(url)?; - assert!(extractor(&Vec::new(), Path::new("foo")).is_ok()); + let mut extract_directories = ExtractDirectories::default(); + extract_directories.add_mapping(Regex::new(".*")?, PathBuf::from("test")); + assert!(extractor(&Vec::new(), extract_directories).is_ok()); Ok(()) } diff --git a/postgresql_archive/src/extractor/tar_gz_extractor.rs b/postgresql_archive/src/extractor/tar_gz_extractor.rs new file mode 100644 index 0000000..bca6a92 --- /dev/null +++ b/postgresql_archive/src/extractor/tar_gz_extractor.rs @@ -0,0 +1,81 @@ +use crate::extractor::ExtractDirectories; +use crate::Error::Unexpected; +use crate::Result; +use flate2::bufread::GzDecoder; +use human_bytes::human_bytes; +use num_format::{Locale, ToFormattedString}; +use std::fs::{create_dir_all, File}; +use std::io::{copy, BufReader, Cursor}; +use std::path::PathBuf; +use tar::Archive; +use tracing::{debug, instrument, warn}; + +/// Extracts the compressed tar `bytes` to paths defined in `extract_directories`. +/// +/// # Errors +/// Returns an error if the extraction fails. +#[allow(clippy::cast_precision_loss)] +#[instrument(skip(bytes))] +pub fn extract(bytes: &Vec, extract_directories: ExtractDirectories) -> Result> { + let mut files = Vec::new(); + let input = BufReader::new(Cursor::new(bytes)); + let decoder = GzDecoder::new(input); + let mut archive = Archive::new(decoder); + let mut extracted_bytes = 0; + + for archive_entry in archive.entries()? { + let mut entry = archive_entry?; + let entry_header = entry.header(); + let entry_type = entry_header.entry_type(); + let entry_size = entry_header.size()?; + #[cfg(unix)] + let file_mode = entry_header.mode()?; + + let entry_header_path = entry_header.path()?.to_path_buf(); + let prefix = match entry_header_path.components().next() { + Some(component) => component.as_os_str().to_str().unwrap_or_default(), + None => { + return Err(Unexpected( + "Failed to get file header path prefix".to_string(), + )); + } + }; + let stripped_entry_header_path = entry_header_path.strip_prefix(prefix)?.to_path_buf(); + let Ok(extract_dir) = extract_directories.get_path(prefix) else { + continue; + }; + let mut entry_name = extract_dir.clone(); + entry_name.push(stripped_entry_header_path); + + if entry_type.is_dir() || entry_name.is_dir() { + create_dir_all(&entry_name)?; + } else if entry_type.is_file() { + let mut output_file = File::create(&entry_name)?; + copy(&mut entry, &mut output_file)?; + extracted_bytes += entry_size; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + output_file.set_permissions(std::fs::Permissions::from_mode(file_mode))?; + } + files.push(entry_name); + } else if entry_type.is_symlink() { + #[cfg(unix)] + if let Some(symlink_target) = entry.link_name()? { + let symlink_path = entry_name.clone(); + std::os::unix::fs::symlink(symlink_target.as_ref(), symlink_path)?; + files.push(entry_name); + } + } + } + + let number_of_files = files.len(); + debug!( + "Extracted {} files totalling {}", + number_of_files.to_formatted_string(&Locale::en), + human_bytes(extracted_bytes as f64) + ); + + Ok(files) +} diff --git a/postgresql_archive/src/extractor/tar_xz_extractor.rs b/postgresql_archive/src/extractor/tar_xz_extractor.rs new file mode 100644 index 0000000..cecdd90 --- /dev/null +++ b/postgresql_archive/src/extractor/tar_xz_extractor.rs @@ -0,0 +1,81 @@ +use crate::extractor::ExtractDirectories; +use crate::Error::Unexpected; +use crate::Result; +use human_bytes::human_bytes; +use num_format::{Locale, ToFormattedString}; +use std::fs::{create_dir_all, File}; +use std::io::{copy, BufReader, Cursor}; +use std::path::PathBuf; +use tar::Archive; +use tracing::{debug, instrument, warn}; +use xz2::bufread::XzDecoder; + +/// Extracts the compressed tar `bytes` to paths defined in `extract_directories`. +/// +/// # Errors +/// Returns an error if the extraction fails. +#[allow(clippy::cast_precision_loss)] +#[instrument(skip(bytes))] +pub fn extract(bytes: &Vec, extract_directories: ExtractDirectories) -> Result> { + let mut files = Vec::new(); + let input = BufReader::new(Cursor::new(bytes)); + let decoder = XzDecoder::new(input); + let mut archive = Archive::new(decoder); + let mut extracted_bytes = 0; + + for archive_entry in archive.entries()? { + let mut entry = archive_entry?; + let entry_header = entry.header(); + let entry_type = entry_header.entry_type(); + let entry_size = entry_header.size()?; + #[cfg(unix)] + let file_mode = entry_header.mode()?; + + let entry_header_path = entry_header.path()?.to_path_buf(); + let prefix = match entry_header_path.components().next() { + Some(component) => component.as_os_str().to_str().unwrap_or_default(), + None => { + return Err(Unexpected( + "Failed to get file header path prefix".to_string(), + )); + } + }; + let stripped_entry_header_path = entry_header_path.strip_prefix(prefix)?.to_path_buf(); + let Ok(extract_dir) = extract_directories.get_path(prefix) else { + continue; + }; + let mut entry_name = extract_dir.clone(); + entry_name.push(stripped_entry_header_path); + + if entry_type.is_dir() || entry_name.is_dir() { + create_dir_all(&entry_name)?; + } else if entry_type.is_file() { + let mut output_file = File::create(&entry_name)?; + copy(&mut entry, &mut output_file)?; + extracted_bytes += entry_size; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + output_file.set_permissions(std::fs::Permissions::from_mode(file_mode))?; + } + files.push(entry_name); + } else if entry_type.is_symlink() { + #[cfg(unix)] + if let Some(symlink_target) = entry.link_name()? { + let symlink_path = entry_name.clone(); + std::os::unix::fs::symlink(symlink_target.as_ref(), symlink_path)?; + files.push(entry_name); + } + } + } + + let number_of_files = files.len(); + debug!( + "Extracted {} files totalling {}", + number_of_files.to_formatted_string(&Locale::en), + human_bytes(extracted_bytes as f64) + ); + + Ok(files) +} diff --git a/postgresql_archive/src/extractor/zip_extractor.rs b/postgresql_archive/src/extractor/zip_extractor.rs new file mode 100644 index 0000000..a57285f --- /dev/null +++ b/postgresql_archive/src/extractor/zip_extractor.rs @@ -0,0 +1,54 @@ +use crate::extractor::ExtractDirectories; +use crate::Result; +use human_bytes::human_bytes; +use num_format::{Locale, ToFormattedString}; +use std::fs::create_dir_all; +use std::io::Cursor; +use std::path::PathBuf; +use std::{fs, io}; +use tracing::{debug, instrument, warn}; +use zip::ZipArchive; + +/// Extracts the compressed tar `bytes` to paths defined in `extract_directories`. +/// +/// # Errors +/// Returns an error if the extraction fails. +#[allow(clippy::cast_precision_loss)] +#[instrument(skip(bytes))] +pub fn extract(bytes: &Vec, extract_directories: ExtractDirectories) -> Result> { + let mut files = Vec::new(); + let reader = Cursor::new(bytes); + let mut archive = + ZipArchive::new(reader).map_err(|_| io::Error::new(io::ErrorKind::Other, "Zip error"))?; + let mut extracted_bytes = 0; + + for i in 0..archive.len() { + let mut file = archive + .by_index(i) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Zip error"))?; + let file_path = PathBuf::from(file.name()); + let file_path = PathBuf::from(file_path.file_name().unwrap_or_default()); + let file_name = file_path.to_string_lossy(); + + let Ok(extract_dir) = extract_directories.get_path(&file_name) else { + continue; + }; + create_dir_all(&extract_dir)?; + + let mut out = Vec::new(); + io::copy(&mut file, &mut out)?; + extracted_bytes += out.len() as u64; + let path = PathBuf::from(&extract_dir).join(file_path); + fs::write(&path, out)?; + files.push(path); + } + + let number_of_files = files.len(); + debug!( + "Extracted {} files totalling {}", + number_of_files.to_formatted_string(&Locale::en), + human_bytes(extracted_bytes as f64) + ); + + Ok(files) +} diff --git a/postgresql_extensions/Cargo.toml b/postgresql_extensions/Cargo.toml index 0bd1c5e..e2ab104 100644 --- a/postgresql_extensions/Cargo.toml +++ b/postgresql_extensions/Cargo.toml @@ -11,7 +11,6 @@ version.workspace = true [dependencies] async-trait = { workspace = true } -flate2 = { workspace = true, optional = true } postgresql_archive = { path = "../postgresql_archive", version = "0.15.0", default-features = false } postgresql_commands = { path = "../postgresql_commands", version = "0.15.0", default-features = false } regex = { workspace = true } @@ -22,13 +21,11 @@ reqwest-tracing = { workspace = true } semver = { workspace = true, features = ["serde"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true, optional = true } -tar = { workspace = true, optional = true } target-triple = { workspace = true, optional = true } tempfile = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["full"], optional = true } tracing = { workspace = true, features = ["log"] } -zip = { workspace = true, optional = true } url = { workspace = true } [dev-dependencies] @@ -47,18 +44,14 @@ default = [ blocking = ["tokio"] portal-corp = [ "dep:target-triple", - "dep:zip", "postgresql_archive/github", ] steampipe = [ - "dep:flate2", "dep:serde_json", - "dep:tar", "postgresql_archive/github", ] tensor-chord = [ "dep:target-triple", - "dep:zip", "postgresql_archive/github", ] tokio = [ diff --git a/postgresql_extensions/src/repository/portal_corp/repository.rs b/postgresql_extensions/src/repository/portal_corp/repository.rs index 1d06a3d..b9a51b8 100644 --- a/postgresql_extensions/src/repository/portal_corp/repository.rs +++ b/postgresql_extensions/src/repository/portal_corp/repository.rs @@ -4,14 +4,13 @@ use crate::repository::portal_corp::URL; use crate::repository::Repository; use crate::Result; use async_trait::async_trait; +use postgresql_archive::extractor::{zip_extract, ExtractDirectories}; use postgresql_archive::get_archive; use postgresql_archive::repository::github::repository::GitHub; +use regex::Regex; use semver::{Version, VersionReq}; use std::fmt::Debug; -use std::io::Cursor; use std::path::PathBuf; -use std::{fs, io}; -use zip::ZipArchive; /// PortalCorp repository. #[derive(Debug)] @@ -78,37 +77,11 @@ impl Repository for PortalCorp { extension_dir: PathBuf, archive: &[u8], ) -> Result> { - let reader = Cursor::new(archive); - let mut archive = ZipArchive::new(reader) - .map_err(|_| io::Error::new(io::ErrorKind::Other, "Zip error"))?; - let mut files = Vec::new(); - - for i in 0..archive.len() { - let mut file = archive - .by_index(i) - .map_err(|_| io::Error::new(io::ErrorKind::Other, "Zip error"))?; - let file_path = PathBuf::from(file.name()); - let file_path = PathBuf::from(file_path.file_name().unwrap_or_default()); - let file_name = file_path.to_string_lossy(); - - if file_name.ends_with(".dll") - || file_name.ends_with(".dylib") - || file_name.ends_with(".so") - { - let mut out = Vec::new(); - io::copy(&mut file, &mut out)?; - let path = PathBuf::from(&library_dir).join(file_path); - fs::write(&path, out)?; - files.push(path); - } else if file_name.ends_with(".control") || file_name.ends_with(".sql") { - let mut out = Vec::new(); - io::copy(&mut file, &mut out)?; - let path = PathBuf::from(&extension_dir).join(file_path); - fs::write(&path, out)?; - files.push(path); - } - } - + let mut extract_directories = ExtractDirectories::default(); + extract_directories.add_mapping(Regex::new(r"(\.dll|\.dylib|\.so)")?, library_dir); + extract_directories.add_mapping(Regex::new(r"(\.control|\.sql)")?, extension_dir); + let bytes = &archive.to_vec(); + let files = zip_extract(bytes, extract_directories)?; Ok(files) } } diff --git a/postgresql_extensions/src/repository/steampipe/repository.rs b/postgresql_extensions/src/repository/steampipe/repository.rs index ca19bed..718fda9 100644 --- a/postgresql_extensions/src/repository/steampipe/repository.rs +++ b/postgresql_extensions/src/repository/steampipe/repository.rs @@ -5,15 +5,13 @@ use crate::repository::{steampipe, Repository}; use crate::Error::ExtensionNotFound; use crate::Result; use async_trait::async_trait; -use flate2::bufread::GzDecoder; +use postgresql_archive::extractor::{tar_gz_extract, ExtractDirectories}; use postgresql_archive::get_archive; use postgresql_archive::repository::github::repository::GitHub; +use regex::Regex; use semver::{Version, VersionReq}; use std::fmt::Debug; -use std::fs; -use std::io::Read; use std::path::PathBuf; -use tar::Archive; /// Steampipe repository. #[derive(Debug)] @@ -92,30 +90,11 @@ impl Repository for Steampipe { extension_dir: PathBuf, archive: &[u8], ) -> Result> { - let tar = GzDecoder::new(archive); - let mut archive = Archive::new(tar); - let mut files = Vec::new(); - - for file in archive.entries()? { - let mut file = file?; - let file_path = PathBuf::from(file.path()?.file_name().unwrap_or_default()); - let file_name = file_path.to_string_lossy(); - - if file_name.ends_with(".dylib") || file_name.ends_with(".so") { - let mut bytes = Vec::new(); - file.read_to_end(&mut bytes)?; - let path = PathBuf::from(&library_dir).join(file_path); - fs::write(&path, bytes)?; - files.push(path); - } else if file_name.ends_with(".control") || file_name.ends_with(".sql") { - let mut bytes = Vec::new(); - file.read_to_end(&mut bytes)?; - let path = PathBuf::from(&extension_dir).join(file_path); - fs::write(&path, bytes)?; - files.push(path); - } - } - + let mut extract_directories = ExtractDirectories::default(); + extract_directories.add_mapping(Regex::new(r"(\.dll|\.dylib|\.so)")?, library_dir); + extract_directories.add_mapping(Regex::new(r"(\.control|\.sql)")?, extension_dir); + let bytes = &archive.to_vec(); + let files = tar_gz_extract(bytes, extract_directories)?; Ok(files) } } diff --git a/postgresql_extensions/src/repository/tensor_chord/repository.rs b/postgresql_extensions/src/repository/tensor_chord/repository.rs index cb85593..e14040e 100644 --- a/postgresql_extensions/src/repository/tensor_chord/repository.rs +++ b/postgresql_extensions/src/repository/tensor_chord/repository.rs @@ -4,14 +4,13 @@ use crate::repository::tensor_chord::URL; use crate::repository::Repository; use crate::Result; use async_trait::async_trait; +use postgresql_archive::extractor::{zip_extract, ExtractDirectories}; use postgresql_archive::get_archive; use postgresql_archive::repository::github::repository::GitHub; +use regex::Regex; use semver::{Version, VersionReq}; use std::fmt::Debug; -use std::io::Cursor; use std::path::PathBuf; -use std::{fs, io}; -use zip::ZipArchive; /// TensorChord repository. #[derive(Debug)] @@ -78,34 +77,11 @@ impl Repository for TensorChord { extension_dir: PathBuf, archive: &[u8], ) -> Result> { - let reader = Cursor::new(archive); - let mut archive = ZipArchive::new(reader) - .map_err(|_| io::Error::new(io::ErrorKind::Other, "Zip error"))?; - let mut files = Vec::new(); - - for i in 0..archive.len() { - let mut file = archive - .by_index(i) - .map_err(|_| io::Error::new(io::ErrorKind::Other, "Zip error"))?; - let file_path = PathBuf::from(file.name()); - let file_path = PathBuf::from(file_path.file_name().unwrap_or_default()); - let file_name = file_path.to_string_lossy(); - - if file_name.ends_with(".dylib") || file_name.ends_with(".so") { - let mut out = Vec::new(); - io::copy(&mut file, &mut out)?; - let path = PathBuf::from(&library_dir).join(file_path); - fs::write(&path, out)?; - files.push(path); - } else if file_name.ends_with(".control") || file_name.ends_with(".sql") { - let mut out = Vec::new(); - io::copy(&mut file, &mut out)?; - let path = PathBuf::from(&extension_dir).join(file_path); - fs::write(&path, out)?; - files.push(path); - } - } - + let mut extract_directories = ExtractDirectories::default(); + extract_directories.add_mapping(Regex::new(r"(\.dll|\.dylib|\.so)")?, library_dir); + extract_directories.add_mapping(Regex::new(r"(\.control|\.sql)")?, extension_dir); + let bytes = &archive.to_vec(); + let files = zip_extract(bytes, extract_directories)?; Ok(files) } }