Skip to content

Commit

Permalink
add more tests, fix dot broadcasting, and warnings (#126)
Browse files Browse the repository at this point in the history
add more tests, fix dot broadcasting, and warnings
  • Loading branch information
wolfv committed Jul 11, 2019
1 parent bfc5e86 commit a743e89
Show file tree
Hide file tree
Showing 8 changed files with 577 additions and 194 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ matrix:
- os: osx
osx_image: xcode8
compiler: clang
env: BLAS=OpenBLAS
env: BLAS=mkl
env:
global:
- MINCONDA_VERSION="latest"
Expand Down Expand Up @@ -147,7 +147,7 @@ install:
# Install xtensor and BLAS
- conda install xtensor=0.20.6 -c conda-forge
- if [[ "$BLAS" == "OpenBLAS" ]]; then
conda install openblas -c conda-forge;
conda install openblas "libblas * *openblas" -c conda-forge;
elif [[ "$BLAS" == "mkl" ]]; then
conda install mkl;
fi
Expand Down
8 changes: 4 additions & 4 deletions include/xtensor-blas/xlapack.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ namespace lapack
std::size_t n = A.shape()[1];

xtype1 s;
s.resize({ std::max(std::size_t(1), std::min(m, n)) });
s.resize({ std::max(static_cast<std::size_t>(1), std::min(m, n)) });

xtype2 u, vt;

Expand Down Expand Up @@ -348,7 +348,7 @@ namespace lapack
}

xtype1 s;
s.resize({ std::max(std::size_t(1), std::min(m, n)) });
s.resize({ std::max(static_cast<std::size_t>(1), std::min(m, n)) });

xtype2 u, vt;

Expand Down Expand Up @@ -446,7 +446,7 @@ namespace lapack
throw std::runtime_error("Could not find workspace size for getri.");
}

work.resize(std::size_t(std::real(work[0])));
work.resize(static_cast<std::size_t>(std::real(work[0])));

info = cxxlapack::getri<blas_index_t>(
static_cast<blas_index_t>(A.shape()[0]),
Expand Down Expand Up @@ -496,7 +496,7 @@ namespace lapack
throw std::runtime_error("Could not find workspace size for geev.");
}

work.resize(std::size_t(work[0]));
work.resize(static_cast<std::size_t>(work[0]));

info = cxxlapack::geev<blas_index_t>(
jobvl,
Expand Down
21 changes: 14 additions & 7 deletions include/xtensor-blas/xlinalg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ namespace linalg
{
result += std::abs(std::pow(v(i), ord));
}
result = std::pow(result, 1./ static_cast<double>(ord));
result = std::pow(result, 1. / static_cast<double>(ord));
}
return result;
}
Expand Down Expand Up @@ -583,9 +583,11 @@ namespace linalg
: m_a(a), m_axis(axis)
{
resize_container(m_idx, a.dimension());
std::fill(m_idx.begin(), m_idx.end(), 0);
m_offset = 0;
}

#define SC(X) static_cast<std::size_t>(X)
inline bool next()
{
size_type dim = static_cast<size_type>(m_a.dimension());
Expand All @@ -596,9 +598,9 @@ namespace linalg
{
// skip
}
else if (m_idx[i] == m_a.shape()[i] - 1)
else if (m_idx[SC(i)] == m_a.shape()[SC(i)] - 1)
{
m_offset -= m_idx[i] * static_cast<size_type>(m_a.strides()[i]);
m_offset -= static_cast<std::ptrdiff_t>(m_idx[i]) * m_a.strides()[i];
m_idx[i] = size_type(0);
if (i == 0 || m_axis == 0 && i == 1)
{
Expand All @@ -608,14 +610,14 @@ namespace linalg
else
{
++m_idx[i];
m_offset += static_cast<size_type>(m_a.strides()[i]);
m_offset += m_a.strides()[i];
return true;
}
}
return false;
}

inline size_type offset() const
#undef SC
inline std::ptrdiff_t offset() const
{
return m_offset;
}
Expand All @@ -624,7 +626,7 @@ namespace linalg
const A& m_a;
index_type m_idx;
size_type m_axis;
size_type m_offset;
std::ptrdiff_t m_offset;
};
}

Expand Down Expand Up @@ -656,6 +658,11 @@ namespace linalg
auto&& t = view_eval<T::static_layout>(xt.derived_cast());
auto&& o = view_eval<O::static_layout>(xo.derived_cast());

// is one of each a scalar? just mulyiply
if (t.dimension() == 0 || o.dimension() == 0)
{
return return_type(t * o);
}
if (t.dimension() == 1 && o.dimension() == 1)
{
result.resize(std::vector<std::size_t>{1});
Expand Down
4 changes: 4 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ set(XTENSOR_BLAS_TESTS
test_qr.cpp
test_dot.cpp
test_tensordot.cpp
test_lstsq.cpp
test_dot.cpp
test_dot_extended.cpp
test_qr.cpp
)

add_executable(test_xtensor_blas ${XTENSOR_BLAS_TESTS} ${XTENSOR_BLAS_HEADERS} ${XTENSOR_HEADERS})
Expand Down

0 comments on commit a743e89

Please sign in to comment.