Skip to content

Commit

Permalink
Merge pull request #78 from wolfv/fix_inv_single_el
Browse files Browse the repository at this point in the history
Fix inverse on 1x1 matrix
  • Loading branch information
wolfv committed Aug 11, 2018
2 parents 161d74b + cbef1d6 commit 693b3e4
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 59 deletions.
12 changes: 6 additions & 6 deletions include/xtensor-blas/xblas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace blas
cxxblas::asum<blas_index_t>(
static_cast<blas_index_t>(ad.shape()[0]),
ad.data() + ad.data_offset(),
static_cast<blas_index_t>(ad.strides().front()),
stride_front(ad),
result
);
}
Expand All @@ -62,7 +62,7 @@ namespace blas
cxxblas::nrm2<blas_index_t>(
static_cast<blas_index_t>(ad.shape()[0]),
ad.data() + ad.data_offset(),
static_cast<blas_index_t>(ad.strides().front()),
stride_front(ad),
result
);
}
Expand All @@ -86,9 +86,9 @@ namespace blas
cxxblas::dot<blas_index_t>(
static_cast<blas_index_t>(ad.shape()[0]),
ad.data() + ad.data_offset(),
static_cast<blas_index_t>(ad.strides().front()),
stride_front(ad),
bd.data() + bd.data_offset(),
static_cast<blas_index_t>(bd.strides().front()),
stride_front(bd),
result
);
}
Expand All @@ -111,9 +111,9 @@ namespace blas
cxxblas::dotu<blas_index_t>(
static_cast<blas_index_t>(ad.shape()[0]),
ad.data() + ad.data_offset(),
static_cast<blas_index_t>(ad.strides().front()),
stride_front(ad),
bd.data() + bd.data_offset(),
static_cast<blas_index_t>(bd.strides().front()),
stride_front(bd),
result
);
}
Expand Down
34 changes: 34 additions & 0 deletions include/xtensor-blas/xblas_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,40 @@ namespace xt
DEFAULT_LEADING_STRIDE_BEHAVIOR;
}

/********************
* Get front stride *
********************/

template <class E>
inline blas_index_t stride_front(const E& e)
{
if (E::static_layout == layout_type::column_major)
{
return blas_index_t(1);
}
else
{
return static_cast<blas_index_t>(e.strides().front() == 0 ? 1 : e.strides().front());
}
}

/*******************
* Get back stride *
*******************/

template <class E>
inline blas_index_t stride_back(const E& e)
{
if (E::static_layout == layout_type::row_major)
{
return blas_index_t(1);
}
else
{
return static_cast<blas_index_t>(e.strides().back() == 0 ? 1 : e.strides().back());
}
}

/*******************************
* is_xfunction implementation *
*******************************/
Expand Down

0 comments on commit 693b3e4

Please sign in to comment.