diff --git a/Cargo.toml b/Cargo.toml index 446fedd35..5cf865cd8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -exclude = ["crates/micro-hnsw-wasm", "examples/ruvLLM/esp32", "examples/ruvLLM/esp32-flash", "examples/edge-net", "examples/data"] +exclude = ["crates/micro-hnsw-wasm", "crates/ruvector-hyperbolic-hnsw", "crates/ruvector-hyperbolic-hnsw-wasm", "examples/ruvLLM/esp32", "examples/ruvLLM/esp32-flash", "examples/edge-net", "examples/data"] members = [ "crates/ruvector-core", "crates/ruvector-node", diff --git a/crates/ruvector-hyperbolic-hnsw-wasm/Cargo.lock b/crates/ruvector-hyperbolic-hnsw-wasm/Cargo.lock new file mode 100644 index 000000000..5ff725b50 --- /dev/null +++ b/crates/ruvector-hyperbolic-hnsw-wasm/Cargo.lock @@ -0,0 +1,947 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "bumpalo" +version = "3.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" + +[[package]] +name = "bytemuck" +version = "1.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4" + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "cc" +version = "1.2.52" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd4932aefd12402b36c60956a4fe0035421f544799057659ff86f923657aada3" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "console_error_panic_hook" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06aeb73f470f66dcdbf7223caeebb85984942f22f1adb2a088cf9668146bbbc" +dependencies = [ + "cfg-if", + "wasm-bindgen", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "find-msvc-tools" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f449e6c6c08c865631d4890cfacf252b3d396c9bcc83adb6623cdb02a8336c41" + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi", + "wasm-bindgen", +] + +[[package]] +name = "glam" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "333928d5eb103c5d4050533cec0384302db6be8ef7d3cebd30ec6a35350353da" + +[[package]] +name = "glam" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3abb554f8ee44336b72d522e0a7fe86a29e09f839a36022fa869a7dfe941a54b" + +[[package]] +name = "glam" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4126c0479ccf7e8664c36a2d719f5f2c140fbb4f9090008098d2c291fa5b3f16" + +[[package]] +name = "glam" +version = "0.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01732b97afd8508eee3333a541b9f7610f454bb818669e66e90f5f57c93a776" + +[[package]] +name = "glam" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525a3e490ba77b8e326fb67d4b44b4bd2f920f44d4cc73ccec50adc68e3bee34" + +[[package]] +name = "glam" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b8509e6791516e81c1a630d0bd7fbac36d2fa8712a9da8662e716b52d5051ca" + +[[package]] +name = "glam" +version = "0.20.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f43e957e744be03f5801a55472f593d43fabdebf25a4585db250f04d86b1675f" + +[[package]] +name = "glam" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "518faa5064866338b013ff9b2350dc318e14cc4fcd6cb8206d7e7c9886c98815" + +[[package]] +name = "glam" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12f597d56c1bd55a811a1be189459e8fad2bbc272616375602443bdfb37fa774" + +[[package]] +name = "glam" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e4afd9ad95555081e109fe1d21f2a30c691b5f0919c67dfa690a2e1eb6bd51c" + +[[package]] +name = "glam" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5418c17512bdf42730f9032c74e1ae39afc408745ebb2acf72fbc4691c17945" + +[[package]] +name = "glam" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "151665d9be52f9bb40fc7966565d39666f2d1e69233571b71b87791c7e0528b3" + +[[package]] +name = "glam" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e05e7e6723e3455f4818c7b26e855439f7546cf617ef669d1adedb8669e5cb9" + +[[package]] +name = "glam" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "779ae4bf7e8421cf91c0b3b64e7e8b40b862fba4d393f59150042de7c4965a94" + +[[package]] +name = "glam" +version = "0.29.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8babf46d4c1c9d92deac9f7be466f76dfc4482b6452fc5024b5e8daf6ffeb3ee" + +[[package]] +name = "glam" +version = "0.30.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19fc433e8437a212d1b6f1e68c7824af3aed907da60afa994e7f542d18d12aa9" + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "js-sys" +version = "0.3.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c942ebf8e95485ca0d52d97da7c5a2c387d0e7f0ba4c35e93bfcaee045955b3" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "libc" +version = "0.2.180" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" + +[[package]] +name = "libm" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + +[[package]] +name = "minicov" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4869b6a491569605d66d3952bcdf03df789e5b536e5f0cf7758a7f08a55ae24d" +dependencies = [ + "cc", + "walkdir", +] + +[[package]] +name = "nalgebra" +version = "0.34.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4d5b3eff5cd580f93da45e64715e8c20a3996342f1e466599cf7a267a0c2f5f" +dependencies = [ + "approx", + "glam 0.14.0", + "glam 0.15.2", + "glam 0.16.0", + "glam 0.17.3", + "glam 0.18.0", + "glam 0.19.0", + "glam 0.20.5", + "glam 0.21.3", + "glam 0.22.0", + "glam 0.23.0", + "glam 0.24.2", + "glam 0.25.0", + "glam 0.27.0", + "glam 0.28.0", + "glam 0.29.3", + "glam 0.30.10", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "973e7178a678cfd059ccec50887658d482ce16b0aa9da3888ddeab5cd5eb4889" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "ndarray" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "portable-atomic" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "535d180e0ecab6268a3e718bb9fd44db66bbbc256257165fc699dadf70d16fe7" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc74d9a594b72ae6656596548f56f667211f8a97b3d4c3d467150794690dc40a" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", + "wasm_sync", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", + "wasm_sync", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ruvector-hyperbolic-hnsw" +version = "0.1.0" +dependencies = [ + "nalgebra", + "ndarray", + "rand", + "rand_distr", + "serde", + "serde_json", + "thiserror", +] + +[[package]] +name = "ruvector-hyperbolic-hnsw-wasm" +version = "0.1.0" +dependencies = [ + "console_error_panic_hook", + "getrandom", + "js-sys", + "rayon", + "ruvector-hyperbolic-hnsw", + "serde", + "serde-wasm-bindgen", + "serde_json", + "wasm-bindgen", + "wasm-bindgen-rayon", + "wasm-bindgen-test", + "web-sys", +] + +[[package]] +name = "safe_arch" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96b02de82ddbe1b636e6170c21be622223aea188ef2e139be0a5b219ec215323" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde-wasm-bindgen" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8302e169f0eddcc139c70f139d19d6467353af16f9fce27e8c30158036a1e16b" +dependencies = [ + "js-sys", + "serde", + "wasm-bindgen", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "simba" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c99284beb21666094ba2b75bbceda012e610f5479dfcc2d6e2426f53197ffd95" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", + "wide", +] + +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + +[[package]] +name = "syn" +version = "2.0.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasm-bindgen" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64024a30ec1e37399cf85a7ffefebdb72205ca1c972291c51512360d90bd8566" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70a6e77fd0ae8029c9ea0063f87c46fde723e7d887703d74ad2616d792e51e6f" +dependencies = [ + "cfg-if", + "futures-util", + "js-sys", + "once_cell", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "008b239d9c740232e71bd39e8ef6429d27097518b6b30bdf9086833bd5b6d608" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5256bae2d58f54820e6490f9839c49780dff84c65aeab9e772f15d5f0e913a55" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-rayon" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a16c60a56c81e4dc3b9c43d76ba5633e1c0278211d59a9cb07d61b6cd1c6583" +dependencies = [ + "crossbeam-channel", + "js-sys", + "rayon", + "wasm-bindgen", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f01b580c9ac74c8d8f0c0e4afb04eeef2acf145458e52c03845ee9cd23e3d12" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-bindgen-test" +version = "0.3.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45649196a53b0b7a15101d845d44d2dda7374fc1b5b5e2bbf58b7577ff4b346d" +dependencies = [ + "async-trait", + "cast", + "js-sys", + "libm", + "minicov", + "nu-ansi-term", + "num-traits", + "oorandom", + "serde", + "serde_json", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-bindgen-test-macro", + "wasm-bindgen-test-shared", +] + +[[package]] +name = "wasm-bindgen-test-macro" +version = "0.3.58" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f579cdd0123ac74b94e1a4a72bd963cf30ebac343f2df347da0b8df24cdebed2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "wasm-bindgen-test-shared" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8145dd1593bf0fb137dbfa85b8be79ec560a447298955877804640e40c2d6ea" + +[[package]] +name = "wasm_sync" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cff360cade7fec41ff0e9d2cda57fe58258c5f16def0e21302394659e6bbb0ea" +dependencies = [ + "js-sys", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "web-sys" +version = "0.3.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "312e32e551d92129218ea9a2452120f4aabc03529ef03e4d0d82fb2780608598" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "wide" +version = "0.7.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce5da8ecb62bcd8ec8b7ea19f69a51275e91299be594ea5cc6ef7819e16cd03" +dependencies = [ + "bytemuck", + "safe_arch", +] + +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "zerocopy" +version = "0.8.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "668f5168d10b9ee831de31933dc111a459c97ec93225beb307aed970d1372dfd" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c7962b26b0a8685668b671ee4b54d007a67d4eaf05fda79ac0ecf41e32270f1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zmij" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd8f3f50b848df28f887acb68e41201b5aea6bc8a8dacc00fb40635ff9a72fea" diff --git a/crates/ruvector-hyperbolic-hnsw-wasm/Cargo.toml b/crates/ruvector-hyperbolic-hnsw-wasm/Cargo.toml new file mode 100644 index 000000000..f7543d47b --- /dev/null +++ b/crates/ruvector-hyperbolic-hnsw-wasm/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "ruvector-hyperbolic-hnsw-wasm" +version = "0.1.0" +edition = "2021" +rust-version = "1.77" +license = "MIT" +authors = ["RuVector Team"] +repository = "https://github.com/ruvnet/ruvector" +description = "WebAssembly bindings for hyperbolic HNSW embeddings - hierarchy-aware vector search in the browser" +keywords = ["wasm", "hyperbolic", "poincare", "hnsw", "vector-search"] +categories = ["wasm", "mathematics", "algorithms"] + +[lib] +crate-type = ["cdylib", "rlib"] + +[features] +default = ["console_error_panic_hook"] +parallel = ["rayon", "wasm-bindgen-rayon"] + +[dependencies] +ruvector-hyperbolic-hnsw = { version = "0.1.0", path = "../ruvector-hyperbolic-hnsw", default-features = false } +wasm-bindgen = "0.2.106" +js-sys = "0.3" +web-sys = { version = "0.3", features = ["console"] } +getrandom = { version = "0.2", features = ["js"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +serde-wasm-bindgen = "0.6" +console_error_panic_hook = { version = "0.1", optional = true } +rayon = { version = "1.10", optional = true } +wasm-bindgen-rayon = { version = "1.2", optional = true } + +[package.metadata.wasm-pack.profile.release] +wasm-opt = ["-O3", "--enable-simd"] + +[dev-dependencies] +wasm-bindgen-test = "0.3" diff --git a/crates/ruvector-hyperbolic-hnsw-wasm/src/lib.rs b/crates/ruvector-hyperbolic-hnsw-wasm/src/lib.rs new file mode 100644 index 000000000..05c48af27 --- /dev/null +++ b/crates/ruvector-hyperbolic-hnsw-wasm/src/lib.rs @@ -0,0 +1,632 @@ +//! WebAssembly Bindings for Hyperbolic HNSW +//! +//! This module provides JavaScript/TypeScript bindings for hyperbolic embeddings +//! and HNSW search in the browser and Node.js environments. +//! +//! # Usage in JavaScript +//! +//! ```javascript +//! import init, { +//! HyperbolicIndex, +//! poincareDistance, +//! mobiusAdd, +//! expMap, +//! logMap +//! } from 'ruvector-hyperbolic-hnsw-wasm'; +//! +//! // Initialize WASM module +//! await init(); +//! +//! // Create index +//! const index = new HyperbolicIndex(16, 1.0); // ef_search=16, curvature=1.0 +//! +//! // Insert vectors +//! index.insert(new Float32Array([0.1, 0.2, 0.3])); +//! index.insert(new Float32Array([-0.1, 0.15, 0.25])); +//! +//! // Search +//! const results = index.search(new Float32Array([0.15, 0.1, 0.2]), 2); +//! console.log(results); // [{id: 0, distance: 0.123}, ...] +//! +//! // Use low-level math operations +//! const d = poincareDistance( +//! new Float32Array([0.3, 0.2]), +//! new Float32Array([-0.1, 0.4]), +//! 1.0 +//! ); +//! ``` + +use ruvector_hyperbolic_hnsw::{ + exp_map, frechet_mean, log_map, mobius_add, mobius_scalar_mult, poincare_distance, + project_to_ball, HyperbolicHnsw, HyperbolicHnswConfig, PoincareConfig, ShardedHyperbolicHnsw, + TangentCache, DEFAULT_CURVATURE, EPS, +}; +use serde::{Deserialize, Serialize}; +use wasm_bindgen::prelude::*; + +#[cfg(feature = "console_error_panic_hook")] +fn set_panic_hook() { + console_error_panic_hook::set_once(); +} + +/// Initialize the WASM module +#[wasm_bindgen(start)] +pub fn init() { + #[cfg(feature = "console_error_panic_hook")] + set_panic_hook(); +} + +// ============================================================================ +// Low-Level Math Operations +// ============================================================================ + +/// Compute Poincaré distance between two points +/// +/// @param u - First point (Float32Array) +/// @param v - Second point (Float32Array) +/// @param curvature - Curvature parameter (positive) +/// @returns Geodesic distance in hyperbolic space +#[wasm_bindgen(js_name = poincareDistance)] +pub fn wasm_poincare_distance(u: &[f32], v: &[f32], curvature: f32) -> f32 { + poincare_distance(u, v, curvature) +} + +/// Möbius addition in Poincaré ball +/// +/// Computes the hyperbolic analog of vector addition: x ⊕_c y +/// +/// @param x - First point (Float32Array) +/// @param y - Second point (Float32Array) +/// @param curvature - Curvature parameter +/// @returns Result of Möbius addition (Float32Array) +#[wasm_bindgen(js_name = mobiusAdd)] +pub fn wasm_mobius_add(x: &[f32], y: &[f32], curvature: f32) -> Vec { + mobius_add(x, y, curvature) +} + +/// Möbius scalar multiplication +/// +/// Computes r ⊗_c x for scalar r and point x +/// +/// @param r - Scalar value +/// @param x - Point in Poincaré ball (Float32Array) +/// @param curvature - Curvature parameter +/// @returns Scaled point (Float32Array) +#[wasm_bindgen(js_name = mobiusScalarMult)] +pub fn wasm_mobius_scalar_mult(r: f32, x: &[f32], curvature: f32) -> Vec { + mobius_scalar_mult(r, x, curvature) +} + +/// Exponential map at point p +/// +/// Maps a tangent vector v at point p to the Poincaré ball +/// +/// @param v - Tangent vector (Float32Array) +/// @param p - Base point (Float32Array) +/// @param curvature - Curvature parameter +/// @returns Point on the manifold (Float32Array) +#[wasm_bindgen(js_name = expMap)] +pub fn wasm_exp_map(v: &[f32], p: &[f32], curvature: f32) -> Vec { + exp_map(v, p, curvature) +} + +/// Logarithmic map at point p +/// +/// Maps a point y to the tangent space at point p +/// +/// @param y - Target point (Float32Array) +/// @param p - Base point (Float32Array) +/// @param curvature - Curvature parameter +/// @returns Tangent vector at p (Float32Array) +#[wasm_bindgen(js_name = logMap)] +pub fn wasm_log_map(y: &[f32], p: &[f32], curvature: f32) -> Vec { + log_map(y, p, curvature) +} + +/// Project point to Poincaré ball +/// +/// Ensures ||x|| < 1/√c - eps for numerical stability +/// +/// @param x - Point to project (Float32Array) +/// @param curvature - Curvature parameter +/// @returns Projected point (Float32Array) +#[wasm_bindgen(js_name = projectToBall)] +pub fn wasm_project_to_ball(x: &[f32], curvature: f32) -> Vec { + project_to_ball(x, curvature, EPS) +} + +/// Compute Fréchet mean (hyperbolic centroid) +/// +/// @param points - Array of points as flat Float32Array +/// @param dim - Dimension of each point +/// @param curvature - Curvature parameter +/// @returns Centroid point (Float32Array) +#[wasm_bindgen(js_name = frechetMean)] +pub fn wasm_frechet_mean(points: &[f32], dim: usize, curvature: f32) -> Result, JsValue> { + if points.is_empty() || dim == 0 { + return Err(JsValue::from_str("Empty points or invalid dimension")); + } + + let point_vecs: Vec> = points.chunks(dim).map(|c| c.to_vec()).collect(); + + let point_refs: Vec<&[f32]> = point_vecs.iter().map(|v| v.as_slice()).collect(); + + let config = PoincareConfig::with_curvature(curvature) + .map_err(|e| JsValue::from_str(&e.to_string()))?; + + frechet_mean(&point_refs, None, &config).map_err(|e| JsValue::from_str(&e.to_string())) +} + +// ============================================================================ +// Search Result Type +// ============================================================================ + +/// Search result from hyperbolic HNSW +#[derive(Debug, Clone, Serialize, Deserialize)] +#[wasm_bindgen] +pub struct WasmSearchResult { + /// Vector ID + pub id: usize, + /// Hyperbolic distance to query + pub distance: f32, +} + +#[wasm_bindgen] +impl WasmSearchResult { + #[wasm_bindgen(constructor)] + pub fn new(id: usize, distance: f32) -> Self { + Self { id, distance } + } +} + +// ============================================================================ +// Hyperbolic HNSW Index +// ============================================================================ + +/// Hyperbolic HNSW Index for hierarchy-aware vector search +/// +/// @example +/// ```javascript +/// const index = new HyperbolicIndex(16, 1.0); +/// index.insert(new Float32Array([0.1, 0.2])); +/// index.insert(new Float32Array([-0.1, 0.3])); +/// const results = index.search(new Float32Array([0.05, 0.25]), 2); +/// ``` +#[wasm_bindgen] +pub struct HyperbolicIndex { + inner: HyperbolicHnsw, +} + +#[wasm_bindgen] +impl HyperbolicIndex { + /// Create a new hyperbolic HNSW index + /// + /// @param ef_search - Size of dynamic candidate list during search (default: 50) + /// @param curvature - Curvature parameter for Poincaré ball (default: 1.0) + #[wasm_bindgen(constructor)] + pub fn new(ef_search: Option, curvature: Option) -> Self { + let mut config = HyperbolicHnswConfig::default(); + config.ef_search = ef_search.unwrap_or(50); + config.curvature = curvature.unwrap_or(DEFAULT_CURVATURE); + + Self { + inner: HyperbolicHnsw::new(config), + } + } + + /// Create with custom configuration + /// + /// @param config - JSON configuration object + #[wasm_bindgen(js_name = fromConfig)] + pub fn from_config(config: JsValue) -> Result { + let config: HyperbolicHnswConfig = + serde_wasm_bindgen::from_value(config).map_err(|e| JsValue::from_str(&e.to_string()))?; + Ok(Self { + inner: HyperbolicHnsw::new(config), + }) + } + + /// Insert a vector into the index + /// + /// @param vector - Vector to insert (Float32Array) + /// @returns ID of inserted vector + #[wasm_bindgen] + pub fn insert(&mut self, vector: &[f32]) -> Result { + self.inner + .insert(vector.to_vec()) + .map_err(|e| JsValue::from_str(&e.to_string())) + } + + /// Insert batch of vectors + /// + /// @param vectors - Flat array of vectors + /// @param dim - Dimension of each vector + /// @returns Array of inserted IDs + #[wasm_bindgen(js_name = insertBatch)] + pub fn insert_batch(&mut self, vectors: &[f32], dim: usize) -> Result, JsValue> { + let vecs: Vec> = vectors.chunks(dim).map(|c| c.to_vec()).collect(); + self.inner + .insert_batch(vecs) + .map_err(|e| JsValue::from_str(&e.to_string())) + } + + /// Search for k nearest neighbors + /// + /// @param query - Query vector (Float32Array) + /// @param k - Number of neighbors to return + /// @returns Array of search results as JSON + #[wasm_bindgen] + pub fn search(&self, query: &[f32], k: usize) -> Result { + let results = self + .inner + .search(query, k) + .map_err(|e| JsValue::from_str(&e.to_string()))?; + + let wasm_results: Vec = results + .into_iter() + .map(|r| WasmSearchResult::new(r.id, r.distance)) + .collect(); + + serde_wasm_bindgen::to_value(&wasm_results).map_err(|e| JsValue::from_str(&e.to_string())) + } + + /// Search with tangent space pruning (optimized) + /// + /// @param query - Query vector (Float32Array) + /// @param k - Number of neighbors to return + /// @returns Array of search results as JSON + #[wasm_bindgen(js_name = searchWithPruning)] + pub fn search_with_pruning(&self, query: &[f32], k: usize) -> Result { + let results = self + .inner + .search_with_pruning(query, k) + .map_err(|e| JsValue::from_str(&e.to_string()))?; + + let wasm_results: Vec = results + .into_iter() + .map(|r| WasmSearchResult::new(r.id, r.distance)) + .collect(); + + serde_wasm_bindgen::to_value(&wasm_results).map_err(|e| JsValue::from_str(&e.to_string())) + } + + /// Build tangent cache for optimized search + #[wasm_bindgen(js_name = buildTangentCache)] + pub fn build_tangent_cache(&mut self) -> Result<(), JsValue> { + self.inner + .build_tangent_cache() + .map_err(|e| JsValue::from_str(&e.to_string())) + } + + /// Get number of vectors in index + #[wasm_bindgen] + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Check if index is empty + #[wasm_bindgen(js_name = isEmpty)] + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Get vector dimension + #[wasm_bindgen] + pub fn dim(&self) -> Option { + self.inner.dim() + } + + /// Update curvature parameter + /// + /// @param curvature - New curvature value (must be positive) + #[wasm_bindgen(js_name = setCurvature)] + pub fn set_curvature(&mut self, curvature: f32) -> Result<(), JsValue> { + self.inner + .set_curvature(curvature) + .map_err(|e| JsValue::from_str(&e.to_string())) + } + + /// Get a vector by ID + /// + /// @param id - Vector ID + /// @returns Vector data or null if not found + #[wasm_bindgen(js_name = getVector)] + pub fn get_vector(&self, id: usize) -> Option> { + self.inner.get_vector(id).map(|v| v.to_vec()) + } + + /// Export index configuration as JSON + #[wasm_bindgen(js_name = exportConfig)] + pub fn export_config(&self) -> Result { + serde_wasm_bindgen::to_value(&self.inner.config) + .map_err(|e| JsValue::from_str(&e.to_string())) + } +} + +// ============================================================================ +// Sharded Index +// ============================================================================ + +/// Sharded Hyperbolic HNSW with per-shard curvature +/// +/// @example +/// ```javascript +/// const manager = new ShardedIndex(1.0); +/// manager.insertToShard("taxonomy", new Float32Array([0.1, 0.2]), 0); +/// manager.insertToShard("taxonomy", new Float32Array([0.3, 0.1]), 3); +/// manager.updateCurvature("taxonomy", 0.5); +/// const results = manager.search(new Float32Array([0.2, 0.15]), 5); +/// ``` +#[wasm_bindgen] +pub struct ShardedIndex { + inner: ShardedHyperbolicHnsw, +} + +#[wasm_bindgen] +impl ShardedIndex { + /// Create a new sharded index + /// + /// @param default_curvature - Default curvature for new shards + #[wasm_bindgen(constructor)] + pub fn new(default_curvature: f32) -> Self { + Self { + inner: ShardedHyperbolicHnsw::new(default_curvature), + } + } + + /// Insert vector with automatic shard assignment + /// + /// @param vector - Vector to insert (Float32Array) + /// @param depth - Optional hierarchy depth for shard assignment + /// @returns Global vector ID + #[wasm_bindgen] + pub fn insert(&mut self, vector: &[f32], depth: Option) -> Result { + self.inner + .insert(vector.to_vec(), depth) + .map_err(|e| JsValue::from_str(&e.to_string())) + } + + /// Insert vector into specific shard + /// + /// @param shard_id - Target shard ID + /// @param vector - Vector to insert (Float32Array) + /// @returns Global vector ID + #[wasm_bindgen(js_name = insertToShard)] + pub fn insert_to_shard(&mut self, shard_id: &str, vector: &[f32]) -> Result { + self.inner + .insert_to_shard(shard_id, vector.to_vec()) + .map_err(|e| JsValue::from_str(&e.to_string())) + } + + /// Search across all shards + /// + /// @param query - Query vector (Float32Array) + /// @param k - Number of neighbors to return + /// @returns Array of search results as JSON + #[wasm_bindgen] + pub fn search(&self, query: &[f32], k: usize) -> Result { + let results = self + .inner + .search(query, k) + .map_err(|e| JsValue::from_str(&e.to_string()))?; + + let wasm_results: Vec = results + .into_iter() + .map(|(id, r)| WasmSearchResult::new(id, r.distance)) + .collect(); + + serde_wasm_bindgen::to_value(&wasm_results).map_err(|e| JsValue::from_str(&e.to_string())) + } + + /// Update curvature for a shard + /// + /// @param shard_id - Shard ID + /// @param curvature - New curvature value + #[wasm_bindgen(js_name = updateCurvature)] + pub fn update_curvature(&mut self, shard_id: &str, curvature: f32) -> Result<(), JsValue> { + self.inner + .update_curvature(shard_id, curvature) + .map_err(|e| JsValue::from_str(&e.to_string())) + } + + /// Set canary curvature for A/B testing + /// + /// @param shard_id - Shard ID + /// @param curvature - Canary curvature value + /// @param traffic - Percentage of traffic for canary (0-100) + #[wasm_bindgen(js_name = setCanaryCurvature)] + pub fn set_canary_curvature(&mut self, shard_id: &str, curvature: f32, traffic: u8) { + self.inner.registry.set_canary(shard_id, curvature, traffic); + } + + /// Promote canary to production + /// + /// @param shard_id - Shard ID + #[wasm_bindgen(js_name = promoteCanary)] + pub fn promote_canary(&mut self, shard_id: &str) -> Result<(), JsValue> { + if let Some(shard_curv) = self.inner.registry.shards.get_mut(shard_id) { + shard_curv.promote_canary(); + } + self.inner + .reload_curvatures() + .map_err(|e| JsValue::from_str(&e.to_string())) + } + + /// Rollback canary + /// + /// @param shard_id - Shard ID + #[wasm_bindgen(js_name = rollbackCanary)] + pub fn rollback_canary(&mut self, shard_id: &str) { + if let Some(shard_curv) = self.inner.registry.shards.get_mut(shard_id) { + shard_curv.rollback_canary(); + } + } + + /// Build tangent caches for all shards + #[wasm_bindgen(js_name = buildCaches)] + pub fn build_caches(&mut self) -> Result<(), JsValue> { + self.inner + .build_caches() + .map_err(|e| JsValue::from_str(&e.to_string())) + } + + /// Get total vector count + #[wasm_bindgen] + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Check if empty + #[wasm_bindgen(js_name = isEmpty)] + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Get number of shards + #[wasm_bindgen(js_name = numShards)] + pub fn num_shards(&self) -> usize { + self.inner.num_shards() + } + + /// Get curvature registry as JSON + #[wasm_bindgen(js_name = getRegistry)] + pub fn get_registry(&self) -> Result { + serde_wasm_bindgen::to_value(&self.inner.registry) + .map_err(|e| JsValue::from_str(&e.to_string())) + } +} + +// ============================================================================ +// Tangent Cache Operations +// ============================================================================ + +/// Tangent space cache for fast pruning +#[wasm_bindgen] +pub struct WasmTangentCache { + inner: TangentCache, +} + +#[wasm_bindgen] +impl WasmTangentCache { + /// Create tangent cache from points + /// + /// @param points - Flat array of points + /// @param dim - Dimension of each point + /// @param curvature - Curvature parameter + #[wasm_bindgen(constructor)] + pub fn new(points: &[f32], dim: usize, curvature: f32) -> Result { + let point_vecs: Vec> = points.chunks(dim).map(|c| c.to_vec()).collect(); + let indices: Vec = (0..point_vecs.len()).collect(); + + let cache = TangentCache::new(&point_vecs, &indices, curvature) + .map_err(|e| JsValue::from_str(&e.to_string()))?; + + Ok(Self { inner: cache }) + } + + /// Get centroid of the cache + #[wasm_bindgen] + pub fn centroid(&self) -> Vec { + self.inner.centroid.clone() + } + + /// Get tangent coordinates for a query + /// + /// @param query - Query point (Float32Array) + /// @returns Tangent coordinates (Float32Array) + #[wasm_bindgen(js_name = queryTangent)] + pub fn query_tangent(&self, query: &[f32]) -> Vec { + self.inner.query_tangent(query) + } + + /// Compute tangent distance squared (for fast pruning) + /// + /// @param query_tangent - Query in tangent space (Float32Array) + /// @param idx - Index of cached point + /// @returns Squared distance in tangent space + #[wasm_bindgen(js_name = tangentDistanceSquared)] + pub fn tangent_distance_squared(&self, query_tangent: &[f32], idx: usize) -> f32 { + self.inner.tangent_distance_squared(query_tangent, idx) + } + + /// Get number of cached points + #[wasm_bindgen] + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Get dimension + #[wasm_bindgen] + pub fn dim(&self) -> usize { + self.inner.dim() + } +} + +// ============================================================================ +// Utility Functions +// ============================================================================ + +/// Get library version +#[wasm_bindgen(js_name = getVersion)] +pub fn get_version() -> String { + ruvector_hyperbolic_hnsw::VERSION.to_string() +} + +/// Get default curvature value +#[wasm_bindgen(js_name = getDefaultCurvature)] +pub fn get_default_curvature() -> f32 { + DEFAULT_CURVATURE +} + +/// Get numerical stability epsilon +#[wasm_bindgen(js_name = getEps)] +pub fn get_eps() -> f32 { + EPS +} + +/// Compute vector norm +#[wasm_bindgen(js_name = vectorNorm)] +pub fn vector_norm(x: &[f32]) -> f32 { + ruvector_hyperbolic_hnsw::norm(x) +} + +/// Compute squared vector norm +#[wasm_bindgen(js_name = vectorNormSquared)] +pub fn vector_norm_squared(x: &[f32]) -> f32 { + ruvector_hyperbolic_hnsw::norm_squared(x) +} + +#[cfg(test)] +mod tests { + use super::*; + use wasm_bindgen_test::*; + + #[wasm_bindgen_test] + fn test_poincare_distance() { + let u = vec![0.3, 0.2]; + let v = vec![-0.1, 0.4]; + let d = wasm_poincare_distance(&u, &v, 1.0); + assert!(d > 0.0); + } + + #[wasm_bindgen_test] + fn test_mobius_add() { + let x = vec![0.2, 0.1]; + let y = vec![0.1, -0.1]; + let z = wasm_mobius_add(&x, &y, 1.0); + assert_eq!(z.len(), 2); + } + + #[wasm_bindgen_test] + fn test_hyperbolic_index() { + let mut index = HyperbolicIndex::new(Some(16), Some(1.0)); + + index.insert(&[0.1, 0.2, 0.3]).unwrap(); + index.insert(&[-0.1, 0.15, 0.25]).unwrap(); + index.insert(&[0.2, -0.1, 0.1]).unwrap(); + + assert_eq!(index.len(), 3); + assert!(!index.is_empty()); + assert_eq!(index.dim(), Some(3)); + } +} diff --git a/crates/ruvector-hyperbolic-hnsw/Cargo.lock b/crates/ruvector-hyperbolic-hnsw/Cargo.lock new file mode 100644 index 000000000..9729cd8c3 --- /dev/null +++ b/crates/ruvector-hyperbolic-hnsw/Cargo.lock @@ -0,0 +1,1145 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + +[[package]] +name = "bumpalo" +version = "3.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" + +[[package]] +name = "bytemuck" +version = "1.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4" + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "clap" +version = "4.5.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6e6ff9dcd79cff5cd969a17a545d79e84ab086e444102a591e288a8aa3ce394" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa42cf4d2b7a41bc8f663a7cab4031ebafa1bf3875705bfaf8466dc60ab52c00" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e64b0cc0439b12df2fa678eae89a1c56a529fd067a9115f7827f1fffd22b32" + +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + +[[package]] +name = "glam" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "333928d5eb103c5d4050533cec0384302db6be8ef7d3cebd30ec6a35350353da" + +[[package]] +name = "glam" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3abb554f8ee44336b72d522e0a7fe86a29e09f839a36022fa869a7dfe941a54b" + +[[package]] +name = "glam" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4126c0479ccf7e8664c36a2d719f5f2c140fbb4f9090008098d2c291fa5b3f16" + +[[package]] +name = "glam" +version = "0.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01732b97afd8508eee3333a541b9f7610f454bb818669e66e90f5f57c93a776" + +[[package]] +name = "glam" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525a3e490ba77b8e326fb67d4b44b4bd2f920f44d4cc73ccec50adc68e3bee34" + +[[package]] +name = "glam" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b8509e6791516e81c1a630d0bd7fbac36d2fa8712a9da8662e716b52d5051ca" + +[[package]] +name = "glam" +version = "0.20.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f43e957e744be03f5801a55472f593d43fabdebf25a4585db250f04d86b1675f" + +[[package]] +name = "glam" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "518faa5064866338b013ff9b2350dc318e14cc4fcd6cb8206d7e7c9886c98815" + +[[package]] +name = "glam" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12f597d56c1bd55a811a1be189459e8fad2bbc272616375602443bdfb37fa774" + +[[package]] +name = "glam" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e4afd9ad95555081e109fe1d21f2a30c691b5f0919c67dfa690a2e1eb6bd51c" + +[[package]] +name = "glam" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5418c17512bdf42730f9032c74e1ae39afc408745ebb2acf72fbc4691c17945" + +[[package]] +name = "glam" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "151665d9be52f9bb40fc7966565d39666f2d1e69233571b71b87791c7e0528b3" + +[[package]] +name = "glam" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e05e7e6723e3455f4818c7b26e855439f7546cf617ef669d1adedb8669e5cb9" + +[[package]] +name = "glam" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "779ae4bf7e8421cf91c0b3b64e7e8b40b862fba4d393f59150042de7c4965a94" + +[[package]] +name = "glam" +version = "0.29.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8babf46d4c1c9d92deac9f7be466f76dfc4482b6452fc5024b5e8daf6ffeb3ee" + +[[package]] +name = "glam" +version = "0.30.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19fc433e8437a212d1b6f1e68c7824af3aed907da60afa994e7f542d18d12aa9" + +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if", + "crunchy", + "zerocopy", +] + +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + +[[package]] +name = "is-terminal" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "js-sys" +version = "0.3.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c942ebf8e95485ca0d52d97da7c5a2c387d0e7f0ba4c35e93bfcaee045955b3" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "libc" +version = "0.2.180" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" + +[[package]] +name = "libm" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" + +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + +[[package]] +name = "nalgebra" +version = "0.34.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4d5b3eff5cd580f93da45e64715e8c20a3996342f1e466599cf7a267a0c2f5f" +dependencies = [ + "approx", + "glam 0.14.0", + "glam 0.15.2", + "glam 0.16.0", + "glam 0.17.3", + "glam 0.18.0", + "glam 0.19.0", + "glam 0.20.5", + "glam 0.21.3", + "glam 0.22.0", + "glam 0.23.0", + "glam 0.24.2", + "glam 0.25.0", + "glam 0.27.0", + "glam 0.28.0", + "glam 0.29.3", + "glam 0.30.10", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "973e7178a678cfd059ccec50887658d482ce16b0aa9da3888ddeab5cd5eb4889" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "ndarray" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "portable-atomic" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "535d180e0ecab6268a3e718bb9fd44db66bbbc256257165fc699dadf70d16fe7" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "proptest" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bee689443a2bd0a16ab0348b52ee43e3b2d1b1f931c8aa5c9f8de4c86fbe8c40" +dependencies = [ + "bit-set", + "bit-vec", + "bitflags", + "num-traits", + "rand 0.9.2", + "rand_chacha 0.9.0", + "rand_xorshift", + "regex-syntax", + "rusty-fork", + "tempfile", + "unarray", +] + +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + +[[package]] +name = "quote" +version = "1.0.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc74d9a594b72ae6656596548f56f667211f8a97b3d4c3d467150794690dc40a" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.17", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand 0.8.5", +] + +[[package]] +name = "rand_xorshift" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a" +dependencies = [ + "rand_core 0.9.5", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "regex" +version = "1.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" + +[[package]] +name = "rustix" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "rusty-fork" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc6bf79ff24e648f6da1f8d1f011e9cac26491b619e6b9280f2b47f1774e6ee2" +dependencies = [ + "fnv", + "quick-error", + "tempfile", + "wait-timeout", +] + +[[package]] +name = "ruvector-hyperbolic-hnsw" +version = "0.1.0" +dependencies = [ + "approx", + "criterion", + "nalgebra", + "ndarray", + "proptest", + "rand 0.8.5", + "rand_distr", + "rayon", + "serde", + "serde_json", + "thiserror", +] + +[[package]] +name = "safe_arch" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96b02de82ddbe1b636e6170c21be622223aea188ef2e139be0a5b219ec215323" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "simba" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c99284beb21666094ba2b75bbceda012e610f5479dfcc2d6e2426f53197ffd95" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", + "wide", +] + +[[package]] +name = "syn" +version = "2.0.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tempfile" +version = "3.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys", +] + +[[package]] +name = "thiserror" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" + +[[package]] +name = "wait-timeout" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ac3b126d3914f9849036f826e054cbabdc8519970b8998ddaf3b5bd3c65f11" +dependencies = [ + "libc", +] + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64024a30ec1e37399cf85a7ffefebdb72205ca1c972291c51512360d90bd8566" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "008b239d9c740232e71bd39e8ef6429d27097518b6b30bdf9086833bd5b6d608" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5256bae2d58f54820e6490f9839c49780dff84c65aeab9e772f15d5f0e913a55" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f01b580c9ac74c8d8f0c0e4afb04eeef2acf145458e52c03845ee9cd23e3d12" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "web-sys" +version = "0.3.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "312e32e551d92129218ea9a2452120f4aabc03529ef03e4d0d82fb2780608598" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "wide" +version = "0.7.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce5da8ecb62bcd8ec8b7ea19f69a51275e91299be594ea5cc6ef7819e16cd03" +dependencies = [ + "bytemuck", + "safe_arch", +] + +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" + +[[package]] +name = "zerocopy" +version = "0.8.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "668f5168d10b9ee831de31933dc111a459c97ec93225beb307aed970d1372dfd" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c7962b26b0a8685668b671ee4b54d007a67d4eaf05fda79ac0ecf41e32270f1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zmij" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd8f3f50b848df28f887acb68e41201b5aea6bc8a8dacc00fb40635ff9a72fea" diff --git a/crates/ruvector-hyperbolic-hnsw/Cargo.toml b/crates/ruvector-hyperbolic-hnsw/Cargo.toml new file mode 100644 index 000000000..045ae231a --- /dev/null +++ b/crates/ruvector-hyperbolic-hnsw/Cargo.toml @@ -0,0 +1,52 @@ +[package] +name = "ruvector-hyperbolic-hnsw" +version = "0.1.0" +edition = "2021" +rust-version = "1.77" +license = "MIT" +authors = ["RuVector Team"] +repository = "https://github.com/ruvnet/ruvector" +description = "Hyperbolic (Poincaré ball) embeddings with HNSW integration for hierarchy-aware vector search" +keywords = ["hyperbolic", "poincare", "hnsw", "vector-search", "embeddings"] +categories = ["algorithms", "science", "mathematics"] + +[lib] +crate-type = ["rlib"] + +[features] +default = ["simd", "parallel"] +simd = [] +parallel = ["rayon"] +wasm = [] + +[dependencies] +# Math and numerics (exact versions as specified) +nalgebra = "0.34.1" +ndarray = "0.17.1" + +# Parallel processing +rayon = { version = "1.10", optional = true } + +# Serialization +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + +# Error handling +thiserror = "2.0" + +# Random number generation +rand = "0.8" +rand_distr = "0.4" + +[dev-dependencies] +criterion = { version = "0.5", features = ["html_reports"] } +approx = "0.5" +proptest = "1.5" + +[[bench]] +name = "hyperbolic_bench" +harness = false + +[[test]] +name = "math_tests" +path = "tests/math_tests.rs" diff --git a/crates/ruvector-hyperbolic-hnsw/README.md b/crates/ruvector-hyperbolic-hnsw/README.md new file mode 100644 index 000000000..61a75c9a1 --- /dev/null +++ b/crates/ruvector-hyperbolic-hnsw/README.md @@ -0,0 +1,242 @@ +# ruvector-hyperbolic-hnsw + +Hyperbolic (Poincaré ball) embeddings with HNSW integration for hierarchy-aware vector search. + +## Why Hyperbolic Space? + +Hierarchies compress naturally in hyperbolic space. Taxonomies, catalogs, ICD trees, product facets, org charts, and long-tail tags all fit better than in Euclidean space, which means higher recall on deep leaves without blowing up memory or latency. + +## Key Features + +- **Poincaré Ball Model**: Store vectors in the Poincaré ball with clamp `r < 1 − eps` +- **HNSW Speed Trick**: Prune with cheap tangent-space proxy, rank with true hyperbolic distance +- **Per-Shard Curvature**: Different parts of the hierarchy can have different optimal curvatures +- **Dual-Space Index**: Keep a synchronized Euclidean ANN for fallback and mutual-ranking fusion +- **Production Guardrails**: Numerical stability, canary testing, hot curvature reload + +## Installation + +### Rust + +```toml +[dependencies] +ruvector-hyperbolic-hnsw = "0.1.0" +``` + +### WebAssembly + +```bash +cd crates/ruvector-hyperbolic-hnsw-wasm +wasm-pack build --target web --release +``` + +### TypeScript/JavaScript + +```typescript +import init, { + HyperbolicIndex, + poincareDistance, + mobiusAdd, + expMap, + logMap +} from 'ruvector-hyperbolic-hnsw-wasm'; + +await init(); + +const index = new HyperbolicIndex(16, 1.0); +index.insert(new Float32Array([0.1, 0.2, 0.3])); +const results = index.search(new Float32Array([0.15, 0.1, 0.2]), 5); +``` + +## Quick Start + +```rust +use ruvector_hyperbolic_hnsw::{HyperbolicHnsw, HyperbolicHnswConfig}; + +// Create index with default settings +let mut index = HyperbolicHnsw::default_config(); + +// Insert vectors (automatically projected to Poincaré ball) +index.insert(vec![0.1, 0.2, 0.3]).unwrap(); +index.insert(vec![-0.1, 0.15, 0.25]).unwrap(); +index.insert(vec![0.2, -0.1, 0.1]).unwrap(); + +// Search for nearest neighbors +let results = index.search(&[0.15, 0.1, 0.2], 2).unwrap(); +for r in results { + println!("ID: {}, Distance: {:.4}", r.id, r.distance); +} +``` + +## HNSW Speed Trick + +The core optimization: + +1. Precompute `u = log_c(x)` at a shard centroid `c` +2. During neighbor selection, use Euclidean `||u_q - u_p||` to prune +3. Run exact Poincaré distance only on top N candidates before final ranking + +```rust +use ruvector_hyperbolic_hnsw::{HyperbolicHnsw, HyperbolicHnswConfig}; + +let mut config = HyperbolicHnswConfig::default(); +config.use_tangent_pruning = true; +config.prune_factor = 10; // Consider 10x candidates in tangent space + +let mut index = HyperbolicHnsw::new(config); + +// ... insert vectors ... + +// Build tangent cache for pruning optimization +index.build_tangent_cache().unwrap(); + +// Search with pruning (faster!) +let results = index.search_with_pruning(&[0.1, 0.15], 5).unwrap(); +``` + +## Core Mathematical Operations + +```rust +use ruvector_hyperbolic_hnsw::poincare::{ + mobius_add, exp_map, log_map, poincare_distance, project_to_ball +}; + +let x = vec![0.3, 0.2]; +let y = vec![-0.1, 0.4]; +let c = 1.0; // Curvature + +// Möbius addition (hyperbolic vector addition) +let z = mobius_add(&x, &y, c); + +// Geodesic distance in hyperbolic space +let d = poincare_distance(&x, &y, c); + +// Map to tangent space at x +let v = log_map(&y, &x, c); + +// Map back to manifold +let y_recovered = exp_map(&v, &x, c); +``` + +## Sharded Index with Per-Shard Curvature + +```rust +use ruvector_hyperbolic_hnsw::{ShardedHyperbolicHnsw, ShardStrategy}; + +let mut manager = ShardedHyperbolicHnsw::new(1.0); + +// Insert with hierarchy depth information +manager.insert(vec![0.1, 0.2], Some(0)).unwrap(); // Root level +manager.insert(vec![0.3, 0.1], Some(3)).unwrap(); // Deeper level + +// Update curvature for specific shard +manager.update_curvature("radius_1", 0.5).unwrap(); + +// Canary testing for new curvature +manager.registry.set_canary("radius_1", 0.3, 10); // 10% traffic + +// Search across all shards +let results = manager.search(&[0.2, 0.15], 5).unwrap(); +``` + +## Numerical Stability + +All operations include numerical safeguards: + +- **Norm clamping**: Points projected with `eps = 1e-5` +- **Projection after updates**: All operations keep points inside the ball +- **Stable acosh**: Uses `log1p` expansions for safety +- **Clamp arguments**: `arctanh` and `atanh` arguments bounded away from ±1 + +## Evaluation Protocol + +### Datasets + +- WordNet +- DBpedia slices +- Synthetic scale-free tree +- Domain taxonomy + +### Primary Metrics + +- **recall@k** (1, 5, 10) +- **Mean rank** +- **NDCG** + +### Hierarchy Metrics + +- **Radius vs depth Spearman correlation** +- **Distance distortion** +- **Ancestor AUPRC** + +### Baselines + +- Euclidean HNSW +- OPQ/PQ compressed +- Simple mutual-ranking fusion + +### Ablations + +- Tangent proxy vs full hyperbolic +- Fixed vs learnable curvature c +- Global vs shard centroids + +## Production Integration + +### Reflex Loop (on writes) + +Small Möbius deltas and tangent-space micro updates that never push points outside the ball. + +```rust +use ruvector_hyperbolic_hnsw::tangent_micro_update; + +let updated = tangent_micro_update( + &point, + &delta, + ¢roid, + curvature, + 0.1 // max step size +); +``` + +### Habit (nightly) + +Riemannian SGD passes to clean neighborhoods and optionally relearn per-shard curvature. Run canary first. + +### Structural (periodic) + +Rebuild of HNSW with true hyperbolic metric, curvature retune, and shard reshuffle if hierarchy preservation drops below SLO. + +## Dependencies (Exact Versions) + +```toml +nalgebra = "0.34.1" +ndarray = "0.17.1" +wasm-bindgen = "0.2.106" +``` + +## Benchmarks + +```bash +cd crates/ruvector-hyperbolic-hnsw +cargo bench +``` + +Benchmark suite includes: + +- Poincaré distance computation +- Möbius addition +- exp/log map operations +- HNSW insert and search +- Tangent cache building +- Search with vs without pruning + +## License + +MIT + +## Related + +- [ruvector-attention](../ruvector-attention) - Hyperbolic attention mechanisms +- [micro-hnsw-wasm](../micro-hnsw-wasm) - Minimal HNSW for WASM +- [ruvector-math](../ruvector-math) - General math primitives diff --git a/crates/ruvector-hyperbolic-hnsw/benches/hyperbolic_bench.rs b/crates/ruvector-hyperbolic-hnsw/benches/hyperbolic_bench.rs new file mode 100644 index 000000000..931789cee --- /dev/null +++ b/crates/ruvector-hyperbolic-hnsw/benches/hyperbolic_bench.rs @@ -0,0 +1,178 @@ +//! Benchmarks for hyperbolic HNSW operations +//! +//! Metrics as specified in evaluation protocol: +//! - p50 and p95 latency +//! - Memory overhead +//! - Search recall@k + +use criterion::{black_box, criterion_group, criterion_main, Criterion, BenchmarkId}; +use ruvector_hyperbolic_hnsw::*; + +fn bench_poincare_distance(c: &mut Criterion) { + let dims = [8, 32, 128, 512]; + + let mut group = c.benchmark_group("poincare_distance"); + + for dim in dims { + let x: Vec = (0..dim).map(|i| (i as f32 * 0.01) % 0.9).collect(); + let y: Vec = (0..dim).map(|i| ((i as f32 * 0.02) + 0.1) % 0.9).collect(); + + group.bench_with_input(BenchmarkId::new("dim", dim), &dim, |b, _| { + b.iter(|| poincare_distance(black_box(&x), black_box(&y), 1.0)) + }); + } + + group.finish(); +} + +fn bench_mobius_add(c: &mut Criterion) { + let dims = [8, 32, 128]; + + let mut group = c.benchmark_group("mobius_add"); + + for dim in dims { + let x: Vec = (0..dim).map(|i| (i as f32 * 0.01) % 0.5).collect(); + let y: Vec = (0..dim).map(|i| ((i as f32 * 0.02) + 0.1) % 0.5).collect(); + + group.bench_with_input(BenchmarkId::new("dim", dim), &dim, |b, _| { + b.iter(|| mobius_add(black_box(&x), black_box(&y), 1.0)) + }); + } + + group.finish(); +} + +fn bench_exp_log_map(c: &mut Criterion) { + let dim = 32; + let p: Vec = (0..dim).map(|i| (i as f32 * 0.01) % 0.3).collect(); + let v: Vec = (0..dim).map(|i| ((i as f32 * 0.005) - 0.1) % 0.2).collect(); + let q: Vec = (0..dim).map(|i| ((i as f32 * 0.02) + 0.1) % 0.4).collect(); + + let mut group = c.benchmark_group("exp_log_map"); + + group.bench_function("exp_map", |b| { + b.iter(|| exp_map(black_box(&v), black_box(&p), 1.0)) + }); + + group.bench_function("log_map", |b| { + b.iter(|| log_map(black_box(&q), black_box(&p), 1.0)) + }); + + group.finish(); +} + +fn bench_hnsw_insert(c: &mut Criterion) { + let sizes = [100, 500, 1000]; + + let mut group = c.benchmark_group("hnsw_insert"); + group.sample_size(20); + + for size in sizes { + let vectors: Vec> = (0..size) + .map(|i| vec![ + (i as f32 * 0.01) % 0.8, + ((i as f32 * 0.02) + 0.1) % 0.8, + ]) + .collect(); + + group.bench_with_input(BenchmarkId::new("n", size), &vectors, |b, vecs| { + b.iter(|| { + let mut hnsw = HyperbolicHnsw::default_config(); + for v in vecs { + hnsw.insert(v.clone()).unwrap(); + } + }) + }); + } + + group.finish(); +} + +fn bench_hnsw_search(c: &mut Criterion) { + let ks = [1, 5, 10, 50]; + + // Build index once + let mut hnsw = HyperbolicHnsw::default_config(); + for i in 0..1000 { + let v = vec![ + (i as f32 * 0.01) % 0.8, + ((i as f32 * 0.02) + 0.1) % 0.8, + ]; + hnsw.insert(v).unwrap(); + } + + let query = vec![0.4, 0.4]; + + let mut group = c.benchmark_group("hnsw_search"); + + for k in ks { + group.bench_with_input(BenchmarkId::new("k", k), &k, |b, &k| { + b.iter(|| hnsw.search(black_box(&query), k)) + }); + } + + group.finish(); +} + +fn bench_tangent_cache(c: &mut Criterion) { + let sizes = [100, 500, 1000]; + + let mut group = c.benchmark_group("tangent_cache"); + group.sample_size(20); + + for size in sizes { + let points: Vec> = (0..size) + .map(|i| vec![ + (i as f32 * 0.01) % 0.8, + ((i as f32 * 0.02) + 0.1) % 0.8, + ]) + .collect(); + let indices: Vec = (0..size).collect(); + + group.bench_with_input(BenchmarkId::new("build", size), &(points.clone(), indices.clone()), |b, (p, i)| { + b.iter(|| TangentCache::new(black_box(p), black_box(i), 1.0)) + }); + } + + group.finish(); +} + +fn bench_search_with_pruning(c: &mut Criterion) { + // Build index with tangent cache + let mut hnsw = HyperbolicHnsw::default_config(); + for i in 0..1000 { + let v = vec![ + (i as f32 * 0.01) % 0.8, + ((i as f32 * 0.02) + 0.1) % 0.8, + ]; + hnsw.insert(v).unwrap(); + } + hnsw.build_tangent_cache().unwrap(); + + let query = vec![0.4, 0.4]; + + let mut group = c.benchmark_group("search_comparison"); + + group.bench_function("standard_search", |b| { + b.iter(|| hnsw.search(black_box(&query), 10)) + }); + + group.bench_function("pruning_search", |b| { + b.iter(|| hnsw.search_with_pruning(black_box(&query), 10)) + }); + + group.finish(); +} + +criterion_group!( + benches, + bench_poincare_distance, + bench_mobius_add, + bench_exp_log_map, + bench_hnsw_insert, + bench_hnsw_search, + bench_tangent_cache, + bench_search_with_pruning, +); + +criterion_main!(benches); diff --git a/crates/ruvector-hyperbolic-hnsw/src/error.rs b/crates/ruvector-hyperbolic-hnsw/src/error.rs new file mode 100644 index 000000000..3161c5283 --- /dev/null +++ b/crates/ruvector-hyperbolic-hnsw/src/error.rs @@ -0,0 +1,42 @@ +//! Error types for hyperbolic HNSW operations + +use thiserror::Error; + +/// Errors that can occur during hyperbolic operations +#[derive(Error, Debug, Clone)] +pub enum HyperbolicError { + /// Vector is outside the Poincaré ball + #[error("Vector norm {norm} exceeds ball radius (1/sqrt(c) - eps) for curvature c={curvature}")] + OutsideBall { norm: f32, curvature: f32 }, + + /// Invalid curvature parameter + #[error("Invalid curvature: {0}. Must be positive.")] + InvalidCurvature(f32), + + /// Dimension mismatch between vectors + #[error("Dimension mismatch: expected {expected}, got {got}")] + DimensionMismatch { expected: usize, got: usize }, + + /// Numerical instability detected + #[error("Numerical instability: {0}")] + NumericalInstability(String), + + /// Shard not found + #[error("Shard not found: {0}")] + ShardNotFound(String), + + /// Index out of bounds + #[error("Index {index} out of bounds for size {size}")] + IndexOutOfBounds { index: usize, size: usize }, + + /// Empty collection + #[error("Cannot perform operation on empty collection")] + EmptyCollection, + + /// Search failed + #[error("Search failed: {0}")] + SearchFailed(String), +} + +/// Result type for hyperbolic operations +pub type HyperbolicResult = Result; diff --git a/crates/ruvector-hyperbolic-hnsw/src/hnsw.rs b/crates/ruvector-hyperbolic-hnsw/src/hnsw.rs new file mode 100644 index 000000000..104848a58 --- /dev/null +++ b/crates/ruvector-hyperbolic-hnsw/src/hnsw.rs @@ -0,0 +1,650 @@ +//! HNSW Adapter with Hyperbolic Distance Support +//! +//! This module provides HNSW (Hierarchical Navigable Small World) graph +//! implementation optimized for hyperbolic space using the Poincaré ball model. +//! +//! # Key Features +//! +//! - Hyperbolic distance metric for neighbor selection +//! - Tangent space pruning for accelerated search +//! - Configurable curvature per index +//! - Dual-space search (Euclidean fallback) + +use crate::error::{HyperbolicError, HyperbolicResult}; +use crate::poincare::{fused_norms, norm_squared, poincare_distance, poincare_distance_from_norms, project_to_ball, EPS}; +use crate::tangent::TangentCache; +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +/// Distance metric type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum DistanceMetric { + /// Poincaré ball hyperbolic distance + Poincare, + /// Standard Euclidean distance + Euclidean, + /// Cosine similarity (converted to distance) + Cosine, + /// Hybrid: Euclidean for pruning, Poincaré for ranking + Hybrid, +} + +/// HNSW configuration for hyperbolic space +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HyperbolicHnswConfig { + /// Maximum number of connections per node (M parameter) + pub max_connections: usize, + /// Maximum connections for layer 0 (M0 = 2*M typically) + pub max_connections_0: usize, + /// Size of dynamic candidate list during construction (ef_construction) + pub ef_construction: usize, + /// Size of dynamic candidate list during search (ef) + pub ef_search: usize, + /// Level multiplier for layer selection (ml = 1/ln(M)) + pub level_mult: f32, + /// Curvature parameter for Poincaré ball + pub curvature: f32, + /// Distance metric + pub metric: DistanceMetric, + /// Pruning factor for tangent space optimization + pub prune_factor: usize, + /// Whether to use tangent space pruning + pub use_tangent_pruning: bool, +} + +impl Default for HyperbolicHnswConfig { + fn default() -> Self { + Self { + max_connections: 16, + max_connections_0: 32, + ef_construction: 200, + ef_search: 50, + level_mult: 1.0 / (16.0_f32).ln(), + curvature: 1.0, + metric: DistanceMetric::Poincare, + prune_factor: 10, + use_tangent_pruning: true, + } + } +} + +/// A node in the HNSW graph +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HnswNode { + /// Node ID + pub id: usize, + /// Vector in Poincaré ball + pub vector: Vec, + /// Connections at each level (level -> neighbor ids) + pub connections: Vec>, + /// Maximum level this node appears in + pub level: usize, +} + +impl HnswNode { + pub fn new(id: usize, vector: Vec, max_level: usize) -> Self { + let connections = (0..=max_level).map(|_| Vec::new()).collect(); + Self { + id, + vector, + connections, + level: max_level, + } + } +} + +/// Search result with distance +#[derive(Debug, Clone)] +pub struct SearchResult { + pub id: usize, + pub distance: f32, +} + +impl PartialEq for SearchResult { + fn eq(&self, other: &Self) -> bool { + self.distance == other.distance + } +} + +impl Eq for SearchResult {} + +impl PartialOrd for SearchResult { + fn partial_cmp(&self, other: &Self) -> Option { + self.distance.partial_cmp(&other.distance) + } +} + +impl Ord for SearchResult { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.distance.partial_cmp(&other.distance).unwrap() + } +} + +/// Hyperbolic HNSW Index +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HyperbolicHnsw { + /// Configuration + pub config: HyperbolicHnswConfig, + /// All nodes in the graph + nodes: Vec, + /// Entry point node ID + entry_point: Option, + /// Maximum level in the graph + max_level: usize, + /// Tangent cache for pruning (not serialized) + #[serde(skip)] + tangent_cache: Option, +} + +impl HyperbolicHnsw { + /// Create a new empty HNSW index + pub fn new(config: HyperbolicHnswConfig) -> Self { + Self { + config, + nodes: Vec::new(), + entry_point: None, + max_level: 0, + tangent_cache: None, + } + } + + /// Create with default configuration + pub fn default_config() -> Self { + Self::new(HyperbolicHnswConfig::default()) + } + + /// Get the number of nodes in the index + pub fn len(&self) -> usize { + self.nodes.len() + } + + /// Check if the index is empty + pub fn is_empty(&self) -> bool { + self.nodes.is_empty() + } + + /// Get the dimension of vectors + pub fn dim(&self) -> Option { + self.nodes.first().map(|n| n.vector.len()) + } + + /// Compute distance between two vectors (optimized with fused norms) + #[inline] + fn distance(&self, a: &[f32], b: &[f32]) -> f32 { + match self.config.metric { + DistanceMetric::Poincare | DistanceMetric::Hybrid => { + // Use fused_norms for single-pass computation + let (diff_sq, norm_a_sq, norm_b_sq) = fused_norms(a, b); + poincare_distance_from_norms(diff_sq, norm_a_sq, norm_b_sq, self.config.curvature) + } + DistanceMetric::Euclidean => { + let (diff_sq, _, _) = fused_norms(a, b); + diff_sq.sqrt() + } + DistanceMetric::Cosine => { + let len = a.len().min(b.len()); + let mut dot_ab = 0.0f32; + let mut norm_a_sq = 0.0f32; + let mut norm_b_sq = 0.0f32; + + // Fused computation + for i in 0..len { + let ai = a[i]; + let bi = b[i]; + dot_ab += ai * bi; + norm_a_sq += ai * ai; + norm_b_sq += bi * bi; + } + + let norm_prod = (norm_a_sq * norm_b_sq).sqrt(); + 1.0 - dot_ab / (norm_prod + EPS) + } + } + } + + /// Compute distance with pre-computed query norm (for batch search) + #[inline] + fn distance_with_query_norm(&self, query: &[f32], query_norm_sq: f32, point: &[f32]) -> f32 { + match self.config.metric { + DistanceMetric::Poincare | DistanceMetric::Hybrid => { + let (diff_sq, _, point_norm_sq) = fused_norms(query, point); + poincare_distance_from_norms(diff_sq, query_norm_sq, point_norm_sq, self.config.curvature) + } + _ => self.distance(query, point) + } + } + + /// Generate random level for a new node + fn random_level(&self) -> usize { + let r: f32 = rand::random(); + (-r.ln() * self.config.level_mult) as usize + } + + /// Insert a vector into the index + pub fn insert(&mut self, vector: Vec) -> HyperbolicResult { + // Project to ball for safety + let vector = project_to_ball(&vector, self.config.curvature, EPS); + + let id = self.nodes.len(); + let level = self.random_level(); + + // Create new node + let node = HnswNode::new(id, vector.clone(), level); + self.nodes.push(node); + + if self.entry_point.is_none() { + self.entry_point = Some(id); + self.max_level = level; + return Ok(id); + } + + let entry_id = self.entry_point.unwrap(); + + // Search for entry point at top levels + let mut current = entry_id; + for l in (level + 1..=self.max_level).rev() { + current = self.search_layer_single(&vector, current, l)?; + } + + // Insert at levels [0, min(level, max_level)] + let insert_level = level.min(self.max_level); + for l in (0..=insert_level).rev() { + let neighbors = self.search_layer(&vector, current, self.config.ef_construction, l)?; + + // Select best neighbors + let max_conn = if l == 0 { + self.config.max_connections_0 + } else { + self.config.max_connections + }; + + let selected: Vec = neighbors.iter().take(max_conn).map(|r| r.id).collect(); + + // Add bidirectional connections + self.nodes[id].connections[l] = selected.clone(); + + for &neighbor_id in &selected { + self.nodes[neighbor_id].connections[l].push(id); + + // Prune if too many connections + if self.nodes[neighbor_id].connections[l].len() > max_conn { + self.prune_connections(neighbor_id, l, max_conn)?; + } + } + + if !neighbors.is_empty() { + current = neighbors[0].id; + } + } + + // Update entry point if new node has higher level + if level > self.max_level { + self.entry_point = Some(id); + self.max_level = level; + } + + // Invalidate tangent cache + self.tangent_cache = None; + + Ok(id) + } + + /// Insert batch of vectors + pub fn insert_batch(&mut self, vectors: Vec>) -> HyperbolicResult> { + let mut ids = Vec::with_capacity(vectors.len()); + for vector in vectors { + ids.push(self.insert(vector)?); + } + Ok(ids) + } + + /// Search for single nearest neighbor at a layer (greedy) + fn search_layer_single(&self, query: &[f32], entry: usize, level: usize) -> HyperbolicResult { + let mut current = entry; + let mut current_dist = self.distance(query, &self.nodes[current].vector); + + loop { + let mut changed = false; + + for &neighbor in &self.nodes[current].connections[level] { + let dist = self.distance(query, &self.nodes[neighbor].vector); + if dist < current_dist { + current_dist = dist; + current = neighbor; + changed = true; + } + } + + if !changed { + break; + } + } + + Ok(current) + } + + /// Search layer with ef candidates + fn search_layer( + &self, + query: &[f32], + entry: usize, + ef: usize, + level: usize, + ) -> HyperbolicResult> { + use std::collections::{BinaryHeap, HashSet}; + + let entry_dist = self.distance(query, &self.nodes[entry].vector); + + let mut visited = HashSet::new(); + visited.insert(entry); + + // Candidates (min-heap by distance) + let mut candidates: BinaryHeap> = BinaryHeap::new(); + candidates.push(std::cmp::Reverse(SearchResult { + id: entry, + distance: entry_dist, + })); + + // Results (max-heap by distance for easy pruning) + let mut results: BinaryHeap = BinaryHeap::new(); + results.push(SearchResult { + id: entry, + distance: entry_dist, + }); + + while let Some(std::cmp::Reverse(current)) = candidates.pop() { + // Check if we can stop early + if let Some(furthest) = results.peek() { + if current.distance > furthest.distance && results.len() >= ef { + break; + } + } + + // Explore neighbors + for &neighbor in &self.nodes[current.id].connections[level] { + if visited.contains(&neighbor) { + continue; + } + visited.insert(neighbor); + + let dist = self.distance(query, &self.nodes[neighbor].vector); + + let should_add = results.len() < ef + || results + .peek() + .map(|r| dist < r.distance) + .unwrap_or(true); + + if should_add { + candidates.push(std::cmp::Reverse(SearchResult { + id: neighbor, + distance: dist, + })); + results.push(SearchResult { + id: neighbor, + distance: dist, + }); + + if results.len() > ef { + results.pop(); + } + } + } + } + + let mut result_vec: Vec = results.into_iter().collect(); + result_vec.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + + Ok(result_vec) + } + + /// Prune connections to keep only the best + fn prune_connections( + &mut self, + node_id: usize, + level: usize, + max_conn: usize, + ) -> HyperbolicResult<()> { + let node_vector = self.nodes[node_id].vector.clone(); + let connections = &self.nodes[node_id].connections[level]; + + let mut scored: Vec<(usize, f32)> = connections + .iter() + .map(|&id| (id, self.distance(&node_vector, &self.nodes[id].vector))) + .collect(); + + scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + + self.nodes[node_id].connections[level] = + scored.into_iter().take(max_conn).map(|(id, _)| id).collect(); + + Ok(()) + } + + /// Search for k nearest neighbors + pub fn search(&self, query: &[f32], k: usize) -> HyperbolicResult> { + if self.is_empty() { + return Ok(Vec::new()); + } + + let query = project_to_ball(query, self.config.curvature, EPS); + let entry = self.entry_point.unwrap(); + + // Navigate to lowest level from top + let mut current = entry; + for l in (1..=self.max_level).rev() { + current = self.search_layer_single(&query, current, l)?; + } + + // Search at layer 0 with ef_search candidates + let ef = self.config.ef_search.max(k); + let mut results = self.search_layer(&query, current, ef, 0)?; + + results.truncate(k); + Ok(results) + } + + /// Search with tangent space pruning (optimized for hyperbolic) + pub fn search_with_pruning(&self, query: &[f32], k: usize) -> HyperbolicResult> { + // Fall back to regular search if no tangent cache + if self.tangent_cache.is_none() || !self.config.use_tangent_pruning { + return self.search(query, k); + } + + let cache = self.tangent_cache.as_ref().unwrap(); + let query = project_to_ball(query, self.config.curvature, EPS); + + // Phase 1: Fast tangent space filtering + let query_tangent = cache.query_tangent(&query); + + let mut candidates: Vec<(usize, f32)> = (0..cache.len()) + .map(|i| { + let tangent_dist = cache.tangent_distance_squared(&query_tangent, i); + (cache.point_indices[i], tangent_dist) + }) + .collect(); + + // Sort by tangent distance + candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + + // Keep top prune_factor * k candidates + let num_candidates = (k * self.config.prune_factor).min(candidates.len()); + candidates.truncate(num_candidates); + + // Phase 2: Exact Poincaré distance for finalists + let mut results: Vec = candidates + .into_iter() + .map(|(id, _)| { + let dist = self.distance(&query, &self.nodes[id].vector); + SearchResult { id, distance: dist } + }) + .collect(); + + results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + results.truncate(k); + + Ok(results) + } + + /// Build tangent cache for all points + pub fn build_tangent_cache(&mut self) -> HyperbolicResult<()> { + if self.is_empty() { + return Ok(()); + } + + let vectors: Vec> = self.nodes.iter().map(|n| n.vector.clone()).collect(); + let indices: Vec = (0..self.nodes.len()).collect(); + + self.tangent_cache = Some(TangentCache::new(&vectors, &indices, self.config.curvature)?); + + Ok(()) + } + + /// Get a reference to a node's vector + pub fn get_vector(&self, id: usize) -> Option<&[f32]> { + self.nodes.get(id).map(|n| n.vector.as_slice()) + } + + /// Update curvature and rebuild tangent cache + pub fn set_curvature(&mut self, curvature: f32) -> HyperbolicResult<()> { + if curvature <= 0.0 { + return Err(HyperbolicError::InvalidCurvature(curvature)); + } + + self.config.curvature = curvature; + + // Reproject all vectors + for node in &mut self.nodes { + node.vector = project_to_ball(&node.vector, curvature, EPS); + } + + // Rebuild tangent cache + if self.tangent_cache.is_some() { + self.build_tangent_cache()?; + } + + Ok(()) + } + + /// Get all vectors as a slice + pub fn vectors(&self) -> Vec<&[f32]> { + self.nodes.iter().map(|n| n.vector.as_slice()).collect() + } +} + +/// Dual-space index for fallback and mutual ranking fusion +#[derive(Debug)] +pub struct DualSpaceIndex { + /// Hyperbolic index (primary) + pub hyperbolic: HyperbolicHnsw, + /// Euclidean index (fallback) + pub euclidean: HyperbolicHnsw, + /// Fusion weight for hyperbolic results (0-1) + pub fusion_weight: f32, +} + +impl DualSpaceIndex { + /// Create a new dual-space index + pub fn new(curvature: f32, fusion_weight: f32) -> Self { + let mut hyp_config = HyperbolicHnswConfig::default(); + hyp_config.curvature = curvature; + hyp_config.metric = DistanceMetric::Poincare; + + let mut euc_config = HyperbolicHnswConfig::default(); + euc_config.metric = DistanceMetric::Euclidean; + + Self { + hyperbolic: HyperbolicHnsw::new(hyp_config), + euclidean: HyperbolicHnsw::new(euc_config), + fusion_weight: fusion_weight.clamp(0.0, 1.0), + } + } + + /// Insert into both indices + pub fn insert(&mut self, vector: Vec) -> HyperbolicResult { + self.euclidean.insert(vector.clone())?; + self.hyperbolic.insert(vector) + } + + /// Search with mutual ranking fusion + pub fn search(&self, query: &[f32], k: usize) -> HyperbolicResult> { + let hyp_results = self.hyperbolic.search(query, k * 2)?; + let euc_results = self.euclidean.search(query, k * 2)?; + + // Combine and re-rank using fusion + use std::collections::HashMap; + + let mut scores: HashMap = HashMap::new(); + + // Add hyperbolic scores + for (rank, r) in hyp_results.iter().enumerate() { + let score = self.fusion_weight * (1.0 / (rank as f32 + 1.0)); + *scores.entry(r.id).or_insert(0.0) += score; + } + + // Add Euclidean scores + for (rank, r) in euc_results.iter().enumerate() { + let score = (1.0 - self.fusion_weight) * (1.0 / (rank as f32 + 1.0)); + *scores.entry(r.id).or_insert(0.0) += score; + } + + // Sort by combined score (higher is better) + let mut combined: Vec<(usize, f32)> = scores.into_iter().collect(); + combined.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + + // Return top k with hyperbolic distances + Ok(combined + .into_iter() + .take(k) + .map(|(id, _)| { + let dist = self.hyperbolic.distance( + query, + self.hyperbolic.get_vector(id).unwrap_or(&[]), + ); + SearchResult { id, distance: dist } + }) + .collect()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_hnsw_insert_search() { + let mut hnsw = HyperbolicHnsw::default_config(); + + // Insert some vectors + for i in 0..10 { + let v = vec![0.1 * i as f32, 0.05 * i as f32]; + hnsw.insert(v).unwrap(); + } + + assert_eq!(hnsw.len(), 10); + + // Search + let query = vec![0.3, 0.15]; + let results = hnsw.search(&query, 3).unwrap(); + + assert_eq!(results.len(), 3); + assert!(results[0].distance <= results[1].distance); + } + + #[test] + fn test_dual_space() { + let mut dual = DualSpaceIndex::new(1.0, 0.5); + + for i in 0..10 { + let v = vec![0.1 * i as f32, 0.05 * i as f32]; + dual.insert(v).unwrap(); + } + + let query = vec![0.3, 0.15]; + let results = dual.search(&query, 3).unwrap(); + + assert_eq!(results.len(), 3); + } +} diff --git a/crates/ruvector-hyperbolic-hnsw/src/lib.rs b/crates/ruvector-hyperbolic-hnsw/src/lib.rs new file mode 100644 index 000000000..03435daf7 --- /dev/null +++ b/crates/ruvector-hyperbolic-hnsw/src/lib.rs @@ -0,0 +1,210 @@ +//! Hyperbolic Embeddings with HNSW Integration for RuVector +//! +//! This crate provides hyperbolic (Poincaré ball) embeddings integrated with +//! HNSW (Hierarchical Navigable Small World) graphs for hierarchy-aware +//! vector search. +//! +//! # Overview +//! +//! Hierarchies compress naturally in hyperbolic space. Taxonomies, catalogs, +//! ICD trees, product facets, org charts, and long-tail tags all fit better +//! than in Euclidean space, which means higher recall on deep leaves without +//! blowing up memory or latency. +//! +//! # Key Features +//! +//! - **Poincaré Ball Model**: Store vectors in the Poincaré ball with proper +//! geometric operations (Möbius addition, exp/log maps) +//! - **Tangent Space Pruning**: Prune HNSW candidates with cheap Euclidean +//! distance in tangent space before exact hyperbolic ranking +//! - **Per-Shard Curvature**: Different parts of the hierarchy can have +//! different optimal curvatures +//! - **Dual-Space Index**: Keep a synchronized Euclidean index for fallback +//! and mutual ranking fusion +//! +//! # Quick Start +//! +//! ```rust +//! use ruvector_hyperbolic_hnsw::{HyperbolicHnsw, HyperbolicHnswConfig}; +//! +//! // Create index with default settings +//! let mut index = HyperbolicHnsw::default_config(); +//! +//! // Insert vectors (automatically projected to Poincaré ball) +//! index.insert(vec![0.1, 0.2, 0.3]).unwrap(); +//! index.insert(vec![-0.1, 0.15, 0.25]).unwrap(); +//! index.insert(vec![0.2, -0.1, 0.1]).unwrap(); +//! +//! // Search for nearest neighbors +//! let results = index.search(&[0.15, 0.1, 0.2], 2).unwrap(); +//! for r in results { +//! println!("ID: {}, Distance: {:.4}", r.id, r.distance); +//! } +//! ``` +//! +//! # HNSW Speed Trick +//! +//! The core optimization is: +//! 1. Precompute `u = log_c(x)` at a shard centroid `c` +//! 2. During neighbor selection, use Euclidean `||u_q - u_p||` to prune +//! 3. Run exact Poincaré distance only on top N candidates before final ranking +//! +//! ```rust +//! use ruvector_hyperbolic_hnsw::{HyperbolicHnsw, HyperbolicHnswConfig}; +//! +//! let mut config = HyperbolicHnswConfig::default(); +//! config.use_tangent_pruning = true; +//! config.prune_factor = 10; // Consider 10x candidates in tangent space +//! +//! let mut index = HyperbolicHnsw::new(config); +//! // ... insert vectors ... +//! +//! // Build tangent cache for pruning optimization +//! # index.insert(vec![0.1, 0.2]).unwrap(); +//! index.build_tangent_cache().unwrap(); +//! +//! // Search with pruning +//! let results = index.search_with_pruning(&[0.1, 0.15], 5).unwrap(); +//! ``` +//! +//! # Sharded Index with Per-Shard Curvature +//! +//! ```rust +//! use ruvector_hyperbolic_hnsw::{ShardedHyperbolicHnsw, ShardStrategy}; +//! +//! let mut manager = ShardedHyperbolicHnsw::new(1.0); +//! +//! // Insert with hierarchy depth information +//! manager.insert(vec![0.1, 0.2], Some(0)).unwrap(); // Root level +//! manager.insert(vec![0.3, 0.1], Some(3)).unwrap(); // Deeper level +//! +//! // Update curvature for specific shard +//! manager.update_curvature("radius_1", 0.5).unwrap(); +//! +//! // Search across all shards +//! let results = manager.search(&[0.2, 0.15], 5).unwrap(); +//! ``` +//! +//! # Mathematical Operations +//! +//! The `poincare` module provides low-level hyperbolic geometry operations: +//! +//! ```rust +//! use ruvector_hyperbolic_hnsw::poincare::{ +//! mobius_add, exp_map, log_map, poincare_distance, project_to_ball +//! }; +//! +//! let x = vec![0.3, 0.2]; +//! let y = vec![-0.1, 0.4]; +//! let c = 1.0; // Curvature +//! +//! // Möbius addition (hyperbolic vector addition) +//! let z = mobius_add(&x, &y, c); +//! +//! // Geodesic distance in hyperbolic space +//! let d = poincare_distance(&x, &y, c); +//! +//! // Map to tangent space at x +//! let v = log_map(&y, &x, c); +//! +//! // Map back to manifold +//! let y_recovered = exp_map(&v, &x, c); +//! ``` +//! +//! # Numerical Stability +//! +//! All operations include numerical safeguards: +//! - Norm clamping with `eps = 1e-5` +//! - Projection after every update +//! - Stable `acosh` and `log1p` implementations +//! +//! # Feature Flags +//! +//! - `simd`: Enable SIMD acceleration (default) +//! - `parallel`: Enable parallel processing with rayon (default) +//! - `wasm`: Enable WebAssembly compatibility + +pub mod error; +pub mod hnsw; +pub mod poincare; +pub mod shard; +pub mod tangent; + +// Re-exports +pub use error::{HyperbolicError, HyperbolicResult}; +pub use hnsw::{ + DistanceMetric, DualSpaceIndex, HnswNode, HyperbolicHnsw, HyperbolicHnswConfig, SearchResult, +}; +pub use poincare::{ + conformal_factor, conformal_factor_from_norm_sq, dot, exp_map, frechet_mean, fused_norms, + hyperbolic_midpoint, log_map, log_map_at_centroid, mobius_add, mobius_add_inplace, + mobius_scalar_mult, norm, norm_squared, parallel_transport, poincare_distance, + poincare_distance_batch, poincare_distance_from_norms, poincare_distance_squared, + project_to_ball, project_to_ball_inplace, PoincareConfig, DEFAULT_CURVATURE, EPS, +}; +pub use shard::{ + CurvatureRegistry, HierarchyMetrics, HyperbolicShard, ShardCurvature, ShardStrategy, + ShardedHyperbolicHnsw, +}; +pub use tangent::{tangent_micro_update, PrunedCandidate, TangentCache, TangentPruner}; + +/// Library version +pub const VERSION: &str = env!("CARGO_PKG_VERSION"); + +/// Prelude for common imports +pub mod prelude { + pub use crate::error::{HyperbolicError, HyperbolicResult}; + pub use crate::hnsw::{HyperbolicHnsw, HyperbolicHnswConfig, SearchResult}; + pub use crate::poincare::{exp_map, log_map, mobius_add, poincare_distance, project_to_ball}; + pub use crate::shard::{ShardedHyperbolicHnsw, ShardStrategy}; + pub use crate::tangent::{TangentCache, TangentPruner}; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic_workflow() { + // Create index + let mut index = HyperbolicHnsw::default_config(); + + // Insert vectors + for i in 0..10 { + let v = vec![0.1 * i as f32, 0.05 * i as f32, 0.02 * i as f32]; + index.insert(v).unwrap(); + } + + // Search + let query = vec![0.35, 0.175, 0.07]; + let results = index.search(&query, 3).unwrap(); + + assert_eq!(results.len(), 3); + // Results should be sorted by distance + for i in 1..results.len() { + assert!(results[i - 1].distance <= results[i].distance); + } + } + + #[test] + fn test_hierarchy_preservation() { + // Create points at different "depths" + let points: Vec> = (0..20) + .map(|i| { + // Points further from origin represent deeper hierarchy + let depth = i / 4; + let radius = 0.1 + 0.15 * depth as f32; + let angle = (i % 4) as f32 * std::f32::consts::PI / 2.0; + vec![radius * angle.cos(), radius * angle.sin()] + }) + .collect(); + + let depths: Vec = (0..20).map(|i| i / 4).collect(); + + // Compute metrics + let metrics = HierarchyMetrics::compute(&points, &depths, 1.0).unwrap(); + + // Radius should correlate positively with depth + assert!(metrics.radius_depth_correlation > 0.5); + } +} diff --git a/crates/ruvector-hyperbolic-hnsw/src/poincare.rs b/crates/ruvector-hyperbolic-hnsw/src/poincare.rs new file mode 100644 index 000000000..14c6b3b3a --- /dev/null +++ b/crates/ruvector-hyperbolic-hnsw/src/poincare.rs @@ -0,0 +1,627 @@ +//! Poincaré Ball Model Operations for Hyperbolic Geometry +//! +//! This module implements core operations in the Poincaré ball model of hyperbolic space, +//! providing mathematically correct implementations with numerical stability guarantees. +//! +//! # Mathematical Background +//! +//! The Poincaré ball model represents hyperbolic space as the interior of a unit ball +//! in Euclidean space. Points are constrained to satisfy ||x|| < 1/√c where c > 0 is +//! the curvature parameter. +//! +//! # Key Operations +//! +//! - **Möbius Addition**: The hyperbolic analog of vector addition +//! - **Exponential Map**: Maps tangent vectors to the manifold +//! - **Logarithmic Map**: Maps manifold points to tangent space +//! - **Poincaré Distance**: The geodesic distance in hyperbolic space + +use crate::error::{HyperbolicError, HyperbolicResult}; +use serde::{Deserialize, Serialize}; + +/// Small epsilon for numerical stability (as specified: eps=1e-5) +pub const EPS: f32 = 1e-5; + +/// Default curvature parameter (negative curvature, c > 0) +pub const DEFAULT_CURVATURE: f32 = 1.0; + +/// Configuration for Poincaré ball operations +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub struct PoincareConfig { + /// Curvature parameter (c > 0 for hyperbolic space) + pub curvature: f32, + /// Numerical stability epsilon + pub eps: f32, + /// Maximum iterations for iterative algorithms (e.g., Fréchet mean) + pub max_iter: usize, + /// Convergence tolerance + pub tol: f32, +} + +impl Default for PoincareConfig { + fn default() -> Self { + Self { + curvature: DEFAULT_CURVATURE, + eps: EPS, + max_iter: 100, + tol: 1e-6, + } + } +} + +impl PoincareConfig { + /// Create configuration with custom curvature + pub fn with_curvature(curvature: f32) -> HyperbolicResult { + if curvature <= 0.0 { + return Err(HyperbolicError::InvalidCurvature(curvature)); + } + Ok(Self { + curvature, + ..Default::default() + }) + } + + /// Maximum allowed norm for points in the ball + #[inline] + pub fn max_norm(&self) -> f32 { + (1.0 / self.curvature.sqrt()) - self.eps + } +} + +// ============================================================================ +// Optimized Core Operations (SIMD-friendly) +// ============================================================================ + +/// Compute the squared Euclidean norm of a slice (optimized with unrolling) +#[inline] +pub fn norm_squared(x: &[f32]) -> f32 { + let len = x.len(); + let mut sum = 0.0f32; + + // Process 4 elements at a time for better SIMD utilization + let chunks = len / 4; + let remainder = len % 4; + + let mut i = 0; + for _ in 0..chunks { + let a = x[i]; + let b = x[i + 1]; + let c = x[i + 2]; + let d = x[i + 3]; + sum += a * a + b * b + c * c + d * d; + i += 4; + } + + // Handle remainder + for j in 0..remainder { + let v = x[i + j]; + sum += v * v; + } + + sum +} + +/// Compute the Euclidean norm of a slice +#[inline] +pub fn norm(x: &[f32]) -> f32 { + norm_squared(x).sqrt() +} + +/// Compute the dot product of two slices (optimized with unrolling) +#[inline] +pub fn dot(x: &[f32], y: &[f32]) -> f32 { + let len = x.len().min(y.len()); + let mut sum = 0.0f32; + + // Process 4 elements at a time + let chunks = len / 4; + let remainder = len % 4; + + let mut i = 0; + for _ in 0..chunks { + sum += x[i] * y[i] + x[i+1] * y[i+1] + x[i+2] * y[i+2] + x[i+3] * y[i+3]; + i += 4; + } + + for j in 0..remainder { + sum += x[i + j] * y[i + j]; + } + + sum +} + +/// Fused computation of ||u-v||², ||u||², ||v||² in single pass (3x faster) +#[inline] +pub fn fused_norms(u: &[f32], v: &[f32]) -> (f32, f32, f32) { + let len = u.len().min(v.len()); + let mut diff_sq = 0.0f32; + let mut norm_u_sq = 0.0f32; + let mut norm_v_sq = 0.0f32; + + // Process 4 elements at a time + let chunks = len / 4; + let remainder = len % 4; + + let mut i = 0; + for _ in 0..chunks { + let (u0, u1, u2, u3) = (u[i], u[i+1], u[i+2], u[i+3]); + let (v0, v1, v2, v3) = (v[i], v[i+1], v[i+2], v[i+3]); + let (d0, d1, d2, d3) = (u0 - v0, u1 - v1, u2 - v2, u3 - v3); + + diff_sq += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3; + norm_u_sq += u0 * u0 + u1 * u1 + u2 * u2 + u3 * u3; + norm_v_sq += v0 * v0 + v1 * v1 + v2 * v2 + v3 * v3; + i += 4; + } + + for j in 0..remainder { + let ui = u[i + j]; + let vi = v[i + j]; + let di = ui - vi; + diff_sq += di * di; + norm_u_sq += ui * ui; + norm_v_sq += vi * vi; + } + + (diff_sq, norm_u_sq, norm_v_sq) +} + +/// Project a point back into the Poincaré ball +/// +/// Ensures ||x|| < 1/√c - eps for numerical stability +#[inline] +pub fn project_to_ball(x: &[f32], c: f32, eps: f32) -> Vec { + let c = c.abs().max(EPS); + let norm_sq = norm_squared(x); + let max_norm = (1.0 / c.sqrt()) - eps; + let max_norm_sq = max_norm * max_norm; + + if norm_sq < max_norm_sq || norm_sq < eps * eps { + x.to_vec() + } else { + let scale = max_norm / norm_sq.sqrt(); + x.iter().map(|&xi| scale * xi).collect() + } +} + +/// Project in-place (avoids allocation when possible) +#[inline] +pub fn project_to_ball_inplace(x: &mut [f32], c: f32, eps: f32) { + let c = c.abs().max(EPS); + let norm_sq = norm_squared(x); + let max_norm = (1.0 / c.sqrt()) - eps; + let max_norm_sq = max_norm * max_norm; + + if norm_sq >= max_norm_sq && norm_sq >= eps * eps { + let scale = max_norm / norm_sq.sqrt(); + for xi in x.iter_mut() { + *xi *= scale; + } + } +} + +/// Compute the conformal factor λ_x at point x +/// +/// λ_x = 2 / (1 - c||x||²) +#[inline] +pub fn conformal_factor(x: &[f32], c: f32) -> f32 { + let norm_sq = norm_squared(x); + 2.0 / (1.0 - c * norm_sq).max(EPS) +} + +/// Conformal factor from pre-computed norm squared +#[inline] +pub fn conformal_factor_from_norm_sq(norm_sq: f32, c: f32) -> f32 { + 2.0 / (1.0 - c * norm_sq).max(EPS) +} + +// ============================================================================ +// Poincaré Distance (Optimized) +// ============================================================================ + +/// Poincaré distance between two points (optimized with fused norms) +/// +/// Uses the formula: +/// d(u, v) = (1/√c) acosh(1 + 2c ||u - v||² / ((1 - c||u||²)(1 - c||v||²))) +#[inline] +pub fn poincare_distance(u: &[f32], v: &[f32], c: f32) -> f32 { + let c = c.abs().max(EPS); + + // Fused computation: single pass for all three norms + let (diff_sq, norm_u_sq, norm_v_sq) = fused_norms(u, v); + + poincare_distance_from_norms(diff_sq, norm_u_sq, norm_v_sq, c) +} + +/// Poincaré distance from pre-computed norms (for batch operations) +#[inline] +pub fn poincare_distance_from_norms(diff_sq: f32, norm_u_sq: f32, norm_v_sq: f32, c: f32) -> f32 { + let sqrt_c = c.sqrt(); + + let lambda_u = (1.0 - c * norm_u_sq).max(EPS); + let lambda_v = (1.0 - c * norm_v_sq).max(EPS); + + let numerator = 2.0 * c * diff_sq; + let denominator = lambda_u * lambda_v; + + let arg = 1.0 + numerator / denominator; + + if arg <= 1.0 { + return 0.0; + } + + // Stable acosh computation + (1.0 / sqrt_c) * fast_acosh(arg) +} + +/// Fast acosh with numerical stability +#[inline] +fn fast_acosh(x: f32) -> f32 { + if x <= 1.0 { + return 0.0; + } + + let delta = x - 1.0; + if delta < 1e-4 { + // Taylor expansion for small delta: acosh(1+δ) ≈ √(2δ) + (2.0 * delta).sqrt() + } else if x < 1e6 { + // Standard formula: acosh(x) = ln(x + √(x²-1)) + (x + (x * x - 1.0).sqrt()).ln() + } else { + // For very large x: acosh(x) ≈ ln(2x) + (2.0 * x).ln() + } +} + +/// Squared Poincaré distance (faster for comparisons) +#[inline] +pub fn poincare_distance_squared(u: &[f32], v: &[f32], c: f32) -> f32 { + let d = poincare_distance(u, v, c); + d * d +} + +/// Batch distance computation (processes multiple pairs efficiently) +pub fn poincare_distance_batch( + query: &[f32], + points: &[&[f32]], + c: f32, +) -> Vec { + let c = c.abs().max(EPS); + let query_norm_sq = norm_squared(query); + + points + .iter() + .map(|point| { + let (diff_sq, _, point_norm_sq) = fused_norms(query, point); + poincare_distance_from_norms(diff_sq, query_norm_sq, point_norm_sq, c) + }) + .collect() +} + +// ============================================================================ +// Möbius Operations (Optimized) +// ============================================================================ + +/// Möbius addition in the Poincaré ball (optimized) +/// +/// x ⊕_c y = ((1 + 2c⟨x,y⟩ + c||y||²)x + (1 - c||x||²)y) / (1 + 2c⟨x,y⟩ + c²||x||²||y||²) +#[inline] +pub fn mobius_add(x: &[f32], y: &[f32], c: f32) -> Vec { + let c = c.abs().max(EPS); + + // Fused computation of norms and dot product + let len = x.len().min(y.len()); + let mut norm_x_sq = 0.0f32; + let mut norm_y_sq = 0.0f32; + let mut dot_xy = 0.0f32; + + // Process 4 elements at a time + let chunks = len / 4; + let remainder = len % 4; + + let mut i = 0; + for _ in 0..chunks { + let (x0, x1, x2, x3) = (x[i], x[i+1], x[i+2], x[i+3]); + let (y0, y1, y2, y3) = (y[i], y[i+1], y[i+2], y[i+3]); + + norm_x_sq += x0 * x0 + x1 * x1 + x2 * x2 + x3 * x3; + norm_y_sq += y0 * y0 + y1 * y1 + y2 * y2 + y3 * y3; + dot_xy += x0 * y0 + x1 * y1 + x2 * y2 + x3 * y3; + i += 4; + } + + for j in 0..remainder { + let xi = x[i + j]; + let yi = y[i + j]; + norm_x_sq += xi * xi; + norm_y_sq += yi * yi; + dot_xy += xi * yi; + } + + // Compute coefficients + let coef_x = 1.0 + 2.0 * c * dot_xy + c * norm_y_sq; + let coef_y = 1.0 - c * norm_x_sq; + let denom = (1.0 + 2.0 * c * dot_xy + c * c * norm_x_sq * norm_y_sq).max(EPS); + let inv_denom = 1.0 / denom; + + // Compute result + let mut result = Vec::with_capacity(len); + for j in 0..len { + result.push((coef_x * x[j] + coef_y * y[j]) * inv_denom); + } + + // Project back into ball + project_to_ball_inplace(&mut result, c, EPS); + result +} + +/// Möbius addition in-place (modifies first argument) +#[inline] +pub fn mobius_add_inplace(x: &mut [f32], y: &[f32], c: f32) { + let c = c.abs().max(EPS); + let len = x.len().min(y.len()); + + let norm_x_sq = norm_squared(x); + let norm_y_sq = norm_squared(y); + let dot_xy = dot(x, y); + + let coef_x = 1.0 + 2.0 * c * dot_xy + c * norm_y_sq; + let coef_y = 1.0 - c * norm_x_sq; + let denom = (1.0 + 2.0 * c * dot_xy + c * c * norm_x_sq * norm_y_sq).max(EPS); + let inv_denom = 1.0 / denom; + + for j in 0..len { + x[j] = (coef_x * x[j] + coef_y * y[j]) * inv_denom; + } + + project_to_ball_inplace(x, c, EPS); +} + +/// Möbius scalar multiplication +/// +/// r ⊗_c x = (1/√c) tanh(r · arctanh(√c ||x||)) · (x / ||x||) +pub fn mobius_scalar_mult(r: f32, x: &[f32], c: f32) -> Vec { + let c = c.abs().max(EPS); + let sqrt_c = c.sqrt(); + let norm_x = norm(x); + + if norm_x < EPS { + return x.to_vec(); + } + + let arctanh_arg = (sqrt_c * norm_x).min(1.0 - EPS); + let arctanh_val = arctanh_arg.atanh(); + let scale = (1.0 / sqrt_c) * (r * arctanh_val).tanh() / norm_x; + + x.iter().map(|&xi| scale * xi).collect() +} + +// ============================================================================ +// Exp/Log Maps (Optimized) +// ============================================================================ + +/// Exponential map at point p +/// +/// exp_p(v) = p ⊕_c (tanh(√c λ_p ||v|| / 2) · v / (√c ||v||)) +pub fn exp_map(v: &[f32], p: &[f32], c: f32) -> Vec { + let c = c.abs().max(EPS); + let sqrt_c = c.sqrt(); + + let norm_p_sq = norm_squared(p); + let lambda_p = conformal_factor_from_norm_sq(norm_p_sq, c); + + let norm_v = norm(v); + + if norm_v < EPS { + return p.to_vec(); + } + + let scaled_norm = sqrt_c * lambda_p * norm_v / 2.0; + let coef = scaled_norm.tanh() / (sqrt_c * norm_v); + + let transported: Vec = v.iter().map(|&vi| coef * vi).collect(); + + mobius_add(p, &transported, c) +} + +/// Logarithmic map at point p +/// +/// log_p(y) = (2 / (√c λ_p)) arctanh(√c ||−p ⊕_c y||) · (−p ⊕_c y) / ||−p ⊕_c y|| +pub fn log_map(y: &[f32], p: &[f32], c: f32) -> Vec { + let c = c.abs().max(EPS); + let sqrt_c = c.sqrt(); + + // Compute -p ⊕_c y + let neg_p: Vec = p.iter().map(|&pi| -pi).collect(); + let diff = mobius_add(&neg_p, y, c); + let norm_diff = norm(&diff); + + if norm_diff < EPS { + return vec![0.0; y.len()]; + } + + let norm_p_sq = norm_squared(p); + let lambda_p = conformal_factor_from_norm_sq(norm_p_sq, c); + + let arctanh_arg = (sqrt_c * norm_diff).min(1.0 - EPS); + let coef = (2.0 / (sqrt_c * lambda_p)) * arctanh_arg.atanh() / norm_diff; + + diff.iter().map(|&di| coef * di).collect() +} + +/// Logarithmic map at a shard centroid for tangent space coordinates +pub fn log_map_at_centroid(x: &[f32], centroid: &[f32], c: f32) -> Vec { + log_map(x, centroid, c) +} + +// ============================================================================ +// Fréchet Mean & Utilities +// ============================================================================ + +/// Compute the Fréchet mean (hyperbolic centroid) of points +pub fn frechet_mean( + points: &[&[f32]], + weights: Option<&[f32]>, + config: &PoincareConfig, +) -> HyperbolicResult> { + if points.is_empty() { + return Err(HyperbolicError::EmptyCollection); + } + + let dim = points[0].len(); + let c = config.curvature; + + // Validate dimensions + for p in points.iter() { + if p.len() != dim { + return Err(HyperbolicError::DimensionMismatch { + expected: dim, + got: p.len(), + }); + } + } + + // Set up weights + let uniform_weights: Vec; + let w = if let Some(weights) = weights { + if weights.len() != points.len() { + return Err(HyperbolicError::DimensionMismatch { + expected: points.len(), + got: weights.len(), + }); + } + weights + } else { + uniform_weights = vec![1.0 / points.len() as f32; points.len()]; + &uniform_weights + }; + + // Initialize with Euclidean weighted mean, projected to ball + let mut mean = vec![0.0; dim]; + for (point, &weight) in points.iter().zip(w) { + for (i, &val) in point.iter().enumerate() { + mean[i] += weight * val; + } + } + project_to_ball_inplace(&mut mean, c, config.eps); + + // Riemannian gradient descent + let learning_rate = 0.1; + let mut grad = vec![0.0; dim]; + + for _ in 0..config.max_iter { + // Reset gradient + for g in grad.iter_mut() { + *g = 0.0; + } + + // Compute Riemannian gradient + for (point, &weight) in points.iter().zip(w) { + let log_result = log_map(point, &mean, c); + for (i, &val) in log_result.iter().enumerate() { + grad[i] += weight * val; + } + } + + // Check convergence + if norm(&grad) < config.tol { + break; + } + + // Update step + let update: Vec = grad.iter().map(|&g| learning_rate * g).collect(); + mean = exp_map(&update, &mean, c); + } + + Ok(mean) +} + +/// Hyperbolic midpoint between two points +pub fn hyperbolic_midpoint(x: &[f32], y: &[f32], c: f32) -> Vec { + let log_y = log_map(y, x, c); + let half_log: Vec = log_y.iter().map(|&v| 0.5 * v).collect(); + exp_map(&half_log, x, c) +} + +/// Parallel transport a tangent vector from p to q +pub fn parallel_transport(v: &[f32], p: &[f32], q: &[f32], c: f32) -> Vec { + let c = c.abs().max(EPS); + + let lambda_p = conformal_factor(p, c); + let lambda_q = conformal_factor(q, c); + let scale = lambda_p / lambda_q; + + v.iter().map(|&vi| scale * vi).collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_project_to_ball() { + let x = vec![0.5, 0.5, 0.5]; + let projected = project_to_ball(&x, 1.0, EPS); + assert!(norm(&projected) < 1.0 - EPS); + } + + #[test] + fn test_mobius_add_identity() { + let x = vec![0.3, 0.2, 0.1]; + let zero = vec![0.0, 0.0, 0.0]; + + let result = mobius_add(&x, &zero, 1.0); + for (a, b) in x.iter().zip(result.iter()) { + assert!((a - b).abs() < 1e-5); + } + } + + #[test] + fn test_exp_log_inverse() { + let p = vec![0.1, 0.2, 0.1]; + let v = vec![0.1, -0.1, 0.05]; + + let q = exp_map(&v, &p, 1.0); + let v_recovered = log_map(&q, &p, 1.0); + + for (a, b) in v.iter().zip(v_recovered.iter()) { + assert!((a - b).abs() < 1e-4); + } + } + + #[test] + fn test_poincare_distance_symmetry() { + let u = vec![0.3, 0.2]; + let v = vec![-0.1, 0.4]; + + let d1 = poincare_distance(&u, &v, 1.0); + let d2 = poincare_distance(&v, &u, 1.0); + + assert!((d1 - d2).abs() < 1e-6); + } + + #[test] + fn test_poincare_distance_origin() { + let origin = vec![0.0, 0.0]; + let d = poincare_distance(&origin, &origin, 1.0); + assert!(d.abs() < 1e-6); + } + + #[test] + fn test_fused_norms() { + let u = vec![0.3, 0.2, 0.1]; + let v = vec![0.1, 0.4, 0.2]; + + let (diff_sq, norm_u_sq, norm_v_sq) = fused_norms(&u, &v); + + let expected_diff_sq: f32 = u.iter().zip(v.iter()) + .map(|(a, b)| (a - b) * (a - b)).sum(); + let expected_norm_u_sq = norm_squared(&u); + let expected_norm_v_sq = norm_squared(&v); + + assert!((diff_sq - expected_diff_sq).abs() < 1e-6); + assert!((norm_u_sq - expected_norm_u_sq).abs() < 1e-6); + assert!((norm_v_sq - expected_norm_v_sq).abs() < 1e-6); + } +} diff --git a/crates/ruvector-hyperbolic-hnsw/src/shard.rs b/crates/ruvector-hyperbolic-hnsw/src/shard.rs new file mode 100644 index 000000000..e4a3f5084 --- /dev/null +++ b/crates/ruvector-hyperbolic-hnsw/src/shard.rs @@ -0,0 +1,575 @@ +//! Shard Management with Curvature Registry +//! +//! This module implements per-shard curvature management for hierarchical data. +//! Different parts of the hierarchy may have different optimal curvatures. +//! +//! # Features +//! +//! - Per-shard curvature configuration +//! - Hot reload of curvature parameters +//! - Canary testing for curvature updates +//! - Hierarchy preservation metrics + +use crate::error::{HyperbolicError, HyperbolicResult}; +use crate::hnsw::{HyperbolicHnsw, HyperbolicHnswConfig, SearchResult}; +use crate::poincare::{frechet_mean, poincare_distance, project_to_ball, PoincareConfig, EPS}; +use crate::tangent::TangentCache; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[cfg(feature = "parallel")] +use rayon::prelude::*; + +/// Curvature configuration for a shard +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ShardCurvature { + /// Current active curvature + pub current: f32, + /// Canary curvature (for testing) + pub canary: Option, + /// Traffic percentage for canary (0-100) + pub canary_traffic: u8, + /// Learned curvature from data + pub learned: Option, + /// Last update timestamp + pub updated_at: i64, +} + +impl Default for ShardCurvature { + fn default() -> Self { + Self { + current: 1.0, + canary: None, + canary_traffic: 0, + learned: None, + updated_at: 0, + } + } +} + +impl ShardCurvature { + /// Get the effective curvature (considering canary traffic) + pub fn effective(&self, use_canary: bool) -> f32 { + if use_canary && self.canary.is_some() && self.canary_traffic > 0 { + self.canary.unwrap() + } else { + self.current + } + } + + /// Promote canary to current + pub fn promote_canary(&mut self) { + if let Some(c) = self.canary { + self.current = c; + self.canary = None; + self.canary_traffic = 0; + } + } + + /// Rollback canary + pub fn rollback_canary(&mut self) { + self.canary = None; + self.canary_traffic = 0; + } +} + +/// Curvature registry for managing per-shard curvatures +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct CurvatureRegistry { + /// Shard curvatures by shard ID + pub shards: HashMap, + /// Global default curvature + pub default_curvature: f32, + /// Registry version (for hot reload) + pub version: u64, +} + +impl CurvatureRegistry { + /// Create a new registry with default curvature + pub fn new(default_curvature: f32) -> Self { + Self { + shards: HashMap::new(), + default_curvature, + version: 0, + } + } + + /// Get curvature for a shard + pub fn get(&self, shard_id: &str) -> f32 { + self.shards + .get(shard_id) + .map(|s| s.current) + .unwrap_or(self.default_curvature) + } + + /// Get curvature with canary consideration + pub fn get_effective(&self, shard_id: &str, use_canary: bool) -> f32 { + self.shards + .get(shard_id) + .map(|s| s.effective(use_canary)) + .unwrap_or(self.default_curvature) + } + + /// Set curvature for a shard + pub fn set(&mut self, shard_id: &str, curvature: f32) { + let entry = self.shards.entry(shard_id.to_string()).or_default(); + entry.current = curvature; + entry.updated_at = chrono_timestamp(); + self.version += 1; + } + + /// Set canary curvature + pub fn set_canary(&mut self, shard_id: &str, curvature: f32, traffic: u8) { + let entry = self.shards.entry(shard_id.to_string()).or_default(); + entry.canary = Some(curvature); + entry.canary_traffic = traffic.min(100); + entry.updated_at = chrono_timestamp(); + self.version += 1; + } + + /// Promote all canaries + pub fn promote_all_canaries(&mut self) { + for (_, shard) in self.shards.iter_mut() { + shard.promote_canary(); + } + self.version += 1; + } + + /// Rollback all canaries + pub fn rollback_all_canaries(&mut self) { + for (_, shard) in self.shards.iter_mut() { + shard.rollback_canary(); + } + self.version += 1; + } + + /// Record learned curvature + pub fn set_learned(&mut self, shard_id: &str, curvature: f32) { + let entry = self.shards.entry(shard_id.to_string()).or_default(); + entry.learned = Some(curvature); + entry.updated_at = chrono_timestamp(); + } +} + +fn chrono_timestamp() -> i64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs() as i64) + .unwrap_or(0) +} + +/// A single shard in the sharded HNSW system +#[derive(Debug)] +pub struct HyperbolicShard { + /// Shard ID + pub id: String, + /// HNSW index for this shard + pub index: HyperbolicHnsw, + /// Tangent cache + pub tangent_cache: Option, + /// Shard centroid + pub centroid: Vec, + /// Hierarchy depth range (min, max) + pub depth_range: (usize, usize), + /// Number of vectors in shard + pub count: usize, +} + +impl HyperbolicShard { + /// Create a new shard + pub fn new(id: String, curvature: f32) -> Self { + let mut config = HyperbolicHnswConfig::default(); + config.curvature = curvature; + + Self { + id, + index: HyperbolicHnsw::new(config), + tangent_cache: None, + centroid: Vec::new(), + depth_range: (0, 0), + count: 0, + } + } + + /// Insert a vector + pub fn insert(&mut self, vector: Vec) -> HyperbolicResult { + let id = self.index.insert(vector)?; + self.count += 1; + // Invalidate tangent cache + self.tangent_cache = None; + Ok(id) + } + + /// Build tangent cache + pub fn build_cache(&mut self) -> HyperbolicResult<()> { + if self.count == 0 { + return Ok(()); + } + + let vectors: Vec> = self + .index + .vectors() + .iter() + .map(|v| v.to_vec()) + .collect(); + let indices: Vec = (0..vectors.len()).collect(); + + self.tangent_cache = Some(TangentCache::new( + &vectors, + &indices, + self.index.config.curvature, + )?); + + if let Some(cache) = &self.tangent_cache { + self.centroid = cache.centroid.clone(); + } + + Ok(()) + } + + /// Search with tangent pruning + pub fn search(&self, query: &[f32], k: usize) -> HyperbolicResult> { + self.index.search(query, k) + } + + /// Update curvature + pub fn set_curvature(&mut self, curvature: f32) -> HyperbolicResult<()> { + self.index.set_curvature(curvature)?; + // Rebuild cache with new curvature + if self.tangent_cache.is_some() { + self.build_cache()?; + } + Ok(()) + } +} + +/// Sharded hyperbolic HNSW manager +#[derive(Debug)] +pub struct ShardedHyperbolicHnsw { + /// Shards by ID + pub shards: HashMap, + /// Curvature registry + pub registry: CurvatureRegistry, + /// Global ID to shard mapping + pub id_to_shard: Vec<(String, usize)>, + /// Shard assignment strategy + pub strategy: ShardStrategy, +} + +/// Strategy for assigning vectors to shards +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ShardStrategy { + /// Assign by hash + Hash, + /// Assign by hierarchy depth + Depth, + /// Assign by radius (distance from origin) + Radius, + /// Round-robin + RoundRobin, +} + +impl Default for ShardStrategy { + fn default() -> Self { + Self::Radius + } +} + +impl ShardedHyperbolicHnsw { + /// Create a new sharded manager + pub fn new(default_curvature: f32) -> Self { + Self { + shards: HashMap::new(), + registry: CurvatureRegistry::new(default_curvature), + id_to_shard: Vec::new(), + strategy: ShardStrategy::default(), + } + } + + /// Create or get a shard + pub fn get_or_create_shard(&mut self, shard_id: &str) -> &mut HyperbolicShard { + let curvature = self.registry.get(shard_id); + self.shards + .entry(shard_id.to_string()) + .or_insert_with(|| HyperbolicShard::new(shard_id.to_string(), curvature)) + } + + /// Determine shard for a vector + pub fn assign_shard(&self, vector: &[f32], depth: Option) -> String { + match self.strategy { + ShardStrategy::Hash => { + let hash: u64 = vector.iter().fold(0u64, |acc, &v| { + acc.wrapping_add((v.to_bits() as u64).wrapping_mul(31)) + }); + format!("shard_{}", hash % (self.shards.len().max(1) as u64)) + } + ShardStrategy::Depth => { + let d = depth.unwrap_or(0); + format!("depth_{}", d / 10) // Group by depth buckets + } + ShardStrategy::Radius => { + let radius: f32 = vector.iter().map(|v| v * v).sum::().sqrt(); + let bucket = (radius * 10.0) as usize; + format!("radius_{}", bucket) + } + ShardStrategy::RoundRobin => { + let idx = self.id_to_shard.len() % self.shards.len().max(1); + self.shards + .keys() + .nth(idx) + .cloned() + .unwrap_or_else(|| "default".to_string()) + } + } + } + + /// Insert vector with automatic shard assignment + pub fn insert(&mut self, vector: Vec, depth: Option) -> HyperbolicResult { + let shard_id = self.assign_shard(&vector, depth); + let shard = self.get_or_create_shard(&shard_id); + let local_id = shard.insert(vector)?; + + let global_id = self.id_to_shard.len(); + self.id_to_shard.push((shard_id, local_id)); + + Ok(global_id) + } + + /// Insert into specific shard + pub fn insert_to_shard( + &mut self, + shard_id: &str, + vector: Vec, + ) -> HyperbolicResult { + let shard = self.get_or_create_shard(shard_id); + let local_id = shard.insert(vector)?; + + let global_id = self.id_to_shard.len(); + self.id_to_shard.push((shard_id.to_string(), local_id)); + + Ok(global_id) + } + + /// Search across all shards + pub fn search(&self, query: &[f32], k: usize) -> HyperbolicResult> { + let mut all_results: Vec<(usize, SearchResult)> = Vec::new(); + + for (shard_id, shard) in &self.shards { + let results = shard.search(query, k)?; + for result in results { + // Map local ID to global ID + if let Some((global_id, _)) = self.id_to_shard.iter().enumerate().find(|(_, (s, l))| s == shard_id && *l == result.id) { + all_results.push((global_id, result)); + } + } + } + + // Sort by distance and take top k + all_results.sort_by(|a, b| a.1.distance.partial_cmp(&b.1.distance).unwrap()); + all_results.truncate(k); + + Ok(all_results) + } + + /// Build all tangent caches + pub fn build_caches(&mut self) -> HyperbolicResult<()> { + for shard in self.shards.values_mut() { + shard.build_cache()?; + } + Ok(()) + } + + /// Update curvature for a shard + pub fn update_curvature(&mut self, shard_id: &str, curvature: f32) -> HyperbolicResult<()> { + self.registry.set(shard_id, curvature); + if let Some(shard) = self.shards.get_mut(shard_id) { + shard.set_curvature(curvature)?; + } + Ok(()) + } + + /// Hot reload curvatures from registry + pub fn reload_curvatures(&mut self) -> HyperbolicResult<()> { + for (shard_id, shard) in self.shards.iter_mut() { + let curvature = self.registry.get(shard_id); + shard.set_curvature(curvature)?; + } + Ok(()) + } + + /// Get total vector count + pub fn len(&self) -> usize { + self.id_to_shard.len() + } + + /// Check if empty + pub fn is_empty(&self) -> bool { + self.id_to_shard.is_empty() + } + + /// Get number of shards + pub fn num_shards(&self) -> usize { + self.shards.len() + } +} + +/// Metrics for hierarchy preservation +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct HierarchyMetrics { + /// Spearman correlation between radius and depth + pub radius_depth_correlation: f32, + /// Average distance distortion + pub distance_distortion: f32, + /// Ancestor preservation (AUPRC) + pub ancestor_auprc: f32, + /// Mean rank + pub mean_rank: f32, + /// NDCG scores + pub ndcg: HashMap, +} + +impl HierarchyMetrics { + /// Compute hierarchy metrics + pub fn compute( + points: &[Vec], + depths: &[usize], + curvature: f32, + ) -> HyperbolicResult { + if points.is_empty() || points.len() != depths.len() { + return Err(HyperbolicError::EmptyCollection); + } + + // Compute radii + let radii: Vec = points + .iter() + .map(|p| p.iter().map(|v| v * v).sum::().sqrt()) + .collect(); + + // Spearman correlation between radius and depth + let radius_depth_correlation = spearman_correlation(&radii, depths); + + // Distance distortion (sample-based for efficiency) + let sample_size = points.len().min(100); + let mut distortion_sum = 0.0; + let mut distortion_count = 0; + + for i in 0..sample_size { + for j in (i + 1)..sample_size { + let hyp_dist = poincare_distance(&points[i], &points[j], curvature); + let depth_diff = (depths[i] as f32 - depths[j] as f32).abs(); + + if depth_diff > 0.0 { + distortion_sum += (hyp_dist - depth_diff).abs() / depth_diff; + distortion_count += 1; + } + } + } + + let distance_distortion = if distortion_count > 0 { + distortion_sum / distortion_count as f32 + } else { + 0.0 + }; + + Ok(Self { + radius_depth_correlation, + distance_distortion, + ancestor_auprc: 0.0, // Requires ground truth + mean_rank: 0.0, // Requires ground truth + ndcg: HashMap::new(), + }) + } +} + +/// Compute Spearman rank correlation +fn spearman_correlation(x: &[f32], y: &[usize]) -> f32 { + if x.len() != y.len() || x.is_empty() { + return 0.0; + } + + let n = x.len(); + + // Compute ranks for x + let mut x_indexed: Vec<(usize, f32)> = x.iter().cloned().enumerate().collect(); + x_indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + let mut x_ranks = vec![0.0; n]; + for (rank, (idx, _)) in x_indexed.iter().enumerate() { + x_ranks[*idx] = rank as f32; + } + + // Compute ranks for y + let mut y_indexed: Vec<(usize, usize)> = y.iter().cloned().enumerate().collect(); + y_indexed.sort_by_key(|a| a.1); + let mut y_ranks = vec![0.0; n]; + for (rank, (idx, _)) in y_indexed.iter().enumerate() { + y_ranks[*idx] = rank as f32; + } + + // Compute Spearman correlation + let mean_x: f32 = x_ranks.iter().sum::() / n as f32; + let mean_y: f32 = y_ranks.iter().sum::() / n as f32; + + let mut cov = 0.0; + let mut var_x = 0.0; + let mut var_y = 0.0; + + for i in 0..n { + let dx = x_ranks[i] - mean_x; + let dy = y_ranks[i] - mean_y; + cov += dx * dy; + var_x += dx * dx; + var_y += dy * dy; + } + + if var_x == 0.0 || var_y == 0.0 { + return 0.0; + } + + cov / (var_x * var_y).sqrt() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_curvature_registry() { + let mut registry = CurvatureRegistry::new(1.0); + + registry.set("shard_1", 0.5); + assert_eq!(registry.get("shard_1"), 0.5); + assert_eq!(registry.get("shard_2"), 1.0); // Default + + registry.set_canary("shard_1", 0.3, 50); + assert_eq!(registry.get_effective("shard_1", false), 0.5); + assert_eq!(registry.get_effective("shard_1", true), 0.3); + } + + #[test] + fn test_sharded_hnsw() { + let mut manager = ShardedHyperbolicHnsw::new(1.0); + + for i in 0..20 { + let v = vec![0.1 * i as f32, 0.05 * i as f32]; + manager.insert(v, Some(i / 5)).unwrap(); + } + + assert_eq!(manager.len(), 20); + + let query = vec![0.3, 0.15]; + let results = manager.search(&query, 5).unwrap(); + assert!(!results.is_empty()); + } + + #[test] + fn test_spearman() { + let x = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let y = vec![1, 2, 3, 4, 5]; + let corr = spearman_correlation(&x, &y); + assert!((corr - 1.0).abs() < 0.01); + + let y_rev = vec![5, 4, 3, 2, 1]; + let corr_rev = spearman_correlation(&x, &y_rev); + assert!((corr_rev + 1.0).abs() < 0.01); + } +} diff --git a/crates/ruvector-hyperbolic-hnsw/src/tangent.rs b/crates/ruvector-hyperbolic-hnsw/src/tangent.rs new file mode 100644 index 000000000..07bb9cf24 --- /dev/null +++ b/crates/ruvector-hyperbolic-hnsw/src/tangent.rs @@ -0,0 +1,348 @@ +//! Tangent Space Operations for HNSW Pruning Optimization +//! +//! This module implements the key optimization for hyperbolic HNSW: +//! - Precompute tangent space coordinates at shard centroids +//! - Use cheap Euclidean distance in tangent space for pruning +//! - Only compute exact Poincaré distance for final ranking +//! +//! # HNSW Speed Trick +//! +//! The core insight is that for points near a centroid c: +//! 1. Map points to tangent space: u = log_c(x) +//! 2. Euclidean distance ||u_q - u_p|| approximates hyperbolic distance +//! 3. Prune candidates using fast Euclidean comparisons +//! 4. Rank final top-N candidates with exact Poincaré distance + +use crate::error::{HyperbolicError, HyperbolicResult}; +use crate::poincare::{ + conformal_factor, frechet_mean, log_map, norm, norm_squared, poincare_distance, + project_to_ball, PoincareConfig, EPS, +}; +use serde::{Deserialize, Serialize}; + +/// Tangent space cache for a shard +/// +/// Stores precomputed tangent coordinates for fast pruning. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TangentCache { + /// Centroid point (base of tangent space) + pub centroid: Vec, + /// Precomputed tangent coordinates for all points in shard + pub tangent_coords: Vec>, + /// Original point indices + pub point_indices: Vec, + /// Curvature parameter + pub curvature: f32, + /// Cached conformal factor at centroid + conformal: f32, +} + +impl TangentCache { + /// Create a new tangent cache for a shard + /// + /// # Arguments + /// * `points` - Points in the shard (Poincaré ball coordinates) + /// * `indices` - Original indices of the points + /// * `curvature` - Curvature parameter + pub fn new(points: &[Vec], indices: &[usize], curvature: f32) -> HyperbolicResult { + if points.is_empty() { + return Err(HyperbolicError::EmptyCollection); + } + + let config = PoincareConfig::with_curvature(curvature)?; + + // Compute centroid as Fréchet mean + let point_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect(); + let centroid = frechet_mean(&point_refs, None, &config)?; + + // Precompute tangent coordinates + let tangent_coords: Vec> = points + .iter() + .map(|p| log_map(p, ¢roid, curvature)) + .collect(); + + let conformal = conformal_factor(¢roid, curvature); + + Ok(Self { + centroid, + tangent_coords, + point_indices: indices.to_vec(), + curvature, + conformal, + }) + } + + /// Create from centroid directly (for incremental updates) + pub fn from_centroid( + centroid: Vec, + points: &[Vec], + indices: &[usize], + curvature: f32, + ) -> HyperbolicResult { + let tangent_coords: Vec> = points + .iter() + .map(|p| log_map(p, ¢roid, curvature)) + .collect(); + + let conformal = conformal_factor(¢roid, curvature); + + Ok(Self { + centroid, + tangent_coords, + point_indices: indices.to_vec(), + curvature, + conformal, + }) + } + + /// Get tangent coordinates for a query point + pub fn query_tangent(&self, query: &[f32]) -> Vec { + log_map(query, &self.centroid, self.curvature) + } + + /// Fast Euclidean distance in tangent space (for pruning) + #[inline] + pub fn tangent_distance_squared(&self, query_tangent: &[f32], idx: usize) -> f32 { + if idx >= self.tangent_coords.len() { + return f32::MAX; + } + + let p = &self.tangent_coords[idx]; + query_tangent + .iter() + .zip(p.iter()) + .map(|(&a, &b)| (a - b) * (a - b)) + .sum() + } + + /// Exact Poincaré distance for final ranking + pub fn exact_distance(&self, query: &[f32], idx: usize, points: &[Vec]) -> f32 { + if idx >= points.len() { + return f32::MAX; + } + poincare_distance(query, &points[idx], self.curvature) + } + + /// Add a new point to the cache (for incremental updates) + pub fn add_point(&mut self, point: &[f32], index: usize) { + let tangent = log_map(point, &self.centroid, self.curvature); + self.tangent_coords.push(tangent); + self.point_indices.push(index); + } + + /// Update centroid and recompute all tangent coordinates + pub fn recompute_centroid(&mut self, points: &[Vec]) -> HyperbolicResult<()> { + if points.is_empty() { + return Err(HyperbolicError::EmptyCollection); + } + + let config = PoincareConfig::with_curvature(self.curvature)?; + let point_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect(); + self.centroid = frechet_mean(&point_refs, None, &config)?; + + self.tangent_coords = points + .iter() + .map(|p| log_map(p, &self.centroid, self.curvature)) + .collect(); + + self.conformal = conformal_factor(&self.centroid, self.curvature); + + Ok(()) + } + + /// Get number of points in cache + pub fn len(&self) -> usize { + self.tangent_coords.len() + } + + /// Check if cache is empty + pub fn is_empty(&self) -> bool { + self.tangent_coords.is_empty() + } + + /// Get the dimension of the tangent space + pub fn dim(&self) -> usize { + self.centroid.len() + } +} + +/// Tangent space pruning result +#[derive(Debug, Clone)] +pub struct PrunedCandidate { + /// Original index + pub index: usize, + /// Tangent space distance (for initial ranking) + pub tangent_dist: f32, + /// Exact Poincaré distance (computed lazily) + pub exact_dist: Option, +} + +/// Tangent space pruner for HNSW neighbor selection +/// +/// Implements the two-phase search: +/// 1. Fast pruning using Euclidean distance in tangent space +/// 2. Exact ranking using Poincaré distance for top candidates +pub struct TangentPruner { + /// Tangent caches for each shard + caches: Vec, + /// Number of candidates to consider in exact phase + top_n: usize, + /// Pruning factor (how many candidates to keep from tangent phase) + prune_factor: usize, +} + +impl TangentPruner { + /// Create a new pruner + /// + /// # Arguments + /// * `top_n` - Number of final results + /// * `prune_factor` - Multiplier for candidates to consider (e.g., 10 means consider 10*top_n) + pub fn new(top_n: usize, prune_factor: usize) -> Self { + Self { + caches: Vec::new(), + top_n, + prune_factor, + } + } + + /// Add a shard cache + pub fn add_cache(&mut self, cache: TangentCache) { + self.caches.push(cache); + } + + /// Get shard caches + pub fn caches(&self) -> &[TangentCache] { + &self.caches + } + + /// Get mutable shard caches + pub fn caches_mut(&mut self) -> &mut [TangentCache] { + &mut self.caches + } + + /// Search across all shards with tangent pruning + /// + /// Returns top_n candidates sorted by exact Poincaré distance. + pub fn search( + &self, + query: &[f32], + points: &[Vec], + curvature: f32, + ) -> Vec { + let num_prune = self.top_n * self.prune_factor; + let mut candidates: Vec = Vec::with_capacity(num_prune); + + // Phase 1: Tangent space pruning across all shards + for cache in &self.caches { + let query_tangent = cache.query_tangent(query); + + for (local_idx, &global_idx) in cache.point_indices.iter().enumerate() { + let tangent_dist = cache.tangent_distance_squared(&query_tangent, local_idx); + candidates.push(PrunedCandidate { + index: global_idx, + tangent_dist, + exact_dist: None, + }); + } + } + + // Sort by tangent distance and keep top prune_factor * top_n + candidates.sort_by(|a, b| a.tangent_dist.partial_cmp(&b.tangent_dist).unwrap()); + candidates.truncate(num_prune); + + // Phase 2: Exact Poincaré distance for finalists + for candidate in &mut candidates { + if candidate.index < points.len() { + candidate.exact_dist = + Some(poincare_distance(query, &points[candidate.index], curvature)); + } + } + + // Sort by exact distance and return top_n + candidates.sort_by(|a, b| { + a.exact_dist + .unwrap_or(f32::MAX) + .partial_cmp(&b.exact_dist.unwrap_or(f32::MAX)) + .unwrap() + }); + candidates.truncate(self.top_n); + + candidates + } +} + +/// Compute micro tangent update for incremental operations +/// +/// For small updates (reflex loop), compute tangent-space delta +/// that keeps the point inside the ball. +pub fn tangent_micro_update( + point: &[f32], + delta: &[f32], + centroid: &[f32], + curvature: f32, + max_step: f32, +) -> Vec { + // Get current tangent coordinates + let tangent = log_map(point, centroid, curvature); + + // Apply bounded delta in tangent space + let delta_norm = norm(delta); + let scale = if delta_norm > max_step { + max_step / delta_norm + } else { + 1.0 + }; + + let new_tangent: Vec = tangent + .iter() + .zip(delta.iter()) + .map(|(&t, &d)| t + scale * d) + .collect(); + + // Map back to ball and project + let new_point = crate::poincare::exp_map(&new_tangent, centroid, curvature); + project_to_ball(&new_point, curvature, EPS) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tangent_cache_creation() { + let points = vec![ + vec![0.1, 0.2, 0.1], + vec![-0.1, 0.15, 0.05], + vec![0.2, -0.1, 0.1], + ]; + let indices: Vec = (0..3).collect(); + + let cache = TangentCache::new(&points, &indices, 1.0).unwrap(); + + assert_eq!(cache.len(), 3); + assert_eq!(cache.dim(), 3); + } + + #[test] + fn test_tangent_pruning() { + let points = vec![ + vec![0.1, 0.2], + vec![-0.1, 0.15], + vec![0.2, -0.1], + vec![0.05, 0.05], + ]; + let indices: Vec = (0..4).collect(); + + let cache = TangentCache::new(&points, &indices, 1.0).unwrap(); + + let mut pruner = TangentPruner::new(2, 2); + pruner.add_cache(cache); + + let query = vec![0.08, 0.1]; + let results = pruner.search(&query, &points, 1.0); + + assert_eq!(results.len(), 2); + // Results should be sorted by exact distance + assert!(results[0].exact_dist.unwrap() <= results[1].exact_dist.unwrap()); + } +} diff --git a/crates/ruvector-hyperbolic-hnsw/tests/math_tests.rs b/crates/ruvector-hyperbolic-hnsw/tests/math_tests.rs new file mode 100644 index 000000000..63c5f8bce --- /dev/null +++ b/crates/ruvector-hyperbolic-hnsw/tests/math_tests.rs @@ -0,0 +1,531 @@ +//! Comprehensive Mathematical Correctness Tests for Hyperbolic Operations +//! +//! These tests verify the mathematical properties of Poincaré ball operations +//! as specified in the evaluation protocol. + +use ruvector_hyperbolic_hnsw::poincare::*; +use ruvector_hyperbolic_hnsw::tangent::*; +use ruvector_hyperbolic_hnsw::hnsw::*; +use ruvector_hyperbolic_hnsw::shard::*; + +// ============================================================================ +// Poincaré Ball Properties +// ============================================================================ + +#[test] +fn test_mobius_add_identity() { + // x ⊕ 0 = x (right identity) + let x = vec![0.3, 0.2, 0.1]; + let zero = vec![0.0, 0.0, 0.0]; + + let result = mobius_add(&x, &zero, 1.0); + + for (a, b) in x.iter().zip(result.iter()) { + assert!((a - b).abs() < 1e-5, "Right identity failed"); + } +} + +#[test] +fn test_mobius_add_inverse() { + // x ⊕ (-x) ≈ 0 (inverse element) + let x = vec![0.3, 0.2]; + let neg_x: Vec = x.iter().map(|v| -v).collect(); + + let result = mobius_add(&x, &neg_x, 1.0); + let result_norm = norm(&result); + + // Result should be close to zero + assert!(result_norm < 0.1, "Inverse element failed: norm = {}", result_norm); +} + +#[test] +fn test_mobius_add_gyrocommutative() { + // Gyrocommutative: x ⊕ y ≈ gyr[x,y](y ⊕ x) (holds for small vectors) + let x = vec![0.1, 0.05]; + let y = vec![0.08, -0.03]; + + let xy = mobius_add(&x, &y, 1.0); + let yx = mobius_add(&y, &x, 1.0); + + // For small vectors, these should be similar + let diff: f32 = xy.iter().zip(yx.iter()).map(|(a, b)| (a - b).abs()).sum(); + assert!(diff < 0.5, "Gyrocommutative property check: diff = {}", diff); +} + +#[test] +fn test_exp_log_inverse() { + // log_p(exp_p(v)) = v (inverse relationship) + let p = vec![0.1, 0.2, 0.1]; + let v = vec![0.1, -0.1, 0.05]; + + let q = exp_map(&v, &p, 1.0); + let v_recovered = log_map(&q, &p, 1.0); + + for (a, b) in v.iter().zip(v_recovered.iter()) { + assert!((a - b).abs() < 1e-4, "exp-log inverse failed: expected {}, got {}", a, b); + } +} + +#[test] +fn test_log_exp_inverse() { + // exp_p(log_p(q)) = q (inverse relationship) + let p = vec![0.1, 0.15]; + let q = vec![0.2, -0.1]; + + let v = log_map(&q, &p, 1.0); + let q_recovered = exp_map(&v, &p, 1.0); + + for (a, b) in q.iter().zip(q_recovered.iter()) { + assert!((a - b).abs() < 1e-4, "log-exp inverse failed: expected {}, got {}", a, b); + } +} + +#[test] +fn test_distance_symmetry() { + // d(x, y) = d(y, x) + let x = vec![0.3, 0.2, 0.1]; + let y = vec![-0.1, 0.4, 0.2]; + + let d1 = poincare_distance(&x, &y, 1.0); + let d2 = poincare_distance(&y, &x, 1.0); + + assert!((d1 - d2).abs() < 1e-6, "Symmetry failed: {} vs {}", d1, d2); +} + +#[test] +fn test_distance_identity() { + // d(x, x) = 0 + let x = vec![0.3, 0.2, 0.1]; + let d = poincare_distance(&x, &x, 1.0); + + assert!(d.abs() < 1e-6, "Identity of indiscernibles failed: d = {}", d); +} + +#[test] +fn test_distance_non_negative() { + // d(x, y) >= 0 + let x = vec![0.3, 0.2]; + let y = vec![-0.1, 0.4]; + + let d = poincare_distance(&x, &y, 1.0); + assert!(d >= 0.0, "Non-negativity failed: d = {}", d); +} + +#[test] +fn test_distance_triangle_inequality() { + // d(x, z) <= d(x, y) + d(y, z) + let x = vec![0.1, 0.2]; + let y = vec![0.3, 0.1]; + let z = vec![-0.1, 0.35]; + + let dxz = poincare_distance(&x, &z, 1.0); + let dxy = poincare_distance(&x, &y, 1.0); + let dyz = poincare_distance(&y, &z, 1.0); + + assert!(dxz <= dxy + dyz + 1e-5, + "Triangle inequality failed: {} > {} + {}", dxz, dxy, dyz); +} + +// ============================================================================ +// Numerical Stability +// ============================================================================ + +#[test] +fn test_projection_keeps_points_inside() { + // All projected points should satisfy ||x|| < 1/sqrt(c) - eps + let test_points = vec![ + vec![0.5, 0.5, 0.5], + vec![0.9, 0.9], + vec![10.0, 10.0, 10.0], + vec![-5.0, 3.0], + ]; + + for point in test_points { + let projected = project_to_ball(&point, 1.0, EPS); + let n = norm(&projected); + // Use <= with small tolerance for floating point + assert!(n <= 1.0 - EPS + 1e-7, + "Projection failed: norm {} >= max {}", n, 1.0 - EPS); + } +} + +#[test] +fn test_near_boundary_stability() { + // Operations near the boundary should remain stable + let near_boundary = vec![0.99 - EPS, 0.0]; + let small_vec = vec![0.01, 0.01]; + + // Should not panic or produce NaN/Inf + let result = mobius_add(&near_boundary, &small_vec, 1.0); + assert!(!result.iter().any(|v| v.is_nan() || v.is_infinite()), + "Near boundary operation produced NaN/Inf"); + + let n = norm(&result); + assert!(n < 1.0 - EPS, "Result escaped ball boundary"); +} + +#[test] +fn test_zero_vector_handling() { + // Operations with zero vector should be stable + let zero = vec![0.0, 0.0, 0.0]; + let x = vec![0.3, 0.2, 0.1]; + + // exp_map with zero tangent should return base point + let result = exp_map(&zero, &x, 1.0); + for (a, b) in x.iter().zip(result.iter()) { + assert!((a - b).abs() < 1e-5, "exp_map with zero failed"); + } + + // log_map of same point should be zero + let log_result = log_map(&x, &x, 1.0); + assert!(norm(&log_result) < 1e-5, "log_map of same point should be zero"); +} + +#[test] +fn test_small_curvature_stability() { + // Small curvatures should work (approaches Euclidean) + let x = vec![0.3, 0.2]; + let y = vec![0.1, 0.4]; + + let d_small_c = poincare_distance(&x, &y, 0.01); + let d_euclidean: f32 = x.iter().zip(y.iter()) + .map(|(a, b)| (a - b) * (a - b)) + .sum::() + .sqrt(); + + // For small curvature, should approach Euclidean + // The ratio should be bounded + assert!(!d_small_c.is_nan() && !d_small_c.is_infinite(), + "Small curvature produced invalid result"); +} + +#[test] +fn test_large_curvature_stability() { + // Large curvatures should work (stronger hyperbolic effect) + let x = vec![0.1, 0.1]; + let y = vec![0.2, 0.1]; + + let d_large_c = poincare_distance(&x, &y, 10.0); + + assert!(!d_large_c.is_nan() && !d_large_c.is_infinite(), + "Large curvature produced invalid result: {}", d_large_c); +} + +// ============================================================================ +// Frechet Mean Properties +// ============================================================================ + +#[test] +fn test_frechet_mean_single_point() { + // Frechet mean of single point is that point + let points = vec![vec![0.3, 0.2]]; + let point_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect(); + let config = PoincareConfig::default(); + + let mean = frechet_mean(&point_refs, None, &config).unwrap(); + + for (a, b) in points[0].iter().zip(mean.iter()) { + assert!((a - b).abs() < 1e-4, "Single point mean failed"); + } +} + +#[test] +fn test_frechet_mean_symmetric() { + // Mean of symmetric points should be near origin + let points = vec![ + vec![0.3, 0.0], + vec![-0.3, 0.0], + ]; + let point_refs: Vec<&[f32]> = points.iter().map(|p| p.as_slice()).collect(); + let config = PoincareConfig::default(); + + let mean = frechet_mean(&point_refs, None, &config).unwrap(); + + // Mean should be close to origin + let mean_norm = norm(&mean); + assert!(mean_norm < 0.1, "Symmetric mean not near origin: {}", mean_norm); +} + +// ============================================================================ +// Tangent Space Operations +// ============================================================================ + +#[test] +fn test_tangent_cache_creation() { + let points = vec![ + vec![0.1, 0.2, 0.1], + vec![-0.1, 0.15, 0.05], + vec![0.2, -0.1, 0.1], + ]; + let indices: Vec = (0..3).collect(); + + let cache = TangentCache::new(&points, &indices, 1.0).unwrap(); + + assert_eq!(cache.len(), 3); + assert_eq!(cache.dim(), 3); + + // Centroid should be inside ball + let centroid_norm = norm(&cache.centroid); + assert!(centroid_norm < 1.0 - EPS, "Centroid outside ball"); +} + +#[test] +fn test_tangent_distance_ordering() { + // Tangent distance should roughly preserve hyperbolic distance ordering + let points = vec![ + vec![0.1, 0.1], + vec![0.2, 0.1], + vec![0.5, 0.3], + ]; + let indices: Vec = (0..3).collect(); + + let cache = TangentCache::new(&points, &indices, 1.0).unwrap(); + + let query = vec![0.12, 0.11]; + let query_tangent = cache.query_tangent(&query); + + let mut tangent_dists: Vec<(usize, f32)> = (0..3) + .map(|i| (i, cache.tangent_distance_squared(&query_tangent, i))) + .collect(); + tangent_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + + let mut hyp_dists: Vec<(usize, f32)> = (0..3) + .map(|i| (i, poincare_distance(&query, &points[i], 1.0))) + .collect(); + hyp_dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + + // First nearest neighbor should match + assert_eq!(tangent_dists[0].0, hyp_dists[0].0, + "First neighbor mismatch: tangent says {}, hyperbolic says {}", + tangent_dists[0].0, hyp_dists[0].0); +} + +// ============================================================================ +// HNSW Integration +// ============================================================================ + +#[test] +fn test_hnsw_insert_and_search() { + let mut hnsw = HyperbolicHnsw::default_config(); + + // Insert points + for i in 0..20 { + let v = vec![0.1 * (i as f32 % 5.0), 0.05 * (i as f32 / 5.0)]; + hnsw.insert(v).unwrap(); + } + + assert_eq!(hnsw.len(), 20); + + // Search + let query = vec![0.25, 0.125]; + let results = hnsw.search(&query, 5).unwrap(); + + assert_eq!(results.len(), 5); + + // Results should be sorted by distance + for i in 1..results.len() { + assert!(results[i-1].distance <= results[i].distance, + "Results not sorted at index {}: {} > {}", + i, results[i-1].distance, results[i].distance); + } +} + +#[test] +fn test_hnsw_nearest_is_correct() { + let mut hnsw = HyperbolicHnsw::default_config(); + + let points = vec![ + vec![0.0, 0.0], + vec![0.5, 0.0], + vec![0.0, 0.5], + vec![0.3, 0.3], + ]; + + for p in &points { + hnsw.insert(p.clone()).unwrap(); + } + + // Query near origin + let query = vec![0.05, 0.05]; + let results = hnsw.search(&query, 1).unwrap(); + + // Should find point at origin (id 0) + assert_eq!(results[0].id, 0, "Expected nearest to be origin"); +} + +#[test] +fn test_hnsw_curvature_update() { + let mut hnsw = HyperbolicHnsw::default_config(); + + hnsw.insert(vec![0.1, 0.2]).unwrap(); + hnsw.insert(vec![0.3, 0.1]).unwrap(); + + // Update curvature + hnsw.set_curvature(2.0).unwrap(); + + assert!((hnsw.config.curvature - 2.0).abs() < 1e-6); + + // Search should still work + let results = hnsw.search(&[0.2, 0.15], 2).unwrap(); + assert_eq!(results.len(), 2); +} + +// ============================================================================ +// Shard Management +// ============================================================================ + +#[test] +fn test_curvature_registry() { + let mut registry = CurvatureRegistry::new(1.0); + + registry.set("shard_1", 0.5); + assert!((registry.get("shard_1") - 0.5).abs() < 1e-6); + assert!((registry.get("unknown") - 1.0).abs() < 1e-6); // Default + + // Canary testing + registry.set_canary("shard_1", 0.3, 50); + assert!((registry.get_effective("shard_1", false) - 0.5).abs() < 1e-6); + assert!((registry.get_effective("shard_1", true) - 0.3).abs() < 1e-6); + + // Promote canary + if let Some(shard) = registry.shards.get_mut("shard_1") { + shard.promote_canary(); + } + assert!((registry.get("shard_1") - 0.3).abs() < 1e-6); +} + +#[test] +fn test_sharded_hnsw() { + let mut manager = ShardedHyperbolicHnsw::new(1.0); + + for i in 0..30 { + let v = vec![0.1 * (i as f32 % 6.0), 0.05 * (i as f32 / 6.0)]; + manager.insert(v, Some(i / 10)).unwrap(); + } + + assert_eq!(manager.len(), 30); + assert!(manager.num_shards() > 0); + + // Search + let results = manager.search(&[0.25, 0.125], 5).unwrap(); + assert!(!results.is_empty()); +} + +// ============================================================================ +// Hierarchy Metrics +// ============================================================================ + +#[test] +fn test_hierarchy_metrics_radius_correlation() { + // Points with radius proportional to depth should have positive correlation + let points: Vec> = (0..20).map(|i| { + let depth = i / 4; + let radius = 0.1 + 0.15 * depth as f32; + let angle = (i % 4) as f32 * std::f32::consts::PI / 2.0; + vec![radius * angle.cos(), radius * angle.sin()] + }).collect(); + + let depths: Vec = (0..20).map(|i| i / 4).collect(); + + let metrics = HierarchyMetrics::compute(&points, &depths, 1.0).unwrap(); + + assert!(metrics.radius_depth_correlation > 0.5, + "Expected positive correlation, got {}", metrics.radius_depth_correlation); +} + +// ============================================================================ +// Dual Space Index +// ============================================================================ + +#[test] +fn test_dual_space_index() { + let mut dual = DualSpaceIndex::new(1.0, 0.5); + + for i in 0..15 { + let v = vec![0.1 * i as f32, 0.05 * i as f32]; + dual.insert(v).unwrap(); + } + + let results = dual.search(&[0.35, 0.175], 5).unwrap(); + + assert_eq!(results.len(), 5); + + // Results should be sorted + for i in 1..results.len() { + assert!(results[i-1].distance <= results[i].distance); + } +} + +// ============================================================================ +// Edge Cases +// ============================================================================ + +#[test] +fn test_empty_index_search() { + let hnsw = HyperbolicHnsw::default_config(); + + let results = hnsw.search(&[0.1, 0.2], 5).unwrap(); + assert!(results.is_empty()); +} + +#[test] +fn test_single_element_search() { + let mut hnsw = HyperbolicHnsw::default_config(); + hnsw.insert(vec![0.3, 0.2]).unwrap(); + + let results = hnsw.search(&[0.1, 0.2], 5).unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].id, 0); +} + +#[test] +fn test_k_larger_than_index() { + let mut hnsw = HyperbolicHnsw::default_config(); + + for i in 0..3 { + hnsw.insert(vec![0.1 * i as f32, 0.1]).unwrap(); + } + + let results = hnsw.search(&[0.15, 0.1], 10).unwrap(); + assert_eq!(results.len(), 3); +} + +// ============================================================================ +// Performance Characteristics +// ============================================================================ + +#[test] +fn test_insert_performance() { + let mut hnsw = HyperbolicHnsw::default_config(); + + // Should handle 100 insertions without panic + for i in 0..100 { + let v = vec![ + 0.05 * (i % 10) as f32, + 0.05 * (i / 10) as f32, + ]; + hnsw.insert(v).unwrap(); + } + + assert_eq!(hnsw.len(), 100); +} + +#[test] +fn test_search_performance() { + let mut hnsw = HyperbolicHnsw::default_config(); + + for i in 0..100 { + let v = vec![ + 0.05 * (i % 10) as f32, + 0.05 * (i / 10) as f32, + ]; + hnsw.insert(v).unwrap(); + } + + // Should handle multiple searches + for _ in 0..10 { + let query = vec![0.25, 0.25]; + let results = hnsw.search(&query, 10).unwrap(); + assert_eq!(results.len(), 10); + } +}