@@ -527,11 +527,12 @@ where
527
527
macro_rules! all_reduce_func_def {
528
528
( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type: ident) => {
529
529
#[ doc=$doc_str]
530
- pub fn $fn_name<T >( input: & Array <T >)
531
- -> (
532
- <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType ,
533
- <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType
534
- )
530
+ pub fn $fn_name<T >(
531
+ input: & Array <T >,
532
+ ) -> (
533
+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType ,
534
+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType ,
535
+ )
535
536
where
536
537
T : HasAfEnum ,
537
538
<T as HasAfEnum >:: $assoc_type: HasAfEnum ,
@@ -541,7 +542,9 @@ macro_rules! all_reduce_func_def {
541
542
let mut imag: f64 = 0.0 ;
542
543
unsafe {
543
544
let err_val = $ffi_name(
544
- & mut real as * mut c_double, & mut imag as * mut c_double, input. get( ) ,
545
+ & mut real as * mut c_double,
546
+ & mut imag as * mut c_double,
547
+ input. get( ) ,
545
548
) ;
546
549
HANDLE_ERROR ( AfError :: from( err_val) ) ;
547
550
}
@@ -676,13 +679,15 @@ macro_rules! all_reduce_func_def2 {
676
679
pub fn $fn_name<T >( input: & Array <T >) -> ( $out_type, $out_type)
677
680
where
678
681
T : HasAfEnum ,
679
- $out_type: HasAfEnum + Fromf64
682
+ $out_type: HasAfEnum + Fromf64 ,
680
683
{
681
684
let mut real: f64 = 0.0 ;
682
685
let mut imag: f64 = 0.0 ;
683
686
unsafe {
684
687
let err_val = $ffi_name(
685
- & mut real as * mut c_double, & mut imag as * mut c_double, input. get( ) ,
688
+ & mut real as * mut c_double,
689
+ & mut imag as * mut c_double,
690
+ input. get( ) ,
686
691
) ;
687
692
HANDLE_ERROR ( AfError :: from( err_val) ) ;
688
693
}
@@ -869,13 +874,16 @@ macro_rules! dim_ireduce_func_def {
869
874
T :: $out_type: HasAfEnum ,
870
875
{
871
876
unsafe {
872
- let mut temp: af_array = std:: ptr:: null_mut( ) ;
873
- let mut idx: af_array = std:: ptr:: null_mut( ) ;
877
+ let mut temp: af_array = std:: ptr:: null_mut( ) ;
878
+ let mut idx: af_array = std:: ptr:: null_mut( ) ;
874
879
let err_val = $ffi_name(
875
- & mut temp as * mut af_array, & mut idx as * mut af_array, input. get( ) , dim,
880
+ & mut temp as * mut af_array,
881
+ & mut idx as * mut af_array,
882
+ input. get( ) ,
883
+ dim,
876
884
) ;
877
885
HANDLE_ERROR ( AfError :: from( err_val) ) ;
878
- ( temp. into( ) , idx. into( ) )
886
+ ( temp. into( ) , idx. into( ) )
879
887
}
880
888
}
881
889
} ;
@@ -910,12 +918,13 @@ dim_ireduce_func_def!("
910
918
macro_rules! all_ireduce_func_def {
911
919
( $doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type: ident) => {
912
920
#[ doc=$doc_str]
913
- pub fn $fn_name<T >( input: & Array <T >)
914
- -> (
915
- <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType ,
916
- <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType ,
917
- u32
918
- )
921
+ pub fn $fn_name<T >(
922
+ input: & Array <T >,
923
+ ) -> (
924
+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType ,
925
+ <<T as HasAfEnum >:: $assoc_type as HasAfEnum >:: BaseType ,
926
+ u32 ,
927
+ )
919
928
where
920
929
T : HasAfEnum ,
921
930
<T as HasAfEnum >:: $assoc_type: HasAfEnum ,
@@ -926,8 +935,10 @@ macro_rules! all_ireduce_func_def {
926
935
let mut temp: u32 = 0 ;
927
936
unsafe {
928
937
let err_val = $ffi_name(
929
- & mut real as * mut c_double, & mut imag as * mut c_double,
930
- & mut temp as * mut c_uint, input. get( ) ,
938
+ & mut real as * mut c_double,
939
+ & mut imag as * mut c_double,
940
+ & mut temp as * mut c_uint,
941
+ input. get( ) ,
931
942
) ;
932
943
HANDLE_ERROR ( AfError :: from( err_val) ) ;
933
944
}
@@ -1277,23 +1288,28 @@ macro_rules! dim_reduce_by_key_func_def {
1277
1288
/// Tuple of Arrays, with output keys and values after reduction
1278
1289
///
1279
1290
#[ doc=$ex_str]
1280
- pub fn $fn_name<KeyType , ValueType >( keys: & Array <KeyType >, vals: & Array <ValueType >,
1281
- dim: i32
1291
+ pub fn $fn_name<KeyType , ValueType >(
1292
+ keys: & Array <KeyType >,
1293
+ vals: & Array <ValueType >,
1294
+ dim: i32 ,
1282
1295
) -> ( Array <KeyType >, Array <$out_type>)
1283
1296
where
1284
1297
KeyType : ReduceByKeyInput ,
1285
1298
ValueType : HasAfEnum ,
1286
1299
$out_type: HasAfEnum ,
1287
1300
{
1288
1301
unsafe {
1289
- let mut out_keys: af_array = std:: ptr:: null_mut( ) ;
1290
- let mut out_vals: af_array = std:: ptr:: null_mut( ) ;
1302
+ let mut out_keys: af_array = std:: ptr:: null_mut( ) ;
1303
+ let mut out_vals: af_array = std:: ptr:: null_mut( ) ;
1291
1304
let err_val = $ffi_name(
1292
- & mut out_keys as * mut af_array, & mut out_vals as * mut af_array,
1293
- keys. get( ) , vals. get( ) , dim,
1305
+ & mut out_keys as * mut af_array,
1306
+ & mut out_vals as * mut af_array,
1307
+ keys. get( ) ,
1308
+ vals. get( ) ,
1309
+ dim,
1294
1310
) ;
1295
1311
HANDLE_ERROR ( AfError :: from( err_val) ) ;
1296
- ( out_keys. into( ) , out_vals. into( ) )
1312
+ ( out_keys. into( ) , out_vals. into( ) )
1297
1313
}
1298
1314
}
1299
1315
} ;
@@ -1408,24 +1424,30 @@ macro_rules! dim_reduce_by_key_nan_func_def {
1408
1424
/// Tuple of Arrays, with output keys and values after reduction
1409
1425
///
1410
1426
#[ doc=$ex_str]
1411
- pub fn $fn_name<KeyType , ValueType >( keys: & Array <KeyType >, vals: & Array <ValueType >,
1412
- dim: i32 , replace_value: f64
1427
+ pub fn $fn_name<KeyType , ValueType >(
1428
+ keys: & Array <KeyType >,
1429
+ vals: & Array <ValueType >,
1430
+ dim: i32 ,
1431
+ replace_value: f64 ,
1413
1432
) -> ( Array <KeyType >, Array <$out_type>)
1414
1433
where
1415
1434
KeyType : ReduceByKeyInput ,
1416
1435
ValueType : HasAfEnum ,
1417
1436
$out_type: HasAfEnum ,
1418
1437
{
1419
1438
unsafe {
1420
- let mut out_keys: af_array = std:: ptr:: null_mut( ) ;
1421
- let mut out_vals: af_array = std:: ptr:: null_mut( ) ;
1439
+ let mut out_keys: af_array = std:: ptr:: null_mut( ) ;
1440
+ let mut out_vals: af_array = std:: ptr:: null_mut( ) ;
1422
1441
let err_val = $ffi_name(
1423
1442
& mut out_keys as * mut af_array,
1424
1443
& mut out_vals as * mut af_array,
1425
- keys. get( ) , vals. get( ) , dim, replace_value,
1444
+ keys. get( ) ,
1445
+ vals. get( ) ,
1446
+ dim,
1447
+ replace_value,
1426
1448
) ;
1427
1449
HANDLE_ERROR ( AfError :: from( err_val) ) ;
1428
- ( out_keys. into( ) , out_vals. into( ) )
1450
+ ( out_keys. into( ) , out_vals. into( ) )
1429
1451
}
1430
1452
}
1431
1453
} ;
0 commit comments