Skip to content

Commit

Permalink
Add trace_dot in linalg (#4176)
Browse files Browse the repository at this point in the history
* Implemented trace_dot in linalg
* Use trace_dot in LMNN
  • Loading branch information
vinx13 authored and karlnapf committed Feb 28, 2018
1 parent eb794d6 commit 9197c72
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
13 changes: 13 additions & 0 deletions src/shogun/mathematics/linalg/LinalgNamespace.h
Original file line number Diff line number Diff line change
Expand Up @@ -1690,6 +1690,19 @@ namespace shogun
return infer_backend(A)->transpose_matrix(A);
}

/**
* Method that computes the trace of \f$AB\f$ as \f$sum(A.*B')$\f
*
* @param A The matrix A
* @param B The matrix B
* @return The trace of the product of A and B
*/
template <typename T>
T trace_dot(const SGMatrix<T>& A, const SGMatrix<T>& B)
{
return sum(element_prod(A, B, false, true));
}

/**
* Solve the linear equations \f$Lx=b\f$,
* where \f$L\f$ is a triangular matrix.
Expand Down
7 changes: 2 additions & 5 deletions src/shogun/metric/LMNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,10 @@ void CLMNN::train(SGMatrix<float64_t> init_transform)

// Compute the objective, trace of Mahalanobis distance matrix (L squared) times the gradient
// plus the number of current impostors to account for the margin
// trace of the product of two matrices computed fast using
// trace(A*B)=sum(A.*B')
SG_DEBUG("Computing objective.\n")
obj[iter] = m_regularization * cur_impostors.size();
obj[iter] += linalg::trace(
linalg::element_prod(
linalg::matrix_prod(L, L, true, false), gradient));
obj[iter] +=
linalg::trace_dot(linalg::matrix_prod(L, L, true, false), gradient);

// Correct step size
CLMNNImpl::correct_stepsize(stepsize, obj, iter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1734,6 +1734,23 @@ TEST(LinalgBackendEigen, SGMatrix_trace)
EXPECT_NEAR(trace(A), tr, 1e-15);
}

TEST(LinalgBackendEigen, SGMatrix_trace_dot)
{
const index_t m = 2;
float64_t data_A[] = {0.68764958, 0.11456779, 0.75164207, 0.50436194};
float64_t data_B[] = {0.30786772, 0.25503552, 0.34367041, 0.66491478};

SGMatrix<float64_t> A(data_A, m, m, false);
SGMatrix<float64_t> B(data_B, m, m, false);

auto C = matrix_prod(A, B);
auto tr = 0.0;
for (auto i : range(m))
tr += C(i, i);

EXPECT_NEAR(tr, trace_dot(A, B), 1e-15);
}

TEST(LinalgBackendEigen, SGMatrix_transpose_matrix)
{
const index_t m = 5, n = 3;
Expand Down

0 comments on commit 9197c72

Please sign in to comment.