Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FAISS with RAFT enabled Benchmarking to raft-ann-bench #2026

Merged
merged 165 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
165 commits
Select commit Hold shift + click to select a range
cc9cbd3
Unpack list data kernel
tarang-jain Jul 1, 2023
28484ef
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 1, 2023
e39ee56
update packing and unpacking functions
tarang-jain Jul 5, 2023
68bf927
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 5, 2023
78d6380
Update codepacker
tarang-jain Jul 14, 2023
49a8834
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 14, 2023
897338e
refactor codepacker (does not build)
tarang-jain Jul 17, 2023
c1d80f5
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 17, 2023
2a2ee51
Undo deletions
tarang-jain Jul 17, 2023
834dd2c
undo yaml changes
tarang-jain Jul 17, 2023
6013429
style
tarang-jain Jul 17, 2023
ab6345a
Update tests, correct make_list_extents
tarang-jain Jul 18, 2023
ed80d1a
More changes
tarang-jain Jul 19, 2023
cdff9e1
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 19, 2023
7412272
debugging
tarang-jain Jul 20, 2023
700ea82
Working build
tarang-jain Jul 21, 2023
27451c6
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 21, 2023
9d742ef
rename codepacking api
tarang-jain Jul 21, 2023
d1ef8a1
Updated gtest
tarang-jain Jul 27, 2023
e187147
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 27, 2023
4f233a6
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 27, 2023
4ee99e3
updates
tarang-jain Jul 27, 2023
22f4f80
update testing
tarang-jain Jul 28, 2023
9f4e22c
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 28, 2023
c95d1e0
updates
tarang-jain Jul 28, 2023
da78c66
Update testing, pow2
tarang-jain Jul 31, 2023
5cc6dc9
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 31, 2023
15db0c6
remove unneccessary changes
tarang-jain Jul 31, 2023
154dc6d
Delete log.txt
tarang-jain Jul 31, 2023
47d6421
updates
tarang-jain Jul 31, 2023
0f1d106
Merge branch 'faiss-ivf' of https://github.com/tarang-jain/raft into …
tarang-jain Jul 31, 2023
e2e1308
ore cleanup
tarang-jain Jul 31, 2023
3f470c8
Merge branch 'branch-23.08' of https://github.com/rapidsai/raft into …
tarang-jain Jul 31, 2023
41a49b2
style
tarang-jain Jul 31, 2023
1d2a5b0
Merge branch 'branch-23.10' of https://github.com/rapidsai/raft into …
tarang-jain Aug 9, 2023
8ce8115
Merge branch 'branch-23.10' of https://github.com/rapidsai/raft into …
tarang-jain Aug 23, 2023
171215b
Merge branch 'branch-23.10' of https://github.com/rapidsai/raft into …
tarang-jain Sep 11, 2023
135d973
Initial commit
tarang-jain Sep 13, 2023
d7a9b4e
Merge branch 'branch-23.10' of https://github.com/rapidsai/raft into …
tarang-jain Sep 13, 2023
5738cca
im
tarang-jain Sep 21, 2023
62b39cf
host pq codepacker
tarang-jain Sep 22, 2023
8702b92
refactored codepacker
tarang-jain Sep 22, 2023
5b2a7e0
Merge branch 'branch-23.10' of https://github.com/rapidsai/raft into …
tarang-jain Sep 22, 2023
4139c7e
updated CP
tarang-jain Sep 22, 2023
e846352
undo some diffs
tarang-jain Sep 22, 2023
2ab3da2
undo some diffs
tarang-jain Sep 22, 2023
eb493a7
undo some diffs
tarang-jain Sep 22, 2023
28b7125
update docs
tarang-jain Sep 22, 2023
4b3b3bb
Merge branch 'branch-23.10' into faiss-ivf
tarang-jain Sep 25, 2023
3da5265
Merge branch 'branch-23.12' into faiss-ivf
cjnolet Oct 5, 2023
d546d89
initial efforts for compress/decompress codepacker
tarang-jain Oct 6, 2023
b6e3de9
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Oct 6, 2023
4b94c45
Merge branch 'branch-23.12' into faiss-ivf
cjnolet Oct 11, 2023
ec11fd8
Merge branch 'branch-23.12' into faiss-ivf
cjnolet Oct 12, 2023
8a41330
Update codepacker and helpers
tarang-jain Oct 17, 2023
86f1aa4
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Oct 17, 2023
0baee4a
Merge branch 'faiss-ivf' of https://github.com/tarang-jain/raft into …
tarang-jain Oct 17, 2023
9d66a8f
more helpers and debugging
tarang-jain Oct 26, 2023
3be7afd
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Oct 26, 2023
fd01442
Update tests
tarang-jain Oct 26, 2023
1b4fd0e
action struct correction
tarang-jain Nov 2, 2023
7d760e9
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Nov 2, 2023
aaff0bf
testing
tarang-jain Nov 3, 2023
c4bc220
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Nov 3, 2023
6a5443a
remove unneeded funcs
tarang-jain Nov 3, 2023
bca8f40
Merge branch 'branch-23.12' into faiss-ivf
cjnolet Nov 7, 2023
8edc7a1
Add helper for extracting cluster centers
tarang-jain Nov 7, 2023
93eebab
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Nov 7, 2023
140701e
Merge branch 'faiss-ivf' of https://github.com/tarang-jain/raft into …
tarang-jain Nov 7, 2023
0b88ca4
Update docs
tarang-jain Nov 9, 2023
d67fe8d
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Nov 9, 2023
a68d7a7
Add test
tarang-jain Nov 9, 2023
41ac27f
correction
tarang-jain Nov 9, 2023
5073ea3
Update docs
tarang-jain Nov 16, 2023
889bbdd
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Nov 16, 2023
3dbf3a7
more updates to docs
tarang-jain Nov 16, 2023
30bdee5
style
tarang-jain Nov 16, 2023
55fa0ef
more docs
tarang-jain Nov 16, 2023
8eb07f8
undo small docstring change
tarang-jain Nov 16, 2023
f8956d5
style
tarang-jain Nov 16, 2023
228e997
more doc updates
tarang-jain Nov 16, 2023
bdd75cf
small doc fix
tarang-jain Nov 16, 2023
6adcb98
resource docs
tarang-jain Nov 16, 2023
1893963
Update docs for ivf_flat::helpers::reset_index
tarang-jain Nov 16, 2023
91e17c2
Merge branch 'branch-23.12' of https://github.com/rapidsai/raft into …
tarang-jain Nov 16, 2023
a2d4575
update reset_index
tarang-jain Nov 16, 2023
1efd28f
change helpers name to contiguous
tarang-jain Nov 17, 2023
9841e6c
move get_list_size to index struct
tarang-jain Nov 17, 2023
3f8baaa
change test name
tarang-jain Nov 17, 2023
11a681f
raft enabled BM
tarang-jain Nov 29, 2023
3cd2d4a
Merge branch 'branch-24.02' of https://github.com/rapidsai/raft into …
tarang-jain Nov 29, 2023
633ad86
raft enabled IVF-Flat BM
tarang-jain Nov 29, 2023
09bcbd8
style
tarang-jain Nov 29, 2023
ab442b3
remove hardcoded pool size
tarang-jain Nov 29, 2023
a3acb5d
update faiss::gpu::benchmark main, revert pool MR in constructor
tarang-jain Dec 1, 2023
8bc00aa
Merge branch 'branch-24.02' of https://github.com/rapidsai/raft into …
tarang-jain Dec 1, 2023
e539fd2
Merge branch 'branch-24.02' of https://github.com/rapidsai/raft into …
tarang-jain Dec 2, 2023
2b089bb
Merge branch 'branch-24.02' of https://github.com/rapidsai/raft into …
tarang-jain Dec 4, 2023
87b3eb5
updated yaml
tarang-jain Dec 6, 2023
5057525
Merge branch 'branch-24.02' of https://github.com/rapidsai/raft into …
tarang-jain Dec 6, 2023
bdf7196
update config, faiss bm
tarang-jain Dec 12, 2023
a045f8e
Merge branch 'branch-24.02' of https://github.com/rapidsai/raft into …
tarang-jain Dec 12, 2023
72b7e00
debug
tarang-jain Dec 16, 2023
3bbf67a
merge changes
tarang-jain Dec 18, 2023
22b6754
raft refinement for faiss index
tarang-jain Dec 25, 2023
9d9a078
merge
tarang-jain Dec 25, 2023
1385cf8
dbg
tarang-jain Dec 26, 2023
651ea18
Merge branch 'faiss-ivf' of https://github.com/tarang-jain/raft into …
tarang-jain Dec 26, 2023
5847a09
Merge branch 'branch-24.02' into faiss-ivf
cjnolet Jan 10, 2024
77f9366
changes
tarang-jain Jan 11, 2024
31f444d
Merge branch 'faiss-ivf' of https://github.com/tarang-jain/raft into …
tarang-jain Jan 11, 2024
dfb2c2c
changes
tarang-jain Jan 16, 2024
9be5ecc
cleanup
tarang-jain Jan 17, 2024
02bdc23
Merge branch 'branch-24.02' of https://github.com/rapidsai/raft into …
tarang-jain Jan 17, 2024
27bf943
cleanup
tarang-jain Jan 17, 2024
9012267
Merge branch 'branch-24.02' of https://github.com/rapidsai/raft into …
tarang-jain Jan 18, 2024
0c714d5
updates,cleanup,style
tarang-jain Jan 18, 2024
df10536
updates,cleanup
tarang-jain Jan 18, 2024
395402c
updates,changes,style
tarang-jain Jan 18, 2024
4b8843d
Merge branch 'branch-24.02' into faiss-ivf
tarang-jain Feb 2, 2024
b1e7495
Merge branch 'branch-24.04' into faiss-ivf
tarang-jain Feb 2, 2024
95dcd10
Remove unnecessary copyright date changes
tfeher Feb 4, 2024
e25acf1
Merge branch 'branch-24.04' into faiss-ivf
tfeher Feb 4, 2024
8975a81
add 100M params,remove debug statements
tarang-jain Feb 6, 2024
8188767
merge
tarang-jain Feb 6, 2024
f697549
Merge branch 'faiss-ivf' of https://github.com/tarang-jain/raft into …
tarang-jain Feb 6, 2024
f0aa1db
small correction in 100M params
tarang-jain Feb 6, 2024
7a429e5
Merge branch 'branch-24.04' into faiss-ivf
cjnolet Feb 13, 2024
a54408f
Merge branch 'branch-24.02' of https://github.com/rapidsai/raft into …
tarang-jain Feb 14, 2024
757e07a
adding faiss cpu configs
tarang-jain Feb 14, 2024
65f096f
Merge branch 'branch-24.04' of https://github.com/rapidsai/raft into …
tarang-jain Feb 14, 2024
d472c06
Merge branch 'faiss-ivf' of https://github.com/tarang-jain/raft into …
tarang-jain Feb 14, 2024
dbb773d
merge
tarang-jain Feb 21, 2024
66baf65
update get_faiss.cmake
tarang-jain Feb 21, 2024
ccc8056
Merge branch 'branch-24.06' into faiss-ivf
tarang-jain Apr 3, 2024
2d223dd
Merge branch 'branch-24.06' into faiss-ivf
cjnolet Apr 10, 2024
092b9b9
Merge branch 'branch-24.06' into faiss-ivf
tarang-jain Apr 11, 2024
b9c64be
style
tarang-jain Apr 11, 2024
981a730
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain May 10, 2024
507ce25
undo copyright change
tarang-jain May 10, 2024
dc14d8b
remove debug statements
tarang-jain May 10, 2024
af37e68
match func signature
tarang-jain May 11, 2024
e5170a8
make build
tarang-jain May 11, 2024
9df0d73
add metric conversion func
tarang-jain May 12, 2024
bd1fe4c
remove metric parsing bugs
tarang-jain May 12, 2024
7295308
include utils header
tarang-jain May 12, 2024
f2f2e3b
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain May 13, 2024
1838102
Merge branch 'branch-24.06' into faiss-ivf
tarang-jain May 13, 2024
fe389db
bm configs for ivfflat
tarang-jain May 13, 2024
9c5cf50
update docs to keep track of FAISS issue
tarang-jain May 13, 2024
09d2422
rm name
tarang-jain May 13, 2024
d56089d
Merge branch 'faiss-ivf' of https://github.com/tarang-jain/raft into …
tarang-jain May 13, 2024
ba2cdd8
Merge branch 'branch-24.06' of https://github.com/rapidsai/raft into …
tarang-jain May 13, 2024
921eadd
Update python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/faiss_…
tarang-jain May 14, 2024
c91a94b
revert comment, final changes
tarang-jain May 14, 2024
29e08cb
merge
tarang-jain May 14, 2024
cca3927
merge 24.06
tarang-jain May 21, 2024
b0ce3ee
merge 24.08
tarang-jain Jun 7, 2024
fc5c2b3
add warning when throughput mode is enabled
tarang-jain Jun 7, 2024
0fa20a9
make compile
tarang-jain Jun 10, 2024
a2c1d7f
Merge branch 'branch-24.08' of https://github.com/rapidsai/raft into …
tarang-jain Jun 10, 2024
bd6d5b5
make compile
tarang-jain Jun 10, 2024
36c97e8
style
tarang-jain Jun 10, 2024
04a3342
corrections
tarang-jain Jun 13, 2024
60d5927
Merge branch 'branch-24.08' into faiss-ivf
tarang-jain Jun 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ if(BUILD_CPU_ONLY)
set(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE OFF)
set(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB OFF)
set(RAFT_ANN_BENCH_USE_GGNN OFF)
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0.0)
# Disable faiss benchmarks on CUDA 12 since faiss is not yet CUDA 12-enabled.
# https://github.com/rapidsai/raft/issues/1627
set(RAFT_FAISS_ENABLE_GPU OFF)
endif()

set(RAFT_ANN_BENCH_USE_RAFT OFF)
Expand Down
10 changes: 8 additions & 2 deletions cpp/bench/ann/src/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,14 @@ void register_search(std::shared_ptr<const Dataset<T>> dataset,
*/
->MeasureProcessCPUTime()
->UseRealTime();

if (metric_objective == Objective::THROUGHPUT) { b->ThreadRange(threads[0], threads[1]); }
if (metric_objective == Objective::THROUGHPUT) {
if (index.algo.find("faiss_gpu") != std::string::npos) {
log_warn(
"FAISS GPU does not work in throughput mode because the underlying "
"StandardGpuResources object is not thread-safe. This will cause unexpected results");
}
b->ThreadRange(threads[0], threads[1]);
}
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions cpp/bench/ann/src/faiss/faiss_cpu_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ void parse_build_param(const nlohmann::json& conf,
{
parse_base_build_param<T>(conf, param);
param.M = conf.at("M");
if (conf.contains("usePrecomputed")) {
param.usePrecomputed = conf.at("usePrecomputed");
if (conf.contains("use_precomputed_table")) {
param.use_precomputed_table = conf.at("use_precomputed_table");
} else {
param.usePrecomputed = false;
param.use_precomputed_table = false;
}
if (conf.contains("bitsPerCode")) {
param.bitsPerCode = conf.at("bitsPerCode");
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class FaissCpuIVFPQ : public FaissCpu<T> {
struct BuildParam : public FaissCpu<T>::BuildParam {
int M;
int bitsPerCode;
bool usePrecomputed;
bool use_precomputed_table;
};

FaissCpuIVFPQ(Metric metric, int dim, const BuildParam& param) : FaissCpu<T>(metric, dim, param)
Expand Down
30 changes: 29 additions & 1 deletion cpp/bench/ann/src/faiss/faiss_gpu_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ void parse_build_param(const nlohmann::json& conf,
typename raft::bench::ann::FaissGpuIVFFlat<T>::BuildParam& param)
{
parse_base_build_param<T>(conf, param);
if (conf.contains("use_raft")) {
param.use_raft = conf.at("use_raft");
} else {
param.use_raft = false;
}
}

template <typename T>
Expand All @@ -63,6 +68,16 @@ void parse_build_param(const nlohmann::json& conf,
} else {
param.useFloat16 = false;
}
if (conf.contains("use_raft")) {
param.use_raft = conf.at("use_raft");
} else {
param.use_raft = false;
}
if (conf.contains("bitsPerCode")) {
param.bitsPerCode = conf.at("bitsPerCode");
} else {
param.bitsPerCode = 8;
}
}

template <typename T>
Expand Down Expand Up @@ -160,5 +175,18 @@ REGISTER_ALGO_INSTANCE(std::uint8_t);

#ifdef ANN_BENCH_BUILD_MAIN
#include "../common/benchmark.hpp"
int main(int argc, char** argv) { return raft::bench::ann::run_main(argc, argv); }
int main(int argc, char** argv)
{
rmm::mr::cuda_memory_resource cuda_mr;
// Construct a resource that uses a coalescing best-fit pool allocator
// and is initially sized to half of free device memory.
rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource> pool_mr{
&cuda_mr, rmm::percent_of_free_device_memory(50)};
// Updates the current device resource pointer to `pool_mr`
auto old_mr = rmm::mr::set_current_device_resource(&pool_mr);
auto ret = raft::bench::ann::run_main(argc, argv);
// Restores the current device resource pointer to its previous value
rmm::mr::set_current_device_resource(old_mr);
return ret;
}
#endif
126 changes: 102 additions & 24 deletions cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,29 @@
#define FAISS_WRAPPER_H_

#include "../common/ann_types.hpp"
#include "../raft/raft_ann_bench_utils.h"

#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resource/stream_view.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/util/cudart_utils.hpp>

#include <raft_runtime/neighbors/refine.hpp>

#include <rmm/cuda_device.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>

#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFFlat.h>
#include <faiss/IndexIVFPQ.h>
#include <faiss/IndexRefine.h>
#include <faiss/IndexScalarQuantizer.h>
#include <faiss/MetricType.h>
#include <faiss/gpu/GpuIndexFlat.h>
#include <faiss/gpu/GpuIndexIVFFlat.h>
#include <faiss/gpu/GpuIndexIVFPQ.h>
Expand All @@ -43,7 +57,7 @@

namespace {

faiss::MetricType parse_metric_type(raft::bench::ann::Metric metric)
faiss::MetricType parse_metric_faiss(raft::bench::ann::Metric metric)
{
if (metric == raft::bench::ann::Metric::kInnerProduct) {
return faiss::METRIC_INNER_PRODUCT;
Expand Down Expand Up @@ -95,7 +109,7 @@ class FaissGpu : public ANN<T>, public AnnGPU {
FaissGpu(Metric metric, int dim, const BuildParam& param)
: ANN<T>(metric, dim),
gpu_resource_{std::make_shared<faiss::gpu::StandardGpuResources>()},
metric_type_(parse_metric_type(metric)),
metric_type_(parse_metric_faiss(metric)),
nlist_{param.nlist},
training_sample_fraction_{1.0 / double(param.ratio)}
{
Expand Down Expand Up @@ -127,7 +141,7 @@ class FaissGpu : public ANN<T>, public AnnGPU {
AlgoProperty property;
// to enable building big dataset which is larger than GPU memory
property.dataset_memory_type = MemoryType::Host;
property.query_memory_type = MemoryType::Host;
property.query_memory_type = MemoryType::Device;
return property;
}

Expand Down Expand Up @@ -162,8 +176,10 @@ class FaissGpu : public ANN<T>, public AnnGPU {
int device_;
double training_sample_fraction_;
std::shared_ptr<faiss::SearchParameters> search_params_;
std::shared_ptr<faiss::IndexRefineSearchParameters> refine_search_params_{nullptr};
const T* dataset_;
float refine_ratio_ = 1.0;
Objective metric_objective_;
};

template <typename T>
Expand Down Expand Up @@ -201,19 +217,65 @@ template <typename T>
void FaissGpu<T>::search(
const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const
{
ASSERT(Objective::LATENCY, "l2Knn: rowMajorIndex and rowMajorQuery should have same layout");
using IdxT = faiss::idx_t;
static_assert(sizeof(size_t) == sizeof(faiss::idx_t),
"sizes of size_t and faiss::idx_t are different");

if (this->refine_ratio_ > 1.0) {
// TODO: FAISS changed their search APIs to accept the search parameters as a struct object
// but their refine API doesn't allow the struct to be passed in. Once this is fixed, we
// need to re-enable refinement below
// index_refine_->search(batch_size, queries, k, distances,
// reinterpret_cast<faiss::idx_t*>(neighbors), this->search_params_.get()); Related FAISS issue:
// https://github.com/facebookresearch/faiss/issues/3118
throw std::runtime_error(
"FAISS doesn't support refinement in their new APIs so this feature is disabled in the "
"benchmarks for the time being.");
if (refine_ratio_ > 1.0) {
if (raft::get_device_for_address(queries) >= 0) {
uint32_t k0 = static_cast<uint32_t>(refine_ratio_ * k);
auto distances_tmp = raft::make_device_matrix<float, IdxT>(
gpu_resource_->getRaftHandle(device_), batch_size, k0);
auto candidates =
raft::make_device_matrix<IdxT, IdxT>(gpu_resource_->getRaftHandle(device_), batch_size, k0);
index_->search(batch_size,
queries,
k0,
distances_tmp.data_handle(),
candidates.data_handle(),
this->search_params_.get());

auto queries_host = raft::make_host_matrix<T, IdxT>(batch_size, index_->d);
auto candidates_host = raft::make_host_matrix<IdxT, IdxT>(batch_size, k0);
auto neighbors_host = raft::make_host_matrix<IdxT, IdxT>(batch_size, k);
auto distances_host = raft::make_host_matrix<float, IdxT>(batch_size, k);
auto dataset_v = raft::make_host_matrix_view<const T, faiss::idx_t>(
this->dataset_, index_->ntotal, index_->d);

raft::device_resources handle_ = gpu_resource_->getRaftHandle(device_);

raft::copy(queries_host.data_handle(), queries, queries_host.size(), handle_.get_stream());
raft::copy(candidates_host.data_handle(),
candidates.data_handle(),
candidates_host.size(),
handle_.get_stream());

// wait for the queries to copy to host in 'stream`
handle_.sync_stream();

raft::runtime::neighbors::refine(handle_,
dataset_v,
queries_host.view(),
candidates_host.view(),
neighbors_host.view(),
distances_host.view(),
parse_metric_type(this->metric_));

raft::copy(neighbors,
(size_t*)neighbors_host.data_handle(),
neighbors_host.size(),
handle_.get_stream());
raft::copy(
distances, distances_host.data_handle(), distances_host.size(), handle_.get_stream());
} else {
index_refine_->search(batch_size,
queries,
k,
distances,
reinterpret_cast<faiss::idx_t*>(neighbors),
this->refine_search_params_.get());
}
} else {
index_->search(batch_size,
queries,
Expand Down Expand Up @@ -255,13 +317,16 @@ void FaissGpu<T>::load_(const std::string& file)
template <typename T>
class FaissGpuIVFFlat : public FaissGpu<T> {
public:
using typename FaissGpu<T>::BuildParam;
struct BuildParam : public FaissGpu<T>::BuildParam {
bool use_raft;
};

FaissGpuIVFFlat(Metric metric, int dim, const BuildParam& param) : FaissGpu<T>(metric, dim, param)
{
faiss::gpu::GpuIndexIVFFlatConfig config;
config.device = this->device_;
this->index_ = std::make_shared<faiss::gpu::GpuIndexIVFFlat>(
config.device = this->device_;
config.use_raft = param.use_raft;
this->index_ = std::make_shared<faiss::gpu::GpuIndexIVFFlat>(
this->gpu_resource_.get(), dim, param.nlist, this->metric_type_, config);
}

Expand Down Expand Up @@ -295,23 +360,26 @@ class FaissGpuIVFPQ : public FaissGpu<T> {
int M;
bool useFloat16;
bool usePrecomputed;
bool use_raft;
int bitsPerCode;
};

FaissGpuIVFPQ(Metric metric, int dim, const BuildParam& param) : FaissGpu<T>(metric, dim, param)
{
faiss::gpu::GpuIndexIVFPQConfig config;
config.useFloat16LookupTables = param.useFloat16;
config.usePrecomputedTables = param.usePrecomputed;
config.use_raft = param.use_raft;
config.interleavedLayout = param.use_raft;
config.device = this->device_;

this->index_ =
std::make_shared<faiss::gpu::GpuIndexIVFPQ>(this->gpu_resource_.get(),
dim,
param.nlist,
param.M,
8, // FAISS only supports bitsPerCode=8
this->metric_type_,
config);
this->index_ = std::make_shared<faiss::gpu::GpuIndexIVFPQ>(this->gpu_resource_.get(),
dim,
param.nlist,
param.M,
param.bitsPerCode,
this->metric_type_,
config);
}

void set_search_param(const typename FaissGpu<T>::AnnSearchParam& param) override
Expand All @@ -329,6 +397,11 @@ class FaissGpuIVFPQ : public FaissGpu<T> {
this->index_refine_ =
std::make_shared<faiss::IndexRefineFlat>(this->index_.get(), this->dataset_);
this->index_refine_.get()->k_factor = search_param.refine_ratio;
faiss::IndexRefineSearchParameters faiss_refine_search_params;
faiss_refine_search_params.k_factor = this->index_refine_.get()->k_factor;
faiss_refine_search_params.base_index_params = this->search_params_.get();
this->refine_search_params_ =
std::make_unique<faiss::IndexRefineSearchParameters>(faiss_refine_search_params);
}
}

Expand Down Expand Up @@ -385,6 +458,11 @@ class FaissGpuIVFSQ : public FaissGpu<T> {
this->index_refine_ =
std::make_shared<faiss::IndexRefineFlat>(this->index_.get(), this->dataset_);
this->index_refine_.get()->k_factor = search_param.refine_ratio;
faiss::IndexRefineSearchParameters faiss_refine_search_params;
faiss_refine_search_params.k_factor = this->index_refine_.get()->k_factor;
faiss_refine_search_params.base_index_params = this->search_params_.get();
this->refine_search_params_ =
std::make_unique<faiss::IndexRefineSearchParameters>(faiss_refine_search_params);
}
}

Expand Down
3 changes: 2 additions & 1 deletion cpp/cmake/thirdparty/get_faiss.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ function(find_and_configure_faiss)
EXCLUDE_FROM_ALL ${exclude}
OPTIONS
"FAISS_ENABLE_GPU ${PKG_ENABLE_GPU}"
"FAISS_ENABLE_RAFT ${PKG_ENABLE_GPU}"
"FAISS_ENABLE_PYTHON OFF"
"FAISS_OPT_LEVEL ${RAFT_FAISS_OPT_LEVEL}"
"FAISS_USE_CUDA_TOOLKIT_STATIC ${CUDA_STATIC_RUNTIME}"
Expand Down Expand Up @@ -115,4 +116,4 @@ endfunction()
find_and_configure_faiss(
BUILD_STATIC_LIBS ${RAFT_USE_FAISS_STATIC}
ENABLE_GPU ${RAFT_FAISS_ENABLE_GPU}
)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed several files in this PR that do not end with a newline.

In practice I've found that this rarely matters, but it can occasionally cause issues for tools doing line-by-line parsing of files. In a future PR, you may want to consider automatically fixing these with pre-commit, like cudf does: https://github.com/rapidsai/cudf/blob/107753ccaacdb62287c4dd4351e5caf3bf8bc62a/.pre-commit-config.yaml#L13

9 changes: 5 additions & 4 deletions cpp/include/raft_runtime/neighbors/refine.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,8 +17,9 @@
#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resources.hpp>
// #include <raft/core/host_mdspan.hpp>
#include <raft/distance/distance_types.hpp>

namespace raft::runtime::neighbors {

Expand All @@ -29,15 +30,15 @@ namespace raft::runtime::neighbors {
raft::device_matrix_view<const IDX_T, int64_t, row_major> neighbor_candidates, \
raft::device_matrix_view<IDX_T, int64_t, row_major> indices, \
raft::device_matrix_view<float, int64_t, row_major> distances, \
distance::DistanceType metric); \
raft::distance::DistanceType metric); \
\
void refine(raft::resources const& handle, \
raft::host_matrix_view<const DATA_T, int64_t, row_major> dataset, \
raft::host_matrix_view<const DATA_T, int64_t, row_major> queries, \
raft::host_matrix_view<const IDX_T, int64_t, row_major> neighbor_candidates, \
raft::host_matrix_view<IDX_T, int64_t, row_major> indices, \
raft::host_matrix_view<float, int64_t, row_major> distances, \
distance::DistanceType metric);
raft::distance::DistanceType metric);

RAFT_INST_REFINE(int64_t, float);
RAFT_INST_REFINE(int64_t, uint8_t);
Expand Down
9 changes: 6 additions & 3 deletions docs/source/ann_benchmarks_build.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ You can limit the algorithms that are built by providing a semicolon-delimited l
```

Available targets to use with `--limit-bench-ann` are:
- FAISS_IVF_FLAT_ANN_BENCH
- FAISS_IVF_PQ_ANN_BENCH
- FAISS_BFKNN_ANN_BENCH
- FAISS_GPU_IVF_FLAT_ANN_BENCH
- FAISS_GPU_IVF_PQ_ANN_BENCH
- FAISS_CPU_IVF_FLAT_ANN_BENCH
- FAISS_CPU_IVF_PQ_ANN_BENCH
- FAISS_GPU_FLAT_ANN_BENCH
- FAISS_CPU_FLAT_ANN_BENCH
- GGNN_ANN_BENCH
- HNSWLIB_ANN_BENCH
- RAFT_CAGRA_ANN_BENCH
Expand Down
Loading
Loading