Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,28 @@ public final class Servers {
* @throws InterruptedException if waiting for termination is interrupted
*/
public static Server shutdownGracefully(Server server, long maxWaitTimeInMillis) throws InterruptedException {
return shutdownGracefully(server, maxWaitTimeInMillis, TimeUnit.MILLISECONDS);
}

/**
* Attempt to {@link Server#shutdown()} the {@link Server} gracefully. If the max wait time is exceeded, give up and
* perform a hard {@link Server#shutdownNow()}.
*
* @param server the server to be shutdown
* @param timeout the max amount of time to wait for graceful shutdown to occur
* @param unit the time unit denominating the shutdown timeout
* @return the given server
* @throws InterruptedException if waiting for termination is interrupted
*/
public static Server shutdownGracefully(Server server, long timeout, TimeUnit unit) throws InterruptedException {
Preconditions.checkNotNull(server, "server");
Preconditions.checkArgument(maxWaitTimeInMillis > 0, "maxWaitTimeInMillis must be greater than 0");
Preconditions.checkArgument(timeout > 0, "timeout must be greater than 0");
Preconditions.checkNotNull(unit, "unit");

server.shutdown();

try {
server.awaitTermination(maxWaitTimeInMillis, TimeUnit.MILLISECONDS);
server.awaitTermination(timeout, unit);
} finally {
server.shutdownNow();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public void shutdownGracefullyThrowsIfMaxWaitTimeInMillisIsZero() {

assertThatThrownBy(() -> Servers.shutdownGracefully(server, 0))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("maxWaitTimeInMillis");
.hasMessageContaining("timeout must be greater than 0");
}

@Test
Expand All @@ -44,7 +44,7 @@ public void shutdownGracefullyThrowsIfMaxWaitTimeInMillisIsLessThanZero() {

assertThatThrownBy(() -> Servers.shutdownGracefully(server, maxWaitTimeInMillis))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("maxWaitTimeInMillis");
.hasMessageContaining("timeout must be greater than 0");
}

@Test
Expand Down
54 changes: 54 additions & 0 deletions grpc-testing-contrib/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Copyright (c) 2018, salesforce.com, inc.
~ All rights reserved.
~ Licensed under the BSD 3-Clause license.
~ For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
-->

<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>grpc-contrib-parent</artifactId>
<groupId>com.salesforce.servicelibs</groupId>
<version>0.7.1-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>

<artifactId>grpc-testing-contrib</artifactId>

<dependencies>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-core</artifactId>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-stub</artifactId>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-netty</artifactId>
</dependency>
<dependency>
<groupId>com.salesforce.servicelibs</groupId>
<artifactId>grpc-contrib</artifactId>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-testing-proto</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright (c) 2017, salesforce.com, inc.
* All rights reserved.
* Licensed under the BSD 3-Clause license.
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
*/

package com.salesforce.grpc.testing.contrib;

import io.grpc.Context;
import org.junit.Assert;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;

/**
* {@code GrpcContextRule} is a JUnit {@link TestRule} that forcibly resets the gRPC
* {@link Context} to {@link Context#ROOT} between every unit test.
*
* <p>This rule makes it easier to correctly implement correct unit tests by preventing the
* accidental leakage of context state between tests.
*/
public class GrpcContextRule implements TestRule {
@Override
public Statement apply(final Statement base, final Description description) {
return new Statement() {
@Override
public void evaluate() throws Throwable {
// Reset the gRPC context between test executions
Context prev = Context.ROOT.attach();
try {
base.evaluate();
if (Context.current() != Context.ROOT) {
Assert.fail("Test is leaking context state between tests! Ensure proper " +
"attach()/detach() pairing.");
}
} finally {
Context.ROOT.detach(prev);
}
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright (c) 2017, salesforce.com, inc.
* All rights reserved.
* Licensed under the BSD 3-Clause license.
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
*/

package com.salesforce.grpc.testing.contrib;

import com.salesforce.grpc.contrib.Servers;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.util.MutableHandlerRegistry;
import org.junit.rules.ExternalResource;

import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;

/**
* {@code NettyGrpcServerRule} is a JUnit {@link org.junit.rules.TestRule} that starts a gRPC Netty service with
* a {@link MutableHandlerRegistry} for adding services. It is particularly useful for testing middleware and
* interceptors using the "real" gRPC wire protocol instead of the InProcess protocol. While InProcess testing works
* 99% of the time, the Netty and InProcess transports have different flow control and serialization semantics that
* can have an affect on low-level gRPC integrations.
*
* <p>An {@link io.grpc.stub.AbstractStub} can be created against this service by using the
* {@link ManagedChannel} provided by {@link NettyGrpcServerRule#getChannel()}.
*/
public class NettyGrpcServerRule extends ExternalResource {

private ManagedChannel channel;
private Server server;
private MutableHandlerRegistry serviceRegistry;
private boolean useDirectExecutor;
private int port = 0;

private Consumer<NettyServerBuilder> configureServerBuilder = sb -> { };
private Consumer<NettyChannelBuilder> configureChannelBuilder = cb -> { };

/**
* Provides a way to configure the {@code NettyServerBuilder} used for testing.
*/
public final NettyGrpcServerRule configureServerBuilder(Consumer<NettyServerBuilder> configureServerBuilder) {
checkState(port == 0, "configureServerBuilder() can only be called at the rule instantiation");
this.configureServerBuilder = checkNotNull(configureServerBuilder, "configureServerBuilder");
return this;
}

/**
* Provides a way to configure the {@code NettyChannelBuilder} used for testing.
*/
public final NettyGrpcServerRule configureChannelBuilder(Consumer<NettyChannelBuilder> configureChannelBuilder) {
checkState(port == 0, "configureChannelBuilder() can only be called at the rule instantiation");
this.configureChannelBuilder = checkNotNull(configureChannelBuilder, "configureChannelBuilder");
return this;
}

/**
* Returns a {@link ManagedChannel} connected to this service.
*/
public final ManagedChannel getChannel() {
return channel;
}

/**
* Returns the underlying gRPC {@link Server} for this service.
*/
public final Server getServer() {
return server;
}

/**
* Returns the randomly generated TCP port for this service.
*/
public final int getPort() {
return port;
}

/**
* Returns the service registry for this service. The registry is used to add service instances
* (e.g. {@link io.grpc.BindableService} or {@link io.grpc.ServerServiceDefinition} to the server.
*/
public final MutableHandlerRegistry getServiceRegistry() {
return serviceRegistry;
}

/**
* Before the test has started, create the server and channel.
*/
@Override
protected void before() throws Throwable {
serviceRegistry = new MutableHandlerRegistry();

NettyServerBuilder serverBuilder = NettyServerBuilder
.forPort(0)
.fallbackHandlerRegistry(serviceRegistry);

if (useDirectExecutor) {
serverBuilder.directExecutor();
}

configureServerBuilder.accept(serverBuilder);
server = serverBuilder.build().start();
port = server.getPort();

NettyChannelBuilder channelBuilder = NettyChannelBuilder.forAddress("localhost", port).usePlaintext(true);
configureChannelBuilder.accept(channelBuilder);
channel = channelBuilder.build();
}

/**
* After the test has completed, clean up the channel and server.
*/
@Override
protected void after() {
serviceRegistry = null;

channel.shutdown();
channel = null;
port = 0;

try {
Servers.shutdownGracefully(server, 1, TimeUnit.MINUTES);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
} finally {
server = null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) 2018, salesforce.com, inc.
* All rights reserved.
* Licensed under the BSD 3-Clause license.
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
*/

package com.salesforce.grpc.testing.contrib;

import io.grpc.Context;
import org.junit.Test;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.Assert.fail;

public class GrpcContextRuleTest {
@Test
public void ruleSetsContextToRoot() {
Context.current().withValue(Context.key("foo"), "bar").run(() -> {
assertThat(Context.current()).isNotEqualTo(Context.ROOT);

try {
GrpcContextRule rule = new GrpcContextRule();
rule.apply(new Statement() {
@Override
public void evaluate() {
assertThat(Context.current()).isEqualTo(Context.ROOT);
}
}, Description.createTestDescription(GrpcContextRuleTest.class, "ruleSetsContextToRoot"))
.evaluate();
} catch (Throwable throwable) {
fail(throwable.getMessage());
}
});
}

@Test
public void ruleFailsIfContextLeaks() {
Context.current().withValue(Context.key("foo"), "bar").run(() -> {
assertThat(Context.current()).isNotEqualTo(Context.ROOT);

assertThatThrownBy(() -> {
GrpcContextRule rule = new GrpcContextRule();
rule.apply(new Statement() {
@Override
public void evaluate() {
// Leak context
Context.current().withValue(Context.key("cheese"), "baz").attach();
}
}, Description.createTestDescription(GrpcContextRuleTest.class, "ruleSetsContextToRoot"))
.evaluate();
}).isInstanceOf(AssertionError.class).hasMessageContaining("Test is leaking context");
});
}
}
Loading