Skip to content

Commit

Permalink
resolving pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
aamijar committed Aug 24, 2024
1 parent aba9c4e commit ba16aca
Showing 1 changed file with 14 additions and 30 deletions.
44 changes: 14 additions & 30 deletions cpp/include/raft/sparse/solver/detail/lanczos.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1441,19 +1441,6 @@ RAFT_KERNEL kernel_subtract_and_scale(T* u, T* vec, T* scalar, int n)
if (idx < n) { u[idx] -= (*scalar) * vec[idx]; }
}

template <typename T>
RAFT_KERNEL kernel_get_last_row(const T* M, T* S, int numRows, int numCols)
{
int col = threadIdx.x + blockIdx.x * blockDim.x;
// Ensure the thread index is within the matrix width
if (col < numCols) {
// Index in the column-major order matrix
int index = (numRows - 1) + col * numRows;
// Copy the value to the last row array
S[col] = M[index];
}
}

template <typename T>
RAFT_KERNEL kernel_triangular_populate(T* M, const T* beta, int n)
{
Expand Down Expand Up @@ -1786,21 +1773,16 @@ int lanczos_smallest(raft::resources const& handle,
raft::linalg::gemm<value_type_t, uint32_t, raft::col_major, raft::col_major, raft::col_major>(
handle, V_T, eigenvectors_k, ritz_eigenvectors);

int blockSize = 256; // Number of threads per block
int numBlocks = (nEigVecs + blockSize - 1) / blockSize;

auto s = raft::make_device_vector<value_type_t>(handle, nEigVecs);

// raft::matrix::slice_coordinates<index_type_t> coords(eigenvectors_k.extent(0) - 1, 0,
// eigenvectors_k.extent(0), eigenvectors_k.extent(1));

// auto S_matrix = raft::make_device_matrix_view<value_type_t, uint32_t,
// raft::col_major>(s.data_handle(), 1, nEigVecs);

// raft::matrix::slice(handle, make_const_mdspan(eigenvectors_k), s, coords);
auto eigenvectors_k_slice =
raft::make_device_matrix_view<value_type_t, index_type_t, raft::col_major>(
eigenvectors.data_handle(), ncv, nEigVecs);
auto S_matrix = raft::make_device_matrix_view<value_type_t, index_type_t, raft::col_major>(
s.data_handle(), 1, nEigVecs);

kernel_get_last_row<<<numBlocks, blockSize, 0, stream>>>(
eigenvectors_k.data_handle(), s.data_handle(), ncv, nEigVecs);
raft::matrix::slice_coordinates<index_type_t> coords(ncv - 1, 0, ncv, nEigVecs);
raft::matrix::slice(handle, make_const_mdspan(eigenvectors_k_slice), S_matrix, coords);

auto beta_k = raft::make_device_vector<value_type_t>(handle, nEigVecs);
raft::matrix::fill(handle, beta_k.view(), zero);
Expand Down Expand Up @@ -2027,12 +2009,14 @@ int lanczos_smallest(raft::resources const& handle,
raft::linalg::gemm<value_type_t, uint32_t, raft::col_major, raft::col_major, raft::col_major>(
handle, V_T, eigenvectors_k, ritz_eigenvectors);

int blockSize = 256; // Number of threads per block
int numBlocks = (nEigVecs + blockSize - 1) / blockSize;
auto eigenvectors_k_slice =
raft::make_device_matrix_view<value_type_t, index_type_t, raft::col_major>(
eigenvectors.data_handle(), ncv, nEigVecs);
auto S_matrix = raft::make_device_matrix_view<value_type_t, index_type_t, raft::col_major>(
s.data_handle(), 1, nEigVecs);

auto s = raft::make_device_vector<value_type_t>(handle, nEigVecs);
kernel_get_last_row<<<numBlocks, blockSize, 0, stream>>>(
eigenvectors_k.data_handle(), s.data_handle(), ncv, nEigVecs);
raft::matrix::slice_coordinates<index_type_t> coords(ncv - 1, 0, ncv, nEigVecs);
raft::matrix::slice(handle, make_const_mdspan(eigenvectors_k_slice), S_matrix, coords);

raft::matrix::fill(handle, beta_k.view(), zero);

Expand Down

0 comments on commit ba16aca

Please sign in to comment.