Skip to content

Commit

Permalink
martian: replace ConnectPassthrough with ConnectFunc
Browse files Browse the repository at this point in the history
ConnectFunc gives much more flexibility to callers. It also allows to dynamically fall back to martian connect code without exporting it - it does not fit well to public API.
  • Loading branch information
mmatczuk committed Nov 17, 2023
1 parent c965cd9 commit 0cf71d4
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 78 deletions.
31 changes: 31 additions & 0 deletions internal/martian/connect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright 2023 Sauce Labs Inc. All rights reserved.
//
// Copyright 2015 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package martian

import (
"errors"
"io"
"net/http"
)

// ErrConnectFallback is returned by a ConnectFunc to indicate
// that the CONNECT request should be handled by martian.
var ErrConnectFallback = errors.New("martian: connect fallback")

// ConnectFunc dials a network connection for a CONNECT request.
// If the returned net.Conn is not nil, the response must be not nil.
type ConnectFunc func(req *http.Request) (*http.Response, io.ReadWriteCloser, error)
47 changes: 16 additions & 31 deletions internal/martian/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package martian
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"

"github.com/saucelabs/forwarder/internal/martian/log"
"github.com/saucelabs/forwarder/internal/martian/proxyutil"
)

func copyHeader(dst, src http.Header) {
Expand Down Expand Up @@ -117,40 +117,25 @@ func (p proxyHandler) handleConnectRequest(ctx *Context, rw http.ResponseWriter,
log.Debugf(req.Context(), "attempting to establish CONNECT tunnel: %s", req.URL.Host)
var (
res *http.Response
cr io.Reader
cw io.WriteCloser
crw io.ReadWriteCloser
cerr error
)
if p.ConnectPassthrough { //nolint:nestif // to be fixed in #445
pr, pw := io.Pipe()
req.Body = pr
defer req.Body.Close()

// perform the HTTP roundtrip
res, cerr = p.roundTrip(ctx, req)
if res != nil {
cr = res.Body
cw = pw

if res.StatusCode/100 == 2 {
res = proxyutil.NewResponse(200, http.NoBody, req)
}
}
} else {
if p.ConnectFunc != nil {
res, crw, cerr = p.ConnectFunc(req)
}
if p.ConnectFunc == nil || errors.Is(cerr, ErrConnectFallback) {
var cconn net.Conn
res, cconn, cerr = p.connect(req)

if cconn != nil {
defer cconn.Close()
cr = cconn
cw = cconn
crw = cconn

if shouldTerminateTLS(req) {
log.Debugf(req.Context(), "attempting to terminate TLS on CONNECT tunnel: %s", req.URL.Host)
tconn := tls.Client(cconn, p.clientTLSConfig())
if err := tconn.Handshake(); err == nil {
cr = tconn
cw = tconn
crw = tconn
} else {
log.Errorf(req.Context(), "failed to terminate TLS on CONNECT tunnel: %v", err)
cerr = err
Expand Down Expand Up @@ -187,7 +172,7 @@ func (p proxyHandler) handleConnectRequest(ctx *Context, rw http.ResponseWriter,
res.ContentLength = -1
}

if err := p.tunnel("CONNECT", rw, req, res, cw, cr); err != nil {
if err := p.tunnel("CONNECT", rw, req, res, crw); err != nil {
log.Errorf(req.Context(), "CONNECT tunnel: %v", err)
panic(http.ErrAbortHandler)
}
Expand All @@ -204,13 +189,13 @@ func (p proxyHandler) handleUpgradeResponse(rw http.ResponseWriter, req *http.Re

res.Body = nil

if err := p.tunnel(resUpType, rw, req, res, uconn, uconn); err != nil {
if err := p.tunnel(resUpType, rw, req, res, uconn); err != nil {
log.Errorf(req.Context(), "%s tunnel: %w", resUpType, err)
panic(http.ErrAbortHandler)
}
}

func (p proxyHandler) tunnel(name string, rw http.ResponseWriter, req *http.Request, res *http.Response, cw io.WriteCloser, cr io.Reader) error {
func (p proxyHandler) tunnel(name string, rw http.ResponseWriter, req *http.Request, res *http.Response, crw io.ReadWriteCloser) error {
var (
rc = http.NewResponseController(rw)
donec = make(chan bool, 2)
Expand All @@ -229,12 +214,12 @@ func (p proxyHandler) tunnel(name string, rw http.ResponseWriter, req *http.Requ
if err := brw.Flush(); err != nil {
return fmt.Errorf("got error while flushing response back to client: %w", err)
}
if err := drainBuffer(cw, brw.Reader); err != nil {
if err := drainBuffer(crw, brw.Reader); err != nil {
return fmt.Errorf("got error while draining buffer: %w", err)
}

go copySync(req.Context(), "outbound "+name, cw, conn, donec)
go copySync(req.Context(), "inbound "+name, conn, cr, donec)
go copySync(req.Context(), "outbound "+name, crw, conn, donec)
go copySync(req.Context(), "inbound "+name, conn, crw, donec)
case 2:
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
Expand All @@ -243,8 +228,8 @@ func (p proxyHandler) tunnel(name string, rw http.ResponseWriter, req *http.Requ
return fmt.Errorf("got error while flushing response back to client: %w", err)
}

go copySync(req.Context(), "outbound "+name, cw, req.Body, donec)
go copySync(req.Context(), "inbound "+name, writeFlusher{rw, rc}, cr, donec)
go copySync(req.Context(), "outbound "+name, crw, req.Body, donec)
go copySync(req.Context(), "inbound "+name, writeFlusher{rw, rc}, crw, donec)
default:
return fmt.Errorf("unsupported protocol version: %d", req.ProtoMajor)
}
Expand Down
49 changes: 17 additions & 32 deletions internal/martian/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ type Proxy struct {
// If empty, no action is taken, and the proxy will generate a new request ID.
RequestIDHeader string

// ConnectPassthrough passes CONNECT requests to the RoundTripper,
// and uses the response body as the connection.
ConnectPassthrough bool

// ConnectRequestModifier modifies CONNECT requests to upstream proxy.
// If ConnectPassthrough is enabled, this is ignored.
ConnectRequestModifier func(*http.Request) error

// ConnectFunc specifies a function to dial network connections for CONNECT requests.
// Implementations can return ErrConnectFallback to indicate that the CONNECT request should be handled by martian.
ConnectFunc ConnectFunc

// MITMFilter specifies a function to determine whether a CONNECT request should be MITMed.
MITMFilter func(*http.Request) bool

Expand Down Expand Up @@ -548,40 +548,25 @@ func (p *Proxy) handleConnectRequest(ctx *Context, req *http.Request, session *S
log.Debugf(req.Context(), "attempting to establish CONNECT tunnel: %s", req.URL.Host)
var (
res *http.Response
cr io.Reader
cw io.WriteCloser
crw io.ReadWriteCloser
cerr error
)
if p.ConnectPassthrough { //nolint:nestif // to be fixed in #445
pr, pw := io.Pipe()
req.Body = pr
defer req.Body.Close()

// perform the HTTP roundtrip
res, cerr = p.roundTrip(ctx, req)
if res != nil {
cr = res.Body
cw = pw

if res.StatusCode/100 == 2 {
res = proxyutil.NewResponse(200, nil, req)
}
}
} else {
if p.ConnectFunc != nil {
res, crw, cerr = p.ConnectFunc(req)
}
if p.ConnectFunc == nil || errors.Is(cerr, ErrConnectFallback) {
var cconn net.Conn
res, cconn, cerr = p.connect(req)

if cconn != nil {
defer cconn.Close()
cr = cconn
cw = cconn
crw = cconn

if shouldTerminateTLS(req) {
log.Debugf(req.Context(), "attempting to terminate TLS on CONNECT tunnel: %s", req.URL.Host)
tconn := tls.Client(cconn, p.clientTLSConfig())
if err := tconn.Handshake(); err == nil {
cr = tconn
cw = tconn
crw = tconn
} else {
log.Errorf(req.Context(), "failed to terminate TLS on CONNECT tunnel: %v", err)
cerr = err
Expand Down Expand Up @@ -622,7 +607,7 @@ func (p *Proxy) handleConnectRequest(ctx *Context, req *http.Request, session *S

res.ContentLength = -1

if err := p.tunnel("CONNECT", res, brw, conn, cw, cr); err != nil {
if err := p.tunnel("CONNECT", res, brw, conn, crw); err != nil {
log.Errorf(req.Context(), "CONNECT tunnel: %w", err)
}

Expand All @@ -640,28 +625,28 @@ func (p *Proxy) handleUpgradeResponse(res *http.Response, brw *bufio.ReadWriter,

res.Body = nil

if err := p.tunnel(resUpType, res, brw, conn, uconn, uconn); err != nil {
if err := p.tunnel(resUpType, res, brw, conn, uconn); err != nil {
log.Errorf(res.Request.Context(), "%s tunnel: %w", resUpType, err)
}

return errClose
}

func (p *Proxy) tunnel(name string, res *http.Response, brw *bufio.ReadWriter, conn net.Conn, cw io.Writer, cr io.Reader) error {
func (p *Proxy) tunnel(name string, res *http.Response, brw *bufio.ReadWriter, conn net.Conn, crw io.ReadWriteCloser) error {
if err := res.Write(brw); err != nil {
return fmt.Errorf("got error while writing response back to client: %w", err)
}
if err := brw.Flush(); err != nil {
return fmt.Errorf("got error while flushing response back to client: %w", err)
}
if err := drainBuffer(cw, brw.Reader); err != nil {
if err := drainBuffer(crw, brw.Reader); err != nil {
return fmt.Errorf("got error while draining read buffer: %w", err)
}

ctx := res.Request.Context()
donec := make(chan bool, 2)
go copySync(ctx, "outbound "+name, cw, conn, donec)
go copySync(ctx, "inbound "+name, conn, cr, donec)
go copySync(ctx, "outbound "+name, crw, conn, donec)
go copySync(ctx, "inbound "+name, conn, crw, donec)

log.Debugf(ctx, "switched protocols, proxying %s traffic", name)
<-donec
Expand Down
37 changes: 22 additions & 15 deletions internal/martian/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/saucelabs/forwarder/internal/martian/martiantest"
"github.com/saucelabs/forwarder/internal/martian/mitm"
"github.com/saucelabs/forwarder/internal/martian/proxyutil"
"go.uber.org/multierr"
)

type tempError struct{}
Expand Down Expand Up @@ -1009,27 +1010,33 @@ func TestIntegrationConnectUpstreamProxy(t *testing.T) {
}
}

func TestIntegrationConnectPassthrough(t *testing.T) {
type pipeConn struct {
*io.PipeReader
*io.PipeWriter
}

func (conn pipeConn) CloseWrite() error {
return conn.PipeWriter.Close()
}

func (conn pipeConn) Close() error {
return multierr.Combine(
conn.PipeReader.Close(),
conn.PipeWriter.Close(),
)
}

func TestIntegrationConnectFunc(t *testing.T) {
t.Parallel()

l := newListener(t)
p := NewProxy()
p.ConnectPassthrough = true
defer p.Close()

tr := martiantest.NewTransport()
tr.Func(func(req *http.Request) (*http.Response, error) {
p.ConnectFunc = func(req *http.Request) (*http.Response, io.ReadWriteCloser, error) {
pr, pw := io.Pipe()
go func() {
if _, err := io.Copy(pw, req.Body); err != nil {
t.Errorf("io.Copy(): got %v, want no error", err)
}
pw.Close()
}()
return proxyutil.NewResponse(200, pr, req), nil
})
p.SetRoundTripper(tr)
return proxyutil.NewResponse(200, nil, req), pipeConn{pr, pw}, nil
}
p.SetTimeout(200 * time.Millisecond)
defer p.Close()

go serve(p, l)

Expand Down

0 comments on commit 0cf71d4

Please sign in to comment.