Skip to content

Commit

Permalink
Merge pull request #42 from wolfv/dotbugs
Browse files Browse the repository at this point in the history
many tests for dot, and bug fixes around transpose + dot
  • Loading branch information
SylvainCorlay committed Nov 14, 2017
2 parents e9be213 + dabd36c commit 363a592
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 76 deletions.
2 changes: 1 addition & 1 deletion .appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ install:
- conda config --set always_yes yes --set changeps1 no
- conda update -q conda
- conda info -a
- conda install xtensor=0.13.0 -c conda-forge
- conda install xtensor=0.13.2 -c conda-forge
- conda install gtest cmake -c conda-forge
- conda install m2w64-openblas -c msys2
# Patch OpenBLASConfig.cmake
Expand Down
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ install:
- conda info -a
- conda install gtest cmake -c conda-forge
# Install xtensor and BLAS
- conda install xtensor=0.13.0 -c conda-forge
- conda install xtensor=0.13.2 -c conda-forge
- if [[ "$BLAS" == "OpenBLAS" ]]; then
conda install openblas -c conda-forge;
elif [[ "$BLAS" == "mkl" ]]; then
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ which are also available on conda-forge.

| `xtensor-blas` | `xtensor` |
|-----------------|-----------|
| master | ^0.13.0 |
| master | ^0.13.2 |
| 0.6.0 | ^0.13.0 |
| 0.5.0 | ^0.11.0 |
| 0.3.1 | ^0.10.2 |
Expand Down
99 changes: 50 additions & 49 deletions include/xtensor-blas/xblas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,92 +28,92 @@ namespace xt
namespace blas
{
/**
* Calculate the dot product between two vectors, conjugating
* the first argument \em a in the case of complex vectors.
* Calculate the 1-norm of a vector
*
* @param a vector of n elements
* @param b vector of n elements
* @returns scalar result
*/
template <class E1, class E2, class R>
void dot(const xexpression<E1>& a, const xexpression<E2>& b,
R& result)
template <class E, class R>
void asum(const xexpression<E>& a, R& result)
{
auto&& ad = view_eval<E1::static_layout>(a.derived_cast());
auto&& bd = view_eval<E1::static_layout>(b.derived_cast());
auto&& ad = view_eval<E::static_layout>(a.derived_cast());
XTENSOR_ASSERT(ad.dimension() == 1);

cxxblas::dot<BLAS_IDX>(
cxxblas::asum<BLAS_IDX>(
(BLAS_IDX) ad.shape()[0],
ad.raw_data() + ad.raw_data_offset(),
(BLAS_IDX) ad.strides().front(),
bd.raw_data() + bd.raw_data_offset(),
(BLAS_IDX) bd.strides().front(),
result
);
}

/**
* Calculate the dot product between two complex vectors, not conjugating the
* first argument \em a.
* Calculate the 2-norm of a vector
*
* @param a vector of n elements
* @param b vector of n elements
* @returns scalar result
*/
template <class E1, class E2, class R>
void dotu(const xexpression<E1>& a, const xexpression<E2>& b, R& result)
template <class E, class R>
void nrm2(const xexpression<E>& a, R& result)
{
auto&& ad = view_eval<E1::static_layout>(a.derived_cast());
auto&& bd = view_eval<E1::static_layout>(b.derived_cast());
auto&& ad = view_eval<E::static_layout>(a.derived_cast());
XTENSOR_ASSERT(ad.dimension() == 1);

cxxblas::dotu<BLAS_IDX>(
cxxblas::nrm2<BLAS_IDX>(
(BLAS_IDX) ad.shape()[0],
ad.raw_data() + ad.raw_data_offset(),
(BLAS_IDX) ad.strides().front(),
bd.raw_data() + bd.raw_data_offset(),
(BLAS_IDX) bd.strides().front(),
result
);
}

/**
* Calculate the 1-norm of a vector
* Calculate the dot product between two vectors, conjugating
* the first argument \em a in the case of complex vectors.
*
* @param a vector of n elements
* @param b vector of n elements
* @returns scalar result
*/
template <class E, class R>
void asum(const xexpression<E>& a, R& result)
template <class E1, class E2, class R>
void dot(const xexpression<E1>& a, const xexpression<E2>& b,
R& result)
{
auto&& ad = view_eval<E::static_layout>(a.derived_cast());
auto&& ad = view_eval<E1::static_layout>(a.derived_cast());
auto&& bd = view_eval<E2::static_layout>(b.derived_cast());
XTENSOR_ASSERT(ad.dimension() == 1);

cxxblas::asum<BLAS_IDX>(
cxxblas::dot<BLAS_IDX>(
(BLAS_IDX) ad.shape()[0],
ad.raw_data() + ad.raw_data_offset(),
(BLAS_IDX) ad.strides().front(),
bd.raw_data() + bd.raw_data_offset(),
(BLAS_IDX) bd.strides().front(),
result
);
}

/**
* Calculate the 2-norm of a vector
* Calculate the dot product between two complex vectors, not conjugating the
* first argument \em a.
*
* @param a vector of n elements
* @param b vector of n elements
* @returns scalar result
*/
template <class E, class R>
void nrm2(const xexpression<E>& a, R& result)
template <class E1, class E2, class R>
void dotu(const xexpression<E1>& a, const xexpression<E2>& b, R& result)
{
auto&& ad = view_eval<E::static_layout>(a.derived_cast());
auto&& ad = view_eval<E1::static_layout>(a.derived_cast());
auto&& bd = view_eval<E2::static_layout>(b.derived_cast());
XTENSOR_ASSERT(ad.dimension() == 1);

cxxblas::nrm2<BLAS_IDX>(
cxxblas::dotu<BLAS_IDX>(
(BLAS_IDX) ad.shape()[0],
ad.raw_data() + ad.raw_data_offset(),
(BLAS_IDX) ad.strides().front(),
bd.raw_data() + bd.raw_data_offset(),
(BLAS_IDX) bd.strides().front(),
result
);
}
Expand All @@ -136,10 +136,10 @@ namespace blas
const value_type& beta = value_type(0.0))
{
auto&& dA = view_eval<E1::static_layout>(A.derived_cast());
auto&& dx = view_eval<E1::static_layout>(x.derived_cast());
auto&& dx = view_eval<E2::static_layout>(x.derived_cast());

cxxblas::gemv<BLAS_IDX>(
get_blas_storage_order(dA),
get_blas_storage_order(result),
transpose_A ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans,
(BLAS_IDX) dA.shape()[0],
(BLAS_IDX) dA.shape()[1],
Expand Down Expand Up @@ -168,31 +168,32 @@ namespace blas
*/
template <class E, class F, class R, class value_type = typename E::value_type>
void gemm(const xexpression<E>& A, const xexpression<F>& B, R& result,
bool transpose_A = false,
bool transpose_B = false,
char transpose_A = false,
char transpose_B = false,
const value_type& alpha = value_type(1.0),
const value_type& beta = value_type(0.0))
{
auto&& da = view_eval<E::static_layout>(A.derived_cast());
auto&& db = view_eval<E::static_layout>(B.derived_cast());
static_assert(R::static_layout != layout_type::dynamic, "GEMM result layout cannot be dynamic.");
auto&& dA = view_eval<R::static_layout>(A.derived_cast());
auto&& dB = view_eval<R::static_layout>(B.derived_cast());

XTENSOR_ASSERT(da.layout() == db.layout());
XTENSOR_ASSERT(result.layout() == da.layout());
XTENSOR_ASSERT(da.dimension() == 2);
XTENSOR_ASSERT(db.dimension() == 2);
XTENSOR_ASSERT(dA.layout() == dB.layout());
XTENSOR_ASSERT(result.layout() == dA.layout());
XTENSOR_ASSERT(dA.dimension() == 2);
XTENSOR_ASSERT(dB.dimension() == 2);

cxxblas::gemm<BLAS_IDX>(
get_blas_storage_order(da),
get_blas_storage_order(result),
transpose_A ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans,
transpose_B ? cxxblas::Transpose::Trans : cxxblas::Transpose::NoTrans,
(BLAS_IDX) (transpose_A ? da.shape()[1] : da.shape()[0]),
(BLAS_IDX) (transpose_B ? db.shape()[0] : db.shape()[1]),
(BLAS_IDX) (transpose_B ? db.shape()[1] : db.shape()[0]),
(BLAS_IDX) (transpose_A ? dA.shape()[1] : dA.shape()[0]),
(BLAS_IDX) (transpose_B ? dB.shape()[0] : dB.shape()[1]),
(BLAS_IDX) (transpose_B ? dB.shape()[1] : dB.shape()[0]),
alpha,
da.raw_data() + da.raw_data_offset(),
get_leading_stride(da),
db.raw_data() + db.raw_data_offset(),
get_leading_stride(db),
dA.raw_data() + dA.raw_data_offset(),
get_leading_stride(dA),
dB.raw_data() + dB.raw_data_offset(),
get_leading_stride(dB),
beta,
result.raw_data() + result.raw_data_offset(),
get_leading_stride(result)
Expand Down
10 changes: 5 additions & 5 deletions include/xtensor-blas/xblas_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace xt
inline auto view_eval(T&& t)
-> std::enable_if_t<has_raw_data_interface<T>::value && std::decay_t<T>::static_layout == L, T&&>
{
return t;
return std::forward<T>(t);
}

template <layout_type L = layout_type::row_major, class T, class I = std::decay_t<T>>
Expand Down Expand Up @@ -82,16 +82,16 @@ namespace xt
* Get leading stride
*/

template <class A, std::enable_if_t<A::static_layout == layout_type::column_major>* = nullptr>
template <class A, std::enable_if_t<A::static_layout == layout_type::row_major>* = nullptr>
inline BLAS_IDX get_leading_stride(const A& a)
{
return (BLAS_IDX) (a.strides().back() == 0 ? a.shape().front() : a.strides().back());
return (BLAS_IDX) (a.strides().front() == 0 ? a.shape().back() : a.strides().front());
}

template <class A, std::enable_if_t<A::static_layout == layout_type::row_major>* = nullptr>
template <class A, std::enable_if_t<A::static_layout == layout_type::column_major>* = nullptr>
inline BLAS_IDX get_leading_stride(const A& a)
{
return (BLAS_IDX) (a.strides().front() == 0 ? a.shape().back() : a.strides().front());
return (BLAS_IDX) (a.strides().back() == 0 ? a.shape().front() : a.strides().back());
}

template <class A, std::enable_if_t<A::static_layout != layout_type::row_major && A::static_layout != layout_type::column_major>* = nullptr>
Expand Down

0 comments on commit 363a592

Please sign in to comment.