Skip to content

Commit

Permalink
Improve hijack error handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
wi1dcard committed Mar 16, 2024
1 parent 633573e commit b243ddb
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 25 deletions.
60 changes: 38 additions & 22 deletions pkg/hack/hajack_clienthello_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package hack

import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"net"
"time"
Expand All @@ -12,6 +14,10 @@ const (
recordHeaderLen = 5
)

var (
ErrIncompleteClientHello = errors.New("incomplete client hello")
)

type HijackClientHelloConn struct {
// internal tls.Conn
tlsConn net.Conn
Expand All @@ -36,10 +42,9 @@ func (c *HijackClientHelloConn) Read(b []byte) (int, error) {
n, err := c.tlsConn.Read(b)
if err == nil {
if c.hasCompleteClientHello() {
c.vlogf("got %d bytes, but client hello is already mature, skipping", n)
c.vlogf("got %d bytes, but client hello is already mature, skipping hijack", n)
} else {
// ignores the error which should be impossible
_ = c.hijackClientHello(b[:n])
c.hijackClientHello(b[:n])
}
}
return n, err
Expand All @@ -62,43 +67,54 @@ func (c *HijackClientHelloConn) hasCompleteClientHello() bool {
return true
}

func (c *HijackClientHelloConn) hijackClientHello(b []byte) error {
func (c *HijackClientHelloConn) hijackClientHello(b []byte) {
c.buf.Write(b)
c.vlogf("wrote %d bytes, total %d bytes", len(b), c.buf.Len())

// ignores the error which should be impossible
_ = c.tryParseClientHello()
}

func (c *HijackClientHelloConn) tryParseClientHello() error {
if c.hasCompleteClientHello() {
c.vlogf("client hello is mature after wrote")
c.vlogf("client hello is mature, skipping parse")
return nil
}

bufBytes := c.buf.Bytes()
bufLen := c.buf.Len()
if bufBytes[0] != recordTypeHandshake {
err := fmt.Errorf("tls record type is not a handshake")
c.vlogf("%s", err)
return err
if bufLen < 5 {
c.vlogf("buffer too short (%d bytes), skipping parse", bufLen)
return ErrIncompleteClientHello
}

if bufLen >= 5 {
// vers := uint16(bufBytes[1])<<8 | uint16(bufBytes[2])
handshakeLen := uint16(bufBytes[3])<<8 | uint16(bufBytes[4])
c.expectedLen = recordHeaderLen + handshakeLen
recType := bufBytes[0]
if recType != recordTypeHandshake {
return fmt.Errorf("tls record type 0x%x is not a handshake", recType)
}

// call hasCompleteClientHello to truncate the buffer if possible
if c.hasCompleteClientHello() {
c.vlogf("client hello is mature after got record length")
}
vers := uint16(bufBytes[1])<<8 | uint16(bufBytes[2])
if vers < tls.VersionSSL30 || vers > tls.VersionTLS13 {
return fmt.Errorf("unknown tls version: 0x%x", vers)
}

return nil
}
handshakeLen := uint16(bufBytes[3])<<8 | uint16(bufBytes[4])
c.expectedLen = recordHeaderLen + handshakeLen

func (c *HijackClientHelloConn) GetClientHello() []byte {
// call hasCompleteClientHello to truncate the buffer if possible
if c.hasCompleteClientHello() {
return c.buf.Bytes()
} else {
c.vlogf("client hello is mature after got record length")
return nil
} else {
return ErrIncompleteClientHello
}
}

func (c *HijackClientHelloConn) GetClientHello() ([]byte, error) {
if err := c.tryParseClientHello(); err != nil {
return nil, err
}
return c.buf.Bytes(), nil
}

func (c *HijackClientHelloConn) vlogf(format string, args ...any) {
Expand Down
6 changes: 3 additions & 3 deletions pkg/proxyserver/proxyserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ func (server *Server) serveConn(conn net.Conn) {
}

// client hello stored in hajackedConn while reading for real handshake
rec := hijackedConn.GetClientHello()
if len(rec) == 0 {
server.logf("could not read client hello from: %s", conn.RemoteAddr())
rec, err := hijackedConn.GetClientHello()
if err != nil {
server.logf("could not read client hello from: %s", err)
server.metricsRequestsTotalInc("0", "")
return
}
Expand Down

0 comments on commit b243ddb

Please sign in to comment.