Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/cuda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,6 @@ jobs:
- name: Download and run cudf-test-harness
run: |
curl -fsSL https://github.com/vortex-data/cudf-test-harness/releases/latest/download/cudf-test-harness-x86_64.tar.gz | tar -xz
cd cudf-test-harness-x86_64
compute-sanitizer --tool memcheck --error-exitcode 1 ./cudf-test-harness check $GITHUB_WORKSPACE/target/x86_64-unknown-linux-gnu/ci/libvortex_test_e2e_cuda.so
$GITHUB_WORKSPACE/target/x86_64-unknown-linux-gnu/ci/cudf_harness_runner \
./cudf-test-harness-x86_64/cudf-test-harness \
$GITHUB_WORKSPACE/target/x86_64-unknown-linux-gnu/ci/libvortex_test_e2e_cuda.so
48 changes: 48 additions & 0 deletions vortex-test/e2e-cuda/src/bin/cudf_harness_runner.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::env;
use std::process::Command;
use std::process::ExitCode;

const PRIMITIVE_DTYPES: &[&str] = &[
"u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64", "f32", "f64",
];
const PRIMITIVE_DTYPE_ENV: &str = "VORTEX_CUDF_PRIMITIVE_DTYPE";

fn main() -> ExitCode {
let args = env::args().collect::<Vec<_>>();
let [program, harness, library] = args.as_slice() else {
eprintln!(
"Usage: {} <cudf-test-harness> <library.so>",
args.first().map_or("cudf_harness_runner", String::as_str)
);
return ExitCode::from(2);
};

for primitive_dtype in PRIMITIVE_DTYPES {
eprintln!("running {program} with {PRIMITIVE_DTYPE_ENV}={primitive_dtype}");

let status = Command::new("compute-sanitizer")
.args(["--tool", "memcheck", "--error-exitcode", "1"])
.arg(harness)
.arg("check")
.arg(library)
.env(PRIMITIVE_DTYPE_ENV, primitive_dtype)
.status();

match status {
Ok(status) if status.success() => {}
Ok(status) => {
eprintln!("cudf-test-harness failed with {status}");
return ExitCode::from(1);
}
Err(err) => {
eprintln!("failed to run cudf-test-harness: {err}");
return ExitCode::from(1);
}
}
}

ExitCode::SUCCESS
}
62 changes: 55 additions & 7 deletions vortex-test/e2e-cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@

#![expect(clippy::unwrap_used, clippy::expect_used)]

use std::env;
use std::mem;
use std::sync::Arc;
use std::sync::LazyLock;

use arrow_array::Array;
use arrow_array::ArrayRef;
use arrow_array::ArrayRef as ArrowArrayRef;
use arrow_array::Date32Array;
use arrow_array::Decimal128Array;
use arrow_array::StringArray;
use arrow_array::UInt32Array;
use arrow_array::cast::AsArray;
use arrow_array::ffi::FFI_ArrowArray;
use arrow_array::ffi::from_ffi;
Expand All @@ -31,16 +31,20 @@ use arrow_schema::Field;
use arrow_schema::Fields;
use arrow_schema::ffi::FFI_ArrowSchema;
use futures::executor::block_on;
use vortex::array::ArrayRef as VortexArrayRef;
use vortex::array::IntoArray;
use vortex::array::VortexSessionExecute;
use vortex::array::arrays::DecimalArray;
use vortex::array::arrays::PrimitiveArray;
use vortex::array::arrays::StructArray;
use vortex::array::arrays::TemporalArray;
use vortex::array::arrays::VarBinViewArray;
use vortex::array::arrow::ArrowSessionExt;
use vortex::array::session::ArraySession;
use vortex::array::validity::Validity;
use vortex::dtype::DecimalDType;
use vortex::dtype::FieldNames;
use vortex::dtype::NativePType;
use vortex::extension::datetime::TimeUnit;
use vortex::io::session::RuntimeSession;
use vortex::layout::session::LayoutSession;
Expand All @@ -50,6 +54,8 @@ use vortex_cuda::CudaSession;
use vortex_cuda::arrow::ArrowDeviceArray;
use vortex_cuda::arrow::DeviceArrayExt;

const PRIMITIVE_DTYPE_ENV: &str = "VORTEX_CUDF_PRIMITIVE_DTYPE";

static SESSION: LazyLock<VortexSession> = LazyLock::new(|| {
VortexSession::empty()
.with::<ArraySession>()
Expand All @@ -59,6 +65,35 @@ static SESSION: LazyLock<VortexSession> = LazyLock::new(|| {
.with::<CudaSession>()
});

fn primitive_dtype_case() -> String {
env::var(PRIMITIVE_DTYPE_ENV).unwrap_or_else(|_| "u32".to_string())
}

fn nullable_primitive<T: NativePType>(first: T, second: T, third: T) -> VortexArrayRef {
PrimitiveArray::from_option_iter([Some(first), None, Some(second), Some(third), None])
.into_array()
}

fn primitive_array() -> Result<VortexArrayRef, String> {
Ok(match primitive_dtype_case().as_str() {
"u8" => nullable_primitive(0u8, 2, 3),
"u16" => nullable_primitive(10u16, 12, 13),
"u32" => nullable_primitive(20u32, 22, 23),
"u64" => nullable_primitive(30u64, 32, 33),
"i8" => nullable_primitive(-4i8, -2, 3),
"i16" => nullable_primitive(-14i16, -12, 13),
"i32" => nullable_primitive(-24i32, -22, 23),
"i64" => nullable_primitive(-34i64, -32, 33),
"f32" => nullable_primitive(1.25f32, -2.5, 3.75),
"f64" => nullable_primitive(10.25f64, -20.5, 30.75),
other => {
return Err(format!(
"unsupported {PRIMITIVE_DTYPE_ENV}={other}; expected one of u8,u16,u32,u64,i8,i16,i32,i64,f32,f64"
));
}
})
}

/// # Safety
/// called by C++ code.
#[unsafe(no_mangle)]
Expand All @@ -68,7 +103,13 @@ pub unsafe extern "C" fn export_array(
) -> i32 {
let mut ctx = CudaSession::create_execution_ctx(&SESSION).unwrap();

let primitive = PrimitiveArray::from_option_iter([Some(0u32), None, Some(2), Some(3), None]);
let primitive = match primitive_array() {
Ok(array) => array,
Err(err) => {
eprintln!("error in export_array: {err}");
return 1;
}
};
let decimal = DecimalArray::from_option_iter(
[Some(0i128), Some(1), None, Some(3), Some(4)],
DecimalDType::new(38, 2),
Expand All @@ -89,7 +130,7 @@ pub unsafe extern "C" fn export_array(
let array = StructArray::new(
FieldNames::from_iter(["prims", "decimals", "strings", "dates"]),
vec![
primitive.into_array(),
primitive,
decimal.into_array(),
strings.into_array(),
dates.into_array(),
Expand Down Expand Up @@ -128,7 +169,14 @@ pub unsafe extern "C" fn validate_array(
let array = make_array(array_data);
let struct_array = array.as_struct();

let primitive = UInt32Array::from_iter([Some(0), None, Some(2), Some(3), None]);
let primitive = SESSION
.arrow()
.execute_arrow(
primitive_array().expect("expected primitive array"),
None,
&mut SESSION.create_execution_ctx(),
)
.expect("expected primitive Arrow array");
let decimal = Decimal128Array::from_iter([Some(0i128), Some(1), None, Some(3), Some(4)])
.with_precision_and_scale(38, 2)
.expect("with_precision_and_scale");
Expand All @@ -155,8 +203,8 @@ pub unsafe extern "C" fn validate_array(
struct_array.fields()
);

let expected_fields: [ArrayRef; _] = [
Arc::new(primitive),
let expected_fields: [ArrowArrayRef; _] = [
primitive,
Arc::new(decimal),
Arc::new(string),
Arc::new(date),
Expand Down
Loading