Skip to content

Commit

Permalink
DATAES-766 - Replace CloseableIterator with SearchHitsIterator in str…
Browse files Browse the repository at this point in the history
…eam operations.
  • Loading branch information
xhaggi committed Mar 20, 2020
1 parent f103bdb commit 5ada297
Show file tree
Hide file tree
Showing 16 changed files with 491 additions and 251 deletions.
Expand Up @@ -208,20 +208,25 @@ public long count(Query query, Class<?> clazz) {

@Override
public <T> CloseableIterator<T> stream(Query query, Class<T> clazz, IndexCoordinates index) {

long scrollTimeInMillis = TimeValue.timeValueMinutes(1).millis();
return (CloseableIterator<T>) SearchHitSupport.unwrapSearchHits(searchForStream(query, clazz, index));
}

@Override
public <T> CloseableIterator<SearchHit<T>> searchForStream(Query query, Class<T> clazz) {
public <T> SearchHitsIterator<T> searchForStream(Query query, Class<T> clazz) {
return searchForStream(query, clazz, getIndexCoordinatesFor(clazz));
}

@Override
public <T> CloseableIterator<SearchHit<T>> searchForStream(Query query, Class<T> clazz, IndexCoordinates index) {
public <T> SearchHitsIterator<T> searchForStream(Query query, Class<T> clazz, IndexCoordinates index) {

long scrollTimeInMillis = TimeValue.timeValueMinutes(1).millis();
return StreamQueries.streamResults(searchScrollStart(scrollTimeInMillis, query, clazz, index),
scrollId -> searchScrollContinue(scrollId, scrollTimeInMillis, clazz), this::searchScrollClear);

return StreamQueries.streamResults( //
searchScrollStart(scrollTimeInMillis, query, clazz, index), //
scrollId -> searchScrollContinue(scrollId, scrollTimeInMillis, clazz), //
this::searchScrollClear);
}

@Override
Expand Down Expand Up @@ -283,13 +288,13 @@ public <T> SearchHits<T> search(Query query, Class<T> clazz) {
/*
* internal use only, not for public API
*/
abstract protected <T> ScrolledPage<SearchHit<T>> searchScrollStart(long scrollTimeInMillis, Query query,
abstract protected <T> SearchScrollHits<T> searchScrollStart(long scrollTimeInMillis, Query query,
Class<T> clazz, IndexCoordinates index);

/*
* internal use only, not for public API
*/
abstract protected <T> ScrolledPage<SearchHit<T>> searchScrollContinue(@Nullable String scrollId,
abstract protected <T> SearchScrollHits<T> searchScrollContinue(@Nullable String scrollId,
long scrollTimeInMillis, Class<T> clazz);

/*
Expand Down
Expand Up @@ -39,7 +39,6 @@
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.search.suggest.SuggestBuilder;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.core.convert.ElasticsearchConverter;
import org.springframework.data.elasticsearch.core.document.DocumentAdapters;
import org.springframework.data.elasticsearch.core.document.SearchDocumentResponse;
Expand Down Expand Up @@ -257,24 +256,28 @@ public <T> SearchHits<T> search(Query query, Class<T> clazz, IndexCoordinates in
}

@Override
public <T> ScrolledPage<SearchHit<T>> searchScrollStart(long scrollTimeInMillis, Query query, Class<T> clazz,
public <T> SearchScrollHits<T> searchScrollStart(long scrollTimeInMillis, Query query, Class<T> clazz,
IndexCoordinates index) {

Assert.notNull(query.getPageable(), "Query.pageable is required for scan & scroll");
Assert.notNull(query.getPageable(), "pageable of query must not be null.");

SearchRequest searchRequest = requestFactory.searchRequest(query, clazz, index);
searchRequest.scroll(TimeValue.timeValueMillis(scrollTimeInMillis));
SearchResponse result = execute(client -> client.search(searchRequest, RequestOptions.DEFAULT));
return elasticsearchConverter.mapResults(SearchDocumentResponse.from(result), clazz, null);

SearchResponse response = execute(client -> client.search(searchRequest, RequestOptions.DEFAULT));

return elasticsearchConverter.readScroll(clazz, SearchDocumentResponse.from(response));
}

@Override
public <T> ScrolledPage<SearchHit<T>> searchScrollContinue(@Nullable String scrollId, long scrollTimeInMillis,
Class<T> clazz) {
public <T> SearchScrollHits<T> searchScrollContinue(@Nullable String scrollId, long scrollTimeInMillis, Class<T> clazz) {

SearchScrollRequest request = new SearchScrollRequest(scrollId);
request.scroll(TimeValue.timeValueMillis(scrollTimeInMillis));
SearchResponse response = execute(client -> client.searchScroll(request, RequestOptions.DEFAULT));
return elasticsearchConverter.mapResults(SearchDocumentResponse.from(response), clazz, Pageable.unpaged());

SearchResponse response = execute(client -> client.scroll(request, RequestOptions.DEFAULT));

return elasticsearchConverter.readScroll(clazz, SearchDocumentResponse.from(response));
}

@Override
Expand Down
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.action.search.MultiSearchResponse;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollRequestBuilder;
import org.elasticsearch.action.update.UpdateRequestBuilder;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.unit.TimeValue;
Expand Down Expand Up @@ -260,22 +261,32 @@ public <T> SearchHits<T> search(Query query, Class<T> clazz, IndexCoordinates in
}

@Override
public <T> ScrolledPage<SearchHit<T>> searchScrollStart(long scrollTimeInMillis, Query query, Class<T> clazz,
public <T> SearchScrollHits<T> searchScrollStart(long scrollTimeInMillis, Query query, Class<T> clazz,
IndexCoordinates index) {
Assert.notNull(query.getPageable(), "Query.pageable is required for scan & scroll");

SearchRequestBuilder searchRequestBuilder = requestFactory.searchRequestBuilder(client, query, clazz, index);
searchRequestBuilder.setScroll(TimeValue.timeValueMillis(scrollTimeInMillis));
SearchResponse response = getSearchResponse(searchRequestBuilder);
return elasticsearchConverter.mapResults(SearchDocumentResponse.from(response), clazz, null);
Assert.notNull(query.getPageable(), "pageable of query must not be null.");

ActionFuture<SearchResponse> action = requestFactory //
.searchRequestBuilder(client, query, clazz, index) //
.setScroll(TimeValue.timeValueMillis(scrollTimeInMillis)) //
.execute();

SearchResponse response = getSearchResponseWithTimeout(action);

return elasticsearchConverter.readScroll(clazz, SearchDocumentResponse.from(response));
}

@Override
public <T> ScrolledPage<SearchHit<T>> searchScrollContinue(@Nullable String scrollId, long scrollTimeInMillis,
Class<T> clazz) {
SearchResponse response = getSearchResponseWithTimeout(
client.prepareSearchScroll(scrollId).setScroll(TimeValue.timeValueMillis(scrollTimeInMillis)).execute());
return elasticsearchConverter.mapResults(SearchDocumentResponse.from(response), clazz, Pageable.unpaged());
public <T> SearchScrollHits<T> searchScrollContinue(@Nullable String scrollId, long scrollTimeInMillis, Class<T> clazz) {

ActionFuture<SearchResponse> action = client //
.prepareSearchScroll(scrollId) //
.setScroll(TimeValue.timeValueMillis(scrollTimeInMillis)) //
.execute();

SearchResponse response = getSearchResponseWithTimeout(action);

return elasticsearchConverter.readScroll(clazz, SearchDocumentResponse.from(response));
}

@Override
Expand Down
Expand Up @@ -2,13 +2,14 @@
package org.springframework.data.elasticsearch.core;

import org.springframework.data.domain.Page;
import org.springframework.lang.Nullable;

/**
* @author Artur Konczak
* @author Peter-Josef Meisch
* @author Sascha Woo
* @deprecated will be removed in a future version.
*/
@Deprecated
public interface ScrolledPage<T> extends Page<T> {

String getScrollId();
Expand Down
Expand Up @@ -32,6 +32,7 @@
* Utility class with helper methods for working with {@link SearchHit}.
*
* @author Peter-Josef Meisch
* @author Sascha Woo
* @since 4.0
*/
public final class SearchHitSupport {
Expand Down Expand Up @@ -97,8 +98,13 @@ public static Object unwrapSearchHits(Object result) {
* @return the created Page
*/
public static <T> AggregatedPage<SearchHit<T>> page(SearchHits<T> searchHits, Pageable pageable) {
return new AggregatedPageImpl<>(searchHits.getSearchHits(), pageable, searchHits.getTotalHits(),
searchHits.getAggregations(), searchHits.getScrollId(), searchHits.getMaxScore());
return new AggregatedPageImpl<>( //
searchHits.getSearchHits(), //
pageable, //
searchHits.getTotalHits(), //
searchHits.getAggregations(), //
null, //
searchHits.getMaxScore());
}

public static <T> SearchPage<T> searchPageFor(SearchHits<T> searchHits, @Nullable Pageable pageable) {
Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright 2019-2020 the original author or authors.
* Copyright 2020 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,143 +15,74 @@
*/
package org.springframework.data.elasticsearch.core;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;

import org.elasticsearch.search.aggregations.Aggregations;
import org.springframework.data.util.Streamable;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* Encapsulates a list of {@link SearchHit}s with additional information from the search.
*
*
* @param <T> the result data class.
* @author Peter-Josef Meisch
* @author Sascha Woo
* @since 4.0
*/
public class SearchHits<T> implements Streamable<SearchHit<T>> {

private final long totalHits;
private final TotalHitsRelation totalHitsRelation;
private final float maxScore;
private final String scrollId;
private final List<? extends SearchHit<T>> searchHits;
private final Aggregations aggregations;

/**
* @param totalHits the number of total hits for the search
* @param totalHitsRelation the relation {@see TotalHitsRelation}, must not be {@literal null}
* @param maxScore the maximum score
* @param scrollId the scroll id if available
* @param searchHits must not be {@literal null}
* @param aggregations the aggregations if available
*/
public SearchHits(long totalHits, TotalHitsRelation totalHitsRelation, float maxScore, @Nullable String scrollId,
List<? extends SearchHit<T>> searchHits, @Nullable Aggregations aggregations) {

Assert.notNull(searchHits, "searchHits must not be null");

this.totalHits = totalHits;
this.totalHitsRelation = totalHitsRelation;
this.maxScore = maxScore;
this.scrollId = scrollId;
this.searchHits = searchHits;
this.aggregations = aggregations;
}

@SuppressWarnings("unchecked")
@Override
public Iterator<SearchHit<T>> iterator() {
return (Iterator<SearchHit<T>>) searchHits.iterator();
}
public interface SearchHits<T> extends Streamable<SearchHit<T>> {

// region getter
/**
* @return the number of total hits.
*/
public long getTotalHits() {
return totalHits;
}

/**
* @return the relation for the total hits
* @return the aggregations.
*/
public TotalHitsRelation getTotalHitsRelation() {
return totalHitsRelation;
}
@Nullable
Aggregations getAggregations();

/**
* @return the maximum score
*/
public float getMaxScore() {
return maxScore;
}
float getMaxScore();

/**
* @return the scroll id
* @param index position in List.
* @return the {@link SearchHit} at position {index}
* @throws IndexOutOfBoundsException on invalid index
*/
@Nullable
public String getScrollId() {
return scrollId;
}
SearchHit<T> getSearchHit(int index);

/**
* @return the contained {@link SearchHit}s.
*/
public List<SearchHit<T>> getSearchHits() {
return Collections.unmodifiableList(searchHits);
}
// endregion
List<SearchHit<T>> getSearchHits();

// region SearchHit access
/**
* @param index position in List.
* @return the {@link SearchHit} at position {index}
* @throws IndexOutOfBoundsException on invalid index
* @return the number of total hits.
*/
public SearchHit<T> getSearchHit(int index) {
return searchHits.get(index);
}
// endregion
long getTotalHits();

@Override
public String toString() {
return "SearchHits{" + //
"totalHits=" + totalHits + //
", totalHitsRelation=" + totalHitsRelation + //
", maxScore=" + maxScore + //
", scrollId='" + scrollId + '\'' + //
", searchHits={" + searchHits.size() + " elements}" + //
", aggregations=" + aggregations + //
'}';
}
/**
* @return the relation for the total hits
*/
TotalHitsRelation getTotalHitsRelation();

// region aggregations
/**
* @return true if aggregations are available
*/
public boolean hasAggregations() {
return aggregations != null;
default boolean hasAggregations() {
return getAggregations() != null;
}

/**
* @return the aggregations.
* @return whether the {@link SearchHits} has search hits.
*/
@Nullable
public Aggregations getAggregations() {
return aggregations;
default boolean hasSearchHits() {
return !getSearchHits().isEmpty();
}
// endregion

/**
* Enum to represent the relation that Elasticsearch returns for the totalHits value {@see <a href=
* "https://www.elastic.co/guide/en/elasticsearch/reference/7.5/search-request-body.html#request-body-search-track-total-hits">Ekasticsearch
* docs</a>}
* @return an iterator for {@link SearchHit}
*/
public enum TotalHitsRelation {
EQUAL_TO, GREATER_THAN_OR_EQUAL_TO
default Iterator<SearchHit<T>> iterator() {
return getSearchHits().iterator();
}

}

0 comments on commit 5ada297

Please sign in to comment.