From 075155e77de802e28cd86ce7ea01cc4ba68eae48 Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Mon, 25 Aug 2025 16:38:29 +0800 Subject: [PATCH] feat: record queries by sqlite Signed-off-by: cutecutecat --- Cargo.lock | 59 ++++++++++ Cargo.toml | 1 + scripts/train.py | 2 +- src/index/functions.rs | 24 ++++ src/index/gucs.rs | 46 ++++++++ src/index/scanners.rs | 2 + src/index/vchordg/am/mod.rs | 10 +- src/index/vchordg/scanners/default.rs | 21 +++- src/index/vchordrq/am/mod.rs | 15 ++- src/index/vchordrq/scanners/default.rs | 20 +++- src/index/vchordrq/scanners/maxsim.rs | 2 + src/lib.rs | 2 + src/recorder/hook.rs | 53 +++++++++ src/recorder/mod.rs | 26 +++++ src/recorder/text.rs | 40 +++++++ src/recorder/types.rs | 62 +++++++++++ src/recorder/worker.rs | 146 +++++++++++++++++++++++++ src/sql/finalize.sql | 96 ++++++++++++++++ tests/vchordrq/recall.slt | 132 +++++++++++++++++++++- 19 files changed, 744 insertions(+), 15 deletions(-) create mode 100644 src/recorder/hook.rs create mode 100644 src/recorder/mod.rs create mode 100644 src/recorder/text.rs create mode 100644 src/recorder/types.rs create mode 100644 src/recorder/worker.rs diff --git a/Cargo.lock b/Cargo.lock index ab9b7d02..d178abb2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -420,6 +420,18 @@ dependencies = [ "once_cell", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fixedbitset" version = "0.5.7" @@ -499,6 +511,15 @@ dependencies = [ "foldhash", ] +[[package]] +name = "hashlink" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" +dependencies = [ + "hashbrown", +] + [[package]] name = "heck" version = "0.5.0" @@ -733,6 +754,17 @@ dependencies = [ "libc", ] +[[package]] +name = "libsqlite3-sys" +version = "0.35.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "133c182a6a2c87864fe97778797e46c7e999672690dc9fa3ee8e241aa4a9c13f" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -999,6 +1031,12 @@ dependencies = [ "syn", ] +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "potential_utf" version = "0.1.2" @@ -1157,6 +1195,20 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" +[[package]] +name = "rusqlite" +version = "0.37.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "165ca6e57b20e1351573e3729b958bc62f0e48025386970b6e4d29e7a7e71f3f" +dependencies = [ + "bitflags", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rustc-hash" version = "2.1.1" @@ -1593,6 +1645,7 @@ dependencies = [ "pgrx-catalog", "rabitq", "rand", + "rusqlite", "seq-macro", "serde", "simd", @@ -1639,6 +1692,12 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "vector" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index 9a087514..67ea8192 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ paste.workspace = true pgrx = { version = "=0.16.0", default-features = false, features = ["cshim"] } pgrx-catalog = "0.3.1" rand.workspace = true +rusqlite = { version = "0.37.0", features = ["bundled"] } seq-macro.workspace = true serde.workspace = true toml = "0.9.5" diff --git a/scripts/train.py b/scripts/train.py index e5f38925..54f86c28 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -29,7 +29,7 @@ import numpy as np DEFAULT_LISTS = 4096 -N_ITER = 25 +N_ITER = 10 CHUNKS = 10 SEED = 42 MAX_POINTS_PER_CLUSTER = 256 diff --git a/src/index/functions.rs b/src/index/functions.rs index 3711fac4..17cf6741 100644 --- a/src/index/functions.rs +++ b/src/index/functions.rs @@ -13,6 +13,8 @@ // Copyright (c) 2025 TensorChord Inc. use crate::index::storage::PostgresRelation; +use crate::recorder::dump; +use pgrx::iter::SetOfIterator; use pgrx::pg_sys::Oid; use pgrx_catalog::{PgAm, PgClass, PgClassRelkind}; @@ -84,3 +86,25 @@ impl Drop for Index { } } } + +#[pgrx::pg_extern(sql = "")] +fn _vchordrq_sampled_vectors(indexrelid: Oid) -> SetOfIterator<'static, String> { + let pg_am = PgAm::search_amname(c"vchordrq").unwrap(); + let Some(pg_am) = pg_am.get() else { + pgrx::error!("vchord is not installed"); + }; + let pg_class = PgClass::search_reloid(indexrelid).unwrap(); + let Some(pg_class) = pg_class.get() else { + pgrx::error!("the relation does not exist"); + }; + if pg_class.relkind() != PgClassRelkind::Index { + pgrx::error!("the relation {:?} is not an index", pg_class.relname()); + } + if pg_class.relam() != pg_am.oid() { + pgrx::error!("the index {:?} is not a vchordrq index", pg_class.relname()); + } + // The user must have access to the index, if not, raise an error from Postgres. + let _relation = Index::open(indexrelid, pgrx::pg_sys::AccessShareLock as _); + let queries = dump(indexrelid.to_u32()); + SetOfIterator::new(queries) +} diff --git a/src/index/gucs.rs b/src/index/gucs.rs index 2e039c14..ed501a52 100644 --- a/src/index/gucs.rs +++ b/src/index/gucs.rs @@ -27,6 +27,12 @@ pub enum PostgresIo { ReadStream, } +static VCHORDRQ_QUERY_SAMPLING_ENABLE: GucSetting = GucSetting::::new(false); + +static VCHORDRQ_QUERY_SAMPLING_MAX_RECORDS: GucSetting = GucSetting::::new(0); + +static VCHORDRQ_QUERY_SAMPLING_RATE: GucSetting = GucSetting::::new(0.0); + static VCHORDG_ENABLE_SCAN: GucSetting = GucSetting::::new(true); static VCHORDG_EF_SEARCH: GucSetting = GucSetting::::new(64); @@ -158,6 +164,34 @@ pub fn init() { GucContext::Userset, GucFlags::default(), ); + GucRegistry::define_bool_guc( + c"vchordrq.query_sampling_enable", + c"`query_sampling_enable` argument of vchordrq.", + c"`query_sampling_enable` argument of vchordrq.", + &VCHORDRQ_QUERY_SAMPLING_ENABLE, + GucContext::Userset, + GucFlags::default(), + ); + GucRegistry::define_int_guc( + c"vchordrq.query_sampling_max_records", + c"`query_sampling_max_records` argument of vchordrq.", + c"`query_sampling_max_records` argument of vchordrq.", + &VCHORDRQ_QUERY_SAMPLING_MAX_RECORDS, + 0, + 10000, + GucContext::Userset, + GucFlags::default(), + ); + GucRegistry::define_float_guc( + c"vchordrq.query_sampling_rate", + c"`query_sampling_rate` argument of vchordrq.", + c"`query_sampling_rate` argument of vchordrq.", + &VCHORDRQ_QUERY_SAMPLING_RATE, + 0.0, + 1.0, + GucContext::Userset, + GucFlags::default(), + ); unsafe { #[cfg(any(feature = "pg13", feature = "pg14"))] pgrx::pg_sys::EmitWarningsOnPlaceholders(c"vchordrq".as_ptr()); @@ -331,3 +365,15 @@ pub fn vchordrq_io_rerank() -> Io { PostgresIo::ReadStream => Io::Stream, } } + +pub fn vchordrq_query_sampling_enable() -> bool { + VCHORDRQ_QUERY_SAMPLING_ENABLE.get() +} + +pub fn vchordrq_query_sampling_max_records() -> u32 { + VCHORDRQ_QUERY_SAMPLING_MAX_RECORDS.get() as u32 +} + +pub fn vchordrq_query_sampling_rate() -> f64 { + VCHORDRQ_QUERY_SAMPLING_RATE.get() +} diff --git a/src/index/scanners.rs b/src/index/scanners.rs index b4203b97..97d8b1b4 100644 --- a/src/index/scanners.rs +++ b/src/index/scanners.rs @@ -13,6 +13,7 @@ // Copyright (c) 2025 TensorChord Inc. use crate::index::fetcher::Fetcher; +use crate::recorder::Recorder; use algo::{Bump, Page, RelationPrefetch, RelationRead, RelationReadStream}; use pgrx::pg_sys::Datum; @@ -44,6 +45,7 @@ pub trait SearchBuilder: 'static { options: Self::Options, fetcher: impl Fetcher + 'b, bump: &'b impl Bump, + recorder: impl Recorder, ) -> Box + 'b> where R: RelationRead + RelationPrefetch + RelationReadStream, diff --git a/src/index/vchordg/am/mod.rs b/src/index/vchordg/am/mod.rs index dc6a233e..d3f5dca5 100644 --- a/src/index/vchordg/am/mod.rs +++ b/src/index/vchordg/am/mod.rs @@ -20,6 +20,7 @@ use crate::index::scanners::SearchBuilder; use crate::index::storage::PostgresRelation; use crate::index::vchordg::opclass::opfamily; use crate::index::vchordg::scanners::*; +use crate::recorder::DefaultRecorder; use pgrx::datum::Internal; use pgrx::pg_sys::Datum; use std::cell::LazyCell; @@ -372,6 +373,13 @@ pub unsafe extern "C-unwind" fn amrescan( ) }) }; + // Query recorde is disable for vchordg indexes for now. + let recorder = DefaultRecorder { + enable: false, + rate: None, + max_records: 0, + index: (*(*scan).indexRelation).rd_id.to_u32(), + }; // PAY ATTENTATION: `scanning` references `bump`, so `scanning` must be dropped before `bump`. let bump = scanner.bump.as_ref(); scanner.scanning = match opfamily { @@ -397,7 +405,7 @@ pub unsafe extern "C-unwind" fn amrescan( LazyCell::new(Box::new(move || { // only do this since `PostgresRelation` has no destructor let index = bump.alloc(index.clone()); - builder.build(index, options, fetcher, bump) + builder.build(index, options, fetcher, bump, recorder) })) } }; diff --git a/src/index/vchordg/scanners/default.rs b/src/index/vchordg/scanners/default.rs index 30683d8e..54a2c560 100644 --- a/src/index/vchordg/scanners/default.rs +++ b/src/index/vchordg/scanners/default.rs @@ -18,6 +18,7 @@ use crate::index::scanners::{Io, SearchBuilder}; use crate::index::vchordg::algo::*; use crate::index::vchordg::opclass::Opfamily; use crate::index::vchordg::scanners::SearchOptions; +use crate::recorder::{Recorder, halfvec_out, vector_out}; use algo::accessor::{Dot, L2S}; use algo::*; use distance::Distance; @@ -26,6 +27,7 @@ use std::num::NonZero; use vchordg::operator::{self}; use vchordg::types::{DistanceKind, OwnedVector, VectorKind}; use vchordg::*; +use vector::VectorOwned; use vector::vect::{VectBorrowed, VectOwned}; pub struct DefaultBuilder { @@ -78,6 +80,7 @@ impl SearchBuilder for DefaultBuilder { options: SearchOptions, _fetcher: impl Fetcher + 'b, bump: &'b impl Bump, + recorder: impl Recorder, ) -> Box + 'b> where R: RelationRead + RelationPrefetch + RelationReadStream, @@ -120,7 +123,7 @@ impl SearchBuilder for DefaultBuilder { match (opfamily.vector_kind(), opfamily.distance_kind()) { (VectorKind::Vecf32, DistanceKind::L2S) => { type Op = operator::Op, L2S>; - let unprojected = if let OwnedVector::Vecf32(vector) = vector { + let unprojected = if let OwnedVector::Vecf32(vector) = vector.clone() { VectBorrowed::new(bump.alloc_slice(vector.slice())) } else { unreachable!() @@ -215,7 +218,7 @@ impl SearchBuilder for DefaultBuilder { } (VectorKind::Vecf16, DistanceKind::L2S) => { type Op = operator::Op, L2S>; - let unprojected = if let OwnedVector::Vecf16(vector) = vector { + let unprojected = if let OwnedVector::Vecf16(vector) = vector.clone() { VectBorrowed::new(bump.alloc_slice(vector.slice())) } else { unreachable!() @@ -310,7 +313,7 @@ impl SearchBuilder for DefaultBuilder { } (VectorKind::Vecf32, DistanceKind::Dot) => { type Op = operator::Op, Dot>; - let unprojected = if let OwnedVector::Vecf32(vector) = vector { + let unprojected = if let OwnedVector::Vecf32(vector) = vector.clone() { VectBorrowed::new(bump.alloc_slice(vector.slice())) } else { unreachable!() @@ -405,7 +408,7 @@ impl SearchBuilder for DefaultBuilder { } (VectorKind::Vecf16, DistanceKind::Dot) => { type Op = operator::Op, Dot>; - let unprojected = if let OwnedVector::Vecf16(vector) = vector { + let unprojected = if let OwnedVector::Vecf16(vector) = vector.clone() { VectBorrowed::new(bump.alloc_slice(vector.slice())) } else { unreachable!() @@ -509,6 +512,16 @@ impl SearchBuilder for DefaultBuilder { } else { iter }; + if recorder.is_enabled() { + match &vector { + OwnedVector::Vecf32(v) => { + recorder.send(&vector_out(v.as_borrowed())); + } + OwnedVector::Vecf16(v) => { + recorder.send(&halfvec_out(v.as_borrowed())); + } + } + } Box::new(iter.map(move |(distance, pointer)| { let (key, _) = pointer_to_kv(pointer); (opfamily.output(distance), key, recheck) diff --git a/src/index/vchordrq/am/mod.rs b/src/index/vchordrq/am/mod.rs index b8603c35..45e774b4 100644 --- a/src/index/vchordrq/am/mod.rs +++ b/src/index/vchordrq/am/mod.rs @@ -20,6 +20,7 @@ use crate::index::scanners::SearchBuilder; use crate::index::storage::PostgresRelation; use crate::index::vchordrq::opclass::{Opfamily, opfamily}; use crate::index::vchordrq::scanners::*; +use crate::recorder::DefaultRecorder; use pgrx::datum::Internal; use pgrx::pg_sys::Datum; use std::cell::LazyCell; @@ -455,6 +456,16 @@ pub unsafe extern "C-unwind" fn amrescan( ) }) }; + let rate = match gucs::vchordrq_query_sampling_rate() { + 0.0 => None, + rate => Some(rate), + }; + let recorder = DefaultRecorder { + enable: gucs::vchordrq_query_sampling_enable(), + rate, + max_records: gucs::vchordrq_query_sampling_max_records(), + index: (*(*scan).indexRelation).rd_id.to_u32(), + }; // PAY ATTENTATION: `scanning` references `bump`, so `scanning` must be dropped before `bump`. let bump = scanner.bump.as_ref(); scanner.scanning = match opfamily { @@ -480,7 +491,7 @@ pub unsafe extern "C-unwind" fn amrescan( LazyCell::new(Box::new(move || { // only do this since `PostgresRelation` has no destructor let index = bump.alloc(index.clone()); - builder.build(index, options, fetcher, bump) + builder.build(index, options, fetcher, bump, recorder) })) } Opfamily::VectorMaxsim | Opfamily::HalfvecMaxsim => { @@ -500,7 +511,7 @@ pub unsafe extern "C-unwind" fn amrescan( LazyCell::new(Box::new(move || { // only do this since `PostgresRelation` has no destructor let index = bump.alloc(index.clone()); - builder.build(index, options, fetcher, bump) + builder.build(index, options, fetcher, bump, recorder) })) } }; diff --git a/src/index/vchordrq/scanners/default.rs b/src/index/vchordrq/scanners/default.rs index 8ac6af34..c61d0bab 100644 --- a/src/index/vchordrq/scanners/default.rs +++ b/src/index/vchordrq/scanners/default.rs @@ -19,6 +19,7 @@ use crate::index::vchordrq::algo::*; use crate::index::vchordrq::filter::filter; use crate::index::vchordrq::opclass::Opfamily; use crate::index::vchordrq::scanners::SearchOptions; +use crate::recorder::{Recorder, halfvec_out, vector_out}; use algo::accessor::{Dot, L2S}; use algo::prefetcher::*; use algo::*; @@ -82,6 +83,7 @@ impl SearchBuilder for DefaultBuilder { options: SearchOptions, mut fetcher: impl Fetcher + 'b, bump: &'b impl Bump, + recorder: impl Recorder, ) -> Box + 'b> where R: RelationRead + RelationPrefetch + RelationReadStream, @@ -120,7 +122,7 @@ impl SearchBuilder for DefaultBuilder { match (opfamily.vector_kind(), opfamily.distance_kind()) { (VectorKind::Vecf32, DistanceKind::L2S) => { type Op = operator::Op, L2S>; - let unprojected = if let OwnedVector::Vecf32(vector) = vector { + let unprojected = if let OwnedVector::Vecf32(vector) = vector.clone() { vector } else { unreachable!() @@ -260,7 +262,7 @@ impl SearchBuilder for DefaultBuilder { } (VectorKind::Vecf32, DistanceKind::Dot) => { type Op = operator::Op, Dot>; - let unprojected = if let OwnedVector::Vecf32(vector) = vector { + let unprojected = if let OwnedVector::Vecf32(vector) = vector.clone() { vector } else { unreachable!() @@ -400,7 +402,7 @@ impl SearchBuilder for DefaultBuilder { } (VectorKind::Vecf16, DistanceKind::L2S) => { type Op = operator::Op, L2S>; - let unprojected = if let OwnedVector::Vecf16(vector) = vector { + let unprojected = if let OwnedVector::Vecf16(vector) = vector.clone() { vector } else { unreachable!() @@ -540,7 +542,7 @@ impl SearchBuilder for DefaultBuilder { } (VectorKind::Vecf16, DistanceKind::Dot) => { type Op = operator::Op, Dot>; - let unprojected = if let OwnedVector::Vecf16(vector) = vector { + let unprojected = if let OwnedVector::Vecf16(vector) = vector.clone() { vector } else { unreachable!() @@ -689,6 +691,16 @@ impl SearchBuilder for DefaultBuilder { } else { iter }; + if recorder.is_enabled() { + match &vector { + OwnedVector::Vecf32(v) => { + recorder.send(&vector_out(v.as_borrowed())); + } + OwnedVector::Vecf16(v) => { + recorder.send(&halfvec_out(v.as_borrowed())); + } + } + } Box::new(iter.map(move |(distance, pointer)| { let (key, _) = pointer_to_kv(pointer); (distance, key, recheck) diff --git a/src/index/vchordrq/scanners/maxsim.rs b/src/index/vchordrq/scanners/maxsim.rs index eecf8158..65603995 100644 --- a/src/index/vchordrq/scanners/maxsim.rs +++ b/src/index/vchordrq/scanners/maxsim.rs @@ -18,6 +18,7 @@ use crate::index::vchordrq::algo::*; use crate::index::vchordrq::filter::filter; use crate::index::vchordrq::opclass::Opfamily; use crate::index::vchordrq::scanners::SearchOptions; +use crate::recorder::Recorder; use algo::accessor::Dot; use algo::prefetcher::*; use algo::*; @@ -72,6 +73,7 @@ impl SearchBuilder for MaxsimBuilder { options: SearchOptions, mut fetcher: impl Fetcher + 'b, bump: &'b impl Bump, + _sender: impl Recorder, ) -> Box + 'b> where R: RelationRead + RelationPrefetch + RelationReadStream, diff --git a/src/lib.rs b/src/lib.rs index b458d51c..25ff506d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,7 @@ mod datatype; mod index; +mod recorder; mod upgrade; pgrx::pg_module_magic!( @@ -53,6 +54,7 @@ unsafe extern "C-unwind" fn _pg_init() { } IS_MAIN.set(true); index::init(); + recorder::init(); unsafe { #[cfg(any(feature = "pg13", feature = "pg14"))] pgrx::pg_sys::EmitWarningsOnPlaceholders(c"vchord".as_ptr()); diff --git a/src/recorder/hook.rs b/src/recorder/hook.rs new file mode 100644 index 00000000..8dadee99 --- /dev/null +++ b/src/recorder/hook.rs @@ -0,0 +1,53 @@ +// This software is licensed under a dual license model: +// +// GNU Affero General Public License v3 (AGPLv3): You may use, modify, and +// distribute this software under the terms of the AGPLv3. +// +// Elastic License v2 (ELv2): You may also use, modify, and distribute this +// software under the Elastic License v2, which has specific restrictions. +// +// We welcome any commercial collaboration or support. For inquiries +// regarding the licenses, please contact us at: +// vectorchord-inquiry@tensorchord.ai +// +// Copyright (c) 2025 TensorChord Inc. + +use crate::recorder::worker::{delete_database, delete_index}; + +static mut PREV_OBJECT_ACCESS: pgrx::pg_sys::object_access_hook_type = None; + +#[pgrx::pg_guard] +unsafe extern "C-unwind" fn recorder_object_access( + access: pgrx::pg_sys::ObjectAccessType::Type, + class_id: pgrx::pg_sys::Oid, + object_id: pgrx::pg_sys::Oid, + sub_id: ::std::os::raw::c_int, + arg: *mut ::std::os::raw::c_void, +) { + unsafe { + use pgrx::pg_sys::submodules::ffi::pg_guard_ffi_boundary; + if let Some(prev_object_access_hook) = PREV_OBJECT_ACCESS { + #[allow(ffi_unwind_calls, reason = "protected by pg_guard_ffi_boundary")] + pg_guard_ffi_boundary(|| { + prev_object_access_hook(access, class_id, object_id, sub_id, arg) + }); + } + if access == pgrx::pg_sys::ObjectAccessType::OAT_DROP + && class_id == pgrx::pg_sys::DatabaseRelationId + { + delete_database(object_id.to_u32()); + } else if access == pgrx::pg_sys::ObjectAccessType::OAT_DROP + && class_id == pgrx::pg_sys::RelationRelationId + { + delete_index(object_id.to_u32()); + } + } +} + +pub fn init() { + assert!(crate::is_main()); + unsafe { + PREV_OBJECT_ACCESS = pgrx::pg_sys::object_access_hook; + pgrx::pg_sys::object_access_hook = Some(recorder_object_access); + } +} diff --git a/src/recorder/mod.rs b/src/recorder/mod.rs new file mode 100644 index 00000000..77a09d4b --- /dev/null +++ b/src/recorder/mod.rs @@ -0,0 +1,26 @@ +// This software is licensed under a dual license model: +// +// GNU Affero General Public License v3 (AGPLv3): You may use, modify, and +// distribute this software under the terms of the AGPLv3. +// +// Elastic License v2 (ELv2): You may also use, modify, and distribute this +// software under the Elastic License v2, which has specific restrictions. +// +// We welcome any commercial collaboration or support. For inquiries +// regarding the licenses, please contact us at: +// vectorchord-inquiry@tensorchord.ai +// +// Copyright (c) 2025 TensorChord Inc. + +pub use text::{halfvec_out, vector_out}; +pub use types::{DefaultRecorder, Recorder}; +pub use worker::dump; + +mod hook; +mod text; +mod types; +mod worker; + +pub fn init() { + hook::init(); +} diff --git a/src/recorder/text.rs b/src/recorder/text.rs new file mode 100644 index 00000000..321599af --- /dev/null +++ b/src/recorder/text.rs @@ -0,0 +1,40 @@ +// This software is licensed under a dual license model: +// +// GNU Affero General Public License v3 (AGPLv3): You may use, modify, and +// distribute this software under the terms of the AGPLv3. +// +// Elastic License v2 (ELv2): You may also use, modify, and distribute this +// software under the Elastic License v2, which has specific restrictions. +// +// We welcome any commercial collaboration or support. For inquiries +// regarding the licenses, please contact us at: +// vectorchord-inquiry@tensorchord.ai +// +// Copyright (c) 2025 TensorChord Inc. + +use simd::f16; +use vector::vect::VectBorrowed; + +pub fn vector_out(vector: VectBorrowed<'_, f32>) -> String { + let mut result = String::from("["); + for x in vector.slice() { + if !result.ends_with('[') { + result.push(','); + } + result.push_str(&x.to_string()); + } + result.push(']'); + result +} + +pub fn halfvec_out(vector: VectBorrowed<'_, f16>) -> String { + let mut result = String::from("["); + for x in vector.slice() { + if !result.ends_with('[') { + result.push(','); + } + result.push_str(&x.to_string()); + } + result.push(']'); + result +} diff --git a/src/recorder/types.rs b/src/recorder/types.rs new file mode 100644 index 00000000..39ec9a97 --- /dev/null +++ b/src/recorder/types.rs @@ -0,0 +1,62 @@ +// This software is licensed under a dual license model: +// +// GNU Affero General Public License v3 (AGPLv3): You may use, modify, and +// distribute this software under the terms of the AGPLv3. +// +// Elastic License v2 (ELv2): You may also use, modify, and distribute this +// software under the Elastic License v2, which has specific restrictions. +// +// We welcome any commercial collaboration or support. For inquiries +// regarding the licenses, please contact us at: +// vectorchord-inquiry@tensorchord.ai +// +// Copyright (c) 2025 TensorChord Inc. + +use crate::recorder::worker::push; +use rand::Rng; +use std::cell::RefMut; + +pub trait Recorder { + fn is_enabled(&self) -> bool; + fn send(&self, sample: &str); +} + +#[derive(Debug)] +pub struct DefaultRecorder { + pub enable: bool, + pub rate: Option, + pub max_records: u32, + pub index: u32, +} + +pub struct PgRefCell(std::cell::RefCell); + +unsafe impl Send for PgRefCell {} +unsafe impl Sync for PgRefCell {} + +impl PgRefCell { + pub const fn new(x: T) -> Self { + Self(std::cell::RefCell::new(x)) + } + pub fn borrow_mut(&self) -> RefMut<'_, T> { + assert!( + crate::is_main(), + "cannot borrow the value outside main thread" + ); + self.0.borrow_mut() + } +} + +impl Recorder for DefaultRecorder { + fn is_enabled(&self) -> bool { + self.enable + } + fn send(&self, sample: &str) { + if let Some(rate) = self.rate { + let mut rng = rand::rng(); + if rng.random_bool(rate) { + push(self.index, sample, self.max_records); + } + } + } +} diff --git a/src/recorder/worker.rs b/src/recorder/worker.rs new file mode 100644 index 00000000..a6da90f0 --- /dev/null +++ b/src/recorder/worker.rs @@ -0,0 +1,146 @@ +// This software is licensed under a dual license model: +// +// GNU Affero General Public License v3 (AGPLv3): You may use, modify, and +// distribute this software under the terms of the AGPLv3. +// +// Elastic License v2 (ELv2): You may also use, modify, and distribute this +// software under the Elastic License v2, which has specific restrictions. +// +// We welcome any commercial collaboration or support. For inquiries +// regarding the licenses, please contact us at: +// vectorchord-inquiry@tensorchord.ai +// +// Copyright (c) 2025 TensorChord Inc. +use crate::recorder::types::PgRefCell; +use std::cell::RefMut; +use std::fs; +use std::path::Path; + +// Safety: The directory name must start with "pgsql_tmp" to be excluded by pg_basebackup +const RECORDER_DIR: &str = "pgsql_tmp_vchord_sampling"; +const RECORDER_VERSION: u32 = 1; + +static CONNECTION: PgRefCell> = + PgRefCell::>::new(None); + +fn get<'a>() -> Option> { + if unsafe { !pgrx::pg_sys::IsBackendPid(pgrx::pg_sys::MyProcPid) } { + return None; + } + let database_oid = unsafe { pgrx::pg_sys::MyDatabaseId.to_u32() }; + if database_oid == 0 { + return None; + } + let mut connection = CONNECTION.borrow_mut(); + if connection.is_none() + && let Err(err) = || -> rusqlite::Result<()> { + if !Path::new(RECORDER_DIR).exists() { + let _ = fs::create_dir_all(RECORDER_DIR); + } + let p = format!("{RECORDER_DIR}/database_{database_oid}.sqlite"); + let mut conn = rusqlite::Connection::open(&p)?; + conn.pragma_update(Some("main"), "journal_mode", "WAL")?; + conn.pragma_update(Some("main"), "synchronous", "NORMAL")?; + let tx = conn.transaction()?; + let version: u32 = tx + .pragma_query_value(Some("main"), "user_version", |row| row.get(0)) + .unwrap_or(RECORDER_VERSION); + if version != RECORDER_VERSION && version != 0 { + let mut statement = tx.prepare( + "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'index_%';", + )?; + let tables = statement.query_map((), |row| row.get::(0))?; + for name in tables.into_iter().flatten() { + let drop_statement = format!("DROP TABLE IF EXISTS {name}"); + tx.execute(&drop_statement, ())?; + } + } + tx.pragma_update(Some("main"), "user_version", RECORDER_VERSION)?; + tx.commit()?; + let _ = connection.insert(conn); + Ok(()) + }() + { + if err.sqlite_error_code() == Some(rusqlite::ErrorCode::DatabaseCorrupt) { + delete_database(database_oid); + } + pgrx::debug1!("Recorder: Error initializing database: {}", err); + return None; + } + RefMut::filter_map(connection, |c| c.as_mut()).ok() +} + +pub fn push(index: u32, sample: &str, max_records: u32) { + let mut connection = match get() { + Some(c) => c, + None => return, + }; + let init_statement = format!( + " + CREATE TABLE IF NOT EXISTS index_{index} (sample TEXT, create_at REAL); + CREATE INDEX IF NOT EXISTS i ON index_{index} (create_at); + " + ); + let insert_statement = + format!("INSERT INTO index_{index} (sample, create_at) VALUES (?1, unixepoch('subsec'))"); + let count_statement = format!("SELECT COUNT(create_at) FROM index_{index}"); + let maintain_statement = format!( + "DELETE FROM index_{index} WHERE rowid = ( + SELECT rowid FROM index_{index} ORDER BY create_at ASC LIMIT ?1);" + ); + if let Err(err) = || -> rusqlite::Result<()> { + let tx = connection.transaction()?; + tx.execute_batch(&init_statement)?; + tx.prepare_cached(&insert_statement)?.execute((sample,))?; + let records = tx.query_one(&count_statement, (), |row| row.get::(0))?; + if records > max_records { + tx.execute(&maintain_statement, (records - max_records,))?; + } + tx.commit()?; + Ok(()) + }() { + pgrx::debug1!("Recorder: Error pushing sample: {}", err); + } +} + +pub fn delete_index(index: u32) { + let connection = match get() { + Some(c) => c, + None => return, + }; + let drop_statement = format!("DROP TABLE IF EXISTS index_{index}"); + if let Err(e) = connection.execute(&drop_statement, ()) { + pgrx::debug1!("Recorder: Error deleting index table: {}", e); + }; +} + +pub fn delete_database(database_oid: u32) { + let _ = fs::remove_file(format!("{RECORDER_DIR}/database_{database_oid}.sqlite")); + let _ = fs::remove_file(format!("{RECORDER_DIR}/database_{database_oid}.sqlite-shm")); + let _ = fs::remove_file(format!("{RECORDER_DIR}/database_{database_oid}.sqlite-wal")); +} + +pub fn dump(index: u32) -> Vec { + let connection = match get() { + Some(c) => c, + None => return Vec::new(), + }; + let load_statement = format!("SELECT sample FROM index_{index} ORDER BY create_at DESC"); + match || -> rusqlite::Result> { + let mut stmt = connection.prepare(&load_statement)?; + let mut rows = stmt.query(())?; + let mut result = Vec::new(); + while let Some(row) = rows.next()? { + if let Ok(sample) = row.get::(0) { + result.push(sample); + } + } + Ok(result) + }() { + Ok(v) => v, + Err(e) => { + pgrx::debug1!("Recorder: Error loading samples: {}", e); + Vec::new() + } + } +} diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index d6c186ff..a520ba39 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -142,6 +142,102 @@ IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vchord_vector_ CREATE FUNCTION quantize_to_scalar8(halfvec) RETURNS scalar8 IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vchord_halfvec_quantize_to_scalar8_wrapper'; +CREATE FUNCTION vchordrq_sampled_vectors(regclass) +RETURNS SETOF TEXT +STRICT LANGUAGE c AS 'MODULE_PATHNAME', '_vchordrq_sampled_vectors_wrapper'; + +CREATE OR REPLACE FUNCTION vchordrq_sampled_queries(regclass) +RETURNS TABLE( + schema_name NAME, + table_name NAME, + column_name NAME, + operator TEXT, + vector_text TEXT +) +LANGUAGE plpgsql +STRICT AS $$ +DECLARE + ext_schema TEXT; + query_text TEXT; +BEGIN + SELECT n.nspname + INTO ext_schema + FROM pg_catalog.pg_extension e + JOIN pg_catalog.pg_namespace n ON n.oid = e.extnamespace + WHERE e.extname = 'vchord'; + + IF ext_schema IS NULL THEN + RAISE EXCEPTION 'vchord is not installed'; + END IF; + + query_text := format( + $q$ + WITH index_metadata AS ( + SELECT + NS.nspname AS schema_name, + C.relname AS table_name, + PA.attname AS column_name, + CASE + WHEN OP.opcname LIKE '%%l2%%' THEN '<->' + WHEN OP.opcname LIKE '%%ip%%' THEN '<#>' + WHEN OP.opcname LIKE '%%cosine%%' THEN '<=>' + ELSE '' + END AS operator + FROM + pg_catalog.pg_index X + JOIN + pg_catalog.pg_class C ON C.oid = X.indrelid + JOIN + pg_catalog.pg_namespace NS ON C.relnamespace = NS.oid + JOIN + pg_catalog.pg_class I ON I.oid = X.indexrelid + JOIN + pg_catalog.pg_am A ON A.oid = I.relam + LEFT JOIN + pg_catalog.pg_opclass AS OP ON OP.oid = X.indclass[0] + LEFT JOIN + pg_catalog.pg_attribute PA ON PA.attrelid = X.indrelid AND PA.attnum = X.indkey[0] + WHERE + A.amname = 'vchordrq' + AND C.relkind = 'r' + AND X.indnatts = 1 + AND X.indexrelid = %1$s + ) + SELECT + im.schema_name, + im.table_name, + im.column_name, + im.operator, + s.vector_text + FROM + index_metadata im, + LATERAL %2$I.vchordrq_sampled_vectors(%1$s) AS s(vector_text); + $q$, + $1::oid, + ext_schema + ); + RETURN QUERY EXECUTE query_text; +END; +$$; + +CREATE VIEW vchordrq_sampled_queries AS +SELECT + record.schema_name, + record.table_name, + record.column_name, + record.operator, + record.vector_text +FROM + ( + SELECT i.oid + FROM pg_catalog.pg_class AS i + JOIN pg_catalog.pg_index AS ix ON i.oid = ix.indexrelid + JOIN pg_catalog.pg_opclass AS opc ON ix.indclass[0] = opc.oid + JOIN pg_catalog.pg_am AS am ON opc.opcmethod = am.oid + WHERE am.amname = 'vchordrq' + ) AS index_oids +CROSS JOIN LATERAL vchordrq_sampled_queries(index_oids.oid::regclass) AS record; + CREATE FUNCTION vchordrq_amhandler(internal) RETURNS index_am_handler IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vchordrq_amhandler_wrapper'; diff --git a/tests/vchordrq/recall.slt b/tests/vchordrq/recall.slt index 04584fd0..ca3fb095 100644 --- a/tests/vchordrq/recall.slt +++ b/tests/vchordrq/recall.slt @@ -1,12 +1,12 @@ statement ok -CREATE TABLE t (val vector(3)); +CREATE TABLE t (id SERIAL PRIMARY KEY, val vector(3)); statement ok INSERT INTO t (val) SELECT ARRAY[i * 0.0001, i * 0.00005, i * 0.0002]::vector(3) FROM generate_series(1, 10000) as s(i); statement ok -CREATE INDEX ON t USING vchordrq (val vector_l2_ops); +CREATE INDEX idx1 ON t USING vchordrq (val vector_l2_ops); statement ok SET vchordrq.epsilon = 0.8; @@ -53,4 +53,130 @@ SHOW vchordrq.epsilon; 0.8 statement ok -DROP TABLE t; \ No newline at end of file +CREATE TABLE t_dim4 (val vector(4), id SERIAL PRIMARY KEY); + +statement ok +INSERT INTO t_dim4 (val) +SELECT ARRAY[i * 0.0001, i * 0.00005, i * 0.0002, i * 0.001]::vector(4) FROM generate_series(1, 10000) as s(i); + +statement ok +CREATE INDEX idx2 ON t_dim4 USING vchordrq (val vector_l2_ops); + +statement ok +ALTER SYSTEM SET vchordrq.query_sampling_max_records = 1; + +statement ok +ALTER SYSTEM SET vchordrq.query_sampling_rate = 1; + +statement ok +ALTER SYSTEM SET vchordrq.query_sampling_enable = on; + +statement ok +SELECT pg_reload_conf(); + +query I retry 5 backoff 1s +SHOW vchordrq.query_sampling_enable; +---- +on + +statement ok +SELECT * from t ORDER BY val <-> '[0.50, 0.25, 1.00]'; + +statement ok +SELECT * from t_dim4 ORDER BY val <-> '[1.00, 0.50, 0.25, 0]'; + +query I +SELECT vector_text from vchordrq_sampled_queries('idx1'); +---- +[0.5,0.25,1] + +query I +SELECT vector_text from vchordrq_sampled_queries('idx2'); +---- +[1,0.5,0.25,0] + +query I +SELECT COUNT(*) from vchordrq_sampled_queries; +---- +2 + +statement ok +SELECT * from t_dim4 ORDER BY val <-> '[2.1, 0.3, 0.7, 0.9]'; + +query I +SELECT * from vchordrq_sampled_queries('idx2'); +---- +public t_dim4 val <-> [2.1,0.3,0.7,0.9] + +query I +SELECT AVG(recall_value) +FROM ( + SELECT + vchordrq_evaluate_query_recall( + query => format( + 'SELECT ctid FROM %I.%I ORDER BY %I %s ''%s'' LIMIT 10', + lq.schema_name, + lq.table_name, + lq.column_name, + lq.operator, + lq.vector_text + ) + ) AS recall_value + FROM + vchordrq_sampled_queries('idx2') AS lq +) AS eval_results; +---- +1 + +statement ok +CREATE TABLE t_expr (id integer); + +statement ok +INSERT INTO t_expr (id) SELECT id FROM generate_series(1, 10000) s(id); + +statement ok +CREATE INDEX idx3 ON t_expr USING vchordrq ((ARRAY[id::real, id::real, id::real]::vector(3)) vector_l2_ops); + +statement ok +SELECT id FROM t_expr ORDER BY ARRAY[id::real, id::real, id::real]::vector(3) <-> '[1.0000, 0.5000, 0.2500]' limit 1; + +query I +SELECT column_name from vchordrq_sampled_queries('idx3'); +---- +NULL + +query I +SELECT vector_text from vchordrq_sampled_queries('idx3'); +---- +[1,0.5,0.25] + +statement ok +SET search_path='@'; + +query I +SELECT vector_text from public.vchordrq_sampled_queries('public.idx3'); +---- +[1,0.5,0.25] + +query I +SELECT COUNT(*) from public.vchordrq_sampled_queries; +---- +3 + +statement ok +RESET search_path; + +statement ok +ALTER SYSTEM RESET vchordrq.query_sampling_enable; + +statement ok +ALTER SYSTEM RESET vchordrq.query_sampling_max_records; + +statement ok +ALTER SYSTEM RESET vchordrq.query_sampling_rate; + +statement ok +SELECT pg_reload_conf(); + +statement ok +DROP TABLE t, t_dim4, t_expr; \ No newline at end of file