Skip to content

Commit

Permalink
Introduces NativeEngineKNNQuery which executes ANN on rewrite (opense…
Browse files Browse the repository at this point in the history
…arch-project#1877)

Signed-off-by: Tejas Shah <shatejas@amazon.com>
(cherry picked from commit df7627c)
  • Loading branch information
shatejas committed Aug 8, 2024
1 parent 4b5e210 commit 368ed00
Show file tree
Hide file tree
Showing 15 changed files with 950 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Refactor method structure and definitions [#1920](https://github.com/opensearch-project/k-NN/pull/1920)
* Refactor KNNVectorFieldType from KNNVectorFieldMapper to a separate class for better readability. [#1931](https://github.com/opensearch-project/k-NN/pull/1931)
* Generalize lib interface to return context objects [#1925](https://github.com/opensearch-project/k-NN/pull/1925)
* Move k search k-NN query to re-write phase of vector search query for Native Engines [#1877](https://github.com/opensearch-project/k-NN/pull/1877)
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
*/

package org.opensearch.knn.common.featureflags;

import com.google.common.annotations.VisibleForTesting;
import lombok.experimental.UtilityClass;
import org.opensearch.common.settings.Setting;
import org.opensearch.knn.index.KNNSettings;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.opensearch.common.settings.Setting.Property.Dynamic;
import static org.opensearch.common.settings.Setting.Property.NodeScope;

/**
* Class to manage KNN feature flags
*/
@UtilityClass
public class KNNFeatureFlags {

// Feature flags
private static final String KNN_LAUNCH_QUERY_REWRITE_ENABLED = "knn.feature.query.rewrite.enabled";
private static final boolean KNN_LAUNCH_QUERY_REWRITE_ENABLED_DEFAULT = true;

@VisibleForTesting
public static final Setting<Boolean> KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING = Setting.boolSetting(
KNN_LAUNCH_QUERY_REWRITE_ENABLED,
KNN_LAUNCH_QUERY_REWRITE_ENABLED_DEFAULT,
NodeScope,
Dynamic
);

public static List<Setting<?>> getFeatureFlags() {
return Stream.of(KNN_LAUNCH_QUERY_REWRITE_ENABLED_SETTING).collect(Collectors.toUnmodifiableList());
}

public static boolean isKnnQueryRewriteEnabled() {
return Boolean.parseBoolean(KNNSettings.state().getSettingValue(KNN_LAUNCH_QUERY_REWRITE_ENABLED).toString());
}
}
20 changes: 15 additions & 5 deletions src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchParseException;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsRequest;
import org.opensearch.action.admin.cluster.settings.ClusterUpdateSettingsResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.unit.ByteSizeUnit;
import org.opensearch.core.common.unit.ByteSizeValue;
import org.opensearch.index.IndexModule;
Expand All @@ -28,20 +28,22 @@
import org.opensearch.monitor.os.OsProbe;

import java.security.InvalidParameterException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.stream.Collectors.toUnmodifiableMap;
import static org.opensearch.common.settings.Setting.Property.Dynamic;
import static org.opensearch.common.settings.Setting.Property.IndexScope;
import static org.opensearch.common.settings.Setting.Property.NodeScope;
import static org.opensearch.common.unit.MemorySizeValue.parseBytesSizeValueOrHeapRatio;
import static org.opensearch.core.common.unit.ByteSizeValue.parseBytesSizeValue;
import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.getFeatureFlags;

/**
* This class defines
Expand Down Expand Up @@ -289,6 +291,9 @@ public class KNNSettings {
}
};

private final static Map<String, Setting<?>> FEATURE_FLAGS = getFeatureFlags().stream()
.collect(toUnmodifiableMap(Setting::getKey, Function.identity()));

private ClusterService clusterService;
private Client client;

Expand Down Expand Up @@ -326,7 +331,7 @@ private void setSettingsUpdateConsumers() {
);

NativeMemoryCacheManager.getInstance().rebuildCache(builder.build());
}, new ArrayList<>(dynamicCacheSettings.values()));
}, Stream.concat(dynamicCacheSettings.values().stream(), FEATURE_FLAGS.values().stream()).collect(Collectors.toUnmodifiableList()));
}

/**
Expand All @@ -346,6 +351,10 @@ private Setting<?> getSetting(String key) {
return dynamicCacheSettings.get(key);
}

if (FEATURE_FLAGS.containsKey(key)) {
return FEATURE_FLAGS.get(key);
}

if (KNN_CIRCUIT_BREAKER_TRIGGERED.equals(key)) {
return KNN_CIRCUIT_BREAKER_TRIGGERED_SETTING;
}
Expand Down Expand Up @@ -390,7 +399,8 @@ public List<Setting<?>> getSettings() {
KNN_FAISS_AVX2_DISABLED_SETTING,
KNN_VECTOR_STREAMING_MEMORY_LIMIT_PCT_SETTING
);
return Stream.concat(settings.stream(), dynamicCacheSettings.values().stream()).collect(Collectors.toList());
return Stream.concat(settings.stream(), Stream.concat(getFeatureFlags().stream(), dynamicCacheSettings.values().stream()))
.collect(Collectors.toList());
}

public static boolean isKNNPluginEnabled() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;

import java.util.Locale;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.common.featureflags.KNNFeatureFlags.isKnnQueryRewriteEnabled;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;

/**
Expand Down Expand Up @@ -98,9 +100,10 @@ public static Query create(CreateQueryRequest createQueryRequest) {
methodParameters
);

KNNQuery knnQuery = null;
switch (vectorDataType) {
case BINARY:
return KNNQuery.builder()
knnQuery = KNNQuery.builder()
.field(fieldName)
.byteQueryVector(byteVector)
.indexName(indexName)
Expand All @@ -110,8 +113,9 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.filterQuery(validatedFilterQuery)
.vectorDataType(vectorDataType)
.build();
break;
default:
return KNNQuery.builder()
knnQuery = KNNQuery.builder()
.field(fieldName)
.queryVector(vector)
.indexName(indexName)
Expand All @@ -122,6 +126,7 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.vectorDataType(vectorDataType)
.build();
}
return isKnnQueryRewriteEnabled() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery;
}

Integer requestEfSearch = null;
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ public float score() throws IOException {
public int docID() {
return docIdsIter.docID();
}

@Override
public boolean equals(Object obj) {
if (!(obj instanceof Scorer)) return false;
return getWeight().equals(((Scorer) obj).getWeight());
}
};

}
}
24 changes: 20 additions & 4 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,30 @@ public Explanation explain(LeafReaderContext context, int doc) {

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
final Map<Integer, Float> docIdToScoreMap = searchLeaf(context);
if (docIdToScoreMap.isEmpty()) {
return KNNScorer.emptyScorer(this);
}

return convertSearchResponseToScorer(docIdToScoreMap);
}

/**
* Executes k nearest neighbor search for a segment to get the top K results
* This is made public purely to be able to be reused in {@link org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery}
*
* @param context LeafReaderContext
* @return A Map of docId to scores for top k results
*/
public Map<Integer, Float> searchLeaf(LeafReaderContext context) throws IOException {

final BitSet filterBitSet = getFilteredDocsBitSet(context);
int cardinality = filterBitSet.cardinality();
// We don't need to go to JNI layer if no documents are found which satisfy the filters
// We should give this condition a deeper look that where it should be placed. For now I feel this is a good
// place,
if (filterWeight != null && cardinality == 0) {
return KNNScorer.emptyScorer(this);
return Collections.emptyMap();
}
final Map<Integer, Float> docIdsToScoreMap = new HashMap<>();

Expand All @@ -129,7 +145,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
} else {
Map<Integer, Float> annResults = doANNSearch(context, filterBitSet, cardinality);
if (annResults == null) {
return null;
return Collections.emptyMap();
}
if (canDoExactSearchAfterANNSearch(cardinality, annResults.size())) {
log.debug(
Expand All @@ -144,9 +160,9 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
docIdsToScoreMap.putAll(annResults);
}
if (docIdsToScoreMap.isEmpty()) {
return KNNScorer.emptyScorer(this);
return Collections.emptyMap();
}
return convertSearchResponseToScorer(docIdsToScoreMap);
return docIdsToScoreMap;
}

private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException {
Expand Down
Loading

0 comments on commit 368ed00

Please sign in to comment.