Skip to content

Commit

Permalink
Decouple request retry logic
Browse files Browse the repository at this point in the history
  • Loading branch information
regadas committed Apr 9, 2019
1 parent 8c0ebf5 commit b7f6f64
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 163 deletions.
Expand Up @@ -37,14 +37,17 @@
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.ProcessFunction;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.Repeatedly;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.sdk.util.BackOffUtils;
import org.apache.beam.sdk.util.FluentBackoff;
import org.apache.beam.sdk.util.Sleeper;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
Expand Down Expand Up @@ -77,7 +80,7 @@ public static <T> Bound<T> withClusterName(String clusterName) {
*
* @param servers endpoints for the Elasticsearch cluster
*/
public static<T> Bound<T> withServers(InetSocketAddress[] servers) {
public static <T> Bound<T> withServers(InetSocketAddress[] servers) {
return new Bound<T>().withServers(servers);
}

Expand All @@ -87,7 +90,7 @@ public static<T> Bound<T> withServers(InetSocketAddress[] servers) {
*
* @param flushInterval delay applied to buffer elements. Defaulted to 1 seconds.
*/
public static<T> Bound withFlushInterval(Duration flushInterval) {
public static <T> Bound withFlushInterval(Duration flushInterval) {
return new Bound<T>().withFlushInterval(flushInterval);
}

Expand All @@ -96,7 +99,8 @@ public static<T> Bound withFlushInterval(Duration flushInterval) {
*
* @param function creates IndexRequest required by Elasticsearch client
*/
public static<T> Bound withFunction(SerializableFunction<T, Iterable<ActionRequest<?>>> function) {
public static <T> Bound withFunction(
SerializableFunction<T, Iterable<ActionRequest<?>>> function) {
return new Bound<T>().withFunction(function);
}

Expand All @@ -106,7 +110,7 @@ public static<T> Bound withFunction(SerializableFunction<T, Iterable<ActionReque
*
* @param numOfShard to construct a batch to bulk write to Elasticsearch.
*/
public static<T> Bound withNumOfShard(long numOfShard) {
public static <T> Bound withNumOfShard(long numOfShard) {
return new Bound<>().withNumOfShard(numOfShard);
}

Expand All @@ -116,11 +120,11 @@ public static<T> Bound withNumOfShard(long numOfShard) {
* @param error applies given function if specified in case of
* Elasticsearch error with bulk writes. Default behavior throws IOException.
*/
public static<T> Bound withError(ThrowingConsumer<BulkExecutionException> error) {
public static <T> Bound withError(ThrowingConsumer<BulkExecutionException> error) {
return new Bound<>().withError(error);
}

public static<T> Bound withMaxBulkRequestSize(int maxBulkRequestSize) {
public static <T> Bound withMaxBulkRequestSize(int maxBulkRequestSize) {
return new Bound<>().withMaxBulkRequestSize(maxBulkRequestSize);
}

Expand Down Expand Up @@ -180,61 +184,62 @@ private Bound(final String clusterName,

Bound() {
this(null, null, null, null, 0, CHUNK_SIZE, DEFAULT_RETRIES, DEFAULT_RETRY_PAUSE,
defaultErrorHandler());
defaultErrorHandler());
}

public Bound<T> withClusterName(String clusterName) {
return new Bound<>(clusterName, servers, flushInterval, toActionRequests, numOfShard,
maxBulkRequestSize,
maxRetries, retryPause, error);
maxBulkRequestSize,
maxRetries, retryPause, error);
}

public Bound<T> withServers(InetSocketAddress[] servers) {
return new Bound<>(clusterName, servers, flushInterval, toActionRequests, numOfShard,
maxBulkRequestSize,
maxRetries, retryPause, error);
maxBulkRequestSize,
maxRetries, retryPause, error);
}

public Bound<T> withFlushInterval(Duration flushInterval) {
return new Bound<>(clusterName, servers, flushInterval, toActionRequests, numOfShard,
maxBulkRequestSize,
maxRetries, retryPause, error);
maxBulkRequestSize,
maxRetries, retryPause, error);
}

public Bound<T> withFunction(SerializableFunction<T, Iterable<ActionRequest<?>>> toIndexRequest) {
public Bound<T> withFunction(
SerializableFunction<T, Iterable<ActionRequest<?>>> toIndexRequest) {
return new Bound<>(clusterName, servers, flushInterval, toIndexRequest, numOfShard,
maxBulkRequestSize,
maxRetries, retryPause, error);
maxBulkRequestSize,
maxRetries, retryPause, error);
}

public Bound<T> withNumOfShard(long numOfShard) {
return new Bound<>(clusterName, servers, flushInterval, toActionRequests, numOfShard,
maxBulkRequestSize,
maxRetries, retryPause, error);
maxBulkRequestSize,
maxRetries, retryPause, error);
}

public Bound<T> withError(ThrowingConsumer<BulkExecutionException> error) {
return new Bound<>(clusterName, servers, flushInterval, toActionRequests, numOfShard,
maxBulkRequestSize,
maxRetries, retryPause, error);
maxBulkRequestSize,
maxRetries, retryPause, error);
}

public Bound<T> withMaxBulkRequestSize(int maxBulkRequestSize) {
return new Bound<>(clusterName, servers, flushInterval, toActionRequests, numOfShard,
maxBulkRequestSize,
maxRetries, retryPause, error);
maxBulkRequestSize,
maxRetries, retryPause, error);
}

public Bound<T> withMaxRetries(int maxRetries) {
return new Bound<>(clusterName, servers, flushInterval, toActionRequests, numOfShard,
maxBulkRequestSize,
maxRetries, retryPause, error);
maxBulkRequestSize,
maxRetries, retryPause, error);
}

public Bound<T> withRetryPause(int retryPause) {
return new Bound<>(clusterName, servers, flushInterval, toActionRequests, numOfShard,
maxBulkRequestSize,
maxRetries, retryPause, error);
maxBulkRequestSize,
maxRetries, retryPause, error);
}

@Override
Expand All @@ -255,17 +260,18 @@ public PDone expand(final PCollection<T> input) {
.pastFirstElementInPane()
.plusDelayOf(flushInterval)))
.discardingFiredPanes()
.withTimestampCombiner(TimestampCombiner.END_OF_WINDOW))
.withTimestampCombiner(TimestampCombiner.END_OF_WINDOW))
.apply(GroupByKey.create())
.apply("Write to Elasticesearch",
ParDo.of(new ElasticsearchWriter<>
(clusterName, servers, maxBulkRequestSize, toActionRequests, error,
maxRetries, retryPause)));
(clusterName, servers, maxBulkRequestSize, toActionRequests, error,
maxRetries, retryPause)));
return PDone.in(input.getPipeline());
}
}

private static class AssignToShard<T> extends DoFn<T, KV<Long, T>> {

private final long numOfShard;

public AssignToShard(long numOfShard) {
Expand All @@ -281,7 +287,13 @@ public void processElement(ProcessContext c) throws Exception {
}

private static class ElasticsearchWriter<T> extends DoFn<KV<Long, Iterable<T>>, Void> {

private static final Logger LOG = LoggerFactory.getLogger(ElasticsearchWriter.class);
private static final String RETRY_ATTEMPT_LOG =
"Error writing to Elasticsearch. Retry attempt[%d]";
private static final String RETRY_FAILED_LOG =
"Error writing to ES after %d attempt(s). No more attempts allowed";

private final ClientSupplier clientSupplier;
private final SerializableFunction<T, Iterable<ActionRequest<?>>> toActionRequests;
private final ThrowingConsumer<BulkExecutionException> error;
Expand All @@ -291,12 +303,12 @@ private static class ElasticsearchWriter<T> extends DoFn<KV<Long, Iterable<T>>,
private final int retryPause;

public ElasticsearchWriter(String clusterName,
InetSocketAddress[] servers,
int maxBulkRequestSize,
SerializableFunction<T, Iterable<ActionRequest<?>>> toActionRequests,
ThrowingConsumer<BulkExecutionException> error,
int maxRetries,
int retryPause) {
InetSocketAddress[] servers,
int maxBulkRequestSize,
SerializableFunction<T, Iterable<ActionRequest<?>>> toActionRequests,
ThrowingConsumer<BulkExecutionException> error,
int maxRetries,
int retryPause) {
this.maxBulkRequestSize = maxBulkRequestSize;
this.clientSupplier = new ClientSupplier(clusterName, servers);
this.toActionRequests = toActionRequests;
Expand Down Expand Up @@ -332,60 +344,66 @@ public void processElement(ProcessContext c) throws Exception {
final Iterable<List<ActionRequest>> chunks =
Iterables.partition(actionRequests::iterator, maxBulkRequestSize);

chunks.forEach(chunk -> {
Exception exception;
final ProcessFunction<List<ActionRequest>, BulkResponse> requestFn =
request(clientSupplier, error);
final ProcessFunction<List<ActionRequest>, BulkResponse> retryFn =
retry(requestFn, backoffConfig);

for (final List<ActionRequest> chunk : chunks) {
try {
requestFn.apply(chunk);
} catch (Exception e) {
retryFn.apply(chunk);
}
}
}

private static ProcessFunction<List<ActionRequest>, BulkResponse> request(
final ClientSupplier clientSupplier,
final ThrowingConsumer<BulkExecutionException> bulkErrorHandler) {
return chunk -> {
final BulkRequest bulkRequest = new BulkRequest().add(chunk);
final BulkResponse bulkItemResponse = clientSupplier.get().bulk(bulkRequest).get();

if (bulkItemResponse.hasFailures()) {
bulkErrorHandler.accept(new BulkExecutionException(bulkItemResponse));
}

return bulkItemResponse;
};
}

private static ProcessFunction<List<ActionRequest>, BulkResponse> retry(
final ProcessFunction<List<ActionRequest>, BulkResponse> requestFn,
final FluentBackoff backoffConfig) {
return chunk -> {
final BackOff backoff = backoffConfig.backoff();
int attempt = 0;
BulkResponse response = null;
Exception exception = null;

do {
while (response == null && BackOffUtils.next(Sleeper.DEFAULT, backoff)) {
LOG.warn(String.format(RETRY_ATTEMPT_LOG, ++attempt));
try {
final BulkRequest bulkRequest = new BulkRequest().add(chunk).refresh(false);
final BulkResponse bulkItemResponse = clientSupplier.get().bulk(bulkRequest).get();
if (bulkItemResponse.hasFailures()) {
throw new BulkExecutionException(bulkItemResponse);
} else {
exception = null;
break;
}
response = requestFn.apply(chunk);
exception = null;
} catch (Exception e) {
exception = e;

LOG.error(
"ElasticsearchWriter: Failed to bulk save chunk of " +
Objects.toString(chunk.size()),
exception);

// Backoff
try {
final long sleepTime = backoff.nextBackOffMillis();
if (sleepTime == BackOff.STOP) {
break;
}
Thread.sleep(sleepTime);
} catch (InterruptedException | IOException e1) {
LOG.error("Interrupt during backoff", e1);
break;
}
}
} while (true);
}

try {
if (exception != null) {
if (exception instanceof BulkExecutionException) {
// This may result in no exception being thrown, depending on callback.
error.accept((BulkExecutionException) exception);
} else {
throw exception;
}
}
} catch (Exception e) {
throw new RuntimeException(e);
if (exception != null) {
throw new Exception(String.format(RETRY_FAILED_LOG, attempt), exception);
}
});

return response;
};
}

}

private static class ClientSupplier implements Supplier<Client>, Serializable {

private final AtomicReference<Client> CLIENT = new AtomicReference<>();
private final String clusterName;
private final InetSocketAddress[] addresses;
Expand All @@ -394,6 +412,7 @@ public ClientSupplier(final String clusterName, final InetSocketAddress[] addres
this.clusterName = clusterName;
this.addresses = addresses;
}

@Override
public Client get() {
if (CLIENT.get() == null) {
Expand Down Expand Up @@ -432,6 +451,7 @@ private static ThrowingConsumer<BulkExecutionException> defaultErrorHandler() {
* An exception that puts information about the failures in the bulk execution.
*/
public static class BulkExecutionException extends IOException {

private final Iterable<Throwable> failures;

BulkExecutionException(BulkResponse bulkResponse) {
Expand Down

0 comments on commit b7f6f64

Please sign in to comment.