Skip to content

Commit

Permalink
Preliminary repair of the trojan outbound
Browse files Browse the repository at this point in the history
  • Loading branch information
selcarpa committed Mar 13, 2024
1 parent b27e6a1 commit ec5074c
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 36 deletions.
4 changes: 2 additions & 2 deletions src/main/kotlin/netty/Common.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ private val logger = KotlinLogging.logger {}
* Auto exec handler
*/
class AutoExecHandler(private val exec: (ChannelHandlerContext) -> Unit) : ChannelInboundHandlerAdapter() {
override fun channelActive(ctx: ChannelHandlerContext) {
override fun handlerAdded(ctx: ChannelHandlerContext) {
exec(ctx)
ctx.pipeline().remove(this)
super.channelActive(ctx)
super.handlerAdded(ctx)
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/main/kotlin/netty/ProxyChannelInitializer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import mu.KotlinLogging
import inbounds.TrojanInboundHandler
import netty.NettyServer.portInboundBinds
import stream.SslActiveHandler
import stream.WebsocketDuplexHandler
import stream.WebSocketDuplexHandler
import utils.closeOnFlush
import java.io.File
import java.net.InetSocketAddress
Expand Down Expand Up @@ -169,7 +169,7 @@ class ProxyChannelInitializer : ChannelInitializer<NioSocketChannel>() {
HttpServerCodec(),
HttpObjectAggregator(Int.MAX_VALUE),
WebSocketServerProtocolHandler(wsInboundSetting.path),
WebsocketDuplexHandler(handleShakePromise)
WebSocketDuplexHandler(handleShakePromise)
)
}
}
Expand Down
55 changes: 40 additions & 15 deletions src/main/kotlin/protocol/Trojan.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import io.netty.handler.codec.http.DefaultHttpHeaders
import io.netty.handler.codec.http.HttpClientCodec
import io.netty.handler.codec.http.HttpObjectAggregator
import io.netty.handler.codec.http.websocketx.WebSocketClientHandshakerFactory
import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler
import io.netty.handler.codec.http.websocketx.WebSocketVersion
import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler
import io.netty.handler.codec.socksx.v5.Socks5CommandType
Expand All @@ -28,9 +27,9 @@ import model.protocol.TrojanPackage
import model.protocol.TrojanRequest
import mu.KotlinLogging
import netty.AutoExecHandler
import netty.EventTriggerHandler
import netty.ExceptionCaughtHandler
import stream.WebsocketDuplexHandler
import stream.WebSocketDuplexHandler
import stream.WebSocketHandshakeHandler
import utils.toAddressType
import utils.toSha224
import utils.toUUid
Expand Down Expand Up @@ -173,14 +172,14 @@ class TrojanProxy(
ctx.pipeline().addBefore(ctx.name(), "HttpObjectAggregator", HttpObjectAggregator(8192))
ctx.pipeline()
.addBefore(ctx.name(), "WebSocketClientCompressionHandler", WebSocketClientCompressionHandler.INSTANCE)

ctx.pipeline().addBefore(
ctx.name(), "websocket_client_handshaker", WebSocketClientProtocolHandler(
WebSocketClientHandshakerFactory.newHandshaker(
uri, WebSocketVersion.V13, null, true, DefaultHttpHeaders()
)
)
)
//
// ctx.pipeline().addBefore(
// ctx.name(), "websocket_client_handshaker", WebSocketClientProtocolHandler(
// WebSocketClientHandshakerFactory.newHandshaker(
// uri, WebSocketVersion.V13, null, true, DefaultHttpHeaders()
// )
// )
// )

val newPromise = ctx.channel().eventLoop().newPromise<Channel>()
newPromise.addListener {
Expand All @@ -191,7 +190,7 @@ class TrojanProxy(
ctx.pipeline().addBefore(ctx.name(), TROJAN_PROXY_OUTBOUND, trojanOutboundHandler)
}
}
ctx.pipeline().addBefore(ctx.name(), "websocket_duplex_handler", WebsocketDuplexHandler(newPromise))
ctx.pipeline().addBefore(ctx.name(), "websocket_duplex_handler", WebSocketDuplexHandler(newPromise))
ctx.pipeline().addLast(ExceptionCaughtHandler())
}

Expand Down Expand Up @@ -244,13 +243,39 @@ class TrojanProxy(
}

override fun newInitialMessage(ctx: ChannelHandlerContext): Any? {
return null
return when (streamBy) {
Protocol.WS, Protocol.WSS -> {
val newHandshaker = WebSocketClientHandshakerFactory.newHandshaker(
URI(
"${
if (streamBy == Protocol.WS) {
"ws"
} else {
"wss"
}
}://${outboundStreamBy.wsOutboundSetting!!.host}:${outboundStreamBy.wsOutboundSetting.port}/${
outboundStreamBy.wsOutboundSetting.path.removePrefix(
"/"
)
}"
), WebSocketVersion.V13, null, true, DefaultHttpHeaders()
)
ctx.pipeline().addBefore(
"websocket_duplex_handler",
"WebSocketHandshakeHandler",
WebSocketHandshakeHandler(newHandshaker)
)
newHandshaker.javaClass.declaredMethods.find { it.name == "newHandshakeRequest" }!!.also {
it.isAccessible = true
}.invoke(newHandshaker)
}

else -> null
}
}

override fun handleResponse(ctx: ChannelHandlerContext, response: Any): Boolean {
return true
}


}

Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@ import io.netty.channel.Channel
import io.netty.channel.ChannelDuplexHandler
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelPromise
import io.netty.channel.SimpleChannelInboundHandler
import io.netty.handler.codec.http.FullHttpRequest
import io.netty.handler.codec.http.FullHttpResponse
import io.netty.handler.codec.http.websocketx.*
import io.netty.util.ReferenceCountUtil
import io.netty.util.concurrent.FutureListener
import io.netty.util.concurrent.Promise
import mu.KotlinLogging

private val logger = KotlinLogging.logger {}
class WebsocketDuplexHandler(private val handleShakePromise: Promise<Channel>? = null) :
ChannelDuplexHandler() {

class WebSocketDuplexHandler(private val handleShakePromise: Promise<Channel>? = null) : ChannelDuplexHandler() {


private var continuationBuffer: ByteBuf? = null
Expand Down Expand Up @@ -134,21 +136,37 @@ class WebsocketDuplexHandler(private val handleShakePromise: Promise<Channel>? =
}
}

fun websocketEvent(ctx: ChannelHandlerContext, evt: Any, handleShakePromise: Promise<Channel>? = null) {
//when surfer as a websocket server, we need to handle handshake complete event to determine whether the handshake is successful, and start the relay operation
if (evt is WebSocketServerProtocolHandler.HandshakeComplete) {
logger.trace { "[${ctx.channel().id()}] WebsocketInbound handshake complete" }
handleShakePromise?.setSuccess(ctx.channel())
}
//when surfer as a websocket client, we also need to handle handshake complete event to determine whether the handshake is successful, and start the relay operation
if (evt is WebSocketClientProtocolHandler.ClientHandshakeStateEvent) {
if (evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
logger.trace { "[${ctx.channel().id()}] WebsocketInbound handshake complete" }
handleShakePromise?.setSuccess(ctx.channel())
} else if (evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT) {
logger.error { "[${ctx.channel().id()}] WebsocketInbound handshake timeout" }
handleShakePromise?.setFailure(Throwable("websocket handshake failed"))
//fun websocketEvent(ctx: ChannelHandlerContext, evt: Any, handleShakePromise: Promise<Channel>? = null) {
// //when surfer as a websocket server, we need to handle handshake complete event to determine whether the handshake is successful, and start the relay operation
// if (evt is WebSocketServerProtocolHandler.HandshakeComplete) {
// logger.trace { "[${ctx.channel().id()}] WebsocketInbound handshake complete" }
// handleShakePromise?.setSuccess(ctx.channel())
// }
// //when surfer as a websocket client, we also need to handle handshake complete event to determine whether the handshake is successful, and start the relay operation
// if (evt is WebSocketClientProtocolHandler.ClientHandshakeStateEvent) {
// if (evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) {
// logger.trace { "[${ctx.channel().id()}] WebsocketInbound handshake complete" }
// handleShakePromise?.setSuccess(ctx.channel())
// } else if (evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT) {
// logger.error { "[${ctx.channel().id()}] WebsocketInbound handshake timeout" }
// handleShakePromise?.setFailure(Throwable("websocket handshake failed"))
// }
// }
// logger.trace { "[${ctx.channel().id()}] userEventTriggered: $evt" }
//}


class WebSocketHandshakeHandler(val handshaker: WebSocketClientHandshaker) :
SimpleChannelInboundHandler<FullHttpResponse>() {
override fun channelRead0(ctx: ChannelHandlerContext, msg: FullHttpResponse) {
if (!handshaker.isHandshakeComplete) {
handshaker.finishHandshake(ctx.channel(), msg);
ctx.fireUserEventTriggered(
WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE
);
ctx.pipeline().remove(this)
return;
}
}
logger.trace { "[${ctx.channel().id()}] userEventTriggered: $evt" }

}

0 comments on commit ec5074c

Please sign in to comment.