Skip to content

Commit

Permalink
add unsafe user-defined reduction operations
Browse files Browse the repository at this point in the history
  • Loading branch information
bsteinb committed Feb 8, 2018
1 parent 7ca36f8 commit c08dca8
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 2 deletions.
29 changes: 28 additions & 1 deletion examples/immediate_reduce.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#![deny(warnings)]
extern crate mpi;

use std::os::raw::{c_int, c_void};

use mpi::traits::*;
use mpi::topology::Rank;
use mpi::collective::SystemOperation;
use mpi::collective::{SystemOperation, UnsafeUserOperation};
#[cfg(feature = "user-operations")]
use mpi::collective::UserOperation;
use mpi::ffi::MPI_Datatype;

#[cfg(feature = "user-operations")]
fn test_user_operations<C: Communicator>(comm: C) {
Expand All @@ -29,6 +32,21 @@ fn test_user_operations<C: Communicator>(comm: C) {
#[cfg(not(feature = "user-operations"))]
fn test_user_operations<C: Communicator>(_: C) {}

unsafe extern "C" fn unsafe_add(
invec: *mut c_void,
inoutvec: *mut c_void,
len: *mut c_int,
_datatype: *mut MPI_Datatype,
) {
use std::slice;

let x: &[Rank] = slice::from_raw_parts(invec as *const Rank, *len as usize);
let y: &mut [Rank] = slice::from_raw_parts_mut(inoutvec as *mut Rank, *len as usize);
for (&x_i, y_i) in x.iter().zip(y) {
*y_i += x_i;
}
}

fn main() {
let universe = mpi::initialize().unwrap();
let world = universe.world();
Expand Down Expand Up @@ -74,4 +92,13 @@ fn main() {
assert_eq!(b, rank.pow(size as u32));

test_user_operations(universe.world());

let mut d = 0;
let op = unsafe { UnsafeUserOperation::commutative(unsafe_add) };
mpi::request::scope(|scope| {
world
.immediate_all_reduce_into(scope, &rank, &mut d, &op)
.wait();
});
assert_eq!(d, size * (size - 1) / 2);
}
25 changes: 24 additions & 1 deletion examples/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
#![cfg_attr(feature = "cargo-clippy", allow(many_single_char_names))]
extern crate mpi;

use std::os::raw::{c_int, c_void};

use mpi::traits::*;
use mpi::topology::Rank;
use mpi::collective::{self, SystemOperation};
use mpi::collective::{self, SystemOperation, UnsafeUserOperation};
#[cfg(feature = "user-operations")]
use mpi::collective::UserOperation;
use mpi::ffi::MPI_Datatype;

#[cfg(feature = "user-operations")]
fn test_user_operations<C: Communicator>(comm: C) {
Expand All @@ -30,6 +33,21 @@ fn test_user_operations<C: Communicator>(comm: C) {
#[cfg(not(feature = "user-operations"))]
fn test_user_operations<C: Communicator>(_: C) {}

unsafe extern "C" fn unsafe_add(
invec: *mut c_void,
inoutvec: *mut c_void,
len: *mut c_int,
_datatype: *mut MPI_Datatype,
) {
use std::slice;

let x: &[Rank] = slice::from_raw_parts(invec as *const Rank, *len as usize);
let y: &mut [Rank] = slice::from_raw_parts_mut(inoutvec as *mut Rank, *len as usize);
for (&x_i, y_i) in x.iter().zip(y) {
*y_i += x_i;
}
}

fn main() {
let universe = mpi::initialize().unwrap();
let world = universe.world();
Expand Down Expand Up @@ -76,4 +94,9 @@ fn main() {
assert_eq!(g, rank.pow(size as u32));

test_user_operations(universe.world());

let mut i = 0;
let op = unsafe { UnsafeUserOperation::commutative(unsafe_add) };
world.all_reduce_into(&(rank + 1), &mut i, &op);
assert_eq!(i, size * (size + 1) / 2);
}
80 changes: 80 additions & 0 deletions src/collective.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,86 @@ unsafe fn user_operation_landing_pad<F>(
);
}

/// An unsafe user-defined operation.
///
/// Unsafe user-defined operations are created from pointers to functions that have the unsafe
/// signatures of user functions defined in the MPI C bindings, `UnsafeUserFunction`.
///
/// The recommended way to create user-defined operations is through the safer `UserOperation`
/// type. This type can be used as a work-around in situations where the `libffi` dependency is not
/// available.
pub struct UnsafeUserOperation {
op: MPI_Op,
}

impl fmt::Debug for UnsafeUserOperation {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("UnsafeUserOperation")
.field(&self.op)
.finish()
}
}

impl Drop for UnsafeUserOperation {
fn drop(&mut self) {
unsafe {
ffi::MPI_Op_free(&mut self.op);
}
}
}

unsafe impl AsRaw for UnsafeUserOperation {
type Raw = MPI_Op;
fn as_raw(&self) -> Self::Raw {
self.op
}
}

impl<'a> Operation for &'a UnsafeUserOperation {}

/// A raw pointer to a function that can be used to define an `UnsafeUserOperation`.
pub type UnsafeUserFunction =
unsafe extern "C" fn(*mut c_void, *mut c_void, *mut c_int, *mut ffi::MPI_Datatype);

impl UnsafeUserOperation {
/// Define an unsafe operation using a function pointer. The operation must be associative.
///
/// This is a more readable shorthand for the `new` method. Refer to [`new`](#method.new) for
/// more information.
pub unsafe fn associative(function: UnsafeUserFunction) -> Self {
Self::new(false, function)
}

/// Define an unsafe operation using a function pointer. The operation must be both associative
/// and commutative.
///
/// This is a more readable shorthand for the `new` method. Refer to [`new`](#method.new) for
/// more information.
pub unsafe fn commutative(function: UnsafeUserFunction) -> Self {
Self::new(true, function)
}

/// Creates an associative and possibly commutative unsafe operation using a function pointer.
///
/// The function receives raw `*mut c_void` as `invec` and `inoutvec` and the number of elemnts
/// of those two vectors as a `*mut c_int` `len`. It shall set `inoutvec`
/// to the value of `f(invec, inoutvec)`, where `f` is a binary associative operation.
///
/// If the operation is also commutative, setting `commute` to `true` may yield performance
/// benefits.
///
/// **Note:** The user function is not allowed to panic.
///
/// # Standard section(s)
///
/// 5.9.5
pub unsafe fn new(commute: bool, function: UnsafeUserFunction) -> Self {
let mut op = mem::uninitialized();
ffi::MPI_Op_create(Some(function), commute as _, &mut op);
UnsafeUserOperation { op }
}
}

/// Perform a local reduction.
///
/// # Examples
Expand Down

0 comments on commit c08dca8

Please sign in to comment.