Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ func main() {
var passwordHash string
var proxyBearerToken string
var proxyHeaders string
var trustedProxies string

rootCmd := &cobra.Command{
Use: "mcp-warp",
Expand Down Expand Up @@ -107,6 +108,14 @@ func main() {
oidcScopesList = []string{"openid", "profile", "email"}
}

var trustedProxiesList []string
if trustedProxies != "" {
trustedProxiesList = strings.Split(trustedProxies, ",")
for i := range trustedProxiesList {
trustedProxiesList[i] = strings.TrimSpace(trustedProxiesList[i])
}
}

// Parse proxy headers into slice
var proxyHeadersList []string
if proxyHeaders != "" {
Expand Down Expand Up @@ -142,6 +151,7 @@ func main() {
oidcAllowedUsersList,
password,
passwordHash,
trustedProxiesList,
proxyHeadersList,
proxyBearerToken,
args,
Expand Down Expand Up @@ -187,6 +197,7 @@ func main() {

// Proxy headers configuration
rootCmd.Flags().StringVar(&proxyBearerToken, "proxy-bearer-token", getEnvWithDefault("PROXY_BEARER_TOKEN", ""), "Bearer token to add to Authorization header when proxying requests")
rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", getEnvWithDefault("TRUSTED_PROXIES", ""), "Comma-separated list of trusted proxies (IP addresses or CIDR ranges)")
rootCmd.Flags().StringVar(&proxyHeaders, "proxy-headers", getEnvWithDefault("PROXY_HEADERS", ""), "Comma-separated list of headers to add when proxying requests (format: Header1:Value1,Header2:Value2)")

if err := rootCmd.Execute(); err != nil {
Expand Down
12 changes: 12 additions & 0 deletions pkg/backend/interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package backend

import (
"context"
"net/http"
)

type Backend interface {
Run(context.Context) (http.Handler, error)
Wait() error
Close() error
}
13 changes: 13 additions & 0 deletions pkg/backend/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package backend

import (
"os"
"testing"

"github.com/gin-gonic/gin"
)

func TestMain(m *testing.M) {
gin.SetMode(gin.TestMode)
os.Exit(m.Run())
}
2 changes: 1 addition & 1 deletion pkg/backend/backend.go → pkg/backend/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type ProxyBackend struct {
client *client.Client
}

func NewProxyBackend(logger *zap.Logger, cmd []string) *ProxyBackend {
func NewProxyBackend(logger *zap.Logger, cmd []string) Backend {
return &ProxyBackend{
logger: logger,
cmd: cmd,
Expand Down
25 changes: 24 additions & 1 deletion pkg/backend/backend_test.go → pkg/backend/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,32 @@ func TestProxyBackendRun(t *testing.T) {
defer pb.Close()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

handler, err := pb.Run(ctx)
require.NoError(t, err, "Run should not return error")
require.NotNil(t, handler, "handler should not be nil")

checkCh := make(chan struct{})
go func() {
<-ctx.Done()
close(checkCh)
}()

timeout := time.After(10 * time.Millisecond)
select {
case <-checkCh:
t.Error("Test completed too early")
case <-timeout:
// Test timed out
}

cancel()

timeout = time.After(10 * time.Second)
select {
case <-checkCh:
// Test completed successfully
case <-timeout:
t.Error("Test timed out")
}
}
100 changes: 100 additions & 0 deletions pkg/backend/transparent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package backend

import (
"context"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/netip"
"net/url"
"sync"

"go.uber.org/zap"
)

type TransparentBackend struct {
logger *zap.Logger
url *url.URL
trusted []netip.Prefix
ctx context.Context
ctxLock sync.Mutex
}

func NewTransparentBackend(logger *zap.Logger, u *url.URL, trusted []string) (Backend, error) {
trn := make([]netip.Prefix, 0, len(trusted))
for _, c := range trusted {
p, err := netip.ParsePrefix(c)
if err != nil {
return nil, err
}
trn = append(trn, p)
}

return &TransparentBackend{
logger: logger,
url: u,
trusted: trn,
}, nil
}

func (p *TransparentBackend) Run(ctx context.Context) (http.Handler, error) {
p.ctxLock.Lock()
defer p.ctxLock.Unlock()
if p.ctx != nil {
return nil, fmt.Errorf("transparent backend is already running")
}
p.ctx = ctx
rp := httputil.ReverseProxy{
Rewrite: func(pr *httputil.ProxyRequest) {
pr.SetURL(p.url)
if p.isTrusted(pr.In.RemoteAddr) {
pr.Out.Header["X-Forwarded-For"] = pr.In.Header["X-Forwarded-For"]
}
pr.SetXForwarded()
if p.isTrusted(pr.In.RemoteAddr) {
if v := pr.In.Header.Get("X-Forwarded-Host"); v != "" {
pr.Out.Header.Set("X-Forwarded-Host", v)
}
if v := pr.In.Header.Get("X-Forwarded-Proto"); v != "" {
pr.Out.Header.Set("X-Forwarded-Proto", v)
}
if v := pr.In.Header.Get("X-Forwarded-Port"); v != "" {
pr.Out.Header.Set("X-Forwarded-Port", v)
}
}
},
}
return &rp, nil
}

func (p *TransparentBackend) isTrusted(hostport string) bool {
if host, _, err := net.SplitHostPort(hostport); err == nil {
hostport = host
}
ip, err := netip.ParseAddr(hostport)
if err != nil {
return false
}
if ip.Is4In6() {
ip = ip.Unmap()
}
for _, p := range p.trusted {
if p.Contains(ip) {
return true
}
}
return false
}

func (p *TransparentBackend) Wait() error {
if p.ctx == nil {
return nil
}
<-p.ctx.Done()
Comment on lines +91 to +94
Copy link

Copilot AI Aug 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Wait() method accesses p.ctx without proper synchronization. Consider using the ctxLock mutex to safely read p.ctx, similar to how it's protected in the Run() method.

Suggested change
if p.ctx == nil {
return nil
}
<-p.ctx.Done()
p.ctxLock.Lock()
ctx := p.ctx
p.ctxLock.Unlock()
if ctx == nil {
return nil
}
<-ctx.Done()

Copilot uses AI. Check for mistakes.
return nil
}

func (p *TransparentBackend) Close() error {
return nil
}
138 changes: 138 additions & 0 deletions pkg/backend/transparent_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package backend

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"

"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)

func TestTransparentBackend(t *testing.T) {
r := gin.New()
r.GET("/", func(c *gin.Context) {
c.JSON(http.StatusOK, c.Request.Header)
})
ts := httptest.NewServer(r)
u, _ := url.Parse(ts.URL)

be, err := NewTransparentBackend(zap.NewNop(), u, []string{})
require.NoError(t, err)
handler, err := be.Run(context.Background())
require.NoError(t, err)
require.NotNil(t, handler)

req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

require.Equal(t, http.StatusOK, rr.Code)
var header http.Header
require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &header))
require.Equal(t, "192.0.2.1", header.Get(("X-Forwarded-For")))
require.Equal(t, "example.com", header.Get(("X-Forwarded-Host")))
require.Equal(t, "http", header.Get(("X-Forwarded-Proto")))
}

func TestTransparentBackendWithProxy(t *testing.T) {
r := gin.New()
r.GET("/", func(c *gin.Context) {
c.JSON(http.StatusOK, c.Request.Header)
})
ts := httptest.NewServer(r)
u, _ := url.Parse(ts.URL)

be, err := NewTransparentBackend(zap.NewNop(), u, []string{"0.0.0.0/0"})
require.NoError(t, err)
handler, err := be.Run(context.Background())
require.NoError(t, err)
require.NotNil(t, handler)

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Forwarded-For", "192.0.3.1")
req.Header.Set("X-Forwarded-Host", "example.org")
req.Header.Set("X-Forwarded-Proto", "https")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

require.Equal(t, http.StatusOK, rr.Code)
var header http.Header
require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &header))
require.Equal(t, "192.0.3.1, 192.0.2.1", header.Get(("X-Forwarded-For")))
require.Equal(t, "example.org", header.Get(("X-Forwarded-Host")))
require.Equal(t, "https", header.Get(("X-Forwarded-Proto")))
}

func TestTransparentBackendWithInvalidProxy(t *testing.T) {
r := gin.New()
r.GET("/", func(c *gin.Context) {
c.JSON(http.StatusOK, c.Request.Header)
})
ts := httptest.NewServer(r)
u, _ := url.Parse(ts.URL)

be, err := NewTransparentBackend(zap.NewNop(), u, []string{"1.1.1.1/32"})
require.NoError(t, err)
handler, err := be.Run(context.Background())
require.NoError(t, err)
require.NotNil(t, handler)

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Forwarded-For", "192.0.3.1")
req.Header.Set("X-Forwarded-Host", "example.org")
req.Header.Set("X-Forwarded-Proto", "https")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

require.Equal(t, http.StatusOK, rr.Code)
var header http.Header
require.NoError(t, json.Unmarshal(rr.Body.Bytes(), &header))
require.Equal(t, "192.0.2.1", header.Get(("X-Forwarded-For")))
require.Equal(t, "example.com", header.Get(("X-Forwarded-Host")))
require.Equal(t, "http", header.Get(("X-Forwarded-Proto")))
}

func TestTransparentBackendRun(t *testing.T) {
r := gin.New()
r.GET("/", func(c *gin.Context) {
c.JSON(http.StatusOK, c.Request.Header)
})
ts := httptest.NewServer(r)
u, _ := url.Parse(ts.URL)

be, err := NewTransparentBackend(zap.NewNop(), u, []string{})
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
_, err = be.Run(ctx)
require.NoError(t, err)

checkCh := make(chan struct{})
go func() {
<-ctx.Done()
close(checkCh)
}()

timeout := time.After(10 * time.Millisecond)
select {
case <-checkCh:
t.Error("Test completed too early")
case <-timeout:
// Test timed out
}

cancel()

timeout = time.After(10 * time.Second)
select {
case <-checkCh:
// Test completed successfully
case <-timeout:
t.Error("Test timed out")
}
}
15 changes: 12 additions & 3 deletions pkg/mcp-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"os"
"os/signal"
Expand Down Expand Up @@ -59,6 +58,7 @@ func Run(
oidcAllowedUsers []string,
password string,
passwordHash string,
trustedProxy []string,
proxyHeaders []string,
proxyBearerToken string,
proxyTarget []string,
Expand Down Expand Up @@ -98,10 +98,18 @@ func Run(
if len(proxyTarget) == 0 {
return fmt.Errorf("proxy target must be specified")
}
var be *backend.ProxyBackend
var be backend.Backend
var beHandler http.Handler
if proxyURL, err := url.Parse(proxyTarget[0]); err == nil && (proxyURL.Scheme == "http" || proxyURL.Scheme == "https") {
beHandler = httputil.NewSingleHostReverseProxy(proxyURL)
var err error
be, err = backend.NewTransparentBackend(logger, proxyURL, trustedProxy)
if err != nil {
return fmt.Errorf("failed to create transparent backend: %w", err)
}
beHandler, err = be.Run(ctx)
if err != nil {
return fmt.Errorf("failed to create transparent backend: %w", err)
}
} else {
be = backend.NewProxyBackend(logger, proxyTarget)
beHandler, err = be.Run(ctx)
Expand Down Expand Up @@ -205,6 +213,7 @@ func Run(
}

router := gin.New()
router.SetTrustedProxies(trustedProxy)

router.Use(ginzap.Ginzap(logger, time.RFC3339, true))
router.Use(ginzap.RecoveryWithZap(logger, true))
Expand Down