diff --git a/Cargo.lock b/Cargo.lock index 6326b4e..2ddf99b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18,6 +18,7 @@ dependencies = [ "core", "index", "storage", + "tempfile", ] [[package]] @@ -141,6 +142,22 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "errno" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" +dependencies = [ + "libc", + "windows-sys 0.60.2", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "getrandom" version = "0.3.3" @@ -210,7 +227,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-targets", + "windows-targets 0.53.2", ] [[package]] @@ -240,6 +257,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linux-raw-sys" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" + [[package]] name = "lz4-sys" version = "1.11.1+lz4-1.10.0" @@ -272,6 +295,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + [[package]] name = "peeking_take_while" version = "0.1.2" @@ -369,6 +398,19 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +[[package]] +name = "rustix" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +dependencies = [ + "bitflags 2.9.1", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.59.0", +] + [[package]] name = "serde" version = "1.0.219" @@ -425,6 +467,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tempfile" +version = "3.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" +dependencies = [ + "fastrand", + "getrandom", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + [[package]] name = "unicode-ident" version = "1.0.18" @@ -446,64 +501,146 @@ dependencies = [ "wit-bindgen-rt", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.2", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + [[package]] name = "windows-targets" version = "0.53.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c66f69fcc9ce11da9966ddb31a40968cad001c5bedeb5c2b82ede4253ab48aef" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + [[package]] name = "windows_aarch64_gnullvm" version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + [[package]] name = "windows_aarch64_msvc" version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + [[package]] name = "windows_i686_gnu" version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + [[package]] name = "windows_i686_gnullvm" version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + [[package]] name = "windows_i686_msvc" version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + [[package]] name = "windows_x86_64_gnu" version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + [[package]] name = "windows_x86_64_gnullvm" version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "windows_x86_64_msvc" version = "0.53.0" diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index ef365d5..c21bcb7 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -7,4 +7,5 @@ license = "MIT" [dependencies] core = { path = "../core" } index = { path = "../index" } -storage = { path = "../storage" } \ No newline at end of file +storage = { path = "../storage" } +tempfile = "3.20.0" diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index 36ef19d..4f53fe9 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -1,54 +1,219 @@ -use core::DbError; - -// use core::{DenseVector, Payload, Point, PointId}; -// use std::{fmt::Error, path::PathBuf, sync::Arc}; - -// use index::{IndexType, VectorIndex}; -// use storage::{StorageEngine, StorageType}; - -// pub struct VectorDb { -// storage: Arc, -// index: Arc, -// } - -// impl VectorDb { -// pub fn new(storage: Arc, index: Arc) -> Self { -// Self { storage, index } -// } - -// pub fn insert(&self, vector: DenseVector, payload: Payload) -> Result { -// // Add to storage and index -// Ok(0) -// } - -// pub fn delete(&self, id: PointId) -> Result<(), Error> { -// // Remove from storage -// // Remove from index -// Ok(()) -// } - -// pub fn get(&self, id: PointId) -> Result, Error> { -// // Search for the Point with given id in storage -// Ok(None) -// } - -// pub fn search(&self, query: DenseVector, limit: usize) -> Result, Error> { -// // Use vector index to find similar vectors -// // Return vector ids with similarity scores -// Ok(vec![]) -// } -// } - -// pub struct DbConfig { -// pub storage_type: StorageType, -// pub index_type: IndexType, -// pub data_path: PathBuf, -// pub dimension: usize, -// } - -pub fn init_api_server() -> Result<(), DbError> { +use core::{DbError, IndexedVector, Similarity}; + +use core::{DenseVector, Payload, Point, PointId}; +use std::path::PathBuf; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, RwLock}; + +use index::flat::FlatIndex; +use index::{IndexType, VectorIndex}; +use storage::rocks_db::RocksDbStorage; +use storage::{StorageEngine, StorageType}; + +static NEXT_ID: AtomicU64 = AtomicU64::new(1); + +fn generate_point_id() -> u64 { + NEXT_ID.fetch_add(1, Ordering::Relaxed) +} + +pub struct VectorDb { + storage: Arc, + index: Arc>, // Using a RwLock instead of Mutex to improve concurrency +} + +impl VectorDb { + fn _new(storage: Arc, index: Arc>) -> Self { + Self { storage, index } + } + + //TODO: Make this an atomic operation + pub fn insert(&self, vector: DenseVector, payload: Payload) -> Result { + // Generate a new point id + let point_id = generate_point_id(); + self.storage + .insert_point(point_id, Some(vector.clone()), Some(payload))?; + + // Get write lock on the index + let mut index = self.index.write().map_err(|_| DbError::LockError)?; + index.insert(IndexedVector { + vector, + id: point_id, + })?; + + Ok(point_id) + } + + //TODO: Make this an atomic operation + pub fn delete(&self, id: PointId) -> Result<(), DbError> { + // Remove from storage + self.storage.delete_point(id)?; + // Remove from index + let mut index = self.index.write().map_err(|_| DbError::LockError)?; + index.delete(id)?; + Ok(()) + } + + pub fn get(&self, id: PointId) -> Result, DbError> { + // Search for the Point with given id in storage + let payload = self.storage.get_payload(id)?; + let vector = self.storage.get_vector(id)?; + if payload.is_some() || vector.is_some() { + Ok(Some(Point { + id, + payload, + vector, + })) + } else { + Ok(None) + } + } + + pub fn search( + &self, + query: DenseVector, + similarity: Similarity, + limit: usize, + ) -> Result, DbError> { + // Use vector index to find similar vectors + let index = self.index.read().map_err(|_| DbError::LockError)?; + + //TODO: Add feat of returning similarity scores in the search + let vectors = index.search(query, similarity, limit)?; + + Ok(vectors) + } +} + +pub struct DbConfig { + pub storage_type: StorageType, + pub index_type: IndexType, + pub data_path: PathBuf, + pub dimension: usize, +} + +pub fn init_api(config: DbConfig) -> Result { // Initialize the storage engine + let storage = match config.storage_type { + StorageType::RocksDb => Arc::new(RocksDbStorage::new(config.data_path)?), + _ => Arc::new(RocksDbStorage::new(config.data_path)?), + }; + // Initialize the vector index - // Start server - Ok(()) + let index: Arc> = match config.index_type { + IndexType::Flat => Arc::new(RwLock::new(FlatIndex::new())), + _ => Arc::new(RwLock::new(FlatIndex::new())), + }; + + // Init the db + let db = VectorDb::_new(storage, index); + + Ok(db) +} + +#[cfg(test)] +mod tests { + + // TODO: Add more exhaustive tests + + use super::*; + use tempfile::tempdir; + + // Helper function to create a test database + fn create_test_db() -> VectorDb { + let temp_dir = tempdir().unwrap(); + let config = DbConfig { + storage_type: StorageType::RocksDb, + index_type: IndexType::Flat, + data_path: temp_dir.path().to_path_buf(), + dimension: 3, + }; + init_api(config).unwrap() + } + + #[test] + fn test_insert_and_get() { + let db = create_test_db(); + let vector = vec![1.0, 2.0, 3.0]; + let payload = Payload {}; + + // Test insert + let id = db.insert(vector.clone(), payload).unwrap(); + assert!(id > 0); + + // Test get + let point = db.get(id).unwrap().unwrap(); + assert_eq!(point.id, id); + assert_eq!(point.vector.as_ref().unwrap(), &vector); + assert_eq!(point.payload.as_ref().unwrap(), &payload); + } + + #[test] + fn test_delete() { + let db = create_test_db(); + let vector = vec![1.0, 2.0, 3.0]; + let payload = Payload {}; + + // Insert a point + let id = db.insert(vector, payload).unwrap(); + + assert!(db.get(id).unwrap().is_some()); + db.delete(id).unwrap(); + assert!(db.get(id).unwrap().is_none()); + } + + #[test] + fn test_search() { + let db = create_test_db(); + + // Insert some points + let vectors = vec![ + vec![1.0, 0.0, 0.0], + vec![0.0, 1.0, 0.0], + vec![0.0, 0.0, 1.0], + ]; + + let mut ids = Vec::new(); + for vector in vectors { + let id = db.insert(vector, Payload {}).unwrap(); + ids.push(id); + } + + // Search for the closest vector to [1.0, 0.1, 0.1] + let query = vec![1.0, 0.1, 0.1]; + let results = db.search(query, Similarity::Cosine, 1).unwrap(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0], ids[0]); // The first vector should be closest + } + + #[test] + fn test_search_limit() { + let db = create_test_db(); + + // Insert 5 points + let mut ids = Vec::new(); + for i in 0..5 { + let vector = vec![i as f32, 0.0, 0.0]; + let id = db.insert(vector, Payload {}).unwrap(); + ids.push(id); + } + + // Search with limit 3 + let query = vec![0.0, 0.0, 0.0]; + let results = db.search(query, Similarity::Euclidean, 3).unwrap(); + + assert_eq!(results.len(), 3); + } + + #[test] + fn test_empty_database() { + let db = create_test_db(); + + // Get non-existent point + assert!(db.get(999).unwrap().is_none()); + + let query = vec![1.0, 2.0, 3.0]; + let results = db.search(query, Similarity::Cosine, 10).unwrap(); + assert_eq!(results.len(), 0); + } } diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index a762944..1079c5e 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -5,4 +5,5 @@ pub enum DbError { SerializationError(String), DeserializationError, IndexError(String), + LockError, } diff --git a/crates/server/src/main.rs b/crates/server/src/main.rs index 9ab5dea..5e83e23 100644 --- a/crates/server/src/main.rs +++ b/crates/server/src/main.rs @@ -1,7 +1,7 @@ use std::fmt::Error; // Import from other crates // use core; // Import the entire crate -use api::init_api_server; +// use api::init_api_server; // use index::some_module; // Import specific module // use storage::{Type1, Type2}; // Import specific types // use api::prelude::*; // Import everything from prelude @@ -10,6 +10,6 @@ fn main() -> Result<(), Error> { // Start tracing // Load configs for DB // Start API and/or gRPC server - let _ = init_api_server(); + // let _ = init_api_server(); Ok(()) }