Skip to content

Commit

Permalink
use mdspan
Browse files Browse the repository at this point in the history
  • Loading branch information
aamijar committed May 31, 2024
1 parent 9b56d67 commit 69b9d42
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
30 changes: 23 additions & 7 deletions cpp/src/tsne/tsne_runner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@
#include <cuml/common/logger.hpp>
#include <cuml/manifold/common.hpp>

#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/divide.cuh>
#include <raft/linalg/multiply.cuh>
#include <raft/stats/mean.cuh>
#include <raft/stats/stddev.cuh>
#include <raft/util/cudart_utils.hpp>

Expand Down Expand Up @@ -124,21 +130,31 @@ class TSNE_runner {
noise_vars.data(),
prms,
stream);
handle.sync_stream(stream);

rmm::device_uvector<float> mean_result(dim, stream);
rmm::device_uvector<float> std_result(dim, stream);
std::vector<float> h_std_result(dim);
float multiplier = 1e-4;
const float multiplier = 1e-4;

raft::stats::mean(mean_result.data(), Y, dim, n, false, false, stream);
raft::stats::stddev(std_result.data(), Y, mean_result.data(), dim, n, true, false, stream);
auto Y_view = raft::make_device_matrix_view<float, int>(Y, n, dim);
auto Y_view_const = raft::make_device_matrix_view<const float, int>(Y, n, dim);

auto mean_result_view = raft::make_device_vector_view<float, int>(mean_result.data(), dim);
auto mean_result_view_const =
raft::make_device_vector_view<const float, int>(mean_result.data(), dim);

auto std_result_view = raft::make_device_vector_view<float, int>(std_result.data(), dim);

auto h_multiplier_view_const = raft::make_host_scalar_view<const float>(&multiplier);
auto h_std_result_view_const = raft::make_host_scalar_view<const float>(&h_std_result[0]);

raft::stats::mean(handle_, Y_view_const, mean_result_view, false);
raft::stats::stddev(handle_, Y_view_const, mean_result_view_const, std_result_view, false);

raft::update_host(h_std_result.data(), std_result.data(), dim, stream);
handle.sync_stream(stream);

raft::linalg::divideScalar(Y, Y, h_std_result[0], n * dim, stream);
raft::linalg::multiplyScalar(Y, Y, multiplier, n * dim, stream);
raft::linalg::divide_scalar(handle_, Y_view_const, Y_view, h_std_result_view_const);
raft::linalg::multiply_scalar(handle_, Y_view_const, Y_view, h_multiplier_view_const);
}
}
}
Expand Down
7 changes: 2 additions & 5 deletions cpp/test/sg/tsne_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,8 @@ class TSNETest : public ::testing::TestWithParam<TSNEInput> {
raft::update_device(X_d.data(), dataset.data(), n * p, stream);

rmm::device_uvector<float> Xtranspose(n * p, stream);

raft::update_device(Xtranspose.data(), X_d.data(), n * p, stream);
raft::copy_async(Xtranspose.data(), X_d.data(), n * p, stream);
raft::linalg::transpose(handle, Xtranspose.data(), X_d.data(), p, n, stream);
handle.sync_stream(stream);

rmm::device_uvector<float> Y_d(n * model_params.dim, stream);
rmm::device_uvector<int64_t> input_indices(0, stream);
Expand Down Expand Up @@ -191,9 +189,8 @@ class TSNETest : public ::testing::TestWithParam<TSNEInput> {
handle.sync_stream(stream);
free(embeddings_h);

raft::update_device(Xtranspose.data(), X_d.data(), n * p, stream);
raft::copy_async(Xtranspose.data(), X_d.data(), n * p, stream);
raft::linalg::transpose(handle, Xtranspose.data(), X_d.data(), n, p, stream);
handle.sync_stream(stream);

// Produce trustworthiness score
results.trustworthiness =
Expand Down

0 comments on commit 69b9d42

Please sign in to comment.