Skip to content

Commit f3b6c03

Browse files
committed
Enable thread-safety marker traits for structs
- Array: Send, Sync - Features: Send, Sync - Event: Send - RandomEngine: Send - Indexer: Send Added a new threading tutorial with code examples illustrating how to share Array across threads. Added unit tests in corresponding modules
1 parent 653ef75 commit f3b6c03

File tree

7 files changed

+421
-0
lines changed

7 files changed

+421
-0
lines changed

src/core/array.rs

+249
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,17 @@ extern "C" {
144144
///
145145
/// Currently, Array objects can store only data until four dimensions
146146
///
147+
/// ## Sharing Across Threads
148+
///
149+
/// While sharing an Array with other threads, there is no need to wrap
150+
/// this in an Arc object unless only one such object is required to exist.
151+
/// The reason being that ArrayFire's internal Array is appropriately reference
152+
/// counted in thread safe manner. However, if you need to modify Array object,
153+
/// then please do wrap the object using a Mutex or Read-Write lock.
154+
///
155+
/// Examples on how to share Array across threads is illustrated in our
156+
/// [book](http://arrayfire.org/arrayfire-rust/book/multi-threading.html)
157+
///
147158
/// ### NOTE
148159
///
149160
/// All operators(traits) from std::ops module implemented for Array object
@@ -156,6 +167,11 @@ pub struct Array<T: HasAfEnum> {
156167
_marker: PhantomData<T>,
157168
}
158169

170+
/// Enable safely moving Array objects across threads
171+
unsafe impl<T: HasAfEnum> Send for Array<T> {}
172+
173+
unsafe impl<T: HasAfEnum> Sync for Array<T> {}
174+
159175
macro_rules! is_func {
160176
($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => (
161177
#[doc=$doc_str]
@@ -834,3 +850,236 @@ pub fn is_eval_manual() -> bool {
834850
ret_val > 0
835851
}
836852
}
853+
854+
#[cfg(test)]
855+
mod tests {
856+
use super::super::array::print;
857+
use super::super::data::constant;
858+
use super::super::device::{info, set_device, sync};
859+
use crate::dim4;
860+
use std::sync::{mpsc, Arc, RwLock};
861+
use std::thread;
862+
863+
#[test]
864+
fn thread_move_array() {
865+
// ANCHOR: move_array_to_thread
866+
set_device(0);
867+
info();
868+
let mut a = constant(1, dim4!(3, 3));
869+
870+
let handle = thread::spawn(move || {
871+
//set_device to appropriate device id is required in each thread
872+
set_device(0);
873+
874+
println!("\nFrom thread {:?}", thread::current().id());
875+
876+
a += constant(2, dim4!(3, 3));
877+
print(&a);
878+
});
879+
880+
//Need to join other threads as main thread holds arrayfire context
881+
handle.join().unwrap();
882+
// ANCHOR_END: move_array_to_thread
883+
}
884+
885+
#[test]
886+
fn thread_borrow_array() {
887+
set_device(0);
888+
info();
889+
let a = constant(1i32, dim4!(3, 3));
890+
891+
let handle = thread::spawn(move || {
892+
set_device(0); //set_device to appropriate device id is required in each thread
893+
println!("\nFrom thread {:?}", thread::current().id());
894+
print(&a);
895+
});
896+
//Need to join other threads as main thread holds arrayfire context
897+
handle.join().unwrap();
898+
}
899+
900+
// ANCHOR: multiple_threads_enum_def
901+
#[derive(Debug, Copy, Clone)]
902+
enum Op {
903+
Add,
904+
Sub,
905+
Div,
906+
Mul,
907+
}
908+
// ANCHOR_END: multiple_threads_enum_def
909+
910+
#[test]
911+
fn read_from_multiple_threads() {
912+
// ANCHOR: read_from_multiple_threads
913+
let ops: Vec<_> = vec![Op::Add, Op::Sub, Op::Div, Op::Mul, Op::Add, Op::Div];
914+
915+
// Set active GPU/device on main thread on which
916+
// subsequent Array objects are created
917+
set_device(0);
918+
919+
// ArrayFire Array's are internally maintained via atomic reference counting
920+
// Thus, they need no Arc wrapping while moving to another thread.
921+
// Just call clone method on the object and share the resulting clone object
922+
let a = constant(1.0f32, dim4!(3, 3));
923+
let b = constant(2.0f32, dim4!(3, 3));
924+
925+
let threads: Vec<_> = ops
926+
.into_iter()
927+
.map(|op| {
928+
let x = a.clone();
929+
let y = b.clone();
930+
thread::spawn(move || {
931+
set_device(0); //Both of objects are created on device 0 earlier
932+
match op {
933+
Op::Add => {
934+
let _c = x + y;
935+
}
936+
Op::Sub => {
937+
let _c = x - y;
938+
}
939+
Op::Div => {
940+
let _c = x / y;
941+
}
942+
Op::Mul => {
943+
let _c = x * y;
944+
}
945+
}
946+
sync(0);
947+
thread::sleep(std::time::Duration::new(1, 0));
948+
})
949+
})
950+
.collect();
951+
for child in threads {
952+
let _ = child.join();
953+
}
954+
// ANCHOR_END: read_from_multiple_threads
955+
}
956+
957+
#[test]
958+
fn access_using_rwlock() {
959+
// ANCHOR: access_using_rwlock
960+
let ops: Vec<_> = vec![Op::Add, Op::Sub, Op::Div, Op::Mul, Op::Add, Op::Div];
961+
962+
// Set active GPU/device on main thread on which
963+
// subsequent Array objects are created
964+
set_device(0);
965+
966+
let c = constant(0.0f32, dim4!(3, 3));
967+
let a = constant(1.0f32, dim4!(3, 3));
968+
let b = constant(2.0f32, dim4!(3, 3));
969+
970+
// Move ownership to RwLock and wrap in Arc since same object is to be modified
971+
let c_lock = Arc::new(RwLock::new(c));
972+
973+
// a and b are internally reference counted by ArrayFire. Unless there
974+
// is prior known need that they may be modified, you can simply clone
975+
// the objects pass them to threads
976+
977+
let threads: Vec<_> = ops
978+
.into_iter()
979+
.map(|op| {
980+
let x = a.clone();
981+
let y = b.clone();
982+
983+
let wlock = c_lock.clone();
984+
thread::spawn(move || {
985+
//Both of objects are created on device 0 in main thread
986+
//Every thread needs to set the device that it is going to
987+
//work on. Note that all Array objects must have been created
988+
//on same device as of date this is written on.
989+
set_device(0);
990+
if let Ok(mut c_guard) = wlock.write() {
991+
match op {
992+
Op::Add => {
993+
*c_guard += x + y;
994+
}
995+
Op::Sub => {
996+
*c_guard += x - y;
997+
}
998+
Op::Div => {
999+
*c_guard += x / y;
1000+
}
1001+
Op::Mul => {
1002+
*c_guard += x * y;
1003+
}
1004+
}
1005+
}
1006+
})
1007+
})
1008+
.collect();
1009+
1010+
for child in threads {
1011+
let _ = child.join();
1012+
}
1013+
1014+
//let read_guard = c_lock.read().unwrap();
1015+
//af_print!("C after threads joined", *read_guard);
1016+
//C after threads joined
1017+
//[3 3 1 1]
1018+
// 8.0000 8.0000 8.0000
1019+
// 8.0000 8.0000 8.0000
1020+
// 8.0000 8.0000 8.0000
1021+
// ANCHOR_END: access_using_rwlock
1022+
}
1023+
1024+
#[test]
1025+
fn accum_using_channel() {
1026+
// ANCHOR: accum_using_channel
1027+
let ops: Vec<_> = vec![Op::Add, Op::Sub, Op::Div, Op::Mul, Op::Add, Op::Div];
1028+
let ops_len: usize = ops.len();
1029+
1030+
// Set active GPU/device on main thread on which
1031+
// subsequent Array objects are created
1032+
set_device(0);
1033+
1034+
let mut c = constant(0.0f32, dim4!(3, 3));
1035+
let a = constant(1.0f32, dim4!(3, 3));
1036+
let b = constant(2.0f32, dim4!(3, 3));
1037+
1038+
let (tx, rx) = mpsc::channel();
1039+
1040+
let threads: Vec<_> = ops
1041+
.into_iter()
1042+
.map(|op| {
1043+
// a and b are internally reference counted by ArrayFire. Unless there
1044+
// is prior known need that they may be modified, you can simply clone
1045+
// the objects pass them to threads
1046+
let x = a.clone();
1047+
let y = b.clone();
1048+
1049+
let tx_clone = tx.clone();
1050+
1051+
thread::spawn(move || {
1052+
//Both of objects are created on device 0 in main thread
1053+
//Every thread needs to set the device that it is going to
1054+
//work on. Note that all Array objects must have been created
1055+
//on same device as of date this is written on.
1056+
set_device(0);
1057+
1058+
let c = match op {
1059+
Op::Add => x + y,
1060+
Op::Sub => x - y,
1061+
Op::Div => x / y,
1062+
Op::Mul => x * y,
1063+
};
1064+
tx_clone.send(c).unwrap();
1065+
})
1066+
})
1067+
.collect();
1068+
1069+
for _i in 0..ops_len {
1070+
c += rx.recv().unwrap();
1071+
}
1072+
1073+
//Need to join other threads as main thread holds arrayfire context
1074+
for child in threads {
1075+
let _ = child.join();
1076+
}
1077+
1078+
//af_print!("C after accumulating results", &c);
1079+
//[3 3 1 1]
1080+
// 8.0000 8.0000 8.0000
1081+
// 8.0000 8.0000 8.0000
1082+
// 8.0000 8.0000 8.0000
1083+
// ANCHOR_END: accum_using_channel
1084+
}
1085+
}

src/core/event.rs

+85
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,17 @@ extern "C" {
1414
}
1515

1616
/// RAII construct to manage ArrayFire events
17+
///
18+
/// ## Sharing Across Threads
19+
///
20+
/// While sharing an Event with other threads, just move it across threads.
1721
pub struct Event {
1822
event_handle: af_event,
1923
}
2024

25+
unsafe impl Send for Event {}
26+
// No borrowed references are to be shared for Events, hence no sync trait
27+
2128
impl Default for Event {
2229
fn default() -> Self {
2330
let mut temp: af_event = std::ptr::null_mut();
@@ -74,3 +81,81 @@ impl Drop for Event {
7481
}
7582
}
7683
}
84+
85+
#[cfg(test)]
86+
mod tests {
87+
use super::super::arith::pow;
88+
use super::super::device::{info, set_device};
89+
use super::super::event::Event;
90+
use crate::{af_print, randu};
91+
use std::sync::mpsc;
92+
use std::thread;
93+
94+
#[test]
95+
fn event_block() {
96+
// This code example will try to compute the following expression
97+
// using data-graph approach using threads, evens for illustration.
98+
//
99+
// (a * (b + c))^(d - 2)
100+
//
101+
// ANCHOR: event_block
102+
103+
// Set active GPU/device on main thread on which
104+
// subsequent Array objects are created
105+
set_device(0);
106+
info();
107+
108+
let a = randu!(10, 10);
109+
let b = randu!(10, 10);
110+
let c = randu!(10, 10);
111+
let d = randu!(10, 10);
112+
113+
let (tx, rx) = mpsc::channel();
114+
115+
thread::spawn(move || {
116+
set_device(0);
117+
118+
let add_event = Event::default();
119+
120+
let add_res = b + c;
121+
122+
add_event.mark();
123+
tx.send((add_res, add_event)).unwrap();
124+
125+
thread::sleep(std::time::Duration::new(10, 0));
126+
127+
let sub_event = Event::default();
128+
129+
let sub_res = d - 2;
130+
131+
sub_event.mark();
132+
tx.send((sub_res, sub_event)).unwrap();
133+
});
134+
135+
let (add_res, add_event) = rx.recv().unwrap();
136+
137+
println!("Got first message, waiting for addition result ...");
138+
thread::sleep(std::time::Duration::new(5, 0));
139+
// Perhaps, do some other tasks
140+
add_event.block();
141+
142+
println!("Got addition result, now scaling it ... ");
143+
let scaled = a * add_res;
144+
145+
let (sub_res, sub_event) = rx.recv().unwrap();
146+
147+
println!("Got message, waiting for subtraction result ...");
148+
thread::sleep(std::time::Duration::new(5, 0));
149+
// Perhaps, do some other tasks
150+
sub_event.block();
151+
152+
let fin_res = pow(&scaled, &sub_res, false);
153+
154+
af_print!(
155+
"Final result of the expression: ((a * (b + c))^(d - 2))",
156+
&fin_res
157+
);
158+
159+
// ANCHOR_END: event_block
160+
}
161+
}

src/core/index.rs

+7
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ extern "C" {
5151

5252
/// Struct to manage an array of resources of type `af_indexer_t`(ArrayFire C struct)
5353
///
54+
/// ## Sharing Across Threads
55+
///
56+
/// While sharing an Indexer object with other threads, just move it across threads. At the
57+
/// moment, one cannot share borrowed references across threads.
58+
///
5459
/// # Examples
5560
///
5661
/// Given below are examples illustrating correct and incorrect usage of Indexer struct.
@@ -108,6 +113,8 @@ pub struct Indexer<'object> {
108113
marker: PhantomData<&'object ()>,
109114
}
110115

116+
unsafe impl<'object> Send for Indexer<'object> {}
117+
111118
/// Trait bound indicating indexability
112119
///
113120
/// Any object to be able to be passed on to [Indexer::set_index()](./struct.Indexer.html#method.set_index) method should implement this trait with appropriate implementation of `set` method.

0 commit comments

Comments
 (0)