Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Copyright (c) 2017, salesforce.com, inc.
* All rights reserved.
* Licensed under the BSD 3-Clause license.
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
*/

package com.salesforce.grpc.contrib.interceptor;

import com.google.common.annotations.VisibleForTesting;
import io.grpc.*;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Function;

import static com.google.common.base.Preconditions.checkNotNull;

/**
* {@code DefaultCallOptionsClientInterceptor} applies specified gRPC {@code CallOptions} to every outbound request.
* By default, {@code DefaultCallOptionsClientInterceptor} will not overwrite {@code CallOptions} already set on the
* outbound request.
*
* <p>Example uses include:
* <ul>
* <li>Applying a set of {@code CallCredentials} to every request from any stub.</li>
* <li>Applying a compression strategy to every request.</li>
* <li>Attaching a custom {@code CallOptions.Key<T>} to every request.</li>
* <li>Setting the {@code WaitForReady} bit on every request.</li>
* <li>Preventing upstream users from tweaking {@code CallOptions} values by forcibly overwriting the value with a
* specific default.</li>
* </ul>
*/
public class DefaultCallOptionsClientInterceptor implements ClientInterceptor {
private static final Field CUSTOM_OPTIONS_FIELD = getCustomOptionsField();

private static Field getCustomOptionsField() {
try {
Field f;
f = CallOptions.class.getDeclaredField("customOptions");
f.setAccessible(true);
return f;
} catch (NoSuchFieldException e) {
throw new RuntimeException(e);
}
}

private CallOptions defaultOptions;
private boolean overwrite = false;

/**
* Constructs a {@code DefaultCallOptionsClientInterceptor}.
* @param options the set of {@code CallOptions} to apply to every call
*/
public DefaultCallOptionsClientInterceptor(CallOptions options) {
this.defaultOptions = checkNotNull(options, "defaultOptions");
}

/**
* Instructs the interceptor to overwrite {@code CallOptions} values even if they are already present on the
* outbound request.
*
* @return this
*/
public DefaultCallOptionsClientInterceptor overwriteExistingValues() {
this.overwrite = true;
return this;
}

public CallOptions getDefaultOptions() {
return defaultOptions;
}

public void setDefaultOptions(CallOptions options) {
this.defaultOptions = options;
}

@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
return next.newCall(method, patchOptions(callOptions));
}

@VisibleForTesting
CallOptions patchOptions(CallOptions baseOptions) {
CallOptions patchedOptions = baseOptions;

patchedOptions = patchOption(patchedOptions, CallOptions::getAuthority, CallOptions::withAuthority);
patchedOptions = patchOption(patchedOptions, CallOptions::getCredentials, CallOptions::withCallCredentials);
patchedOptions = patchOption(patchedOptions, CallOptions::getCompressor, CallOptions::withCompression);
patchedOptions = patchOption(patchedOptions, CallOptions::getDeadline, CallOptions::withDeadline);
patchedOptions = patchOption(patchedOptions, CallOptions::isWaitForReady, (callOptions, waitForReady) -> waitForReady ? callOptions.withWaitForReady() : callOptions.withoutWaitForReady());
patchedOptions = patchOption(patchedOptions, CallOptions::getMaxInboundMessageSize, CallOptions::withMaxInboundMessageSize);
patchedOptions = patchOption(patchedOptions, CallOptions::getMaxOutboundMessageSize, CallOptions::withMaxOutboundMessageSize);
patchedOptions = patchOption(patchedOptions, CallOptions::getExecutor, CallOptions::withExecutor);

for (ClientStreamTracer.Factory factory : defaultOptions.getStreamTracerFactories()) {
patchedOptions = patchedOptions.withStreamTracerFactory(factory);
}

for (CallOptions.Key<Object> key : customOptionKeys(defaultOptions)) {
patchedOptions = patchOption(patchedOptions, co -> co.getOption(key), (co, o) -> co.withOption(key, o));
}

return patchedOptions;
}

private <T> CallOptions patchOption(CallOptions baseOptions, Function<CallOptions, T> getter, BiFunction<CallOptions, T, CallOptions> setter) {
T baseValue = getter.apply(baseOptions);
if (baseValue == null || overwrite) {
T patchValue = getter.apply(defaultOptions);
if (patchValue != null) {
return setter.apply(baseOptions, patchValue);
}
}

return baseOptions;
}

@SuppressWarnings("unchecked")
private List<CallOptions.Key<Object>> customOptionKeys(CallOptions callOptions) {
try {
Object[][] customOptions = (Object[][]) CUSTOM_OPTIONS_FIELD.get(callOptions);
List<CallOptions.Key<Object>> keys = new ArrayList<>(customOptions.length);
for (Object[] arr : customOptions) {
keys.add((CallOptions.Key<Object>) arr[0]);
}
return keys;
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ public boolean equals(Object o) {
}
}

@Test
public void jsonMarshallerPrimitiveRoundtrip() {
Metadata.AsciiMarshaller<Integer> marshaller = MoreMetadata.JSON_MARSHALLER(Integer.class);
String s = marshaller.toAsciiString(42);
assertThat(s).isEqualTo("42");

Integer l = marshaller.parseAsciiString(s);
assertThat(l).isEqualTo(42);
}

@Test
public void protobufMarshallerRoundtrip() {
HelloRequest request = HelloRequest.newBuilder().setName("World").build();
Expand Down Expand Up @@ -134,4 +144,18 @@ public void rawJsonToTypedJson() {
assertThat(bar.cheese).isEqualTo("swiss");
assertThat(bar.age).isEqualTo(42);
}

@Test
public void rawBytesToTypedProto() {
Metadata.Key<byte[]> byteKey = Metadata.Key.of("key-bin", Metadata.BINARY_BYTE_MARSHALLER);
Metadata.Key<HelloRequest> protoKey = Metadata.Key.of("key-bin", MoreMetadata.PROTOBUF_MARSHALLER(HelloRequest.class));

HelloRequest request = HelloRequest.newBuilder().setName("World").build();
Metadata metadata = new Metadata();
metadata.put(byteKey, request.toByteArray());

HelloRequest request2 = metadata.get(protoKey);
assertThat(request2).isNotNull();
assertThat(request2.getName()).isEqualTo("World");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright (c) 2017, salesforce.com, inc.
* All rights reserved.
* Licensed under the BSD 3-Clause license.
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
*/

package com.salesforce.grpc.contrib.interceptor;

import io.grpc.CallOptions;
import io.grpc.ClientStreamTracer;
import org.junit.Test;

import static org.assertj.core.api.Java6Assertions.assertThat;

public class DefaultCallOptionsClientInterceptorTest {
@Test
public void simpleValueTransfers() {
CallOptions baseOptions = CallOptions.DEFAULT;
CallOptions defaultOptions = CallOptions.DEFAULT.withAuthority("FOO");

DefaultCallOptionsClientInterceptor interceptor = new DefaultCallOptionsClientInterceptor(defaultOptions);

CallOptions patchedOptions = interceptor.patchOptions(baseOptions);

assertThat(patchedOptions.getAuthority()).isEqualTo("FOO");
}

@Test
public void clientStreamTracerTransfers() {
ClientStreamTracer.Factory factory1 = new ClientStreamTracer.Factory() {};
ClientStreamTracer.Factory factory2 = new ClientStreamTracer.Factory() {};

CallOptions baseOptions = CallOptions.DEFAULT.withStreamTracerFactory(factory1);
CallOptions defaultOptions = CallOptions.DEFAULT.withStreamTracerFactory(factory2);

DefaultCallOptionsClientInterceptor interceptor = new DefaultCallOptionsClientInterceptor(defaultOptions);

CallOptions patchedOptions = interceptor.patchOptions(baseOptions);

assertThat(patchedOptions.getStreamTracerFactories()).containsExactly(factory1, factory2);
}

@Test
public void customKeyTransfers() {
CallOptions.Key<String> k1 = CallOptions.Key.of("k1", null);
CallOptions.Key<String> k2 = CallOptions.Key.of("k2", null);

CallOptions baseOptions = CallOptions.DEFAULT.withOption(k1, "FOO");
CallOptions defaultOptions = CallOptions.DEFAULT.withOption(k2, "BAR");

DefaultCallOptionsClientInterceptor interceptor = new DefaultCallOptionsClientInterceptor(defaultOptions);

CallOptions patchedOptions = interceptor.patchOptions(baseOptions);

assertThat(patchedOptions.getOption(k1)).isEqualTo("FOO");
assertThat(patchedOptions.getOption(k2)).isEqualTo("BAR");
}

@Test
public void noOverwriteWorks() {
CallOptions baseOptions = CallOptions.DEFAULT.withAuthority("FOO");
CallOptions defaultOptions = CallOptions.DEFAULT.withAuthority("BAR");

DefaultCallOptionsClientInterceptor interceptor = new DefaultCallOptionsClientInterceptor(defaultOptions);

CallOptions patchedOptions = interceptor.patchOptions(baseOptions);

assertThat(patchedOptions.getAuthority()).isEqualTo("FOO");
}

@Test
public void noOverwriteWorksCustomKeys() {
CallOptions.Key<String> k1 = CallOptions.Key.of("k1", null);
CallOptions.Key<String> k2 = CallOptions.Key.of("k2", null);
CallOptions.Key<String> k3 = CallOptions.Key.of("k3", null);

CallOptions baseOptions = CallOptions.DEFAULT.withOption(k1, "FOO").withOption(k3, "BAZ");
CallOptions defaultOptions = CallOptions.DEFAULT.withOption(k2, "BAR").withOption(k3, "BOP");

DefaultCallOptionsClientInterceptor interceptor = new DefaultCallOptionsClientInterceptor(defaultOptions);

CallOptions patchedOptions = interceptor.patchOptions(baseOptions);

assertThat(patchedOptions.getOption(k1)).isEqualTo("FOO");
assertThat(patchedOptions.getOption(k2)).isEqualTo("BAR");
assertThat(patchedOptions.getOption(k3)).isEqualTo("BAZ");
}

@Test
public void overwriteWorks() {
CallOptions baseOptions = CallOptions.DEFAULT.withAuthority("FOO");
CallOptions defaultOptions = CallOptions.DEFAULT.withAuthority("BAR");

DefaultCallOptionsClientInterceptor interceptor = new DefaultCallOptionsClientInterceptor(defaultOptions)
.overwriteExistingValues();

CallOptions patchedOptions = interceptor.patchOptions(baseOptions);

assertThat(patchedOptions.getAuthority()).isEqualTo("BAR");
}

@Test
public void overwriteWorksCustomKeys() {
CallOptions.Key<String> k1 = CallOptions.Key.of("k1", null);
CallOptions.Key<String> k2 = CallOptions.Key.of("k2", null);
CallOptions.Key<String> k3 = CallOptions.Key.of("k3", null);

CallOptions baseOptions = CallOptions.DEFAULT.withOption(k1, "FOO").withOption(k3, "BAZ");
CallOptions defaultOptions = CallOptions.DEFAULT.withOption(k2, "BAR").withOption(k3, "BOP");

DefaultCallOptionsClientInterceptor interceptor = new DefaultCallOptionsClientInterceptor(defaultOptions)
.overwriteExistingValues();

CallOptions patchedOptions = interceptor.patchOptions(baseOptions);

assertThat(patchedOptions.getOption(k1)).isEqualTo("FOO");
assertThat(patchedOptions.getOption(k2)).isEqualTo("BAR");
assertThat(patchedOptions.getOption(k3)).isEqualTo("BOP");
}
}