@@ -518,12 +518,17 @@ where
518
518
}
519
519
520
520
macro_rules! all_reduce_func_def {
521
- ( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type : ty ) => {
521
+ ( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type : ident ) => {
522
522
#[ doc=$doc_str]
523
- pub fn $fn_name<T >( input: & Array <T >) -> ( $out_type, $out_type)
523
+ pub fn $fn_name<T >( input: & Array <T >)
524
+ -> (
525
+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType ,
526
+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType
527
+ )
524
528
where
525
529
T : HasAfEnum ,
526
- $out_type: HasAfEnum + Fromf64
530
+ <T as HasAfEnum >:: $assoc_type: HasAfEnum ,
531
+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType : HasAfEnum + Fromf64 ,
527
532
{
528
533
let mut real: f64 = 0.0 ;
529
534
let mut imag: f64 = 0.0 ;
@@ -533,7 +538,10 @@ macro_rules! all_reduce_func_def {
533
538
) ;
534
539
HANDLE_ERROR ( AfError :: from( err_val) ) ;
535
540
}
536
- ( <$out_type>:: fromf64( real) , <$out_type>:: fromf64( imag) )
541
+ (
542
+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType :: fromf64( real) ,
543
+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType :: fromf64( imag) ,
544
+ )
537
545
}
538
546
} ;
539
547
}
@@ -564,7 +572,7 @@ all_reduce_func_def!(
564
572
" ,
565
573
sum_all,
566
574
af_sum_all,
567
- T :: AggregateOutType
575
+ AggregateOutType
568
576
) ;
569
577
570
578
all_reduce_func_def ! (
@@ -594,7 +602,7 @@ all_reduce_func_def!(
594
602
" ,
595
603
product_all,
596
604
af_product_all,
597
- T :: ProductOutType
605
+ ProductOutType
598
606
) ;
599
607
600
608
all_reduce_func_def ! (
@@ -623,7 +631,7 @@ all_reduce_func_def!(
623
631
" ,
624
632
min_all,
625
633
af_min_all,
626
- T :: InType
634
+ InType
627
635
) ;
628
636
629
637
all_reduce_func_def ! (
@@ -652,10 +660,31 @@ all_reduce_func_def!(
652
660
" ,
653
661
max_all,
654
662
af_max_all,
655
- T :: InType
663
+ InType
656
664
) ;
657
665
658
- all_reduce_func_def ! (
666
+ macro_rules! all_reduce_func_def2 {
667
+ ( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
668
+ #[ doc=$doc_str]
669
+ pub fn $fn_name<T >( input: & Array <T >) -> ( $out_type, $out_type)
670
+ where
671
+ T : HasAfEnum ,
672
+ $out_type: HasAfEnum + Fromf64
673
+ {
674
+ let mut real: f64 = 0.0 ;
675
+ let mut imag: f64 = 0.0 ;
676
+ unsafe {
677
+ let err_val = $ffi_name(
678
+ & mut real as * mut c_double, & mut imag as * mut c_double, input. get( ) ,
679
+ ) ;
680
+ HANDLE_ERROR ( AfError :: from( err_val) ) ;
681
+ }
682
+ ( <$out_type>:: fromf64( real) , <$out_type>:: fromf64( imag) )
683
+ }
684
+ } ;
685
+ }
686
+
687
+ all_reduce_func_def2 ! (
659
688
"
660
689
Find if all values of Array are non-zero
661
690
@@ -682,7 +711,7 @@ all_reduce_func_def!(
682
711
bool
683
712
) ;
684
713
685
- all_reduce_func_def ! (
714
+ all_reduce_func_def2 ! (
686
715
"
687
716
Find if any value of Array is non-zero
688
717
@@ -709,7 +738,7 @@ all_reduce_func_def!(
709
738
bool
710
739
) ;
711
740
712
- all_reduce_func_def ! (
741
+ all_reduce_func_def2 ! (
713
742
"
714
743
Count number of non-zero values in the Array
715
744
@@ -751,10 +780,17 @@ all_reduce_func_def!(
751
780
/// A tuple of summation result.
752
781
///
753
782
/// Note: For non-complex data type Arrays, second value of tuple is zero.
754
- pub fn sum_nan_all < T > ( input : & Array < T > , val : f64 ) -> ( T :: AggregateOutType , T :: AggregateOutType )
783
+ pub fn sum_nan_all < T > (
784
+ input : & Array < T > ,
785
+ val : f64 ,
786
+ ) -> (
787
+ <<T as HasAfEnum >:: AggregateOutType as HasAfEnum >:: BaseType ,
788
+ <<T as HasAfEnum >:: AggregateOutType as HasAfEnum >:: BaseType ,
789
+ )
755
790
where
756
791
T : HasAfEnum ,
757
- T :: AggregateOutType : HasAfEnum + Fromf64 ,
792
+ <T as HasAfEnum >:: AggregateOutType : HasAfEnum ,
793
+ <<T as HasAfEnum >:: AggregateOutType as HasAfEnum >:: BaseType : HasAfEnum + Fromf64 ,
758
794
{
759
795
let mut real: f64 = 0.0 ;
760
796
let mut imag: f64 = 0.0 ;
@@ -768,8 +804,8 @@ where
768
804
HANDLE_ERROR ( AfError :: from ( err_val) ) ;
769
805
}
770
806
(
771
- <T :: AggregateOutType > :: fromf64 ( real) ,
772
- <T :: AggregateOutType > :: fromf64 ( imag) ,
807
+ << T as HasAfEnum > :: AggregateOutType as HasAfEnum > :: BaseType :: fromf64 ( real) ,
808
+ << T as HasAfEnum > :: AggregateOutType as HasAfEnum > :: BaseType :: fromf64 ( imag) ,
773
809
)
774
810
}
775
811
@@ -788,10 +824,17 @@ where
788
824
/// A tuple of product result.
789
825
///
790
826
/// Note: For non-complex data type Arrays, second value of tuple is zero.
791
- pub fn product_nan_all < T > ( input : & Array < T > , val : f64 ) -> ( T :: ProductOutType , T :: ProductOutType )
827
+ pub fn product_nan_all < T > (
828
+ input : & Array < T > ,
829
+ val : f64 ,
830
+ ) -> (
831
+ <<T as HasAfEnum >:: ProductOutType as HasAfEnum >:: BaseType ,
832
+ <<T as HasAfEnum >:: ProductOutType as HasAfEnum >:: BaseType ,
833
+ )
792
834
where
793
835
T : HasAfEnum ,
794
- T :: ProductOutType : HasAfEnum + Fromf64 ,
836
+ <T as HasAfEnum >:: ProductOutType : HasAfEnum ,
837
+ <<T as HasAfEnum >:: ProductOutType as HasAfEnum >:: BaseType : HasAfEnum + Fromf64 ,
795
838
{
796
839
let mut real: f64 = 0.0 ;
797
840
let mut imag: f64 = 0.0 ;
@@ -805,8 +848,8 @@ where
805
848
HANDLE_ERROR ( AfError :: from ( err_val) ) ;
806
849
}
807
850
(
808
- <T :: ProductOutType > :: fromf64 ( real) ,
809
- <T :: ProductOutType > :: fromf64 ( imag) ,
851
+ << T as HasAfEnum > :: ProductOutType as HasAfEnum > :: BaseType :: fromf64 ( real) ,
852
+ << T as HasAfEnum > :: ProductOutType as HasAfEnum > :: BaseType :: fromf64 ( imag) ,
810
853
)
811
854
}
812
855
@@ -858,12 +901,18 @@ dim_ireduce_func_def!("
858
901
" , imax, af_imax, InType ) ;
859
902
860
903
macro_rules! all_ireduce_func_def {
861
- ( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type : ty ) => {
904
+ ( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type : ident ) => {
862
905
#[ doc=$doc_str]
863
- pub fn $fn_name<T >( input: & Array <T >) -> ( $out_type, $out_type, u32 )
906
+ pub fn $fn_name<T >( input: & Array <T >)
907
+ -> (
908
+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType ,
909
+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType ,
910
+ u32
911
+ )
864
912
where
865
913
T : HasAfEnum ,
866
- $out_type: HasAfEnum + Fromf64
914
+ <T as HasAfEnum >:: $assoc_type: HasAfEnum ,
915
+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType : HasAfEnum + Fromf64 ,
867
916
{
868
917
let mut real: f64 = 0.0 ;
869
918
let mut imag: f64 = 0.0 ;
@@ -875,7 +924,11 @@ macro_rules! all_ireduce_func_def {
875
924
) ;
876
925
HANDLE_ERROR ( AfError :: from( err_val) ) ;
877
926
}
878
- ( <$out_type>:: fromf64( real) , <$out_type>:: fromf64( imag) , temp)
927
+ (
928
+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType :: fromf64( real) ,
929
+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType :: fromf64( imag) ,
930
+ temp,
931
+ )
879
932
}
880
933
} ;
881
934
}
@@ -898,7 +951,7 @@ all_ireduce_func_def!(
898
951
" ,
899
952
imin_all,
900
953
af_imin_all,
901
- T :: InType
954
+ InType
902
955
) ;
903
956
all_ireduce_func_def ! (
904
957
"
@@ -918,7 +971,7 @@ all_ireduce_func_def!(
918
971
" ,
919
972
imax_all,
920
973
af_imax_all,
921
- T :: InType
974
+ InType
922
975
) ;
923
976
924
977
/// Locate the indices of non-zero elements.
@@ -1386,3 +1439,40 @@ dim_reduce_by_key_nan_func_def!(
1386
1439
af_product_by_key_nan,
1387
1440
ValueType :: ProductOutType
1388
1441
) ;
1442
+
1443
+ #[ cfg( test) ]
1444
+ mod tests {
1445
+ use super :: super :: core:: c32;
1446
+ use super :: { imax_all, imin_all, product_nan_all, sum_all, sum_nan_all} ;
1447
+ use crate :: randu;
1448
+
1449
+ #[ test]
1450
+ fn all_reduce_api ( ) {
1451
+ let a = randu ! ( c32; 10 , 10 ) ;
1452
+ println ! ( "Reduction of complex f32 matrix: {:?}" , sum_all( & a) ) ;
1453
+
1454
+ let b = randu ! ( bool ; 10 , 10 ) ;
1455
+ println ! ( "reduction of bool matrix: {:?}" , sum_all( & b) ) ;
1456
+
1457
+ println ! (
1458
+ "reduction of complex f32 matrix after replacing nan with {}: {:?}" ,
1459
+ 1.0 ,
1460
+ product_nan_all( & a, 1.0 )
1461
+ ) ;
1462
+
1463
+ println ! (
1464
+ "reduction of bool matrix after replacing nan with {}: {:?}" ,
1465
+ 0.0 ,
1466
+ sum_nan_all( & b, 0.0 )
1467
+ ) ;
1468
+ }
1469
+
1470
+ #[ test]
1471
+ fn all_ireduce_api ( ) {
1472
+ let a = randu ! ( c32; 10 ) ;
1473
+ println ! ( "Reduction of complex f32 matrix: {:?}" , imin_all( & a) ) ;
1474
+
1475
+ let b = randu ! ( u32 ; 10 ) ;
1476
+ println ! ( "reduction of bool matrix: {:?}" , imax_all( & b) ) ;
1477
+ }
1478
+ }
0 commit comments