Skip to content

Commit

Permalink
Enable matrix block elementwise product
Browse files Browse the repository at this point in the history
and replace old lining::elementwise_square
  • Loading branch information
OXPHOS committed Feb 7, 2017
1 parent f30a150 commit be8d60b
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 430 deletions.
5 changes: 3 additions & 2 deletions src/shogun/kernel/CustomKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <shogun/features/DummyFeatures.h>
#include <shogun/features/IndexFeatures.h>
#include <shogun/io/SGIO.h>
#include <shogun/mathematics/linalg/LinalgNamespace.h>

using namespace shogun;

Expand Down Expand Up @@ -293,9 +294,9 @@ SGMatrix<float64_t> CCustomKernel::row_wise_sum_squared_sum_symmetric_block(
SGVector<float32_t> sum=rowwise_sum<Backend::EIGEN3>(block(kmatrix,
block_begin, block_begin, block_size, block_size), no_diag);

auto kmatrix_block = block(kmatrix, block_begin, block_begin, block_size, block_size);
SGVector<float32_t> sq_sum=rowwise_sum<Backend::EIGEN3>(
elementwise_square<Backend::EIGEN3>(block(kmatrix,
block_begin, block_begin, block_size, block_size)), no_diag);
element_prod(kmatrix_block, kmatrix_block), no_diag);

for (index_t i=0; i<sum.vlen; ++i)
row_sum(i, 0)=sum[i];
Expand Down
14 changes: 14 additions & 0 deletions src/shogun/mathematics/linalg/LinalgBackendBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,20 @@ class LinalgBackendBase
DEFINE_FOR_ALL_PTYPE(BACKEND_GENERIC_IN_PLACE_ELEMENT_PROD, SGMatrix)
#undef BACKEND_GENERIC_IN_PLACE_ELEMENT_PROD

/**
* Wrapper method of in-place matrix block elementwise product.
*
* @see linalg::element_prod
*/
#define BACKEND_GENERIC_IN_PLACE_BLOCK_ELEMENT_PROD(Type, Container) \
virtual void element_prod(linalg::Block<Container<Type>>& a, \
linalg::Block<Container<Type>>& b, Container<Type>& result) const \
{ \
SG_SNOTIMPLEMENTED; \
}
DEFINE_FOR_ALL_PTYPE(BACKEND_GENERIC_IN_PLACE_BLOCK_ELEMENT_PROD, SGMatrix)
#undef BACKEND_GENERIC_IN_PLACE_BLOCK_ELEMENT_PROD

/**
* Wrapper method of matrix product method.
*
Expand Down
27 changes: 27 additions & 0 deletions src/shogun/mathematics/linalg/LinalgBackendEigen.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ class LinalgBackendEigen : public LinalgBackendBase
DEFINE_FOR_ALL_PTYPE(BACKEND_GENERIC_IN_PLACE_ELEMENT_PROD, SGMatrix)
#undef BACKEND_GENERIC_IN_PLACE_ELEMENT_PROD

/** Implementation of @see LinalgBackendBase::element_prod */
#define BACKEND_GENERIC_IN_PLACE_BLOCK_ELEMENT_PROD(Type, Container) \
virtual void element_prod(linalg::Block<Container<Type>>& a, \
linalg::Block<Container<Type>>& b, Container<Type>& result) const \
{ \
element_prod_impl(a, b, result); \
}
DEFINE_FOR_ALL_PTYPE(BACKEND_GENERIC_IN_PLACE_BLOCK_ELEMENT_PROD, SGMatrix)
#undef BACKEND_GENERIC_IN_PLACE_BLOCK_ELEMENT_PROD

/** Implementation of @see LinalgBackendBase::matrix_prod */
#define BACKEND_GENERIC_IN_PLACE_MATRIX_PROD(Type, Container) \
virtual void matrix_prod(Container<Type>& a, Container<Type>& b,\
Expand Down Expand Up @@ -322,6 +332,23 @@ class LinalgBackendEigen : public LinalgBackendBase
result_eig = a_eig.array() * b_eig.array();
}

/** Eigen3 matrix block in-place elementwise product method */
template <typename T>
void element_prod_impl(linalg::Block<SGMatrix<T>>& a,
linalg::Block<SGMatrix<T>>& b, SGMatrix<T>& result) const
{
typename SGMatrix<T>::EigenMatrixXtMap a_eig = a.m_matrix;
typename SGMatrix<T>::EigenMatrixXtMap b_eig = b.m_matrix;
typename SGMatrix<T>::EigenMatrixXtMap result_eig = result;

Eigen::Block<typename SGMatrix<T>::EigenMatrixXtMap> a_block =
a_eig.block(a.m_row_begin, a.m_col_begin, a.m_row_size, a.m_col_size);
Eigen::Block<typename SGMatrix<T>::EigenMatrixXtMap> b_block =
b_eig.block(b.m_row_begin, b.m_col_begin, b.m_row_size, b.m_col_size);

result_eig = a_block.array() * b_block.array();
}

/** Eigen3 matrix in-place product method */
template <typename T>
void matrix_prod_impl(SGMatrix<T>& a, SGMatrix<T>& b, SGMatrix<T>& result,
Expand Down
50 changes: 50 additions & 0 deletions src/shogun/mathematics/linalg/LinalgNamespace.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,56 @@ T dot(const SGVector<T>& a, const SGVector<T>& b)
return infer_backend(a, b)->dot(a, b);
}

/** Performs the operation C = A .* B where ".*" denotes elementwise multiplication
* on matrix blocks.
*
* This version returns the result in-place.
* User should pass an appropriately allocate memory matrix
* Or pass one of the operands arguments (A or B) as a result
*
* @param a First matrix block
* @param b Second matrix block
* @param c Result matrix
*/
template <typename T>
void element_prod(Block<SGMatrix<T>>& a, Block<SGMatrix<T>>& b, SGMatrix<T>& result)
{
REQUIRE(a.m_row_size == b.m_row_size && a.m_col_size == b.m_col_size,
"Dimension mismatch! A(%d x %d) vs B(%d x %d)\n",
a.m_row_size, a.m_col_size, b.m_row_size, b.m_col_size);
REQUIRE(a.m_row_size == result.num_rows && a.m_col_size == result.num_cols,
"Dimension mismatch! A(%d x %d) vs result(%d x %d)\n",
a.m_row_size, a.m_col_size, result.num_rows, result.num_cols);

REQUIRE(!result.on_gpu(), "Cannot operate with matrix result on_gpu (%d) \
as matrix blocks are on CPU.\n", result.on_gpu());
sg_linalg->get_cpu_backend()->element_prod(a, b, result);
}

/** Performs the operation C = A .* B where ".*" denotes elementwise multiplication
* on matrix blocks.
*
* This version returns the result in a newly created matrix.
*
* @param A First matrix block
* @param B Second matrix block
* @return The result of the operation
*/
template <typename T>
SGMatrix<T> element_prod(Block<SGMatrix<T>>& a, Block<SGMatrix<T>>& b)
{
REQUIRE(a.m_row_size == b.m_row_size && a.m_col_size == b.m_col_size,
"Dimension mismatch! A(%d x %d) vs B(%d x %d)\n",
a.m_row_size, a.m_col_size, b.m_row_size, b.m_col_size);

SGMatrix<T> result(a.m_row_size, a.m_col_size);
result.zero();

element_prod(a, b, result);

return result;
}

/** Performs the operation C = A .* B where ".*" denotes elementwise multiplication.
*
* This version returns the result in-place.
Expand Down

This file was deleted.

0 comments on commit be8d60b

Please sign in to comment.