From 0cf71d45b1e206ce88fc3d1c2e2d5b446e8de4dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Matczuk?= Date: Fri, 17 Nov 2023 14:34:30 +0100 Subject: [PATCH] martian: replace ConnectPassthrough with ConnectFunc ConnectFunc gives much more flexibility to callers. It also allows to dynamically fall back to martian connect code without exporting it - it does not fit well to public API. --- internal/martian/connect.go | 31 +++++++++++++++++++++ internal/martian/handler.go | 47 +++++++++++--------------------- internal/martian/proxy.go | 49 ++++++++++++---------------------- internal/martian/proxy_test.go | 37 ++++++++++++++----------- 4 files changed, 86 insertions(+), 78 deletions(-) create mode 100644 internal/martian/connect.go diff --git a/internal/martian/connect.go b/internal/martian/connect.go new file mode 100644 index 00000000..1866afdd --- /dev/null +++ b/internal/martian/connect.go @@ -0,0 +1,31 @@ +// Copyright 2023 Sauce Labs Inc. All rights reserved. +// +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package martian + +import ( + "errors" + "io" + "net/http" +) + +// ErrConnectFallback is returned by a ConnectFunc to indicate +// that the CONNECT request should be handled by martian. +var ErrConnectFallback = errors.New("martian: connect fallback") + +// ConnectFunc dials a network connection for a CONNECT request. +// If the returned net.Conn is not nil, the response must be not nil. +type ConnectFunc func(req *http.Request) (*http.Response, io.ReadWriteCloser, error) diff --git a/internal/martian/handler.go b/internal/martian/handler.go index 16984054..bf397f7c 100644 --- a/internal/martian/handler.go +++ b/internal/martian/handler.go @@ -19,6 +19,7 @@ package martian import ( "context" "crypto/tls" + "errors" "fmt" "io" "net" @@ -26,7 +27,6 @@ import ( "strings" "github.com/saucelabs/forwarder/internal/martian/log" - "github.com/saucelabs/forwarder/internal/martian/proxyutil" ) func copyHeader(dst, src http.Header) { @@ -117,40 +117,25 @@ func (p proxyHandler) handleConnectRequest(ctx *Context, rw http.ResponseWriter, log.Debugf(req.Context(), "attempting to establish CONNECT tunnel: %s", req.URL.Host) var ( res *http.Response - cr io.Reader - cw io.WriteCloser + crw io.ReadWriteCloser cerr error ) - if p.ConnectPassthrough { //nolint:nestif // to be fixed in #445 - pr, pw := io.Pipe() - req.Body = pr - defer req.Body.Close() - - // perform the HTTP roundtrip - res, cerr = p.roundTrip(ctx, req) - if res != nil { - cr = res.Body - cw = pw - - if res.StatusCode/100 == 2 { - res = proxyutil.NewResponse(200, http.NoBody, req) - } - } - } else { + if p.ConnectFunc != nil { + res, crw, cerr = p.ConnectFunc(req) + } + if p.ConnectFunc == nil || errors.Is(cerr, ErrConnectFallback) { var cconn net.Conn res, cconn, cerr = p.connect(req) if cconn != nil { defer cconn.Close() - cr = cconn - cw = cconn + crw = cconn if shouldTerminateTLS(req) { log.Debugf(req.Context(), "attempting to terminate TLS on CONNECT tunnel: %s", req.URL.Host) tconn := tls.Client(cconn, p.clientTLSConfig()) if err := tconn.Handshake(); err == nil { - cr = tconn - cw = tconn + crw = tconn } else { log.Errorf(req.Context(), "failed to terminate TLS on CONNECT tunnel: %v", err) cerr = err @@ -187,7 +172,7 @@ func (p proxyHandler) handleConnectRequest(ctx *Context, rw http.ResponseWriter, res.ContentLength = -1 } - if err := p.tunnel("CONNECT", rw, req, res, cw, cr); err != nil { + if err := p.tunnel("CONNECT", rw, req, res, crw); err != nil { log.Errorf(req.Context(), "CONNECT tunnel: %v", err) panic(http.ErrAbortHandler) } @@ -204,13 +189,13 @@ func (p proxyHandler) handleUpgradeResponse(rw http.ResponseWriter, req *http.Re res.Body = nil - if err := p.tunnel(resUpType, rw, req, res, uconn, uconn); err != nil { + if err := p.tunnel(resUpType, rw, req, res, uconn); err != nil { log.Errorf(req.Context(), "%s tunnel: %w", resUpType, err) panic(http.ErrAbortHandler) } } -func (p proxyHandler) tunnel(name string, rw http.ResponseWriter, req *http.Request, res *http.Response, cw io.WriteCloser, cr io.Reader) error { +func (p proxyHandler) tunnel(name string, rw http.ResponseWriter, req *http.Request, res *http.Response, crw io.ReadWriteCloser) error { var ( rc = http.NewResponseController(rw) donec = make(chan bool, 2) @@ -229,12 +214,12 @@ func (p proxyHandler) tunnel(name string, rw http.ResponseWriter, req *http.Requ if err := brw.Flush(); err != nil { return fmt.Errorf("got error while flushing response back to client: %w", err) } - if err := drainBuffer(cw, brw.Reader); err != nil { + if err := drainBuffer(crw, brw.Reader); err != nil { return fmt.Errorf("got error while draining buffer: %w", err) } - go copySync(req.Context(), "outbound "+name, cw, conn, donec) - go copySync(req.Context(), "inbound "+name, conn, cr, donec) + go copySync(req.Context(), "outbound "+name, crw, conn, donec) + go copySync(req.Context(), "inbound "+name, conn, crw, donec) case 2: copyHeader(rw.Header(), res.Header) rw.WriteHeader(res.StatusCode) @@ -243,8 +228,8 @@ func (p proxyHandler) tunnel(name string, rw http.ResponseWriter, req *http.Requ return fmt.Errorf("got error while flushing response back to client: %w", err) } - go copySync(req.Context(), "outbound "+name, cw, req.Body, donec) - go copySync(req.Context(), "inbound "+name, writeFlusher{rw, rc}, cr, donec) + go copySync(req.Context(), "outbound "+name, crw, req.Body, donec) + go copySync(req.Context(), "inbound "+name, writeFlusher{rw, rc}, crw, donec) default: return fmt.Errorf("unsupported protocol version: %d", req.ProtoMajor) } diff --git a/internal/martian/proxy.go b/internal/martian/proxy.go index 9240540b..2bc6cd01 100644 --- a/internal/martian/proxy.go +++ b/internal/martian/proxy.go @@ -114,14 +114,14 @@ type Proxy struct { // If empty, no action is taken, and the proxy will generate a new request ID. RequestIDHeader string - // ConnectPassthrough passes CONNECT requests to the RoundTripper, - // and uses the response body as the connection. - ConnectPassthrough bool - // ConnectRequestModifier modifies CONNECT requests to upstream proxy. // If ConnectPassthrough is enabled, this is ignored. ConnectRequestModifier func(*http.Request) error + // ConnectFunc specifies a function to dial network connections for CONNECT requests. + // Implementations can return ErrConnectFallback to indicate that the CONNECT request should be handled by martian. + ConnectFunc ConnectFunc + // MITMFilter specifies a function to determine whether a CONNECT request should be MITMed. MITMFilter func(*http.Request) bool @@ -548,40 +548,25 @@ func (p *Proxy) handleConnectRequest(ctx *Context, req *http.Request, session *S log.Debugf(req.Context(), "attempting to establish CONNECT tunnel: %s", req.URL.Host) var ( res *http.Response - cr io.Reader - cw io.WriteCloser + crw io.ReadWriteCloser cerr error ) - if p.ConnectPassthrough { //nolint:nestif // to be fixed in #445 - pr, pw := io.Pipe() - req.Body = pr - defer req.Body.Close() - - // perform the HTTP roundtrip - res, cerr = p.roundTrip(ctx, req) - if res != nil { - cr = res.Body - cw = pw - - if res.StatusCode/100 == 2 { - res = proxyutil.NewResponse(200, nil, req) - } - } - } else { + if p.ConnectFunc != nil { + res, crw, cerr = p.ConnectFunc(req) + } + if p.ConnectFunc == nil || errors.Is(cerr, ErrConnectFallback) { var cconn net.Conn res, cconn, cerr = p.connect(req) if cconn != nil { defer cconn.Close() - cr = cconn - cw = cconn + crw = cconn if shouldTerminateTLS(req) { log.Debugf(req.Context(), "attempting to terminate TLS on CONNECT tunnel: %s", req.URL.Host) tconn := tls.Client(cconn, p.clientTLSConfig()) if err := tconn.Handshake(); err == nil { - cr = tconn - cw = tconn + crw = tconn } else { log.Errorf(req.Context(), "failed to terminate TLS on CONNECT tunnel: %v", err) cerr = err @@ -622,7 +607,7 @@ func (p *Proxy) handleConnectRequest(ctx *Context, req *http.Request, session *S res.ContentLength = -1 - if err := p.tunnel("CONNECT", res, brw, conn, cw, cr); err != nil { + if err := p.tunnel("CONNECT", res, brw, conn, crw); err != nil { log.Errorf(req.Context(), "CONNECT tunnel: %w", err) } @@ -640,28 +625,28 @@ func (p *Proxy) handleUpgradeResponse(res *http.Response, brw *bufio.ReadWriter, res.Body = nil - if err := p.tunnel(resUpType, res, brw, conn, uconn, uconn); err != nil { + if err := p.tunnel(resUpType, res, brw, conn, uconn); err != nil { log.Errorf(res.Request.Context(), "%s tunnel: %w", resUpType, err) } return errClose } -func (p *Proxy) tunnel(name string, res *http.Response, brw *bufio.ReadWriter, conn net.Conn, cw io.Writer, cr io.Reader) error { +func (p *Proxy) tunnel(name string, res *http.Response, brw *bufio.ReadWriter, conn net.Conn, crw io.ReadWriteCloser) error { if err := res.Write(brw); err != nil { return fmt.Errorf("got error while writing response back to client: %w", err) } if err := brw.Flush(); err != nil { return fmt.Errorf("got error while flushing response back to client: %w", err) } - if err := drainBuffer(cw, brw.Reader); err != nil { + if err := drainBuffer(crw, brw.Reader); err != nil { return fmt.Errorf("got error while draining read buffer: %w", err) } ctx := res.Request.Context() donec := make(chan bool, 2) - go copySync(ctx, "outbound "+name, cw, conn, donec) - go copySync(ctx, "inbound "+name, conn, cr, donec) + go copySync(ctx, "outbound "+name, crw, conn, donec) + go copySync(ctx, "inbound "+name, conn, crw, donec) log.Debugf(ctx, "switched protocols, proxying %s traffic", name) <-donec diff --git a/internal/martian/proxy_test.go b/internal/martian/proxy_test.go index d18456cb..d8a40cb2 100644 --- a/internal/martian/proxy_test.go +++ b/internal/martian/proxy_test.go @@ -36,6 +36,7 @@ import ( "github.com/saucelabs/forwarder/internal/martian/martiantest" "github.com/saucelabs/forwarder/internal/martian/mitm" "github.com/saucelabs/forwarder/internal/martian/proxyutil" + "go.uber.org/multierr" ) type tempError struct{} @@ -1009,27 +1010,33 @@ func TestIntegrationConnectUpstreamProxy(t *testing.T) { } } -func TestIntegrationConnectPassthrough(t *testing.T) { +type pipeConn struct { + *io.PipeReader + *io.PipeWriter +} + +func (conn pipeConn) CloseWrite() error { + return conn.PipeWriter.Close() +} + +func (conn pipeConn) Close() error { + return multierr.Combine( + conn.PipeReader.Close(), + conn.PipeWriter.Close(), + ) +} + +func TestIntegrationConnectFunc(t *testing.T) { t.Parallel() l := newListener(t) p := NewProxy() - p.ConnectPassthrough = true - defer p.Close() - - tr := martiantest.NewTransport() - tr.Func(func(req *http.Request) (*http.Response, error) { + p.ConnectFunc = func(req *http.Request) (*http.Response, io.ReadWriteCloser, error) { pr, pw := io.Pipe() - go func() { - if _, err := io.Copy(pw, req.Body); err != nil { - t.Errorf("io.Copy(): got %v, want no error", err) - } - pw.Close() - }() - return proxyutil.NewResponse(200, pr, req), nil - }) - p.SetRoundTripper(tr) + return proxyutil.NewResponse(200, nil, req), pipeConn{pr, pw}, nil + } p.SetTimeout(200 * time.Millisecond) + defer p.Close() go serve(p, l)