Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ members = [
"ext/env",
"ext/core",
"ext/os",
"ext/ai"
"ext/ai",
"ext/ai/utilities"
]

[workspace.dependencies]
Expand Down Expand Up @@ -75,6 +76,7 @@ sb_env = { version = "0.1.0", path = "./ext/env"}
sb_core = { version = "0.1.0", path = "./ext/core" }
sb_os = { version = "0.1.0", path = "./ext/os" }
sb_ai = { version = "0.1.0", path = "./ext/ai" }
sb_ai_v8_utilities = { version = "0.1.0", path = "./ext/ai/utilities" }

# crypto
hkdf = "0.12.3"
Expand Down
1 change: 1 addition & 0 deletions ext/ai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ path = "lib.rs"
deno_core.workspace = true

base_rt.workspace = true
sb_ai_v8_utilities.workspace = true

anyhow.workspace = true
log.workspace = true
Expand Down
71 changes: 71 additions & 0 deletions ext/ai/onnxruntime/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ impl JsTensor {
pub fn extract_ort_tensor_ref<'a, T: IntoTensorElementType + Debug>(
mut self,
) -> anyhow::Result<ValueRefMut<'a, DynValueTypeMarker>> {
let expected_length = self.dims.iter().product::<i64>() as usize;
let current_length = self.data.len() / size_of::<T>();

if current_length != expected_length {
return Err(anyhow!(
"invalid tensor length! got '{current_length}' expect '{expected_length}'"
));
};

// Same impl. as the Tensor::from_array()
// https://github.com/pykeio/ort/blob/abd527b6a1df8f566c729a9c4398bdfd185d652f/src/value/impl_tensor/create.rs#L170
let memory_info = MemoryInfo::new(
Expand All @@ -118,6 +127,7 @@ impl JsTensor {
)?;

// Zero-Copying Data to an ORT Tensor based on JS type
// SAFETY: we did check tensor size above
let tensor = unsafe {
TensorRefMut::<T>::from_raw(
memory_info,
Expand Down Expand Up @@ -200,3 +210,64 @@ impl ToJsTensor {
})
}
}

#[cfg(test)]
mod tests {
use sb_ai_v8_utilities::v8_do;

use super::*;

#[test]
fn test_ort_tensor_extract_ref() {
v8_do(|| {
// region: v8-init
// ref: https://github.com/denoland/deno_core/blob/490079f6b5c9233f476b0a529eace1f5b2c4ed07/serde_v8/tests/magic.rs#L23
let isolate = &mut v8::Isolate::new(v8::CreateParams::default());
let handle_scope = &mut v8::HandleScope::new(isolate);
let context = v8::Context::new(handle_scope);
let scope = &mut v8::ContextScope::new(handle_scope, context);
// endregion: v8-init

// Bad Tensor Scenario:
let tensor_script = r#"({
type: 'float32',
data: new Float32Array([]),
dims: [1, 1],
size: 300
})"#;

let js_tensor = {
let code = v8::String::new(scope, tensor_script).unwrap();
let script = v8::Script::compile(scope, code, None).unwrap();
script.run(scope).unwrap()
};

let tensor: JsTensor = deno_core::serde_v8::from_v8(scope, js_tensor).unwrap();

let tensor_ref_result = tensor.extract_ort_tensor_ref::<f32>();
assert!(
tensor_ref_result.is_err(),
"Since `data.len()` doesn't reflect `dims` it must return Error"
);

// Good Tensor Scenario:
let tensor_script = r#"({
type: 'float32',
data: new Float32Array([0.1, 0.2]),
dims: [1, 2],
size: 2
})"#;

let js_tensor = {
let code = v8::String::new(scope, tensor_script).unwrap();
let script = v8::Script::compile(scope, code, None).unwrap();
script.run(scope).unwrap()
};

let tensor: JsTensor = deno_core::serde_v8::from_v8(scope, js_tensor).unwrap();

let tensor_ref_result = tensor.extract_ort_tensor_ref::<f32>();
assert!(tensor_ref_result.is_ok());
});
}
}
13 changes: 13 additions & 0 deletions ext/ai/utilities/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[package]
name = "sb_ai_v8_utilities"
version = "0.1.0"
authors = ["Supabase <team@supabase.com>"]
edition = "2021"
license = "MIT"
publish = false

[lib]
path = "lib.rs"

[dependencies]
deno_core.workspace = true
26 changes: 26 additions & 0 deletions ext/ai/utilities/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use std::sync::Once;

use deno_core::v8;

pub fn v8_init() {
let platform = v8::new_unprotected_default_platform(0, false).make_shared();
v8::V8::initialize_platform(platform);
v8::V8::initialize();
}

pub fn v8_shutdown() {
// SAFETY: this is safe, because all isolates have been shut down already.
unsafe {
v8::V8::dispose();
}
v8::V8::dispose_platform();
}

pub fn v8_do(f: impl FnOnce()) {
static V8_INIT: Once = Once::new();
V8_INIT.call_once(|| {
v8_init();
});
f();
// v8_shutdown();
}