diff --git a/grpc-spring/src/main/java/com/salesforce/grpc/contrib/spring/GrpcServerFactory.java b/grpc-spring/src/main/java/com/salesforce/grpc/contrib/spring/GrpcServerFactory.java index 8ec12077..4a702abb 100644 --- a/grpc-spring/src/main/java/com/salesforce/grpc/contrib/spring/GrpcServerFactory.java +++ b/grpc-spring/src/main/java/com/salesforce/grpc/contrib/spring/GrpcServerFactory.java @@ -7,10 +7,13 @@ package com.salesforce.grpc.contrib.spring; +import com.google.common.collect.ImmutableList; import io.grpc.BindableService; import io.grpc.Server; +import java.lang.annotation.Annotation; import java.util.Collection; +import java.util.List; /** * Implement this interface in a bean to override how {@link GrpcServerHost} initializes a {@link Server} from a @@ -25,4 +28,15 @@ public interface GrpcServerFactory { * @return A new grpc {@link Server} */ Server buildServerForServices(int port, Collection services); + + /** + * The {@link Annotation}s this GrpcServerFactory will match on when discovering gRPC service implementations. + * Override this method to provide your own set of annotations instead of the default + * {@code {@literal @}GrpcService} annotation. + * + * @return a set of java annotations to match on. + */ + default List> forAnnotations() { + return ImmutableList.of(GrpcService.class); + } } diff --git a/grpc-spring/src/main/java/com/salesforce/grpc/contrib/spring/GrpcServerHost.java b/grpc-spring/src/main/java/com/salesforce/grpc/contrib/spring/GrpcServerHost.java index c653106e..d9d4edf7 100644 --- a/grpc-spring/src/main/java/com/salesforce/grpc/contrib/spring/GrpcServerHost.java +++ b/grpc-spring/src/main/java/com/salesforce/grpc/contrib/spring/GrpcServerHost.java @@ -21,7 +21,9 @@ import javax.annotation.Nonnull; import java.io.IOException; +import java.lang.annotation.Annotation; import java.util.Collection; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -123,15 +125,16 @@ public final int getPort() { * @throws IllegalStateException if any non-{@link BindableService} classes are annotated with {@link GrpcService} */ public void start() throws IOException { + if (serverFactory == null) { + serverFactory = findServerFactory(); + } + final Collection services = getServicesFromApplicationContext(); if (services.isEmpty()) { throw new IOException("gRPC server not started because no services were found in the application context."); } - if (serverFactory == null) { - serverFactory = findServerFactory(); - } server = serverFactory.buildServerForServices(port, services); server.start(); } @@ -157,8 +160,6 @@ public void close() throws Exception { final Server server = server(); if (server != null) { - final int port = getPort(); - server.shutdown(); try { @@ -173,7 +174,11 @@ public void close() throws Exception { } private Collection getServicesFromApplicationContext() { - Map possibleServices = applicationContext.getBeansWithAnnotation(GrpcService.class); + Map possibleServices = new HashMap<>(); + + for (Class annotation : serverFactory.forAnnotations()) { + possibleServices.putAll(applicationContext.getBeansWithAnnotation(annotation)); + } Collection invalidServiceNames = possibleServices.entrySet().stream() .filter(e -> !(e.getValue() instanceof BindableService)) diff --git a/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/AlsoAGrpcService.java b/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/AlsoAGrpcService.java new file mode 100644 index 00000000..3e7bfd10 --- /dev/null +++ b/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/AlsoAGrpcService.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2017, salesforce.com, inc. + * All rights reserved. + * Licensed under the BSD 3-Clause license. + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + */ + +package com.salesforce.grpc.contrib.spring; + +import org.springframework.stereotype.Service; + +import java.lang.annotation.*; + +/** + * {@code GrpcService} is an annotation that is used to mark a gRPC service implementation for automatic inclusion in + * your server. + */ +@Service +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface AlsoAGrpcService { + +} diff --git a/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/GrpcServerHostEndToEndTest.java b/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/GrpcServerHostEndToEndTest.java index 6f082861..50741e32 100644 --- a/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/GrpcServerHostEndToEndTest.java +++ b/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/GrpcServerHostEndToEndTest.java @@ -7,6 +7,7 @@ package com.salesforce.grpc.contrib.spring; +import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; @@ -23,7 +24,9 @@ import javax.annotation.Nullable; import java.io.IOException; +import java.lang.annotation.Annotation; import java.util.Collection; +import java.util.List; import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -31,6 +34,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.awaitility.Awaitility.await; +@SuppressWarnings("Duplicates") @ContextConfiguration @RunWith(SpringJUnit4ClassRunner.class) public class GrpcServerHostEndToEndTest { @@ -101,6 +105,11 @@ public Server buildServerForServices(int port, Collection servi System.out.println("Building a service for " + services.size() + " services"); return super.buildServerForServices(port, services); } + + @Override + public List> forAnnotations() { + return ImmutableList.of(GrpcService.class, AlsoAGrpcService.class); + } }; } @@ -111,6 +120,7 @@ public GrpcServerHost serverHost() throws IOException { } @GrpcService + @AlsoAGrpcService private static class GreeterImpl extends GreeterGrpc.GreeterImplBase { @Autowired private GreetingComposer composer; diff --git a/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/GrpcServerHostInProcessEndToEndTest.java b/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/GrpcServerHostInProcessEndToEndTest.java new file mode 100644 index 00000000..50ba1bb8 --- /dev/null +++ b/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/GrpcServerHostInProcessEndToEndTest.java @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2017, salesforce.com, inc. + * All rights reserved. + * Licensed under the BSD 3-Clause license. + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + */ + +package com.salesforce.grpc.contrib.spring; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.*; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.StreamObserver; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import javax.annotation.Nullable; +import java.io.IOException; +import java.lang.annotation.Annotation; +import java.util.Collection; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +@SuppressWarnings("Duplicates") +@ContextConfiguration +@RunWith(SpringJUnit4ClassRunner.class) +public class GrpcServerHostInProcessEndToEndTest { + private static String SERVER_NAME = "GrpcServerHostInProcessEndToEndTest"; + + @Autowired + private GrpcServerHost grpcServerHost; + + @Test + public void serverIsRunningAndSayHelloReturnsExpectedResponse() throws Exception { + final String name = UUID.randomUUID().toString(); + grpcServerHost.start(); + + ManagedChannel channel = InProcessChannelBuilder + .forName(SERVER_NAME) + .usePlaintext(true) + .build(); + + GreeterGrpc.GreeterFutureStub stub = GreeterGrpc.newFutureStub(channel); + + ListenableFuture responseFuture = stub.sayHello(HelloRequest.newBuilder().setName(name).build()); + AtomicReference response = new AtomicReference<>(); + + Futures.addCallback( + responseFuture, + new FutureCallback() { + @Override + public void onSuccess(@Nullable HelloResponse result) { + response.set(result); + } + + @Override + public void onFailure(Throwable t) { + + } + }, + MoreExecutors.directExecutor()); + + await().atMost(10, TimeUnit.SECONDS).until(responseFuture::isDone); + + channel.shutdownNow(); + + assertThat(response.get()).isNotNull(); + assertThat(response.get().getMessage()).contains(name); + } + + private interface GreetingComposer { + String greet(String name); + } + + @Configuration + static class TestConfiguration { + + @Bean + public GreeterImpl greeter() { + return new GreeterImpl(); + } + + @Bean + public GreetingComposer greetingComposer() { + return name -> "Hello, " + name; + } + + @Bean + public GrpcServerFactory factory() { + return new SimpleGrpcServerFactory() { + @Override + public Server buildServerForServices(int port, Collection services) { + System.out.println("Building an IN-PROC service for " + services.size() + " services"); + + ServerBuilder builder = InProcessServerBuilder.forName(SERVER_NAME); + services.forEach(builder::addService); + return builder.build(); + } + + @Override + public List> forAnnotations() { + return ImmutableList.of(InProcessGrpcService.class); + } + }; + } + + @Bean + public GrpcServerHost serverHost() throws IOException { + return new GrpcServerHost(9999); + } + } + + @InProcessGrpcService + private static class GreeterImpl extends GreeterGrpc.GreeterImplBase { + @Autowired + private GreetingComposer composer; + + @Override + public void sayHello(HelloRequest request, StreamObserver responseObserver) { + responseObserver.onNext(HelloResponse.newBuilder().setMessage(composer.greet(request.getName())).build()); + responseObserver.onCompleted(); + } + } +} diff --git a/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/GrpcServerHostTest.java b/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/GrpcServerHostTest.java index 20cb26ad..15ae4f27 100644 --- a/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/GrpcServerHostTest.java +++ b/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/GrpcServerHostTest.java @@ -167,7 +167,7 @@ public void startDoesNotStartServerWithoutServices() throws Exception { assertThatThrownBy(runner::start).isInstanceOf(IOException.class); // Make sure the server builder was not used. - verifyZeroInteractions(factory); + verify(factory, never()).buildServerForServices(anyInt(), any()); assertThat(runner.server()).isNull(); } diff --git a/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/InProcessGrpcService.java b/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/InProcessGrpcService.java new file mode 100644 index 00000000..7d55c9c2 --- /dev/null +++ b/grpc-spring/src/test/java/com/salesforce/grpc/contrib/spring/InProcessGrpcService.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2017, salesforce.com, inc. + * All rights reserved. + * Licensed under the BSD 3-Clause license. + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + */ + +package com.salesforce.grpc.contrib.spring; + +import org.springframework.stereotype.Service; + +import java.lang.annotation.*; + +/** + * {@code GrpcService} is an annotation that is used to mark a gRPC service implementation for automatic inclusion in + * your server. + */ +@Service +@Documented +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface InProcessGrpcService { + +}