Skip to content

Commit 262afad

Browse files
authored
Merge branch 'master' into improve-trait-definition
2 parents 2cab0ba + 9ae1ed9 commit 262afad

File tree

18 files changed

+570
-71
lines changed

18 files changed

+570
-71
lines changed

Diff for: .github/workflows/ci.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ jobs:
4848
export AF_PATH=${GITHUB_WORKSPACE}/afbin
4949
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${AF_PATH}/lib64
5050
echo "Using cargo version: $(cargo --version)"
51-
cargo build --all
52-
cargo test --no-fail-fast
51+
cargo build --all --all-features
52+
cargo test --no-fail-fast --all-features
5353
5454
format:
5555
name: Format Check

Diff for: Cargo.toml

+4
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,19 @@ statistics = []
4646
vision = []
4747
default = ["algorithm", "arithmetic", "blas", "data", "indexing", "graphics", "image", "lapack",
4848
"ml", "macros", "random", "signal", "sparse", "statistics", "vision"]
49+
afserde = ["serde"]
4950

5051
[dependencies]
5152
libc = "0.2"
5253
num = "0.2"
5354
lazy_static = "1.0"
5455
half = "1.5.0"
56+
serde = { version = "1.0", features = ["derive"], optional = true }
5557

5658
[dev-dependencies]
5759
half = "1.5.0"
60+
serde_json = "1.0"
61+
bincode = "1.3"
5862

5963
[build-dependencies]
6064
serde_json = "1.0"

Diff for: README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Only, Major(M) & Minor(m) version numbers need to match. *p1* and *p2* are patch
1616

1717
## Supported platforms
1818

19-
Linux, Windows and OSX. Rust 1.15.1 or higher is required.
19+
Linux, Windows and OSX. Rust 1.31 or newer is required.
2020

2121
## Use from Crates.io [![][6]][7] [![][8]][9]
2222

Diff for: cuda-interop/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ rustacuda = "0.1"
1919
rustacuda_core = "0.1"
2020

2121
[[example]]
22-
name = "custom_kernel"
22+
name = "afcuda_custom_kernel"
2323
path = "examples/custom_kernel.rs"
2424

2525
[[example]]

Diff for: cuda-interop/examples/cuda_af_app.rs

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use arrayfire::{af_print, dim4, info, set_device, Array};
22
use rustacuda::prelude::*;
3-
use rustacuda::*;
43

54
fn main() {
65
// MAKE SURE to do all rustacuda initilization before arrayfire API's

Diff for: examples/helloworld.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ fn main() {
1313
);
1414
println!("Revision: {}", get_revision());
1515

16-
let num_rows: u64 = 5;
17-
let num_cols: u64 = 3;
16+
let num_rows: i64 = 5;
17+
let num_cols: i64 = 3;
1818
let values: [f32; 3] = [1.0, 2.0, 3.0];
1919
let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
2020

2121
af_print!("Indices ", indices);
2222

23-
let dims = Dim4::new(&[num_rows, num_cols, 1, 1]);
23+
let dims = Dim4::new(&[num_rows as u64, num_cols as u64, 1, 1]);
2424

2525
let mut a = randu::<f32>(dims);
2626
af_print!("Create a 5-by-3 float matrix on the GPU", a);

Diff for: opencl-interop/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ cl-sys = "0.4.2"
1818
ocl-core = "0.11.2"
1919

2020
[[example]]
21-
name = "custom_kernel"
21+
name = "afocl_custom_kernel"
2222
path = "examples/custom_kernel.rs"
2323

2424
[[example]]

Diff for: opencl-interop/examples/custom_kernel.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ fn main() {
2222
let af_ctx = afcl::get_context(false);
2323
let af_que = afcl::get_queue(false);
2424

25-
let devid = unsafe { ocl_core::DeviceId::from_raw(af_did) };
25+
let _devid = unsafe { ocl_core::DeviceId::from_raw(af_did) };
2626
let contx = unsafe { ocl_core::Context::from_raw_copied_ptr(af_ctx) };
2727
let queue = unsafe { ocl_core::CommandQueue::from_raw_copied_ptr(af_que) };
2828

Diff for: src/algorithm/mod.rs

+115-25
Original file line numberDiff line numberDiff line change
@@ -518,12 +518,17 @@ where
518518
}
519519

520520
macro_rules! all_reduce_func_def {
521-
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => {
521+
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => {
522522
#[doc=$doc_str]
523-
pub fn $fn_name<T>(input: &Array<T>) -> ($out_type, $out_type)
523+
pub fn $fn_name<T>(input: &Array<T>)
524+
-> (
525+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
526+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType
527+
)
524528
where
525529
T: HasAfEnum,
526-
$out_type: HasAfEnum + Fromf64
530+
<T as HasAfEnum>::$assoc_type: HasAfEnum,
531+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
527532
{
528533
let mut real: f64 = 0.0;
529534
let mut imag: f64 = 0.0;
@@ -533,7 +538,10 @@ macro_rules! all_reduce_func_def {
533538
);
534539
HANDLE_ERROR(AfError::from(err_val));
535540
}
536-
(<$out_type>::fromf64(real), <$out_type>::fromf64(imag))
541+
(
542+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(real),
543+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(imag),
544+
)
537545
}
538546
};
539547
}
@@ -564,7 +572,7 @@ all_reduce_func_def!(
564572
",
565573
sum_all,
566574
af_sum_all,
567-
T::AggregateOutType
575+
AggregateOutType
568576
);
569577

570578
all_reduce_func_def!(
@@ -594,7 +602,7 @@ all_reduce_func_def!(
594602
",
595603
product_all,
596604
af_product_all,
597-
T::ProductOutType
605+
ProductOutType
598606
);
599607

600608
all_reduce_func_def!(
@@ -623,7 +631,7 @@ all_reduce_func_def!(
623631
",
624632
min_all,
625633
af_min_all,
626-
T::InType
634+
InType
627635
);
628636

629637
all_reduce_func_def!(
@@ -652,10 +660,31 @@ all_reduce_func_def!(
652660
",
653661
max_all,
654662
af_max_all,
655-
T::InType
663+
InType
656664
);
657665

658-
all_reduce_func_def!(
666+
macro_rules! all_reduce_func_def2 {
667+
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => {
668+
#[doc=$doc_str]
669+
pub fn $fn_name<T>(input: &Array<T>) -> ($out_type, $out_type)
670+
where
671+
T: HasAfEnum,
672+
$out_type: HasAfEnum + Fromf64
673+
{
674+
let mut real: f64 = 0.0;
675+
let mut imag: f64 = 0.0;
676+
unsafe {
677+
let err_val = $ffi_name(
678+
&mut real as *mut c_double, &mut imag as *mut c_double, input.get(),
679+
);
680+
HANDLE_ERROR(AfError::from(err_val));
681+
}
682+
(<$out_type>::fromf64(real), <$out_type>::fromf64(imag))
683+
}
684+
};
685+
}
686+
687+
all_reduce_func_def2!(
659688
"
660689
Find if all values of Array are non-zero
661690
@@ -682,7 +711,7 @@ all_reduce_func_def!(
682711
bool
683712
);
684713

685-
all_reduce_func_def!(
714+
all_reduce_func_def2!(
686715
"
687716
Find if any value of Array is non-zero
688717
@@ -709,7 +738,7 @@ all_reduce_func_def!(
709738
bool
710739
);
711740

712-
all_reduce_func_def!(
741+
all_reduce_func_def2!(
713742
"
714743
Count number of non-zero values in the Array
715744
@@ -751,10 +780,17 @@ all_reduce_func_def!(
751780
/// A tuple of summation result.
752781
///
753782
/// Note: For non-complex data type Arrays, second value of tuple is zero.
754-
pub fn sum_nan_all<T>(input: &Array<T>, val: f64) -> (T::AggregateOutType, T::AggregateOutType)
783+
pub fn sum_nan_all<T>(
784+
input: &Array<T>,
785+
val: f64,
786+
) -> (
787+
<<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType,
788+
<<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType,
789+
)
755790
where
756791
T: HasAfEnum,
757-
T::AggregateOutType: HasAfEnum + Fromf64,
792+
<T as HasAfEnum>::AggregateOutType: HasAfEnum,
793+
<<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
758794
{
759795
let mut real: f64 = 0.0;
760796
let mut imag: f64 = 0.0;
@@ -768,8 +804,8 @@ where
768804
HANDLE_ERROR(AfError::from(err_val));
769805
}
770806
(
771-
<T::AggregateOutType>::fromf64(real),
772-
<T::AggregateOutType>::fromf64(imag),
807+
<<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType::fromf64(real),
808+
<<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType::fromf64(imag),
773809
)
774810
}
775811

@@ -788,10 +824,17 @@ where
788824
/// A tuple of product result.
789825
///
790826
/// Note: For non-complex data type Arrays, second value of tuple is zero.
791-
pub fn product_nan_all<T>(input: &Array<T>, val: f64) -> (T::ProductOutType, T::ProductOutType)
827+
pub fn product_nan_all<T>(
828+
input: &Array<T>,
829+
val: f64,
830+
) -> (
831+
<<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType,
832+
<<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType,
833+
)
792834
where
793835
T: HasAfEnum,
794-
T::ProductOutType: HasAfEnum + Fromf64,
836+
<T as HasAfEnum>::ProductOutType: HasAfEnum,
837+
<<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
795838
{
796839
let mut real: f64 = 0.0;
797840
let mut imag: f64 = 0.0;
@@ -805,8 +848,8 @@ where
805848
HANDLE_ERROR(AfError::from(err_val));
806849
}
807850
(
808-
<T::ProductOutType>::fromf64(real),
809-
<T::ProductOutType>::fromf64(imag),
851+
<<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType::fromf64(real),
852+
<<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType::fromf64(imag),
810853
)
811854
}
812855

@@ -858,12 +901,18 @@ dim_ireduce_func_def!("
858901
", imax, af_imax, InType);
859902

860903
macro_rules! all_ireduce_func_def {
861-
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => {
904+
($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => {
862905
#[doc=$doc_str]
863-
pub fn $fn_name<T>(input: &Array<T>) -> ($out_type, $out_type, u32)
906+
pub fn $fn_name<T>(input: &Array<T>)
907+
-> (
908+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
909+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
910+
u32
911+
)
864912
where
865913
T: HasAfEnum,
866-
$out_type: HasAfEnum + Fromf64
914+
<T as HasAfEnum>::$assoc_type: HasAfEnum,
915+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
867916
{
868917
let mut real: f64 = 0.0;
869918
let mut imag: f64 = 0.0;
@@ -875,7 +924,11 @@ macro_rules! all_ireduce_func_def {
875924
);
876925
HANDLE_ERROR(AfError::from(err_val));
877926
}
878-
(<$out_type>::fromf64(real), <$out_type>::fromf64(imag), temp)
927+
(
928+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(real),
929+
<<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(imag),
930+
temp,
931+
)
879932
}
880933
};
881934
}
@@ -898,7 +951,7 @@ all_ireduce_func_def!(
898951
",
899952
imin_all,
900953
af_imin_all,
901-
T::InType
954+
InType
902955
);
903956
all_ireduce_func_def!(
904957
"
@@ -918,7 +971,7 @@ all_ireduce_func_def!(
918971
",
919972
imax_all,
920973
af_imax_all,
921-
T::InType
974+
InType
922975
);
923976

924977
/// Locate the indices of non-zero elements.
@@ -1386,3 +1439,40 @@ dim_reduce_by_key_nan_func_def!(
13861439
af_product_by_key_nan,
13871440
ValueType::ProductOutType
13881441
);
1442+
1443+
#[cfg(test)]
1444+
mod tests {
1445+
use super::super::core::c32;
1446+
use super::{imax_all, imin_all, product_nan_all, sum_all, sum_nan_all};
1447+
use crate::randu;
1448+
1449+
#[test]
1450+
fn all_reduce_api() {
1451+
let a = randu!(c32; 10, 10);
1452+
println!("Reduction of complex f32 matrix: {:?}", sum_all(&a));
1453+
1454+
let b = randu!(bool; 10, 10);
1455+
println!("reduction of bool matrix: {:?}", sum_all(&b));
1456+
1457+
println!(
1458+
"reduction of complex f32 matrix after replacing nan with {}: {:?}",
1459+
1.0,
1460+
product_nan_all(&a, 1.0)
1461+
);
1462+
1463+
println!(
1464+
"reduction of bool matrix after replacing nan with {}: {:?}",
1465+
0.0,
1466+
sum_nan_all(&b, 0.0)
1467+
);
1468+
}
1469+
1470+
#[test]
1471+
fn all_ireduce_api() {
1472+
let a = randu!(c32; 10);
1473+
println!("Reduction of complex f32 matrix: {:?}", imin_all(&a));
1474+
1475+
let b = randu!(u32; 10);
1476+
println!("reduction of bool matrix: {:?}", imax_all(&b));
1477+
}
1478+
}

0 commit comments

Comments
 (0)