Permalink
Browse files

Fix recorder HTTPS, close #884

  • Loading branch information...
1 parent cdddb1c commit 5ce59e00bfe7d6a7f6ece82e0dff91c7538fe9b2 Stephane Landelle committed Dec 28, 2012
@@ -17,7 +17,7 @@ package com.excilys.ebi.gatling.recorder.http.channel
import org.jboss.netty.bootstrap.{ ServerBootstrap, ClientBootstrap }
import org.jboss.netty.channel.{ ChannelPipelineFactory, ChannelPipeline, ChannelHandlerContext }
-import org.jboss.netty.channel.Channels.pipeline
+import org.jboss.netty.channel.Channels
import org.jboss.netty.channel.socket.nio.{ NioServerSocketChannelFactory, NioClientSocketChannelFactory }
import org.jboss.netty.handler.codec.http.{ HttpResponseEncoder, HttpRequestDecoder, HttpRequest, HttpContentDecompressor, HttpContentCompressor, HttpClientCodec, HttpChunkAggregator }
import org.jboss.netty.handler.ssl.SslHandler
@@ -29,26 +29,28 @@ import com.excilys.ebi.gatling.recorder.http.ssl.{ SSLEngineFactory, FirstEventI
object BootstrapFactory {
+ val SSL_HANDLER_NAME = "ssl"
+
private val CHUNK_MAX_SIZE = 100 * 1024 * 1024; // 100Mo
private val clientChannelFactory = new NioClientSocketChannelFactory
private val serverChannelFactory = new NioServerSocketChannelFactory
- def newClientBootstrap(controller: RecorderController, browserCtx: ChannelHandlerContext, browserRequest: HttpRequest, ssl: Boolean): ClientBootstrap = {
+ def newClientBootstrap(controller: RecorderController, requestContext: ChannelHandlerContext, browserRequest: HttpRequest, ssl: Boolean): ClientBootstrap = {
val bootstrap = new ClientBootstrap(clientChannelFactory)
bootstrap.setPipelineFactory(new ChannelPipelineFactory() {
def getPipeline: ChannelPipeline = {
- val tmpPipeline = pipeline()
+ val pipeline = Channels.pipeline
if (ssl)
- tmpPipeline.addLast("ssl", new SslHandler(SSLEngineFactory.newClientSSLEngine))
- tmpPipeline.addLast("codec", new HttpClientCodec)
- tmpPipeline.addLast("inflater", new HttpContentDecompressor)
- tmpPipeline.addLast("aggregator", new HttpChunkAggregator(CHUNK_MAX_SIZE))
- tmpPipeline.addLast("gatling", new ServerHttpResponseHandler(controller, browserCtx, browserRequest))
+ pipeline.addLast(SSL_HANDLER_NAME, new SslHandler(SSLEngineFactory.newClientSSLEngine))
+ pipeline.addLast("codec", new HttpClientCodec)
+ pipeline.addLast("inflater", new HttpContentDecompressor)
+ pipeline.addLast("aggregator", new HttpChunkAggregator(CHUNK_MAX_SIZE))
+ pipeline.addLast("gatling", new ServerHttpResponseHandler(controller, requestContext, browserRequest))
- tmpPipeline
+ pipeline
}
})
@@ -64,19 +66,19 @@ object BootstrapFactory {
bootstrap.setPipelineFactory(new ChannelPipelineFactory() {
def getPipeline: ChannelPipeline = {
- val tmpPipeline = pipeline()
+ val pipeline = Channels.pipeline
if (ssl)
- tmpPipeline.addLast("ssl", new FirstEventIsUnsecuredConnectSslHandler(SSLEngineFactory.newServerSSLEngine))
- tmpPipeline.addLast("decoder", new HttpRequestDecoder)
- tmpPipeline.addLast("aggregator", new HttpChunkAggregator(CHUNK_MAX_SIZE))
- tmpPipeline.addLast("encoder", new HttpResponseEncoder)
- tmpPipeline.addLast("deflater", new HttpContentCompressor)
+ pipeline.addLast(SSL_HANDLER_NAME, new FirstEventIsUnsecuredConnectSslHandler(SSLEngineFactory.newServerSSLEngine))
+ pipeline.addLast("decoder", new HttpRequestDecoder)
+ pipeline.addLast("aggregator", new HttpChunkAggregator(CHUNK_MAX_SIZE))
+ pipeline.addLast("encoder", new HttpResponseEncoder)
+ pipeline.addLast("deflater", new HttpContentCompressor)
if (ssl)
- tmpPipeline.addLast("gatling", new BrowserHttpsRequestHandler(controller, proxyConfig))
+ pipeline.addLast("gatling", new BrowserHttpsRequestHandler(controller, proxyConfig))
else
- tmpPipeline.addLast("gatling", new BrowserHttpRequestHandler(controller, proxyConfig))
+ pipeline.addLast("gatling", new BrowserHttpRequestHandler(controller, proxyConfig))
- tmpPipeline
+ pipeline
}
})
@@ -31,6 +31,10 @@ import grizzled.slf4j.Logging
abstract class AbstractBrowserRequestHandler(controller: RecorderController, proxyConfig: ProxyConfig) extends SimpleChannelHandler with Logging {
+ implicit def function2ChannelFutureListener(thunk: ChannelFuture => Any) = new ChannelFutureListener {
+ def operationComplete(future: ChannelFuture) { thunk(future) }
+ }
+
override def messageReceived(ctx: ChannelHandlerContext, event: MessageEvent) {
event.getMessage match {
@@ -45,39 +49,26 @@ abstract class AbstractBrowserRequestHandler(controller: RecorderController, pro
}
}.getOrElse(request.removeHeader("Proxy-Connection")) // remove Proxy-Connection header if it's not significant
- val future = connectToServerOnBrowserRequestReceived(ctx, request)
+ propagateRequest(ctx, request)
controller.receiveRequest(request)
- sendRequestToServerAfterConnection(future, request);
-
case unknown => warn("Received unknown message: " + unknown)
}
}
- def connectToServerOnBrowserRequestReceived(ctx: ChannelHandlerContext, request: HttpRequest): ChannelFuture
+ def propagateRequest(requestContext: ChannelHandlerContext, request: HttpRequest)
override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
error("Exception caught", e.getCause)
// Properly closing
val future = ctx.getChannel.getCloseFuture
- future.addListener(new ChannelFutureListener {
- def operationComplete(future: ChannelFuture) = future.getChannel.close
- })
+ future.addListener(ChannelFutureListener.CLOSE)
ctx.sendUpstream(e)
}
- private def sendRequestToServerAfterConnection(future: ChannelFuture, request: HttpRequest) {
-
- Option(future).map { future =>
- future.addListener(new ChannelFutureListener {
- def operationComplete(future: ChannelFuture) = future.getChannel.write(buildRequestWithRelativeURI(request))
- })
- }
- }
-
- private def buildRequestWithRelativeURI(request: HttpRequest) = {
+ def buildRequestWithRelativeURI(request: HttpRequest) = {
val uri = new URI(request.getUri)
val newUri = new URI(null, null, null, -1, uri.getPath, uri.getQuery, uri.getFragment).toString
val newRequest = new DefaultHttpRequest(request.getProtocolVersion, request.getMethod, newUri)
@@ -26,9 +26,9 @@ import com.excilys.ebi.gatling.recorder.http.channel.BootstrapFactory.newClientB
class BrowserHttpRequestHandler(controller: RecorderController, proxyConfig: ProxyConfig) extends AbstractBrowserRequestHandler(controller, proxyConfig) {
- def connectToServerOnBrowserRequestReceived(ctx: ChannelHandlerContext, request: HttpRequest): ChannelFuture = {
+ def propagateRequest(requestContext: ChannelHandlerContext, request: HttpRequest) {
- val bootstrap = newClientBootstrap(controller, ctx, request, false)
+ val bootstrap = newClientBootstrap(controller, requestContext, request, false)
val (proxyHost, proxyPort) = (for {
host <- proxyConfig.host
@@ -40,6 +40,8 @@ class BrowserHttpRequestHandler(controller: RecorderController, proxyConfig: Pro
(uri.getHost, port)
}
- bootstrap.connect(new InetSocketAddress(proxyHost, proxyPort))
+ bootstrap
+ .connect(new InetSocketAddress(proxyHost, proxyPort))
+ .addListener { future: ChannelFuture => future.getChannel.write(buildRequestWithRelativeURI(request)) }
}
}
@@ -19,9 +19,11 @@ import java.net.{ InetSocketAddress, URI }
import org.jboss.netty.channel.{ ChannelFuture, ChannelHandlerContext }
import org.jboss.netty.handler.codec.http.{ DefaultHttpResponse, HttpMethod, HttpRequest, HttpResponseStatus, HttpVersion }
+import org.jboss.netty.handler.ssl.SslHandler
import com.excilys.ebi.gatling.recorder.config.ProxyConfig
import com.excilys.ebi.gatling.recorder.controller.RecorderController
+import com.excilys.ebi.gatling.recorder.http.channel.BootstrapFactory
import com.excilys.ebi.gatling.recorder.http.channel.BootstrapFactory.newClientBootstrap
import grizzled.slf4j.Logging
@@ -30,34 +32,37 @@ class BrowserHttpsRequestHandler(controller: RecorderController, proxyConfig: Pr
@volatile var targetHostURI: URI = _
- def connectToServerOnBrowserRequestReceived(ctx: ChannelHandlerContext, request: HttpRequest): ChannelFuture = {
-
- info("Received " + request.getMethod + " on " + request.getUri)
-
- if (request.getMethod == HttpMethod.CONNECT) {
-
- targetHostURI = new URI("https://" + request.getUri());
+ def propagateRequest(requestContext: ChannelHandlerContext, request: HttpRequest) {
+ def handleConnect {
+ targetHostURI = new URI("https://" + request.getUri)
warn("Trying to connect to " + targetHostURI + ", make sure you've accepted the recorder certificate for this site")
-
controller.secureConnection(targetHostURI)
+ requestContext.getChannel.write(new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK))
+ }
- ctx.getChannel.write(new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK))
-
- null
-
- } else {
- // set full uri so that it's correctly recorded
+ def handlePropagatableRequest {
+ // set full uri so that it's correctly recorded FIXME ugly
request.setUri(targetHostURI + request.getUri)
- val bootstrap = newClientBootstrap(controller, ctx, request, true)
+ val bootstrap = newClientBootstrap(controller, requestContext, request, true)
val (host, port) = (for {
host <- proxyConfig.host
port <- proxyConfig.port
} yield (host, port)).getOrElse(targetHostURI.getHost, targetHostURI.getPort)
- bootstrap.connect(new InetSocketAddress(host, port))
+ bootstrap
+ .connect(new InetSocketAddress(host, port))
+ .addListener { connectFuture: ChannelFuture =>
+ connectFuture.getChannel.getPipeline.get(BootstrapFactory.SSL_HANDLER_NAME).asInstanceOf[SslHandler].handshake.addListener { handshakeFuture: ChannelFuture =>
+ handshakeFuture.getChannel.write(buildRequestWithRelativeURI(request))
+ }
+ }
}
+
+ info("Received " + request.getMethod + " on " + request.getUri)
+ if (request.getMethod == HttpMethod.CONNECT) handleConnect
+ else handlePropagatableRequest
}
}
Oops, something went wrong.

0 comments on commit 5ce59e0

Please sign in to comment.