diff --git a/go.mod b/go.mod index 6fd0f071..d22544c5 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/prometheus/client_model v0.6.2 github.com/prometheus/common v0.63.0 github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726 + github.com/soheilhy/cmux v0.1.5 github.com/spf13/cobra v1.9.1 github.com/stretchr/testify v1.11.1 github.com/tidwall/btree v1.7.0 @@ -232,7 +233,6 @@ require ( github.com/shurcooL/httpfs v0.0.0-20230704072500-f1e31cf0ba5c // indirect github.com/shurcooL/vfsgen v0.0.0-20181202132449-6a9ea43bcacd // indirect github.com/sirupsen/logrus v1.9.3 // indirect - github.com/soheilhy/cmux v0.1.5 // indirect github.com/spf13/pflag v1.0.7 // indirect github.com/spkg/bom v1.0.0 // indirect github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 // indirect diff --git a/pkg/server/api/debug_test.go b/pkg/server/api/debug_test.go index dabc40dc..9e1da29f 100644 --- a/pkg/server/api/debug_test.go +++ b/pkg/server/api/debug_test.go @@ -93,3 +93,21 @@ func TestDebugHealthManualOverride(t *testing.T) { }) assertHealth(http.StatusBadGateway, "server is not ready") } + +func TestDebugHealthAllowsHTTPWithHTTPTLS(t *testing.T) { + _, doHTTP, doHTTPS := createServerWithConfig(t, `security.server-http-tls.auto-certs = true`) + + doHTTP(t, http.MethodGet, "/api/debug/health", httpOpts{}, func(t *testing.T, r *http.Response) { + require.Equal(t, http.StatusOK, r.StatusCode) + }) + doHTTP(t, http.MethodGet, "/debug/health", httpOpts{}, func(t *testing.T, r *http.Response) { + require.Equal(t, http.StatusOK, r.StatusCode) + }) + + doHTTPS(t, http.MethodGet, "/api/debug/health", httpOpts{}, func(t *testing.T, r *http.Response) { + require.Equal(t, http.StatusOK, r.StatusCode) + }) + doHTTPS(t, http.MethodGet, "/api/metrics", httpOpts{}, func(t *testing.T, r *http.Response) { + require.Equal(t, http.StatusOK, r.StatusCode) + }) +} diff --git a/pkg/server/api/server.go b/pkg/server/api/server.go index 640c0ac3..b2cc8c25 100644 --- a/pkg/server/api/server.go +++ b/pkg/server/api/server.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tiproxy/pkg/proxy/proxyprotocol" mgrrp "github.com/pingcap/tiproxy/pkg/sqlreplay/manager" "github.com/pingcap/tiproxy/pkg/util/waitgroup" + "github.com/soheilhy/cmux" "go.uber.org/atomic" "go.uber.org/ratelimit" "go.uber.org/zap" @@ -130,20 +131,47 @@ func NewServer(cfg config.API, lg *zap.Logger, mgr Managers, handler HTTPHandler } if tlscfg := mgr.CertMgr.ServerHTTPTLS(); tlscfg != nil { - h.listener = tls.NewListener(h.listener, tlscfg) + mux := cmux.New(h.listener) + mux.SetReadTimeout(DefConnTimeout) + plainHealthListener := mux.Match(cmux.HTTP1Fast()) + tlsListener := tls.NewListener(mux.Match(cmux.TLS()), tlscfg) + + h.serveHTTP("HTTP health", plainHealthListener, h.newHTTPHealthHandler()) + h.serveHTTP("HTTPS", tlsListener, engine.Handler()) + h.wg.RunWithRecover(func() { + lg.Info("HTTP mux closed", zap.Error(mux.Serve())) + }, nil, h.lg) + return h, nil } + h.serveHTTP("HTTP", h.listener, engine.Handler()) + return h, nil +} + +func (h *Server) serveHTTP(name string, listener net.Listener, handler http.Handler) { hsrv := http.Server{ - Handler: engine.Handler(), + Handler: handler, ReadHeaderTimeout: DefConnTimeout, IdleTimeout: DefConnTimeout, } h.wg.RunWithRecover(func() { - lg.Info("HTTP closed", zap.Error(hsrv.Serve(h.listener))) + h.lg.Info(name+" closed", zap.Error(hsrv.Serve(listener))) }, nil, h.lg) +} - return h, nil +func (h *Server) newHTTPHealthHandler() http.Handler { + engine := gin.New() + engine.Use( + gin.Recovery(), + h.rateLimit, + h.readyState, + h.attachLogger, + ) + // Keep the plaintext health routes consistent with the main server. + engine.GET("/api/debug/health", h.DebugHealth) + engine.GET("/debug/health", h.DebugHealth) + return engine.Handler() } func (h *Server) rateLimit(c *gin.Context) { diff --git a/pkg/server/api/server_test.go b/pkg/server/api/server_test.go index ad87675b..5ff43168 100644 --- a/pkg/server/api/server_test.go +++ b/pkg/server/api/server_test.go @@ -5,6 +5,7 @@ package api import ( "context" + "crypto/tls" "fmt" "io" "net/http" @@ -28,10 +29,18 @@ type httpOpts struct { type doHTTPFunc func(t *testing.T, method string, path string, opts httpOpts, f func(*testing.T, *http.Response)) func createServer(t *testing.T) (*Server, doHTTPFunc) { + srv, doHTTP, _ := createServerWithConfig(t, "") + return srv, doHTTP +} + +func createServerWithConfig(t *testing.T, tomlConfig string) (*Server, doHTTPFunc, doHTTPFunc) { lg, _ := logger.CreateLoggerForTest(t) ready := atomic.NewBool(true) cfgmgr := mgrcfg.NewConfigManager() require.NoError(t, cfgmgr.Init(context.Background(), "", "")) + if tomlConfig != "" { + require.NoError(t, cfgmgr.SetTOMLConfig([]byte(tomlConfig))) + } crtmgr := mgrcrt.NewCertManager() require.NoError(t, crtmgr.Init(cfgmgr.GetConfig(), lg, cfgmgr.WatchConfig())) nsMgr := newMockNamespaceManager() @@ -49,21 +58,31 @@ func createServer(t *testing.T) (*Server, doHTTPFunc) { require.NoError(t, srv.Close()) }) - addr := fmt.Sprintf("http://%s", srv.listener.Addr().String()) - return srv, func(t *testing.T, method, pa string, opts httpOpts, f func(*testing.T, *http.Response)) { - if pa[0] != '/' { - pa = "/" + pa - } - req, err := http.NewRequest(method, fmt.Sprintf("%s%s", addr, pa), opts.reader) - require.NoError(t, err) - for key, value := range opts.header { - req.Header.Set(key, value) + addr := srv.listener.Addr().String() + httpsClient := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + } + do := func(scheme string, client *http.Client) doHTTPFunc { + return func(t *testing.T, method, pa string, opts httpOpts, f func(*testing.T, *http.Response)) { + if pa[0] != '/' { + pa = "/" + pa + } + req, err := http.NewRequest(method, fmt.Sprintf("%s://%s%s", scheme, addr, pa), opts.reader) + require.NoError(t, err) + for key, value := range opts.header { + req.Header.Set(key, value) + } + resp, err := client.Do(req) + require.NoError(t, err) + f(t, resp) + require.NoError(t, resp.Body.Close()) } - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - f(t, resp) - require.NoError(t, resp.Body.Close()) } + return srv, do("http", http.DefaultClient), do("https", httpsClient) } func TestGrpc(t *testing.T) {