Skip to content
This repository
tree: e914c03997
Fetching contributors…

Cannot retrieve contributors at this time

file 91 lines (75 sloc) 5.516 kb
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
#include <unittest/unittest.h>
#include <thrust/functional.h>
#include <thrust/transform.h>

#include <functional>
#include <algorithm>
    
const size_t NUM_SAMPLES = 10000;

template <class InputVector, class OutputVector, class Operator, class ReferenceOperator>
void TestBinaryFunctional(void)
{
    typedef typename InputVector::value_type InputType;
    typedef typename OutputVector::value_type OutputType;
    
    thrust::host_vector<InputType> std_input1 = unittest::random_samples<InputType>(NUM_SAMPLES);
    thrust::host_vector<InputType> std_input2 = unittest::random_samples<InputType>(NUM_SAMPLES);
    thrust::host_vector<OutputType> std_output(NUM_SAMPLES);

    // Replace zeros to avoid divide by zero exceptions
    std::replace(std_input2.begin(), std_input2.end(), (InputType) 0, (InputType) 1);

    InputVector input1 = std_input1;
    InputVector input2 = std_input2;
    OutputVector output(NUM_SAMPLES);

    thrust::transform( input1.begin(), input1.end(), input2.begin(), output.begin(), Operator());
    thrust::transform(std_input1.begin(), std_input1.end(), std_input2.begin(), std_output.begin(), ReferenceOperator());

    // Note: FP division is not bit-equal, even when nvcc is invoked with --prec-div
    ASSERT_ALMOST_EQUAL(output, std_output);
}


// XXX add bool to list
// Instantiate a macro for all integer-like data types
#define INSTANTIATE_INTEGER_TYPES(Macro, vector_type, operator_name) \
Macro(vector_type, operator_name, char ) \
Macro(vector_type, operator_name, unsigned char ) \
Macro(vector_type, operator_name, short ) \
Macro(vector_type, operator_name, unsigned short) \
Macro(vector_type, operator_name, int ) \
Macro(vector_type, operator_name, unsigned int ) \
Macro(vector_type, operator_name, long ) \
Macro(vector_type, operator_name, unsigned long )

// Instantiate a macro for all integer and floating point data types
#define INSTANTIATE_ALL_TYPES(Macro, vector_type, operator_name) \
INSTANTIATE_INTEGER_TYPES(Macro, vector_type, operator_name) \
Macro(vector_type, operator_name, float)


// op(T,T) -> T
#define INSTANTIATE_BINARY_ARITHMETIC_FUNCTIONAL_TEST(vector_type, operator_name, data_type) \
TestBinaryFunctional< thrust::vector_type<data_type>, \
thrust::vector_type<data_type>, \
thrust::operator_name<data_type>, \
std::operator_name<data_type> >();
// op(T,T) -> T
#define DECLARE_BINARY_ARITHMETIC_FUNCTIONAL_UNITTEST(operator_name, OperatorName) \
void Test##OperatorName##FunctionalHost(void) \
{ \
INSTANTIATE_ALL_TYPES( INSTANTIATE_BINARY_ARITHMETIC_FUNCTIONAL_TEST, host_vector, operator_name); \
} \
DECLARE_UNITTEST(Test##OperatorName##FunctionalHost); \
void Test##OperatorName##FunctionalDevice(void) \
{ \
INSTANTIATE_ALL_TYPES( INSTANTIATE_BINARY_ARITHMETIC_FUNCTIONAL_TEST, device_vector, operator_name); \
} \
DECLARE_UNITTEST(Test##OperatorName##FunctionalDevice);

// op(T,T) -> T (for integer T only)
#define DECLARE_BINARY_INTEGER_ARITHMETIC_FUNCTIONAL_UNITTEST(operator_name, OperatorName) \
void Test##OperatorName##FunctionalHost(void) \
{ \
INSTANTIATE_INTEGER_TYPES( INSTANTIATE_BINARY_ARITHMETIC_FUNCTIONAL_TEST, host_vector, operator_name); \
} \
DECLARE_UNITTEST(Test##OperatorName##FunctionalHost); \
void Test##OperatorName##FunctionalDevice(void) \
{ \
INSTANTIATE_INTEGER_TYPES( INSTANTIATE_BINARY_ARITHMETIC_FUNCTIONAL_TEST, device_vector, operator_name); \
} \
DECLARE_UNITTEST(Test##OperatorName##FunctionalDevice);

// Create the unit tests
DECLARE_BINARY_ARITHMETIC_FUNCTIONAL_UNITTEST(plus, Plus );
DECLARE_BINARY_ARITHMETIC_FUNCTIONAL_UNITTEST(minus, Minus );
DECLARE_BINARY_ARITHMETIC_FUNCTIONAL_UNITTEST(multiplies, Multiplies);
DECLARE_BINARY_ARITHMETIC_FUNCTIONAL_UNITTEST(divides, Divides );

DECLARE_BINARY_INTEGER_ARITHMETIC_FUNCTIONAL_UNITTEST(modulus, Modulus);
Something went wrong with that request. Please try again.