Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ require (
github.com/prometheus/client_model v0.6.2
github.com/prometheus/common v0.63.0
github.com/siddontang/go v0.0.0-20180604090527-bdc77568d726
github.com/soheilhy/cmux v0.1.5
github.com/spf13/cobra v1.9.1
github.com/stretchr/testify v1.11.1
github.com/tidwall/btree v1.7.0
Expand Down Expand Up @@ -232,7 +233,6 @@ require (
github.com/shurcooL/httpfs v0.0.0-20230704072500-f1e31cf0ba5c // indirect
github.com/shurcooL/vfsgen v0.0.0-20181202132449-6a9ea43bcacd // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/soheilhy/cmux v0.1.5 // indirect
github.com/spf13/pflag v1.0.7 // indirect
github.com/spkg/bom v1.0.0 // indirect
github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 // indirect
Expand Down
18 changes: 18 additions & 0 deletions pkg/server/api/debug_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,21 @@ func TestDebugHealthManualOverride(t *testing.T) {
})
assertHealth(http.StatusBadGateway, "server is not ready")
}

func TestDebugHealthAllowsHTTPWithHTTPTLS(t *testing.T) {
_, doHTTP, doHTTPS := createServerWithConfig(t, `security.server-http-tls.auto-certs = true`)

doHTTP(t, http.MethodGet, "/api/debug/health", httpOpts{}, func(t *testing.T, r *http.Response) {
require.Equal(t, http.StatusOK, r.StatusCode)
})
doHTTP(t, http.MethodGet, "/debug/health", httpOpts{}, func(t *testing.T, r *http.Response) {
require.Equal(t, http.StatusOK, r.StatusCode)
})

doHTTPS(t, http.MethodGet, "/api/debug/health", httpOpts{}, func(t *testing.T, r *http.Response) {
require.Equal(t, http.StatusOK, r.StatusCode)
})
doHTTPS(t, http.MethodGet, "/api/metrics", httpOpts{}, func(t *testing.T, r *http.Response) {
require.Equal(t, http.StatusOK, r.StatusCode)
})
}
36 changes: 32 additions & 4 deletions pkg/server/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/pingcap/tiproxy/pkg/proxy/proxyprotocol"
mgrrp "github.com/pingcap/tiproxy/pkg/sqlreplay/manager"
"github.com/pingcap/tiproxy/pkg/util/waitgroup"
"github.com/soheilhy/cmux"
"go.uber.org/atomic"
"go.uber.org/ratelimit"
"go.uber.org/zap"
Expand Down Expand Up @@ -130,20 +131,47 @@ func NewServer(cfg config.API, lg *zap.Logger, mgr Managers, handler HTTPHandler
}

if tlscfg := mgr.CertMgr.ServerHTTPTLS(); tlscfg != nil {
h.listener = tls.NewListener(h.listener, tlscfg)
mux := cmux.New(h.listener)
mux.SetReadTimeout(DefConnTimeout)
plainHealthListener := mux.Match(cmux.HTTP1Fast())
tlsListener := tls.NewListener(mux.Match(cmux.TLS()), tlscfg)

h.serveHTTP("HTTP health", plainHealthListener, h.newHTTPHealthHandler())
h.serveHTTP("HTTPS", tlsListener, engine.Handler())
h.wg.RunWithRecover(func() {
lg.Info("HTTP mux closed", zap.Error(mux.Serve()))
}, nil, h.lg)
return h, nil
}

h.serveHTTP("HTTP", h.listener, engine.Handler())
return h, nil
}

func (h *Server) serveHTTP(name string, listener net.Listener, handler http.Handler) {
hsrv := http.Server{
Handler: engine.Handler(),
Handler: handler,
ReadHeaderTimeout: DefConnTimeout,
IdleTimeout: DefConnTimeout,
}

h.wg.RunWithRecover(func() {
lg.Info("HTTP closed", zap.Error(hsrv.Serve(h.listener)))
h.lg.Info(name+" closed", zap.Error(hsrv.Serve(listener)))
}, nil, h.lg)
}

return h, nil
func (h *Server) newHTTPHealthHandler() http.Handler {
engine := gin.New()
engine.Use(
gin.Recovery(),
h.rateLimit,
h.readyState,
h.attachLogger,
)
// Keep the plaintext health routes consistent with the main server.
engine.GET("/api/debug/health", h.DebugHealth)
engine.GET("/debug/health", h.DebugHealth)
return engine.Handler()
}

func (h *Server) rateLimit(c *gin.Context) {
Expand Down
45 changes: 32 additions & 13 deletions pkg/server/api/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package api

import (
"context"
"crypto/tls"
"fmt"
"io"
"net/http"
Expand All @@ -28,10 +29,18 @@ type httpOpts struct {
type doHTTPFunc func(t *testing.T, method string, path string, opts httpOpts, f func(*testing.T, *http.Response))

func createServer(t *testing.T) (*Server, doHTTPFunc) {
srv, doHTTP, _ := createServerWithConfig(t, "")
return srv, doHTTP
}

func createServerWithConfig(t *testing.T, tomlConfig string) (*Server, doHTTPFunc, doHTTPFunc) {
lg, _ := logger.CreateLoggerForTest(t)
ready := atomic.NewBool(true)
cfgmgr := mgrcfg.NewConfigManager()
require.NoError(t, cfgmgr.Init(context.Background(), "", ""))
if tomlConfig != "" {
require.NoError(t, cfgmgr.SetTOMLConfig([]byte(tomlConfig)))
}
crtmgr := mgrcrt.NewCertManager()
require.NoError(t, crtmgr.Init(cfgmgr.GetConfig(), lg, cfgmgr.WatchConfig()))
nsMgr := newMockNamespaceManager()
Expand All @@ -49,21 +58,31 @@ func createServer(t *testing.T) (*Server, doHTTPFunc) {
require.NoError(t, srv.Close())
})

addr := fmt.Sprintf("http://%s", srv.listener.Addr().String())
return srv, func(t *testing.T, method, pa string, opts httpOpts, f func(*testing.T, *http.Response)) {
if pa[0] != '/' {
pa = "/" + pa
}
req, err := http.NewRequest(method, fmt.Sprintf("%s%s", addr, pa), opts.reader)
require.NoError(t, err)
for key, value := range opts.header {
req.Header.Set(key, value)
addr := srv.listener.Addr().String()
httpsClient := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
}
do := func(scheme string, client *http.Client) doHTTPFunc {
return func(t *testing.T, method, pa string, opts httpOpts, f func(*testing.T, *http.Response)) {
if pa[0] != '/' {
pa = "/" + pa
}
req, err := http.NewRequest(method, fmt.Sprintf("%s://%s%s", scheme, addr, pa), opts.reader)
require.NoError(t, err)
for key, value := range opts.header {
req.Header.Set(key, value)
}
resp, err := client.Do(req)
require.NoError(t, err)
f(t, resp)
require.NoError(t, resp.Body.Close())
}
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
f(t, resp)
require.NoError(t, resp.Body.Close())
}
return srv, do("http", http.DefaultClient), do("https", httpsClient)
}

func TestGrpc(t *testing.T) {
Expand Down