diff --git a/Cargo.lock b/Cargo.lock index 2ae1becd1..57fb7f4e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6042,6 +6042,7 @@ dependencies = [ "ort-sys", "rand", "reqwest 0.12.4", + "sb_ai_v8_utilities", "scopeguard", "serde", "tokenizers", @@ -6052,6 +6053,13 @@ dependencies = [ "xxhash-rust", ] +[[package]] +name = "sb_ai_v8_utilities" +version = "0.1.0" +dependencies = [ + "deno_core", +] + [[package]] name = "sb_core" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 090d643ed..d7c3b9c25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,8 @@ members = [ "ext/env", "ext/core", "ext/os", - "ext/ai" + "ext/ai", + "ext/ai/utilities" ] [workspace.dependencies] @@ -75,6 +76,7 @@ sb_env = { version = "0.1.0", path = "./ext/env"} sb_core = { version = "0.1.0", path = "./ext/core" } sb_os = { version = "0.1.0", path = "./ext/os" } sb_ai = { version = "0.1.0", path = "./ext/ai" } +sb_ai_v8_utilities = { version = "0.1.0", path = "./ext/ai/utilities" } # crypto hkdf = "0.12.3" diff --git a/ext/ai/Cargo.toml b/ext/ai/Cargo.toml index a0c70575e..06c157ee8 100644 --- a/ext/ai/Cargo.toml +++ b/ext/ai/Cargo.toml @@ -12,6 +12,7 @@ path = "lib.rs" deno_core.workspace = true base_rt.workspace = true +sb_ai_v8_utilities.workspace = true anyhow.workspace = true log.workspace = true diff --git a/ext/ai/onnxruntime/tensor.rs b/ext/ai/onnxruntime/tensor.rs index 3dac07d23..1cfea1d6f 100644 --- a/ext/ai/onnxruntime/tensor.rs +++ b/ext/ai/onnxruntime/tensor.rs @@ -108,6 +108,15 @@ impl JsTensor { pub fn extract_ort_tensor_ref<'a, T: IntoTensorElementType + Debug>( mut self, ) -> anyhow::Result> { + let expected_length = self.dims.iter().product::() as usize; + let current_length = self.data.len() / size_of::(); + + if current_length != expected_length { + return Err(anyhow!( + "invalid tensor length! got '{current_length}' expect '{expected_length}'" + )); + }; + // Same impl. as the Tensor::from_array() // https://github.com/pykeio/ort/blob/abd527b6a1df8f566c729a9c4398bdfd185d652f/src/value/impl_tensor/create.rs#L170 let memory_info = MemoryInfo::new( @@ -118,6 +127,7 @@ impl JsTensor { )?; // Zero-Copying Data to an ORT Tensor based on JS type + // SAFETY: we did check tensor size above let tensor = unsafe { TensorRefMut::::from_raw( memory_info, @@ -200,3 +210,64 @@ impl ToJsTensor { }) } } + +#[cfg(test)] +mod tests { + use sb_ai_v8_utilities::v8_do; + + use super::*; + + #[test] + fn test_ort_tensor_extract_ref() { + v8_do(|| { + // region: v8-init + // ref: https://github.com/denoland/deno_core/blob/490079f6b5c9233f476b0a529eace1f5b2c4ed07/serde_v8/tests/magic.rs#L23 + let isolate = &mut v8::Isolate::new(v8::CreateParams::default()); + let handle_scope = &mut v8::HandleScope::new(isolate); + let context = v8::Context::new(handle_scope); + let scope = &mut v8::ContextScope::new(handle_scope, context); + // endregion: v8-init + + // Bad Tensor Scenario: + let tensor_script = r#"({ + type: 'float32', + data: new Float32Array([]), + dims: [1, 1], + size: 300 + })"#; + + let js_tensor = { + let code = v8::String::new(scope, tensor_script).unwrap(); + let script = v8::Script::compile(scope, code, None).unwrap(); + script.run(scope).unwrap() + }; + + let tensor: JsTensor = deno_core::serde_v8::from_v8(scope, js_tensor).unwrap(); + + let tensor_ref_result = tensor.extract_ort_tensor_ref::(); + assert!( + tensor_ref_result.is_err(), + "Since `data.len()` doesn't reflect `dims` it must return Error" + ); + + // Good Tensor Scenario: + let tensor_script = r#"({ + type: 'float32', + data: new Float32Array([0.1, 0.2]), + dims: [1, 2], + size: 2 + })"#; + + let js_tensor = { + let code = v8::String::new(scope, tensor_script).unwrap(); + let script = v8::Script::compile(scope, code, None).unwrap(); + script.run(scope).unwrap() + }; + + let tensor: JsTensor = deno_core::serde_v8::from_v8(scope, js_tensor).unwrap(); + + let tensor_ref_result = tensor.extract_ort_tensor_ref::(); + assert!(tensor_ref_result.is_ok()); + }); + } +} diff --git a/ext/ai/utilities/Cargo.toml b/ext/ai/utilities/Cargo.toml new file mode 100644 index 000000000..51704d41b --- /dev/null +++ b/ext/ai/utilities/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "sb_ai_v8_utilities" +version = "0.1.0" +authors = ["Supabase "] +edition = "2021" +license = "MIT" +publish = false + +[lib] +path = "lib.rs" + +[dependencies] +deno_core.workspace = true \ No newline at end of file diff --git a/ext/ai/utilities/lib.rs b/ext/ai/utilities/lib.rs new file mode 100644 index 000000000..b9a95944f --- /dev/null +++ b/ext/ai/utilities/lib.rs @@ -0,0 +1,26 @@ +use std::sync::Once; + +use deno_core::v8; + +pub fn v8_init() { + let platform = v8::new_unprotected_default_platform(0, false).make_shared(); + v8::V8::initialize_platform(platform); + v8::V8::initialize(); +} + +pub fn v8_shutdown() { + // SAFETY: this is safe, because all isolates have been shut down already. + unsafe { + v8::V8::dispose(); + } + v8::V8::dispose_platform(); +} + +pub fn v8_do(f: impl FnOnce()) { + static V8_INIT: Once = Once::new(); + V8_INIT.call_once(|| { + v8_init(); + }); + f(); + // v8_shutdown(); +}