diff --git a/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java index 553a68781c..f17774a1ce 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/AbstractElasticsearchTemplate.java @@ -208,20 +208,25 @@ public long count(Query query, Class clazz) { @Override public CloseableIterator stream(Query query, Class clazz, IndexCoordinates index) { + long scrollTimeInMillis = TimeValue.timeValueMinutes(1).millis(); return (CloseableIterator) SearchHitSupport.unwrapSearchHits(searchForStream(query, clazz, index)); } @Override - public CloseableIterator> searchForStream(Query query, Class clazz) { + public SearchHitsIterator searchForStream(Query query, Class clazz) { return searchForStream(query, clazz, getIndexCoordinatesFor(clazz)); } @Override - public CloseableIterator> searchForStream(Query query, Class clazz, IndexCoordinates index) { + public SearchHitsIterator searchForStream(Query query, Class clazz, IndexCoordinates index) { + long scrollTimeInMillis = TimeValue.timeValueMinutes(1).millis(); - return StreamQueries.streamResults(searchScrollStart(scrollTimeInMillis, query, clazz, index), - scrollId -> searchScrollContinue(scrollId, scrollTimeInMillis, clazz), this::searchScrollClear); + + return StreamQueries.streamResults( // + searchScrollStart(scrollTimeInMillis, query, clazz, index), // + scrollId -> searchScrollContinue(scrollId, scrollTimeInMillis, clazz), // + this::searchScrollClear); } @Override @@ -283,13 +288,13 @@ public SearchHits search(Query query, Class clazz) { /* * internal use only, not for public API */ - abstract protected ScrolledPage> searchScrollStart(long scrollTimeInMillis, Query query, + abstract protected SearchScrollHits searchScrollStart(long scrollTimeInMillis, Query query, Class clazz, IndexCoordinates index); /* * internal use only, not for public API */ - abstract protected ScrolledPage> searchScrollContinue(@Nullable String scrollId, + abstract protected SearchScrollHits searchScrollContinue(@Nullable String scrollId, long scrollTimeInMillis, Class clazz); /* diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java index 147d66b7e3..76ecd7c1d0 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchRestTemplate.java @@ -39,7 +39,6 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.search.suggest.SuggestBuilder; -import org.springframework.data.domain.Pageable; import org.springframework.data.elasticsearch.core.convert.ElasticsearchConverter; import org.springframework.data.elasticsearch.core.document.DocumentAdapters; import org.springframework.data.elasticsearch.core.document.SearchDocumentResponse; @@ -257,24 +256,28 @@ public SearchHits search(Query query, Class clazz, IndexCoordinates in } @Override - public ScrolledPage> searchScrollStart(long scrollTimeInMillis, Query query, Class clazz, + public SearchScrollHits searchScrollStart(long scrollTimeInMillis, Query query, Class clazz, IndexCoordinates index) { - Assert.notNull(query.getPageable(), "Query.pageable is required for scan & scroll"); + Assert.notNull(query.getPageable(), "pageable of query must not be null."); SearchRequest searchRequest = requestFactory.searchRequest(query, clazz, index); searchRequest.scroll(TimeValue.timeValueMillis(scrollTimeInMillis)); - SearchResponse result = execute(client -> client.search(searchRequest, RequestOptions.DEFAULT)); - return elasticsearchConverter.mapResults(SearchDocumentResponse.from(result), clazz, null); + + SearchResponse response = execute(client -> client.search(searchRequest, RequestOptions.DEFAULT)); + + return elasticsearchConverter.readScroll(clazz, SearchDocumentResponse.from(response)); } @Override - public ScrolledPage> searchScrollContinue(@Nullable String scrollId, long scrollTimeInMillis, - Class clazz) { + public SearchScrollHits searchScrollContinue(@Nullable String scrollId, long scrollTimeInMillis, Class clazz) { + SearchScrollRequest request = new SearchScrollRequest(scrollId); request.scroll(TimeValue.timeValueMillis(scrollTimeInMillis)); - SearchResponse response = execute(client -> client.searchScroll(request, RequestOptions.DEFAULT)); - return elasticsearchConverter.mapResults(SearchDocumentResponse.from(response), clazz, Pageable.unpaged()); + + SearchResponse response = execute(client -> client.scroll(request, RequestOptions.DEFAULT)); + + return elasticsearchConverter.readScroll(clazz, SearchDocumentResponse.from(response)); } @Override diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java index d8ef076503..22394d3ec3 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplate.java @@ -27,6 +27,7 @@ import org.elasticsearch.action.search.MultiSearchResponse; import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.SearchScrollRequestBuilder; import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.client.Client; import org.elasticsearch.common.unit.TimeValue; @@ -260,22 +261,32 @@ public SearchHits search(Query query, Class clazz, IndexCoordinates in } @Override - public ScrolledPage> searchScrollStart(long scrollTimeInMillis, Query query, Class clazz, + public SearchScrollHits searchScrollStart(long scrollTimeInMillis, Query query, Class clazz, IndexCoordinates index) { - Assert.notNull(query.getPageable(), "Query.pageable is required for scan & scroll"); - SearchRequestBuilder searchRequestBuilder = requestFactory.searchRequestBuilder(client, query, clazz, index); - searchRequestBuilder.setScroll(TimeValue.timeValueMillis(scrollTimeInMillis)); - SearchResponse response = getSearchResponse(searchRequestBuilder); - return elasticsearchConverter.mapResults(SearchDocumentResponse.from(response), clazz, null); + Assert.notNull(query.getPageable(), "pageable of query must not be null."); + + ActionFuture action = requestFactory // + .searchRequestBuilder(client, query, clazz, index) // + .setScroll(TimeValue.timeValueMillis(scrollTimeInMillis)) // + .execute(); + + SearchResponse response = getSearchResponseWithTimeout(action); + + return elasticsearchConverter.readScroll(clazz, SearchDocumentResponse.from(response)); } @Override - public ScrolledPage> searchScrollContinue(@Nullable String scrollId, long scrollTimeInMillis, - Class clazz) { - SearchResponse response = getSearchResponseWithTimeout( - client.prepareSearchScroll(scrollId).setScroll(TimeValue.timeValueMillis(scrollTimeInMillis)).execute()); - return elasticsearchConverter.mapResults(SearchDocumentResponse.from(response), clazz, Pageable.unpaged()); + public SearchScrollHits searchScrollContinue(@Nullable String scrollId, long scrollTimeInMillis, Class clazz) { + + ActionFuture action = client // + .prepareSearchScroll(scrollId) // + .setScroll(TimeValue.timeValueMillis(scrollTimeInMillis)) // + .execute(); + + SearchResponse response = getSearchResponseWithTimeout(action); + + return elasticsearchConverter.readScroll(clazz, SearchDocumentResponse.from(response)); } @Override diff --git a/src/main/java/org/springframework/data/elasticsearch/core/ScrolledPage.java b/src/main/java/org/springframework/data/elasticsearch/core/ScrolledPage.java index 0e0d23158a..92ac39ce45 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/ScrolledPage.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/ScrolledPage.java @@ -2,13 +2,14 @@ package org.springframework.data.elasticsearch.core; import org.springframework.data.domain.Page; -import org.springframework.lang.Nullable; /** * @author Artur Konczak * @author Peter-Josef Meisch * @author Sascha Woo + * @deprecated will be removed in a future version. */ +@Deprecated public interface ScrolledPage extends Page { String getScrollId(); diff --git a/src/main/java/org/springframework/data/elasticsearch/core/SearchHitSupport.java b/src/main/java/org/springframework/data/elasticsearch/core/SearchHitSupport.java index 4947e62d1b..168930b3a6 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/SearchHitSupport.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/SearchHitSupport.java @@ -32,6 +32,7 @@ * Utility class with helper methods for working with {@link SearchHit}. * * @author Peter-Josef Meisch + * @author Sascha Woo * @since 4.0 */ public final class SearchHitSupport { @@ -97,8 +98,13 @@ public static Object unwrapSearchHits(Object result) { * @return the created Page */ public static AggregatedPage> page(SearchHits searchHits, Pageable pageable) { - return new AggregatedPageImpl<>(searchHits.getSearchHits(), pageable, searchHits.getTotalHits(), - searchHits.getAggregations(), searchHits.getScrollId(), searchHits.getMaxScore()); + return new AggregatedPageImpl<>( // + searchHits.getSearchHits(), // + pageable, // + searchHits.getTotalHits(), // + searchHits.getAggregations(), // + null, // + searchHits.getMaxScore()); } public static SearchPage searchPageFor(SearchHits searchHits, @Nullable Pageable pageable) { diff --git a/src/main/java/org/springframework/data/elasticsearch/core/SearchHits.java b/src/main/java/org/springframework/data/elasticsearch/core/SearchHits.java index 57c18c9aea..b0a9cb4b1a 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/SearchHits.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/SearchHits.java @@ -1,5 +1,5 @@ /* - * Copyright 2019-2020 the original author or authors. + * Copyright 2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,143 +15,74 @@ */ package org.springframework.data.elasticsearch.core; -import java.util.Collections; import java.util.Iterator; import java.util.List; import org.elasticsearch.search.aggregations.Aggregations; import org.springframework.data.util.Streamable; import org.springframework.lang.Nullable; -import org.springframework.util.Assert; -import org.springframework.util.StringUtils; /** * Encapsulates a list of {@link SearchHit}s with additional information from the search. - * + * * @param the result data class. - * @author Peter-Josef Meisch + * @author Sascha Woo * @since 4.0 */ -public class SearchHits implements Streamable> { - - private final long totalHits; - private final TotalHitsRelation totalHitsRelation; - private final float maxScore; - private final String scrollId; - private final List> searchHits; - private final Aggregations aggregations; - - /** - * @param totalHits the number of total hits for the search - * @param totalHitsRelation the relation {@see TotalHitsRelation}, must not be {@literal null} - * @param maxScore the maximum score - * @param scrollId the scroll id if available - * @param searchHits must not be {@literal null} - * @param aggregations the aggregations if available - */ - public SearchHits(long totalHits, TotalHitsRelation totalHitsRelation, float maxScore, @Nullable String scrollId, - List> searchHits, @Nullable Aggregations aggregations) { - - Assert.notNull(searchHits, "searchHits must not be null"); - - this.totalHits = totalHits; - this.totalHitsRelation = totalHitsRelation; - this.maxScore = maxScore; - this.scrollId = scrollId; - this.searchHits = searchHits; - this.aggregations = aggregations; - } - - @SuppressWarnings("unchecked") - @Override - public Iterator> iterator() { - return (Iterator>) searchHits.iterator(); - } +public interface SearchHits extends Streamable> { - // region getter /** - * @return the number of total hits. - */ - public long getTotalHits() { - return totalHits; - } - - /** - * @return the relation for the total hits + * @return the aggregations. */ - public TotalHitsRelation getTotalHitsRelation() { - return totalHitsRelation; - } + @Nullable + Aggregations getAggregations(); /** * @return the maximum score */ - public float getMaxScore() { - return maxScore; - } + float getMaxScore(); /** - * @return the scroll id + * @param index position in List. + * @return the {@link SearchHit} at position {index} + * @throws IndexOutOfBoundsException on invalid index */ - @Nullable - public String getScrollId() { - return scrollId; - } + SearchHit getSearchHit(int index); /** * @return the contained {@link SearchHit}s. */ - public List> getSearchHits() { - return Collections.unmodifiableList(searchHits); - } - // endregion + List> getSearchHits(); - // region SearchHit access /** - * @param index position in List. - * @return the {@link SearchHit} at position {index} - * @throws IndexOutOfBoundsException on invalid index + * @return the number of total hits. */ - public SearchHit getSearchHit(int index) { - return searchHits.get(index); - } - // endregion + long getTotalHits(); - @Override - public String toString() { - return "SearchHits{" + // - "totalHits=" + totalHits + // - ", totalHitsRelation=" + totalHitsRelation + // - ", maxScore=" + maxScore + // - ", scrollId='" + scrollId + '\'' + // - ", searchHits={" + searchHits.size() + " elements}" + // - ", aggregations=" + aggregations + // - '}'; - } + /** + * @return the relation for the total hits + */ + TotalHitsRelation getTotalHitsRelation(); - // region aggregations /** * @return true if aggregations are available */ - public boolean hasAggregations() { - return aggregations != null; + default boolean hasAggregations() { + return getAggregations() != null; } /** - * @return the aggregations. + * @return whether the {@link SearchHits} has search hits. */ - @Nullable - public Aggregations getAggregations() { - return aggregations; + default boolean hasSearchHits() { + return !getSearchHits().isEmpty(); } - // endregion /** - * Enum to represent the relation that Elasticsearch returns for the totalHits value {@see Ekasticsearch - * docs} + * @return an iterator for {@link SearchHit} */ - public enum TotalHitsRelation { - EQUAL_TO, GREATER_THAN_OR_EQUAL_TO + default Iterator> iterator() { + return getSearchHits().iterator(); } + } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/SearchHitsImpl.java b/src/main/java/org/springframework/data/elasticsearch/core/SearchHitsImpl.java new file mode 100644 index 0000000000..e8bf454523 --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/core/SearchHitsImpl.java @@ -0,0 +1,117 @@ +/* + * Copyright 2019-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.core; + +import java.util.Collections; +import java.util.List; + +import org.elasticsearch.search.aggregations.Aggregations; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Basic implementation of {@link SearchScrollHits} + * + * @param the result data class. + * @author Peter-Josef Meisch + * @author Sascha Woo + * @since 4.0 + */ +public class SearchHitsImpl implements SearchScrollHits { + + private final long totalHits; + private final TotalHitsRelation totalHitsRelation; + private final float maxScore; + private final String scrollId; + private final List> searchHits; + private final Aggregations aggregations; + + /** + * @param totalHits the number of total hits for the search + * @param totalHitsRelation the relation {@see TotalHitsRelation}, must not be {@literal null} + * @param maxScore the maximum score + * @param scrollId the scroll id if available + * @param searchHits must not be {@literal null} + * @param aggregations the aggregations if available + */ + public SearchHitsImpl(long totalHits, TotalHitsRelation totalHitsRelation, float maxScore, @Nullable String scrollId, + List> searchHits, @Nullable Aggregations aggregations) { + + Assert.notNull(searchHits, "searchHits must not be null"); + + this.totalHits = totalHits; + this.totalHitsRelation = totalHitsRelation; + this.maxScore = maxScore; + this.scrollId = scrollId; + this.searchHits = searchHits; + this.aggregations = aggregations; + } + + // region getter + @Override + public long getTotalHits() { + return totalHits; + } + + @Override + public TotalHitsRelation getTotalHitsRelation() { + return totalHitsRelation; + } + + @Override + public float getMaxScore() { + return maxScore; + } + + @Override + @Nullable + public String getScrollId() { + return scrollId; + } + + @Override + public List> getSearchHits() { + return Collections.unmodifiableList(searchHits); + } + // endregion + + // region SearchHit access + @Override + public SearchHit getSearchHit(int index) { + return searchHits.get(index); + } + // endregion + + @Override + public String toString() { + return "SearchHits{" + // + "totalHits=" + totalHits + // + ", totalHitsRelation=" + totalHitsRelation + // + ", maxScore=" + maxScore + // + ", scrollId='" + scrollId + '\'' + // + ", searchHits={" + searchHits.size() + " elements}" + // + ", aggregations=" + aggregations + // + '}'; + } + + // region aggregations + @Override + @Nullable + public Aggregations getAggregations() { + return aggregations; + } + // endregion +} diff --git a/src/main/java/org/springframework/data/elasticsearch/core/SearchHitsIterator.java b/src/main/java/org/springframework/data/elasticsearch/core/SearchHitsIterator.java new file mode 100644 index 0000000000..a5045f4be5 --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/core/SearchHitsIterator.java @@ -0,0 +1,60 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.core; + +import org.elasticsearch.search.aggregations.Aggregations; +import org.springframework.data.util.CloseableIterator; +import org.springframework.lang.Nullable; + +/** + * A {@link SearchHitsIterator} encapsulates {@link SearchHit} results that can be wrapped in a Java 8 + * {@link java.util.stream.Stream}. + * + * @author Sascha Woo + * @param + * @since 4.0 + */ +public interface SearchHitsIterator extends CloseableIterator> { + + /** + * @return the aggregations. + */ + @Nullable + Aggregations getAggregations(); + + /** + * @return the maximum score + */ + float getMaxScore(); + + /** + * @return the number of total hits. + */ + long getTotalHits(); + + /** + * @return the relation for the total hits + */ + TotalHitsRelation getTotalHitsRelation(); + + /** + * @return true if aggregations are available + */ + default boolean hasAggregations() { + return getAggregations() != null; + } + +} diff --git a/src/main/java/org/springframework/data/elasticsearch/core/SearchOperations.java b/src/main/java/org/springframework/data/elasticsearch/core/SearchOperations.java index 9346d7e595..4f229be168 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/SearchOperations.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/SearchOperations.java @@ -36,6 +36,7 @@ * APIs. * * @author Peter-Josef Meisch + * @author Sascha Woo * @since 4.0 */ public interface SearchOperations { @@ -155,8 +156,9 @@ default List> queryForPage(List queries, List * @param query the query to execute * @param clazz the entity clazz used for property mapping * @param index the index to run the query against - * @return a {@link CloseableIterator} that wraps an Elasticsearch scroll context that needs to be closed in case of * - * error. + * @return a {@link CloseableIterator} that wraps an Elasticsearch scroll context that needs to be closed. The + * try-with-resources construct should be used to ensure that the close method is invoked after the operations + * are completed. * @deprecated since 4.0, use {@link #searchForStream(Query, Class, IndexCoordinates)}. */ @Deprecated @@ -237,7 +239,6 @@ default AggregatedPage moreLikeThis(MoreLikeThisQuery query, Class cla return (AggregatedPage) SearchHitSupport.unwrapSearchHits(aggregatedPage); } - // endregion /** @@ -340,27 +341,29 @@ default SearchHit searchOne(Query query, Class clazz, IndexCoordinates SearchHits search(MoreLikeThisQuery query, Class clazz, IndexCoordinates index); /** - * Executes the given {@link Query} against elasticsearch and return result as {@link CloseableIterator}. + * Executes the given {@link Query} against elasticsearch and return result as {@link SearchHitsIterator}. *

* * @param element return type * @param query the query to execute * @param clazz the entity clazz used for property mapping and index name extraction - * @return a {@link CloseableIterator} that wraps an Elasticsearch scroll context that needs to be closed in case of * - * error. + * @return a {@link SearchHitsIterator} that wraps an Elasticsearch scroll context that needs to be closed. The + * try-with-resources construct should be used to ensure that the close method is invoked after the operations + * are completed. */ - CloseableIterator> searchForStream(Query query, Class clazz); + SearchHitsIterator searchForStream(Query query, Class clazz); /** - * Executes the given {@link Query} against elasticsearch and return result as {@link CloseableIterator}. + * Executes the given {@link Query} against elasticsearch and return result as {@link SearchHitsIterator}. *

* * @param element return type * @param query the query to execute * @param clazz the entity clazz used for property mapping * @param index the index to run the query against - * @return a {@link CloseableIterator} that wraps an Elasticsearch scroll context that needs to be closed in case of * - * error. + * @return a {@link SearchHitsIterator} that wraps an Elasticsearch scroll context that needs to be closed. The + * try-with-resources construct should be used to ensure that the close method is invoked after the operations + * are completed. */ - CloseableIterator> searchForStream(Query query, Class clazz, IndexCoordinates index); + SearchHitsIterator searchForStream(Query query, Class clazz, IndexCoordinates index); } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/SearchScrollHits.java b/src/main/java/org/springframework/data/elasticsearch/core/SearchScrollHits.java new file mode 100644 index 0000000000..cb7d076948 --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/core/SearchScrollHits.java @@ -0,0 +1,34 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.core; + +/** + * This interface is used to expose the current {@code scrollId} from the underlying scroll context. + *

+ * Internal use only. + * + * @author Sascha Woo + * @param + * @since 4.0 + */ +public interface SearchScrollHits extends SearchHits { + + /** + * @return the scroll id + */ + String getScrollId(); + +} diff --git a/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java b/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java index 62104b8395..a42dd767ee 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/StreamQueries.java @@ -20,7 +20,8 @@ import java.util.function.Consumer; import java.util.function.Function; -import org.springframework.data.util.CloseableIterator; +import org.elasticsearch.search.aggregations.Aggregations; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; /** @@ -35,25 +36,30 @@ abstract class StreamQueries { /** * Stream query results using {@link ScrolledPage}. * - * @param page the initial scrolled page. + * @param searchHits the initial hits * @param continueScrollFunction function to continue scrolling applies to the current scrollId. * @param clearScrollConsumer consumer to clear the scroll context by accepting the current scrollId. * @param - * @return the {@link CloseableIterator}. + * @return the {@link SearchHitsIterator}. */ - static CloseableIterator streamResults(ScrolledPage page, - Function> continueScrollFunction, Consumer clearScrollConsumer) { + static SearchHitsIterator streamResults(SearchScrollHits searchHits, + Function> continueScrollFunction, Consumer clearScrollConsumer) { - Assert.notNull(page, "page must not be null."); - Assert.notNull(page.getScrollId(), "scrollId must not be null."); + Assert.notNull(searchHits, "searchHits must not be null."); + Assert.notNull(searchHits.getScrollId(), "scrollId of searchHits must not be null."); Assert.notNull(continueScrollFunction, "continueScrollFunction must not be null."); Assert.notNull(clearScrollConsumer, "clearScrollConsumer must not be null."); - return new CloseableIterator() { + Aggregations aggregations = searchHits.getAggregations(); + float maxScore = searchHits.getMaxScore(); + long totalHits = searchHits.getTotalHits(); + TotalHitsRelation totalHitsRelation = searchHits.getTotalHitsRelation(); + + return new SearchHitsIterator() { // As we couldn't retrieve single result with scroll, store current hits. - private volatile Iterator scrollHits = page.iterator(); - private volatile String scrollId = page.getScrollId(); + private volatile Iterator> scrollHits = searchHits.iterator(); + private volatile String scrollId = searchHits.getScrollId(); private volatile boolean continueScroll = scrollHits.hasNext(); @Override @@ -67,6 +73,27 @@ public void close() { } } + @Override + @Nullable + public Aggregations getAggregations() { + return aggregations; + } + + @Override + public float getMaxScore() { + return maxScore; + } + + @Override + public long getTotalHits() { + return totalHits; + } + + @Override + public TotalHitsRelation getTotalHitsRelation() { + return totalHitsRelation; + } + @Override public boolean hasNext() { @@ -75,7 +102,7 @@ public boolean hasNext() { } if (!scrollHits.hasNext()) { - ScrolledPage nextPage = continueScrollFunction.apply(scrollId); + SearchScrollHits nextPage = continueScrollFunction.apply(scrollId); scrollHits = nextPage.iterator(); scrollId = nextPage.getScrollId(); continueScroll = scrollHits.hasNext(); @@ -85,7 +112,7 @@ public boolean hasNext() { } @Override - public T next() { + public SearchHit next() { if (hasNext()) { return scrollHits.next(); } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/TotalHitsRelation.java b/src/main/java/org/springframework/data/elasticsearch/core/TotalHitsRelation.java new file mode 100644 index 0000000000..14f069de0d --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/core/TotalHitsRelation.java @@ -0,0 +1,30 @@ +/* + * Copyright 2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.core; + +/** + * Enum to represent the relation that Elasticsearch returns for the totalHits value {@see Ekasticsearch + * docs} + * + * @author Peter-Josef Meisch + * @author Sascha Woo + * @since 4.0 + */ +public enum TotalHitsRelation { + EQUAL_TO, // + GREATER_THAN_OR_EQUAL_TO +} diff --git a/src/main/java/org/springframework/data/elasticsearch/core/convert/ElasticsearchConverter.java b/src/main/java/org/springframework/data/elasticsearch/core/convert/ElasticsearchConverter.java index d1666ee70f..eeac391238 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/convert/ElasticsearchConverter.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/convert/ElasticsearchConverter.java @@ -19,10 +19,9 @@ import java.util.stream.Collectors; import org.springframework.data.convert.EntityConverter; -import org.springframework.data.domain.Pageable; import org.springframework.data.elasticsearch.core.SearchHit; import org.springframework.data.elasticsearch.core.SearchHits; -import org.springframework.data.elasticsearch.core.aggregation.AggregatedPage; +import org.springframework.data.elasticsearch.core.SearchScrollHits; import org.springframework.data.elasticsearch.core.document.Document; import org.springframework.data.elasticsearch.core.document.SearchDocument; import org.springframework.data.elasticsearch.core.document.SearchDocumentResponse; @@ -39,6 +38,7 @@ * @author Mohsin Husen * @author Christoph Strobl * @author Peter-Josef Meisch + * @author Sasch Woo */ public interface ElasticsearchConverter extends EntityConverter, ElasticsearchPersistentProperty, Object, Document> { @@ -89,6 +89,17 @@ default List mapDocuments(List documents, Class type) { * @since 4.0 */ SearchHits read(Class type, SearchDocumentResponse searchDocumentResponse); + + /** + * builds a {@link SearchScrollHits} from a {@link SearchDocumentResponse}. + * + * @param the clazz of the type, must not be {@literal null}. + * @param type the type of the returned data, must not be {@literal null}. + * @param searchDocumentResponse the response to read from, must not be {@literal null}. + * @return a {@link SearchScrollHits} object + * @since 4.0 + */ + SearchScrollHits readScroll(Class type, SearchDocumentResponse searchDocumentResponse); /** * builds a {@link SearchHit} from a {@link SearchDocument}. @@ -101,9 +112,6 @@ default List mapDocuments(List documents, Class type) { */ SearchHit read(Class type, SearchDocument searchDocument); - AggregatedPage> mapResults(SearchDocumentResponse response, Class clazz, - @Nullable Pageable pageable); - // endregion // region write diff --git a/src/main/java/org/springframework/data/elasticsearch/core/convert/MappingElasticsearchConverter.java b/src/main/java/org/springframework/data/elasticsearch/core/convert/MappingElasticsearchConverter.java index 28d74c2f60..b3b440ff4c 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/convert/MappingElasticsearchConverter.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/convert/MappingElasticsearchConverter.java @@ -29,13 +29,13 @@ import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.core.convert.support.GenericConversionService; import org.springframework.data.convert.CustomConversions; -import org.springframework.data.domain.Pageable; import org.springframework.data.elasticsearch.ElasticsearchException; import org.springframework.data.elasticsearch.annotations.ScriptedField; +import org.springframework.data.elasticsearch.core.SearchScrollHits; import org.springframework.data.elasticsearch.core.SearchHit; import org.springframework.data.elasticsearch.core.SearchHits; -import org.springframework.data.elasticsearch.core.aggregation.AggregatedPage; -import org.springframework.data.elasticsearch.core.aggregation.impl.AggregatedPageImpl; +import org.springframework.data.elasticsearch.core.SearchHitsImpl; +import org.springframework.data.elasticsearch.core.TotalHitsRelation; import org.springframework.data.elasticsearch.core.document.Document; import org.springframework.data.elasticsearch.core.document.SearchDocument; import org.springframework.data.elasticsearch.core.document.SearchDocumentResponse; @@ -138,18 +138,31 @@ public void afterPropertiesSet() { // region read @Override - public AggregatedPage> mapResults(SearchDocumentResponse response, Class type, - @Nullable Pageable pageable) { + public SearchHits read(Class type, SearchDocumentResponse searchDocumentResponse) { + return readResponse(type, searchDocumentResponse); + } - List> results = response.getSearchDocuments().stream() // - .map(searchDocument -> read(type, searchDocument)) // - .collect(Collectors.toList()); + @Override + public SearchHit read(Class type, SearchDocument searchDocument) { + + Assert.notNull(type, "type must not be null"); + Assert.notNull(searchDocument, "searchDocument must not be null"); - return new AggregatedPageImpl<>(results, pageable, response); + String id = searchDocument.hasId() ? searchDocument.getId() : null; + float score = searchDocument.getScore(); + Object[] sortValues = searchDocument.getSortValues(); + Map> highlightFields = getHighlightsAndRemapFieldNames(type, searchDocument); + T content = mapDocument(searchDocument, type); + + return new SearchHit(id, score, sortValues, highlightFields, content); } @Override - public SearchHits read(Class type, SearchDocumentResponse searchDocumentResponse) { + public SearchScrollHits readScroll(Class type, SearchDocumentResponse searchDocumentResponse) { + return readResponse(type, searchDocumentResponse); + } + + private SearchHitsImpl readResponse(Class type, SearchDocumentResponse searchDocumentResponse) { Assert.notNull(type, "type must not be null"); Assert.notNull(searchDocumentResponse, "searchDocumentResponse must not be null"); @@ -161,25 +174,10 @@ public SearchHits read(Class type, SearchDocumentResponse searchDocume .map(searchDocument -> read(type, searchDocument)) // .collect(Collectors.toList()); Aggregations aggregations = searchDocumentResponse.getAggregations(); - SearchHits.TotalHitsRelation totalHitsRelation = SearchHits.TotalHitsRelation + TotalHitsRelation totalHitsRelation = TotalHitsRelation .valueOf(searchDocumentResponse.getTotalHitsRelation()); - return new SearchHits<>(totalHits, totalHitsRelation, maxScore, scrollId, searchHits, aggregations); - } - - @Override - public SearchHit read(Class type, SearchDocument searchDocument) { - - Assert.notNull(type, "type must not be null"); - Assert.notNull(searchDocument, "searchDocument must not be null"); - - String id = searchDocument.hasId() ? searchDocument.getId() : null; - float score = searchDocument.getScore(); - Object[] sortValues = searchDocument.getSortValues(); - Map> highlightFields = getHighlightsAndRemapFieldNames(type, searchDocument); - T content = mapDocument(searchDocument, type); - - return new SearchHit(id, score, sortValues, highlightFields, content); + return new SearchHitsImpl<>(totalHits, totalHitsRelation, maxScore, scrollId, searchHits, aggregations); } @Nullable diff --git a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java index 4c2d7a5410..4b8a631d41 100755 --- a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java @@ -28,17 +28,12 @@ import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; -import java.lang.Double; -import java.lang.Integer; -import java.lang.Long; -import java.lang.Object; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.UUID; import java.util.stream.Collectors; @@ -290,7 +285,7 @@ public void shouldReturnSearchHitsForGivenSearchQuery() { // then assertThat(searchHits).isNotNull(); assertThat(searchHits.getTotalHits()).isEqualTo(1); - assertThat(searchHits.getTotalHitsRelation()).isEqualByComparingTo(SearchHits.TotalHitsRelation.EQUAL_TO); + assertThat(searchHits.getTotalHitsRelation()).isEqualByComparingTo(TotalHitsRelation.EQUAL_TO); } @Test // DATAES-595 @@ -1055,11 +1050,11 @@ public void shouldReturnResultsWithScanAndScrollForGivenCriteriaQuery() { CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria()); criteriaQuery.setPageable(PageRequest.of(0, 10)); - ScrolledPage> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, + SearchScrollHits scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, criteriaQuery, SampleEntity.class, index); List> sampleEntities = new ArrayList<>(); - while (scroll.hasContent()) { - sampleEntities.addAll(scroll.getContent()); + while (scroll.hasSearchHits()) { + sampleEntities.addAll(scroll.getSearchHits()); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000, SampleEntity.class); } @@ -1082,11 +1077,11 @@ public void shouldReturnResultsWithScanAndScrollForGivenSearchQuery() { NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()) .withPageable(PageRequest.of(0, 10)).build(); - ScrolledPage> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, + SearchScrollHits scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, searchQuery, SampleEntity.class, index); List> sampleEntities = new ArrayList<>(); - while (scroll.hasContent()) { - sampleEntities.addAll(scroll.getContent()); + while (scroll.hasSearchHits()) { + sampleEntities.addAll(scroll.getSearchHits()); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000, SampleEntity.class); } @@ -1109,12 +1104,12 @@ public void shouldReturnResultsWithScanAndScrollForSpecifiedFieldsForCriteriaQue criteriaQuery.addFields("message"); criteriaQuery.setPageable(PageRequest.of(0, 10)); - ScrolledPage> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, + SearchScrollHits scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, criteriaQuery, SampleEntity.class, index); String scrollId = scroll.getScrollId(); List> sampleEntities = new ArrayList<>(); - while (scroll.hasContent()) { - sampleEntities.addAll(scroll.getContent()); + while (scroll.hasSearchHits()) { + sampleEntities.addAll(scroll.getSearchHits()); scrollId = scroll.getScrollId(); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class); } @@ -1136,12 +1131,12 @@ public void shouldReturnResultsWithScanAndScrollForSpecifiedFieldsForSearchCrite NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()).withFields("message") .withQuery(matchAllQuery()).withPageable(PageRequest.of(0, 10)).build(); - ScrolledPage> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, + SearchScrollHits scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, searchQuery, SampleEntity.class, index); String scrollId = scroll.getScrollId(); List> sampleEntities = new ArrayList<>(); - while (scroll.hasContent()) { - sampleEntities.addAll(scroll.getContent()); + while (scroll.hasSearchHits()) { + sampleEntities.addAll(scroll.getSearchHits()); scrollId = scroll.getScrollId(); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class); } @@ -1163,12 +1158,12 @@ public void shouldReturnResultsForScanAndScrollWithCustomResultMapperForGivenCri CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria()); criteriaQuery.setPageable(PageRequest.of(0, 10)); - ScrolledPage> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, + SearchScrollHits scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, criteriaQuery, SampleEntity.class, index); String scrollId = scroll.getScrollId(); List> sampleEntities = new ArrayList<>(); - while (scroll.hasContent()) { - sampleEntities.addAll(scroll.getContent()); + while (scroll.hasSearchHits()) { + sampleEntities.addAll(scroll.getSearchHits()); scrollId = scroll.getScrollId(); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class); } @@ -1190,12 +1185,12 @@ public void shouldReturnResultsForScanAndScrollWithCustomResultMapperForGivenSea NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()) .withPageable(PageRequest.of(0, 10)).build(); - ScrolledPage> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, + SearchScrollHits scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, searchQuery, SampleEntity.class, index); String scrollId = scroll.getScrollId(); List> sampleEntities = new ArrayList<>(); - while (scroll.hasContent()) { - sampleEntities.addAll(scroll.getContent()); + while (scroll.hasSearchHits()) { + sampleEntities.addAll(scroll.getSearchHits()); scrollId = scroll.getScrollId(); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class); } @@ -1217,12 +1212,12 @@ public void shouldReturnResultsWithScanAndScrollForGivenCriteriaQueryAndClass() CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria()); criteriaQuery.setPageable(PageRequest.of(0, 10)); - ScrolledPage> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, + SearchScrollHits scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, criteriaQuery, SampleEntity.class, index); String scrollId = scroll.getScrollId(); List> sampleEntities = new ArrayList<>(); - while (scroll.hasContent()) { - sampleEntities.addAll(scroll.getContent()); + while (scroll.hasSearchHits()) { + sampleEntities.addAll(scroll.getSearchHits()); scrollId = scroll.getScrollId(); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class); } @@ -1244,12 +1239,12 @@ public void shouldReturnResultsWithScanAndScrollForGivenSearchQueryAndClass() { NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()) .withPageable(PageRequest.of(0, 10)).build(); - ScrolledPage> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, + SearchScrollHits scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, searchQuery, SampleEntity.class, index); String scrollId = scroll.getScrollId(); List> sampleEntities = new ArrayList<>(); - while (scroll.hasContent()) { - sampleEntities.addAll(scroll.getContent()); + while (scroll.hasSearchHits()) { + sampleEntities.addAll(scroll.getSearchHits()); scrollId = scroll.getScrollId(); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scrollId, 1000, SampleEntity.class); } @@ -1529,16 +1524,16 @@ public void shouldPassIndicesOptionsForGivenSearchScrollQuery() { NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()) .withIndicesOptions(IndicesOptions.lenientExpandOpen()).build(); - ScrolledPage> scroll = ((AbstractElasticsearchTemplate) operations) + SearchScrollHits scroll = ((AbstractElasticsearchTemplate) operations) .searchScrollStart(scrollTimeInMillis, searchQuery, SampleEntity.class, index); - List> entities = new ArrayList<>(scroll.getContent()); + List> entities = new ArrayList<>(scroll.getSearchHits()); - while (scroll.hasContent()) { + while (scroll.hasSearchHits()) { scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), scrollTimeInMillis, SampleEntity.class); - entities.addAll(scroll.getContent()); + entities.addAll(scroll.getSearchHits()); } // then @@ -2431,11 +2426,11 @@ public void shouldApplyCriteriaQueryToScanAndScrollForGivenCriteriaQuery() { CriteriaQuery criteriaQuery = new CriteriaQuery(new Criteria("message").contains("message")); criteriaQuery.setPageable(PageRequest.of(0, 10)); - ScrolledPage> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, + SearchScrollHits scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, criteriaQuery, SampleEntity.class, index); List> sampleEntities = new ArrayList<>(); - while (scroll.hasContent()) { - sampleEntities.addAll(scroll.getContent()); + while (scroll.hasSearchHits()) { + sampleEntities.addAll(scroll.getSearchHits()); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000, SampleEntity.class); } @@ -2469,11 +2464,11 @@ public void shouldApplySearchQueryToScanAndScrollForGivenSearchQuery() { NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchQuery("message", "message")) .withPageable(PageRequest.of(0, 10)).build(); - ScrolledPage> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, + SearchScrollHits scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, searchQuery, SampleEntity.class, index); List> sampleEntities = new ArrayList<>(); - while (scroll.hasContent()) { - sampleEntities.addAll(scroll.getContent()); + while (scroll.hasSearchHits()) { + sampleEntities.addAll(scroll.getSearchHits()); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000, SampleEntity.class); } @@ -2502,11 +2497,11 @@ public void shouldRespectSourceFilterWithScanAndScrollForGivenSearchQuery() { NativeSearchQuery searchQuery = new NativeSearchQueryBuilder().withQuery(matchAllQuery()) .withPageable(PageRequest.of(0, 10)).withSourceFilter(sourceFilter).build(); - ScrolledPage> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, + SearchScrollHits scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, searchQuery, SampleEntity.class, index); List> sampleEntities = new ArrayList<>(); - while (scroll.hasContent()) { - sampleEntities.addAll(scroll.getContent()); + while (scroll.hasSearchHits()) { + sampleEntities.addAll(scroll.getSearchHits()); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000, SampleEntity.class); } @@ -2549,11 +2544,11 @@ public void shouldSortResultsGivenSortCriteriaWithScanAndScroll() { .withSort(new FieldSortBuilder("message").order(SortOrder.DESC)).withPageable(PageRequest.of(0, 10)).build(); // when - ScrolledPage> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, + SearchScrollHits scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, searchQuery, SampleEntity.class, index); List> sampleEntities = new ArrayList<>(); - while (scroll.hasContent()) { - sampleEntities.addAll(scroll.getContent()); + while (scroll.hasSearchHits()) { + sampleEntities.addAll(scroll.getSearchHits()); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000, SampleEntity.class); } @@ -2598,11 +2593,11 @@ public void shouldSortResultsGivenSortCriteriaFromPageableWithScanAndScroll() { .build(); // when - ScrolledPage> scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, + SearchScrollHits scroll = ((AbstractElasticsearchTemplate) operations).searchScrollStart(1000, searchQuery, SampleEntity.class, index); List> sampleEntities = new ArrayList<>(); - while (scroll.hasContent()) { - sampleEntities.addAll(scroll.getContent()); + while (scroll.hasSearchHits()) { + sampleEntities.addAll(scroll.getSearchHits()); scroll = ((AbstractElasticsearchTemplate) operations).searchScrollContinue(scroll.getScrollId(), 1000, SampleEntity.class); } diff --git a/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java b/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java index 0508ee488a..9aac33fd0d 100644 --- a/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/StreamQueriesTest.java @@ -19,12 +19,14 @@ import java.util.ArrayList; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; +import org.elasticsearch.search.aggregations.Aggregations; import org.junit.jupiter.api.Test; import org.springframework.data.domain.PageImpl; -import org.springframework.data.util.CloseableIterator; +import org.springframework.data.domain.Pageable; import org.springframework.lang.Nullable; /** @@ -36,42 +38,51 @@ public class StreamQueriesTest { public void shouldCallClearScrollOnIteratorClose() { // given - List results = new ArrayList<>(); - results.add("one"); + List> hits = new ArrayList<>(); + hits.add(new SearchHit(null, 0, null, null, "one")); - ScrolledPage page = new ScrolledPageImpl("1234", results); + SearchScrollHits searchHits = newSearchScrollHits(hits); AtomicBoolean clearScrollCalled = new AtomicBoolean(false); // when - CloseableIterator closeableIterator = StreamQueries.streamResults( // - page, // - scrollId -> new ScrolledPageImpl(scrollId, Collections.emptyList()), // + SearchHitsIterator iterator = StreamQueries.streamResults( // + searchHits, // + scrollId -> newSearchScrollHits(Collections.emptyList()), // scrollId -> clearScrollCalled.set(true)); - while (closeableIterator.hasNext()) { - closeableIterator.next(); + while (iterator.hasNext()) { + iterator.next(); } - closeableIterator.close(); + iterator.close(); // then assertThat(clearScrollCalled).isTrue(); } - private static class ScrolledPageImpl extends PageImpl implements ScrolledPage { + @Test // DATAES-766 + public void shouldReturnTotalHits() { - private String scrollId; + // given + List> hits = new ArrayList<>(); + hits.add(new SearchHit(null, 0, null, null, "one")); - public ScrolledPageImpl(String scrollId, List content) { - super(content); - this.scrollId = scrollId; - } + SearchScrollHits searchHits = newSearchScrollHits(hits); - @Override - @Nullable - public String getScrollId() { - return scrollId; - } + // when + SearchHitsIterator iterator = StreamQueries.streamResults( // + searchHits, // + scrollId -> newSearchScrollHits(Collections.emptyList()), // + scrollId -> { + }); + + // then + assertThat(iterator.getTotalHits()).isEqualTo(1); + + } + + private SearchScrollHits newSearchScrollHits(List> hits) { + return new SearchHitsImpl(1, TotalHitsRelation.EQUAL_TO, 0, "1234", hits, null); } }