Skip to content

Commit

Permalink
Merge pull request #79 from wolfv/fix_scal_initialization
Browse files Browse the repository at this point in the history
fix scal initialization
  • Loading branch information
wolfv committed Aug 11, 2018
2 parents 693b3e4 + 47790b3 commit f2e9109
Show file tree
Hide file tree
Showing 21 changed files with 113 additions and 19 deletions.
26 changes: 26 additions & 0 deletions include/xflens/cxxblas/level1/scal.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@ scal_generic(IndexType n, const ALPHA &alpha, Y *y, IndexType incY)
}
}

template <typename IndexType, typename ALPHA, typename Y>
void
scal_init_generic(IndexType n, const ALPHA &alpha, Y *y, IndexType incY)
{
CXXBLAS_DEBUG_OUT("scal_init_generic");

if (alpha == ALPHA(0)) {
for (IndexType i=0, iY=0; i<n; ++i, iY+=incY) {
y[iY] = 0;
}
}
else {
scal_generic(n, alpha, y, incY);
}
}

template <typename IndexType, typename ALPHA, typename Y>
void
scal(IndexType n, const ALPHA &alpha, Y *y, IndexType incY)
Expand All @@ -58,6 +74,16 @@ scal(IndexType n, const ALPHA &alpha, Y *y, IndexType incY)
scal_generic(n, alpha, y, incY);
}

template <typename IndexType, typename ALPHA, typename Y>
void
scal_init(IndexType n, const ALPHA &alpha, Y *y, IndexType incY)
{
if (incY<0) {
y -= incY*(n-1);
}
scal_init_generic(n, alpha, y, incY);
}

#ifdef HAVE_CBLAS

// sscal
Expand Down
2 changes: 1 addition & 1 deletion include/xflens/cxxblas/level1extensions/acxpby.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ acxpby(IndexType n, const ALPHA &alpha, const X *x,
{
CXXBLAS_DEBUG_OUT("acxpby_generic");

scal(n, beta, y, incY);
scal_init(n, beta, y, incY);
acxpy(n, alpha, x, incX, y, incY);
}

Expand Down
2 changes: 1 addition & 1 deletion include/xflens/cxxblas/level1extensions/axpby.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ axpby(IndexType n, const ALPHA &alpha, const X *x, IndexType incX,
{
CXXBLAS_DEBUG_OUT("axpby_generic");

scal(n, beta, y, incY);
scal_init(n, beta, y, incY);
axpy(n, alpha, x, incX, y, incY);
}

Expand Down
6 changes: 6 additions & 0 deletions include/xflens/cxxblas/level1extensions/gescal.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@

namespace cxxblas {

template <typename IndexType, typename ALPHA, typename MA>
void
gescal_init(StorageOrder order,
IndexType m, IndexType n,
const ALPHA &alpha, MA *A, IndexType ldA);

template <typename IndexType, typename ALPHA, typename MA>
void
gescal(StorageOrder order,
Expand Down
23 changes: 23 additions & 0 deletions include/xflens/cxxblas/level1extensions/gescal.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,29 @@

namespace cxxblas {

template <typename IndexType, typename ALPHA, typename MA>
void
gescal_init(StorageOrder order,
IndexType m, IndexType n,
const ALPHA &alpha, MA *A, IndexType ldA)
{
CXXBLAS_DEBUG_OUT("gescal_generic");

if (order==ColMajor) {
std::swap(m,n);
}
if (ldA==n) {
scal_init(m*n, alpha, A, IndexType(1));
return;
} else {
for (IndexType i=0; i<m; ++i) {
scal_init(n, alpha, A+i*ldA, IndexType(1));
}
return;
}
}


template <typename IndexType, typename ALPHA, typename MA>
void
gescal(StorageOrder order,
Expand Down
25 changes: 25 additions & 0 deletions include/xflens/cxxblas/level1extensions/syscal.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,31 @@ syscal(StorageOrder order, StorageUpLo upLo,
}
}

template <typename IndexType, typename ALPHA, typename MA>
void
syscal_init(StorageOrder order, StorageUpLo upLo,
IndexType n,
const ALPHA &alpha, MA *A, IndexType ldA)
{
CXXBLAS_DEBUG_OUT("syscal_generic");

if (alpha==ALPHA(1)) {
return;
}
if (order==ColMajor) {
upLo = (upLo==Upper) ? Lower : Upper;
}
if (upLo==Upper) {
for (IndexType i=0; i<n; ++i) {
scal_init(n-i, alpha, A+i*(ldA+1), IndexType(1));
}
} else {
for (IndexType i=0; i<n; ++i) {
scal_init(i+1, alpha, A+i*ldA, IndexType(1));
}
}
}

} // namespace cxxblas

#endif // CXXBLAS_LEVEL1EXTENSIONS_SYSCAL_TCC
4 changes: 2 additions & 2 deletions include/xflens/cxxblas/level2/gbmv.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ gbmv_generic(StorageOrder order, Transpose trans, Transpose conjX,
if (incY<0) {
y -= incY*(m-1);
}
scal_generic(m, beta, y, incY);
scal_init_generic(m, beta, y, incY);
if (trans==NoTrans) {
if (conjX==NoTrans) {
for (IndexType i=0, iY=0; i<m; ++i, iY+=incY) {
Expand Down Expand Up @@ -122,7 +122,7 @@ gbmv_generic(StorageOrder order, Transpose trans, Transpose conjX,
if (incY<0) {
y -= incY*(n-1);
}
scal_generic(n, beta, y, incY);
scal_init_generic(n, beta, y, incY);
if (trans==Trans) {
if (conjX == NoTrans) {
for (IndexType i=0, iX=0; i<m; ++i, iX+=incX) {
Expand Down
5 changes: 2 additions & 3 deletions include/xflens/cxxblas/level2/gemv.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ gemv_generic(StorageOrder order, Transpose transA, Transpose conjX,
if (incY<0) {
y -= incY*(m-1);
}

scal_generic(m, beta, y, incY);
scal_init_generic(m, beta, y, incY);
if (conjX==NoTrans) {
if (transA==Conj) {
for (IndexType i=0, iY=0; i<m; ++i, iY+=incY) {
Expand Down Expand Up @@ -101,7 +100,7 @@ gemv_generic(StorageOrder order, Transpose transA, Transpose conjX,
y -= incY*(n-1);
}

scal_generic(n, beta, y, incY);
scal_init_generic(n, beta, y, incY);
if (conjX==NoTrans) {
if (transA==ConjTrans) {
for (IndexType i=0, iY=0; i<n; ++i, iY+=incY) {
Expand Down
2 changes: 1 addition & 1 deletion include/xflens/cxxblas/level2/hbmv.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ hbmv_generic(StorageOrder order, StorageUpLo upLo, Transpose conjugateA,
upLo = (upLo==Upper) ? Lower : Upper;
conjugateA = Transpose(conjugateA^Conj);
}
scal_generic(n, beta, y, incY);
scal_init_generic(n, beta, y, incY);
if (upLo==Upper) {
if (conjugateA==Conj) {
for (IndexType i=0, iX=0, iY=0; i<n; ++i, iX+=incX, iY+=incY) {
Expand Down
2 changes: 1 addition & 1 deletion include/xflens/cxxblas/level2/hemv.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ hemv_generic(StorageOrder order, StorageUpLo upLo, Transpose conjugateA,
upLo = (upLo==Upper) ? Lower : Upper;
conjugateA = Transpose(conjugateA^Conj);
}
scal_generic(n, beta, y, incY);
scal_init_generic(n, beta, y, incY);
if (upLo==Upper) {
if (conjugateA==Conj) {
for (IndexType i=0, iX=0, iY=0; i<n; ++i, iX+=incX, iY+=incY) {
Expand Down
2 changes: 1 addition & 1 deletion include/xflens/cxxblas/level2/hpmv.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ hpmv_generic(StorageOrder order, StorageUpLo upLo, Transpose conjugateA,
upLo = (upLo==Upper) ? Lower : Upper;
conjugateA = Transpose(conjugateA^Conj);
}
scal_generic(n, beta, y, incY);
scal_init_generic(n, beta, y, incY);
if (upLo==Upper) {
if (conjugateA==Conj) {
for (IndexType i=0, iY=0, iX=0; i<n; ++i, iX+=incX, iY+=incY) {
Expand Down
2 changes: 1 addition & 1 deletion include/xflens/cxxblas/level2/sbmv.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ sbmv_generic(StorageOrder order, StorageUpLo upLo,
return;
}

scal_generic(n, beta, y, incY);
scal_init_generic(n, beta, y, incY);
if (upLo==Upper) {
for (IndexType i=0, iX=0, iY=0; i<n; ++i, iX+=incX, iY+=incY) {
IndexType len = min(k+1, n-i);
Expand Down
2 changes: 1 addition & 1 deletion include/xflens/cxxblas/level2/spmv.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ spmv_generic(StorageOrder order, StorageUpLo upLo,
if (order==ColMajor) {
upLo = (upLo==Upper) ? Lower : Upper;
}
scal_generic(n, beta, y, incY);
scal_init_generic(n, beta, y, incY);
if (upLo==Upper) {
for (IndexType i=0, iY=0, iX=0; i<n; ++i, iX+=incX, iY+=incY) {
VY y_ = VY(0);
Expand Down
2 changes: 1 addition & 1 deletion include/xflens/cxxblas/level2/symv.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ symv_generic(StorageOrder order, StorageUpLo upLo,
if (order==ColMajor) {
upLo = (upLo==Upper) ? Lower : Upper;
}
scal_generic(n, beta, y, incY);
scal_init_generic(n, beta, y, incY);
if (upLo==Upper) {
for (IndexType i=0, iY=0; i<n; ++i, iY+=incY) {
VY y_ = VY(0);
Expand Down
2 changes: 1 addition & 1 deletion include/xflens/cxxblas/level3/gemm.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ gemm_generic(StorageOrder order,
return;
}

gescal(order, m, n, beta, C, ldC);
gescal_init(order, m, n, beta, C, ldC);
if (alpha==ALPHA(0)) {
return;
}
Expand Down
2 changes: 1 addition & 1 deletion include/xflens/cxxblas/level3/hemm.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ hemm_generic(StorageOrder order,
C, ldC);
return;
}
gescal(order, m, n, beta, C, ldC);
gescal_init(order, m, n, beta, C, ldC);
if (sideA==Right) {
for (IndexType i=0; i<m; ++i) {
hemv(order, upLoA, Conj, n, alpha, A, ldA, B+i*ldB, IndexType(1),
Expand Down
2 changes: 1 addition & 1 deletion include/xflens/cxxblas/level3/symm.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ symm_generic(StorageOrder order, Side sideA, StorageUpLo upLoA,
C, ldC);
return;
}
gescal(order, m, n, beta, C, ldC);
gescal_init(order, m, n, beta, C, ldC);
if (sideA==Right) {
for (IndexType i=0; i<m; ++i) {
symv(order, upLoA, n, alpha, A, ldA, B+i*ldB, IndexType(1),
Expand Down
2 changes: 1 addition & 1 deletion include/xflens/cxxblas/level3/syr2k.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ syr2k_generic(StorageOrder order, StorageUpLo upLoC,
alpha, A, ldA, B, ldB, beta, C, ldC);
return;
}
syscal(order, upLoC, n, beta, C, ldC);
syscal_init(order, upLoC, n, beta, C, ldC);
if (k==0) {
return;
}
Expand Down
2 changes: 1 addition & 1 deletion include/xflens/cxxblas/level3/syrk.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ syrk_generic(StorageOrder order, StorageUpLo upLoC,
alpha, A, ldA, beta, C, ldC);
return;
}
syscal(order, upLoC, n, beta, C, ldC);
syscal_init(order, upLoC, n, beta, C, ldC);
if (k==0) {
return;
}
Expand Down
2 changes: 1 addition & 1 deletion include/xflens/cxxblas/level3/trmm.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ trmm_generic(StorageOrder order, Side sideA, StorageUpLo upLoA,
trmv(order, upLoA, transA, diagA, m, A, ldA, B+j, ldB);
}
}
gescal(order, m, n, alpha, B, ldB);
gescal_init(order, m, n, alpha, B, ldB);
}

template <typename IndexType, typename ALPHA, typename MA, typename MB>
Expand Down
15 changes: 15 additions & 0 deletions test/test_blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,21 @@ namespace xt
EXPECT_TRUE(all(equal(expected, t3)));
}

TEST(xblas, nan_result)
{
xt::xarray<double> X = {{1, 2, 3},
{1, 2, 3}};

auto M = xt::xarray<double>::from_shape({3, 3});
M(0, 0) = std::numeric_limits<double>::quiet_NaN();
M(0, 1) = std::numeric_limits<double>::quiet_NaN();
xt::blas::gemm(X, X, M, true, false, 1.0, 0.0);
for (std::size_t i = 0; i < M.size(); ++i)
{
EXPECT_FALSE(std::isnan(M.storage()[i]));
}
}

TEST(xblas, gemm_transpose)
{
xt::xarray<double> X = {{1, 2, 3},
Expand Down

0 comments on commit f2e9109

Please sign in to comment.