diff --git a/CHANGELOG.md b/CHANGELOG.md index 356c9c2..c76ed4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,46 @@ All notable changes to this project will be documented in this file. +## [0.16.0] - 2025-10-08 + +### Bug Fixes + +- Keep python thread handle alive (Adrian Seyboldt) + +- No errors for unused parameters (Adrian Seyboldt) + + +### Documentation + +- Fix typo in pymc usage docs (Adrian Seyboldt) + + +### Features + +- Support step size adaptation method (Adrian Seyboldt) + +- Add argument for mindepth (Adrian Seyboldt) + +- Support free-threaded python build (Adrian Seyboldt) + +- Use new nuts-rs storage interface (Adrian Seyboldt) + +- Add zarr_store argument to write trace while sampling (Adrian Seyboldt) + + +### Miscellaneous Tasks + +- Bump actions/checkout from 4 to 5 (dependabot[bot]) + +- Bump actions/download-artifact from 4 to 5 (dependabot[bot]) + +- Bump actions/setup-python from 5 to 6 (#240) (dependabot[bot]) + +- Bump actions/attest-build-provenance from 2 to 3 (#239) (dependabot[bot]) + +- Update pyo3 (Adrian Seyboldt) + + ## [0.15.2] - 2025-07-16 ### Features diff --git a/Cargo.lock b/Cargo.lock index 26eb163..11948d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,15 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 + +[[package]] +name = "addr2line" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b5d307320b3181d6d7954e663bd7c774a838b8220fe0593c86d9fb09f498b4b" +dependencies = [ + "gimli", +] [[package]] name = "adler2" @@ -43,10 +52,10 @@ dependencies = [ ] [[package]] -name = "android-tzdata" -version = "0.1.1" +name = "allocator-api2" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] name = "android_system_properties" @@ -65,27 +74,30 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.11" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" [[package]] name = "anyhow" -version = "1.0.98" +version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" +checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" [[package]] name = "arrow" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3f15b4c6b148206ff3a2b35002e08929c2462467b62b9c02036d9c34f9ef994" +checksum = "6e833808ff2d94ed40d9379848a950d995043c7fb3e81a30b383f4c6033821cc" dependencies = [ "arrow-arith", "arrow-array", "arrow-buffer", "arrow-cast", + "arrow-csv", "arrow-data", + "arrow-ipc", + "arrow-json", "arrow-ord", "arrow-row", "arrow-schema", @@ -95,9 +107,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30feb679425110209ae35c3fbf82404a39a4c0436bb3ec36164d8bffed2a4ce4" +checksum = "ad08897b81588f60ba983e3ca39bda2b179bdd84dced378e7df81a5313802ef8" dependencies = [ "arrow-array", "arrow-buffer", @@ -109,25 +121,26 @@ dependencies = [ [[package]] name = "arrow-array" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70732f04d285d49054a48b72c54f791bb3424abae92d27aafdf776c98af161c8" +checksum = "8548ca7c070d8db9ce7aa43f37393e4bfcf3f2d3681df278490772fd1673d08d" dependencies = [ "ahash", "arrow-buffer", "arrow-data", "arrow-schema", "chrono", + "chrono-tz", "half", - "hashbrown", + "hashbrown 0.16.0", "num", ] [[package]] name = "arrow-buffer" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "169b1d5d6cb390dd92ce582b06b23815c7953e9dfaaea75556e89d890d19993d" +checksum = "e003216336f70446457e280807a73899dd822feaf02087d31febca1363e2fccc" dependencies = [ "bytes", "half", @@ -136,9 +149,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4f12eccc3e1c05a766cafb31f6a60a46c2f8efec9b74c6e0648766d30686af8" +checksum = "919418a0681298d3a77d1a315f625916cb5678ad0d74b9c60108eb15fd083023" dependencies = [ "arrow-array", "arrow-buffer", @@ -148,29 +161,81 @@ dependencies = [ "atoi", "base64", "chrono", + "comfy-table", "half", "lexical-core", "num", "ryu", ] +[[package]] +name = "arrow-csv" +version = "56.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa9bf02705b5cf762b6f764c65f04ae9082c7cfc4e96e0c33548ee3f67012eb" +dependencies = [ + "arrow-array", + "arrow-cast", + "arrow-schema", + "chrono", + "csv", + "csv-core", + "regex", +] + [[package]] name = "arrow-data" -version = "55.2.0" +version = "56.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5c64fff1d142f833d78897a772f2e5b55b36cb3e6320376f0961ab0db7bd6d0" +dependencies = [ + "arrow-buffer", + "arrow-schema", + "half", + "num", +] + +[[package]] +name = "arrow-ipc" +version = "56.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d3594dcddccc7f20fd069bc8e9828ce37220372680ff638c5e00dea427d88f5" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "flatbuffers", +] + +[[package]] +name = "arrow-json" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8de1ce212d803199684b658fc4ba55fb2d7e87b213de5af415308d2fee3619c2" +checksum = "88cf36502b64a127dc659e3b305f1d993a544eab0d48cce704424e62074dc04b" dependencies = [ + "arrow-array", "arrow-buffer", + "arrow-cast", + "arrow-data", "arrow-schema", + "chrono", "half", + "indexmap", + "lexical-core", + "memchr", "num", + "serde", + "serde_json", + "simdutf8", ] [[package]] name = "arrow-ord" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6506e3a059e3be23023f587f79c82ef0bcf6d293587e3272d20f2d30b969b5a7" +checksum = "3c8f82583eb4f8d84d4ee55fd1cb306720cddead7596edce95b50ee418edf66f" dependencies = [ "arrow-array", "arrow-buffer", @@ -181,9 +246,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52bf7393166beaf79b4bed9bfdf19e97472af32ce5b6b48169d321518a08cae2" +checksum = "9d07ba24522229d9085031df6b94605e0f4b26e099fb7cdeec37abd941a73753" dependencies = [ "arrow-array", "arrow-buffer", @@ -194,18 +259,20 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af7686986a3bf2254c9fb130c623cdcb2f8e1f15763e7c71c310f0834da3d292" +checksum = "b3aa9e59c611ebc291c28582077ef25c97f1975383f1479b12f3b9ffee2ffabe" dependencies = [ "bitflags", + "serde", + "serde_json", ] [[package]] name = "arrow-select" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd2b45757d6a2373faa3352d02ff5b54b098f5e21dccebc45a21806bc34501e5" +checksum = "8c41dbbd1e97bfcaee4fcb30e29105fb2c75e4d82ae4de70b792a5d3f66b2e7a" dependencies = [ "ahash", "arrow-array", @@ -217,9 +284,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "55.2.0" +version = "56.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0377d532850babb4d927a06294314b316e23311503ed580ec6ce6a0158f49d40" +checksum = "53f5183c150fbc619eede22b861ea7c0eebed8eaac0333eaa7f6da5205fd504d" dependencies = [ "arrow-array", "arrow-buffer", @@ -232,6 +299,39 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "async-generic" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddf3728566eefa873833159754f5732fb0951d3649e6e5b891cc70d56dd41673" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "async-lock" +version = "3.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fd03604047cee9b6ce9de9f70c6cd540a0520c813cbd49bae61f33ab80ed1dc" +dependencies = [ + "event-listener", + "event-listener-strategy", + "pin-project-lite", +] + +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "atoi" version = "2.0.0" @@ -241,12 +341,44 @@ dependencies = [ "num-traits", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "auto_impl" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffdcb70bdbc4d478427380519163274ac86e52916e10f0a8889adf0f96d3fee7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "autocfg" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "backtrace" +version = "0.3.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb531853791a215d7c62a30daf0dde835f381ab5de4589cfe7c649d2cbe92bd6" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-link", +] + [[package]] name = "base64" version = "0.22.1" @@ -261,9 +393,9 @@ checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" [[package]] name = "bindgen" -version = "0.71.1" +version = "0.72.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" +checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" dependencies = [ "bitflags", "cexpr", @@ -276,14 +408,14 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn", + "syn 2.0.106", ] [[package]] name = "bitflags" -version = "2.9.1" +version = "2.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" +checksum = "2261d10cca569e4643e526d8dc2e62e433cc8aba21ab764233731f8d369bf394" [[package]] name = "block-buffer" @@ -294,17 +426,30 @@ dependencies = [ "generic-array", ] +[[package]] +name = "blosc-src" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb68d27ab5ceb94ae9cd343f6fbc7bb84543496d547ed7c0db6718175fd41cb6" +dependencies = [ + "cc", + "libz-sys", + "lz4-sys", + "snappy_src", + "zstd-sys", +] + [[package]] name = "bridgestan" -version = "2.6.2" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fcf23cdd20237d4699464b803c6aef49f547266514c7361c27b25875ee69298" +checksum = "f6d0e34116970162606ca313a4d3cf76b4828600877ae30959f6f122e434cb29" dependencies = [ "bindgen", "libloading", "log", "path-absolutize", - "thiserror 2.0.12", + "thiserror 2.0.17", ] [[package]] @@ -315,9 +460,23 @@ checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "bytemuck" -version = "1.23.1" +version = "1.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c76a5792e44e4abe34d3abf15636779261d45a7450612059293d1d2cfc63422" +checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] [[package]] name = "byteorder" @@ -359,10 +518,11 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.29" +version = "1.2.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c1599538de2394445747c8cf7935946e3cc27e9625f889d979bfb2aaf569362" +checksum = "e1d05d92f4b1fd76aad469d46cdd858ca761576082cd37df81416691e50199fb" dependencies = [ + "find-msvc-tools", "jobserver", "libc", "shlex", @@ -379,22 +539,40 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.1" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" + +[[package]] +name = "cfg_aliases" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.41" +version = "0.4.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" +checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" dependencies = [ - "android-tzdata", "iana-time-zone", + "js-sys", "num-traits", + "serde", + "wasm-bindgen", "windows-link", ] +[[package]] +name = "chrono-tz" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6139a8597ed92cf816dfb33f5dd6cf0bb93a6adc938f11039f371bc5bcd26c3" +dependencies = [ + "chrono", + "phf", +] + [[package]] name = "ciborium" version = "0.2.2" @@ -445,18 +623,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.41" +version = "4.5.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be92d32e80243a54711e5d7ce823c35c41c9d929dc4ab58e1276f625841aadf9" +checksum = "e2134bb3ea021b78629caa971416385309e0131b351b25e01dc16fb54e1b5fae" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.41" +version = "4.5.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "707eab41e9622f9139419d573eca0900137718000c517d47da73045f54331c3d" +checksum = "c2ba64afa3c0a6df7fa517765e31314e983f51dda798ffba27b988194fb65dc9" dependencies = [ "anstyle", "clap_lex", @@ -468,17 +646,37 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" +[[package]] +name = "comfy-table" +version = "7.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0d05af1e006a2407bedef5af410552494ce5be9090444dbbcb57258c1af3d56" +dependencies = [ + "strum", + "strum_macros", + "unicode-width", +] + +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "console" -version = "0.16.0" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e09ced7ebbccb63b4c65413d821f2e00ce54c5ca4514ddc6b3c892fdbcbc69d" +checksum = "b430743a6eb14e9764d4260d4c0d8123087d504eeb9c48f2b2a5e810dd369df4" dependencies = [ "encode_unicode", "libc", "once_cell", "unicode-width", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -507,6 +705,16 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -522,6 +730,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32c" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a47af21622d091a8f0fb295b88bc886ac74efcc613efc19f5d0b21de5c89e47" +dependencies = [ + "rustc_version", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -533,9 +750,9 @@ dependencies = [ [[package]] name = "criterion" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bf7af66b0989381bd0be551bd7cc91912a655a58c6918420c9527b1fd8b4679" +checksum = "e1c047a62b0cc3e145fa84415a3191f628e980b194c2755aa12300a4e6cbd928" dependencies = [ "anes", "cast", @@ -556,12 +773,21 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +checksum = "9b1bcc0dc7dfae599d84ad0b1a55f80cde8af3725da8313b528da95ef783e338" dependencies = [ "cast", - "itertools 0.10.5", + "itertools 0.13.0", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", ] [[package]] @@ -605,15 +831,57 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d" +dependencies = [ + "memchr", +] + [[package]] name = "deranged" -version = "0.4.0" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" +checksum = "a41953f86f8a05768a6cda24def994fd2f424b04ec5c719cf89989779f199071" dependencies = [ "powerfmt", ] +[[package]] +name = "derive_more" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", + "unicode-xid", +] + [[package]] name = "digest" version = "0.10.7" @@ -625,15 +893,33 @@ dependencies = [ "subtle", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "dyn-stack" -version = "0.13.0" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "490bd48eb68fffcfed519b4edbfd82c69cbe741d175b84f0e0cbe8c57cbe0bdd" +checksum = "1c4713e43e2886ba72b8271aa66c93d722116acf7a75555cce11dcde84388fe8" dependencies = [ "bytemuck", + "dyn-stack-macros", ] +[[package]] +name = "dyn-stack-macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05dbec7076f432bb132db738df90d87a4f5789e99f59e7b1219a6b8ef61eaa68" + [[package]] name = "either" version = "1.15.0" @@ -672,7 +958,7 @@ checksum = "3bf679796c0322556351f287a51b49e48f7c4986e727b5dd78c972d30e2e16cc" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -683,14 +969,41 @@ checksum = "44f23cf4b44bfce11a86ace86f8a73ffdec849c9fd00a386a53d278bd9e81fb3" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "event-listener" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener", + "pin-project-lite", ] [[package]] name = "faer" -version = "0.22.6" +version = "0.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49fce40ad65c366fbc6cd70a99d09d1008f075280bf2455e558e163c82913a9f" +checksum = "3cb922206162d9405f9fc059052b3f997bdc92745da7bfd620645f5092df20d1" dependencies = [ "bytemuck", "dyn-stack", @@ -709,20 +1022,20 @@ dependencies = [ [[package]] name = "faer-macros" -version = "0.21.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d0a255d1442b5825c61812a7eafda9034ec53d969c98555251085e148428e6a" +checksum = "2cc4b8cd876795d3b19ddfd59b03faa303c0b8adb9af6e188e81fc647c485bb9" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] name = "faer-traits" -version = "0.22.1" +version = "0.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54febfcbb90edaab562d85447a94d500f1601f11db0b30d27da87ed6542c8f91" +checksum = "24b69235b5f54416286c485fb047f2f499fc935a4eee2caadf4757f3c94c7b62" dependencies = [ "bytemuck", "dyn-stack", @@ -736,48 +1049,174 @@ dependencies = [ "reborrow", ] +[[package]] +name = "find-msvc-tools" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0399f9d26e5191ce32c498bebd31e7a3ceabc2745f0ac54af3f335126c3f24b3" + +[[package]] +name = "flatbuffers" +version = "25.9.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09b6620799e7340ebd9968d2e0708eb82cf1971e9a16821e2091b6d6e475eed5" +dependencies = [ + "bitflags", + "rustc_version", +] + [[package]] name = "flate2" -version = "1.1.2" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a3d7db9596fecd151c5f638c0ee5d5bd487b6e0ea232e5dc96d5250f6f94b1d" +checksum = "dc5a4e564e38c699f2880d3fda590bedc2e69f3f84cd48b457bd892ce61d0aa9" dependencies = [ "crc32fast", "miniz_oxide", ] [[package]] -name = "gemm" -version = "0.18.2" +name = "fnv" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab96b703d31950f1aeddded248bc95543c9efc7ac9c4a21fda8703a83ee35451" -dependencies = [ - "dyn-stack", - "gemm-c32", - "gemm-c64", - "gemm-common", - "gemm-f32", - "gemm-f64", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "seq-macro", -] +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] -name = "gemm-c32" -version = "0.18.2" +name = "foldhash" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6db9fd9f40421d00eea9dd0770045a5603b8d684654816637732463f4073847" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" dependencies = [ - "dyn-stack", - "gemm-common", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "seq-macro", + "percent-encoding", +] + +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "gemm" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab96b703d31950f1aeddded248bc95543c9efc7ac9c4a21fda8703a83ee35451" +dependencies = [ + "dyn-stack", + "gemm-c32", + "gemm-c64", + "gemm-common", + "gemm-f32", + "gemm-f64", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c32" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6db9fd9f40421d00eea9dd0770045a5603b8d684654816637732463f4073847" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", ] [[package]] @@ -866,8 +1305,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.1+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -877,16 +1318,43 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", - "wasi 0.14.2+wasi-0.2.4", + "wasi 0.14.7+wasi-0.2.4", + "wasm-bindgen", ] +[[package]] +name = "gimli" +version = "0.32.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7" + [[package]] name = "glob" -version = "0.3.2" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + +[[package]] +name = "h2" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" +checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] [[package]] name = "half" @@ -894,6 +1362,7 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ + "bytemuck", "cfg-if", "crunchy", "num-traits", @@ -901,9 +1370,20 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.4" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" +checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" [[package]] name = "heck" @@ -920,11 +1400,120 @@ dependencies = [ "digest", ] +[[package]] +name = "http" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "humantime" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" + +[[package]] +name = "hyper" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb3aa54a13a0dfe7fbe3a59e0c76093041720fdc77b110cc0fc260fafb4dc51e" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "h2", + "http", + "http-body", + "httparse", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-native-certs", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" +dependencies = [ + "base64", + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body", + "hyper", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + [[package]] name = "iana-time-zone" -version = "0.1.63" +version = "0.1.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" +checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -944,6 +1533,123 @@ dependencies = [ "cc", ] +[[package]] +name = "icu_collections" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" + +[[package]] +name = "icu_properties" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "potential_utf", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" + +[[package]] +name = "icu_provider" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" +dependencies = [ + "displaydoc", + "icu_locale_core", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "indexmap" +version = "2.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b0f83760fb341a774ed326568e19f5a863af4a952def8c39f9ab92fd95b88e5" +dependencies = [ + "equivalent", + "hashbrown 0.16.0", +] + [[package]] name = "indicatif" version = "0.18.0" @@ -973,12 +1679,39 @@ dependencies = [ ] [[package]] -name = "itertools" -version = "0.10.5" +name = "inventory" +version = "0.3.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +checksum = "bc61209c082fbeb19919bee74b176221b27223e27b65d781eb91af24eb1fb46e" dependencies = [ - "either", + "rustversion", +] + +[[package]] +name = "io-uring" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "046fa2d4d00aea763528b4950358d0ead425372445dc8ff86312b3c69ff7727b" +dependencies = [ + "bitflags", + "cfg-if", + "libc", +] + +[[package]] +name = "ipnet" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" + +[[package]] +name = "iri-string" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc5ebe9c3a1a7a5127f920a418f7585e9e758e911d0466ed004f393b0e380b2" +dependencies = [ + "memchr", + "serde", ] [[package]] @@ -1007,9 +1740,9 @@ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jobserver" -version = "0.1.33" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" dependencies = [ "getrandom 0.3.3", "libc", @@ -1017,9 +1750,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.77" +version = "0.3.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" +checksum = "ec48937a97411dcb524a265206ccd4c90bb711fca92b2792c407f268825b9305" dependencies = [ "once_cell", "wasm-bindgen", @@ -1033,9 +1766,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "lexical-core" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b765c31809609075565a70b4b71402281283aeda7ecaf4818ac14a7b2ade8958" +checksum = "7d8d125a277f807e55a77304455eb7b1cb52f2b18c143b60e766c120bd64a594" dependencies = [ "lexical-parse-float", "lexical-parse-integer", @@ -1046,69 +1779,62 @@ dependencies = [ [[package]] name = "lexical-parse-float" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de6f9cb01fb0b08060209a057c048fcbab8717b4c1ecd2eac66ebfe39a65b0f2" +checksum = "52a9f232fbd6f550bc0137dcb5f99ab674071ac2d690ac69704593cb4abbea56" dependencies = [ "lexical-parse-integer", "lexical-util", - "static_assertions", ] [[package]] name = "lexical-parse-integer" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72207aae22fc0a121ba7b6d479e42cbfea549af1479c3f3a4f12c70dd66df12e" +checksum = "9a7a039f8fb9c19c996cd7b2fcce303c1b2874fe1aca544edc85c4a5f8489b34" dependencies = [ "lexical-util", - "static_assertions", ] [[package]] name = "lexical-util" -version = "1.0.6" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a82e24bf537fd24c177ffbbdc6ebcc8d54732c35b50a3f28cc3f4e4c949a0b3" -dependencies = [ - "static_assertions", -] +checksum = "2604dd126bb14f13fb5d1bd6a66155079cb9fa655b37f875b3a742c705dbed17" [[package]] name = "lexical-write-float" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5afc668a27f460fb45a81a757b6bf2f43c2d7e30cb5a2dcd3abf294c78d62bd" +checksum = "50c438c87c013188d415fbabbb1dceb44249ab81664efbd31b14ae55dabb6361" dependencies = [ "lexical-util", "lexical-write-integer", - "static_assertions", ] [[package]] name = "lexical-write-integer" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "629ddff1a914a836fb245616a7888b62903aae58fa771e1d83943035efa0f978" +checksum = "409851a618475d2d5796377cad353802345cba92c867d9fbcde9cf4eac4e14df" dependencies = [ "lexical-util", - "static_assertions", ] [[package]] name = "libc" -version = "0.2.174" +version = "0.2.176" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" +checksum = "58f929b4d672ea937a23a1ab494143d968337a5f47e56d0815df1e0890ddf174" [[package]] name = "libloading" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" dependencies = [ "cfg-if", - "windows-targets 0.53.2", + "windows-link", ] [[package]] @@ -1118,49 +1844,175 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] -name = "log" -version = "0.4.27" +name = "libz-sys" +version = "1.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +checksum = "8b70e7a7df205e92a1a4cd9aaae7898dac0aa555503cc0a649494d0d60e7651d" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] [[package]] -name = "matrixmultiply" -version = "0.3.10" +name = "link-cplusplus" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +checksum = "7f78c730aaa7d0b9336a299029ea49f9ee53b0ed06e9202e8cb7db9bae7b8c82" dependencies = [ - "autocfg", - "rawpointer", + "cc", ] [[package]] -name = "memchr" -version = "2.7.5" +name = "litemap" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" +checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" [[package]] -name = "memoffset" -version = "0.9.1" +name = "lock_api" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" dependencies = [ - "autocfg", + "scopeguard", ] [[package]] -name = "minimal-lexical" -version = "0.2.1" +name = "log" +version = "0.4.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" [[package]] -name = "miniz_oxide" +name = "lru" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfe949189f46fabb938b3a9a0be30fdd93fd8a09260da863399a8cf3db756ec8" +dependencies = [ + "hashbrown 0.15.5", +] + +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + +[[package]] +name = "lz4-sys" +version = "1.11.1+lz4-1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "miniz_oxide" version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", + "simd-adler32", +] + +[[package]] +name = "mio" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" +dependencies = [ + "libc", + "wasi 0.11.1+wasi-snapshot-preview1", + "windows-sys 0.59.0", +] + +[[package]] +name = "moka" +version = "0.12.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8261cd88c312e0004c1d51baad2980c66528dfdb2bee62003e643a4d8f86b077" +dependencies = [ + "async-lock", + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "equivalent", + "event-listener", + "futures-util", + "parking_lot", + "portable-atomic", + "rustc_version", + "smallvec", + "tagptr", + "uuid", +] + +[[package]] +name = "monostate" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3341a273f6c9d5bef1908f17b7267bbab0e95c9bf69a0d4dcf8e9e1b2c76ef67" +dependencies = [ + "monostate-impl", + "serde", + "serde_core", +] + +[[package]] +name = "monostate-impl" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4db6d5580af57bf992f59068d4ea26fd518574ff48d7639b255a36f9de6e7e9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", ] [[package]] @@ -1341,10 +2193,11 @@ dependencies = [ [[package]] name = "numpy" -version = "0.25.0" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29f1dee9aa8d3f6f8e8b9af3803006101bb3653866ef056d530d53ae68587191" +checksum = "9b2dba356160b54f5371b550575b78130a54718b4c6e46b3f33a6da74a27e78b" dependencies = [ + "half", "libc", "ndarray", "num-complex", @@ -1357,7 +2210,7 @@ dependencies = [ [[package]] name = "nutpie" -version = "0.15.2" +version = "0.16.0" dependencies = [ "anyhow", "arrow", @@ -1368,33 +2221,110 @@ dependencies = [ "numpy", "nuts-rs", "pyo3", - "rand 0.9.1", + "pyo3-arrow", + "pyo3-object_store", + "rand 0.9.2", "rand_chacha 0.9.0", "rand_distr", "rayon", "smallvec", "tch", - "thiserror 2.0.12", + "thiserror 2.0.17", "time-humanize", + "tokio", "upon", + "zarrs", + "zarrs_object_store", +] + +[[package]] +name = "nuts-derive" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64eac5046c75ced9bdaede15ebc30c4ce982a13e75032ae8d5c1312d1e05d82e" +dependencies = [ + "nuts-storable", + "proc-macro2", + "quote", + "syn 1.0.109", ] [[package]] name = "nuts-rs" -version = "0.16.1" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acad2be84df0d14341d8de7d30c1019ecc008f4722befbd45745092a918c0a02" +checksum = "0cbacdcc02ea7e33cf6c3389b1c4944d028d5ca038e503e656856a9407c5b2b6" dependencies = [ "anyhow", "arrow", + "arrow-schema", "faer", "itertools 0.14.0", + "nuts-derive", + "nuts-storable", "pulp", - "rand 0.9.1", + "rand 0.9.2", "rand_chacha 0.9.0", "rand_distr", "rayon", - "thiserror 2.0.12", + "serde", + "serde_json", + "thiserror 2.0.17", + "tokio", + "zarrs", +] + +[[package]] +name = "nuts-storable" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb6cf9fc84ca313648ddb112f8728eb2f9531f2e4533959dd01127eb34290b5b" + +[[package]] +name = "object" +version = "0.37.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff76201f031d8863c38aa7f905eca4f53abbfa15f609db4277d44cd8938f33fe" +dependencies = [ + "memchr", +] + +[[package]] +name = "object_store" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c1be0c6c22ec0817cdc77d3842f721a17fd30ab6965001415b5402a74e6b740" +dependencies = [ + "async-trait", + "base64", + "bytes", + "chrono", + "form_urlencoded", + "futures", + "http", + "http-body-util", + "httparse", + "humantime", + "hyper", + "itertools 0.14.0", + "md-5", + "parking_lot", + "percent-encoding", + "quick-xml", + "rand 0.9.2", + "reqwest", + "ring", + "rustls-pemfile", + "serde", + "serde_json", + "serde_urlencoded", + "thiserror 2.0.17", + "tokio", + "tracing", + "url", + "walkdir", + "wasm-bindgen-futures", + "web-time", ] [[package]] @@ -1409,6 +2339,51 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + [[package]] name = "password-hash" version = "0.4.2" @@ -1444,6 +2419,12 @@ dependencies = [ "once_cell", ] +[[package]] +name = "pathdiff" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" + [[package]] name = "pbkdf2" version = "0.11.0" @@ -1456,6 +2437,42 @@ dependencies = [ "sha2", ] +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "phf" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "913273894cec178f401a31ec4b656318d95473527be05c0752cc41cdc32be8b7" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_shared" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06005508882fb681fd97892ecff4b7fd0fee13ef1aa569f8695dae7ab9099981" +dependencies = [ + "siphasher", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "pkg-config" version = "0.3.32" @@ -1505,6 +2522,15 @@ dependencies = [ "portable-atomic", ] +[[package]] +name = "potential_utf" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84df19adbe5b5a0782edcab45899906947ab039ccf4573713735ee7de1e6b08a" +dependencies = [ + "zerovec", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -1522,19 +2548,19 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.35" +version = "0.2.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "061c1221631e079b26479d25bbf2275bfe5917ae8419cd7e34f13bfc2aa7539a" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", - "syn", + "syn 2.0.106", ] [[package]] name = "proc-macro2" -version = "1.0.95" +version = "1.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" +checksum = "89ae43fd86e4158d6db51ad8e2b80f313af9cc74f5c0e03ccb87de09998732de" dependencies = [ "unicode-ident", ] @@ -1555,11 +2581,14 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.25.1" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8970a78afe0628a3e3430376fc5fd76b6b45c4d43360ffd6cdd40bdde72b682a" +checksum = "7ba0117f4212101ee6544044dae45abe1083d30ce7b29c4b5cbdfa2354e07383" dependencies = [ "anyhow", + "chrono", + "chrono-tz", + "indexmap", "indoc", "libc", "memoffset", @@ -1572,20 +2601,51 @@ dependencies = [ ] [[package]] -name = "pyo3-build-config" -version = "0.25.1" +name = "pyo3-arrow" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbbf9d6d0573f13480184e789095d6b5cfa11403d8d8311931bd5d111dbf007a" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "arrow-select", + "half", + "indexmap", + "numpy", + "pyo3", + "thiserror 1.0.69", +] + +[[package]] +name = "pyo3-async-runtimes" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458eb0c55e7ece017adeba38f2248ff3ac615e53660d7c71a238d7d2a01c7598" +checksum = "e6ee6d4cb3e8d5b925f5cdb38da183e0ff18122eb2048d4041c9e7034d026e23" dependencies = [ + "futures", "once_cell", + "pin-project-lite", + "pyo3", + "tokio", +] + +[[package]] +name = "pyo3-build-config" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fc6ddaf24947d12a9aa31ac65431fb1b851b8f4365426e182901eabfb87df5f" +dependencies = [ "target-lexicon", ] [[package]] name = "pyo3-ffi" -version = "0.25.1" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7114fe5457c61b276ab77c5055f206295b812608083644a5c5b2640c3102565c" +checksum = "025474d3928738efb38ac36d4744a74a400c901c7596199e20e45d98eb194105" dependencies = [ "libc", "pyo3-build-config", @@ -1593,27 +2653,50 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.25.1" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8725c0a622b374d6cb051d11a0983786448f7785336139c3c94f5aa6bef7e50" +checksum = "2e64eb489f22fe1c95911b77c44cc41e7c19f3082fc81cce90f657cdc42ffded" dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn", + "syn 2.0.106", ] [[package]] name = "pyo3-macros-backend" -version = "0.25.1" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4109984c22491085343c05b0dbc54ddc405c3cf7b4374fc533f5c3313a572ccc" +checksum = "100246c0ecf400b475341b8455a9213344569af29a3c841d29270e53102e0fcf" dependencies = [ "heck", "proc-macro2", "pyo3-build-config", "quote", - "syn", + "syn 2.0.106", +] + +[[package]] +name = "pyo3-object_store" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cda46869b9ce0e94ca68a8c2f48fdc940a543ed5e2d9272c3e7cc4bcc579fd6" +dependencies = [ + "async-trait", + "bytes", + "chrono", + "futures", + "http", + "humantime", + "itertools 0.14.0", + "object_store", + "percent-encoding", + "pyo3", + "pyo3-async-runtimes", + "serde", + "thiserror 1.0.69", + "tokio", + "url", ] [[package]] @@ -1629,48 +2712,125 @@ dependencies = [ ] [[package]] -name = "quote" -version = "1.0.40" +name = "quick-xml" +version = "0.38.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +checksum = "42a232e7487fc2ef313d96dde7948e7a3c05101870d8985e4fd8d26aedd27b89" dependencies = [ - "proc-macro2", + "memchr", + "serde", ] [[package]] -name = "r-efi" -version = "5.3.0" +name = "quick_cache" +version = "0.6.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +checksum = "ba15f5bccfb18c666351668b97bbff66da5093f96757ca15299e4e594fe1316e" +dependencies = [ + "ahash", + "equivalent", + "hashbrown 0.16.0", + "parking_lot", +] [[package]] -name = "rand" -version = "0.8.5" +name = "quinn" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" dependencies = [ - "libc", - "rand_chacha 0.3.1", - "rand_core 0.6.4", + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror 2.0.17", + "tokio", + "tracing", + "web-time", ] [[package]] -name = "rand" -version = "0.9.1" +name = "quinn-proto" +version = "0.11.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" dependencies = [ - "rand_chacha 0.9.0", - "rand_core 0.9.3", + "bytes", + "getrandom 0.3.3", + "lru-slab", + "rand 0.9.2", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.17", + "tinyvec", + "tracing", + "web-time", ] [[package]] -name = "rand_chacha" -version = "0.3.1" +name = "quinn-udp" +version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" dependencies = [ - "ppv-lite86", + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.60.2", +] + +[[package]] +name = "quote" +version = "1.0.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce25767e7b499d1b604768e7cde645d14cc8584231ea6b295e9c9eb22c02e1d1" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.3", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", "rand_core 0.6.4", ] @@ -1709,14 +2869,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" dependencies = [ "num-traits", - "rand 0.9.1", + "rand 0.9.2", ] [[package]] name = "raw-cpuid" -version = "11.5.0" +version = "11.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6df7ab838ed27997ba19a4664507e6f82b41fe6e20be42929332156e5e85146" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" dependencies = [ "bitflags", ] @@ -1729,9 +2889,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" dependencies = [ "either", "rayon-core", @@ -1739,25 +2899,43 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.12.1" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" dependencies = [ "crossbeam-deque", "crossbeam-utils", ] +[[package]] +name = "rayon_iter_concurrent_limit" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d09ee01023de07fa073ce14c37cbe0a9e099c6b0b60a29cf4af6d04d9553fed7" +dependencies = [ + "rayon", +] + [[package]] name = "reborrow" version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" -version = "1.11.1" +version = "1.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +checksum = "8b5288124840bee7b386bc413c487869b360b2b4ec421ea56425128692f2a82c" dependencies = [ "aho-corasick", "memchr", @@ -1767,9 +2945,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.9" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +checksum = "833eb9ce86d40ef33cb1306d8accf7bc8ec2bfea4355cbdebb3df68b40925cad" dependencies = [ "aho-corasick", "memchr", @@ -1778,9 +2956,71 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.5" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "caf4aa5b0f434c91fe5c7f1ecb6a5ece2130b02ad2a590589dda5146df959001" + +[[package]] +name = "reqwest" +version = "0.12.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d429f34c8092b2d42c7c93cec323bb4adeb7c67698f70839adec842ec10c7ceb" +dependencies = [ + "base64", + "bytes", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-util", + "js-sys", + "log", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls", + "rustls-native-certs", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-rustls", + "tokio-util", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", +] + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.16", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" [[package]] name = "rustc-hash" @@ -1788,11 +3028,76 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "rustls" +version = "0.23.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd3c25631629d034ce7cd9940adc9d45762d46de2b0f57193c4443b92c6d4d40" +dependencies = [ + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +dependencies = [ + "web-time", + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10b3f4191e8a80e6b43eebabfac91e5dcecebb27a71f04e820c47ec41d314bf" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" -version = "1.0.21" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" [[package]] name = "ryu" @@ -1819,6 +3124,50 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "security-framework" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3297343eaf830f66ede390ea39da1d462b6b0c1b000f420d0a83f898bbbe6ef" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + [[package]] name = "seq-macro" version = "0.3.6" @@ -1827,34 +3176,69 @@ checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.145" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" dependencies = [ + "indexmap", "itoa", "memchr", "ryu", "serde", + "serde_core", +] + +[[package]] +name = "serde_repr" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", ] [[package]] @@ -1885,6 +3269,30 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + +[[package]] +name = "simdutf8" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" + +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + [[package]] name = "smallvec" version = "1.15.1" @@ -1892,10 +3300,49 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" [[package]] -name = "static_assertions" -version = "1.1.0" +name = "snappy_src" +version = "0.2.5+snappy.1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +checksum = "4e1432067a55bcfb1fd522d2aca6537a4fcea32bba87ea86921226d14f9bad53" +dependencies = [ + "cc", + "link-cplusplus", +] + +[[package]] +name = "socket2" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + +[[package]] +name = "strum" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.106", +] [[package]] name = "subtle" @@ -1905,9 +3352,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.104" +version = "1.0.109" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", "quote", @@ -1915,10 +3362,47 @@ dependencies = [ ] [[package]] -name = "target-lexicon" +name = "syn" +version = "2.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + +[[package]] +name = "synstructure" version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + +[[package]] +name = "target-lexicon" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df7f62577c25e07834649fc3b39fafdc597c0a3527dc1c60129201ccfcbaa50c" [[package]] name = "tch" @@ -1948,11 +3432,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.12" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" dependencies = [ - "thiserror-impl 2.0.12", + "thiserror-impl 2.0.17", ] [[package]] @@ -1963,25 +3447,34 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] name = "thiserror-impl" -version = "2.0.12" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", +] + +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", ] [[package]] name = "time" -version = "0.3.41" +version = "0.3.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" dependencies = [ "deranged", "num-conv", @@ -1992,9 +3485,9 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.4" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" [[package]] name = "time-humanize" @@ -2011,6 +3504,16 @@ dependencies = [ "crunchy", ] +[[package]] +name = "tinystr" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -2022,34 +3525,189 @@ dependencies = [ ] [[package]] -name = "torch-sys" -version = "0.20.0" +name = "tinyvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "tokio" +version = "1.47.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89e49afdadebb872d3145a5638b59eb0691ea23e46ca484037cfab3b76b95038" +dependencies = [ + "backtrace", + "bytes", + "io-uring", + "libc", + "mio", + "pin-project-lite", + "slab", + "socket2", + "tokio-macros", + "windows-sys 0.59.0", +] + +[[package]] +name = "tokio-macros" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14307c986784f72ef81c89db7d9e28d6ac26d16213b109ea501696195e6e3ce5" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "torch-sys" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad6fa4ac5662b84047081375b007f102d4968d5a0191f567a9776294445af9ac" +dependencies = [ + "anyhow", + "cc", + "libc", + "zip", +] + +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-http" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +dependencies = [ + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "iri-string", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "tracing-core" +version = "0.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad6fa4ac5662b84047081375b007f102d4968d5a0191f567a9776294445af9ac" +checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" dependencies = [ - "anyhow", - "cc", - "libc", - "zip", + "once_cell", ] +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + [[package]] name = "typenum" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" [[package]] name = "unicode-ident" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +checksum = "f63a545481291138910575129486daeaf8ac54aee4387fe7906919f7830c7d9d" [[package]] name = "unicode-width" -version = "0.2.1" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + +[[package]] +name = "unicode-xid" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a1a07cc7db3810833284e8d372ccdc6da29741639ecc70c9ec107df0fa6154c" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" [[package]] name = "unindent" @@ -2063,12 +3721,59 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "323402cff2dd658f39ca17c789b502021b3f18707c91cdf22e3838e1b4023817" +[[package]] +name = "unsafe_cell_slice" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6659959f702dcdaad77bd6e42a9409a32ceccc06943ec93c8a4306be00eb6cf1" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "upon" version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3ead40aa15464f4d808014183fa0b030761ff6f57e162f7fc76d6a900df7a28" +[[package]] +name = "url" +version = "2.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "uuid" +version = "1.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" +dependencies = [ + "getrandom 0.3.3", + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" @@ -2085,6 +3790,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" @@ -2093,44 +3807,67 @@ checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" -version = "0.14.2+wasi-0.2.4" +version = "0.14.7+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" +dependencies = [ + "wasip2", +] + +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" dependencies = [ - "wit-bindgen-rt", + "wit-bindgen", ] [[package]] name = "wasm-bindgen" -version = "0.2.100" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" +checksum = "c1da10c01ae9f1ae40cbfac0bac3b1e724b320abfcf52229f80b547c0d250e2d" dependencies = [ "cfg-if", "once_cell", "rustversion", "wasm-bindgen-macro", + "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.100" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" +checksum = "671c9a5a66f49d8a47345ab942e2cb93c7d1d0339065d4f8139c486121b43b19" dependencies = [ "bumpalo", "log", "proc-macro2", "quote", - "syn", + "syn 2.0.106", "wasm-bindgen-shared", ] +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e038d41e478cc73bae0ff9b36c60cff1c98b8f38f8d7e8061e79ee63608ac5c" +dependencies = [ + "cfg-if", + "js-sys", + "once_cell", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "wasm-bindgen-macro" -version = "0.2.100" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" +checksum = "7ca60477e4c59f5f2986c50191cd972e3a50d8a95603bc9434501cf156a9a119" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2138,31 +3875,44 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.100" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" +checksum = "9f07d2f20d4da7b26400c9f4a0511e6e0345b040694e8a75bd41d578fa4421d7" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.100" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +checksum = "bad67dc8b2a1a6e5448428adec4c3e84c43e561d8c9ee8a9e5aabeb193ec41d1" dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" -version = "0.3.77" +version = "0.3.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" +checksum = "9367c417a924a74cae129e6a2ae3b47fabb1f8995595ab474029da749a8be120" dependencies = [ "js-sys", "wasm-bindgen", @@ -2178,20 +3928,42 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" -version = "0.1.9" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" -version = "0.61.2" +version = "0.62.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ "windows-implement", "windows-interface", @@ -2202,50 +3974,59 @@ dependencies = [ [[package]] name = "windows-implement" -version = "0.60.0" +version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] name = "windows-interface" -version = "0.59.1" +version = "0.59.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] name = "windows-link" -version = "0.1.3" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" [[package]] name = "windows-result" -version = "0.3.4" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" dependencies = [ "windows-link", ] [[package]] name = "windows-strings" -version = "0.4.2" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" dependencies = [ "windows-link", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.59.0" @@ -2261,7 +4042,16 @@ version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" dependencies = [ - "windows-targets 0.53.2", + "windows-targets 0.53.5", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", ] [[package]] @@ -2282,18 +4072,19 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.53.2" +version = "0.53.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c66f69fcc9ce11da9966ddb31a40968cad001c5bedeb5c2b82ede4253ab48aef" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ - "windows_aarch64_gnullvm 0.53.0", - "windows_aarch64_msvc 0.53.0", - "windows_i686_gnu 0.53.0", - "windows_i686_gnullvm 0.53.0", - "windows_i686_msvc 0.53.0", - "windows_x86_64_gnu 0.53.0", - "windows_x86_64_gnullvm 0.53.0", - "windows_x86_64_msvc 0.53.0", + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", ] [[package]] @@ -2304,9 +4095,9 @@ checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_gnullvm" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" [[package]] name = "windows_aarch64_msvc" @@ -2316,9 +4107,9 @@ checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_aarch64_msvc" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" [[package]] name = "windows_i686_gnu" @@ -2328,9 +4119,9 @@ checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnu" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" [[package]] name = "windows_i686_gnullvm" @@ -2340,9 +4131,9 @@ checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_gnullvm" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" [[package]] name = "windows_i686_msvc" @@ -2352,9 +4143,9 @@ checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_i686_msvc" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" [[package]] name = "windows_x86_64_gnu" @@ -2364,9 +4155,9 @@ checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnu" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" [[package]] name = "windows_x86_64_gnullvm" @@ -2376,9 +4167,9 @@ checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_gnullvm" -version = "0.53.0" +version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" [[package]] name = "windows_x86_64_msvc" @@ -2388,37 +4179,281 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "windows_x86_64_msvc" -version = "0.53.0" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" + +[[package]] +name = "writeable" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" +checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" [[package]] -name = "wit-bindgen-rt" -version = "0.39.0" +name = "yoke" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" +checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" dependencies = [ - "bitflags", + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", + "synstructure", +] + +[[package]] +name = "zarrs" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad12c7c2b91d2f6871f21efc28fd5a302809d7246974fdd99ef55bd3b16b78a0" +dependencies = [ + "async-generic", + "async-lock", + "async-trait", + "blosc-src", + "bytemuck", + "bytes", + "crc32c", + "derive_more", + "flate2", + "futures", + "getrandom 0.3.3", + "half", + "inventory", + "itertools 0.14.0", + "itoa", + "lru", + "moka", + "ndarray", + "num", + "num-complex", + "quick_cache", + "rayon", + "rayon_iter_concurrent_limit", + "serde", + "serde_json", + "thiserror 2.0.17", + "thread_local", + "unsafe_cell_slice", + "uuid", + "zarrs_data_type", + "zarrs_filesystem", + "zarrs_metadata", + "zarrs_metadata_ext", + "zarrs_plugin", + "zarrs_registry", + "zarrs_storage", + "zstd 0.13.3", +] + +[[package]] +name = "zarrs_data_type" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e68a3b9e663cf4933afcd90f460ee72986fdf6c2b4d43d0441ad049b802342" +dependencies = [ + "derive_more", + "half", + "inventory", + "num", + "thiserror 2.0.17", + "zarrs_metadata", + "zarrs_plugin", +] + +[[package]] +name = "zarrs_filesystem" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e135c32621a3a5796d917768d5c7aa7f58be9480ae00778956b82ec6409150b" +dependencies = [ + "bytes", + "derive_more", + "itertools 0.14.0", + "libc", + "page_size", + "parking_lot", + "pathdiff", + "thiserror 2.0.17", + "walkdir", + "zarrs_storage", +] + +[[package]] +name = "zarrs_metadata" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "708b938e5af9e6564d7135fb2d1e05c0deff3d7124694ff6822aa01614a6c991" +dependencies = [ + "derive_more", + "half", + "monostate", + "serde", + "serde_json", + "thiserror 2.0.17", +] + +[[package]] +name = "zarrs_metadata_ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4fb56ca32761b64c4b2a3db1097fbd29adfb321a129279b1db99be0a61d361a" +dependencies = [ + "derive_more", + "half", + "monostate", + "num", + "serde", + "serde_json", + "serde_repr", + "thiserror 2.0.17", + "zarrs_metadata", + "zarrs_registry", +] + +[[package]] +name = "zarrs_object_store" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5d9d0d3db426dd50dfb0b5d7cc1660bda368caeb5cd8645c60c46bc4f261a19" +dependencies = [ + "async-trait", + "futures", + "object_store", + "zarrs_storage", +] + +[[package]] +name = "zarrs_plugin" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3c9e0514d4c50f44d11285d5df70e4e586486a39826579c9d87ddc3f3dac561" +dependencies = [ + "thiserror 2.0.17", +] + +[[package]] +name = "zarrs_registry" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe4e55522eeb87eefab89017bef78cb823f86861fd8a3cc12e9f6538c348d57" +dependencies = [ + "regex", +] + +[[package]] +name = "zarrs_storage" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bc1037a8fa8c44ccb8f5c6c85753a63ddf296fb43280f28150f0f29fda8d301" +dependencies = [ + "async-trait", + "auto_impl", + "bytes", + "derive_more", + "futures", + "itertools 0.14.0", + "parking_lot", + "thiserror 2.0.17", + "unsafe_cell_slice", ] [[package]] name = "zerocopy" -version = "0.8.26" +version = "0.8.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" +checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.26" +version = "0.8.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zerotrie" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7aa2bd55086f1ab526693ecbe444205da57e25f4489879da80635a46d90e73b" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" +checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.106", ] [[package]] @@ -2438,7 +4473,7 @@ dependencies = [ "pbkdf2", "sha1", "time", - "zstd", + "zstd 0.11.2+zstd.1.5.2", ] [[package]] @@ -2447,7 +4482,16 @@ version = "0.11.2+zstd.1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" dependencies = [ - "zstd-safe", + "zstd-safe 5.0.2+zstd.1.5.2", +] + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe 7.2.4", ] [[package]] @@ -2460,11 +4504,20 @@ dependencies = [ "zstd-sys", ] +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + [[package]] name = "zstd-sys" -version = "2.0.15+zstd.1.5.7" +version = "2.0.16+zstd.1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index 6f1fe11..b8410b4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nutpie" -version = "0.15.2" +version = "0.16.0" authors = [ "Adrian Seyboldt ", "PyMC Developers ", @@ -10,7 +10,7 @@ license = "MIT" repository = "https://github.com/pymc-devs/nutpie" keywords = ["statistics", "bayes"] description = "Python wrapper for nuts-rs -- a NUTS sampler written in Rust." -rust-version = "1.76" +rust-version = "1.89" [features] extension-module = ["pyo3/extension-module"] @@ -21,30 +21,35 @@ name = "_lib" crate-type = ["cdylib"] [dependencies] -nuts-rs = "0.16.1" -numpy = "0.25.0" +nuts-rs = { version = "0.17.0", features = ["zarr", "arrow"] } +numpy = "0.26.0" rand = "0.9.0" thiserror = "2.0.3" rand_chacha = "0.9.0" -rayon = "1.10.0" -# Keep arrow in sync with nuts-rs requirements -arrow = { version = "55.1.0", default-features = false, features = ["ffi"] } +rayon = "1.11.0" anyhow = "1.0.72" itertools = "0.14.0" -bridgestan = "2.6.1" +bridgestan = "2.7.0" rand_distr = "0.5.0" -smallvec = "1.14.0" +smallvec = "1.15.0" upon = { version = "0.10.0", default-features = false, features = [] } time-humanize = { version = "0.1.3", default-features = false } indicatif = "0.18.0" tch = { version = "0.20.0", optional = true } +pyo3-object_store = "0.6.0" +# Keep zarrs crates in sync with nuts-rs requirements +zarrs = { version = "0.22.2", features = ["async"] } +zarrs_object_store = "0.5.0" +tokio = { version = "1.47.1", features = ["rt", "rt-multi-thread"] } +pyo3-arrow = "0.12.0" +arrow = { version = "56.2.0", features = ["json"] } [dependencies.pyo3] -version = "0.25.0" +version = "0.26.0" features = ["extension-module", "anyhow"] [dev-dependencies] -criterion = "0.6.0" +criterion = "0.7.0" [profile.release] lto = "fat" diff --git a/docs/pymc-usage.qmd b/docs/pymc-usage.qmd index 428143c..a045490 100644 --- a/docs/pymc-usage.qmd +++ b/docs/pymc-usage.qmd @@ -129,7 +129,7 @@ pixi add jax We can select the backend by passing the `backend` argument to the `compile_pymc_model`: ```python -compiled_jax = nutpie.compiled_pymc_model(model, backend="jax") +compiled_jax = nutpie.compile_pymc_model(model, backend="jax") trace = nutpie.sample(compiled_jax) ``` diff --git a/docs/sampling-options.qmd b/docs/sampling-options.qmd index ef6bfe4..458c6df 100644 --- a/docs/sampling-options.qmd +++ b/docs/sampling-options.qmd @@ -25,7 +25,7 @@ trace = nutpie.sample( tune=500, # Number of warmup draws for adaptation chains=6, # Number of independent chains cores=None, # Number chains that are allowed to run simultainiously - seed=12345 # Random seed for reproducibility + seed=12345 # Random seed for reproducibility ) ``` @@ -143,6 +143,60 @@ trace = nutpie.sample( ) ``` +## Zarr Storage (Experimental) + +Nutpie includes experimental support for writing traces directly to zarr storage, which can be useful for large traces that don't fit in memory or for distributed storage scenarios. The zarr format provides efficient, chunked, compressed storage for multi-dimensional arrays. + +### Basic Usage + +You can write traces directly to zarr storage by providing a `zarr_store` parameter: + +```python +import nutpie +import pymc as pm + +with pm.Model() as model: + pm.HalfNormal("a") + +compiled = nutpie.compile_pymc_model(model, backend="numba") + +# Create a local zarr store +path = "trace.zarr" +store = nutpie.zarr_store.LocalStore(path) + +trace = nutpie.sample( + compiled, + chains=2, + seed=123, + draws=100, + tune=100, + zarr_store=store +) +``` + +### Memory Considerations + +When using zarr storage, the trace object supports lazy loading: + +```python +# The trace is not loaded into memory by default +posterior_data = trace.posterior.a # Lazy access + +# Explicitly load the entire trace into memory (optional) +loaded_trace = trace.load() +posterior_data = loaded_trace.posterior.a # In-memory access +``` + +### Available Store Types + +Nutpie supports several zarr store backends: + +- `nutpie.zarr_store.LocalStore(path)` - Local filesystem storage +- `nutpie.zarr_store.S3Store(...)` - Amazon S3 storage +- `nutpie.zarr_store.GCSStore(...)` - Google Cloud Storage +- `nutpie.zarr_store.AzureStore(...)` - Azure Blob Storage +- `nutpie.zarr_store.HTTPStore(...)` - HTTP-based storage + ## Progress Monitoring Customize the sampling progress display: diff --git a/pyproject.toml b/pyproject.toml index c53f390..84e5e33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,9 +17,12 @@ classifiers = [ dependencies = [ "pyarrow >= 12.0.0", + "arro3-core >= 0.6.0", "pandas >= 2.0", "xarray >= 2025.01.2", "arviz >= 0.20.0", + "obstore >= 0.8.0", + "zarr >= 3.1.0", ] dynamic = ["version"] @@ -28,12 +31,12 @@ Homepage = "https://pymc-devs.github.io/nutpie/" Repository = "https://github.com/pymc-devs/nutpie" [project.optional-dependencies] -stan = ["bridgestan >= 2.6.1", "stanio >= 0.5.1"] +stan = ["bridgestan >= 2.7.0", "stanio >= 0.5.1"] pymc = ["pymc >= 5.20.1", "numba >= 0.60.0"] pymc-jax = ["pymc >= 5.20.1", "jax >= 0.4.27"] nnflow = ["flowjax >= 17.1.0", "equinox >= 0.11.12"] dev = [ - "bridgestan >= 2.6.1", + "bridgestan >= 2.7.0", "stanio >= 0.5.1", "pymc >= 5.20.1", "numba >= 0.60.0", @@ -44,7 +47,7 @@ dev = [ "pytest-arraydiff", ] all = [ - "bridgestan >= 2.6.1", + "bridgestan >= 2.7.0", "stanio >= 0.5.1", "pymc >= 5.20.1", "numba >= 0.60.0", @@ -76,7 +79,7 @@ features = ["pyo3/extension-module"] [tool.pytest.ini_options] markers = [ - "flow: tests for normalizing flows", - "stan: tests for Stan models", - "pymc: tests for PyMC models", + "flow: tests for normalizing flows", + "stan: tests for Stan models", + "pymc: tests for PyMC models", ] diff --git a/python/nutpie/__init__.py b/python/nutpie/__init__.py index 443f099..4084b57 100644 --- a/python/nutpie/__init__.py +++ b/python/nutpie/__init__.py @@ -2,6 +2,13 @@ from nutpie.compile_pymc import compile_pymc_model from nutpie.compile_stan import compile_stan_model from nutpie.sample import sample +from nutpie._lib import store as zarr_store __version__: str = _lib.__version__ -__all__ = ["__version__", "compile_pymc_model", "compile_stan_model", "sample"] +__all__ = [ + "__version__", + "compile_pymc_model", + "compile_stan_model", + "sample", + "zarr_store", +] diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 36a3a09..5028f33 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -111,6 +111,7 @@ class CompiledPyMCModel(CompiledModel): _n_dim: int _shapes: dict[str, tuple[int, ...]] _coords: Optional[dict[str, Any]] + _transform_adapt_args: dict | None = None @property def n_dim(self): @@ -146,13 +147,14 @@ def with_data(self, **updates): user_data=user_data, ) - def _make_sampler(self, settings, init_mean, cores, progress_type): + def _make_sampler(self, settings, init_mean, cores, progress_type, store): model = self._make_model(init_mean) return _lib.PySampler.from_pymc( settings, cores, model, progress_type, + store, ) def _make_model(self, init_mean): @@ -164,24 +166,46 @@ def _make_model(self, init_mean): self, ) logp_fn = _lib.LogpFunc( - self.n_dim, self.compiled_logp_func.address, self.user_data.ctypes.data, self, ) - var_sizes = [prod(shape) for shape in self.shape_info[2]] var_names = self.shape_info[0] + coords = self._coords.copy() if self._coords is not None else {} + dim_sizes = {name: len(vals) for name, vals in coords.items()} + dims = self.dims.copy() if self.dims is not None else {} + var_types = ["float64"] * len(var_names) + var_shapes = self.shape_info[2] + + variables = _lib.PyVariable.new_variables( + var_names, var_types, var_shapes, dim_sizes, dims + ) + + outer_kwargs = self._transform_adapt_args + if outer_kwargs is None: + outer_kwargs = {} + + def make_adapter(*args, **kwargs): + from nutpie.transform_adapter import make_transform_adapter + + return make_transform_adapter(**outer_kwargs)(*args, **kwargs, logp_fn=None) + return _lib.PyMcModel( - self.n_dim, logp_fn, expand_fn, + variables, + self.n_dim, + dim_sizes, + coords, self.initial_point_func, - var_sizes, - var_names, + make_adapter, ) + def with_transform_adapt(self, **kwargs): + return dataclasses.replace(self, _transform_adapt_args=kwargs) + def update_user_data(user_data, user_data_storage): user_data = user_data[()] diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index 9e65b1e..82de950 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Any, Optional -import pandas as pd from numpy.typing import NDArray from nutpie import _lib @@ -52,13 +51,25 @@ def make_adapter(*args, **kwargs): return make_transform_adapter(**outer_kwargs)(*args, **kwargs, logp_fn=None) - model = _lib.StanModel(self.library, seed, data_json, make_adapter) + coords = self._coords + if coords is None: + coords = {} + coords = coords.copy() + + dims = self.dims + if dims is None: + dims = {} + dims = dims.copy() + dim_sizes = {name: len(dim) for name, dim in coords.items()} + + model = _lib.StanModel( + self.library, dim_sizes, dims, coords, seed, data_json, make_adapter + ) coords = self._coords if coords is None: coords = {} else: coords = coords.copy() - coords["unconstrained_parameter"] = pd.Index(model.param_unc_names()) return CompiledStanModel( _coords=coords, @@ -93,13 +104,14 @@ def _make_model(self, init_mean): return self.with_data().model return self.model - def _make_sampler(self, settings, init_mean, cores, progress_type): + def _make_sampler(self, settings, init_mean, cores, progress_type, store): model = self._make_model(init_mean) return _lib.PySampler.from_stan( settings, cores, model, progress_type, + store, ) @property diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index 618feea..db58c28 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -19,6 +19,7 @@ class PyFuncModel(CompiledModel): _shared_data: dict[str, Any] _n_dim: int _variables: list[_lib.PyVariable] + _dim_sizes: dict[str, int] _coords: dict[str, Any] _raw_logp_fn: Callable | None _transform_adapt_args: dict | None = None @@ -47,13 +48,14 @@ def with_data(self, **updates): def with_transform_adapt(self, **kwargs): return dataclasses.replace(self, _transform_adapt_args=kwargs) - def _make_sampler(self, settings, init_mean, cores, progress_type): + def _make_sampler(self, settings, init_mean, cores, progress_type, store): model = self._make_model(init_mean) return _lib.PySampler.from_pyfunc( settings, cores, model, progress_type, + store, ) def _make_model(self, init_mean): @@ -85,6 +87,8 @@ def make_adapter(*args, **kwargs): make_expand_func, self._variables, self.n_dim, + dim_sizes=self._dim_sizes, + coords=self._coords, init_point_func=self._make_initial_points, transform_adapter=make_adapter, ) @@ -105,19 +109,6 @@ def from_pyfunc( make_transform_adapter=None, raw_logp_fn=None, ): - variables = [] - for name, shape, dtype in zip( - expanded_names, expanded_shapes, expanded_dtypes, strict=True - ): - shape = _lib.TensorShape(list(shape)) - if dtype == np.float64: - dtype = _lib.ExpandDtype.float64_array(shape) - elif dtype == np.float32: - dtype = _lib.ExpandDtype.float32_array(shape) - elif dtype == np.int64: - dtype = _lib.ExpandDtype.int64_array(shape) - variables.append(_lib.PyVariable(name, dtype)) - if coords is None: coords = {} if dims is None: @@ -125,10 +116,23 @@ def from_pyfunc( if shared_data is None: shared_data = {} + coords = coords.copy() + + dim_sizes = {k: len(v) for k, v in coords.items()} + shapes = [tuple(shape) for shape in expanded_shapes] + variables = _lib.PyVariable.new_variables( + expanded_names, + [str(dtype) for dtype in expanded_dtypes], + shapes, + dim_sizes, + dims, + ) + return PyFuncModel( _n_dim=ndim, dims=dims, _coords=coords, + _dim_sizes=dim_sizes, _make_logp_func=make_logp_fn, _make_expand_func=make_expand_fn, _make_initial_points=make_initial_point_fn, diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 0655173..edc80f0 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -54,75 +54,102 @@ def benchmark_logp(self, point, num_evals, cores): return pd.concat(times) -def _trace_to_arviz(traces, n_tune, shapes, **kwargs): - n_chains = len(traces) - - data_dict = {} - data_dict_tune = {} - stats_dict = {} - stats_dict_tune = {} - - draw_batches = [] - stats_batches = [] - for draws, stats in traces: - draw_batches.append(pyarrow.RecordBatch.from_struct_array(draws)) - stats_batches.append(pyarrow.RecordBatch.from_struct_array(stats)) - - table = pyarrow.Table.from_batches(draw_batches) - table_stats = pyarrow.Table.from_batches(stats_batches) - for name, col in zip(table.column_names, table.columns): - lengths = [len(chunk) for chunk in col.chunks] - length = max(lengths) - dtype = col.chunks[0].values.to_numpy().dtype - if dtype in [np.float64, np.float32]: - data = np.full( - (n_chains, length, *tuple(shapes[name])), np.nan, dtype=dtype - ) - else: - data = np.zeros((n_chains, length, *tuple(shapes[name])), dtype=dtype) - for i, chunk in enumerate(col.chunks): - data[i, : len(chunk)] = chunk.values.to_numpy().reshape( - (len(chunk),) + shapes[name] - ) - - data_dict[name] = data[:, n_tune:] - data_dict_tune[name] = data[:, :n_tune] +def _arrow_to_arviz(draw_batches, stat_batches, skip_vars=None, **kwargs): + if skip_vars is None: + skip_vars = [] + + n_chains = len(draw_batches) + assert n_chains == len(stat_batches) + + max_tuning = 0 + max_posterior = 0 + num_tuning = [] + + for draw, stat in zip(draw_batches, stat_batches): + tuning = stat.column("tuning") + _num_tuning = tuning.to_numpy().sum() + assert draw.num_rows == stat.num_rows + max_tuning = max(max_tuning, _num_tuning) + max_posterior = max(max_posterior, draw.num_rows - _num_tuning) + num_tuning.append(_num_tuning) + + data_tune = {} + data_posterior = {} + + stats_tune = {} + stats_posterior = {} + + dims = {} + + for i, draw in enumerate(draw_batches): + draw_tune = draw.slice(0, num_tuning[i]) + _add_arrow_data(data_tune, max_tuning, draw_tune, i, n_chains, dims, []) + draw_posterior = draw.slice(num_tuning[i], draw.num_rows - num_tuning[i]) + _add_arrow_data( + data_posterior, max_posterior, draw_posterior, i, n_chains, dims, [] + ) + for i, stat in enumerate(stat_batches): + stat_tune = stat.slice(0, num_tuning[i]) + _add_arrow_data(stats_tune, max_tuning, stat_tune, i, n_chains, dims, skip_vars) + stat_posterior = stat.slice(num_tuning[i], stat.num_rows - num_tuning[i]) + _add_arrow_data( + stats_posterior, max_posterior, stat_posterior, i, n_chains, dims, skip_vars + ) - for name, col in zip(table_stats.column_names, table_stats.columns): - if name in ["chain", "draw", "divergence_message"]: - continue - col_type = col.type - if hasattr(col_type, "list_size"): - last_shape = (col_type.list_size,) - dtype = col_type.field(0).type.to_pandas_dtype() - else: - dtype = col_type.to_pandas_dtype() - last_shape = () + return arviz.from_dict( + data_posterior, + sample_stats=stats_posterior, + warmup_posterior=data_tune, + warmup_sample_stats=stats_tune, + dims=dims, + **kwargs, + ) - lengths = [len(chunk) for chunk in col.chunks] - length = max(lengths) - if dtype in [np.float64, np.float32]: - data = np.full((n_chains, length, *last_shape), np.nan, dtype=dtype) - else: - data = np.zeros((n_chains, length, *last_shape), dtype=dtype) +def _add_arrow_data(data_dict, max_length, batch, chain, n_chains, dims, skip_vars): + num_draws = batch.num_rows - for i, chunk in enumerate(col.chunks): - if hasattr(chunk, "values"): - values = chunk.values.to_numpy(False) + for name in batch.column_names: + if name in skip_vars: + continue + col = batch.column(name) + meta = col.field.metadata + item_dims = meta.get(b"dims", []) + if item_dims: + item_dims = item_dims.decode("utf-8").split(",") + item_shape = meta.get(b"shape", []) + if item_shape: + item_shape = item_shape.decode("utf-8").split(",") + item_shape = [int(s) for s in item_shape] + total_shape = [n_chains, max_length, *item_shape] + + col = pyarrow.array(col) + + is_null = col.is_null() + + if hasattr(col, "flatten"): + col = col.flatten() + dtype = col.type.to_pandas_dtype() + + if name not in data_dict: + if dtype in [np.float64, np.float32]: + data = np.full(total_shape, np.nan, dtype=dtype) else: - values = chunk.to_numpy(False) - data[i, : len(chunk)] = values.reshape((len(chunk), *last_shape)) - stats_dict[name] = data[:, n_tune:] - stats_dict_tune[name] = data[:, :n_tune] + data = np.zeros(total_shape, dtype=dtype) + data_dict[name] = data - return arviz.from_dict( - data_dict, - sample_stats=stats_dict, - warmup_posterior=data_dict_tune, - warmup_sample_stats=stats_dict_tune, - **kwargs, - ) + dims[name] = item_dims + + values = col.to_numpy(False) + if is_null.sum() == 0: + data_dict[name][chain, :num_draws] = values.reshape( + (num_draws,) + tuple(item_shape) + ) + else: + is_null = is_null.to_numpy(False) + data_dict[name][chain, :num_draws][~is_null] = values.reshape( + ((~is_null).sum(),) + tuple(item_shape) + ) _progress_style = """ @@ -360,6 +387,15 @@ def in_colab(): return False # Probably standard Python interpreter +_ZarrStoreType = ( + _lib.store.S3Store + | _lib.store.LocalStore + | _lib.store.HTTPStore + | _lib.store.GCSStore + | _lib.store.AzureStore +) + + class _BackgroundSampler: _sampler: Any _num_divs: int @@ -369,6 +405,8 @@ class _BackgroundSampler: _chains_finished: int _compiled_model: CompiledModel _save_warmup: bool + _store: _lib.PyStorage + _zarr_store: _ZarrStoreType | None = None def __init__( self, @@ -383,6 +421,7 @@ def __init__( progress_template=None, progress_style=None, progress_rate=100, + store=None, ): self._settings = settings self._compiled_model = compiled_model @@ -391,6 +430,14 @@ def __init__( self._html = None + if store is None: + store = _lib.PyStorage.arrow() + elif type(store).__module__ == "_lib.store": + self._zarr_store = store + store = _lib.PyStorage.zarr(store) + + self._store = store + if not progress_bar: progress_type = _lib.ProgressType.none() @@ -411,8 +458,11 @@ def __init__( self.display_id = IPython.display.display(self, display_id=True) def callback(formatted): - self._html = formatted - self.display_id.update(self) + try: + self._html = formatted + self.display_id.update(self) + except Exception as e: + print(f"Error updating progress display: {e}") progress_type = _lib.ProgressType.template_callback( progress_rate, progress_template, cores, callback @@ -447,6 +497,7 @@ def callback(formatted): init_mean, cores, progress_type, + self._store, ) def wait(self, *, timeout=None): @@ -460,35 +511,64 @@ def wait(self, *, timeout=None): This resumes the sampler in case it had been paused. """ self._sampler.wait(timeout) - results = self._sampler.extract_results() + results = self._sampler.take_results() return self._extract(results) def _extract(self, results): - dims = {name: list(dim) for name, dim in self._compiled_model.dims.items()} - dims["mass_matrix_inv"] = ["unconstrained_parameter"] - dims["gradient"] = ["unconstrained_parameter"] - dims["unconstrained_draw"] = ["unconstrained_parameter"] - dims["divergence_start"] = ["unconstrained_parameter"] - dims["divergence_start_gradient"] = ["unconstrained_parameter"] - dims["divergence_end"] = ["unconstrained_parameter"] - dims["divergence_momentum"] = ["unconstrained_parameter"] - dims["transformed_gradient"] = ["unconstrained_parameter"] - dims["transformed_position"] = ["unconstrained_parameter"] - if self._return_raw_trace: return results else: - return _trace_to_arviz( - results, - self._settings.num_tune, - self._compiled_model.shapes, - dims=dims, - coords={ - name: pd.Index(vals) - for name, vals in self._compiled_model.coords.items() - }, - save_warmup=self._save_warmup, - ) + if results.is_zarr(): + from zarr.storage import ObjectStore + import obstore + import xarray as xr + + assert self._zarr_store is not None + + args, kwargs = self._zarr_store.__getnewargs_ex__() + name = self._zarr_store.__class__.__name__ + cls = getattr(obstore.store, name) + store = cls(*args, **kwargs) + + obj_store = ObjectStore(store, read_only=True) + ds = xr.open_datatree(obj_store, engine="zarr", consolidated=False) + return arviz.from_datatree(ds) + + elif results.is_arrow(): + skip_vars = [] + skips = { + "store_gradient": ["gradient"], + "store_unconstrained": ["unconstrained"], + "store_mass_matrix": [ + "mass_matrix_inv", + "mass_matrix_eigvals", + "mass_matrix_stds", + ], + "store_divergences": [ + "divergence_start", + "divergence_end", + "divergence_momentum", + "divergence_start_gradient", + ], + } + + for setting, names in skips.items(): + if not getattr(self._settings, setting, False): + skip_vars.extend(names) + + draw_batches, stat_batches = results.get_arrow_trace() + return _arrow_to_arviz( + draw_batches, + stat_batches, + skip_vars=skip_vars, + coords={ + name: pd.Index(vals) + for name, vals in self._compiled_model.coords.items() + }, + save_warmup=self._save_warmup, + ) + else: + raise ValueError("Unknown results type") def inspect(self): """Get a copy of the current state of the trace""" @@ -543,6 +623,7 @@ def sample( progress_template: str | None = None, progress_style: str | None = None, progress_rate: int = 100, + zarr_store: _ZarrStoreType | None = None, ) -> arviz.InferenceData: ... @@ -565,6 +646,7 @@ def sample( progress_template: str | None = None, progress_style: str | None = None, progress_rate: int = 100, + zarr_store: _ZarrStoreType | None = None, **kwargs, ) -> arviz.InferenceData: ... @@ -588,6 +670,7 @@ def sample( progress_template: str | None = None, progress_style: str | None = None, progress_rate: int = 100, + zarr_store: _ZarrStoreType | None = None, **kwargs, ) -> _BackgroundSampler: ... @@ -610,6 +693,7 @@ def sample( progress_template: str | None = None, progress_style: str | None = None, progress_rate: int = 100, + zarr_store: _ZarrStoreType | None = None, **kwargs, ) -> arviz.InferenceData | _BackgroundSampler: """Sample the posterior distribution for a compiled model. @@ -694,6 +778,12 @@ def sample( transform_adapt: bool, default=False Use the experimental transform adaptation algorithm during tuning. + zarr_store: nutpie.zarr_store.* + A store created using nutpie.zarr_store to store the samples + in. If None (default), the samples will be stored in + memory using an arrow table. This can be used to write + the trace directly into a zarr store, for instance + on disk or to S3 or GCS. **kwargs Pass additional arguments to nutpie._lib.PySamplerArgs @@ -750,6 +840,7 @@ def sample( progress_template=progress_template, progress_style=progress_style, progress_rate=progress_rate, + store=zarr_store, ) if not blocking: diff --git a/src/common.rs b/src/common.rs new file mode 100644 index 0000000..855352c --- /dev/null +++ b/src/common.rs @@ -0,0 +1,381 @@ +use std::collections::HashMap; + +use anyhow::{bail, Context, Result}; +use numpy::{PyArray1, PyReadonlyArray1}; +use nuts_rs::Value; +use pyo3::{ + exceptions::PyRuntimeError, + pyclass, pymethods, + types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods, PyType}, + Bound, FromPyObject, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyErr, Python, +}; +use smallvec::SmallVec; + +#[derive(Debug, Clone)] +pub struct Dims(pub SmallVec<[String; 4]>); + +impl Dims { + pub fn as_slice(&self) -> &[String] { + &self.0 + } +} + +#[derive(Debug, Clone)] +pub struct Shape(pub SmallVec<[u64; 4]>); + +impl Shape { + pub fn as_slice(&self) -> &[u64] { + &self.0 + } +} + +#[derive(Debug, Clone)] +pub struct ItemType(pub nuts_rs::ItemType); + +impl ItemType { + pub fn into_inner(self) -> nuts_rs::ItemType { + self.0 + } + + pub fn as_inner(&self) -> &nuts_rs::ItemType { + &self.0 + } +} + +impl<'py> IntoPyObject<'py> for &Dims { + type Target = PyList; + type Output = Bound<'py, PyList>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> std::result::Result { + PyList::new(py, self.0.iter()) + } +} + +impl<'py> IntoPyObject<'py> for &Shape { + type Target = PyList; + type Output = Bound<'py, PyList>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> std::result::Result { + PyList::new(py, self.0.iter()) + } +} + +impl<'py> IntoPyObject<'py> for &ItemType { + type Target = PyAny; + type Output = Bound<'py, PyAny>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> std::result::Result { + let dtype_str = match self.0 { + nuts_rs::ItemType::U64 => "uint64", + nuts_rs::ItemType::I64 => "int64", + nuts_rs::ItemType::F64 => "float64", + nuts_rs::ItemType::F32 => "float32", + nuts_rs::ItemType::Bool => "bool", + nuts_rs::ItemType::String => "object", + }; + let numpy = py.import("numpy")?; + let dtype = numpy.getattr("dtype")?.call1((dtype_str,))?; + Ok(dtype) + } +} + +impl<'py> FromPyObject<'py> for ItemType { + fn extract_bound(ob: &Bound<'_, PyAny>) -> std::result::Result { + let dtype_str: &str = ob.extract()?; + let item_type = match dtype_str { + "uint64" => nuts_rs::ItemType::U64, + "int64" => nuts_rs::ItemType::I64, + "float64" => nuts_rs::ItemType::F64, + "float32" => nuts_rs::ItemType::F32, + "bool" => nuts_rs::ItemType::Bool, + "object" => nuts_rs::ItemType::String, + _ => { + return Err(PyRuntimeError::new_err(format!( + "Unsupported item type: {}", + dtype_str + ))) + } + }; + Ok(ItemType(item_type)) + } +} + +#[pyclass] +pub struct PyValue(Value); + +impl<'py> FromPyObject<'py> for PyValue { + fn extract_bound(ob: &Bound<'py, PyAny>) -> std::result::Result { + let ob = if ob.hasattr("values")? { + &ob.getattr("values")? + } else { + ob + }; + if let Ok(arr) = ob.extract::>() { + let vec = arr + .as_slice() + .map_err(|_| PyRuntimeError::new_err("Array is not contiguous"))?; + return Ok(PyValue(Value::F64(vec.to_vec()))); + } + if let Ok(arr) = ob.extract::>() { + let vec = arr + .as_slice() + .map_err(|_| PyRuntimeError::new_err("Array is not contiguous"))?; + return Ok(PyValue(Value::F32(vec.to_vec()))); + } + if let Ok(arr) = ob.extract::>() { + let vec = arr + .as_slice() + .map_err(|_| PyRuntimeError::new_err("Array is not contiguous"))?; + return Ok(PyValue(Value::I64(vec.to_vec()))); + } + if let Ok(arr) = ob.extract::>() { + let vec = arr + .as_slice() + .map_err(|_| PyRuntimeError::new_err("Array is not contiguous"))?; + return Ok(PyValue(Value::U64(vec.to_vec()))); + } + if let Ok(arr) = ob.extract::>() { + let vec = arr + .as_slice() + .map_err(|_| PyRuntimeError::new_err("Array is not contiguous"))?; + return Ok(PyValue(Value::Bool(vec.to_vec()))); + } + if let Ok(list) = ob.extract::>() { + let vec: Vec = list + .iter() + .map(|item| { + item.extract::() + .map_err(|_| PyRuntimeError::new_err("List item is not a string")) + }) + .collect::>()?; + return Ok(PyValue(Value::Strings(vec))); + } + if let Ok(arr) = ob.extract::>>() { + let vec = arr + .as_slice() + .map_err(|_| PyRuntimeError::new_err("Array is not contiguous"))?; + let vals_as_str = vec + .iter() + .map(|item| { + item.extract::(ob.py()) + .map_err(|_| PyRuntimeError::new_err("Array item is not a string")) + }) + .collect::>()?; + return Ok(PyValue(Value::Strings(vals_as_str))); + } + Err(PyRuntimeError::new_err( + "Could not convert to Value. Unsupported type.", + )) + } +} + +impl PyValue { + pub fn into_value(self) -> Value { + self.0 + } + + pub fn into_array(self, py: Python) -> Result> { + match self.0 { + Value::F64(vec) => Ok(PyArray1::from_vec(py, vec).into_any()), + Value::F32(vec) => Ok(PyArray1::from_vec(py, vec).into_any()), + Value::I64(vec) => Ok(PyArray1::from_vec(py, vec).into_any()), + Value::U64(vec) => Ok(PyArray1::from_vec(py, vec).into_any()), + Value::Bool(vec) => Ok(PyArray1::from_vec(py, vec).into_any()), + Value::Strings(vec) => Ok(PyList::new(py, vec)?.into_any()), + Value::ScalarString(val) => Ok(val.into_bound_py_any(py)?), + Value::ScalarU64(val) => Ok(val.into_bound_py_any(py)?), + Value::ScalarI64(val) => Ok(val.into_bound_py_any(py)?), + Value::ScalarF64(val) => Ok(val.into_bound_py_any(py)?), + Value::ScalarF32(val) => Ok(val.into_bound_py_any(py)?), + Value::ScalarBool(val) => Ok(val.into_bound_py_any(py)?), + } + } +} + +#[pyclass] +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct PyVariable { + #[pyo3(get)] + pub name: String, + pub item_type: ItemType, + #[pyo3(get)] + pub dims: Dims, + #[pyo3(get)] + pub shape: Shape, + #[pyo3(get)] + pub num_elements: usize, + #[pyo3(get)] + pub start_idx: Option, + #[pyo3(get)] + pub end_idx: Option, +} + +impl PyVariable { + pub fn new( + name: String, + item_type: ItemType, + shape: Option>, + all_dims: &mut HashMap>, + dim_sizes: &mut HashMap, + start_idx: Option, + ) -> anyhow::Result { + let dims = all_dims.get(&name); + + let (dims, shape) = match (dims, shape) { + (Some(dims), Some(shape)) => { + if dims.len() != shape.len() { + bail!( + "Variable '{}': number of dims ({}) does not match number of shape entries ({})", + name, + dims.len(), + shape.len(), + ); + } + for (dim, size) in dims.iter().zip(shape.iter()) { + if let Some(existing_size) = dim_sizes.get(dim) { + if *existing_size != *size { + bail!("Variable '{}': dimension '{}' has inconsistent size. Expected {}, but previously defined as {}", + name, dim, size, existing_size); + } + } + } + (dims.clone(), shape) + } + (Some(dims), None) => { + let mut inferred_shape = Vec::new(); + for dim in dims.iter() { + if let Some(size) = dim_sizes.get(dim) { + inferred_shape.push(*size); + } else { + bail!( + "Variable '{}': dimension '{}' size unknown and no shape provided", + name, + dim + ); + } + } + (dims.clone(), inferred_shape) + } + (None, Some(shape)) => { + let mut inferred_dims = Vec::new(); + for (i, size) in shape.iter().enumerate() { + let generated_name = format!("{}_dim_{}", name, i); + if dim_sizes.contains_key(&generated_name) { + bail!("Variable '{}': generated anonymous dimension name '{}' already exists.", + name, generated_name); + } + dim_sizes.insert(generated_name.clone(), *size); + inferred_dims.push(generated_name); + } + all_dims.insert(name.clone(), inferred_dims.clone()); + (inferred_dims, shape) + } + (None, None) => { + bail!("Variable '{}': no dims or shape provided", name); + } + }; + + let num_elements = shape.iter().product::() as usize; + + Ok(PyVariable { + name, + item_type, + dims: Dims(dims.into()), + shape: Shape(shape.into()), + num_elements, + start_idx, + end_idx: start_idx.map(|idx| idx + num_elements), + }) + } +} + +#[pymethods] +impl PyVariable { + #[classmethod] + fn new_variables<'py>( + cls: &Bound<'py, PyType>, + names: Vec, + item_types: Vec, + shapes: Vec>>, + dim_sizes: Py, + dims: Py, + ) -> Result> { + let mut rust_all_dims = HashMap::new(); + let mut rust_dim_sizes = HashMap::new(); + + let py = cls.py(); + + for (key, value) in dims.bind(py).iter() { + let key: String = key.extract().context("Dimension key is not a string")?; + let value: Vec = value + .extract() + .context("Dimension value is not a list of strings")?; + rust_all_dims.insert(key, value); + } + + for (key, value) in dim_sizes.bind(py).iter() { + let key: String = key + .extract() + .context("Dimension size key is not a string")?; + let value: u64 = value + .extract() + .context("Dimension size value is not an integer")?; + rust_dim_sizes.insert(key, value); + } + + let mut current_idx = 0; + + let variables = names + .into_iter() + .zip(item_types) + .zip(shapes) + .map(|((name, item_type), shape)| { + let item_type = match item_type.as_str() { + "uint64" => ItemType(nuts_rs::ItemType::U64), + "int64" => ItemType(nuts_rs::ItemType::I64), + "float64" => ItemType(nuts_rs::ItemType::F64), + "float32" => ItemType(nuts_rs::ItemType::F32), + "bool" => ItemType(nuts_rs::ItemType::Bool), + "string" => ItemType(nuts_rs::ItemType::String), + _ => bail!("Unsupported item type: {}", item_type), + }; + + let start_idx = Some(current_idx); + let var = Self::new( + name, + item_type, + shape, + &mut rust_all_dims, + &mut rust_dim_sizes, + start_idx, + ) + .context("Could not create variable")?; + current_idx += var.num_elements; + Ok(var) + }) + .collect::>>()?; + + let dim_sizes = dim_sizes.bind(py); + for key in rust_dim_sizes.keys() { + if !dim_sizes.contains(key).unwrap_or(false) { + dim_sizes + .set_item(key, rust_dim_sizes[key]) + .context("Could not update dimension sizes")?; + } + } + + let all_dims = dims.bind(py); + for key in rust_all_dims.keys() { + if !all_dims.contains(key).unwrap_or(false) { + all_dims + .set_item(key, rust_all_dims[key].clone()) + .context("Could not update all_dims")?; + } + } + Ok(variables) + } +} diff --git a/src/lib.rs b/src/lib.rs index 6154f92..287118f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +mod common; mod progress; mod pyfunc; mod pymc; diff --git a/src/progress.rs b/src/progress.rs index 2b130a4..2881c75 100644 --- a/src/progress.rs +++ b/src/progress.rs @@ -1,4 +1,12 @@ -use std::{collections::BTreeMap, sync::Arc, time::Duration}; +use std::{ + collections::BTreeMap, + sync::{ + mpsc::{sync_channel, SyncSender}, + Arc, + }, + thread::spawn, + time::Duration, +}; use anyhow::{Context, Result}; use indicatif::ProgressBar; @@ -10,20 +18,39 @@ use upon::{Engine, Value}; pub struct ProgressHandler { engine: Engine<'static>, template: String, - callback: Arc>, rate: Duration, n_cores: usize, + updates: SyncSender, } impl ProgressHandler { pub fn new(callback: Arc>, rate: Duration, template: String, n_cores: usize) -> Self { let engine = Engine::new(); + + let (update_tx, update_rx) = sync_channel(1); + + spawn(move || { + // We keep an extra gil reference alive, to ensure the + // python ThreadState is not destroyed. + // See https://github.com/PyO3/pyo3/issues/5467 + Python::attach(move |py| { + py.detach(move || loop { + let update = update_rx.recv(); + let Ok(update) = update else { break }; + let res = Python::attach(|py| callback.call1(py, (update,))); + if let Err(err) = res { + eprintln!("Error in progress callback: {err}"); + } + }); + }); + }); + Self { engine, - callback, rate, template, n_cores, + updates: update_tx, } } @@ -50,7 +77,10 @@ impl ProgressHandler { progress_to_value(progress_update_count, self.n_cores, time_sampling, progress); let rendered = template.render_from(&self.engine, &progress).to_string(); let rendered = rendered.unwrap_or_else(|err| format!("{err}")); - let _ = Python::with_gil(|py| self.callback.call1(py, (rendered,))); + if let Err(e) = self.updates.send(rendered) { + eprintln!("Could not send progress update: {e}"); + return; + } progress_update_count += 1; }; diff --git a/src/pyfunc.rs b/src/pyfunc.rs index de3ecd9..8a21875 100644 --- a/src/pyfunc.rs +++ b/src/pyfunc.rs @@ -1,75 +1,24 @@ -use std::sync::Arc; - -use anyhow::{anyhow, bail, Context, Result}; -use arrow::{ - array::{ - Array, ArrayBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int64Builder, - LargeListBuilder, PrimitiveBuilder, StructBuilder, - }, - datatypes::{DataType, Field, Float32Type, Float64Type, Int64Type}, +use std::{collections::HashMap, sync::Arc}; + +use anyhow::{bail, Context, Result}; +use numpy::{ + NotContiguousError, PyArray1, PyReadonlyArray1, PyReadonlyArrayDyn, PyUntypedArrayMethods, }; -use numpy::{NotContiguousError, PyArray1, PyReadonlyArray1}; -use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model}; +use nuts_rs::{CpuLogpFunc, CpuMath, HasDims, LogpError, Model, Storable, Value}; use pyo3::{ exceptions::PyRuntimeError, pyclass, pymethods, - types::{PyAnyMethods, PyDict, PyDictMethods}, + types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyListMethods}, Bound, Py, PyAny, PyErr, Python, }; use rand::Rng; use rand_distr::{Distribution, Uniform}; -use smallvec::SmallVec; use thiserror::Error; -use crate::wrapper::PyTransformAdapt; - -#[pyclass] -#[derive(Debug, Clone)] -#[non_exhaustive] -pub struct PyVariable { - #[pyo3(get)] - pub name: String, - #[pyo3(get)] - pub dtype: ExpandDtype, -} - -impl PyVariable { - fn arrow_dtype(&self) -> DataType { - match &self.dtype { - ExpandDtype::Boolean {} => DataType::Boolean, - ExpandDtype::Float64 {} => DataType::Float64, - ExpandDtype::Float32 {} => DataType::Float32, - ExpandDtype::Int64 {} => DataType::Int64, - ExpandDtype::BooleanArray { tensor_type: _ } => { - let field = Arc::new(Field::new("item", DataType::Boolean, false)); - DataType::LargeList(field) - } - ExpandDtype::ArrayFloat64 { tensor_type: _ } => { - let field = Arc::new(Field::new("item", DataType::Float64, true)); - DataType::LargeList(field) - } - ExpandDtype::ArrayFloat32 { tensor_type: _ } => { - let field = Arc::new(Field::new("item", DataType::Float32, false)); - DataType::LargeList(field) - } - ExpandDtype::ArrayInt64 { tensor_type: _ } => { - let field = Arc::new(Field::new("item", DataType::Int64, false)); - DataType::LargeList(field) - } - } - } -} - -#[pymethods] -impl PyVariable { - #[new] - fn new(name: String, value_type: ExpandDtype) -> Self { - Self { - name, - dtype: value_type, - } - } -} +use crate::{ + common::{PyValue, PyVariable}, + wrapper::PyTransformAdapt, +}; #[pyclass] #[derive(Debug, Clone)] @@ -80,28 +29,59 @@ pub struct PyModel { variables: Arc>, transform_adapter: Option, ndim: usize, + dim_sizes: HashMap, + coords: HashMap, } #[pymethods] impl PyModel { #[new] - #[pyo3(signature = (make_logp_func, make_expand_func, variables, ndim, *, init_point_func=None, transform_adapter=None))] + #[pyo3(signature = (make_logp_func, make_expand_func, variables, ndim, dim_sizes, coords, *, init_point_func=None, transform_adapter=None))] fn new<'py>( + py: Python<'py>, make_logp_func: Py, make_expand_func: Py, variables: Vec, ndim: usize, + dim_sizes: Py, + coords: Py, init_point_func: Option>, transform_adapter: Option>, - ) -> Self { - Self { + ) -> Result { + let dim_sizes = dim_sizes + .bind(py) + .iter() + .map(|(key, value)| { + let key: String = key.extract().context("Dimension key is not a string")?; + let value: u64 = value + .extract() + .context("Dimension size value is not an integer")?; + Ok((key, value)) + }) + .collect::>>()?; + + let coords = coords + .bind(py) + .iter() + .map(|(key, value)| { + let key: String = key.extract().context("Coordinate key is not a string")?; + let value: PyValue = value + .extract() + .context("Coordinate value has incorrect type")?; + Ok((key, value.into_value())) + }) + .collect::>>()?; + + Ok(Self { make_logp_func: Arc::new(make_logp_func), make_expand_func: Arc::new(make_expand_func), init_point_func: init_point_func.map(|x| x.into()), variables: Arc::new(variables), ndim, transform_adapter: transform_adapter.map(PyTransformAdapt::new), - } + dim_sizes, + coords, + }) } } @@ -123,7 +103,7 @@ impl LogpError for PyLogpError { fn is_recoverable(&self) -> bool { match self { Self::BadLogp(_) => true, - Self::PyError(err) => Python::with_gil(|py| { + Self::PyError(err) => Python::attach(|py| { let Ok(attr) = err.value(py).getattr("is_recoverable") else { return false; }; @@ -139,32 +119,94 @@ impl LogpError for PyLogpError { pub struct PyDensity { logp: Py, + expand_func: Py, transform_adapter: Option, dim: usize, + variables: Arc>, + dim_sizes: HashMap, + coords: HashMap, } impl PyDensity { fn new( logp_clone_func: &Py, + expand_clone_func: &Py, dim: usize, transform_adapter: Option<&PyTransformAdapt>, + variables: Arc>, + dim_sizes: HashMap, + coords: HashMap, ) -> Result { - let logp_func = Python::with_gil(|py| logp_clone_func.call0(py))?; + let logp_func = Python::attach(|py| logp_clone_func.call0(py))?; + let expand_func = Python::attach(|py| expand_clone_func.call1(py, (0u64, 0u64, 0u64)))?; let transform_adapter = transform_adapter.cloned(); Ok(Self { logp: logp_func, + expand_func, transform_adapter, dim, + variables, + dim_sizes, + coords, }) } } +impl HasDims for PyDensity { + fn dim_sizes(&self) -> HashMap { + self.dim_sizes.clone() + } + + fn coords(&self) -> HashMap { + self.coords.clone() + } +} + +pub struct ExpandedVector(Vec>); + +impl Storable for ExpandedVector { + fn names(parent: &PyDensity) -> Vec<&str> { + parent + .variables + .iter() + .map(|var| var.name.as_str()) + .collect() + } + + fn item_type(parent: &PyDensity, item: &str) -> nuts_rs::ItemType { + parent + .variables + .iter() + .find(|var| var.name == item) + .map(|var| var.item_type.as_inner().clone()) + .expect("Item not found") + } + + fn dims<'a>(parent: &'a PyDensity, item: &str) -> Vec<&'a str> { + parent + .variables + .iter() + .find(|var| var.name == item) + .map(|var| var.dims.as_slice().iter().map(|s| s.as_str()).collect()) + .expect("Item not found") + } + + fn get_all<'a>(&'a mut self, parent: &'a PyDensity) -> Vec<(&'a str, Option)> { + self.0 + .iter_mut() + .zip(parent.variables.iter()) + .map(|(val, var)| (var.name.as_str(), val.take())) + .collect() + } +} + impl CpuLogpFunc for PyDensity { type LogpError = PyLogpError; - type TransformParams = Py; + type FlowParameters = Py; + type ExpandedVector = ExpandedVector; fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result { - Python::with_gil(|py| { + Python::attach(|py| { let pos_array = PyArray1::from_slice(py, position); let result = self.logp.call1(py, (pos_array,)); match result { @@ -193,6 +235,139 @@ impl CpuLogpFunc for PyDensity { self.dim } + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> std::result::Result + where + R: rand::Rng + ?Sized, + { + Python::attach(|py| { + let expanded = self + .expand_func + .call1(py, (PyArray1::from_slice(py, array),)); + let Ok(expanded) = expanded else { + return Err(nuts_rs::CpuMathError::ExpandError( + "Expanding function raised an error".into(), + )); + }; + let expanded: Bound = expanded.extract(py).map_err(|_| { + nuts_rs::CpuMathError::ExpandError("Expand function did not return a dict".into()) + })?; + let values = expanded.iter(); + let vars = self.variables.iter(); + + let mut expanded = Vec::with_capacity(self.variables.len()); + for (var, (name2, val)) in vars.zip(values) { + let name2 = name2.extract::<&str>().map_err(|_| { + nuts_rs::CpuMathError::ExpandError("expand key was not a string".into()) + })?; + if var.name != name2 { + return Err(nuts_rs::CpuMathError::ExpandError(format!( + "Unexpected expand key: expected {} but found {}", + var.name, name2 + ))); + } + + if val.is_none() { + expanded.push(None); + continue; + } + + fn as_value<'py, 'a, T>( + var: &'a PyVariable, + val: &'a Bound<'py, PyAny>, + ) -> Result, nuts_rs::CpuMathError> + where + T: numpy::Element + Clone, + { + let arr: PyReadonlyArrayDyn = val.extract().map_err(|_| { + nuts_rs::CpuMathError::ExpandError(format!( + "variable {} had incorrect type", + var.name + )) + })?; + if !arr.is_c_contiguous() { + return Err(nuts_rs::CpuMathError::ExpandError( + "not c contiguous".into(), + )); + } + if !arr + .shape() + .iter() + .zip(var.shape.as_slice()) + .all(|(a, &b)| *a as u64 == b) + { + return Err(nuts_rs::CpuMathError::ExpandError("upected shape".into())); + } + Ok(arr) + } + + let val_array = match var.item_type.as_inner() { + nuts_rs::ItemType::F64 => { + let arr = as_value::(var, &val)?; + let slice = arr.as_slice().map_err(|_| { + nuts_rs::CpuMathError::ExpandError("Could not read as slice".into()) + })?; + Some(Value::F64(slice.to_vec())) + } + nuts_rs::ItemType::F32 => { + let arr = as_value::(var, &val)?; + let slice = arr.as_slice().map_err(|_| { + nuts_rs::CpuMathError::ExpandError("Could not read as slice".into()) + })?; + Some(Value::F32(slice.to_vec())) + } + nuts_rs::ItemType::I64 => { + let arr = as_value::(var, &val)?; + let slice = arr.as_slice().map_err(|_| { + nuts_rs::CpuMathError::ExpandError("Could not read as slice".into()) + })?; + Some(Value::I64(slice.to_vec())) + } + nuts_rs::ItemType::Bool => { + let arr = as_value::(var, &val)?; + let slice = arr.as_slice().map_err(|_| { + nuts_rs::CpuMathError::ExpandError("Could not read as slice".into()) + })?; + Some(Value::Bool(slice.to_vec())) + } + nuts_rs::ItemType::U64 => { + let arr = as_value::(var, &val)?; + let slice = arr.as_slice().map_err(|_| { + nuts_rs::CpuMathError::ExpandError("Could not read as slice".into()) + })?; + Some(Value::U64(slice.to_vec())) + } + nuts_rs::ItemType::String => { + let list: Bound = val.extract().map_err(|_| { + nuts_rs::CpuMathError::ExpandError("did not return list".into()) + })?; + if list.len() != var.shape.as_slice().iter().product::() as usize { + return Err(nuts_rs::CpuMathError::ExpandError( + "Incorrect number of items".into(), + )); + } + let vec: Vec = list + .iter() + .map(|item| { + item.extract::().map_err(|_| { + nuts_rs::CpuMathError::ExpandError( + "items were not all strings".into(), + ) + }) + }) + .collect::>()?; + Some(Value::Strings(vec)) + } + }; + expanded.push(val_array); + } + Ok(ExpandedVector(expanded)) + }) + } + fn inv_transform_normalize( &mut self, params: &Py, @@ -305,332 +480,21 @@ impl CpuLogpFunc for PyDensity { } } -pub struct PyTrace { - expand: Py, - variables: Arc>, - builder: StructBuilder, -} - -impl PyTrace { - pub fn new( - rng: &mut R, - chain: u64, - variables: Arc>, - make_expand_func: &Py, - capacity: usize, - ) -> Result { - let seed1 = rng.next_u64(); - let seed2 = rng.next_u64(); - let expand = Python::with_gil(|py| { - make_expand_func - .call1(py, (seed1, seed2, chain)) - .context("Failed to call expand function factory") - })?; - - let fields: Vec = variables - .iter() - .map(|variable| Field::new(variable.name.clone(), variable.arrow_dtype(), false)) - .collect(); - let builder = StructBuilder::from_fields(fields, capacity); - - Ok(Self { - expand, - variables, - builder, - }) - } -} - -pub type ShapeVec = SmallVec<[usize; 4]>; - -#[derive(Debug, Clone)] -#[non_exhaustive] -#[pyclass] -pub struct TensorShape { - pub shape: ShapeVec, - pub dims: Vec>, - size: usize, -} - -impl TensorShape { - pub fn new(shape: ShapeVec, dims: Vec>) -> Self { - let size = shape.iter().product(); - Self { shape, dims, size } - } - pub fn size(&self) -> usize { - self.size - } -} - -#[pymethods] -impl TensorShape { - #[new] - #[pyo3(signature = (shape, dims=None))] - fn py_new(shape: Vec, dims: Option>>) -> Result { - let dims = dims.unwrap_or(shape.iter().map(|_| None).collect()); - if dims.len() != shape.len() { - bail!("Number of dimensions must be the same as the shape"); - } - - let size = shape.iter().product(); - Ok(Self { - shape: shape.into(), - dims, - size, - }) - } -} - -#[non_exhaustive] -#[pyclass] -#[derive(Debug, Clone)] -pub enum ExpandDtype { - Boolean {}, - Float64 {}, - Float32 {}, - Int64 {}, - BooleanArray { tensor_type: TensorShape }, - ArrayFloat64 { tensor_type: TensorShape }, - ArrayFloat32 { tensor_type: TensorShape }, - ArrayInt64 { tensor_type: TensorShape }, -} - -#[pymethods] -impl ExpandDtype { - #[staticmethod] - fn boolean() -> Self { - Self::Boolean {} - } - - #[staticmethod] - fn float64() -> Self { - Self::Float64 {} - } - - #[staticmethod] - fn float32() -> Self { - Self::Float32 {} - } - - #[staticmethod] - fn int64() -> Self { - Self::Int64 {} - } - - #[staticmethod] - fn boolean_array(shape: TensorShape) -> Self { - Self::BooleanArray { tensor_type: shape } - } - - #[staticmethod] - fn float64_array(shape: TensorShape) -> Self { - Self::ArrayFloat64 { tensor_type: shape } - } - #[staticmethod] - fn float32_array(shape: TensorShape) -> Self { - Self::ArrayFloat32 { tensor_type: shape } - } - #[staticmethod] - fn int64_array(shape: TensorShape) -> Self { - Self::ArrayInt64 { tensor_type: shape } - } - - #[getter] - fn shape(&self) -> Option> { - match self { - Self::BooleanArray { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()), - Self::ArrayFloat64 { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()), - Self::ArrayFloat32 { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()), - Self::ArrayInt64 { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()), - _ => None, - } - } -} - -impl DrawStorage for PyTrace { - fn append_value(&mut self, point: &[f64]) -> Result<()> { - Python::with_gil(|py| { - let point = PyArray1::from_slice(py, point); - let full_point = self - .expand - .call1(py, (point,)) - .context("Failed to call expand function")? - .into_bound(py); - let point: &Bound = full_point - .downcast() - .map_err(|_| anyhow!("expand function must return a dict")) - .context("Expand function must return dict")?; - point - .iter() - .zip(self.variables.iter()) - .enumerate() - .try_for_each(|(i, ((key, value), variable))| { - let key: &str = key.extract()?; - if key != variable.name { - return Err(anyhow!("Incorrectly ordered expanded point")); - } - - match &variable.dtype { - ExpandDtype::Boolean {} => { - let builder: &mut BooleanBuilder = - self.builder.field_builder(i).context( - "Builder has incorrect type", - )?; - let value = value - .extract() - .expect("Return value from expand function could not be converted to boolean"); - builder.append_value(value) - }, - ExpandDtype::Float64 {} => { - let builder: &mut Float64Builder = - self.builder.field_builder(i).context( - "Builder has incorrect type", - )?; - builder.append_value( - value - .extract() - .expect("Return value from expand function could not be converted to float64") - ) - }, - ExpandDtype::Float32 {} => { - let builder: &mut Float32Builder = - self.builder.field_builder(i).context( - "Builder has incorrect type", - )?; - builder.append_value( - value - .extract() - .expect("Return value from expand function could not be converted to float32") - ) - }, - ExpandDtype::Int64 {} => { - let builder: &mut Int64Builder = - self.builder.field_builder(i).context( - "Builder has incorrect type", - )?; - let value = value.extract().expect("Return value from expand function could not be converted to int64"); - builder.append_value(value) - }, - ExpandDtype::BooleanArray { tensor_type } => { - let builder: &mut LargeListBuilder> = - self.builder.field_builder(i).context( - "Builder has incorrect type. Expected LargeListBuilder of Bool", - )?; - let value_builder = builder - .values() - .as_any_mut() - .downcast_mut::() - .context("Could not downcast builder to boolean type")?; - let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; - if values.len()? != tensor_type.size() { - bail!("Extracted array has incorrect shape"); - } - value_builder.append_slice(values.as_slice().context("Extracted array is not contiguous")?); - builder.append(true); - }, - ExpandDtype::ArrayFloat64 { tensor_type } => { - //let builder: &mut FixedSizeListBuilder> = - let builder: &mut LargeListBuilder> = - self.builder.field_builder(i).context( - "Builder has incorrect type. Expected LargeListBuilder of Float64", - )?; - let value_builder = builder - .values() - .as_any_mut() - .downcast_mut::>() - .context("Could not downcast builder to float64 type")?; - let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; - if values.len()? != tensor_type.size() { - bail!("Extracted array has incorrect shape"); - } - value_builder.append_slice(values.as_slice().context("Extracted array is not contiguous")?); - builder.append(true); - }, - ExpandDtype::ArrayFloat32 { tensor_type } => { - let builder: &mut LargeListBuilder> = - self.builder.field_builder(i).context( - "Builder has incorrect type. Expected LargeListBuilder of Float32", - )?; - let value_builder = builder - .values() - .as_any_mut() - .downcast_mut::>() - .context("Could not downcast builder to float32 type")?; - let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; - if values.len()? != tensor_type.size() { - bail!("Extracted array has incorrect shape"); - } - value_builder.append_slice(values.as_slice().context("Extracted array is not contiguous")?); - builder.append(true); - }, - ExpandDtype::ArrayInt64 {tensor_type} => { - let builder: &mut LargeListBuilder> = - self.builder.field_builder(i).context( - "Builder has incorrect type. Expected LargeListBuilder of Int64", - )?; - let value_builder = builder - .values() - .as_any_mut() - .downcast_mut::>() - .context("Could not downcast builder to i64 type")?; - let values: PyReadonlyArray1 = value.extract().context("Could not convert object to array")?; - if values.len()? != tensor_type.size() { - bail!("Extracted array has incorrect shape"); - } - value_builder.append_slice(values.as_slice().context("Extracted array is not contiguous")?); - builder.append(true); - }, - } - - Ok(()) - }).context("Could not save output of expand function to trace")?; - self.builder.append(true); - Ok(()) - }) - } - - fn finalize(mut self) -> Result> { - Ok(Arc::new(self.builder.finish())) - } - - fn inspect(&self) -> Result> { - Ok(Arc::new(self.builder.finish_cloned())) - } -} - impl Model for PyModel { type Math<'model> = CpuMath where Self: 'model; - type DrawStorage<'model, S: nuts_rs::Settings> - = PyTrace - where - Self: 'model; - - fn new_trace<'model, S: nuts_rs::Settings, R: rand::prelude::Rng + ?Sized>( - &'model self, - rng: &mut R, - chain_id: u64, - settings: &'model S, - ) -> Result> { - let draws = settings.hint_num_tune() + settings.hint_num_draws(); - PyTrace::new( - rng, - chain_id, - self.variables.clone(), - &self.make_expand_func, - draws, - ) - .context("Could not create PyTrace object") - } - - fn math(&self) -> Result> { + fn math(&self, _rng: &mut R) -> Result> { Ok(CpuMath::new(PyDensity::new( &self.make_logp_func, + &self.make_expand_func, self.ndim, self.transform_adapter.as_ref(), + self.variables.clone(), + self.dim_sizes.clone(), + self.coords.clone(), )?)) } @@ -647,7 +511,7 @@ impl Model for PyModel { let seed = rng.next_u64(); - Python::with_gil(|py| { + Python::attach(|py| { let init_point = init_func .call1(py, (seed,)) .context("Failed to initialize point")?; diff --git a/src/pymc.rs b/src/pymc.rs index 220ad5f..ce6db33 100644 --- a/src/pymc.rs +++ b/src/pymc.rs @@ -1,23 +1,23 @@ -use std::{ffi::c_void, fmt::Display, sync::Arc}; +use std::{collections::HashMap, ffi::c_void, sync::Arc}; use anyhow::{bail, Context, Result}; -use arrow::{ - array::{Array, Float64Array, LargeListArray, StructArray}, - buffer::OffsetBuffer, - datatypes::{DataType, Field, Fields}, -}; -use itertools::{izip, Itertools}; -use numpy::PyReadonlyArray1; -use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model, Settings}; +use numpy::{NotContiguousError, PyReadonlyArray1}; +use nuts_rs::{CpuLogpFunc, CpuMath, HasDims, LogpError, Model, Storable, Value}; use pyo3::{ + exceptions::PyRuntimeError, pyclass, pymethods, - types::{PyAnyMethods, PyList}, - Bound, Py, PyAny, PyObject, PyResult, Python, + types::{PyAnyMethods, PyDict, PyDictMethods}, + Py, PyAny, PyErr, PyResult, Python, }; -use rand_distr::num_traits::CheckedEuclid; +use rand::Rng; use thiserror::Error; +use crate::{ + common::{PyValue, PyVariable}, + wrapper::PyTransformAdapt, +}; + type UserData = *const std::ffi::c_void; type RawLogpFunc = unsafe extern "C" fn( @@ -40,9 +40,8 @@ type RawExpandFunc = unsafe extern "C" fn( #[derive(Clone)] pub(crate) struct LogpFunc { func: RawLogpFunc, - _keep_alive: Arc, + _keep_alive: Arc>, user_data_ptr: UserData, - dim: usize, } unsafe impl Send for LogpFunc {} @@ -51,15 +50,15 @@ unsafe impl Sync for LogpFunc {} #[pymethods] impl LogpFunc { #[new] - fn new(dim: usize, ptr: usize, user_data_ptr: usize, keep_alive: PyObject) -> Self { + fn new(ptr: usize, user_data_ptr: usize, keep_alive: Py) -> Result { let func = unsafe { std::mem::transmute::<*const c_void, RawLogpFunc>(ptr as *const c_void) }; - Self { + + Ok(Self { func, _keep_alive: Arc::new(keep_alive), user_data_ptr: user_data_ptr as UserData, - dim, - } + }) } } @@ -67,7 +66,7 @@ impl LogpFunc { #[derive(Clone)] pub(crate) struct ExpandFunc { func: RawExpandFunc, - _keep_alive: Arc, + _keep_alive: Arc>, user_data_ptr: UserData, dim: usize, expanded_dim: usize, @@ -81,7 +80,7 @@ impl ExpandFunc { expanded_dim: usize, ptr: usize, user_data_ptr: usize, - keep_alive: PyObject, + keep_alive: Py, ) -> Self { let func = unsafe { std::mem::transmute::<*const c_void, RawExpandFunc>(ptr as *const c_void) }; @@ -98,142 +97,297 @@ impl ExpandFunc { unsafe impl Send for ExpandFunc {} unsafe impl Sync for ExpandFunc {} -#[derive(Error, Debug)] -pub(crate) struct ErrorCode(std::os::raw::c_int); +impl HasDims for PyMcModelRef<'_> { + fn dim_sizes(&self) -> HashMap { + self.model.dim_sizes.clone() + } + + fn coords(&self) -> HashMap { + self.model.coords.clone() + } +} + +pub struct ExpandedVector(Vec>); + +impl<'f> Storable> for ExpandedVector { + fn names<'a>(parent: &'a PyMcModelRef<'f>) -> Vec<&'a str> { + parent + .model + .variables + .iter() + .map(|var| var.name.as_str()) + .collect() + } + + fn item_type(parent: &PyMcModelRef<'f>, item: &str) -> nuts_rs::ItemType { + parent + .model + .variables + .iter() + .find(|var| var.name == item) + .map(|var| var.item_type.as_inner().clone()) + .expect("Item not found") + } -impl Display for ErrorCode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Logp function returned error code {}", self.0) + fn dims<'a>(parent: &'a PyMcModelRef<'f>, item: &str) -> Vec<&'a str> { + parent + .model + .variables + .iter() + .find(|var| var.name == item) + .map(|var| var.dims.as_slice().iter().map(|s| s.as_str()).collect()) + .expect("Item not found") + } + + fn get_all<'a>( + &'a mut self, + parent: &'a PyMcModelRef<'f>, + ) -> Vec<(&'a str, Option)> { + self.0 + .iter_mut() + .zip(parent.model.variables.iter()) + .map(|(val, var)| (var.name.as_str(), val.take())) + .collect() } } -impl LogpError for ErrorCode { +#[derive(Debug, Error)] +pub enum PyMcLogpError { + #[error("Python error: {0}")] + PyError(#[from] PyErr), + #[error("Python retured a non-contigous array")] + NotContiguousError(#[from] NotContiguousError), + #[error("Unknown error: {0}")] + Anyhow(#[from] anyhow::Error), + #[error("Logp function returned error code: {0}")] + ErrorCode(std::os::raw::c_int), +} + +impl LogpError for PyMcLogpError { fn is_recoverable(&self) -> bool { - self.0 > 0 + match self { + Self::PyError(err) => Python::attach(|py| { + let Ok(attr) = err.value(py).getattr("is_recoverable") else { + return false; + }; + attr.is_truthy() + .expect("Could not access is_recoverable in error check") + }), + Self::NotContiguousError(_) => false, + Self::Anyhow(_) => false, + Self::ErrorCode(code) => *code > (0 as std::os::raw::c_int), + } } } -impl CpuLogpFunc for &LogpFunc { - type LogpError = ErrorCode; - type TransformParams = (); +pub struct PyMcModelRef<'a> { + model: &'a PyMcModel, + transform_adapter: Option, +} + +impl CpuLogpFunc for PyMcModelRef<'_> { + type LogpError = PyMcLogpError; + type FlowParameters = Py; + type ExpandedVector = ExpandedVector; fn dim(&self) -> usize { - self.dim + self.model.dim } fn logp(&mut self, position: &[f64], gradient: &mut [f64]) -> Result { let mut logp = 0f64; let logp_ptr = (&mut logp) as *mut f64; - assert!(position.len() == self.dim); - assert!(gradient.len() == self.dim); + assert!(position.len() == self.model.dim); + assert!(gradient.len() == self.model.dim); let retcode = unsafe { - (self.func)( - self.dim, + (self.model.density.func)( + self.model.dim, position.as_ptr(), gradient.as_mut_ptr(), logp_ptr, - self.user_data_ptr, + self.model.density.user_data_ptr, ) }; if retcode == 0 { return Ok(logp); } - Err(ErrorCode(retcode)) + Err(PyMcLogpError::ErrorCode(retcode)) } -} - -#[derive(Clone)] -pub(crate) struct PyMcTrace<'model> { - dim: usize, - data: Vec>, - var_sizes: Vec, - var_names: Vec, - expand: &'model ExpandFunc, - count: usize, -} - -impl<'model> DrawStorage for PyMcTrace<'model> { - fn append_value(&mut self, point: &[f64]) -> Result<()> { - assert!(point.len() == self.dim); - let point = self - .expand_draw(point) - .context("Could not compute deterministic variables")?; + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> std::result::Result + where + R: rand::Rng + ?Sized, + { + let mut out = vec![0f64; self.model.expand.expanded_dim].into_boxed_slice(); + let retcode = unsafe { + (self.model.expand.func)( + self.model.expand.dim, + self.model.expand.expanded_dim, + array.as_ptr(), + out.as_mut_ptr(), + self.model.expand.user_data_ptr, + ) + }; - let mut start: usize = 0; - for (&size, data) in self.var_sizes.iter().zip_eq(self.data.iter_mut()) { - let end = start.checked_add(size).unwrap(); - let vals = &point[start..end]; - data.extend_from_slice(vals); - start = end; + let mut values = Vec::new(); + for var in self.model.variables.iter() { + let start = var.start_idx.expect("Variable has no start index"); + let end = var.end_idx.expect("Variable has no end index"); + let slice = &out[start..end]; + + let value = match var.item_type.as_inner() { + nuts_rs::ItemType::U64 => { + let vec: Vec = slice.iter().map(|&x| x as u64).collect(); + nuts_rs::Value::U64(vec.into()) + } + nuts_rs::ItemType::I64 => { + let vec: Vec = slice.iter().map(|&x| x as i64).collect(); + nuts_rs::Value::I64(vec.into()) + } + nuts_rs::ItemType::F64 => { + let vec: Vec = slice.iter().map(|&x| x as f64).collect(); + nuts_rs::Value::F64(vec.into()) + } + nuts_rs::ItemType::F32 => { + let vec: Vec = slice.iter().map(|&x| x as f32).collect(); + nuts_rs::Value::F32(vec.into()) + } + nuts_rs::ItemType::Bool => { + let vec: Vec = slice.iter().map(|&x| x != 0.0).collect(); + nuts_rs::Value::Bool(vec.into()) + } + nuts_rs::ItemType::String => { + return Err(nuts_rs::CpuMathError::ExpandError( + "String type not supported in expansion".into(), + )); + } + }; + + values.push(Some(value)); } - self.count += 1; - Ok(()) + if retcode == 0 { + Ok(ExpandedVector(values)) + } else { + Err(nuts_rs::CpuMathError::ExpandError(format!( + "Expand function returned error code {}", + retcode + ))) + } + } + fn inv_transform_normalize( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result { + let logdet = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .inv_transform_normalize( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + )?; + Ok(logdet) } - fn finalize(self) -> Result> { - let (fields, arrays): (Vec<_>, _) = izip!(self.data, self.var_names, self.var_sizes) - .map(|(data, name, size)| { - let (num_arrays, rem) = data - .len() - .checked_div_rem_euclid(&size) - .unwrap_or((self.count, 0)); - assert!(rem == 0); - assert!(num_arrays == self.count); - let data = Float64Array::from(data); - let item_field = Arc::new(Field::new("item", DataType::Float64, false)); - let offsets = OffsetBuffer::from_lengths((0..num_arrays).map(|_| size)); - let array = LargeListArray::new(item_field.clone(), offsets, Arc::new(data), None); - let field = Field::new(name, DataType::LargeList(item_field), false); - (Arc::new(field), Arc::new(array) as Arc) - }) - .unzip(); + fn init_from_transformed_position( + &mut self, + params: &Py, + untransformed_position: &mut [f64], + untransformed_gradient: &mut [f64], + transformed_position: &[f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result<(f64, f64), Self::LogpError> { + let (logp, logdet) = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .init_from_transformed_position( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + )?; + Ok((logp, logdet)) + } - let fields = Fields::from(fields); - Ok(Arc::new( - StructArray::try_new(fields, arrays, None).context("Could not create arrow struct")?, - )) + fn init_from_untransformed_position( + &mut self, + params: &Py, + untransformed_position: &[f64], + untransformed_gradient: &mut [f64], + transformed_position: &mut [f64], + transformed_gradient: &mut [f64], + ) -> std::result::Result<(f64, f64), Self::LogpError> { + let (logp, logdet) = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .init_from_untransformed_position( + params, + untransformed_position, + untransformed_gradient, + transformed_position, + transformed_gradient, + )?; + Ok((logp, logdet)) } - fn inspect(&self) -> Result> { - self.clone().finalize() + fn update_transformation<'a, R: rand::Rng + ?Sized>( + &'a mut self, + rng: &mut R, + untransformed_positions: impl ExactSizeIterator, + untransformed_gradients: impl ExactSizeIterator, + untransformed_logp: impl ExactSizeIterator, + params: &'a mut Py, + ) -> std::result::Result<(), Self::LogpError> { + self.transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .update_transformation( + rng, + untransformed_positions, + untransformed_gradients, + untransformed_logp, + params, + )?; + Ok(()) } -} -impl<'model> PyMcTrace<'model> { - fn new(model: &'model PyMcModel, settings: &impl Settings) -> Self { - let draws = settings.hint_num_draws() + settings.hint_num_tune(); - Self { - dim: model.dim, - data: model - .var_sizes - .iter() - .map(|&size| Vec::with_capacity(size * draws)) - .collect(), - var_sizes: model.var_sizes.clone(), - var_names: model.var_names.clone(), - expand: &model.expand, - count: 0, - } + fn new_transformation( + &mut self, + rng: &mut R, + untransformed_position: &[f64], + untransformed_gradient: &[f64], + chain: u64, + ) -> std::result::Result, Self::LogpError> { + let trafo = self + .transform_adapter + .as_mut() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .new_transformation(rng, untransformed_position, untransformed_gradient, chain)?; + Ok(trafo) } - fn expand_draw(&mut self, point: &[f64]) -> Result> { - let mut out = vec![0f64; self.expand.expanded_dim].into_boxed_slice(); - let retcode = unsafe { - (self.expand.func)( - self.expand.dim, - self.expand.expanded_dim, - point.as_ptr(), - out.as_mut_ptr(), - self.expand.user_data_ptr, - ) - }; - if retcode == 0 { - Ok(out) - } else { - Err(anyhow::Error::msg("Failed to expand a draw.")) - } + fn transformation_id(&self, params: &Py) -> std::result::Result { + let id = self + .transform_adapter + .as_ref() + .ok_or_else(|| PyRuntimeError::new_err("No transformation adapter specified"))? + .transformation_id(params)?; + Ok(id) } } @@ -244,28 +398,59 @@ pub(crate) struct PyMcModel { density: LogpFunc, expand: ExpandFunc, init_func: Arc>, - var_sizes: Vec, - var_names: Vec, + transform_adapter: Option, + variables: Arc>, + dim_sizes: HashMap, + coords: HashMap, } #[pymethods] impl PyMcModel { #[new] fn new<'py>( - dim: usize, + py: Python<'py>, density: LogpFunc, expand: ExpandFunc, + variables: Vec, + dim: usize, + dim_sizes: Py, + coords: Py, init_func: Py, - var_sizes: &Bound<'py, PyList>, - var_names: &Bound<'py, PyList>, + transform_adapter: Option>, ) -> PyResult { + let dim_sizes = dim_sizes + .bind(py) + .iter() + .map(|(key, value)| { + let key: String = key.extract().context("Dimension key is not a string")?; + let value: u64 = value + .extract() + .context("Dimension size value is not an integer")?; + Ok((key, value)) + }) + .collect::>>()?; + + let coords = coords + .bind(py) + .iter() + .map(|(key, value)| { + let key: String = key.extract().context("Coordinate key is not a string")?; + let value: PyValue = value + .extract() + .context("Coordinate value has incorrect type")?; + Ok((key, value.into_value())) + }) + .collect::>>()?; + Ok(Self { dim, density, expand, init_func: init_func.into(), - var_names: var_names.extract()?, - var_sizes: var_sizes.extract()?, + coords, + dim_sizes, + transform_adapter: transform_adapter.map(PyTransformAdapt::new), + variables: Arc::new(variables), }) } @@ -291,12 +476,13 @@ impl PyMcModel { } impl Model for PyMcModel { - type Math<'model> = CpuMath<&'model LogpFunc>; + type Math<'model> = CpuMath>; - type DrawStorage<'model, S: Settings> = PyMcTrace<'model>; - - fn math(&self) -> Result> { - Ok(CpuMath::new(&self.density)) + fn math(&self, _rng: &mut R) -> Result> { + Ok(CpuMath::new(PyMcModelRef { + model: self, + transform_adapter: self.transform_adapter.clone(), + })) } fn init_position( @@ -306,7 +492,7 @@ impl Model for PyMcModel { ) -> Result<()> { let seed = rng.next_u64(); - Python::with_gil(|py| { + Python::attach(|py| { let init_point = self .init_func .call1(py, (seed,)) @@ -329,13 +515,4 @@ impl Model for PyMcModel { })?; Ok(()) } - - fn new_trace<'model, S: Settings, R: rand::prelude::Rng + ?Sized>( - &'model self, - _rng: &mut R, - _chain_id: u64, - settings: &'model S, - ) -> Result> { - Ok(PyMcTrace::new(self, settings)) - } } diff --git a/src/stan.rs b/src/stan.rs index b10ac44..37d4be5 100644 --- a/src/stan.rs +++ b/src/stan.rs @@ -1,23 +1,23 @@ +use std::collections::HashMap; use std::sync::Arc; use std::{ffi::CString, path::PathBuf}; -use anyhow::{bail, Context}; -use arrow::array::{Array, FixedSizeListArray, Float64Array, StructArray}; -use arrow::datatypes::{DataType, Field}; +use anyhow::{bail, Context, Result}; use bridgestan::open_library; -use itertools::{izip, Itertools}; -use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model, Settings}; +use itertools::Itertools; +use nuts_rs::{CpuLogpFunc, CpuMath, HasDims, LogpError, Model, Storable, Value}; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use pyo3::types::{PyDict, PyTuple}; use pyo3::{exceptions::PyValueError, pyclass, pymethods, PyResult}; use rand::prelude::Distribution; -use rand::{rng, RngCore}; +use rand::{rng, Rng, RngCore}; use rand_distr::StandardNormal; use smallvec::{SmallVec, ToSmallVec}; use thiserror::Error; +use crate::common::{ItemType, PyValue, PyVariable}; use crate::wrapper::PyTransformAdapt; type InnerModel = bridgestan::Model>; @@ -79,13 +79,22 @@ impl StanVariable { #[pyclass] #[derive(Clone)] pub struct StanModel { - model: Arc, - variables: Vec, + inner: Arc, + variables: Vec, transform_adapter: Option, + dim_sizes: HashMap, + coords: HashMap, + #[pyo3(get)] + dims: HashMap>, + unc_names: Value, } /// Return meta information about the constrained parameters of the model -fn params(var_string: &str) -> anyhow::Result> { +fn params( + var_string: &str, + all_dims: &mut HashMap>, + dim_sizes: &mut HashMap, +) -> anyhow::Result> { if var_string.is_empty() { return Ok(vec![]); } @@ -143,35 +152,38 @@ fn params(var_string: &str) -> anyhow::Result> { .context(format!("Error while parsing stan variable {name}"))?; // Calculate total size of this variable - let size = shape.iter().product(); + let size: usize = shape.iter().product(); let mut end_idx = start_idx + size; // Create Parameter objects (one for real and one for imag if complex) if is_complex { - variables.push(Parameter { - name: format!("{name}.real"), - shape: shape.clone(), - size, - start_idx, - end_idx, - }); + variables.push(PyVariable::new( + format!("{name}.real"), + ItemType(nuts_rs::ItemType::F64), + Some(shape.iter().map(|&d| d as u64).collect()), + all_dims, + dim_sizes, + Some(start_idx), + )?); start_idx = end_idx; end_idx = start_idx + size; - variables.push(Parameter { - name: format!("{name}.imag"), - shape, - size, - start_idx, - end_idx, - }); + variables.push(PyVariable::new( + format!("{name}.imag"), + ItemType(nuts_rs::ItemType::F64), + Some(shape.iter().map(|&d| d as u64).collect()), + all_dims, + dim_sizes, + Some(start_idx), + )?); } else { - variables.push(Parameter { - name: name.to_string(), - shape, - size, - start_idx, - end_idx, - }); + variables.push(PyVariable::new( + name.to_string(), + ItemType(nuts_rs::ItemType::F64), + Some(shape.iter().map(|&d| d as u64).collect()), + all_dims, + dim_sizes, + Some(start_idx), + )?); } // Move to the next variable @@ -240,29 +252,85 @@ where #[pymethods] impl StanModel { #[new] - #[pyo3(signature = (lib, seed=None, data=None, transform_adapter=None))] + #[pyo3(signature = (lib, dim_sizes, dims, coords, seed=None, data=None, transform_adapter=None))] pub fn new( + py: Python<'_>, lib: StanLibrary, + dim_sizes: Py, + dims: Py, + coords: Py, seed: Option, data: Option, transform_adapter: Option>, ) -> anyhow::Result { + let mut dim_sizes = dim_sizes + .bind(py) + .iter() + .map(|(key, value)| { + let key: String = key.extract().context("Dimension key is not a string")?; + let value: u64 = value + .extract() + .context("Dimension size value is not an integer")?; + Ok((key, value)) + }) + .collect::>>()?; + + let mut dims = dims + .bind(py) + .iter() + .map(|(key, value)| { + let key: String = key.extract().context("Dimension key is not a string")?; + let value: Vec = value + .extract() + .context("Dimension value is not a list of strings")?; + Ok((key, value)) + }) + .collect::>>()?; + + let coords = coords + .bind(py) + .iter() + .map(|(key, value)| { + let key: String = key.extract().context("Coordinate key is not a string")?; + let value: PyValue = value + .extract() + .context("Coordinate value has incorrect type")?; + Ok((key, value.into_value())) + }) + .collect::>>()?; + let seed = match seed { Some(seed) => seed, None => rng().next_u32(), }; let data: Option = data.map(CString::new).transpose()?; - let model = Arc::new( - bridgestan::Model::new(lib.0, data.as_ref(), seed).map_err(anyhow::Error::new)?, - ); + let mut model = + bridgestan::Model::new(lib.0, data.as_ref(), seed).map_err(anyhow::Error::new)?; + + // TODO: bridgestan should not require mut self here + let names = model.param_unc_names(); + let mut names: Vec<_> = names.split(',').map(|v| v.to_string()).collect(); + if let Some(first) = names.first() { + if first.is_empty() { + names = vec![]; + } + }; + let unc_names = Value::Strings(names); + + let model = Arc::new(model); let var_string = model.param_names(true, true); - let variables = params(var_string)?; + let variables = params(var_string, &mut dims, &mut dim_sizes)?; let transform_adapter = transform_adapter.map(PyTransformAdapt::new); + Ok(StanModel { - model, + inner: model, variables, transform_adapter, + dim_sizes, + coords, + dims, + unc_names, }) } @@ -271,29 +339,14 @@ impl StanModel { let results: Result, _> = self .variables .iter() - .map(|var| { - out.set_item( - var.name.clone(), - StanVariable(var.clone()).into_pyobject(py)?, - ) - }) + .map(|var| out.set_item(var.name.clone(), var.clone())) .collect(); results?; Ok(out) } pub fn ndim(&self) -> usize { - self.model.param_unc_num() - } - - pub fn param_unc_names(&mut self) -> anyhow::Result> { - Ok(Arc::get_mut(&mut self.model) - .ok_or_else(|| anyhow::format_err!("Model is currently in use")) - .context("Failed to access the names of unconstrained parameters")? - .param_unc_names() - .split(',') - .map(|name| name.to_string()) - .collect()) + self.inner.param_unc_num() } /* @@ -318,8 +371,10 @@ impl StanModel { } pub struct StanDensity<'model> { - inner: &'model InnerModel, + model: &'model StanModel, + rng: bridgestan::Rng<&'model bridgestan::StanLibrary>, transform_adapter: Option, + expanded_buffer: Vec, } #[derive(Debug, Error)] @@ -340,12 +395,65 @@ impl LogpError for StanLogpError { } } +pub struct ExpandedVector(Vec>); + +impl<'model> Storable> for ExpandedVector { + fn names<'a>(parent: &'a StanDensity<'model>) -> Vec<&'a str> { + parent + .model + .variables + .iter() + .map(|var| var.name.as_str()) + .collect() + } + + fn item_type(parent: &StanDensity<'model>, item: &str) -> nuts_rs::ItemType { + parent + .model + .variables + .iter() + .find(|var| var.name == item) + .map(|var| var.item_type.as_inner().clone()) + .expect("Item not found") + } + + fn dims<'a>(parent: &'a StanDensity<'model>, item: &str) -> Vec<&'a str> { + parent + .model + .variables + .iter() + .find(|var| var.name == item) + .map(|var| var.dims.as_slice().iter().map(|s| s.as_str()).collect()) + .expect("Item not found") + } + + fn get_all<'a>(&'a mut self, parent: &'a StanDensity<'model>) -> Vec<(&'a str, Option)> { + self.0 + .iter_mut() + .zip(parent.model.variables.iter()) + .map(|(val, var)| (var.name.as_str(), val.take())) + .collect() + } +} + +impl<'model> HasDims for StanDensity<'model> { + fn dim_sizes(&self) -> HashMap { + self.model.dim_sizes.clone() + } + + fn coords(&self) -> HashMap { + self.model.coords.clone() + } +} + impl<'model> CpuLogpFunc for StanDensity<'model> { type LogpError = StanLogpError; - type TransformParams = Py; + type FlowParameters = Py; + type ExpandedVector = ExpandedVector; fn logp(&mut self, position: &[f64], grad: &mut [f64]) -> Result { let logp = self + .model .inner .log_density_gradient(position, true, true, grad)?; if !logp.is_finite() { @@ -355,7 +463,60 @@ impl<'model> CpuLogpFunc for StanDensity<'model> { } fn dim(&self) -> usize { - self.inner.param_unc_num() + self.model.inner.param_unc_num() + } + + fn vector_coord(&self) -> Option { + Some(self.model.unc_names.clone()) + } + + fn expand_vector( + &mut self, + _rng: &mut R, + array: &[f64], + ) -> Result + where + R: rand::Rng + ?Sized, + { + self.model + .inner + .param_constrain( + array, + true, + true, + &mut self.expanded_buffer, + Some(&mut self.rng), + ) + .context("Failed to constrain the parameters of the draw") + .map_err(|e| nuts_rs::CpuMathError::ExpandError(format!("{}", e)))?; + + let mut vars = Vec::new(); + + for var in self.model.variables.iter() { + let mut out = Vec::with_capacity(var.num_elements); + let start = var.start_idx.expect("Variable start index not set"); + let end = var.end_idx.expect("Variable end index not set"); + let slice = &self.expanded_buffer[start..end]; + assert!(slice.len() == var.num_elements); + + if var.num_elements == 0 { + vars.push(Some(Value::F64(out))); + continue; + } + + // The slice is in fortran order. This doesn't matter if it low dim + if var.shape.as_slice().len() < 2 { + out.extend_from_slice(slice); + vars.push(Some(Value::F64(out))); + continue; + } + + // We need to transpose + fortran_to_c_order(slice, var.shape.as_slice(), &mut out); + vars.push(Some(Value::F64(out))); + } + + Ok(ExpandedVector(vars)) } fn inv_transform_normalize( @@ -495,11 +656,11 @@ impl<'model> CpuLogpFunc for StanDensity<'model> { } } -fn fortran_to_c_order(data: &[f64], shape: &[usize], out: &mut Vec) { +fn fortran_to_c_order(data: &[f64], shape: &[u64], out: &mut Vec) { let rank = shape.len(); let strides = { - let mut strides: SmallVec<[usize; 8]> = SmallVec::with_capacity(rank); - let mut current: usize = 1; + let mut strides: SmallVec<[u64; 8]> = SmallVec::with_capacity(rank); + let mut current: u64 = 1; for &length in shape.iter() { strides.push(current); current = current @@ -510,33 +671,34 @@ fn fortran_to_c_order(data: &[f64], shape: &[usize], out: &mut Vec) { strides }; - let mut shape: SmallVec<[usize; 8]> = shape.to_smallvec(); + let mut shape: SmallVec<[u64; 8]> = shape.to_smallvec(); shape.reverse(); - let mut idx: SmallVec<[usize; 8]> = shape.iter().map(|_| 0usize).collect(); - let mut position: usize = 0; + let mut idx: SmallVec<[u64; 8]> = shape.iter().map(|_| 0u64).collect(); + let mut position: u64 = 0; 'iterate: loop { - out.push(data[position]); + out.push(data[position as usize]); - let mut axis: usize = 0; + let mut axis: u64 = 0; 'nextidx: loop { - idx[axis] += 1; - position += strides[axis]; + idx[axis as usize] += 1; + position += strides[axis as usize]; - if idx[axis] < shape[axis] { + if idx[axis as usize] < shape[axis as usize] { break 'nextidx; } - idx[axis] = 0; - position -= shape[axis] * strides[axis]; + idx[axis as usize] = 0; + position -= shape[axis as usize] * strides[axis as usize]; axis += 1; - if axis == rank { + if axis == rank as u64 { break 'iterate; } } } } +/* pub struct StanTrace<'model> { inner: &'model InnerModel, model: &'model StanModel, @@ -546,28 +708,6 @@ pub struct StanTrace<'model> { count: usize, } -impl<'model> Clone for StanTrace<'model> { - fn clone(&self) -> Self { - // TODO We should avoid this Clone implementation. - // We only need it for `StanTrace.inspect`, which - // doesn't need rng, so we could avoid this strange - // seed of zeros. - let rng = self - .model - .model - .new_rng(0) - .expect("Could not create stan rng"); - Self { - inner: self.inner, - model: self.model, - trace: self.trace.clone(), - expanded_buffer: self.expanded_buffer.clone(), - rng, - count: self.count, - } - } -} - impl<'model> DrawStorage for StanTrace<'model> { fn append_value(&mut self, point: &[f64]) -> anyhow::Result<()> { self.inner @@ -599,41 +739,13 @@ impl<'model> DrawStorage for StanTrace<'model> { self.count += 1; Ok(()) } - - fn finalize(self) -> anyhow::Result> { - let (fields, arrays): (Vec<_>, Vec<_>) = izip!(self.trace, &self.model.variables) - .map(|(data, variable)| { - let data = Float64Array::from(data); - let item_field = Arc::new(Field::new("item", DataType::Float64, false)); - let array = FixedSizeListArray::new( - item_field.clone(), - variable.size as _, - Arc::new(data), - None, - ); - let dtype = DataType::FixedSizeList(item_field, variable.size as i32); - let field = Arc::new(Field::new(variable.name.clone(), dtype.clone(), false)); - let list: Arc = Arc::new(array); - (field, list) - }) - .unzip(); - - Ok(Arc::new( - StructArray::try_new_with_length(fields.into(), arrays, None, self.count) - .context("Could not create arrow StructArray")?, - )) - } - - fn inspect(&self) -> anyhow::Result> { - self.clone().finalize() - } } +*/ impl Model for StanModel { type Math<'model> = CpuMath>; - type DrawStorage<'model, S: nuts_rs::Settings> = StanTrace<'model>; - + /* fn new_trace<'a, S: Settings, R: rand::Rng + ?Sized>( &'a self, rng: &mut R, @@ -658,11 +770,16 @@ impl Model for StanModel { count: 0, }) } + */ - fn math(&self) -> anyhow::Result> { + fn math(&self, rng: &mut R) -> anyhow::Result> { + let rng = self.inner.new_rng(rng.next_u32())?; + let num_expanded = self.inner.param_num(true, true); Ok(CpuMath::new(StanDensity { - inner: &self.model, + model: &self, + rng, transform_adapter: self.transform_adapter.clone(), + expanded_buffer: vec![0f64; num_expanded], })) } @@ -681,6 +798,8 @@ impl Model for StanModel { #[cfg(test)] mod tests { + use std::collections::HashMap; + use itertools::Itertools; use super::fortran_to_c_order; @@ -741,48 +860,51 @@ mod tests { #[test] fn parse_vars() { + let mut dims = HashMap::new(); + let mut dim_sizes = HashMap::new(); + let vars = ""; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert!(parsed.len() == 0); let vars = "x.1.1,x.2.1,x.3.1,x.1.2,x.2.2,x.3.2"; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert!(parsed.len() == 1); let parsed = parsed[0].clone(); assert!(parsed.name == "x"); - assert!(parsed.shape == vec![3, 2]); + assert!(parsed.shape.as_slice() == vec![3, 2]); // Incorrect order let vars = "x.1.2,x.1.1,x.2.1,x.2.2,x.3.1,x.3.2"; - assert!(super::params(vars).is_err()); + assert!(super::params(vars, &mut dims, &mut dim_sizes).is_err()); // Incorrect order let vars = "x.1.2.real,x.1.2.imag"; - assert!(super::params(vars).is_err()); + assert!(super::params(vars, &mut dims, &mut dim_sizes).is_err()); let vars = "x.1.1.real,x.1.1.imag,x.2.1.real,x.2.1.imag,x.3.1.real,x.3.1.imag"; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert!(parsed.len() == 2); let var = parsed[0].clone(); assert!(var.name == "x.real"); - assert!(var.shape == vec![3, 1]); + assert!(var.shape.as_slice() == vec![3, 1]); let var = parsed[1].clone(); assert!(var.name == "x.imag"); - assert!(var.shape == vec![3, 1]); + assert!(var.shape.as_slice() == vec![3, 1]); // Test single variable let vars = "alpha"; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert_eq!(parsed.len(), 1); let var = &parsed[0]; assert_eq!(var.name, "alpha"); - assert_eq!(var.shape, Vec::::new()); - assert_eq!(var.size, 1); + assert_eq!(var.shape.as_slice(), vec![0; 0]); + assert_eq!(var.num_elements, 1); // Test multiple scalar variables let vars = "alpha,beta,gamma"; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert_eq!(parsed.len(), 3); assert_eq!(parsed[0].name, "alpha"); assert_eq!(parsed[1].name, "beta"); @@ -790,21 +912,21 @@ mod tests { // Test 1D array let vars = "theta.1,theta.2,theta.3,theta.4"; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert_eq!(parsed.len(), 1); let var = &parsed[0]; assert_eq!(var.name, "theta"); - assert_eq!(var.shape, vec![4]); - assert_eq!(var.size, 4); + assert_eq!(var.shape.as_slice(), vec![4]); + assert_eq!(var.num_elements, 4); // Test variable name with colons and dots let vars = "x:1:2.4:1.1,x:1:2.4:1.2,x:1:2.4:1.3"; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert_eq!(parsed.len(), 1); let var = &parsed[0]; assert_eq!(var.name, "x:1:2.4:1"); - assert_eq!(var.shape, vec![3]); - assert_eq!(var.size, 3); + assert_eq!(var.shape.as_slice(), vec![3]); + assert_eq!(var.num_elements, 3); let vars = " a, @@ -1009,89 +1131,89 @@ mod tests { ultimate.2.3:2.3.5, ultimate.2.3:2.4.5 "; - let parsed = super::params(vars).unwrap(); + let parsed = super::params(vars, &mut dims, &mut dim_sizes).unwrap(); assert_eq!(parsed[0].name, "a"); - assert_eq!(parsed[0].shape, vec![0usize; 0]); + assert_eq!(parsed[0].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[1].name, "base"); - assert_eq!(parsed[1].shape, vec![0usize; 0]); + assert_eq!(parsed[1].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[2].name, "base_i"); - assert_eq!(parsed[2].shape, vec![0usize; 0]); + assert_eq!(parsed[2].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[3].name, "pair:1"); - assert_eq!(parsed[3].shape, vec![0usize; 0]); + assert_eq!(parsed[3].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[4].name, "pair:2"); - assert_eq!(parsed[4].shape, vec![0usize; 0]); + assert_eq!(parsed[4].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[5].name, "nested:1"); - assert_eq!(parsed[5].shape, vec![0usize; 0]); + assert_eq!(parsed[5].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[6].name, "nested:2:1"); - assert_eq!(parsed[6].shape, vec![0usize; 0]); + assert_eq!(parsed[6].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[7].name, "nested:2:2.real"); - assert_eq!(parsed[7].shape, vec![0usize; 0]); + assert_eq!(parsed[7].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[8].name, "nested:2:2.imag"); - assert_eq!(parsed[8].shape, vec![0usize; 0]); + assert_eq!(parsed[8].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[9].name, "arr_pair.1:1"); - assert_eq!(parsed[9].shape, vec![0usize; 0]); + assert_eq!(parsed[9].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[10].name, "arr_pair.1:2"); - assert_eq!(parsed[10].shape, vec![0usize; 0]); + assert_eq!(parsed[10].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[11].name, "arr_pair.2:1"); - assert_eq!(parsed[11].shape, vec![0usize; 0]); + assert_eq!(parsed[11].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[12].name, "arr_pair.2:2"); - assert_eq!(parsed[12].shape, vec![0usize; 0]); + assert_eq!(parsed[12].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[13].name, "arr_very_nested.1:1:1"); - assert_eq!(parsed[13].shape, vec![0usize; 0]); + assert_eq!(parsed[13].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[14].name, "arr_very_nested.1:1:2:1"); - assert_eq!(parsed[14].shape, vec![0usize; 0]); + assert_eq!(parsed[14].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[15].name, "arr_very_nested.1:1:2:2.real"); - assert_eq!(parsed[15].shape, vec![0usize; 0]); + assert_eq!(parsed[15].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[16].name, "arr_very_nested.1:1:2:2.imag"); - assert_eq!(parsed[16].shape, vec![0usize; 0]); + assert_eq!(parsed[16].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[17].name, "arr_very_nested.1:2"); - assert_eq!(parsed[17].shape, vec![0usize; 0]); + assert_eq!(parsed[17].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[18].name, "arr_very_nested.2:1:1"); - assert_eq!(parsed[18].shape, vec![0usize; 0]); + assert_eq!(parsed[18].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[19].name, "arr_very_nested.2:1:2:1"); - assert_eq!(parsed[19].shape, vec![0usize; 0]); + assert_eq!(parsed[19].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[20].name, "arr_very_nested.2:1:2:2.real"); - assert_eq!(parsed[20].shape, vec![0usize; 0]); + assert_eq!(parsed[20].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[21].name, "arr_very_nested.2:1:2:2.imag"); - assert_eq!(parsed[21].shape, vec![0usize; 0]); + assert_eq!(parsed[21].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[22].name, "arr_very_nested.2:2"); - assert_eq!(parsed[22].shape, vec![0usize; 0]); + assert_eq!(parsed[22].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[23].name, "arr_very_nested.3:1:1"); - assert_eq!(parsed[23].shape, vec![0usize; 0]); + assert_eq!(parsed[23].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[24].name, "arr_very_nested.3:1:2:1"); - assert_eq!(parsed[24].shape, vec![0usize; 0]); + assert_eq!(parsed[24].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[25].name, "arr_very_nested.3:1:2:2.real"); - assert_eq!(parsed[25].shape, vec![0usize; 0]); + assert_eq!(parsed[25].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[26].name, "arr_very_nested.3:1:2:2.imag"); - assert_eq!(parsed[26].shape, vec![0usize; 0]); + assert_eq!(parsed[26].shape.as_slice(), vec![0; 0]); assert_eq!(parsed[27].name, "arr_very_nested.3:2"); - assert_eq!(parsed[27].shape, vec![0usize; 0]); + assert_eq!(parsed[27].shape.as_slice(), vec![0; 0]); } } diff --git a/src/wrapper.rs b/src/wrapper.rs index f29620c..37322d6 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -6,27 +6,31 @@ use std::{ }; use crate::{ + common::PyVariable, progress::{IndicatifHandler, ProgressHandler}, - pyfunc::{ExpandDtype, PyModel, PyVariable, TensorShape}, + pyfunc::PyModel, pymc::{ExpandFunc, LogpFunc, PyMcModel}, stan::{StanLibrary, StanModel}, }; -use anyhow::{bail, Context, Result}; -use arrow::array::Array; +use anyhow::{anyhow, bail, Context, Result}; use numpy::{PyArray1, PyReadonlyArray1}; use nuts_rs::{ - ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, ProgressCallback, Sampler, - SamplerWaitResult, Trace, TransformedNutsSettings, + ArrowConfig, ArrowTrace, ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, Model, + ProgressCallback, Sampler, SamplerWaitResult, StepSizeAdaptMethod, TransformedNutsSettings, + ZarrAsyncConfig, }; use pyo3::{ - exceptions::PyTimeoutError, - ffi::Py_uintptr_t, + exceptions::{PyTimeoutError, PyValueError}, intern, prelude::*, - types::{PyList, PyTuple}, + types::PyList, }; +use pyo3_arrow::PyRecordBatch; +use pyo3_object_store::AnyObjectStore; use rand::{rng, RngCore}; +use tokio::runtime::Runtime; +use zarrs_object_store::{object_store::limit::LimitStore, AsyncObjectStore}; #[pyclass] struct PyChainProgress(ChainProgress); @@ -276,22 +280,13 @@ impl PyNutsSettings { fn initial_step(&self) -> f64 { match &self.inner { Settings::Diag(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .initial_step + nuts_settings.adapt_options.step_size_settings.initial_step } Settings::LowRank(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .initial_step + nuts_settings.adapt_options.step_size_settings.initial_step } Settings::Transforming(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .initial_step + nuts_settings.adapt_options.step_size_settings.initial_step } } } @@ -300,22 +295,13 @@ impl PyNutsSettings { fn set_initial_step(&mut self, val: f64) { match &mut self.inner { Settings::Diag(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .initial_step = val; + nuts_settings.adapt_options.step_size_settings.initial_step = val; } Settings::LowRank(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .initial_step = val; + nuts_settings.adapt_options.step_size_settings.initial_step = val; } Settings::Transforming(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .initial_step = val; + nuts_settings.adapt_options.step_size_settings.initial_step = val; } } } @@ -338,6 +324,24 @@ impl PyNutsSettings { } } + #[getter] + fn mindepth(&self) -> u64 { + match &self.inner { + Settings::Diag(nuts_settings) => nuts_settings.mindepth, + Settings::LowRank(nuts_settings) => nuts_settings.mindepth, + Settings::Transforming(nuts_settings) => nuts_settings.mindepth, + } + } + + #[setter(maxdepth)] + fn set_mindepth(&mut self, val: u64) { + match &mut self.inner { + Settings::Diag(nuts_settings) => nuts_settings.mindepth = val, + Settings::LowRank(nuts_settings) => nuts_settings.mindepth = val, + Settings::Transforming(nuts_settings) => nuts_settings.mindepth = val, + } + } + #[getter] fn store_gradient(&self) -> bool { match &self.inner { @@ -414,22 +418,13 @@ impl PyNutsSettings { fn set_target_accept(&self) -> f64 { match &self.inner { Settings::Diag(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .target_accept + nuts_settings.adapt_options.step_size_settings.target_accept } Settings::LowRank(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .target_accept + nuts_settings.adapt_options.step_size_settings.target_accept } Settings::Transforming(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .target_accept + nuts_settings.adapt_options.step_size_settings.target_accept } } } @@ -438,22 +433,13 @@ impl PyNutsSettings { fn target_accept(&mut self, val: f64) { match &mut self.inner { Settings::Diag(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .target_accept = val + nuts_settings.adapt_options.step_size_settings.target_accept = val } Settings::LowRank(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .target_accept = val + nuts_settings.adapt_options.step_size_settings.target_accept = val } Settings::Transforming(nuts_settings) => { - nuts_settings - .adapt_options - .dual_average_options - .target_accept = val + nuts_settings.adapt_options.step_size_settings.target_accept = val } } } @@ -467,9 +453,7 @@ impl PyNutsSettings { Settings::Diag(settings) => { Ok(settings.adapt_options.mass_matrix_options.store_mass_matrix) } - Settings::Transforming(_) => { - bail!("Option store_mass_matrix not availbale for transformation adaptation") - } + Settings::Transforming(_) => Ok(false), } } @@ -561,7 +545,10 @@ impl PyNutsSettings { } #[setter(mass_matrix_eigval_cutoff)] - fn set_mass_matrix_eigval_cutoff(&mut self, val: f64) -> Result<()> { + fn set_mass_matrix_eigval_cutoff(&mut self, val: Option) -> Result<()> { + let Some(val) = val else { + return Ok(()); + }; match &mut self.inner { Settings::LowRank(inner) => inner.adapt_options.mass_matrix_options.eigval_cutoff = val, Settings::Diag(_) => { @@ -588,7 +575,10 @@ impl PyNutsSettings { } #[setter(mass_matrix_gamma)] - fn set_mass_matrix_gamma(&mut self, val: f64) -> Result<()> { + fn set_mass_matrix_gamma(&mut self, val: Option) -> Result<()> { + let Some(val) = val else { + return Ok(()); + }; match &mut self.inner { Settings::LowRank(inner) => { inner.adapt_options.mass_matrix_options.gamma = val; @@ -654,11 +644,153 @@ impl PyNutsSettings { } Ok(()) } + + #[getter] + fn step_size_adapt_method(&self) -> String { + let method = match &self.inner { + Settings::LowRank(inner) => inner.adapt_options.step_size_settings.adapt_options.method, + Settings::Diag(inner) => inner.adapt_options.step_size_settings.adapt_options.method, + Settings::Transforming(inner) => { + inner.adapt_options.step_size_settings.adapt_options.method + } + }; + + match method { + nuts_rs::StepSizeAdaptMethod::DualAverage => "dual_average", + nuts_rs::StepSizeAdaptMethod::Adam => "adam", + nuts_rs::StepSizeAdaptMethod::Fixed(_) => "fixed", + } + .to_string() + } + + #[setter(step_size_adapt_method)] + fn set_step_size_adapt_method(&mut self, method: Py) -> Result<()> { + let method = Python::attach(|py| { + if let Ok(method) = method.extract::(py) { + match method.as_str() { + "dual_average" => Ok(StepSizeAdaptMethod::DualAverage), + "adam" => Ok(StepSizeAdaptMethod::Adam), + _ => { + if let Ok(step_size) = method.parse::() { + Ok(StepSizeAdaptMethod::Fixed(step_size)) + } else { + bail!("step_size_adapt_method must be a positive float when using fixed step size"); + } + } + } + } else { + bail!("step_size_adapt_method must be a string"); + } + })?; + + match &mut self.inner { + Settings::LowRank(inner) => { + inner.adapt_options.step_size_settings.adapt_options.method = method + } + Settings::Diag(inner) => { + inner.adapt_options.step_size_settings.adapt_options.method = method + } + Settings::Transforming(inner) => { + inner.adapt_options.step_size_settings.adapt_options.method = method + } + }; + Ok(()) + } + + #[getter] + fn step_size_adam_learning_rate(&self) -> Option { + match &self.inner { + Settings::LowRank(inner) => { + if let StepSizeAdaptMethod::Adam = + inner.adapt_options.step_size_settings.adapt_options.method + { + Some( + inner + .adapt_options + .step_size_settings + .adapt_options + .adam + .learning_rate, + ) + } else { + None + } + } + Settings::Diag(inner) => { + if let StepSizeAdaptMethod::Adam = + inner.adapt_options.step_size_settings.adapt_options.method + { + Some( + inner + .adapt_options + .step_size_settings + .adapt_options + .adam + .learning_rate, + ) + } else { + None + } + } + Settings::Transforming(inner) => { + if let StepSizeAdaptMethod::Adam = + inner.adapt_options.step_size_settings.adapt_options.method + { + Some( + inner + .adapt_options + .step_size_settings + .adapt_options + .adam + .learning_rate, + ) + } else { + None + } + } + } + } + + #[setter(step_size_adam_learning_rate)] + fn set_step_size_adam_learning_rate(&mut self, val: Option) -> Result<()> { + let Some(val) = val else { + return Ok(()); + }; + match &mut self.inner { + Settings::LowRank(inner) => { + inner + .adapt_options + .step_size_settings + .adapt_options + .adam + .learning_rate = val + } + Settings::Diag(inner) => { + inner + .adapt_options + .step_size_settings + .adapt_options + .adam + .learning_rate = val + } + Settings::Transforming(inner) => { + inner + .adapt_options + .step_size_settings + .adapt_options + .adam + .learning_rate = val + } + }; + Ok(()) + } } pub(crate) enum SamplerState { - Running(Sampler), - Finished(Option), + RunningZarr(Sampler<()>), + RunningArrow(Sampler>), + FinishedZarr, + FinishedArrow(Vec), Empty, } @@ -728,57 +860,221 @@ impl ProgressType { } } +enum InnerPyStorage { + Zarr(Option), + Arrow, +} + #[pyclass] -struct PySampler(Mutex); +struct PyStorage(InnerPyStorage); #[pymethods] -impl PySampler { +impl PyStorage { #[staticmethod] - fn from_pymc( + fn zarr(object_store: AnyObjectStore) -> Self { + Self(InnerPyStorage::Zarr(Some(object_store))) + } + + #[staticmethod] + fn arrow() -> Self { + Self(InnerPyStorage::Arrow) + } +} + +#[pyclass] +struct PySampler(Mutex<(SamplerState, Runtime)>); + +impl PySampler { + fn new( settings: PyNutsSettings, cores: usize, - model: PyMcModel, + model: M, progress_type: ProgressType, - ) -> PyResult { + store: &mut PyStorage, + ) -> PyResult { let callback = progress_type.into_callback()?; - match settings.inner { - Settings::LowRank(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) + let tokio_rt = Runtime::new().context("Failed to create Tokio runtime")?; + match &mut store.0 { + InnerPyStorage::Arrow => { + let storage_config = ArrowConfig::new(); + match settings.inner { + Settings::LowRank(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + Settings::Diag(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + Settings::Transforming(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningArrow(sampler).into(), + tokio_rt, + )))) + } + } } - Settings::Diag(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) + InnerPyStorage::Zarr(store) => { + let object_store = store + .take() + .ok_or_else(|| anyhow!("Can not use storage configuration twice"))? + .into_dyn(); + let object_store = LimitStore::new(object_store, 50); + let store = AsyncObjectStore::new(object_store); + let store = Arc::new(store); + let storage_config = ZarrAsyncConfig::new(tokio_rt.handle().clone(), store); + match settings.inner { + Settings::LowRank(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + Settings::Diag(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + Settings::Transforming(settings) => { + let sampler = + Sampler::new(model, settings, storage_config, cores, callback)?; + Ok(PySampler(Mutex::new(( + SamplerState::RunningZarr(sampler).into(), + tokio_rt, + )))) + } + } } - Settings::Transforming(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) + } + } +} + +impl PySampler { + fn wait_inner_arrow( + &self, + mut control: Sampler>, + timeout: Option, + ) -> (PyResult<()>, SamplerState) { + let start_time = Instant::now(); + let step = Duration::from_millis(100); + + loop { + let time_so_far = Instant::now().saturating_duration_since(start_time); + let next_timeout = match timeout { + Some(timeout) => { + let Some(remaining) = timeout.checked_sub(time_so_far) else { + return ( + Err(PyTimeoutError::new_err( + "Timeout while waiting for sampler to finish", + )), + SamplerState::RunningArrow(control), + ); + }; + remaining.min(step) + } + None => step, + }; + + match control.wait_timeout(next_timeout) { + SamplerWaitResult::Trace(trace) => { + return (Ok(()), SamplerState::FinishedArrow(trace)) + } + SamplerWaitResult::Timeout(new_control) => { + control = new_control; + } + SamplerWaitResult::Err(err, trace) => { + return ( + Err(err.into()), + SamplerState::FinishedArrow(trace.unwrap_or_default()), + ) + } + } + + if let Err(err) = Python::attach(|py| py.check_signals()) { + return (Err(err), SamplerState::RunningArrow(control)); } } } + fn wait_inner_zarr( + &self, + mut control: Sampler<()>, + timeout: Option, + ) -> (PyResult<()>, SamplerState) { + let start_time = Instant::now(); + let step = Duration::from_millis(100); + + loop { + let time_so_far = Instant::now().saturating_duration_since(start_time); + let next_timeout = match timeout { + Some(timeout) => { + let Some(remaining) = timeout.checked_sub(time_so_far) else { + return ( + Err(PyTimeoutError::new_err( + "Timeout while waiting for sampler to finish", + )), + SamplerState::RunningZarr(control), + ); + }; + remaining.min(step) + } + None => step, + }; + + match control.wait_timeout(next_timeout) { + SamplerWaitResult::Trace(_trace) => return (Ok(()), SamplerState::FinishedZarr), + SamplerWaitResult::Timeout(new_control) => { + control = new_control; + } + SamplerWaitResult::Err(err, _trace) => { + return (Err(err.into()), SamplerState::FinishedZarr) + } + } + + if let Err(err) = Python::attach(|py| py.check_signals()) { + return (Err(err), SamplerState::RunningZarr(control)); + } + } + } +} + +#[pymethods] +impl PySampler { + #[staticmethod] + fn from_pymc( + settings: PyNutsSettings, + cores: usize, + model: PyMcModel, + progress_type: ProgressType, + store: &mut PyStorage, + ) -> PyResult { + PySampler::new(settings, cores, model, progress_type, store) + } + #[staticmethod] fn from_stan( settings: PyNutsSettings, cores: usize, model: StanModel, progress_type: ProgressType, + store: &mut PyStorage, ) -> PyResult { - let callback = progress_type.into_callback()?; - match settings.inner { - Settings::LowRank(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) - } - Settings::Diag(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) - } - Settings::Transforming(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) - } - } + PySampler::new(settings, cores, model, progress_type, store) } #[staticmethod] @@ -787,86 +1083,70 @@ impl PySampler { cores: usize, model: PyModel, progress_type: ProgressType, + store: &mut PyStorage, ) -> PyResult { - let callback = progress_type.into_callback()?; - match settings.inner { - Settings::LowRank(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) - } - Settings::Diag(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) - } - Settings::Transforming(settings) => { - let sampler = Sampler::new(model, settings, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler).into())) - } - } + PySampler::new(settings, cores, model, progress_type, store) } fn is_finished(&mut self, py: Python<'_>) -> PyResult { - py.allow_threads(|| { + self.wait(py, Some(0.001))?; + py.detach(|| { let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); - let slot = guard.deref_mut(); - - let state = std::mem::replace(slot, SamplerState::Empty); - - let SamplerState::Running(sampler) = state else { - let _ = std::mem::replace(slot, state); - return Ok(true); - }; - - match sampler.wait_timeout(Duration::from_millis(1)) { - SamplerWaitResult::Trace(trace) => { - let _ = std::mem::replace(slot, SamplerState::Finished(Some(trace))); - Ok(true) - } - SamplerWaitResult::Timeout(sampler) => { - let _ = std::mem::replace(slot, SamplerState::Running(sampler)); - Ok(false) - } - SamplerWaitResult::Err(err, trace) => { - let _ = std::mem::replace(slot, SamplerState::Finished(trace)); - Err(err.into()) - } - } + Ok(matches!( + guard.deref_mut().0, + SamplerState::FinishedZarr | SamplerState::FinishedArrow(_) | SamplerState::Empty + )) }) } fn pause(&mut self, py: Python<'_>) -> PyResult<()> { - py.allow_threads(|| { - if let SamplerState::Running(ref mut control) = self + py.detach(|| { + match self .0 .lock() - .expect("Poised sampler state mutex") + .expect("Poisond sampler state mutex") .deref_mut() { - control.pause()? + (SamplerState::RunningZarr(control), _) => { + control.pause()?; + return Ok(()); + } + (SamplerState::RunningArrow(control), _) => { + control.pause()?; + return Ok(()); + } + _ => return Ok(()), } - Ok(()) }) } fn resume(&mut self, py: Python<'_>) -> PyResult<()> { - py.allow_threads(|| { - if let SamplerState::Running(ref mut control) = self + py.detach(|| { + match self .0 .lock() .expect("Poisond sampler state mutex") .deref_mut() { - control.resume()? + (SamplerState::RunningZarr(control), _) => { + control.resume()?; + return Ok(()); + } + (SamplerState::RunningArrow(control), _) => { + control.resume()?; + return Ok(()); + } + _ => return Ok(()), } - Ok(()) }) } #[pyo3(signature = (timeout_seconds=None))] fn wait(&mut self, py: Python<'_>, timeout_seconds: Option) -> PyResult<()> { - py.allow_threads(|| { + py.detach(|| { let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); let slot = guard.deref_mut(); + let slot = &mut slot.0; let timeout = match timeout_seconds { Some(val) => Some(Duration::try_from_secs_f64(val).context("Invalid timeout")?), @@ -875,46 +1155,12 @@ impl PySampler { let state = std::mem::replace(slot, SamplerState::Empty); - let SamplerState::Running(mut control) = state else { - let _ = std::mem::replace(slot, state); - return Ok(()); - }; - - let start_time = Instant::now(); - let step = Duration::from_millis(100); - - let (final_state, retval) = loop { - let time_so_far = Instant::now().saturating_duration_since(start_time); - let next_timeout = match timeout { - Some(timeout) => { - let Some(remaining) = timeout.checked_sub(time_so_far) else { - break ( - SamplerState::Running(control), - Err(PyTimeoutError::new_err( - "Timeout while waiting for sampler to finish", - )), - ); - }; - remaining.min(step) - } - None => step, - }; - - match control.wait_timeout(next_timeout) { - SamplerWaitResult::Trace(trace) => { - break (SamplerState::Finished(Some(trace)), Ok(())) - } - SamplerWaitResult::Timeout(new_control) => { - control = new_control; - } - SamplerWaitResult::Err(err, trace) => { - break (SamplerState::Finished(trace), Err(err.into())) - } - } - - if let Err(err) = Python::with_gil(|py| py.check_signals()) { - break (SamplerState::Running(control), Err(err)); - } + let (retval, final_state) = match state { + SamplerState::FinishedZarr + | SamplerState::FinishedArrow(_) + | SamplerState::Empty => (Ok(()), state), + SamplerState::RunningZarr(control) => self.wait_inner_zarr(control, timeout), + SamplerState::RunningArrow(control) => self.wait_inner_arrow(control, timeout), }; let _ = std::mem::replace(slot, final_state); @@ -923,104 +1169,163 @@ impl PySampler { } fn abort(&mut self, py: Python<'_>) -> PyResult<()> { - py.allow_threads(|| { + py.detach(|| { let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); let slot = guard.deref_mut(); + let slot = &mut slot.0; let state = std::mem::replace(slot, SamplerState::Empty); - let SamplerState::Running(control) = state else { - let _ = std::mem::replace(slot, state); - return Ok(()); - }; - - let (result, trace) = control.abort(); - let _ = std::mem::replace(slot, SamplerState::Finished(trace)); - result?; - Ok(()) + match state { + SamplerState::FinishedZarr + | SamplerState::FinishedArrow(_) + | SamplerState::Empty => { + let _ = std::mem::replace(slot, state); + return Ok(()); + } + SamplerState::RunningZarr(control) => { + let (result, _) = control.abort()?; + let _ = std::mem::replace(slot, SamplerState::FinishedZarr); + if let Some(err) = result { + Err(err)?; + } + Ok(()) + } + SamplerState::RunningArrow(control) => { + let (result, trace) = control.abort()?; + let _ = std::mem::replace(slot, SamplerState::FinishedArrow(trace)); + if let Some(err) = result { + Err(err)?; + } + Ok(()) + } + } }) } - fn extract_results<'py>(&mut self, py: Python<'py>) -> PyResult> { - let guard = &mut self.0.lock().expect("Poisond sampler state mutex"); - let slot = guard.deref_mut(); - - let state = std::mem::replace(slot, SamplerState::Empty); - - let SamplerState::Finished(trace) = state else { - let _ = std::mem::replace(slot, state); - return Err(anyhow::anyhow!("Sampler is not finished"))?; - }; - - let Some(trace) = trace else { - return Err(anyhow::anyhow!( - "Sampler failed and did not produce a trace" - ))?; - }; + fn is_empty(&self) -> bool { + matches!( + self.0.lock().expect("Poisoned sampler state lock").deref(), + (SamplerState::Empty, _) + ) + } - trace_to_list(trace, py) + fn flush<'py>(&mut self, py: Python<'py>) -> PyResult<()> { + match self + .0 + .lock() + .expect("Poisond sampler state mutex") + .deref_mut() + .0 + { + SamplerState::FinishedZarr => Ok(()), + SamplerState::FinishedArrow(_) => Ok(()), + SamplerState::Empty => Ok(()), + SamplerState::RunningZarr(ref mut control) => { + py.detach(|| control.flush())?; + Ok(()) + } + SamplerState::RunningArrow(ref mut control) => { + py.detach(|| control.flush())?; + Ok(()) + } + } } - fn is_empty(&self) -> bool { - match self.0.lock().expect("Poisoned sampler state lock").deref() { - SamplerState::Running(_) => false, - SamplerState::Finished(_) => false, - SamplerState::Empty => true, + fn inspect<'py>(&self, py: Python<'py>) -> PyResult> { + match &mut self + .0 + .lock() + .expect("Poisond sampler state mutex") + .deref_mut() + .0 + { + SamplerState::FinishedZarr => Ok(Some(PyTrace(InnerPyTrace::Zarr))), + SamplerState::FinishedArrow(trace) => { + Ok(Some(PyTrace(InnerPyTrace::Arrow(Some(trace.clone()))))) + } + SamplerState::Empty => Ok(None), + SamplerState::RunningZarr(control) => { + let (res, _) = py.detach(|| control.inspect())?; + if let Some(err) = res { + return Err(err.into()); + } + Ok(Some(PyTrace(InnerPyTrace::Zarr))) + } + SamplerState::RunningArrow(control) => { + let (res, trace) = py.detach(|| control.inspect())?; + if let Some(err) = res { + return Err(err.into()); + } + Ok(Some(PyTrace(InnerPyTrace::Arrow(Some(trace))))) + } } } - fn inspect<'py>(&mut self, py: Python<'py>) -> PyResult> { - let trace = py.allow_threads(|| { - let mut guard = self.0.lock().unwrap(); - let SamplerState::Running(ref mut sampler) = guard.deref_mut() else { - return Err(anyhow::anyhow!("Sampler is not running"))?; - }; + fn take_results(&mut self) -> PyResult { + let state = &mut self.0.lock().expect("Poisond sampler state mutex"); - sampler.inspect_trace() - })?; - trace_to_list(trace, py) + match &state.0 { + SamplerState::FinishedZarr => { + let _ = std::mem::replace(&mut state.0, SamplerState::Empty); + Ok(PyTrace(InnerPyTrace::Zarr)) + } + SamplerState::FinishedArrow(_) => { + let state = std::mem::replace(&mut state.0, SamplerState::Empty); + let SamplerState::FinishedArrow(trace) = state else { + unreachable!(); + }; + Ok(PyTrace(InnerPyTrace::Arrow(Some(trace)))) + } + SamplerState::Empty => Err(PyErr::new::( + "Sampler has no results to take", + )), + SamplerState::RunningZarr(_) => Err(PyErr::new::( + "Sampler is still running, can only take results after it has finished", + )), + SamplerState::RunningArrow(_) => Err(PyErr::new::( + "Sampler is still running, can only take results after it has finished", + )), + } } } -fn trace_to_list(trace: Trace, py: Python<'_>) -> PyResult> { - let list = PyList::new( - py, - trace - .chains - .into_iter() - .map(|chain| { - Ok(PyTuple::new( - py, - [ - export_array(py, chain.draws)?, - export_array(py, chain.stats)?, - ] - .into_iter(), - )?) - }) - .collect::>>()?, - )?; - Ok(list) +enum InnerPyTrace { + Zarr, + Arrow(Option>), } -fn export_array(py: Python<'_>, data: Arc) -> PyResult { - let pa = py.import("pyarrow")?; - let array = pa.getattr("Array")?; - - let data = data.into_data(); - - let (data, schema) = arrow::ffi::to_ffi(&data).context("Could not convert to arrow ffi")?; +#[pyclass] +pub struct PyTrace(InnerPyTrace); - let data = array - .call_method1( - "_import_from_c", - ( - (&data as *const _ as Py_uintptr_t).into_pyobject(py)?, - (&schema as *const _ as Py_uintptr_t).into_pyobject(py)?, - ), - ) - .context("Could not import arrow trace in python")?; - Ok(data.unbind()) +#[pymethods] +impl PyTrace { + fn is_zarr(&self) -> bool { + matches!(self.0, InnerPyTrace::Zarr) + } + + fn is_arrow(&self) -> bool { + matches!(self.0, InnerPyTrace::Arrow(_)) + } + + fn get_arrow_trace(&mut self) -> PyResult<(Vec, Vec)> { + match &mut self.0 { + InnerPyTrace::Zarr => Err(PyErr::new::( + "Trace is not stored in Arrow format", + )), + InnerPyTrace::Arrow(trace) => Ok(trace + .take() + .ok_or_else(|| PyValueError::new_err("The trace was already taken"))? + .into_iter() + .map(|array| { + ( + PyRecordBatch::new(array.posterior), + PyRecordBatch::new(array.sample_stats), + ) + }) + .collect()), + } + } } #[pyclass] @@ -1044,7 +1349,7 @@ impl PyTransformAdapt { transformed_position: &mut [f64], transformed_gradient: &mut [f64], ) -> Result { - Python::with_gil(|py| { + Python::attach(|py| { let untransformed_position = PyArray1::from_slice(py, untransformed_position); let untransformed_gradient = PyArray1::from_slice(py, untransformed_gradient); @@ -1099,7 +1404,7 @@ impl PyTransformAdapt { transformed_position: &[f64], transformed_gradient: &mut [f64], ) -> Result<(f64, f64)> { - Python::with_gil(|py| { + Python::attach(|py| { let transformed_position = PyArray1::from_slice(py, transformed_position); let output = params @@ -1132,7 +1437,7 @@ impl PyTransformAdapt { untransformed_position: &mut [f64], transformed_position: &[f64], ) -> Result> { - Python::with_gil(|py| { + Python::attach(|py| { let transformed_position = PyArray1::from_slice(py, transformed_position); let output = params @@ -1153,7 +1458,7 @@ impl PyTransformAdapt { untransformed_gradient: &[f64], transformed_gradient: &mut [f64], ) -> Result { - Python::with_gil(|py| { + Python::attach(|py| { let untransformed_gradient = PyArray1::from_slice(py, untransformed_gradient); let output = params @@ -1175,7 +1480,7 @@ impl PyTransformAdapt { transformed_position: &mut [f64], transformed_gradient: &mut [f64], ) -> Result<(f64, f64)> { - Python::with_gil(|py| { + Python::attach(|py| { let untransformed_position = PyArray1::from_slice(py, untransformed_position); let output = params @@ -1214,7 +1519,7 @@ impl PyTransformAdapt { untransformed_logp: impl ExactSizeIterator, params: &'a mut Py, ) -> Result<()> { - Python::with_gil(|py| { + Python::attach(|py| { let positions = PyList::new( py, untransformed_positions.map(|pos| PyArray1::from_slice(py, pos)), @@ -1241,7 +1546,7 @@ impl PyTransformAdapt { untransformed_gradient: &[f64], chain: u64, ) -> Result> { - Python::with_gil(|py| { + Python::attach(|py| { let position = PyArray1::from_slice(py, untransformed_position); let gradient = PyArray1::from_slice(py, untransformed_gradient); @@ -1254,7 +1559,7 @@ impl PyTransformAdapt { } pub fn transformation_id(&self, params: &Py) -> Result { - Python::with_gil(|py| { + Python::attach(|py| { let id: i64 = params .getattr(py, intern!(py, "transformation_id"))? .extract(py)?; @@ -1264,7 +1569,7 @@ impl PyTransformAdapt { } /// A Python module implemented in Rust. -#[pymodule] +#[pymodule(gil_used = false)] pub fn _lib(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; @@ -1275,10 +1580,12 @@ pub fn _lib(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add("__version__", env!("CARGO_PKG_VERSION"))?; + pyo3_object_store::register_store_module(m.py(), m, "_lib", "store")?; + pyo3_object_store::register_exceptions_module(m.py(), m, "_lib", "exceptions")?; Ok(()) } diff --git a/tests/reference/test_deterministic_sampling_jax.txt b/tests/reference/test_deterministic_sampling_jax.txt index 114966e..0c6237a 100644 --- a/tests/reference/test_deterministic_sampling_jax.txt +++ b/tests/reference/test_deterministic_sampling_jax.txt @@ -1,200 +1,200 @@ -0.941959 -0.559649 -0.534203 -0.561444 -0.561444 -0.418685 -0.827896 -0.847014 -0.738508 -0.961291 -0.923931 -1.00584 -1.16386 -1.10065 -1.6348 -1.13139 -0.993458 -0.993458 -0.966241 -1.10922 -1.10922 -1.05723 -1.05723 -2.32492 -0.0700824 -0.0860656 -1.36431 -0.829624 -0.584658 -0.531506 -0.507961 -0.543701 -0.510104 -2.46898 -0.820341 -0.490474 -0.343958 -0.300549 -2.60267 -0.588131 -0.430013 -0.618032 -1.27527 -1.80449 -1.80449 -0.855217 -0.556106 -1.77619 -2.03761 -1.02106 -0.774811 -1.78438 -1.61398 -0.712683 -1.04966 -1.17936 -1.5425 -1.5425 -1.26262 -1.39659 -0.337024 -0.177694 -0.0424286 -0.180403 -0.140553 -0.367095 -0.348732 -0.341436 -1.82764 -0.692738 -0.629186 -0.245706 -0.732305 -0.56873 -0.498757 -0.204131 -0.417031 -0.184895 -0.208768 -0.238139 -1.95089 -1.95089 -0.593379 -0.593379 -0.750063 -0.69929 -0.490359 -0.478709 -0.361632 -0.346159 -0.728965 -1.58228 -0.985676 -1.58468 -0.709012 -0.700483 -0.805006 -1.70347 -1.26293 -1.24837 -0.23989 -0.881025 -1.39084 -1.37812 -0.969265 -0.969265 -0.938487 -0.846447 -1.61945 -0.108473 -0.173496 -0.897353 -0.455899 -0.571886 -0.891672 -0.891672 -0.864419 -0.739099 -1.49009 -1.49009 -0.385499 -0.228701 -1.83156 -1.83156 -0.947635 -0.805623 -0.714762 -0.853477 -1.45906 -0.908818 -0.540951 -1.40995 -1.22564 -0.26496 -0.159994 -0.423836 -0.350158 -0.388884 -1.39507 -0.727701 -1.80674 -0.466389 -1.61574 -1.61574 -0.42774 -0.217983 -0.14579 -1.01321 -1.01321 -1.19713 -0.390791 -0.223687 -0.149019 -0.103866 -0.153768 -0.12942 -0.346371 -0.814553 -2.41042 -0.42739 -0.322291 -0.248911 -0.854404 -1.35372 -1.35372 -2.00546 -0.0457881 -0.0415644 -0.0797551 -0.0913076 -0.070948 -0.00993872 -0.421448 -0.550377 -0.609387 -0.490487 -2.6607 -0.32804 -0.385999 -0.497294 -1.67109 -1.14328 -1.14328 -0.903063 -0.903063 -0.903063 -0.691269 -2.00151 -0.587672 -0.79679 -1.35563 -0.598471 -0.681826 -0.818296 -1.14265 -0.113094 -0.250861 -0.284491 -0.00420445 -0.00566936 +0.00293185 +0.00896809 +0.00768812 +0.00190588 +0.00320792 +0.00740496 +0.0958083 +0.109615 +0.0105327 +0.0192266 +0.0214682 +0.0218331 +0.0585783 +0.0251717 +0.0280682 +0.0729859 +0.425133 +0.457443 +2.52983 +1.15484 +1.15484 +1.18416 +0.104273 +1.20247 +1.1064 +1.67005 +1.05586 +2.55089 +1.83339 +0.971751 +0.470398 +0.284519 +0.253759 +2.29193 +1.29672 +1.29672 +0.432495 +0.411462 +1.10822 +1.10822 +0.698466 +1.01384 +0.422528 +0.471828 +0.354965 +0.370006 +0.932942 +0.924415 +0.821473 +2.34528 +1.8362 +0.329965 +0.427145 +0.995745 +1.17653 +0.937676 +0.937676 +0.71568 +0.916428 +1.05491 +0.479239 +0.488732 +1.07755 +1.05904 +0.269731 +0.197423 +0.303258 +0.0738098 +0.0535444 +0.0704248 +0.083286 +0.158385 +0.149845 +0.416708 +0.349628 +0.31117 +0.304837 +0.0724371 +1.5569 +1.20564 +2.12525 +0.303531 +0.712031 +0.844468 +0.434198 +0.277141 +0.593882 +0.648409 +1.02533 +0.692478 +0.367875 +0.316403 +0.351662 +0.117319 +1.85435 +0.413934 +0.409025 +0.661536 +0.650092 +0.766712 +0.594595 +0.501872 +0.515377 +0.236945 +0.689338 +2.99054 +0.172018 +0.0528735 +0.0579658 +0.0581689 +0.0497977 +0.063146 +0.311101 +0.347411 +0.763051 +0.734721 +1.17926 +1.02504 +1.02504 +0.645771 +0.970169 +1.20163 +1.1179 +0.385697 +0.410691 +0.471671 +0.540587 +0.250604 +0.254267 +0.220907 +0.673968 +0.265055 +0.766607 +1.50436 +1.58131 +0.719291 +0.958127 +0.546963 +1.60432 +1.60432 +1.45897 +0.717682 +0.668208 +0.71339 +0.276479 +0.255967 +0.799242 +1.32658 +0.724295 +0.36085 +0.217894 +0.254816 +0.125993 +1.31909 +1.56969 +0.750499 +1.11993 +1.87465 +1.472 +0.950422 +0.754906 +0.270587 +0.231469 +1.19634 +1.19634 +1.19634 +1.51182 +1.34804 +1.42657 +0.544703 +1.66443 +1.66443 +1.14928 +1.10046 +1.16557 +1.5537 +0.629914 +0.880496 +0.525169 +0.312335 +0.797038 +0.733363 +1.6496 +0.0602699 +0.0840557 +0.107319 +0.0324205 +0.0929894 +0.226149 +0.202803 +0.217807 +0.366175 +0.158146 +0.160235 +0.175013 +0.148804 +0.526506 +0.785313 +1.23336 +0.733001 diff --git a/tests/reference/test_deterministic_sampling_numba.txt b/tests/reference/test_deterministic_sampling_numba.txt index 6426e8c..5bea297 100644 --- a/tests/reference/test_deterministic_sampling_numba.txt +++ b/tests/reference/test_deterministic_sampling_numba.txt @@ -1,200 +1,200 @@ -0.862203 -0.743827 -0.985284 -0.864159 -1.11537 -1.46228 -1.46228 -0.731645 -0.618394 -0.70658 -1.58816 -1.58816 -1.58816 -1.58816 -1.02597 -1.02597 -2.38965 -0.0442154 -0.0556998 -1.20147 -0.878239 -0.595919 -0.542086 -0.520452 -0.56279 -0.539904 -0.129453 -0.136407 -0.408806 -0.34263 -0.929525 -0.947864 -0.947864 -1.94444 -0.911973 -0.429576 -0.776378 -0.452981 -0.985476 -1.74745 -1.74095 -1.74095 -0.9855 -0.886535 -0.617313 -0.86405 -2.00577 -0.839407 -0.745118 -1.49611 -1.74491 -1.40854 -0.631877 -1.95302 -1.01379 -1.1063 -0.930275 -0.315935 -0.225544 -0.136821 -0.180021 -0.498635 -0.462448 -0.445633 -0.0878991 -0.105731 -0.355683 -0.750934 -0.750934 -0.874486 -1.15119 -0.657067 -0.500027 -1.28332 -1.28332 -0.919994 -1.09658 -1.73803 -1.13439 -1.21956 -0.643106 -0.329788 -0.456239 -0.596018 -0.180103 -0.388767 -1.03772 -1.03192 -1.03192 -1.04759 -1.04759 -1.13558 -0.673716 -0.871073 -0.50739 -0.625146 -0.999657 -1.00779 -2.06182 -0.707917 -0.107437 -0.0772623 -0.10719 -0.36616 -0.14863 -0.0333724 -0.0295763 -0.0205304 -0.127619 -0.164319 -0.241143 -0.376838 -0.87369 -1.64165 -0.106128 -0.170459 -0.916833 -0.458599 -0.575215 -0.894488 -0.894488 -0.865427 -0.739365 -0.681649 -0.72888 -1.38352 -1.38352 -2.28238 -2.28238 -2.28238 -0.567775 -0.41864 -1.41709 -1.41709 -1.41709 -1.41709 -0.600311 -0.598689 -0.627731 -0.460137 -1.86219 -1.81783 -1.78092 -1.78092 -1.78092 -0.492732 -1.37953 -1.16762 -0.597573 -0.627465 -0.617661 -0.649115 -0.608255 -0.685365 -0.685365 -0.685365 -0.685365 -0.685365 -2.2227 -0.971606 -0.4219 -0.879055 -0.74434 -2.08679 -1.34952 -1.34952 -1.34952 -1.34952 -0.513284 -0.16734 -0.174037 -0.626756 -0.913504 -0.271423 -0.200176 -0.132462 -0.465497 -0.406755 -0.493296 -0.0175891 -0.0234891 -0.0220327 -0.132404 -0.0788943 -0.0949265 -0.103031 -0.0760492 -0.377155 -1.90599 -1.58063 -1.58063 -1.17038 -0.556726 -0.55085 -0.24632 -0.375951 -0.339243 -0.747524 -1.82921 -0.794344 +0.00293185 +0.00896808 +0.00768811 +0.00190587 +0.00320792 +0.00740495 +0.0958081 +0.109615 +0.0105327 +0.0192265 +0.0214681 +0.021833 +0.0585782 +0.0251717 +0.0280681 +0.0729857 +0.425132 +0.457442 +2.52983 +1.15483 +1.15483 +1.18416 +0.104274 +1.20247 +1.1064 +1.67005 +1.05586 +2.55089 +1.8334 +0.971751 +0.470398 +0.284519 +0.25376 +2.29193 +1.29672 +1.29672 +0.432495 +0.411463 +1.10822 +1.10822 +0.698467 +1.01384 +0.422528 +0.471828 +0.354965 +0.370006 +0.932942 +0.924415 +0.821473 +2.34528 +1.8362 +0.329965 +0.427145 +0.995744 +1.17653 +0.937677 +0.937677 +0.71568 +0.916428 +1.05491 +0.479239 +0.488732 +1.07755 +1.05904 +0.269731 +0.197423 +0.303257 +0.0738098 +0.0535443 +0.0704248 +0.083286 +0.158385 +0.149844 +0.416707 +0.349628 +0.31117 +0.304836 +0.072437 +1.5569 +1.20564 +2.12525 +0.303531 +0.712031 +0.844469 +0.434198 +0.277141 +0.593882 +0.648409 +1.02533 +0.692478 +0.367875 +0.316403 +0.351662 +0.117319 +1.85435 +0.413932 +0.409023 +0.661534 +0.650092 +0.766712 +0.594595 +0.501872 +0.515377 +0.236945 +0.689338 +2.99054 +0.172018 +0.0528735 +0.0579658 +0.0581689 +0.0497977 +0.063146 +0.311101 +0.347411 +0.763051 +0.734721 +1.17926 +1.02504 +1.02504 +0.645771 +0.970169 +1.20163 +1.1179 +0.385697 +0.410691 +0.471671 +0.540587 +0.250604 +0.254267 +0.220907 +0.673968 +0.265055 +0.766607 +1.50436 +1.58131 +0.719291 +0.958127 +0.546963 +1.60432 +1.60432 +1.45897 +0.717682 +0.668208 +0.71339 +0.276479 +0.255967 +0.799242 +1.32658 +0.724295 +0.36085 +0.217894 +0.254816 +0.125993 +1.31909 +1.56969 +0.750499 +1.11993 +1.87465 +1.472 +0.950422 +0.754906 +0.270587 +0.231469 +1.19634 +1.19634 +1.19634 +1.51182 +1.34804 +1.42657 +0.544703 +1.66443 +1.66443 +1.14928 +1.10046 +1.16557 +1.5537 +0.629914 +0.880496 +0.525169 +0.312335 +0.797038 +0.733363 +1.6496 +0.0602699 +0.0840557 +0.107319 +0.0324205 +0.0929894 +0.226149 +0.202803 +0.217807 +0.366175 +0.158146 +0.160235 +0.175013 +0.148804 +0.526506 +0.785313 +1.23336 +0.733001 diff --git a/tests/reference/test_deterministic_sampling_stan.txt b/tests/reference/test_deterministic_sampling_stan.txt index 3bed2a2..dd85d53 100644 --- a/tests/reference/test_deterministic_sampling_stan.txt +++ b/tests/reference/test_deterministic_sampling_stan.txt @@ -1,2 +1,2 @@ -1.21572 1.03376 1.60518 1.60518 1.59553 1.35023 0.761056 1.41688 1.41688 1.41688 -0.252389 0.999663 0.999663 0.999663 0.740026 0.387763 0.944247 0.289785 1.52909 0.683129 +0.754944 0.746804 0.687211 1.56984 2.15413 2.15413 0.186138 1.19976 1.19976 0.818806 +0.185979 1.20179 0.236474 0.240597 0.416886 0.529295 0.574728 0.59912 1.02193 0.902788 diff --git a/tests/test_pymc.py b/tests/test_pymc.py index e710aae..8d697b0 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -88,7 +88,12 @@ def test_low_rank(backend, gradient_backend): model, backend=backend, gradient_backend=gradient_backend ) trace = nutpie.sample(compiled, chains=1, low_rank_modified_mass_matrix=True) - trace.posterior.a # noqa: B018 + + assert "mass_matrix_eigvals" not in trace.sample_stats + trace = nutpie.sample( + compiled, chains=1, low_rank_modified_mass_matrix=True, store_mass_matrix=True + ) + assert "mass_matrix_eigvals" in trace.sample_stats @pytest.mark.pymc @@ -421,7 +426,7 @@ def test_missing(backend, gradient_backend): @pytest.mark.pymc -@pytest.mark.array_compare +@pytest.mark.array_compare(atol=1e-4, rtol=1e-4) def test_deterministic_sampling_numba(): with pm.Model() as model: pm.HalfNormal("a") @@ -432,7 +437,7 @@ def test_deterministic_sampling_numba(): @pytest.mark.pymc -@pytest.mark.array_compare +@pytest.mark.array_compare(atol=1e-4, rtol=1e-4) def test_deterministic_sampling_jax(): with pm.Model() as model: pm.HalfNormal("a") @@ -440,3 +445,28 @@ def test_deterministic_sampling_jax(): compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax") trace = nutpie.sample(compiled, chains=2, seed=123, draws=100, tune=100) return trace.posterior.a.values.ravel() + + +@pytest.mark.pymc +def test_zarr_store(tmp_path): + with pm.Model() as model: + pm.HalfNormal("a") + + compiled = nutpie.compile_pymc_model(model, backend="numba") + + path = tmp_path / "trace.zarr" + path.mkdir() + store = nutpie.zarr_store.LocalStore(str(path)) + trace = nutpie.sample( + compiled, chains=2, seed=123, draws=100, tune=100, zarr_store=store + ) + trace.load().posterior.a # noqa: B018 + + +@pytest.fixture +def tmp_path(): + import tempfile + from pathlib import Path + + with tempfile.TemporaryDirectory() as tmpdirname: + yield Path(tmpdirname) diff --git a/tests/test_stan.py b/tests/test_stan.py index 53b6b40..66cc4d5 100644 --- a/tests/test_stan.py +++ b/tests/test_stan.py @@ -278,7 +278,7 @@ def test_stan_flow(): # TODO: There are small numerical differences between linux and windows. # We should figure out if they originate in stan or in nutpie. -@pytest.mark.array_compare(atol=1e-4) +@pytest.mark.array_compare(atol=1e-4, rtol=1e-4) @pytest.mark.stan def test_deterministic_sampling_stan(): model = """ @@ -296,6 +296,6 @@ def test_deterministic_sampling_stan(): compiled_model = nutpie.compile_stan_model(code=model) trace = nutpie.sample(compiled_model, chains=2, seed=123, draws=100, tune=100) trace2 = nutpie.sample(compiled_model, chains=2, seed=123, draws=100, tune=100) - np.testing.assert_allclose(trace.posterior.a.values, trace2.posterior.a.values) - np.testing.assert_allclose(trace.posterior.b.values, trace2.posterior.b.values) + np.testing.assert_array_max_ulp(trace.posterior.a.values, trace2.posterior.a.values) + np.testing.assert_array_max_ulp(trace.posterior.b.values, trace2.posterior.b.values) return trace.posterior.a.isel(draw=slice(None, 10)).values