Skip to content

Commit

Permalink
[split] Add client support for proxies that support HTTP CONNECT, suc…
Browse files Browse the repository at this point in the history
…h as squid.

Signed-off-by: marius a. eriksen <marius@twitter.com>

RB_ID=135062
  • Loading branch information
cooper bethea authored and CI committed Mar 26, 2013
1 parent 3f487db commit ca52719
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 2 deletions.
Expand Up @@ -167,6 +167,7 @@ private[builder] final case class ClientConfig[Req, Rep, HasCluster, HasCodec, H
private val _logger : Option[Logger] = None,
private val _channelFactory : Option[ChannelFactory] = None,
private val _tls : Option[(() => Engine, Option[String])] = None,
private val _httpProxy : Option[SocketAddress] = None,
private val _socksProxy : Option[SocketAddress] = None,
private val _failureAccrual : Option[Timer => ServiceFactoryWrapper] = Some(FailureAccrualFactory.wrapper(5, 5.seconds)),
private val _tracer : Tracer = NullTracer,
Expand Down Expand Up @@ -210,6 +211,7 @@ private[builder] final case class ClientConfig[Req, Rep, HasCluster, HasCodec, H
val logger = _logger
val channelFactory = _channelFactory
val tls = _tls
val httpProxy = _httpProxy
val socksProxy = _socksProxy
val failureAccrual = _failureAccrual
val tracer = _tracer
Expand Down Expand Up @@ -243,6 +245,7 @@ private[builder] final case class ClientConfig[Req, Rep, HasCluster, HasCodec, H
"logger" -> _logger,
"channelFactory" -> _channelFactory,
"tls" -> _tls,
"httpProxy" -> _httpProxy,
"socksProxy" -> _socksProxy,
"failureAccrual" -> _failureAccrual,
"tracer" -> Some(_tracer),
Expand Down Expand Up @@ -564,7 +567,15 @@ class ClientBuilder[Req, Rep, HasCluster, HasCodec, HasHostConnectionLimit] priv
withConfig(_.copy(_tls = Some({ () => Ssl.clientWithoutCertificateValidation()}, None)))

/**
* Make connections via the given SOCKS proxy
* Make connections via the given HTTP proxy.
* If this is defined concurrently with socksProxy, the order in which they are applied is undefined.
*/
def httpProxy(httpProxy: SocketAddress): This =
withConfig(_.copy(_httpProxy = Some(httpProxy)))

/**
* Make connections via the given SOCKS proxy.
* If this is defined concurrently with httpProxy, the order in which they are applied is undefined.
*/
def socksProxy(socksProxy: SocketAddress): This =
withConfig(_.copy(_socksProxy = Some(socksProxy)))
Expand Down Expand Up @@ -685,6 +696,7 @@ class ClientBuilder[Req, Rep, HasCluster, HasCodec, HasHostConnectionLimit] priv
newChannel = newChannel,
newTransport = codec.newClientTransport(_, statsReceiver),
tlsConfig = config.tls map { case (e, v) => Netty3TransporterTLSConfig(e, v) },
httpProxy = config.httpProxy,
socksProxy = config.socksProxy,
channelReaderTimeout = config.readerIdleTimeout getOrElse Duration.Top,
channelWriterTimeout = config.writerIdleTimeout getOrElse Duration.Top,
Expand Down
@@ -0,0 +1,109 @@
package com.twitter.finagle.httpproxy

import java.net.{InetSocketAddress, SocketAddress}
import java.util.concurrent.atomic.AtomicReference

import org.jboss.netty.buffer.{ChannelBuffer, ChannelBuffers}
import org.jboss.netty.channel._
import org.jboss.netty.handler.codec.http.{DefaultHttpRequest, DefaultHttpResponse, HttpClientCodec, HttpMethod, HttpResponseStatus, HttpVersion}
import org.jboss.netty.util.CharsetUtil

import com.twitter.finagle.{ChannelClosedException, ConnectionFailedException, InconsistentStateException}

/**
* Handle SSL connections through a proxy that accepts HTTP CONNECT.
*
* See http://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html#9.9
*
*/
class HttpConnectHandler(proxyAddr: SocketAddress, addr: InetSocketAddress, pipeline: ChannelPipeline)
extends SimpleChannelHandler
{
private[this] val clientCodec = new HttpClientCodec()
pipeline.addFirst("httpProxyCodec", clientCodec)
private[this] val connectFuture = new AtomicReference[ChannelFuture](null)

private[this] def fail(c: Channel, t: Throwable) {
Option(connectFuture.get) foreach { _.setFailure(t) }
Channels.close(c)
}

private[this] def writeRequest(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
val req = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.CONNECT, addr.getAddress.getHostName + ":" + addr.getPort)
Channels.write(ctx, Channels.future(ctx.getChannel), req, null)
}

override def connectRequested(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
e match {
case de: DownstreamChannelStateEvent =>
if (!connectFuture.compareAndSet(null, e.getFuture)) {
fail(ctx.getChannel, new InconsistentStateException(addr))
return
}

// proxy cancellation
val wrappedConnectFuture = Channels.future(de.getChannel, true)
de.getFuture.addListener(new ChannelFutureListener {
def operationComplete(f: ChannelFuture) {
if (f.isCancelled)
wrappedConnectFuture.cancel()
}
})
// Proxy failures here so that if the connect fails, it is
// propagated to the listener, not just on the channel.
wrappedConnectFuture.addListener(new ChannelFutureListener {
def operationComplete(f: ChannelFuture) {
if (f.isSuccess || f.isCancelled)
return

fail(f.getChannel, f.getCause)
}
})

val wrappedEvent = new DownstreamChannelStateEvent(
de.getChannel, wrappedConnectFuture,
de.getState, proxyAddr)

super.connectRequested(ctx, wrappedEvent)

case _ =>
fail(ctx.getChannel, new InconsistentStateException(addr))
}
}

// we delay propagating connection upstream until we've completed the proxy connection.
override def channelConnected(ctx: ChannelHandlerContext, e: ChannelStateEvent) {
if (connectFuture.get eq null) {
fail(ctx.getChannel, new InconsistentStateException(addr))
return
}

// proxy cancellations again.
connectFuture.get.addListener(new ChannelFutureListener {
def operationComplete(f: ChannelFuture) {
if (f.isSuccess)
HttpConnectHandler.super.channelConnected(ctx, e)

else if (f.isCancelled)
fail(ctx.getChannel, new ChannelClosedException(addr))
}
})

writeRequest(ctx, e)
}

override def messageReceived(ctx: ChannelHandlerContext, e: MessageEvent) {
if (connectFuture.get eq null) {
fail(ctx.getChannel, new InconsistentStateException(addr))
return
}
val resp = e.getMessage.asInstanceOf[DefaultHttpResponse]
if (resp.getStatus == HttpResponseStatus.OK) {
ctx.getPipeline.remove(clientCodec)
ctx.getPipeline.remove(this)
connectFuture.get.setSuccess()
} else {
fail(e.getChannel, new ConnectionFailedException(null, addr))
}
}
}
Expand Up @@ -4,6 +4,7 @@ import com.twitter.finagle._
import com.twitter.finagle.channel.{
ChannelRequestStatsHandler, ChannelStatsHandler, IdleChannelHandler
}
import com.twitter.finagle.httpproxy.HttpConnectHandler
import com.twitter.finagle.socks.SocksConnectHandler
import com.twitter.finagle.ssl.{Engine, SslConnectHandler}
import com.twitter.finagle.stats.{ClientStatsReceiver, StatsReceiver}
Expand Down Expand Up @@ -125,14 +126,15 @@ case class Netty3Transporter[In, Out](
newChannel: ChannelPipeline => Channel = Netty3Transporter.channelFactory.newChannel(_),
newTransport: Channel => Transport[In, Out] = new ChannelTransport[In, Out](_),
tlsConfig: Option[Netty3TransporterTLSConfig] = None,
httpProxy: Option[SocketAddress] = None,
socksProxy: Option[SocketAddress] = None,
channelReaderTimeout: Duration = Duration.Top,
channelWriterTimeout: Duration = Duration.Top,
channelSnooper: Option[ChannelSnooper] = None,
channelOptions: Map[String, Object] = Netty3Transporter.defaultChannelOptions
) extends ((SocketAddress, StatsReceiver) => Future[Transport[In, Out]]) {
private[this] val statsHandlers = new IdentityHashMap[StatsReceiver, ChannelHandler]

// TODO: These gauges will stay around forever. It's
// fine, but it would be nice to clean them up.
def channelStatsHandler(statsReceiver: StatsReceiver) = synchronized {
Expand Down Expand Up @@ -192,6 +194,12 @@ case class Netty3Transporter[In, Out](
case _ =>
}

(httpProxy, addr) match {
case (Some(proxyAddr), (inetAddr : InetSocketAddress)) =>
pipeline.addFirst("httpConnect", new HttpConnectHandler(proxyAddr, inetAddr, pipeline))
case _ =>
}

for (snooper <- channelSnooper)
pipeline.addFirst("channelSnooper", snooper)

Expand Down
@@ -0,0 +1,120 @@
package com.twitter.finagle.httpproxy

import java.net.{InetAddress, InetSocketAddress, SocketAddress}
import java.util.Arrays
import org.jboss.netty.channel._
import org.jboss.netty.handler.codec.http.{DefaultHttpRequest, DefaultHttpResponse, HttpMethod,
HttpResponseStatus, HttpVersion}
import org.jboss.netty.buffer.{ChannelBuffers, ChannelBuffer}
import org.mockito.ArgumentCaptor
import org.specs.SpecificationWithJUnit
import org.specs.mock.Mockito
import org.specs.matcher.Matcher

class HttpConnectHandlerSpec extends SpecificationWithJUnit with Mockito {
"HttpConnectHandler" should {
val ctx = mock[ChannelHandlerContext]
val channel = mock[Channel]
ctx.getChannel returns channel
val pipeline = mock[ChannelPipeline]
ctx.getPipeline returns pipeline
channel.getPipeline returns pipeline
val closeFuture = Channels.future(channel)
channel.getCloseFuture returns closeFuture
val remoteAddress = new InetSocketAddress("localhost", 443)
channel.getRemoteAddress returns remoteAddress
val proxyAddress = mock[SocketAddress]
val connectFuture = Channels.future(channel, true)
val connectRequested = new DownstreamChannelStateEvent(
channel, connectFuture, ChannelState.CONNECTED, remoteAddress)
val ch = new HttpConnectHandler(proxyAddress, remoteAddress, pipeline)
ch.handleDownstream(ctx, connectRequested)

def checkDidClose() {
val ec = ArgumentCaptor.forClass(classOf[DownstreamChannelStateEvent])
there was one(pipeline).sendDownstream(ec.capture)
val e = ec.getValue
e.getChannel must be(channel)
e.getFuture must be(closeFuture)
e.getState must be(ChannelState.OPEN)
e.getValue must be(java.lang.Boolean.FALSE)
}

"upon connect" in {
val ec = ArgumentCaptor.forClass(classOf[DownstreamChannelStateEvent])
there was one(ctx).sendDownstream(ec.capture)
val e = ec.getValue

"wrap the downstream connect request" in {
e.getChannel must be(channel)
e.getFuture must notBe(connectFuture) // this is proxied
e.getState must be(ChannelState.CONNECTED)
e.getValue must be(proxyAddress)
}

"propagate cancellation" in {
e.getFuture.isCancelled must beFalse
connectFuture.cancel()
e.getFuture.isCancelled must beTrue
}
}

"when connect is succesful" in {
ch.handleUpstream(ctx, new UpstreamChannelStateEvent(
channel, ChannelState.CONNECTED, remoteAddress))
connectFuture.isDone must beFalse
there was no(ctx).sendUpstream(any)

"not propagate success" in {
there was no(ctx).sendUpstream(any)
}

"propagate connection cancellation" in {
connectFuture.cancel()
checkDidClose()
}

"do HTTP CONNECT" in {
{ // send connect request
val ec = ArgumentCaptor.forClass(classOf[DownstreamMessageEvent])
there was atLeastOne(ctx).sendDownstream(ec.capture)
val e = ec.getValue
val req = e.getMessage.asInstanceOf[DefaultHttpRequest]
req.getMethod must_== HttpMethod.CONNECT
req.getUri must_== "localhost:443"
}

{ // when connect response is received, propagate the connect and remove the handler
ch.handleUpstream(ctx, new UpstreamMessageEvent(
channel,
new DefaultHttpResponse(HttpVersion.HTTP_1_0, HttpResponseStatus.OK),
null))

connectFuture.isDone must beTrue
there was one(pipeline).remove(ch)

// we propagated the connect
val ec = ArgumentCaptor.forClass(classOf[UpstreamChannelStateEvent])
there was one(ctx).sendUpstream(ec.capture)
val e = ec.getValue

e.getChannel must be(channel)
e.getState must be(ChannelState.CONNECTED)
e.getValue must be(remoteAddress)
}
}
}

"propagate connection failure" in {
val ec = ArgumentCaptor.forClass(classOf[DownstreamChannelStateEvent])
there was one(ctx).sendDownstream(ec.capture)
val e = ec.getValue
val exc = new Exception("failed to connect")

connectFuture.isDone must beFalse
e.getFuture.setFailure(exc)
connectFuture.isDone must beTrue
connectFuture.getCause must be_==(exc)
}
}
}

0 comments on commit ca52719

Please sign in to comment.