-
Notifications
You must be signed in to change notification settings - Fork 182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FAISS with RAFT enabled Benchmarking to raft-ann-bench #2026
Changes from 115 commits
cc9cbd3
28484ef
e39ee56
68bf927
78d6380
49a8834
897338e
c1d80f5
2a2ee51
834dd2c
6013429
ab6345a
ed80d1a
cdff9e1
7412272
700ea82
27451c6
9d742ef
d1ef8a1
e187147
4f233a6
4ee99e3
22f4f80
9f4e22c
c95d1e0
da78c66
5cc6dc9
15db0c6
154dc6d
47d6421
0f1d106
e2e1308
3f470c8
41a49b2
1d2a5b0
8ce8115
171215b
135d973
d7a9b4e
5738cca
62b39cf
8702b92
5b2a7e0
4139c7e
e846352
2ab3da2
eb493a7
28b7125
4b3b3bb
3da5265
d546d89
b6e3de9
4b94c45
ec11fd8
8a41330
86f1aa4
0baee4a
9d66a8f
3be7afd
fd01442
1b4fd0e
7d760e9
aaff0bf
c4bc220
6a5443a
bca8f40
8edc7a1
93eebab
140701e
0b88ca4
d67fe8d
a68d7a7
41ac27f
5073ea3
889bbdd
3dbf3a7
30bdee5
55fa0ef
8eb07f8
f8956d5
228e997
bdd75cf
6adcb98
1893963
91e17c2
a2d4575
1efd28f
9841e6c
3f8baaa
11a681f
3cd2d4a
633ad86
09bcbd8
ab442b3
a3acb5d
8bc00aa
e539fd2
2b089bb
87b3eb5
5057525
bdf7196
a045f8e
72b7e00
3bbf67a
22b6754
9d9a078
1385cf8
651ea18
5847a09
77f9366
31f444d
dfb2c2c
9be5ecc
02bdc23
27bf943
9012267
0c714d5
df10536
395402c
4b8843d
b1e7495
95dcd10
e25acf1
8975a81
8188767
f697549
f0aa1db
7a429e5
a54408f
757e07a
65f096f
d472c06
dbb773d
66baf65
ccc8056
2d223dd
092b9b9
b9c64be
981a730
507ce25
dc14d8b
af37e68
e5170a8
9df0d73
bd1fe4c
7295308
f2f2e3b
1838102
fe389db
9c5cf50
09d2422
d56089d
ba2cdd8
921eadd
c91a94b
29e08cb
cca3927
b0ce3ee
fc5c2b3
0fa20a9
a2c1d7f
bd6d5b5
36c97e8
04a3342
60d5927
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -43,6 +43,11 @@ void parse_build_param(const nlohmann::json& conf, | |||||||||||||||||||||||||||||||||
typename raft::bench::ann::FaissGpuIVFFlat<T>::BuildParam& param) | ||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||
parse_base_build_param<T>(conf, param); | ||||||||||||||||||||||||||||||||||
if (conf.contains("use_raft")) { | ||||||||||||||||||||||||||||||||||
param.use_raft = conf.at("use_raft"); | ||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||
param.use_raft = false; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
template <typename T> | ||||||||||||||||||||||||||||||||||
|
@@ -61,6 +66,16 @@ void parse_build_param(const nlohmann::json& conf, | |||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||
param.useFloat16 = false; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
if (conf.contains("use_raft")) { | ||||||||||||||||||||||||||||||||||
param.use_raft = conf.at("use_raft"); | ||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||
param.use_raft = false; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
if (conf.contains("bitsPerCode")) { | ||||||||||||||||||||||||||||||||||
param.bitsPerCode = conf.at("bitsPerCode"); | ||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||
param.bitsPerCode = 8; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
template <typename T> | ||||||||||||||||||||||||||||||||||
|
@@ -77,6 +92,12 @@ void parse_search_param(const nlohmann::json& conf, | |||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||
param.nprobe = conf.at("nprobe"); | ||||||||||||||||||||||||||||||||||
if (conf.contains("refine_ratio")) { param.refine_ratio = conf.at("refine_ratio"); } | ||||||||||||||||||||||||||||||||||
if (conf.contains("raft_refinement")) { | ||||||||||||||||||||||||||||||||||
RAFT_LOG_INFO("found raft_refinement"); | ||||||||||||||||||||||||||||||||||
param.raft_refinement = conf.at("raft_refinement"); | ||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||
param.raft_refinement = false; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
template <typename T, template <typename> class Algo> | ||||||||||||||||||||||||||||||||||
|
@@ -158,5 +179,15 @@ REGISTER_ALGO_INSTANCE(std::uint8_t); | |||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
#ifdef ANN_BENCH_BUILD_MAIN | ||||||||||||||||||||||||||||||||||
#include "../common/benchmark.hpp" | ||||||||||||||||||||||||||||||||||
int main(int argc, char** argv) { return raft::bench::ann::run_main(argc, argv); } | ||||||||||||||||||||||||||||||||||
int main(int argc, char** argv) | ||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||
rmm::mr::cuda_memory_resource cuda_mr; | ||||||||||||||||||||||||||||||||||
// Construct a resource that uses a coalescing best-fit pool allocator | ||||||||||||||||||||||||||||||||||
rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource> pool_mr{&cuda_mr}; | ||||||||||||||||||||||||||||||||||
rmm::mr::set_current_device_resource( | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For single threaded benchmarks this is fine and also for multi-threaded benchmarks, the pool will be correctly shared. We shall consider how the RAFT handle is shared in multi threaded environment. (Commenting here because the other relevant code path in Notes:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, yes, I've added the note on multithreading here raft/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h Lines 171 to 186 in 27bf943
The gist is that, in the current state, we share a single faiss handle among multiple threads; this is in contrast to a new raft handle being created for every thread in raft algorithms. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, there is just one handle for each device. https://github.com/facebookresearch/faiss/blob/5e3eae4fccb20723dbc674b3ffa005ce09afcd8d/faiss/gpu/StandardGpuResources.cpp#L432 For benchmarking purposes I have just been using a single CPU thread to prevent differences due to number of threads. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a method getAlternateStreams https://github.com/facebookresearch/faiss/blob/5e3eae4fccb20723dbc674b3ffa005ce09afcd8d/faiss/gpu/StandardGpuResources.cpp#L450 that can be used to get a vector of alternate streams for the device. Perhaps we can have something similar for raft handles. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@tarang-jain unfortunately we don't yet have automated testing for the benchmarks tool, which means we need to be vigilant about manually testing changes in the meantime. At a minimum, the tests should be run with both There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can just create a separate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tfeher but that would mean that the whole index would have to be copied for each thread. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tfeher @achirkin There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes In practice we use the |
||||||||||||||||||||||||||||||||||
&pool_mr); // Updates the current device resource pointer to `pool_mr` | ||||||||||||||||||||||||||||||||||
rmm::mr::device_memory_resource* mr = | ||||||||||||||||||||||||||||||||||
rmm::mr::get_current_device_resource(); // Points to `pool_mr` | ||||||||||||||||||||||||||||||||||
return raft::bench::ann::run_main(argc, argv); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,10 +17,15 @@ | |
#define FAISS_WRAPPER_H_ | ||
|
||
#include "../common/ann_types.hpp" | ||
#include <raft/core/device_mdarray.hpp> | ||
#include <raft/core/host_mdarray.hpp> | ||
#include <raft/core/host_mdspan.hpp> | ||
#include <raft/distance/distance_types.hpp> | ||
|
||
#include <raft/core/logger.hpp> | ||
#include <raft/util/cudart_utils.hpp> | ||
|
||
#include <faiss/MetricType.h> | ||
#include <faiss/IndexFlat.h> | ||
#include <faiss/IndexIVFFlat.h> | ||
#include <faiss/IndexIVFPQ.h> | ||
|
@@ -37,6 +42,10 @@ | |
|
||
#include <raft/core/device_resources.hpp> | ||
#include <raft/core/resource/stream_view.hpp> | ||
#include <raft_runtime/neighbors/refine.hpp> | ||
#include <rmm/cuda_device.hpp> | ||
#include <rmm/mr/device/device_memory_resource.hpp> | ||
#include <rmm/mr/device/per_device_resource.hpp> | ||
|
||
#include <cassert> | ||
#include <memory> | ||
|
@@ -99,7 +108,8 @@ class FaissGpu : public ANN<T> { | |
using typename ANN<T>::AnnSearchParam; | ||
struct SearchParam : public AnnSearchParam { | ||
int nprobe; | ||
float refine_ratio = 1.0; | ||
float refine_ratio = 1.0; | ||
bool raft_refinement = false; | ||
auto needs_dataset() const -> bool override { return refine_ratio > 1.0f; } | ||
}; | ||
|
||
|
@@ -143,6 +153,8 @@ class FaissGpu : public ANN<T> { | |
return property; | ||
} | ||
|
||
auto metric_faiss_to_raft(faiss::MetricType metric) const -> raft::distance::DistanceType; | ||
|
||
protected: | ||
template <typename GpuIndex, typename CpuIndex> | ||
void save_(const std::string& file) const; | ||
|
@@ -181,13 +193,27 @@ class FaissGpu : public ANN<T> { | |
copyable_event sync_{}; | ||
double training_sample_fraction_; | ||
std::shared_ptr<faiss::SearchParameters> search_params_; | ||
std::shared_ptr<faiss::IndexRefineSearchParameters> refine_search_params_{nullptr}; | ||
const T* dataset_; | ||
float refine_ratio_ = 1.0; | ||
float refine_ratio_ = 1.0; | ||
bool raft_refinement_ = false; | ||
}; | ||
|
||
template <typename T> | ||
auto FaissGpu<T>::metric_faiss_to_raft(faiss::MetricType metric) const | ||
-> raft::distance::DistanceType | ||
{ | ||
switch (metric) { | ||
case faiss::MetricType::METRIC_L2: return raft::distance::DistanceType::L2Expanded; | ||
case faiss::MetricType::METRIC_INNER_PRODUCT: | ||
default: throw std::runtime_error("FAISS supports only metric type of inner product and L2"); | ||
} | ||
} | ||
|
||
template <typename T> | ||
void FaissGpu<T>::build(const T* dataset, size_t nrow, cudaStream_t stream) | ||
{ | ||
// raft::print_host_vector("faiss dataset", dataset, 100, std::cout); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove. |
||
OmpSingleThreadScope omp_single_thread; | ||
auto index_ivf = dynamic_cast<faiss::gpu::GpuIndexIVF*>(index_.get()); | ||
if (index_ivf != nullptr) { | ||
|
@@ -208,7 +234,7 @@ void FaissGpu<T>::build(const T* dataset, size_t nrow, cudaStream_t stream) | |
nlist_, | ||
index_ivf->cp.min_points_per_centroid); | ||
} | ||
index_ivf->cp.max_points_per_centroid = max_ppc; | ||
index_ivf->cp.max_points_per_centroid = 300; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you changing this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes these were all parts of debugging experiments. Changed it back. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you change this back to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you change it back to |
||
index_ivf->cp.min_points_per_centroid = min_ppc; | ||
} | ||
index_->train(nrow, dataset); // faiss::gpu::GpuIndexFlat::train() will do nothing | ||
|
@@ -225,19 +251,79 @@ void FaissGpu<T>::search(const T* queries, | |
float* distances, | ||
cudaStream_t stream) const | ||
{ | ||
using IdxT = faiss::idx_t; | ||
static_assert(sizeof(size_t) == sizeof(faiss::idx_t), | ||
"sizes of size_t and faiss::idx_t are different"); | ||
|
||
if (this->refine_ratio_ > 1.0) { | ||
// TODO: FAISS changed their search APIs to accept the search parameters as a struct object | ||
// but their refine API doesn't allow the struct to be passed in. Once this is fixed, we | ||
// need to re-enable refinement below | ||
// index_refine_->search(batch_size, queries, k, distances, | ||
// reinterpret_cast<faiss::idx_t*>(neighbors), this->search_params_.get()); Related FAISS issue: | ||
// https://github.com/facebookresearch/faiss/issues/3118 | ||
throw std::runtime_error( | ||
"FAISS doesn't support refinement in their new APIs so this feature is disabled in the " | ||
"benchmarks for the time being."); | ||
if (refine_ratio_ > 1.0) { | ||
if (raft_refinement_) { | ||
uint32_t k0 = static_cast<uint32_t>(refine_ratio_ * k); | ||
// auto distances_tmp = raft::make_host_matrix<float, IdxT>(batch_size, k0); | ||
// auto candidates = raft::make_host_matrix<IdxT, IdxT>(batch_size, k0); | ||
auto distances_tmp = raft::make_device_matrix<float, IdxT>(gpu_resource_->getRaftHandle(device_), batch_size, k0); | ||
auto candidates = raft::make_device_matrix<IdxT, IdxT>(gpu_resource_->getRaftHandle(device_), batch_size, k0); | ||
index_->search(batch_size, | ||
queries, | ||
k0, | ||
distances_tmp.data_handle(), | ||
candidates.data_handle(), | ||
this->search_params_.get()); | ||
// auto queries_v = raft::make_host_matrix_view<const T, IdxT>(queries, batch_size, index_->d); | ||
|
||
|
||
// auto dataset_v = raft::make_host_matrix_view<const T, faiss::idx_t>( | ||
// this->dataset_, index_->ntotal, index_->d); | ||
|
||
// auto neighbors_v = | ||
// raft::make_host_matrix_view<IdxT, IdxT>(reinterpret_cast<IdxT*>(neighbors), batch_size, k); | ||
// auto distances_v = raft::make_host_matrix_view<float, IdxT>(distances, batch_size, k); | ||
|
||
// raft::runtime::neighbors::refine(gpu_resource_->getRaftHandle(device_), | ||
// dataset_v, | ||
// queries_v, | ||
// candidates.view(), | ||
// neighbors_v, | ||
// distances_v, | ||
// metric_faiss_to_raft(index_->metric_type)); | ||
|
||
auto queries_host = raft::make_host_matrix<T, IdxT>(batch_size, index_->d); | ||
auto candidates_host = raft::make_host_matrix<IdxT, IdxT>(batch_size, k0); | ||
auto neighbors_host = raft::make_host_matrix<IdxT, IdxT>(batch_size, k); | ||
auto distances_host = raft::make_host_matrix<float, IdxT>(batch_size, k); | ||
auto dataset_v = raft::make_host_matrix_view<const T, faiss::idx_t>( | ||
this->dataset_, index_->ntotal, index_->d); | ||
|
||
auto handle_ = gpu_resource_->getRaftHandle(device_); | ||
|
||
raft::copy(queries_host.data_handle(), queries, queries_host.size(), stream); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aren't these already on host? We have the following property defied for our wrapper class: property.query_memory_type = MemoryType::Host; I would expect that our benchmark provides queries to be on the host. If I understand correctly, we could also provide the candidates array as a host array to FAISS. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am allowing queries to be on both -- host and device. When the queries are on host and we have refinement ratios > 1, use FAISS' refinement methods. When they are on device and we have refinement enabled, we use raft's refinement API because FAISS'
|
||
raft::copy(candidates_host.data_handle(), | ||
candidates.data_handle(), | ||
candidates_host.size(), | ||
resource::get_cuda_stream(handle_)); | ||
|
||
// wait for the queries to copy to host in 'stream` and for IVF-PQ::search to finish | ||
// RAFT_CUDA_TRY(cudaEventRecord(handle_.get_sync_event(), resource::get_cuda_stream(handle_))); | ||
// RAFT_CUDA_TRY(cudaEventRecord(handle_.get_sync_event(), stream)); | ||
// RAFT_CUDA_TRY(cudaEventSynchronize(handle_.get_sync_event())); | ||
handle_.sync_stream(); | ||
raft::runtime::neighbors::refine(handle_, | ||
dataset_v, | ||
queries_host.view(), | ||
candidates_host.view(), | ||
neighbors_host.view(), | ||
distances_host.view(), | ||
metric_faiss_to_raft(index_->metric_type)); | ||
|
||
raft::copy(neighbors, (size_t*)neighbors_host.data_handle(), neighbors_host.size(), stream); | ||
raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); | ||
} else { | ||
index_refine_->search(batch_size, | ||
queries, | ||
k, | ||
distances, | ||
reinterpret_cast<faiss::idx_t*>(neighbors), | ||
this->refine_search_params_.get()); | ||
} | ||
} else { | ||
index_->search(batch_size, | ||
queries, | ||
|
@@ -280,13 +366,16 @@ void FaissGpu<T>::load_(const std::string& file) | |
template <typename T> | ||
class FaissGpuIVFFlat : public FaissGpu<T> { | ||
public: | ||
using typename FaissGpu<T>::BuildParam; | ||
struct BuildParam : public FaissGpu<T>::BuildParam { | ||
bool use_raft; | ||
}; | ||
|
||
FaissGpuIVFFlat(Metric metric, int dim, const BuildParam& param) : FaissGpu<T>(metric, dim, param) | ||
{ | ||
faiss::gpu::GpuIndexIVFFlatConfig config; | ||
config.device = this->device_; | ||
this->index_ = std::make_shared<faiss::gpu::GpuIndexIVFFlat>( | ||
config.device = this->device_; | ||
config.use_raft = param.use_raft; | ||
this->index_ = std::make_shared<faiss::gpu::GpuIndexIVFFlat>( | ||
this->gpu_resource_.get(), dim, param.nlist, this->metric_type_, config); | ||
} | ||
|
||
|
@@ -320,21 +409,25 @@ class FaissGpuIVFPQ : public FaissGpu<T> { | |
int M; | ||
bool useFloat16; | ||
bool usePrecomputed; | ||
bool use_raft; | ||
int bitsPerCode; | ||
}; | ||
|
||
FaissGpuIVFPQ(Metric metric, int dim, const BuildParam& param) : FaissGpu<T>(metric, dim, param) | ||
{ | ||
faiss::gpu::GpuIndexIVFPQConfig config; | ||
config.useFloat16LookupTables = param.useFloat16; | ||
config.usePrecomputedTables = param.usePrecomputed; | ||
config.use_raft = param.use_raft; | ||
config.interleavedLayout = param.use_raft; | ||
config.device = this->device_; | ||
|
||
this->index_ = | ||
std::make_shared<faiss::gpu::GpuIndexIVFPQ>(this->gpu_resource_.get(), | ||
dim, | ||
param.nlist, | ||
param.M, | ||
8, // FAISS only supports bitsPerCode=8 | ||
param.bitsPerCode, | ||
this->metric_type_, | ||
config); | ||
} | ||
|
@@ -354,7 +447,14 @@ class FaissGpuIVFPQ : public FaissGpu<T> { | |
this->index_refine_ = | ||
std::make_shared<faiss::IndexRefineFlat>(this->index_.get(), this->dataset_); | ||
this->index_refine_.get()->k_factor = search_param.refine_ratio; | ||
faiss::IndexRefineSearchParameters faiss_refine_search_params; | ||
faiss_refine_search_params.k_factor = this->index_refine_.get()->k_factor; | ||
faiss_refine_search_params.base_index_params = this->search_params_.get(); | ||
this->refine_search_params_ = | ||
std::make_unique<faiss::IndexRefineSearchParameters>(faiss_refine_search_params); | ||
} | ||
this->raft_refinement_ = search_param.raft_refinement; | ||
RAFT_LOG_INFO("refine_ratio %f raft_refinement %d", this->refine_ratio_, this->raft_refinement_); | ||
} | ||
|
||
void save(const std::string& file) const override | ||
|
@@ -410,6 +510,11 @@ class FaissGpuIVFSQ : public FaissGpu<T> { | |
this->index_refine_ = | ||
std::make_shared<faiss::IndexRefineFlat>(this->index_.get(), this->dataset_); | ||
this->index_refine_.get()->k_factor = search_param.refine_ratio; | ||
faiss::IndexRefineSearchParameters faiss_refine_search_params; | ||
faiss_refine_search_params.k_factor = this->index_refine_.get()->k_factor; | ||
faiss_refine_search_params.base_index_params = this->search_params_.get(); | ||
this->refine_search_params_ = | ||
std::make_unique<faiss::IndexRefineSearchParameters>(faiss_refine_search_params); | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,7 +78,7 @@ class RaftIvfPQ : public ANN<T> { | |
{ | ||
AlgoProperty property; | ||
property.dataset_memory_type = MemoryType::Host; | ||
property.query_memory_type = MemoryType::Device; | ||
property.query_memory_type = MemoryType::Host; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am surprised that this does not break the code. RAFT There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It does break I think. These changes and the debug (print) statements were all part of some experiments that I was doing. Updating all of those now. |
||
return property; | ||
} | ||
void save(const std::string& file) const override; | ||
|
@@ -209,7 +209,7 @@ void RaftIvfPQ<T, IdxT>::search(const T* queries, | |
raft::make_device_matrix_view<const T, IdxT>(queries, batch_size, index_->dim()); | ||
auto neighbors_v = raft::make_device_matrix_view<IdxT, IdxT>((IdxT*)neighbors, batch_size, k); | ||
auto distances_v = raft::make_device_matrix_view<float, IdxT>(distances, batch_size, k); | ||
|
||
raft::runtime::neighbors::ivf_pq::search( | ||
handle_, search_params_, *index_, queries_v, neighbors_v, distances_v); | ||
handle_.stream_wait(stream); // RAFT stream -> bench stream | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove