@@ -144,6 +144,17 @@ extern "C" {
144
144
///
145
145
/// Currently, Array objects can store only data until four dimensions
146
146
///
147
+ /// ## Sharing Across Threads
148
+ ///
149
+ /// While sharing an Array with other threads, there is no need to wrap
150
+ /// this in an Arc object unless only one such object is required to exist.
151
+ /// The reason being that ArrayFire's internal Array is appropriately reference
152
+ /// counted in thread safe manner. However, if you need to modify Array object,
153
+ /// then please do wrap the object using a Mutex or Read-Write lock.
154
+ ///
155
+ /// Examples on how to share Array across threads is illustrated in our
156
+ /// [book](http://arrayfire.org/arrayfire-rust/book/multi-threading.html)
157
+ ///
147
158
/// ### NOTE
148
159
///
149
160
/// All operators(traits) from std::ops module implemented for Array object
@@ -156,6 +167,11 @@ pub struct Array<T: HasAfEnum> {
156
167
_marker : PhantomData < T > ,
157
168
}
158
169
170
+ /// Enable safely moving Array objects across threads
171
+ unsafe impl < T : HasAfEnum > Send for Array < T > { }
172
+
173
+ unsafe impl < T : HasAfEnum > Sync for Array < T > { }
174
+
159
175
macro_rules! is_func {
160
176
( $doc_str: expr, $fn_name: ident, $ffi_fn: ident) => (
161
177
#[ doc=$doc_str]
@@ -834,3 +850,236 @@ pub fn is_eval_manual() -> bool {
834
850
ret_val > 0
835
851
}
836
852
}
853
+
854
+ #[ cfg( test) ]
855
+ mod tests {
856
+ use super :: super :: array:: print;
857
+ use super :: super :: data:: constant;
858
+ use super :: super :: device:: { info, set_device, sync} ;
859
+ use crate :: dim4;
860
+ use std:: sync:: { mpsc, Arc , RwLock } ;
861
+ use std:: thread;
862
+
863
+ #[ test]
864
+ fn thread_move_array ( ) {
865
+ // ANCHOR: move_array_to_thread
866
+ set_device ( 0 ) ;
867
+ info ( ) ;
868
+ let mut a = constant ( 1 , dim4 ! ( 3 , 3 ) ) ;
869
+
870
+ let handle = thread:: spawn ( move || {
871
+ //set_device to appropriate device id is required in each thread
872
+ set_device ( 0 ) ;
873
+
874
+ println ! ( "\n From thread {:?}" , thread:: current( ) . id( ) ) ;
875
+
876
+ a += constant ( 2 , dim4 ! ( 3 , 3 ) ) ;
877
+ print ( & a) ;
878
+ } ) ;
879
+
880
+ //Need to join other threads as main thread holds arrayfire context
881
+ handle. join ( ) . unwrap ( ) ;
882
+ // ANCHOR_END: move_array_to_thread
883
+ }
884
+
885
+ #[ test]
886
+ fn thread_borrow_array ( ) {
887
+ set_device ( 0 ) ;
888
+ info ( ) ;
889
+ let a = constant ( 1i32 , dim4 ! ( 3 , 3 ) ) ;
890
+
891
+ let handle = thread:: spawn ( move || {
892
+ set_device ( 0 ) ; //set_device to appropriate device id is required in each thread
893
+ println ! ( "\n From thread {:?}" , thread:: current( ) . id( ) ) ;
894
+ print ( & a) ;
895
+ } ) ;
896
+ //Need to join other threads as main thread holds arrayfire context
897
+ handle. join ( ) . unwrap ( ) ;
898
+ }
899
+
900
+ // ANCHOR: multiple_threads_enum_def
901
+ #[ derive( Debug , Copy , Clone ) ]
902
+ enum Op {
903
+ Add ,
904
+ Sub ,
905
+ Div ,
906
+ Mul ,
907
+ }
908
+ // ANCHOR_END: multiple_threads_enum_def
909
+
910
+ #[ test]
911
+ fn read_from_multiple_threads ( ) {
912
+ // ANCHOR: read_from_multiple_threads
913
+ let ops: Vec < _ > = vec ! [ Op :: Add , Op :: Sub , Op :: Div , Op :: Mul , Op :: Add , Op :: Div ] ;
914
+
915
+ // Set active GPU/device on main thread on which
916
+ // subsequent Array objects are created
917
+ set_device ( 0 ) ;
918
+
919
+ // ArrayFire Array's are internally maintained via atomic reference counting
920
+ // Thus, they need no Arc wrapping while moving to another thread.
921
+ // Just call clone method on the object and share the resulting clone object
922
+ let a = constant ( 1.0f32 , dim4 ! ( 3 , 3 ) ) ;
923
+ let b = constant ( 2.0f32 , dim4 ! ( 3 , 3 ) ) ;
924
+
925
+ let threads: Vec < _ > = ops
926
+ . into_iter ( )
927
+ . map ( |op| {
928
+ let x = a. clone ( ) ;
929
+ let y = b. clone ( ) ;
930
+ thread:: spawn ( move || {
931
+ set_device ( 0 ) ; //Both of objects are created on device 0 earlier
932
+ match op {
933
+ Op :: Add => {
934
+ let _c = x + y;
935
+ }
936
+ Op :: Sub => {
937
+ let _c = x - y;
938
+ }
939
+ Op :: Div => {
940
+ let _c = x / y;
941
+ }
942
+ Op :: Mul => {
943
+ let _c = x * y;
944
+ }
945
+ }
946
+ sync ( 0 ) ;
947
+ thread:: sleep ( std:: time:: Duration :: new ( 1 , 0 ) ) ;
948
+ } )
949
+ } )
950
+ . collect ( ) ;
951
+ for child in threads {
952
+ let _ = child. join ( ) ;
953
+ }
954
+ // ANCHOR_END: read_from_multiple_threads
955
+ }
956
+
957
+ #[ test]
958
+ fn access_using_rwlock ( ) {
959
+ // ANCHOR: access_using_rwlock
960
+ let ops: Vec < _ > = vec ! [ Op :: Add , Op :: Sub , Op :: Div , Op :: Mul , Op :: Add , Op :: Div ] ;
961
+
962
+ // Set active GPU/device on main thread on which
963
+ // subsequent Array objects are created
964
+ set_device ( 0 ) ;
965
+
966
+ let c = constant ( 0.0f32 , dim4 ! ( 3 , 3 ) ) ;
967
+ let a = constant ( 1.0f32 , dim4 ! ( 3 , 3 ) ) ;
968
+ let b = constant ( 2.0f32 , dim4 ! ( 3 , 3 ) ) ;
969
+
970
+ // Move ownership to RwLock and wrap in Arc since same object is to be modified
971
+ let c_lock = Arc :: new ( RwLock :: new ( c) ) ;
972
+
973
+ // a and b are internally reference counted by ArrayFire. Unless there
974
+ // is prior known need that they may be modified, you can simply clone
975
+ // the objects pass them to threads
976
+
977
+ let threads: Vec < _ > = ops
978
+ . into_iter ( )
979
+ . map ( |op| {
980
+ let x = a. clone ( ) ;
981
+ let y = b. clone ( ) ;
982
+
983
+ let wlock = c_lock. clone ( ) ;
984
+ thread:: spawn ( move || {
985
+ //Both of objects are created on device 0 in main thread
986
+ //Every thread needs to set the device that it is going to
987
+ //work on. Note that all Array objects must have been created
988
+ //on same device as of date this is written on.
989
+ set_device ( 0 ) ;
990
+ if let Ok ( mut c_guard) = wlock. write ( ) {
991
+ match op {
992
+ Op :: Add => {
993
+ * c_guard += x + y;
994
+ }
995
+ Op :: Sub => {
996
+ * c_guard += x - y;
997
+ }
998
+ Op :: Div => {
999
+ * c_guard += x / y;
1000
+ }
1001
+ Op :: Mul => {
1002
+ * c_guard += x * y;
1003
+ }
1004
+ }
1005
+ }
1006
+ } )
1007
+ } )
1008
+ . collect ( ) ;
1009
+
1010
+ for child in threads {
1011
+ let _ = child. join ( ) ;
1012
+ }
1013
+
1014
+ //let read_guard = c_lock.read().unwrap();
1015
+ //af_print!("C after threads joined", *read_guard);
1016
+ //C after threads joined
1017
+ //[3 3 1 1]
1018
+ // 8.0000 8.0000 8.0000
1019
+ // 8.0000 8.0000 8.0000
1020
+ // 8.0000 8.0000 8.0000
1021
+ // ANCHOR_END: access_using_rwlock
1022
+ }
1023
+
1024
+ #[ test]
1025
+ fn accum_using_channel ( ) {
1026
+ // ANCHOR: accum_using_channel
1027
+ let ops: Vec < _ > = vec ! [ Op :: Add , Op :: Sub , Op :: Div , Op :: Mul , Op :: Add , Op :: Div ] ;
1028
+ let ops_len: usize = ops. len ( ) ;
1029
+
1030
+ // Set active GPU/device on main thread on which
1031
+ // subsequent Array objects are created
1032
+ set_device ( 0 ) ;
1033
+
1034
+ let mut c = constant ( 0.0f32 , dim4 ! ( 3 , 3 ) ) ;
1035
+ let a = constant ( 1.0f32 , dim4 ! ( 3 , 3 ) ) ;
1036
+ let b = constant ( 2.0f32 , dim4 ! ( 3 , 3 ) ) ;
1037
+
1038
+ let ( tx, rx) = mpsc:: channel ( ) ;
1039
+
1040
+ let threads: Vec < _ > = ops
1041
+ . into_iter ( )
1042
+ . map ( |op| {
1043
+ // a and b are internally reference counted by ArrayFire. Unless there
1044
+ // is prior known need that they may be modified, you can simply clone
1045
+ // the objects pass them to threads
1046
+ let x = a. clone ( ) ;
1047
+ let y = b. clone ( ) ;
1048
+
1049
+ let tx_clone = tx. clone ( ) ;
1050
+
1051
+ thread:: spawn ( move || {
1052
+ //Both of objects are created on device 0 in main thread
1053
+ //Every thread needs to set the device that it is going to
1054
+ //work on. Note that all Array objects must have been created
1055
+ //on same device as of date this is written on.
1056
+ set_device ( 0 ) ;
1057
+
1058
+ let c = match op {
1059
+ Op :: Add => x + y,
1060
+ Op :: Sub => x - y,
1061
+ Op :: Div => x / y,
1062
+ Op :: Mul => x * y,
1063
+ } ;
1064
+ tx_clone. send ( c) . unwrap ( ) ;
1065
+ } )
1066
+ } )
1067
+ . collect ( ) ;
1068
+
1069
+ for _i in 0 ..ops_len {
1070
+ c += rx. recv ( ) . unwrap ( ) ;
1071
+ }
1072
+
1073
+ //Need to join other threads as main thread holds arrayfire context
1074
+ for child in threads {
1075
+ let _ = child. join ( ) ;
1076
+ }
1077
+
1078
+ //af_print!("C after accumulating results", &c);
1079
+ //[3 3 1 1]
1080
+ // 8.0000 8.0000 8.0000
1081
+ // 8.0000 8.0000 8.0000
1082
+ // 8.0000 8.0000 8.0000
1083
+ // ANCHOR_END: accum_using_channel
1084
+ }
1085
+ }
0 commit comments