diff --git a/lib/config/proxy.go b/lib/config/proxy.go index 69b1a3c71..443fa9d3e 100644 --- a/lib/config/proxy.go +++ b/lib/config/proxy.go @@ -47,9 +47,10 @@ type KeepAlive struct { } type ProxyServerOnline struct { - MaxConnections uint64 `yaml:"max-connections,omitempty" toml:"max-connections,omitempty" json:"max-connections,omitempty" reloadable:"true"` - ConnBufferSize int `yaml:"conn-buffer-size,omitempty" toml:"conn-buffer-size,omitempty" json:"conn-buffer-size,omitempty" reloadable:"true"` - FrontendKeepalive KeepAlive `yaml:"frontend-keepalive" toml:"frontend-keepalive" json:"frontend-keepalive"` + MaxConnections uint64 `yaml:"max-connections,omitempty" toml:"max-connections,omitempty" json:"max-connections,omitempty" reloadable:"true"` + HighMemoryUsageRejectThreshold float64 `yaml:"high-memory-usage-reject-threshold,omitempty" toml:"high-memory-usage-reject-threshold,omitempty" json:"high-memory-usage-reject-threshold,omitempty" reloadable:"true"` + ConnBufferSize int `yaml:"conn-buffer-size,omitempty" toml:"conn-buffer-size,omitempty" json:"conn-buffer-size,omitempty" reloadable:"true"` + FrontendKeepalive KeepAlive `yaml:"frontend-keepalive" toml:"frontend-keepalive" json:"frontend-keepalive"` // BackendHealthyKeepalive applies when the observer treats the backend as healthy. // The config values should be conservative to save CPU and tolerate network fluctuation. BackendHealthyKeepalive KeepAlive `yaml:"backend-healthy-keepalive" toml:"backend-healthy-keepalive" json:"backend-healthy-keepalive"` @@ -149,6 +150,7 @@ func NewConfig() *Config { cfg.Proxy.Addr = "0.0.0.0:6000" cfg.Proxy.FrontendKeepalive, cfg.Proxy.BackendHealthyKeepalive, cfg.Proxy.BackendUnhealthyKeepalive = DefaultKeepAlive() + cfg.Proxy.HighMemoryUsageRejectThreshold = 0.9 cfg.Proxy.PDAddrs = "127.0.0.1:2379" cfg.Proxy.GracefulCloseConnTimeout = 15 cfg.Proxy.FailoverTimeout = 60 @@ -211,6 +213,12 @@ func (cfg *Config) Check() error { } func (ps *ProxyServer) Check() error { + if ps.HighMemoryUsageRejectThreshold < 0 || ps.HighMemoryUsageRejectThreshold > 1 { + return errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.high-memory-usage-reject-threshold") + } + if ps.HighMemoryUsageRejectThreshold > 0 && ps.HighMemoryUsageRejectThreshold < 0.5 { + ps.HighMemoryUsageRejectThreshold = 0.5 + } if ps.FailoverTimeout < 0 { return errors.Wrapf(ErrInvalidConfigValue, "proxy.failover-timeout must be greater than or equal to 0") } diff --git a/lib/config/proxy_test.go b/lib/config/proxy_test.go index e908af9e6..0a353942c 100644 --- a/lib/config/proxy_test.go +++ b/lib/config/proxy_test.go @@ -20,13 +20,14 @@ var testProxyConfig = Config{ Addr: "0.0.0.0:4000", PDAddrs: "127.0.0.1:4089", ProxyServerOnline: ProxyServerOnline{ - MaxConnections: 1, - FrontendKeepalive: KeepAlive{Enabled: true}, - ProxyProtocol: "v2", - GracefulWaitBeforeShutdown: 10, - FailBackendList: []string{"db-tidb-0", "db-tidb-1"}, - FailoverTimeout: 60, - ConnBufferSize: 32 * 1024, + MaxConnections: 1, + HighMemoryUsageRejectThreshold: 0.9, + FrontendKeepalive: KeepAlive{Enabled: true}, + ProxyProtocol: "v2", + GracefulWaitBeforeShutdown: 10, + FailBackendList: []string{"db-tidb-0", "db-tidb-1"}, + FailoverTimeout: 60, + ConnBufferSize: 32 * 1024, }, }, API: API{ @@ -92,6 +93,26 @@ func TestProxyCheck(t *testing.T) { post func(*testing.T, *Config) err error }{ + { + pre: func(t *testing.T, c *Config) { + c.Proxy.HighMemoryUsageRejectThreshold = -0.1 + }, + err: ErrInvalidConfigValue, + }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.HighMemoryUsageRejectThreshold = 1.1 + }, + err: ErrInvalidConfigValue, + }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.HighMemoryUsageRejectThreshold = 0.4 + }, + post: func(t *testing.T, c *Config) { + require.Equal(t, 0.5, c.Proxy.HighMemoryUsageRejectThreshold) + }, + }, { pre: func(t *testing.T, c *Config) { c.Workdir = "" diff --git a/pkg/manager/memory/memory.go b/pkg/manager/memory/memory.go index f6355bfbf..bfe4b650e 100644 --- a/pkg/manager/memory/memory.go +++ b/pkg/manager/memory/memory.go @@ -9,6 +9,7 @@ import ( "path/filepath" "runtime" "runtime/pprof" + "sync/atomic" "time" "github.com/pingcap/tidb/pkg/util/memory" @@ -18,16 +19,26 @@ import ( ) const ( - // Check the memory usage every 30 seconds. - checkInterval = 30 * time.Second + // Refresh the memory usage every 5 seconds. + refreshInterval = 5 * time.Second // No need to record too frequently. recordMinInterval = 5 * time.Minute // Record the profiles when the memory usage is higher than 60%. alarmThreshold = 0.6 // Remove the oldest profiles when the number of profiles exceeds this limit. maxSavedProfiles = 20 + // Fail open if the latest sampled usage is too old. + snapshotExpireInterval = 3 * refreshInterval ) +type UsageSnapshot struct { + Used uint64 + Limit uint64 + Usage float64 + UpdateTime time.Time + Valid bool +} + // MemManager is a manager for memory usage. // Although the continuous profiling collects profiles periodically, when TiProxy runs in the replayer mode, // the profiles are not collected. @@ -38,77 +49,166 @@ type MemManager struct { cfgGetter config.ConfigGetter savedProfileNames []string lastRecordTime time.Time - checkInterval time.Duration // used for test + refreshInterval time.Duration // used for test recordMinInterval time.Duration // used for test maxSavedProfiles int // used for test + snapshotExpire time.Duration // used for test memoryLimit uint64 + latestUsage atomic.Value + // connBufferMemDelta tracks the estimated buffer memory change since the latest refreshUsage. + connBufferMemDelta atomic.Int64 } func NewMemManager(lg *zap.Logger, cfgGetter config.ConfigGetter) *MemManager { - return &MemManager{ + mgr := &MemManager{ lg: lg, cfgGetter: cfgGetter, - checkInterval: checkInterval, + refreshInterval: refreshInterval, recordMinInterval: recordMinInterval, maxSavedProfiles: maxSavedProfiles, + snapshotExpire: snapshotExpireInterval, } + mgr.latestUsage.Store(UsageSnapshot{}) + return mgr } func (m *MemManager) Start(ctx context.Context) { // Call the memory.MemTotal and memory.MemUsed in TiDB repo because they have considered cgroup. limit, err := memory.MemTotal() if err != nil || limit == 0 { - m.lg.Error("get memory limit failed", zap.Uint64("limit", limit), zap.Error(err)) + m.lg.Warn("get memory limit failed", zap.Uint64("limit", limit), zap.Error(err)) return } m.memoryLimit = limit + if _, err = m.refreshUsage(); err != nil { + return + } childCtx, cancel := context.WithCancel(ctx) m.cancel = cancel m.wg.RunWithRecover(func() { - m.alarmLoop(childCtx) + m.refreshLoop(childCtx) }, nil, m.lg) } -func (m *MemManager) alarmLoop(ctx context.Context) { - ticker := time.NewTicker(m.checkInterval) +func (m *MemManager) refreshLoop(ctx context.Context) { + ticker := time.NewTicker(m.refreshInterval) defer ticker.Stop() for ctx.Err() == nil { select { case <-ctx.Done(): return case <-ticker.C: - m.checkAndAlarm() + m.refreshAndAlarm() } } } -func (m *MemManager) checkAndAlarm() { +func (m *MemManager) refreshAndAlarm() { + snapshot, err := m.refreshUsage() + if err != nil || !snapshot.Valid { + return + } + if snapshot.Usage < alarmThreshold { + return + } if time.Since(m.lastRecordTime) < m.recordMinInterval { return } // The filename is hot-reloadable. - logPath := m.cfgGetter.GetConfig().Log.LogFile.Filename + cfg := m.cfgGetter.GetConfig() + if cfg == nil { + return + } + logPath := cfg.Log.LogFile.Filename if logPath == "" { return } recordDir := filepath.Dir(logPath) + m.lastRecordTime = snapshot.UpdateTime + m.lg.Warn("memory usage alarm", zap.Uint64("limit", snapshot.Limit), zap.Uint64("used", snapshot.Used), zap.Float64("usage", snapshot.Usage)) + now := time.Now().Format(time.RFC3339) + m.recordHeap(filepath.Join(recordDir, "heap_"+now)) + m.recordGoroutine(filepath.Join(recordDir, "goroutine_"+now)) + m.rmExpiredProfiles() +} + +func (m *MemManager) refreshUsage() (UsageSnapshot, error) { + if m.memoryLimit == 0 { + return UsageSnapshot{}, nil + } used, err := memory.MemUsed() if err != nil || used == 0 { - m.lg.Error("get used memory failed", zap.Uint64("used", used), zap.Error(err)) - return + m.lg.Warn("get used memory failed", zap.Uint64("used", used), zap.Error(err)) + return UsageSnapshot{}, err + } + // Start a new delta window from this sampled snapshot. Later connection create/close + // events only adjust the in-memory estimate relative to this refresh result. + m.connBufferMemDelta.Swap(0) + snapshot := UsageSnapshot{ + Used: used, + Limit: m.memoryLimit, + Usage: float64(used) / float64(m.memoryLimit), + UpdateTime: time.Now(), + Valid: true, } - memoryUsage := float64(used) / float64(m.memoryLimit) - if memoryUsage < alarmThreshold { + m.latestUsage.Store(snapshot) + return snapshot, nil +} + +func (m *MemManager) LatestUsage() UsageSnapshot { + snapshot, _ := m.latestUsage.Load().(UsageSnapshot) + return snapshot +} + +func (m *MemManager) UpdateConnBufferMemory(delta int64) { + if m == nil || delta == 0 { return } + m.connBufferMemDelta.Add(delta) +} - m.lastRecordTime = time.Now() - m.lg.Warn("memory usage alarm", zap.Uint64("limit", m.memoryLimit), zap.Uint64("used", used), zap.Float64("usage", memoryUsage)) - now := time.Now().Format(time.RFC3339) - m.recordHeap(filepath.Join(recordDir, "heap_"+now)) - m.recordGoroutine(filepath.Join(recordDir, "goroutine_"+now)) - m.rmExpiredProfiles() +// adjustUsageByConnBuffer applies the connection buffer delta accumulated after the +// latest refreshUsage, so ShouldRejectNewConn can react before the next memory sample. +func (m *MemManager) adjustUsageByConnBuffer(snapshot UsageSnapshot) UsageSnapshot { + delta := m.connBufferMemDelta.Load() + if delta == 0 { + return snapshot + } + if delta > 0 { + snapshot.Used += uint64(delta) + } else { + released := uint64(-delta) + if released >= snapshot.Used { + snapshot.Used = 0 + } else { + snapshot.Used -= released + } + } + if snapshot.Limit > 0 { + snapshot.Usage = float64(snapshot.Used) / float64(snapshot.Limit) + } + return snapshot +} + +func (m *MemManager) ShouldRejectNewConn() (bool, UsageSnapshot, float64) { + if m == nil || m.cfgGetter == nil { + return false, UsageSnapshot{}, 0 + } + cfg := m.cfgGetter.GetConfig() + if cfg == nil { + return false, UsageSnapshot{}, 0 + } + threshold := cfg.Proxy.HighMemoryUsageRejectThreshold + if threshold == 0 { + return false, UsageSnapshot{}, 0 + } + snapshot := m.LatestUsage() + if !snapshot.Valid || time.Since(snapshot.UpdateTime) > m.snapshotExpire { + return false, snapshot, threshold + } + snapshot = m.adjustUsageByConnBuffer(snapshot) + return snapshot.Usage >= threshold, snapshot, threshold } func (m *MemManager) recordHeap(fileName string) { diff --git a/pkg/manager/memory/memory_test.go b/pkg/manager/memory/memory_test.go index 9d0d8fee8..dec324ac3 100644 --- a/pkg/manager/memory/memory_test.go +++ b/pkg/manager/memory/memory_test.go @@ -8,6 +8,7 @@ import ( "os" "path" "strings" + "sync/atomic" "testing" "time" @@ -26,6 +27,12 @@ func (c *mockCfgGetter) GetConfig() *config.Config { } func TestRecordProfile(t *testing.T) { + oldMemUsed, oldMemTotal := memory.MemUsed, memory.MemTotal + defer func() { + memory.MemUsed = oldMemUsed + memory.MemTotal = oldMemTotal + }() + dir := t.TempDir() cfg := &config.Config{} cfg.Log.LogFile.Filename = path.Join(dir, "proxy.log") @@ -39,7 +46,7 @@ func TestRecordProfile(t *testing.T) { m := NewMemManager(zap.NewNop(), &cfgGetter) // The timestamp in file names are in seconds instead of milliseconds, so recording too frequently is useless. // Instead, it may overwrite the previous files. - m.checkInterval = 100 * time.Millisecond + m.refreshInterval = 100 * time.Millisecond m.recordMinInterval = 1200 * time.Millisecond m.maxSavedProfiles = 2 m.Start(context.Background()) @@ -75,3 +82,100 @@ func TestRecordProfile(t *testing.T) { require.NoError(t, err) require.Len(t, entries, m.maxSavedProfiles) } + +func TestShouldRejectNewConn(t *testing.T) { + oldMemUsed, oldMemTotal := memory.MemUsed, memory.MemTotal + defer func() { + memory.MemUsed = oldMemUsed + memory.MemTotal = oldMemTotal + }() + + cfg := config.NewConfig() + cfg.Proxy.HighMemoryUsageRejectThreshold = 0.9 + cfgGetter := mockCfgGetter{cfg: cfg} + memory.MemUsed = func() (uint64, error) { + return 9 * (1 << 30), nil + } + memory.MemTotal = func() (uint64, error) { + return 10 * (1 << 30), nil + } + m := NewMemManager(zap.NewNop(), &cfgGetter) + m.refreshInterval = 50 * time.Millisecond + m.snapshotExpire = 200 * time.Millisecond + m.Start(context.Background()) + defer m.Close() + + require.Eventually(t, func() bool { + reject, snapshot, threshold := m.ShouldRejectNewConn() + return reject && snapshot.Valid && threshold == 0.9 + }, time.Second, 10*time.Millisecond) + m.Close() + + cfg.Proxy.HighMemoryUsageRejectThreshold = 0 + reject, _, threshold := m.ShouldRejectNewConn() + require.False(t, reject) + require.Zero(t, threshold) + + staleSnapshot := m.LatestUsage() + staleSnapshot.UpdateTime = time.Now().Add(-m.snapshotExpire - time.Second) + m.latestUsage.Store(staleSnapshot) + cfg.Proxy.HighMemoryUsageRejectThreshold = 0.9 + reject, _, threshold = m.ShouldRejectNewConn() + require.False(t, reject) + require.Equal(t, 0.9, threshold) +} + +func TestShouldRejectNewConnTracksConnBufferMemory(t *testing.T) { + oldMemUsed, oldMemTotal := memory.MemUsed, memory.MemTotal + defer func() { + memory.MemUsed = oldMemUsed + memory.MemTotal = oldMemTotal + }() + + cfg := config.NewConfig() + cfg.Proxy.HighMemoryUsageRejectThreshold = 0.9 + cfgGetter := mockCfgGetter{cfg: cfg} + var used atomic.Uint64 + used.Store(890) + memory.MemUsed = func() (uint64, error) { + return used.Load(), nil + } + memory.MemTotal = func() (uint64, error) { + return 1000, nil + } + m := NewMemManager(zap.NewNop(), &cfgGetter) + m.refreshInterval = 50 * time.Millisecond + m.snapshotExpire = time.Second + m.Start(context.Background()) + defer m.Close() + + require.Eventually(t, func() bool { + reject, snapshot, threshold := m.ShouldRejectNewConn() + return !reject && snapshot.Valid && threshold == 0.9 && snapshot.Used == 890 + }, time.Second, 10*time.Millisecond) + + m.UpdateConnBufferMemory(20) + reject, snapshot, threshold := m.ShouldRejectNewConn() + require.True(t, reject) + require.Equal(t, 0.9, threshold) + require.Equal(t, uint64(910), snapshot.Used) + require.InDelta(t, 0.91, snapshot.Usage, 0.0001) + + used.Store(910) + snapshot, err := m.refreshUsage() + require.NoError(t, err) + require.Equal(t, uint64(910), snapshot.Used) + require.InDelta(t, 0.91, snapshot.Usage, 0.0001) + reject, snapshot, threshold = m.ShouldRejectNewConn() + require.True(t, reject) + require.Equal(t, 0.9, threshold) + require.Equal(t, uint64(910), snapshot.Used) + require.InDelta(t, 0.91, snapshot.Usage, 0.0001) + + m.UpdateConnBufferMemory(-20) + reject, snapshot, threshold = m.ShouldRejectNewConn() + require.False(t, reject) + require.Equal(t, 0.9, threshold) + require.Equal(t, uint64(890), snapshot.Used) + require.InDelta(t, 0.89, snapshot.Usage, 0.0001) +} diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 39c5eca1b..8501721e5 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -98,6 +98,7 @@ func init() { colls = []prometheus.Collector{ ConnGauge, CreateConnCounter, + RejectConnCounter, DisConnCounter, MaxProcsGauge, OwnerGauge, diff --git a/pkg/metrics/server.go b/pkg/metrics/server.go index b569a9910..53a288216 100644 --- a/pkg/metrics/server.go +++ b/pkg/metrics/server.go @@ -34,6 +34,14 @@ var ( Help: "Number of create connections.", }) + RejectConnCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: ModuleProxy, + Subsystem: LabelServer, + Name: "reject_connection_total", + Help: "Number of rejected connections.", + }, []string{LblType}) + DisConnCounter = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: ModuleProxy, diff --git a/pkg/proxy/backend/cmd_processor_test.go b/pkg/proxy/backend/cmd_processor_test.go index fbf32551d..25ff34806 100644 --- a/pkg/proxy/backend/cmd_processor_test.go +++ b/pkg/proxy/backend/cmd_processor_test.go @@ -1033,7 +1033,8 @@ func TestNetworkError(t *testing.T) { } backendErrChecker := func(t *testing.T, ts *testSuite) { require.True(t, pnet.IsDisconnectError(ts.mp.err)) - require.True(t, pnet.IsDisconnectError(ts.mb.err)) + // The backend mock may finish writing the error packet before the proxy actively closes the backend side. + require.True(t, ts.mb.err == nil || pnet.IsDisconnectError(ts.mb.err)) } proxyErrChecker := func(t *testing.T, ts *testSuite) { require.True(t, pnet.IsDisconnectError(ts.mp.err)) diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 021376125..d159e5d40 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -239,6 +239,9 @@ type packetIO struct { } func NewPacketIO(conn net.Conn, lg *zap.Logger, bufferSize int, opts ...PacketIOption) *packetIO { + if bufferSize == 0 { + bufferSize = DefaultConnBufferSize + } p := &packetIO{ rawConn: conn, logger: lg, diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 4e3215ffa..cc6cbfb18 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -15,6 +15,7 @@ import ( "github.com/pingcap/tiproxy/lib/util/waitgroup" "github.com/pingcap/tiproxy/pkg/manager/cert" "github.com/pingcap/tiproxy/pkg/manager/id" + mgrmem "github.com/pingcap/tiproxy/pkg/manager/memory" "github.com/pingcap/tiproxy/pkg/metrics" "github.com/pingcap/tiproxy/pkg/proxy/backend" "github.com/pingcap/tiproxy/pkg/proxy/client" @@ -44,6 +45,7 @@ type SQLServer struct { logger *zap.Logger certMgr *cert.CertManager idMgr *id.IDManager + memUsage memoryStateProvider hsHandler backend.HandshakeHandler cpt capture.Capture wg waitgroup.WaitGroup @@ -52,13 +54,31 @@ type SQLServer struct { mu serverState } +type memoryStateProvider interface { + ShouldRejectNewConn() (bool, mgrmem.UsageSnapshot, float64) +} + +type connBufferMemoryUpdater interface { + UpdateConnBufferMemory(delta int64) +} + +func estimateConnBufferMemDelta(bufferSize int) int64 { + if bufferSize == 0 { + bufferSize = pnet.DefaultConnBufferSize + } + // write buffer + read buffer + return int64(bufferSize * 2) +} + // NewSQLServer creates a new SQLServer. -func NewSQLServer(logger *zap.Logger, cfg *config.Config, certMgr *cert.CertManager, idMgr *id.IDManager, cpt capture.Capture, hsHandler backend.HandshakeHandler) (*SQLServer, error) { +func NewSQLServer(logger *zap.Logger, cfg *config.Config, certMgr *cert.CertManager, idMgr *id.IDManager, cpt capture.Capture, + hsHandler backend.HandshakeHandler, memUsage memoryStateProvider) (*SQLServer, error) { var err error s := &SQLServer{ logger: logger, certMgr: certMgr, idMgr: idMgr, + memUsage: memUsage, hsHandler: hsHandler, cpt: cpt, mu: serverState{ @@ -139,6 +159,18 @@ func (s *SQLServer) Run(ctx context.Context, cfgch <-chan *config.Config) { } func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { + if s.rejectConnByMemory(conn) { + return + } + + var ( + connBufferUpdater connBufferMemoryUpdater + connBufferMemDelta int64 + ) + if s.memUsage != nil { + connBufferUpdater, _ = s.memUsage.(connBufferMemoryUpdater) + } + tcpKeepAlive, logger, connID, clientConn := func() (bool, *zap.Logger, uint64, *client.ClientConnection) { s.mu.Lock() defer s.mu.Unlock() @@ -147,7 +179,8 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { maxConns := s.mu.maxConnections // 'maxConns == 0' => unlimited connections if maxConns != 0 && conns >= maxConns { - s.logger.Warn("too many connections", zap.Uint64("max connections", maxConns), zap.String("client_addr", conn.RemoteAddr().Network()), zap.Error(conn.Close())) + metrics.RejectConnCounter.WithLabelValues("max_connections").Inc() + s.logger.Warn("too many connections", zap.Uint64("max connections", maxConns), zap.Stringer("client_addr", conn.RemoteAddr()), zap.Error(conn.Close())) return false, nil, 0, nil } @@ -163,6 +196,10 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { ConnBufferSize: s.mu.connBufferSize, }) s.mu.clients[connID] = clientConn + connBufferMemDelta = estimateConnBufferMemDelta(s.mu.connBufferSize) + if connBufferUpdater != nil { + connBufferUpdater.UpdateConnBufferMemory(connBufferMemDelta) + } logger.Debug("new connection", zap.Bool("proxy-protocol", s.mu.proxyProtocol), zap.Bool("require_backend_tls", s.mu.requireBackendTLS)) return s.mu.tcpKeepAlive, logger, connID, clientConn }() @@ -178,6 +215,9 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { s.mu.Lock() delete(s.mu.clients, connID) s.mu.Unlock() + if connBufferUpdater != nil { + connBufferUpdater.UpdateConnBufferMemory(-connBufferMemDelta) + } if err := clientConn.Close(); err != nil && !pnet.IsDisconnectError(err) { logger.Error("close connection fails", zap.Error(err)) @@ -194,6 +234,26 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { clientConn.Run(ctx) } +func (s *SQLServer) rejectConnByMemory(conn net.Conn) bool { + if s.memUsage == nil { + return false + } + reject, snapshot, threshold := s.memUsage.ShouldRejectNewConn() + if !reject { + return false + } + metrics.RejectConnCounter.WithLabelValues("memory").Inc() + s.logger.Warn("reject connection due to high memory usage", + zap.Stringer("client_addr", conn.RemoteAddr()), + zap.Float64("threshold", threshold), + zap.Float64("usage", snapshot.Usage), + zap.Uint64("used", snapshot.Used), + zap.Uint64("limit", snapshot.Limit), + zap.Time("last_update", snapshot.UpdateTime), + zap.Error(conn.Close())) + return true +} + func (s *SQLServer) PreClose() { // Step 1: HTTP status returns unhealthy so that NLB takes this instance offline and then new connections won't come. s.mu.Lock() diff --git a/pkg/proxy/proxy_test.go b/pkg/proxy/proxy_test.go index d0d3711e4..2dda7957c 100644 --- a/pkg/proxy/proxy_test.go +++ b/pkg/proxy/proxy_test.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "strings" + "sync/atomic" "testing" "time" @@ -20,6 +21,7 @@ import ( "github.com/pingcap/tiproxy/pkg/balance/router" "github.com/pingcap/tiproxy/pkg/manager/cert" "github.com/pingcap/tiproxy/pkg/manager/id" + mgrmem "github.com/pingcap/tiproxy/pkg/manager/memory" "github.com/pingcap/tiproxy/pkg/metrics" "github.com/pingcap/tiproxy/pkg/proxy/backend" "github.com/pingcap/tiproxy/pkg/proxy/client" @@ -32,7 +34,7 @@ func TestCreateConn(t *testing.T) { cfg := &config.Config{} certManager := cert.NewCertManager() require.NoError(t, certManager.Init(cfg, lg, nil)) - server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, &mockHsHandler{}) + server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, &mockHsHandler{}, nil) require.NoError(t, err) server.Run(context.Background(), nil) defer func() { @@ -69,6 +71,92 @@ func TestCreateConn(t *testing.T) { checkMetrics(0, 2) } +func TestRejectConnByMemory(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + certManager := cert.NewCertManager() + require.NoError(t, certManager.Init(&config.Config{}, lg, nil)) + server, err := NewSQLServer(lg, &config.Config{}, certManager, id.NewIDManager(), nil, &mockHsHandler{}, &mockMemUsageProvider{ + reject: true, + snapshot: mgrmem.UsageSnapshot{ + Used: 9 * (1 << 30), + Limit: 10 * (1 << 30), + Usage: 0.9, + UpdateTime: time.Now(), + Valid: true, + }, + threshold: 0.9, + }) + require.NoError(t, err) + server.Run(context.Background(), nil) + defer func() { + server.PreClose() + require.NoError(t, server.Close()) + certManager.Close() + }() + + rejectBefore, err := metrics.ReadCounter(metrics.RejectConnCounter.WithLabelValues("memory")) + require.NoError(t, err) + createBefore, err := metrics.ReadCounter(metrics.CreateConnCounter) + require.NoError(t, err) + connGaugeBefore, err := metrics.ReadGauge(metrics.ConnGauge) + require.NoError(t, err) + + conn, err := net.Dial("tcp", server.listeners[0].Addr().String()) + require.NoError(t, err) + defer func() { _ = conn.Close() }() + + require.Eventually(t, func() bool { + _ = conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) + var buf [1]byte + _, err := conn.Read(buf[:]) + return err != nil + }, time.Second, 10*time.Millisecond) + + require.Eventually(t, func() bool { + rejectAfter, err := metrics.ReadCounter(metrics.RejectConnCounter.WithLabelValues("memory")) + require.NoError(t, err) + createAfter, err := metrics.ReadCounter(metrics.CreateConnCounter) + require.NoError(t, err) + connGaugeAfter, err := metrics.ReadGauge(metrics.ConnGauge) + require.NoError(t, err) + return rejectAfter == rejectBefore+1 && createAfter == createBefore && connGaugeAfter == connGaugeBefore + }, time.Second, 10*time.Millisecond) +} + +func TestTrackConnBufferMemDelta(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + certManager := cert.NewCertManager() + cfg := &config.Config{ + Proxy: config.ProxyServer{ + ProxyServerOnline: config.ProxyServerOnline{ + ConnBufferSize: 4096, + }, + }, + } + require.NoError(t, certManager.Init(cfg, lg, nil)) + memUsage := &mockMemUsageProvider{} + server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, &mockHsHandler{}, memUsage) + require.NoError(t, err) + server.Run(context.Background(), nil) + defer func() { + server.PreClose() + require.NoError(t, server.Close()) + certManager.Close() + }() + + conn, err := net.Dial("tcp", server.listeners[0].Addr().String()) + require.NoError(t, err) + + require.Eventually(t, func() bool { + return memUsage.connBufferMemDelta.Load() == int64(cfg.Proxy.ConnBufferSize*2) + }, time.Second, 10*time.Millisecond) + + require.NoError(t, conn.Close()) + require.Eventually(t, func() bool { + return memUsage.connBufferMemDelta.Load() == 0 + }, time.Second, 10*time.Millisecond) +} + func TestGracefulCloseConn(t *testing.T) { // Graceful shutdown finishes immediately if there's no connection. lg, _ := logger.CreateLoggerForTest(t) @@ -80,7 +168,7 @@ func TestGracefulCloseConn(t *testing.T) { }, }, } - server, err := NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, hsHandler) + server, err := NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, hsHandler, nil) require.NoError(t, err) finish := make(chan struct{}) go func() { @@ -110,7 +198,7 @@ func TestGracefulCloseConn(t *testing.T) { } // Graceful shutdown will be blocked if there are alive connections. - server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, hsHandler) + server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, hsHandler, nil) require.NoError(t, err) clientConn := createClientConn() go func() { @@ -136,7 +224,7 @@ func TestGracefulCloseConn(t *testing.T) { // Graceful shutdown will shut down after GracefulCloseConnTimeout. cfg.Proxy.GracefulCloseConnTimeout = 1 - server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, hsHandler) + server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, hsHandler, nil) require.NoError(t, err) createClientConn() go func() { @@ -164,7 +252,7 @@ func TestGracefulShutDown(t *testing.T) { }, }, } - server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, &mockHsHandler{}) + server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, &mockHsHandler{}, nil) require.NoError(t, err) server.Run(context.Background(), nil) @@ -202,7 +290,7 @@ func TestMultiAddr(t *testing.T) { Proxy: config.ProxyServer{ Addr: "0.0.0.0:0,0.0.0.0:0", }, - }, certManager, id.NewIDManager(), nil, &mockHsHandler{}) + }, certManager, id.NewIDManager(), nil, &mockHsHandler{}, nil) require.NoError(t, err) server.Run(context.Background(), nil) @@ -222,7 +310,7 @@ func TestWatchCfg(t *testing.T) { lg, _ := logger.CreateLoggerForTest(t) hsHandler := backend.NewDefaultHandshakeHandler(nil) cfgch := make(chan *config.Config) - server, err := NewSQLServer(lg, &config.Config{}, nil, id.NewIDManager(), nil, hsHandler) + server, err := NewSQLServer(lg, &config.Config{}, nil, id.NewIDManager(), nil, hsHandler, nil) require.NoError(t, err) server.Run(context.Background(), cfgch) cfg := &config.Config{ @@ -264,7 +352,7 @@ func TestRecoverPanic(t *testing.T) { } return nil }, - }) + }, nil) require.NoError(t, err) server.Run(context.Background(), nil) @@ -290,6 +378,21 @@ type mockHsHandler struct { handshakeResp func(ctx backend.ConnContext, _ *pnet.HandshakeResp) error } +type mockMemUsageProvider struct { + reject bool + snapshot mgrmem.UsageSnapshot + threshold float64 + connBufferMemDelta atomic.Int64 +} + +func (m *mockMemUsageProvider) ShouldRejectNewConn() (bool, mgrmem.UsageSnapshot, float64) { + return m.reject, m.snapshot, m.threshold +} + +func (m *mockMemUsageProvider) UpdateConnBufferMemory(delta int64) { + m.connBufferMemDelta.Add(delta) +} + // HandleHandshakeResp only panics for the first connections. func (handler *mockHsHandler) HandleHandshakeResp(ctx backend.ConnContext, resp *pnet.HandshakeResp) error { if handler.handshakeResp != nil { diff --git a/pkg/server/server.go b/pkg/server/server.go index ec159ba34..eac0cb460 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -176,7 +176,7 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) // setup proxy server { - srv.proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg, srv.certManager, idMgr, srv.replay.GetCapture(), hsHandler) + srv.proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg, srv.certManager, idMgr, srv.replay.GetCapture(), hsHandler, srv.memManager) if err != nil { return }