Skip to content

Commit

Permalink
Polish netty client support
Browse files Browse the repository at this point in the history
  • Loading branch information
rstoyanchev committed Oct 27, 2014
1 parent e120757 commit 083dece
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,15 @@ class Netty4ClientHttpRequest extends AbstractAsyncClientHttpRequest implements

private final ByteBufOutputStream body;


Netty4ClientHttpRequest(Bootstrap bootstrap, URI uri, HttpMethod method, int maxRequestSize) {
this.bootstrap = bootstrap;
this.uri = uri;
this.method = method;
this.body = new ByteBufOutputStream(Unpooled.buffer(maxRequestSize));
}


@Override
public HttpMethod getMethod() {
return this.method;
Expand All @@ -83,8 +85,7 @@ protected OutputStream getBodyInternal(HttpHeaders headers) throws IOException {
}

@Override
protected ListenableFuture<ClientHttpResponse> executeInternal(final HttpHeaders headers)
throws IOException {
protected ListenableFuture<ClientHttpResponse> executeInternal(final HttpHeaders headers) throws IOException {
final SettableListenableFuture<ClientHttpResponse> responseFuture =
new SettableListenableFuture<ClientHttpResponse>();

Expand All @@ -93,42 +94,19 @@ protected ListenableFuture<ClientHttpResponse> executeInternal(final HttpHeaders
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
Channel channel = future.channel();
channel.pipeline()
.addLast(new SimpleChannelInboundHandler<FullHttpResponse>() {

@Override
protected void channelRead0(
ChannelHandlerContext ctx,
FullHttpResponse msg) throws Exception {
responseFuture
.set(new Netty4ClientHttpResponse(ctx,
msg));
}

@Override
public void exceptionCaught(
ChannelHandlerContext ctx,
Throwable cause) throws Exception {
responseFuture.setException(cause);
}
});

FullHttpRequest nettyRequest =
createFullHttpRequest(headers);

channel.pipeline().addLast(new RequestExecuteHandler(responseFuture));
FullHttpRequest nettyRequest = createFullHttpRequest(headers);
channel.writeAndFlush(nettyRequest);
}
else {
responseFuture.setException(future.cause());
}

}
};

bootstrap.connect(uri.getHost(), getPort(uri)).addListener(connectionListener);
this.bootstrap.connect(this.uri.getHost(), getPort(this.uri)).addListener(connectionListener);

return responseFuture;

}

@Override
Expand All @@ -142,7 +120,8 @@ public ClientHttpResponse execute() throws IOException {
catch (ExecutionException ex) {
if (ex.getCause() instanceof IOException) {
throw (IOException) ex.getCause();
} else {
}
else {
throw new IOException(ex.getMessage(), ex);
}
}
Expand All @@ -163,17 +142,13 @@ else if ("https".equalsIgnoreCase(uri.getScheme())) {

private FullHttpRequest createFullHttpRequest(HttpHeaders headers) {
io.netty.handler.codec.http.HttpMethod nettyMethod =
io.netty.handler.codec.http.HttpMethod.valueOf(method.name());
io.netty.handler.codec.http.HttpMethod.valueOf(this.method.name());

FullHttpRequest nettyRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1,
nettyMethod, this.uri.getRawPath(),
this.body.buffer());
nettyMethod, this.uri.getRawPath(), this.body.buffer());

nettyRequest.headers()
.set(io.netty.handler.codec.http.HttpHeaders.Names.HOST, uri.getHost());
nettyRequest.headers()
.set(io.netty.handler.codec.http.HttpHeaders.Names.CONNECTION,
io.netty.handler.codec.http.HttpHeaders.Values.CLOSE);
nettyRequest.headers().set(HttpHeaders.HOST, uri.getHost());
nettyRequest.headers().set(HttpHeaders.CONNECTION, io.netty.handler.codec.http.HttpHeaders.Values.CLOSE);

for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
nettyRequest.headers().add(entry.getKey(), entry.getValue());
Expand All @@ -183,4 +158,26 @@ private FullHttpRequest createFullHttpRequest(HttpHeaders headers) {
}


/**
* A SimpleChannelInboundHandler to update the given SettableListenableFuture.
*/
private static class RequestExecuteHandler extends SimpleChannelInboundHandler<FullHttpResponse> {

private final SettableListenableFuture<ClientHttpResponse> responseFuture;

public RequestExecuteHandler(SettableListenableFuture<ClientHttpResponse> responseFuture) {
this.responseFuture = responseFuture;
}

@Override
protected void channelRead0(ChannelHandlerContext context, FullHttpResponse response) throws Exception {
this.responseFuture.set(new Netty4ClientHttpResponse(context, response));
}

@Override
public void exceptionCaught(ChannelHandlerContext context, Throwable cause) throws Exception {
this.responseFuture.setException(cause);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@
* @author Arjen Poutsma
* @since 4.2
*/
public class Netty4ClientHttpRequestFactory
implements ClientHttpRequestFactory, AsyncClientHttpRequestFactory,
InitializingBean, DisposableBean {
public class Netty4ClientHttpRequestFactory implements ClientHttpRequestFactory,
AsyncClientHttpRequestFactory, InitializingBean, DisposableBean {

/**
* The default maximum request size.
* @see #setMaxRequestSize(int)
*/
public static final int DEFAULT_MAX_REQUEST_SIZE = 1024 * 1024 * 10;


private final EventLoopGroup eventLoopGroup;

private final boolean defaultEventLoopGroup;
Expand All @@ -65,18 +65,19 @@ public class Netty4ClientHttpRequestFactory

private Bootstrap bootstrap;


/**
* Creates a new {@code Netty4ClientHttpRequestFactory} with a default
* Create a new {@code Netty4ClientHttpRequestFactory} with a default
* {@link NioEventLoopGroup}.
*/
public Netty4ClientHttpRequestFactory() {
int ioWorkerCount = Runtime.getRuntime().availableProcessors() * 2;
eventLoopGroup = new NioEventLoopGroup(ioWorkerCount);
defaultEventLoopGroup = true;
this.eventLoopGroup = new NioEventLoopGroup(ioWorkerCount);
this.defaultEventLoopGroup = true;
}

/**
* Creates a new {@code Netty4ClientHttpRequestFactory} with the given
* Create a new {@code Netty4ClientHttpRequestFactory} with the given
* {@link EventLoopGroup}.
*
* <p><b>NOTE:</b> the given group will <strong>not</strong> be
Expand All @@ -89,17 +90,20 @@ public Netty4ClientHttpRequestFactory(EventLoopGroup eventLoopGroup) {
this.defaultEventLoopGroup = false;
}


/**
* Sets the default maximum request size. The default is
* {@link #DEFAULT_MAX_REQUEST_SIZE}.
* Set the default maximum request size.
* <p>By default this is set to {@link #DEFAULT_MAX_REQUEST_SIZE}.
* @see HttpObjectAggregator#HttpObjectAggregator(int)
*/
public void setMaxRequestSize(int maxRequestSize) {
this.maxRequestSize = maxRequestSize;
}

/**
* Sets the SSL context.
* Set the SSL context. When configured it is used to create and insert an
* {@link io.netty.handler.ssl.SslHandler} in the channel pipeline.
* <p>By default this is not set.
*/
public void setSslContext(SslContext sslContext) {
this.sslContext = sslContext;
Expand All @@ -108,14 +112,14 @@ public void setSslContext(SslContext sslContext) {
private Bootstrap getBootstrap() {
if (this.bootstrap == null) {
Bootstrap bootstrap = new Bootstrap();
bootstrap.group(eventLoopGroup).channel(NioSocketChannel.class)
bootstrap.group(this.eventLoopGroup).channel(NioSocketChannel.class)
.handler(new ChannelInitializer<SocketChannel>() {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline pipeline = ch.pipeline();
protected void initChannel(SocketChannel channel) throws Exception {
ChannelPipeline pipeline = channel.pipeline();

if (sslContext != null) {
pipeline.addLast(sslContext.newHandler(ch.alloc()));
pipeline.addLast(sslContext.newHandler(channel.alloc()));
}
pipeline.addLast(new HttpClientCodec());
pipeline.addLast(new HttpObjectAggregator(maxRequestSize));
Expand All @@ -131,29 +135,26 @@ public void afterPropertiesSet() throws Exception {
getBootstrap();
}

private Netty4ClientHttpRequest createRequestInternal(URI uri, HttpMethod httpMethod) {
return new Netty4ClientHttpRequest(getBootstrap(), uri, httpMethod, maxRequestSize);
}

@Override
public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod)
throws IOException {
public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException {
return createRequestInternal(uri, httpMethod);
}

@Override
public AsyncClientHttpRequest createAsyncRequest(URI uri, HttpMethod httpMethod)
throws IOException {
public AsyncClientHttpRequest createAsyncRequest(URI uri, HttpMethod httpMethod) throws IOException {
return createRequestInternal(uri, httpMethod);
}

private Netty4ClientHttpRequest createRequestInternal(URI uri, HttpMethod httpMethod) {
return new Netty4ClientHttpRequest(getBootstrap(), uri, httpMethod, this.maxRequestSize);
}

@Override
public void destroy() throws InterruptedException {
if (defaultEventLoopGroup) {
if (this.defaultEventLoopGroup) {
// clean up the EventLoopGroup if we created it in the constructor
eventLoopGroup.shutdownGracefully().sync();
this.eventLoopGroup.shutdownGracefully().sync();
}
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ class Netty4ClientHttpResponse extends AbstractClientHttpResponse {
private HttpHeaders headers;


Netty4ClientHttpResponse(ChannelHandlerContext context,
FullHttpResponse nettyResponse) {
Netty4ClientHttpResponse(ChannelHandlerContext context, FullHttpResponse nettyResponse) {
Assert.notNull(context, "'context' must not be null");
Assert.notNull(nettyResponse, "'nettyResponse' must not be null");
this.context = context;
Expand All @@ -55,6 +54,7 @@ class Netty4ClientHttpResponse extends AbstractClientHttpResponse {
this.nettyResponse.retain();
}


@Override
public int getRawStatusCode() throws IOException {
return this.nettyResponse.getStatus().code();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@

package org.springframework.http.client;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

import java.io.IOException;
import java.net.URI;
import java.util.Arrays;
import java.util.Locale;
import java.util.concurrent.Future;

import org.junit.After;
import static org.junit.Assert.*;
import org.junit.Before;
import org.junit.Test;

import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.http.HttpMethod;
Expand All @@ -44,16 +46,16 @@ public abstract class AbstractAsyncHttpRequestFactoryTestCase extends AbstractJe

@Before
public final void createFactory() throws Exception {
factory = createRequestFactory();
if (factory instanceof InitializingBean) {
((InitializingBean) factory).afterPropertiesSet();
this.factory = createRequestFactory();
if (this.factory instanceof InitializingBean) {
((InitializingBean) this.factory).afterPropertiesSet();
}
}

@After
public final void destroyFactory() throws Exception {
if (factory instanceof DisposableBean) {
((DisposableBean) factory).destroy();
if (this.factory instanceof DisposableBean) {
((DisposableBean) this.factory).destroy();
}
}

Expand All @@ -63,22 +65,23 @@ public final void destroyFactory() throws Exception {
@Test
public void status() throws Exception {
URI uri = new URI(baseUrl + "/status/notfound");
AsyncClientHttpRequest request = factory.createAsyncRequest(uri, HttpMethod.GET);
AsyncClientHttpRequest request = this.factory.createAsyncRequest(uri, HttpMethod.GET);
assertEquals("Invalid HTTP method", HttpMethod.GET, request.getMethod());
assertEquals("Invalid HTTP URI", uri, request.getURI());
Future<ClientHttpResponse> futureResponse = request.executeAsync();
ClientHttpResponse response = futureResponse.get();
try {
assertEquals("Invalid status code", HttpStatus.NOT_FOUND, response.getStatusCode());
} finally {
}
finally {
response.close();
}
}

@Test
public void statusCallback() throws Exception {
URI uri = new URI(baseUrl + "/status/notfound");
AsyncClientHttpRequest request = factory.createAsyncRequest(uri, HttpMethod.GET);
AsyncClientHttpRequest request = this.factory.createAsyncRequest(uri, HttpMethod.GET);
assertEquals("Invalid HTTP method", HttpMethod.GET, request.getMethod());
assertEquals("Invalid HTTP URI", uri, request.getURI());
ListenableFuture<ClientHttpResponse> listenableFuture = request.executeAsync();
Expand Down Expand Up @@ -108,7 +111,7 @@ public void onFailure(Throwable ex) {

@Test
public void echo() throws Exception {
AsyncClientHttpRequest request = factory.createAsyncRequest(new URI(baseUrl + "/echo"), HttpMethod.PUT);
AsyncClientHttpRequest request = this.factory.createAsyncRequest(new URI(baseUrl + "/echo"), HttpMethod.PUT);
assertEquals("Invalid HTTP method", HttpMethod.PUT, request.getMethod());
String headerName = "MyHeader";
String headerValue1 = "value1";
Expand Down Expand Up @@ -143,7 +146,7 @@ public void echo() throws Exception {

@Test
public void multipleWrites() throws Exception {
AsyncClientHttpRequest request = factory.createAsyncRequest(new URI(baseUrl + "/echo"), HttpMethod.POST);
AsyncClientHttpRequest request = this.factory.createAsyncRequest(new URI(baseUrl + "/echo"), HttpMethod.POST);
final byte[] body = "Hello World".getBytes("UTF-8");

if (request instanceof StreamingHttpOutputMessage) {
Expand All @@ -170,7 +173,7 @@ public void multipleWrites() throws Exception {

@Test
public void headersAfterExecute() throws Exception {
AsyncClientHttpRequest request = factory.createAsyncRequest(new URI(baseUrl + "/echo"), HttpMethod.POST);
AsyncClientHttpRequest request = this.factory.createAsyncRequest(new URI(baseUrl + "/echo"), HttpMethod.POST);
request.getHeaders().add("MyHeader", "value");
byte[] body = "Hello World".getBytes("UTF-8");
FileCopyUtils.copy(body, request.getBody());
Expand Down Expand Up @@ -202,7 +205,7 @@ public void httpMethods() throws Exception {
protected void assertHttpMethod(String path, HttpMethod method) throws Exception {
ClientHttpResponse response = null;
try {
AsyncClientHttpRequest request = factory.createAsyncRequest(new URI(baseUrl + "/methods/" + path), method);
AsyncClientHttpRequest request = this.factory.createAsyncRequest(new URI(baseUrl + "/methods/" + path), method);
Future<ClientHttpResponse> futureResponse = request.executeAsync();
response = futureResponse.get();
assertEquals("Invalid response status", HttpStatus.OK, response.getStatusCode());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ public void status() throws Exception {
ClientHttpResponse response = request.execute();
try {
assertEquals("Invalid status code", HttpStatus.NOT_FOUND, response.getStatusCode());
} finally {
}
finally {
response.close();
}
}
Expand Down
Loading

0 comments on commit 083dece

Please sign in to comment.