Skip to content

Commit

Permalink
Fix recorder HTTPS, close gatling#884
Browse files Browse the repository at this point in the history
  • Loading branch information
Stephane Landelle committed Dec 28, 2012
1 parent cdddb1c commit 5ce59e0
Show file tree
Hide file tree
Showing 6 changed files with 376 additions and 376 deletions.
Expand Up @@ -17,7 +17,7 @@ package com.excilys.ebi.gatling.recorder.http.channel

import org.jboss.netty.bootstrap.{ ServerBootstrap, ClientBootstrap }
import org.jboss.netty.channel.{ ChannelPipelineFactory, ChannelPipeline, ChannelHandlerContext }
import org.jboss.netty.channel.Channels.pipeline
import org.jboss.netty.channel.Channels
import org.jboss.netty.channel.socket.nio.{ NioServerSocketChannelFactory, NioClientSocketChannelFactory }
import org.jboss.netty.handler.codec.http.{ HttpResponseEncoder, HttpRequestDecoder, HttpRequest, HttpContentDecompressor, HttpContentCompressor, HttpClientCodec, HttpChunkAggregator }
import org.jboss.netty.handler.ssl.SslHandler
Expand All @@ -29,26 +29,28 @@ import com.excilys.ebi.gatling.recorder.http.ssl.{ SSLEngineFactory, FirstEventI

object BootstrapFactory {

val SSL_HANDLER_NAME = "ssl"

private val CHUNK_MAX_SIZE = 100 * 1024 * 1024; // 100Mo

private val clientChannelFactory = new NioClientSocketChannelFactory

private val serverChannelFactory = new NioServerSocketChannelFactory

def newClientBootstrap(controller: RecorderController, browserCtx: ChannelHandlerContext, browserRequest: HttpRequest, ssl: Boolean): ClientBootstrap = {
def newClientBootstrap(controller: RecorderController, requestContext: ChannelHandlerContext, browserRequest: HttpRequest, ssl: Boolean): ClientBootstrap = {
val bootstrap = new ClientBootstrap(clientChannelFactory)
bootstrap.setPipelineFactory(new ChannelPipelineFactory() {
def getPipeline: ChannelPipeline = {
val tmpPipeline = pipeline()
val pipeline = Channels.pipeline

if (ssl)
tmpPipeline.addLast("ssl", new SslHandler(SSLEngineFactory.newClientSSLEngine))
tmpPipeline.addLast("codec", new HttpClientCodec)
tmpPipeline.addLast("inflater", new HttpContentDecompressor)
tmpPipeline.addLast("aggregator", new HttpChunkAggregator(CHUNK_MAX_SIZE))
tmpPipeline.addLast("gatling", new ServerHttpResponseHandler(controller, browserCtx, browserRequest))
pipeline.addLast(SSL_HANDLER_NAME, new SslHandler(SSLEngineFactory.newClientSSLEngine))
pipeline.addLast("codec", new HttpClientCodec)
pipeline.addLast("inflater", new HttpContentDecompressor)
pipeline.addLast("aggregator", new HttpChunkAggregator(CHUNK_MAX_SIZE))
pipeline.addLast("gatling", new ServerHttpResponseHandler(controller, requestContext, browserRequest))

tmpPipeline
pipeline
}
})

Expand All @@ -64,19 +66,19 @@ object BootstrapFactory {

bootstrap.setPipelineFactory(new ChannelPipelineFactory() {
def getPipeline: ChannelPipeline = {
val tmpPipeline = pipeline()
val pipeline = Channels.pipeline
if (ssl)
tmpPipeline.addLast("ssl", new FirstEventIsUnsecuredConnectSslHandler(SSLEngineFactory.newServerSSLEngine))
tmpPipeline.addLast("decoder", new HttpRequestDecoder)
tmpPipeline.addLast("aggregator", new HttpChunkAggregator(CHUNK_MAX_SIZE))
tmpPipeline.addLast("encoder", new HttpResponseEncoder)
tmpPipeline.addLast("deflater", new HttpContentCompressor)
pipeline.addLast(SSL_HANDLER_NAME, new FirstEventIsUnsecuredConnectSslHandler(SSLEngineFactory.newServerSSLEngine))
pipeline.addLast("decoder", new HttpRequestDecoder)
pipeline.addLast("aggregator", new HttpChunkAggregator(CHUNK_MAX_SIZE))
pipeline.addLast("encoder", new HttpResponseEncoder)
pipeline.addLast("deflater", new HttpContentCompressor)
if (ssl)
tmpPipeline.addLast("gatling", new BrowserHttpsRequestHandler(controller, proxyConfig))
pipeline.addLast("gatling", new BrowserHttpsRequestHandler(controller, proxyConfig))
else
tmpPipeline.addLast("gatling", new BrowserHttpRequestHandler(controller, proxyConfig))
pipeline.addLast("gatling", new BrowserHttpRequestHandler(controller, proxyConfig))

tmpPipeline
pipeline
}
})

Expand Down
Expand Up @@ -31,6 +31,10 @@ import grizzled.slf4j.Logging

abstract class AbstractBrowserRequestHandler(controller: RecorderController, proxyConfig: ProxyConfig) extends SimpleChannelHandler with Logging {

implicit def function2ChannelFutureListener(thunk: ChannelFuture => Any) = new ChannelFutureListener {
def operationComplete(future: ChannelFuture) { thunk(future) }
}

override def messageReceived(ctx: ChannelHandlerContext, event: MessageEvent) {

event.getMessage match {
Expand All @@ -45,39 +49,26 @@ abstract class AbstractBrowserRequestHandler(controller: RecorderController, pro
}
}.getOrElse(request.removeHeader("Proxy-Connection")) // remove Proxy-Connection header if it's not significant

val future = connectToServerOnBrowserRequestReceived(ctx, request)
propagateRequest(ctx, request)

controller.receiveRequest(request)

sendRequestToServerAfterConnection(future, request);

case unknown => warn("Received unknown message: " + unknown)
}
}

def connectToServerOnBrowserRequestReceived(ctx: ChannelHandlerContext, request: HttpRequest): ChannelFuture
def propagateRequest(requestContext: ChannelHandlerContext, request: HttpRequest)

override def exceptionCaught(ctx: ChannelHandlerContext, e: ExceptionEvent) {
error("Exception caught", e.getCause)

// Properly closing
val future = ctx.getChannel.getCloseFuture
future.addListener(new ChannelFutureListener {
def operationComplete(future: ChannelFuture) = future.getChannel.close
})
future.addListener(ChannelFutureListener.CLOSE)
ctx.sendUpstream(e)
}

private def sendRequestToServerAfterConnection(future: ChannelFuture, request: HttpRequest) {

Option(future).map { future =>
future.addListener(new ChannelFutureListener {
def operationComplete(future: ChannelFuture) = future.getChannel.write(buildRequestWithRelativeURI(request))
})
}
}

private def buildRequestWithRelativeURI(request: HttpRequest) = {
def buildRequestWithRelativeURI(request: HttpRequest) = {
val uri = new URI(request.getUri)
val newUri = new URI(null, null, null, -1, uri.getPath, uri.getQuery, uri.getFragment).toString
val newRequest = new DefaultHttpRequest(request.getProtocolVersion, request.getMethod, newUri)
Expand Down
Expand Up @@ -26,9 +26,9 @@ import com.excilys.ebi.gatling.recorder.http.channel.BootstrapFactory.newClientB

class BrowserHttpRequestHandler(controller: RecorderController, proxyConfig: ProxyConfig) extends AbstractBrowserRequestHandler(controller, proxyConfig) {

def connectToServerOnBrowserRequestReceived(ctx: ChannelHandlerContext, request: HttpRequest): ChannelFuture = {
def propagateRequest(requestContext: ChannelHandlerContext, request: HttpRequest) {

val bootstrap = newClientBootstrap(controller, ctx, request, false)
val bootstrap = newClientBootstrap(controller, requestContext, request, false)

val (proxyHost, proxyPort) = (for {
host <- proxyConfig.host
Expand All @@ -40,6 +40,8 @@ class BrowserHttpRequestHandler(controller: RecorderController, proxyConfig: Pro
(uri.getHost, port)
}

bootstrap.connect(new InetSocketAddress(proxyHost, proxyPort))
bootstrap
.connect(new InetSocketAddress(proxyHost, proxyPort))
.addListener { future: ChannelFuture => future.getChannel.write(buildRequestWithRelativeURI(request)) }
}
}
Expand Up @@ -19,9 +19,11 @@ import java.net.{ InetSocketAddress, URI }

import org.jboss.netty.channel.{ ChannelFuture, ChannelHandlerContext }
import org.jboss.netty.handler.codec.http.{ DefaultHttpResponse, HttpMethod, HttpRequest, HttpResponseStatus, HttpVersion }
import org.jboss.netty.handler.ssl.SslHandler

import com.excilys.ebi.gatling.recorder.config.ProxyConfig
import com.excilys.ebi.gatling.recorder.controller.RecorderController
import com.excilys.ebi.gatling.recorder.http.channel.BootstrapFactory
import com.excilys.ebi.gatling.recorder.http.channel.BootstrapFactory.newClientBootstrap

import grizzled.slf4j.Logging
Expand All @@ -30,34 +32,37 @@ class BrowserHttpsRequestHandler(controller: RecorderController, proxyConfig: Pr

@volatile var targetHostURI: URI = _

def connectToServerOnBrowserRequestReceived(ctx: ChannelHandlerContext, request: HttpRequest): ChannelFuture = {

info("Received " + request.getMethod + " on " + request.getUri)

if (request.getMethod == HttpMethod.CONNECT) {

targetHostURI = new URI("https://" + request.getUri());
def propagateRequest(requestContext: ChannelHandlerContext, request: HttpRequest) {

def handleConnect {
targetHostURI = new URI("https://" + request.getUri)
warn("Trying to connect to " + targetHostURI + ", make sure you've accepted the recorder certificate for this site")

controller.secureConnection(targetHostURI)
requestContext.getChannel.write(new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK))
}

ctx.getChannel.write(new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK))

null

} else {
// set full uri so that it's correctly recorded
def handlePropagatableRequest {
// set full uri so that it's correctly recorded FIXME ugly
request.setUri(targetHostURI + request.getUri)

val bootstrap = newClientBootstrap(controller, ctx, request, true)
val bootstrap = newClientBootstrap(controller, requestContext, request, true)

val (host, port) = (for {
host <- proxyConfig.host
port <- proxyConfig.port
} yield (host, port)).getOrElse(targetHostURI.getHost, targetHostURI.getPort)

bootstrap.connect(new InetSocketAddress(host, port))
bootstrap
.connect(new InetSocketAddress(host, port))
.addListener { connectFuture: ChannelFuture =>
connectFuture.getChannel.getPipeline.get(BootstrapFactory.SSL_HANDLER_NAME).asInstanceOf[SslHandler].handshake.addListener { handshakeFuture: ChannelFuture =>
handshakeFuture.getChannel.write(buildRequestWithRelativeURI(request))
}
}
}

info("Received " + request.getMethod + " on " + request.getUri)
if (request.getMethod == HttpMethod.CONNECT) handleConnect
else handlePropagatableRequest
}
}

0 comments on commit 5ce59e0

Please sign in to comment.