From 76afb57ec750443eb5e9a22f6991d6c542258771 Mon Sep 17 00:00:00 2001 From: djshow832 Date: Tue, 21 Apr 2026 14:23:57 +0800 Subject: [PATCH 1/3] This is an automated cherry-pick of #1120 Signed-off-by: ti-chi-bot --- lib/config/proxy.go | 169 +++++++++++++++- lib/config/proxy_test.go | 43 ++++ pkg/manager/memory/memory.go | 256 +++++++++++++++++++++++ pkg/manager/memory/memory_test.go | 181 +++++++++++++++++ pkg/metrics/metrics.go | 1 + pkg/metrics/server.go | 8 + pkg/proxy/backend/cmd_processor_test.go | 3 +- pkg/proxy/net/packetio.go | 3 + pkg/proxy/proxy.go | 96 +++++++++ pkg/proxy/proxy_test.go | 259 +++++++++++++++++++++++- pkg/server/server.go | 4 + 11 files changed, 1018 insertions(+), 5 deletions(-) create mode 100644 pkg/manager/memory/memory.go create mode 100644 pkg/manager/memory/memory_test.go diff --git a/lib/config/proxy.go b/lib/config/proxy.go index 69b1a3c71..c034ecf25 100644 --- a/lib/config/proxy.go +++ b/lib/config/proxy.go @@ -47,9 +47,10 @@ type KeepAlive struct { } type ProxyServerOnline struct { - MaxConnections uint64 `yaml:"max-connections,omitempty" toml:"max-connections,omitempty" json:"max-connections,omitempty" reloadable:"true"` - ConnBufferSize int `yaml:"conn-buffer-size,omitempty" toml:"conn-buffer-size,omitempty" json:"conn-buffer-size,omitempty" reloadable:"true"` - FrontendKeepalive KeepAlive `yaml:"frontend-keepalive" toml:"frontend-keepalive" json:"frontend-keepalive"` + MaxConnections uint64 `yaml:"max-connections,omitempty" toml:"max-connections,omitempty" json:"max-connections,omitempty" reloadable:"true"` + HighMemoryUsageRejectThreshold float64 `yaml:"high-memory-usage-reject-threshold,omitempty" toml:"high-memory-usage-reject-threshold,omitempty" json:"high-memory-usage-reject-threshold,omitempty" reloadable:"true"` + ConnBufferSize int `yaml:"conn-buffer-size,omitempty" toml:"conn-buffer-size,omitempty" json:"conn-buffer-size,omitempty" reloadable:"true"` + FrontendKeepalive KeepAlive `yaml:"frontend-keepalive" toml:"frontend-keepalive" json:"frontend-keepalive"` // BackendHealthyKeepalive applies when the observer treats the backend as healthy. // The config values should be conservative to save CPU and tolerate network fluctuation. BackendHealthyKeepalive KeepAlive `yaml:"backend-healthy-keepalive" toml:"backend-healthy-keepalive" json:"backend-healthy-keepalive"` @@ -149,6 +150,7 @@ func NewConfig() *Config { cfg.Proxy.Addr = "0.0.0.0:6000" cfg.Proxy.FrontendKeepalive, cfg.Proxy.BackendHealthyKeepalive, cfg.Proxy.BackendUnhealthyKeepalive = DefaultKeepAlive() + cfg.Proxy.HighMemoryUsageRejectThreshold = 0.9 cfg.Proxy.PDAddrs = "127.0.0.1:2379" cfg.Proxy.GracefulCloseConnTimeout = 15 cfg.Proxy.FailoverTimeout = 60 @@ -272,3 +274,164 @@ func (cfg *Config) GetIPPort() (ip, port, statusPort string, err error) { } return } +<<<<<<< HEAD +======= + +// GetBackendClusters returns configured backend clusters. +// It keeps backward compatibility for the legacy `proxy.pd-addrs` setting. +func (cfg *Config) GetBackendClusters() []BackendCluster { + if len(cfg.Proxy.BackendClusters) > 0 { + return cfg.Proxy.BackendClusters + } + if strings.TrimSpace(cfg.Proxy.PDAddrs) == "" { + return nil + } + return []BackendCluster{{ + Name: DefaultBackendClusterName, + PDAddrs: cfg.Proxy.PDAddrs, + }} +} + +func (ps *ProxyServer) Check() error { + if ps.HighMemoryUsageRejectThreshold < 0 || ps.HighMemoryUsageRejectThreshold > 1 { + return errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.high-memory-usage-reject-threshold") + } + if ps.HighMemoryUsageRejectThreshold > 0 && ps.HighMemoryUsageRejectThreshold < 0.5 { + ps.HighMemoryUsageRejectThreshold = 0.5 + } + if _, err := ps.GetSQLAddrs(); err != nil { + return errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.addr or proxy.port-range: %s", err.Error()) + } + clusterNames := make(map[string]struct{}, len(ps.BackendClusters)) + for i, cluster := range ps.BackendClusters { + name := strings.TrimSpace(cluster.Name) + if name == "" { + return errors.Wrapf(ErrInvalidConfigValue, "proxy.backend-clusters[%d].name is empty", i) + } + if _, ok := clusterNames[name]; ok { + return errors.Wrapf(ErrInvalidConfigValue, "duplicate proxy.backend-clusters name %s", name) + } + clusterNames[name] = struct{}{} + if err := validateAddrList(cluster.PDAddrs, "proxy.backend-clusters.pd-addrs"); err != nil { + return err + } + if _, err := ParseNSServers(cluster.NSServers); err != nil { + return errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.backend-clusters.ns-servers: %s", err.Error()) + } + } + + if ps.FailoverTimeout < 0 { + return errors.Wrapf(ErrInvalidConfigValue, "proxy.failover-timeout must be greater than or equal to 0") + } + failBackends := ps.FailBackendList[:0] + failBackendSet := make(map[string]struct{}, len(ps.FailBackendList)) + for i, backendName := range ps.FailBackendList { + backendName = strings.TrimSpace(backendName) + if backendName == "" { + return errors.Wrapf(ErrInvalidConfigValue, "proxy.fail-backend-list[%d] is empty", i) + } + if _, ok := failBackendSet[backendName]; ok { + continue + } + failBackendSet[backendName] = struct{}{} + failBackends = append(failBackends, backendName) + } + ps.FailBackendList = failBackends + return nil +} + +// SplitAddrList splits a comma-separated address list, trims each address, and drops empty entries. +func SplitAddrList(addrs string) []string { + parts := strings.Split(addrs, ",") + trimmed := make([]string, 0, len(parts)) + for _, part := range parts { + addr := strings.TrimSpace(part) + if addr != "" { + trimmed = append(trimmed, addr) + } + } + return trimmed +} + +func validateAddrList(addrs, field string) error { + parts := SplitAddrList(addrs) + if len(parts) == 0 { + return errors.Wrapf(ErrInvalidConfigValue, "%s is empty", field) + } + for _, addr := range parts { + if _, _, err := net.SplitHostPort(addr); err != nil { + return errors.Wrapf(ErrInvalidConfigValue, "invalid %s address %s", field, addr) + } + } + return nil +} + +func ParseNSServers(nsServers []string) ([]string, error) { + if len(nsServers) == 0 { + return nil, nil + } + normalized := make([]string, 0, len(nsServers)) + for _, server := range nsServers { + addr, err := normalizeNSServer(server) + if err != nil { + return nil, err + } + normalized = append(normalized, addr) + } + return normalized, nil +} + +func normalizeNSServer(server string) (string, error) { + host, port, err := net.SplitHostPort(server) + if err == nil { + if host == "" { + return "", errors.Wrapf(ErrInvalidConfigValue, "host is empty") + } + portNum, err := strconv.Atoi(port) + if err != nil || portNum < 1 || portNum > 65535 { + return "", errors.Wrapf(ErrInvalidConfigValue, "port is invalid") + } + return net.JoinHostPort(host, strconv.Itoa(portNum)), nil + } + + if server == "" { + return "", errors.Wrapf(ErrInvalidConfigValue, "host is empty") + } + if strings.ContainsAny(server, "[]") { + return "", errors.Wrapf(ErrInvalidConfigValue, "host is invalid") + } + return net.JoinHostPort(server, "53"), nil +} + +func (ps *ProxyServer) GetSQLAddrs() ([]string, error) { + addrs := SplitAddrList(ps.Addr) + if len(addrs) == 0 { + if len(ps.PortRange) == 0 { + return []string{""}, nil + } + return nil, errors.Wrapf(ErrInvalidConfigValue, "proxy.addr is empty") + } + if len(ps.PortRange) == 0 { + return addrs, nil + } + if len(ps.PortRange) != 2 { + return nil, errors.Wrapf(ErrInvalidConfigValue, "proxy.port-range must contain exactly two ports") + } + start, end := ps.PortRange[0], ps.PortRange[1] + if start < 1 || start > 65535 || end < 1 || end > 65535 || start > end { + return nil, errors.Wrapf(ErrInvalidConfigValue, "proxy.port-range is invalid") + } + if len(addrs) != 1 { + return nil, errors.Wrapf(ErrInvalidConfigValue, "proxy.addr must contain exactly one host when proxy.port-range is set") + } + host, _, err := net.SplitHostPort(addrs[0]) + if err != nil { + return nil, errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.addr: %s", err.Error()) + } + sqlAddrs := make([]string, 0, end-start+1) + for port := start; port <= end; port++ { + sqlAddrs = append(sqlAddrs, net.JoinHostPort(host, strconv.Itoa(port))) + } + return sqlAddrs, nil +} +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) diff --git a/lib/config/proxy_test.go b/lib/config/proxy_test.go index e908af9e6..baa281bef 100644 --- a/lib/config/proxy_test.go +++ b/lib/config/proxy_test.go @@ -20,6 +20,7 @@ var testProxyConfig = Config{ Addr: "0.0.0.0:4000", PDAddrs: "127.0.0.1:4089", ProxyServerOnline: ProxyServerOnline{ +<<<<<<< HEAD MaxConnections: 1, FrontendKeepalive: KeepAlive{Enabled: true}, ProxyProtocol: "v2", @@ -27,6 +28,28 @@ var testProxyConfig = Config{ FailBackendList: []string{"db-tidb-0", "db-tidb-1"}, FailoverTimeout: 60, ConnBufferSize: 32 * 1024, +======= + MaxConnections: 1, + HighMemoryUsageRejectThreshold: 0.9, + FrontendKeepalive: KeepAlive{Enabled: true}, + ProxyProtocol: "v2", + GracefulWaitBeforeShutdown: 10, + FailBackendList: []string{"db-tidb-0", "db-tidb-1"}, + FailoverTimeout: 60, + ConnBufferSize: 32 * 1024, + BackendClusters: []BackendCluster{ + { + Name: "cluster-a", + PDAddrs: "127.0.0.1:12379,127.0.0.1:22379", + NSServers: []string{"10.0.0.2", "10.0.0.3"}, + }, + { + Name: "cluster-b", + PDAddrs: "127.0.0.1:32379", + NSServers: []string{"10.0.0.4"}, + }, + }, +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) }, }, API: API{ @@ -92,6 +115,26 @@ func TestProxyCheck(t *testing.T) { post func(*testing.T, *Config) err error }{ + { + pre: func(t *testing.T, c *Config) { + c.Proxy.HighMemoryUsageRejectThreshold = -0.1 + }, + err: ErrInvalidConfigValue, + }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.HighMemoryUsageRejectThreshold = 1.1 + }, + err: ErrInvalidConfigValue, + }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.HighMemoryUsageRejectThreshold = 0.4 + }, + post: func(t *testing.T, c *Config) { + require.Equal(t, 0.5, c.Proxy.HighMemoryUsageRejectThreshold) + }, + }, { pre: func(t *testing.T, c *Config) { c.Workdir = "" diff --git a/pkg/manager/memory/memory.go b/pkg/manager/memory/memory.go new file mode 100644 index 000000000..5c293e8f3 --- /dev/null +++ b/pkg/manager/memory/memory.go @@ -0,0 +1,256 @@ +// Copyright 2025 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package memory + +import ( + "context" + "os" + "path/filepath" + "runtime" + "runtime/pprof" + "sync/atomic" + "time" + + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tiproxy/lib/config" + "github.com/pingcap/tiproxy/pkg/util/waitgroup" + "go.uber.org/zap" +) + +const ( + // Refresh the memory usage every 5 seconds. + refreshInterval = 5 * time.Second + // No need to record too frequently. + recordMinInterval = 5 * time.Minute + // Record the profiles when the memory usage is higher than 60%. + alarmThreshold = 0.6 + // Remove the oldest profiles when the number of profiles exceeds this limit. + maxSavedProfiles = 20 + // Fail open if the latest sampled usage is too old. + snapshotExpireInterval = 3 * refreshInterval +) + +type UsageSnapshot struct { + Used uint64 + Limit uint64 + Usage float64 + UpdateTime time.Time + Valid bool +} + +// MemManager is a manager for memory usage. +// Although the continuous profiling collects profiles periodically, when TiProxy runs in the replayer mode, +// the profiles are not collected. +type MemManager struct { + lg *zap.Logger + cancel context.CancelFunc + wg waitgroup.WaitGroup + cfgGetter config.ConfigGetter + savedProfileNames []string + lastRecordTime time.Time + refreshInterval time.Duration // used for test + recordMinInterval time.Duration // used for test + maxSavedProfiles int // used for test + snapshotExpire time.Duration // used for test + memoryLimit uint64 + latestUsage atomic.Value + // connBufferMemDelta tracks the estimated buffer memory change since the latest refreshUsage. + connBufferMemDelta atomic.Int64 +} + +func NewMemManager(lg *zap.Logger, cfgGetter config.ConfigGetter) *MemManager { + mgr := &MemManager{ + lg: lg, + cfgGetter: cfgGetter, + refreshInterval: refreshInterval, + recordMinInterval: recordMinInterval, + maxSavedProfiles: maxSavedProfiles, + snapshotExpire: snapshotExpireInterval, + } + mgr.latestUsage.Store(UsageSnapshot{}) + return mgr +} + +func (m *MemManager) Start(ctx context.Context) { + // Call the memory.MemTotal and memory.MemUsed in TiDB repo because they have considered cgroup. + limit, err := memory.MemTotal() + if err != nil || limit == 0 { + m.lg.Warn("get memory limit failed", zap.Uint64("limit", limit), zap.Error(err)) + return + } + m.memoryLimit = limit + if _, err = m.refreshUsage(); err != nil { + return + } + childCtx, cancel := context.WithCancel(ctx) + m.cancel = cancel + m.wg.RunWithRecover(func() { + m.refreshLoop(childCtx) + }, nil, m.lg) +} + +func (m *MemManager) refreshLoop(ctx context.Context) { + ticker := time.NewTicker(m.refreshInterval) + defer ticker.Stop() + for ctx.Err() == nil { + select { + case <-ctx.Done(): + return + case <-ticker.C: + m.refreshAndAlarm() + } + } +} + +func (m *MemManager) refreshAndAlarm() { + snapshot, err := m.refreshUsage() + if err != nil || !snapshot.Valid { + return + } + if snapshot.Usage < alarmThreshold { + return + } + if time.Since(m.lastRecordTime) < m.recordMinInterval { + return + } + // The filename is hot-reloadable. + cfg := m.cfgGetter.GetConfig() + if cfg == nil { + return + } + logPath := cfg.Log.LogFile.Filename + if logPath == "" { + return + } + recordDir := filepath.Dir(logPath) + + m.lastRecordTime = snapshot.UpdateTime + m.lg.Warn("memory usage alarm", zap.Uint64("limit", snapshot.Limit), zap.Uint64("used", snapshot.Used), zap.Float64("usage", snapshot.Usage)) + now := time.Now().Format(time.RFC3339) + m.recordHeap(filepath.Join(recordDir, "heap_"+now)) + m.recordGoroutine(filepath.Join(recordDir, "goroutine_"+now)) + m.rmExpiredProfiles() +} + +func (m *MemManager) refreshUsage() (UsageSnapshot, error) { + if m.memoryLimit == 0 { + return UsageSnapshot{}, nil + } + used, err := memory.MemUsed() + if err != nil || used == 0 { + m.lg.Warn("get used memory failed", zap.Uint64("used", used), zap.Error(err)) + return UsageSnapshot{}, err + } + // Start a new delta window from this sampled snapshot. Later connection create/close + // events only adjust the in-memory estimate relative to this refresh result. + m.connBufferMemDelta.Swap(0) + snapshot := UsageSnapshot{ + Used: used, + Limit: m.memoryLimit, + Usage: float64(used) / float64(m.memoryLimit), + UpdateTime: time.Now(), + Valid: true, + } + m.latestUsage.Store(snapshot) + return snapshot, nil +} + +func (m *MemManager) LatestUsage() UsageSnapshot { + snapshot, _ := m.latestUsage.Load().(UsageSnapshot) + return snapshot +} + +func (m *MemManager) UpdateConnBufferMemory(delta int64) { + if m == nil || delta == 0 { + return + } + m.connBufferMemDelta.Add(delta) +} + +// adjustUsageByConnBuffer applies the connection buffer delta accumulated after the +// latest refreshUsage, so ShouldRejectNewConn can react before the next memory sample. +func (m *MemManager) adjustUsageByConnBuffer(snapshot UsageSnapshot) UsageSnapshot { + delta := m.connBufferMemDelta.Load() + if delta == 0 { + return snapshot + } + if delta > 0 { + snapshot.Used += uint64(delta) + } else { + released := uint64(-delta) + if released >= snapshot.Used { + snapshot.Used = 0 + } else { + snapshot.Used -= released + } + } + if snapshot.Limit > 0 { + snapshot.Usage = float64(snapshot.Used) / float64(snapshot.Limit) + } + return snapshot +} + +func (m *MemManager) ShouldRejectNewConn() (bool, UsageSnapshot, float64) { + if m == nil || m.cfgGetter == nil { + return false, UsageSnapshot{}, 0 + } + cfg := m.cfgGetter.GetConfig() + if cfg == nil { + return false, UsageSnapshot{}, 0 + } + threshold := cfg.Proxy.HighMemoryUsageRejectThreshold + if threshold == 0 { + return false, UsageSnapshot{}, 0 + } + snapshot := m.LatestUsage() + if !snapshot.Valid || time.Since(snapshot.UpdateTime) > m.snapshotExpire { + return false, snapshot, threshold + } + snapshot = m.adjustUsageByConnBuffer(snapshot) + return snapshot.Usage >= threshold, snapshot, threshold +} + +func (m *MemManager) recordHeap(fileName string) { + f, err := os.Create(fileName) + if err != nil { + m.lg.Error("failed to create heap profile file", zap.Error(err)) + return + } + defer f.Close() + p := pprof.Lookup("heap") + if err = p.WriteTo(f, 0); err != nil { + m.lg.Error("failed to write heap profile file", zap.Error(err)) + } + m.savedProfileNames = append(m.savedProfileNames, fileName) +} + +func (m *MemManager) recordGoroutine(fileName string) { + buf := make([]byte, 1<<26) // 64MB buffer + n := runtime.Stack(buf, true) + if n >= len(buf) { + m.lg.Warn("goroutine stack trace is too large, truncating", zap.Int("size", n)) + } + //nolint: gosec + if err := os.WriteFile(fileName, buf[:n], 0644); err != nil { + m.lg.Error("failed to write goroutine profile file", zap.Error(err)) + } + m.savedProfileNames = append(m.savedProfileNames, fileName) +} + +func (m *MemManager) rmExpiredProfiles() { + for len(m.savedProfileNames) > m.maxSavedProfiles { + if err := os.Remove(m.savedProfileNames[0]); err != nil { + m.lg.Warn("failed to remove expired profile file", zap.String("file", m.savedProfileNames[0]), zap.Error(err)) + } + copy(m.savedProfileNames[0:], m.savedProfileNames[1:]) + m.savedProfileNames = m.savedProfileNames[:len(m.savedProfileNames)-1] + } +} + +func (m *MemManager) Close() { + if m.cancel != nil { + m.cancel() + } + m.wg.Wait() +} diff --git a/pkg/manager/memory/memory_test.go b/pkg/manager/memory/memory_test.go new file mode 100644 index 000000000..dec324ac3 --- /dev/null +++ b/pkg/manager/memory/memory_test.go @@ -0,0 +1,181 @@ +// Copyright 2025 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package memory + +import ( + "context" + "os" + "path" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tiproxy/lib/config" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +type mockCfgGetter struct { + cfg *config.Config +} + +func (c *mockCfgGetter) GetConfig() *config.Config { + return c.cfg +} + +func TestRecordProfile(t *testing.T) { + oldMemUsed, oldMemTotal := memory.MemUsed, memory.MemTotal + defer func() { + memory.MemUsed = oldMemUsed + memory.MemTotal = oldMemTotal + }() + + dir := t.TempDir() + cfg := &config.Config{} + cfg.Log.LogFile.Filename = path.Join(dir, "proxy.log") + cfgGetter := mockCfgGetter{cfg: cfg} + memory.MemUsed = func() (uint64, error) { + return 9 * (1 << 30), nil + } + memory.MemTotal = func() (uint64, error) { + return 10 * (1 << 30), nil + } + m := NewMemManager(zap.NewNop(), &cfgGetter) + // The timestamp in file names are in seconds instead of milliseconds, so recording too frequently is useless. + // Instead, it may overwrite the previous files. + m.refreshInterval = 100 * time.Millisecond + m.recordMinInterval = 1200 * time.Millisecond + m.maxSavedProfiles = 2 + m.Start(context.Background()) + + // The profiles are recorded. + require.Eventually(t, func() bool { + entries, err := os.ReadDir(dir) + require.NoError(t, err) + prefixes := []string{"heap_", "goroutine_"} + for _, entry := range entries { + if entry.IsDir() { + continue + } + for i, prefix := range prefixes { + if strings.HasPrefix(entry.Name(), prefix) { + info, err := os.Stat(path.Join(dir, entry.Name())) + require.NoError(t, err) + if info.Size() == 0 { + return false + } + prefixes = append(prefixes[:i], prefixes[i+1:]...) + break + } + } + } + return len(prefixes) == 0 + }, 3*time.Second, 100*time.Millisecond) + + // The expired profiles are removed. + time.Sleep(2 * time.Second) + m.Close() + entries, err := os.ReadDir(dir) + require.NoError(t, err) + require.Len(t, entries, m.maxSavedProfiles) +} + +func TestShouldRejectNewConn(t *testing.T) { + oldMemUsed, oldMemTotal := memory.MemUsed, memory.MemTotal + defer func() { + memory.MemUsed = oldMemUsed + memory.MemTotal = oldMemTotal + }() + + cfg := config.NewConfig() + cfg.Proxy.HighMemoryUsageRejectThreshold = 0.9 + cfgGetter := mockCfgGetter{cfg: cfg} + memory.MemUsed = func() (uint64, error) { + return 9 * (1 << 30), nil + } + memory.MemTotal = func() (uint64, error) { + return 10 * (1 << 30), nil + } + m := NewMemManager(zap.NewNop(), &cfgGetter) + m.refreshInterval = 50 * time.Millisecond + m.snapshotExpire = 200 * time.Millisecond + m.Start(context.Background()) + defer m.Close() + + require.Eventually(t, func() bool { + reject, snapshot, threshold := m.ShouldRejectNewConn() + return reject && snapshot.Valid && threshold == 0.9 + }, time.Second, 10*time.Millisecond) + m.Close() + + cfg.Proxy.HighMemoryUsageRejectThreshold = 0 + reject, _, threshold := m.ShouldRejectNewConn() + require.False(t, reject) + require.Zero(t, threshold) + + staleSnapshot := m.LatestUsage() + staleSnapshot.UpdateTime = time.Now().Add(-m.snapshotExpire - time.Second) + m.latestUsage.Store(staleSnapshot) + cfg.Proxy.HighMemoryUsageRejectThreshold = 0.9 + reject, _, threshold = m.ShouldRejectNewConn() + require.False(t, reject) + require.Equal(t, 0.9, threshold) +} + +func TestShouldRejectNewConnTracksConnBufferMemory(t *testing.T) { + oldMemUsed, oldMemTotal := memory.MemUsed, memory.MemTotal + defer func() { + memory.MemUsed = oldMemUsed + memory.MemTotal = oldMemTotal + }() + + cfg := config.NewConfig() + cfg.Proxy.HighMemoryUsageRejectThreshold = 0.9 + cfgGetter := mockCfgGetter{cfg: cfg} + var used atomic.Uint64 + used.Store(890) + memory.MemUsed = func() (uint64, error) { + return used.Load(), nil + } + memory.MemTotal = func() (uint64, error) { + return 1000, nil + } + m := NewMemManager(zap.NewNop(), &cfgGetter) + m.refreshInterval = 50 * time.Millisecond + m.snapshotExpire = time.Second + m.Start(context.Background()) + defer m.Close() + + require.Eventually(t, func() bool { + reject, snapshot, threshold := m.ShouldRejectNewConn() + return !reject && snapshot.Valid && threshold == 0.9 && snapshot.Used == 890 + }, time.Second, 10*time.Millisecond) + + m.UpdateConnBufferMemory(20) + reject, snapshot, threshold := m.ShouldRejectNewConn() + require.True(t, reject) + require.Equal(t, 0.9, threshold) + require.Equal(t, uint64(910), snapshot.Used) + require.InDelta(t, 0.91, snapshot.Usage, 0.0001) + + used.Store(910) + snapshot, err := m.refreshUsage() + require.NoError(t, err) + require.Equal(t, uint64(910), snapshot.Used) + require.InDelta(t, 0.91, snapshot.Usage, 0.0001) + reject, snapshot, threshold = m.ShouldRejectNewConn() + require.True(t, reject) + require.Equal(t, 0.9, threshold) + require.Equal(t, uint64(910), snapshot.Used) + require.InDelta(t, 0.91, snapshot.Usage, 0.0001) + + m.UpdateConnBufferMemory(-20) + reject, snapshot, threshold = m.ShouldRejectNewConn() + require.False(t, reject) + require.Equal(t, 0.9, threshold) + require.Equal(t, uint64(890), snapshot.Used) + require.InDelta(t, 0.89, snapshot.Usage, 0.0001) +} diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index b16818598..74703154d 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -98,6 +98,7 @@ func init() { colls = []prometheus.Collector{ ConnGauge, CreateConnCounter, + RejectConnCounter, DisConnCounter, MaxProcsGauge, OwnerGauge, diff --git a/pkg/metrics/server.go b/pkg/metrics/server.go index b569a9910..53a288216 100644 --- a/pkg/metrics/server.go +++ b/pkg/metrics/server.go @@ -34,6 +34,14 @@ var ( Help: "Number of create connections.", }) + RejectConnCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: ModuleProxy, + Subsystem: LabelServer, + Name: "reject_connection_total", + Help: "Number of rejected connections.", + }, []string{LblType}) + DisConnCounter = prometheus.NewCounterVec( prometheus.CounterOpts{ Namespace: ModuleProxy, diff --git a/pkg/proxy/backend/cmd_processor_test.go b/pkg/proxy/backend/cmd_processor_test.go index fbf32551d..25ff34806 100644 --- a/pkg/proxy/backend/cmd_processor_test.go +++ b/pkg/proxy/backend/cmd_processor_test.go @@ -1033,7 +1033,8 @@ func TestNetworkError(t *testing.T) { } backendErrChecker := func(t *testing.T, ts *testSuite) { require.True(t, pnet.IsDisconnectError(ts.mp.err)) - require.True(t, pnet.IsDisconnectError(ts.mb.err)) + // The backend mock may finish writing the error packet before the proxy actively closes the backend side. + require.True(t, ts.mb.err == nil || pnet.IsDisconnectError(ts.mb.err)) } proxyErrChecker := func(t *testing.T, ts *testSuite) { require.True(t, pnet.IsDisconnectError(ts.mp.err)) diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 8e27b2755..e74abd403 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -239,6 +239,9 @@ type packetIO struct { } func NewPacketIO(conn net.Conn, lg *zap.Logger, bufferSize int, opts ...PacketIOption) *packetIO { + if bufferSize == 0 { + bufferSize = DefaultConnBufferSize + } p := &packetIO{ rawConn: conn, logger: lg, diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 4e3215ffa..e12b4dd12 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -15,6 +15,7 @@ import ( "github.com/pingcap/tiproxy/lib/util/waitgroup" "github.com/pingcap/tiproxy/pkg/manager/cert" "github.com/pingcap/tiproxy/pkg/manager/id" + mgrmem "github.com/pingcap/tiproxy/pkg/manager/memory" "github.com/pingcap/tiproxy/pkg/metrics" "github.com/pingcap/tiproxy/pkg/proxy/backend" "github.com/pingcap/tiproxy/pkg/proxy/client" @@ -44,6 +45,7 @@ type SQLServer struct { logger *zap.Logger certMgr *cert.CertManager idMgr *id.IDManager + memUsage memoryStateProvider hsHandler backend.HandshakeHandler cpt capture.Capture wg waitgroup.WaitGroup @@ -52,13 +54,35 @@ type SQLServer struct { mu serverState } +type memoryStateProvider interface { + ShouldRejectNewConn() (bool, mgrmem.UsageSnapshot, float64) +} + +type connBufferMemoryUpdater interface { + UpdateConnBufferMemory(delta int64) +} + +func estimateConnBufferMemDelta(bufferSize int) int64 { + if bufferSize == 0 { + bufferSize = pnet.DefaultConnBufferSize + } + // write buffer + read buffer + return int64(bufferSize * 2) +} + // NewSQLServer creates a new SQLServer. +<<<<<<< HEAD func NewSQLServer(logger *zap.Logger, cfg *config.Config, certMgr *cert.CertManager, idMgr *id.IDManager, cpt capture.Capture, hsHandler backend.HandshakeHandler) (*SQLServer, error) { +======= +func NewSQLServer(logger *zap.Logger, cfg *config.Config, certMgr *cert.CertManager, idMgr *id.IDManager, cpt capture.Capture, + meter backend.Meter, hsHandler backend.HandshakeHandler, memUsage memoryStateProvider) (*SQLServer, error) { +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) var err error s := &SQLServer{ logger: logger, certMgr: certMgr, idMgr: idMgr, + memUsage: memUsage, hsHandler: hsHandler, cpt: cpt, mu: serverState{ @@ -139,6 +163,18 @@ func (s *SQLServer) Run(ctx context.Context, cfgch <-chan *config.Config) { } func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { + if s.rejectConnByMemory(conn) { + return + } + + var ( + connBufferUpdater connBufferMemoryUpdater + connBufferMemDelta int64 + ) + if s.memUsage != nil { + connBufferUpdater, _ = s.memUsage.(connBufferMemoryUpdater) + } + tcpKeepAlive, logger, connID, clientConn := func() (bool, *zap.Logger, uint64, *client.ClientConnection) { s.mu.Lock() defer s.mu.Unlock() @@ -147,7 +183,12 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { maxConns := s.mu.maxConnections // 'maxConns == 0' => unlimited connections if maxConns != 0 && conns >= maxConns { +<<<<<<< HEAD s.logger.Warn("too many connections", zap.Uint64("max connections", maxConns), zap.String("client_addr", conn.RemoteAddr().Network()), zap.Error(conn.Close())) +======= + metrics.RejectConnCounter.WithLabelValues("max_connections").Inc() + s.logger.Warn("too many connections", zap.Uint64("max connections", maxConns), zap.Stringer("client_addr", conn.RemoteAddr()), zap.Error(conn.Close())) +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) return false, nil, 0, nil } @@ -163,6 +204,10 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { ConnBufferSize: s.mu.connBufferSize, }) s.mu.clients[connID] = clientConn + connBufferMemDelta = estimateConnBufferMemDelta(s.mu.connBufferSize) + if connBufferUpdater != nil { + connBufferUpdater.UpdateConnBufferMemory(connBufferMemDelta) + } logger.Debug("new connection", zap.Bool("proxy-protocol", s.mu.proxyProtocol), zap.Bool("require_backend_tls", s.mu.requireBackendTLS)) return s.mu.tcpKeepAlive, logger, connID, clientConn }() @@ -178,6 +223,9 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { s.mu.Lock() delete(s.mu.clients, connID) s.mu.Unlock() + if connBufferUpdater != nil { + connBufferUpdater.UpdateConnBufferMemory(-connBufferMemDelta) + } if err := clientConn.Close(); err != nil && !pnet.IsDisconnectError(err) { logger.Error("close connection fails", zap.Error(err)) @@ -194,6 +242,54 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { clientConn.Run(ctx) } +<<<<<<< HEAD +======= +func (s *SQLServer) rejectConnByMemory(conn net.Conn) bool { + if s.memUsage == nil { + return false + } + reject, snapshot, threshold := s.memUsage.ShouldRejectNewConn() + if !reject { + return false + } + metrics.RejectConnCounter.WithLabelValues("memory").Inc() + s.logger.Warn("reject connection due to high memory usage", + zap.Stringer("client_addr", conn.RemoteAddr()), + zap.Float64("threshold", threshold), + zap.Float64("usage", snapshot.Usage), + zap.Uint64("used", snapshot.Used), + zap.Uint64("limit", snapshot.Limit), + zap.Time("last_update", snapshot.UpdateTime), + zap.Error(conn.Close())) + return true +} + +func (s *SQLServer) fromPublicEndpoint(addr net.Addr) bool { + if addr == nil || reflect.ValueOf(addr).IsNil() { + return false + } + s.mu.RLock() + publicEndpoints := s.mu.publicEndpoints + s.mu.RUnlock() + ip, err := netutil.NetAddr2IP(addr) + if err != nil { + s.logger.Warn("failed to check public endpoint", zap.Any("addr", addr), zap.Error(err)) + return false + } + contains, err := netutil.CIDRContainsIP(publicEndpoints, ip) + if err != nil { + s.logger.Warn("failed to check public endpoint", zap.Any("ip", ip), zap.Error(err)) + return false + } + if contains { + return true + } + // The public NLB may enable preserveIP, and the incoming address is the client address, which may be a public address. + // Even if the private NLB enables preserveIP, the client address is still a private address. + return !netutil.IsPrivate(ip) +} + +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) func (s *SQLServer) PreClose() { // Step 1: HTTP status returns unhealthy so that NLB takes this instance offline and then new connections won't come. s.mu.Lock() diff --git a/pkg/proxy/proxy_test.go b/pkg/proxy/proxy_test.go index d0d3711e4..b78b3c49e 100644 --- a/pkg/proxy/proxy_test.go +++ b/pkg/proxy/proxy_test.go @@ -9,6 +9,11 @@ import ( "fmt" "net" "strings" +<<<<<<< HEAD +======= + "sync" + "sync/atomic" +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) "testing" "time" @@ -20,6 +25,7 @@ import ( "github.com/pingcap/tiproxy/pkg/balance/router" "github.com/pingcap/tiproxy/pkg/manager/cert" "github.com/pingcap/tiproxy/pkg/manager/id" + mgrmem "github.com/pingcap/tiproxy/pkg/manager/memory" "github.com/pingcap/tiproxy/pkg/metrics" "github.com/pingcap/tiproxy/pkg/proxy/backend" "github.com/pingcap/tiproxy/pkg/proxy/client" @@ -32,7 +38,11 @@ func TestCreateConn(t *testing.T) { cfg := &config.Config{} certManager := cert.NewCertManager() require.NoError(t, certManager.Init(cfg, lg, nil)) +<<<<<<< HEAD server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, &mockHsHandler{}) +======= + server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, nil) +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) require.NoError(t, err) server.Run(context.Background(), nil) defer func() { @@ -69,6 +79,92 @@ func TestCreateConn(t *testing.T) { checkMetrics(0, 2) } +func TestRejectConnByMemory(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + certManager := cert.NewCertManager() + require.NoError(t, certManager.Init(&config.Config{}, lg, nil)) + server, err := NewSQLServer(lg, &config.Config{}, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, &mockMemUsageProvider{ + reject: true, + snapshot: mgrmem.UsageSnapshot{ + Used: 9 * (1 << 30), + Limit: 10 * (1 << 30), + Usage: 0.9, + UpdateTime: time.Now(), + Valid: true, + }, + threshold: 0.9, + }) + require.NoError(t, err) + server.Run(context.Background(), nil) + defer func() { + server.PreClose() + require.NoError(t, server.Close()) + certManager.Close() + }() + + rejectBefore, err := metrics.ReadCounter(metrics.RejectConnCounter.WithLabelValues("memory")) + require.NoError(t, err) + createBefore, err := metrics.ReadCounter(metrics.CreateConnCounter) + require.NoError(t, err) + connGaugeBefore, err := metrics.ReadGauge(metrics.ConnGauge) + require.NoError(t, err) + + conn, err := net.Dial("tcp", server.listeners[0].Addr().String()) + require.NoError(t, err) + defer func() { _ = conn.Close() }() + + require.Eventually(t, func() bool { + _ = conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) + var buf [1]byte + _, err := conn.Read(buf[:]) + return err != nil + }, time.Second, 10*time.Millisecond) + + require.Eventually(t, func() bool { + rejectAfter, err := metrics.ReadCounter(metrics.RejectConnCounter.WithLabelValues("memory")) + require.NoError(t, err) + createAfter, err := metrics.ReadCounter(metrics.CreateConnCounter) + require.NoError(t, err) + connGaugeAfter, err := metrics.ReadGauge(metrics.ConnGauge) + require.NoError(t, err) + return rejectAfter == rejectBefore+1 && createAfter == createBefore && connGaugeAfter == connGaugeBefore + }, time.Second, 10*time.Millisecond) +} + +func TestTrackConnBufferMemDelta(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + certManager := cert.NewCertManager() + cfg := &config.Config{ + Proxy: config.ProxyServer{ + ProxyServerOnline: config.ProxyServerOnline{ + ConnBufferSize: 4096, + }, + }, + } + require.NoError(t, certManager.Init(cfg, lg, nil)) + memUsage := &mockMemUsageProvider{} + server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, memUsage) + require.NoError(t, err) + server.Run(context.Background(), nil) + defer func() { + server.PreClose() + require.NoError(t, server.Close()) + certManager.Close() + }() + + conn, err := net.Dial("tcp", server.listeners[0].Addr().String()) + require.NoError(t, err) + + require.Eventually(t, func() bool { + return memUsage.connBufferMemDelta.Load() == int64(cfg.Proxy.ConnBufferSize*2) + }, time.Second, 10*time.Millisecond) + + require.NoError(t, conn.Close()) + require.Eventually(t, func() bool { + return memUsage.connBufferMemDelta.Load() == 0 + }, time.Second, 10*time.Millisecond) +} + func TestGracefulCloseConn(t *testing.T) { // Graceful shutdown finishes immediately if there's no connection. lg, _ := logger.CreateLoggerForTest(t) @@ -80,7 +176,11 @@ func TestGracefulCloseConn(t *testing.T) { }, }, } +<<<<<<< HEAD server, err := NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, hsHandler) +======= + server, err := NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, nil, hsHandler, nil) +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) require.NoError(t, err) finish := make(chan struct{}) go func() { @@ -110,7 +210,11 @@ func TestGracefulCloseConn(t *testing.T) { } // Graceful shutdown will be blocked if there are alive connections. +<<<<<<< HEAD server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, hsHandler) +======= + server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, nil, hsHandler, nil) +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) require.NoError(t, err) clientConn := createClientConn() go func() { @@ -136,7 +240,11 @@ func TestGracefulCloseConn(t *testing.T) { // Graceful shutdown will shut down after GracefulCloseConnTimeout. cfg.Proxy.GracefulCloseConnTimeout = 1 +<<<<<<< HEAD server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, hsHandler) +======= + server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, nil, hsHandler, nil) +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) require.NoError(t, err) createClientConn() go func() { @@ -164,7 +272,11 @@ func TestGracefulShutDown(t *testing.T) { }, }, } +<<<<<<< HEAD server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, &mockHsHandler{}) +======= + server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, nil) +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) require.NoError(t, err) server.Run(context.Background(), nil) @@ -202,7 +314,11 @@ func TestMultiAddr(t *testing.T) { Proxy: config.ProxyServer{ Addr: "0.0.0.0:0,0.0.0.0:0", }, +<<<<<<< HEAD }, certManager, id.NewIDManager(), nil, &mockHsHandler{}) +======= + }, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, nil) +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) require.NoError(t, err) server.Run(context.Background(), nil) @@ -218,11 +334,96 @@ func TestMultiAddr(t *testing.T) { certManager.Close() } +<<<<<<< HEAD +======= +func TestPortRange(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + certManager := cert.NewCertManager() + err := certManager.Init(&config.Config{}, lg, nil) + require.NoError(t, err) + start, end := findFreePortRange(t, 3) + server, err := NewSQLServer(lg, &config.Config{ + Proxy: config.ProxyServer{ + Addr: fmt.Sprintf("127.0.0.1:%d", start), + PortRange: []int{start, end}, + }, + }, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, nil) + require.NoError(t, err) + server.Run(context.Background(), nil) + + require.Len(t, server.listeners, 3) + ports := make([]int, 0, len(server.listeners)) + for _, listener := range server.listeners { + tcpAddr, ok := listener.Addr().(*net.TCPAddr) + require.True(t, ok) + ports = append(ports, tcpAddr.Port) + + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + require.NoError(t, conn.Close()) + } + slices.Sort(ports) + require.Equal(t, []int{start, start + 1, end}, ports) + + server.PreClose() + require.NoError(t, server.Close()) + certManager.Close() +} + +func TestConnAddrUsesActualListenerAddr(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + certManager := cert.NewCertManager() + require.NoError(t, certManager.Init(&config.Config{}, lg, nil)) + + var ( + addrMu sync.Mutex + connAddr string + ) + handler := &mockHsHandler{ + getRouter: func(ctx backend.ConnContext, _ *pnet.HandshakeResp) (router.Router, error) { + addrMu.Lock() + connAddr, _ = ctx.Value(backend.ConnContextKeyConnAddr).(string) + addrMu.Unlock() + return nil, errors.New("no router") + }, + } + server, err := NewSQLServer(lg, &config.Config{ + Proxy: config.ProxyServer{ + Addr: "127.0.0.1:0", + }, + }, certManager, id.NewIDManager(), nil, nil, handler, nil) + require.NoError(t, err) + server.Run(context.Background(), nil) + defer func() { + server.PreClose() + require.NoError(t, server.Close()) + certManager.Close() + }() + + _, port, err := net.SplitHostPort(server.listeners[0].Addr().String()) + require.NoError(t, err) + mdb, err := sql.Open("mysql", fmt.Sprintf("root@tcp(127.0.0.1:%s)/test", port)) + require.NoError(t, err) + defer func() { require.NoError(t, mdb.Close()) }() + + require.ErrorContains(t, mdb.Ping(), "no router") + require.Eventually(t, func() bool { + addrMu.Lock() + defer addrMu.Unlock() + return connAddr == server.listeners[0].Addr().String() + }, 3*time.Second, 10*time.Millisecond) +} + +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) func TestWatchCfg(t *testing.T) { lg, _ := logger.CreateLoggerForTest(t) hsHandler := backend.NewDefaultHandshakeHandler(nil) cfgch := make(chan *config.Config) +<<<<<<< HEAD server, err := NewSQLServer(lg, &config.Config{}, nil, id.NewIDManager(), nil, hsHandler) +======= + server, err := NewSQLServer(lg, &config.Config{}, nil, id.NewIDManager(), nil, nil, hsHandler, nil) +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) require.NoError(t, err) server.Run(context.Background(), cfgch) cfg := &config.Config{ @@ -264,7 +465,7 @@ func TestRecoverPanic(t *testing.T) { } return nil }, - }) + }, nil) require.NoError(t, err) server.Run(context.Background(), nil) @@ -285,11 +486,67 @@ func TestRecoverPanic(t *testing.T) { certManager.Close() } +<<<<<<< HEAD +======= +func TestPublicEndpoint(t *testing.T) { + tests := []struct { + publicEndpoints []string + publicIps []string + privateIps []string + }{ + { + publicIps: []string{"137.84.2.178"}, + privateIps: []string{"10.10.10.10"}, + }, + { + publicEndpoints: []string{"10.10.10.0/24"}, + publicIps: []string{"137.84.2.178", "10.10.10.10"}, + privateIps: []string{"10.10.20.10"}, + }, + { + publicEndpoints: []string{"10.10.10.0/24", "10.10.20.10"}, + publicIps: []string{"137.84.2.178", "10.10.10.10", "10.10.20.10"}, + privateIps: []string{"10.10.20.11"}, + }, + } + + server, err := NewSQLServer(zap.NewNop(), &config.Config{}, nil, id.NewIDManager(), nil, nil, backend.NewDefaultHandshakeHandler(nil), nil) + require.NoError(t, err) + for i, test := range tests { + cfg := &config.Config{} + cfg.Proxy.PublicEndpoints = test.publicEndpoints + server.reset(cfg) + for j, ip := range test.publicIps { + require.True(t, server.fromPublicEndpoint(&net.TCPAddr{IP: net.ParseIP(ip), Port: 1000}), "test %d %d", i, j) + } + for j, ip := range test.privateIps { + require.False(t, server.fromPublicEndpoint(&net.TCPAddr{IP: net.ParseIP(ip), Port: 1000}), "test %d %d", i, j) + } + require.False(t, server.fromPublicEndpoint(nil)) + } +} + +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) type mockHsHandler struct { backend.DefaultHandshakeHandler handshakeResp func(ctx backend.ConnContext, _ *pnet.HandshakeResp) error } +type mockMemUsageProvider struct { + reject bool + snapshot mgrmem.UsageSnapshot + threshold float64 + connBufferMemDelta atomic.Int64 +} + +func (m *mockMemUsageProvider) ShouldRejectNewConn() (bool, mgrmem.UsageSnapshot, float64) { + return m.reject, m.snapshot, m.threshold +} + +func (m *mockMemUsageProvider) UpdateConnBufferMemory(delta int64) { + m.connBufferMemDelta.Add(delta) +} + // HandleHandshakeResp only panics for the first connections. func (handler *mockHsHandler) HandleHandshakeResp(ctx backend.ConnContext, resp *pnet.HandshakeResp) error { if handler.handshakeResp != nil { diff --git a/pkg/server/server.go b/pkg/server/server.go index b700e5329..fefd2ff7b 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -171,7 +171,11 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) // setup proxy server { +<<<<<<< HEAD srv.proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg, srv.certManager, idMgr, srv.replay.GetCapture(), hsHandler) +======= + srv.proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg, srv.certManager, idMgr, srv.replay.GetCapture(), srv.meter, hsHandler, srv.memManager) +>>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) if err != nil { return } From 6fbb0b1a694897ef785de4f6ee25cb5b7a09359a Mon Sep 17 00:00:00 2001 From: djshow832 Date: Tue, 21 Apr 2026 16:42:48 +0800 Subject: [PATCH 2/3] fix some conflicts --- lib/config/proxy.go | 161 ------------------------------------ lib/config/proxy_test.go | 22 ----- pkg/proxy/proxy.go | 38 +-------- pkg/proxy/proxy_test.go | 172 ++------------------------------------- pkg/server/server.go | 6 +- 5 files changed, 11 insertions(+), 388 deletions(-) diff --git a/lib/config/proxy.go b/lib/config/proxy.go index c034ecf25..168aa9304 100644 --- a/lib/config/proxy.go +++ b/lib/config/proxy.go @@ -274,164 +274,3 @@ func (cfg *Config) GetIPPort() (ip, port, statusPort string, err error) { } return } -<<<<<<< HEAD -======= - -// GetBackendClusters returns configured backend clusters. -// It keeps backward compatibility for the legacy `proxy.pd-addrs` setting. -func (cfg *Config) GetBackendClusters() []BackendCluster { - if len(cfg.Proxy.BackendClusters) > 0 { - return cfg.Proxy.BackendClusters - } - if strings.TrimSpace(cfg.Proxy.PDAddrs) == "" { - return nil - } - return []BackendCluster{{ - Name: DefaultBackendClusterName, - PDAddrs: cfg.Proxy.PDAddrs, - }} -} - -func (ps *ProxyServer) Check() error { - if ps.HighMemoryUsageRejectThreshold < 0 || ps.HighMemoryUsageRejectThreshold > 1 { - return errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.high-memory-usage-reject-threshold") - } - if ps.HighMemoryUsageRejectThreshold > 0 && ps.HighMemoryUsageRejectThreshold < 0.5 { - ps.HighMemoryUsageRejectThreshold = 0.5 - } - if _, err := ps.GetSQLAddrs(); err != nil { - return errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.addr or proxy.port-range: %s", err.Error()) - } - clusterNames := make(map[string]struct{}, len(ps.BackendClusters)) - for i, cluster := range ps.BackendClusters { - name := strings.TrimSpace(cluster.Name) - if name == "" { - return errors.Wrapf(ErrInvalidConfigValue, "proxy.backend-clusters[%d].name is empty", i) - } - if _, ok := clusterNames[name]; ok { - return errors.Wrapf(ErrInvalidConfigValue, "duplicate proxy.backend-clusters name %s", name) - } - clusterNames[name] = struct{}{} - if err := validateAddrList(cluster.PDAddrs, "proxy.backend-clusters.pd-addrs"); err != nil { - return err - } - if _, err := ParseNSServers(cluster.NSServers); err != nil { - return errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.backend-clusters.ns-servers: %s", err.Error()) - } - } - - if ps.FailoverTimeout < 0 { - return errors.Wrapf(ErrInvalidConfigValue, "proxy.failover-timeout must be greater than or equal to 0") - } - failBackends := ps.FailBackendList[:0] - failBackendSet := make(map[string]struct{}, len(ps.FailBackendList)) - for i, backendName := range ps.FailBackendList { - backendName = strings.TrimSpace(backendName) - if backendName == "" { - return errors.Wrapf(ErrInvalidConfigValue, "proxy.fail-backend-list[%d] is empty", i) - } - if _, ok := failBackendSet[backendName]; ok { - continue - } - failBackendSet[backendName] = struct{}{} - failBackends = append(failBackends, backendName) - } - ps.FailBackendList = failBackends - return nil -} - -// SplitAddrList splits a comma-separated address list, trims each address, and drops empty entries. -func SplitAddrList(addrs string) []string { - parts := strings.Split(addrs, ",") - trimmed := make([]string, 0, len(parts)) - for _, part := range parts { - addr := strings.TrimSpace(part) - if addr != "" { - trimmed = append(trimmed, addr) - } - } - return trimmed -} - -func validateAddrList(addrs, field string) error { - parts := SplitAddrList(addrs) - if len(parts) == 0 { - return errors.Wrapf(ErrInvalidConfigValue, "%s is empty", field) - } - for _, addr := range parts { - if _, _, err := net.SplitHostPort(addr); err != nil { - return errors.Wrapf(ErrInvalidConfigValue, "invalid %s address %s", field, addr) - } - } - return nil -} - -func ParseNSServers(nsServers []string) ([]string, error) { - if len(nsServers) == 0 { - return nil, nil - } - normalized := make([]string, 0, len(nsServers)) - for _, server := range nsServers { - addr, err := normalizeNSServer(server) - if err != nil { - return nil, err - } - normalized = append(normalized, addr) - } - return normalized, nil -} - -func normalizeNSServer(server string) (string, error) { - host, port, err := net.SplitHostPort(server) - if err == nil { - if host == "" { - return "", errors.Wrapf(ErrInvalidConfigValue, "host is empty") - } - portNum, err := strconv.Atoi(port) - if err != nil || portNum < 1 || portNum > 65535 { - return "", errors.Wrapf(ErrInvalidConfigValue, "port is invalid") - } - return net.JoinHostPort(host, strconv.Itoa(portNum)), nil - } - - if server == "" { - return "", errors.Wrapf(ErrInvalidConfigValue, "host is empty") - } - if strings.ContainsAny(server, "[]") { - return "", errors.Wrapf(ErrInvalidConfigValue, "host is invalid") - } - return net.JoinHostPort(server, "53"), nil -} - -func (ps *ProxyServer) GetSQLAddrs() ([]string, error) { - addrs := SplitAddrList(ps.Addr) - if len(addrs) == 0 { - if len(ps.PortRange) == 0 { - return []string{""}, nil - } - return nil, errors.Wrapf(ErrInvalidConfigValue, "proxy.addr is empty") - } - if len(ps.PortRange) == 0 { - return addrs, nil - } - if len(ps.PortRange) != 2 { - return nil, errors.Wrapf(ErrInvalidConfigValue, "proxy.port-range must contain exactly two ports") - } - start, end := ps.PortRange[0], ps.PortRange[1] - if start < 1 || start > 65535 || end < 1 || end > 65535 || start > end { - return nil, errors.Wrapf(ErrInvalidConfigValue, "proxy.port-range is invalid") - } - if len(addrs) != 1 { - return nil, errors.Wrapf(ErrInvalidConfigValue, "proxy.addr must contain exactly one host when proxy.port-range is set") - } - host, _, err := net.SplitHostPort(addrs[0]) - if err != nil { - return nil, errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.addr: %s", err.Error()) - } - sqlAddrs := make([]string, 0, end-start+1) - for port := start; port <= end; port++ { - sqlAddrs = append(sqlAddrs, net.JoinHostPort(host, strconv.Itoa(port))) - } - return sqlAddrs, nil -} ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) diff --git a/lib/config/proxy_test.go b/lib/config/proxy_test.go index baa281bef..0a353942c 100644 --- a/lib/config/proxy_test.go +++ b/lib/config/proxy_test.go @@ -20,15 +20,6 @@ var testProxyConfig = Config{ Addr: "0.0.0.0:4000", PDAddrs: "127.0.0.1:4089", ProxyServerOnline: ProxyServerOnline{ -<<<<<<< HEAD - MaxConnections: 1, - FrontendKeepalive: KeepAlive{Enabled: true}, - ProxyProtocol: "v2", - GracefulWaitBeforeShutdown: 10, - FailBackendList: []string{"db-tidb-0", "db-tidb-1"}, - FailoverTimeout: 60, - ConnBufferSize: 32 * 1024, -======= MaxConnections: 1, HighMemoryUsageRejectThreshold: 0.9, FrontendKeepalive: KeepAlive{Enabled: true}, @@ -37,19 +28,6 @@ var testProxyConfig = Config{ FailBackendList: []string{"db-tidb-0", "db-tidb-1"}, FailoverTimeout: 60, ConnBufferSize: 32 * 1024, - BackendClusters: []BackendCluster{ - { - Name: "cluster-a", - PDAddrs: "127.0.0.1:12379,127.0.0.1:22379", - NSServers: []string{"10.0.0.2", "10.0.0.3"}, - }, - { - Name: "cluster-b", - PDAddrs: "127.0.0.1:32379", - NSServers: []string{"10.0.0.4"}, - }, - }, ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) }, }, API: API{ diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index e12b4dd12..cc6cbfb18 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -71,12 +71,8 @@ func estimateConnBufferMemDelta(bufferSize int) int64 { } // NewSQLServer creates a new SQLServer. -<<<<<<< HEAD -func NewSQLServer(logger *zap.Logger, cfg *config.Config, certMgr *cert.CertManager, idMgr *id.IDManager, cpt capture.Capture, hsHandler backend.HandshakeHandler) (*SQLServer, error) { -======= func NewSQLServer(logger *zap.Logger, cfg *config.Config, certMgr *cert.CertManager, idMgr *id.IDManager, cpt capture.Capture, - meter backend.Meter, hsHandler backend.HandshakeHandler, memUsage memoryStateProvider) (*SQLServer, error) { ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) + hsHandler backend.HandshakeHandler, memUsage memoryStateProvider) (*SQLServer, error) { var err error s := &SQLServer{ logger: logger, @@ -183,12 +179,8 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { maxConns := s.mu.maxConnections // 'maxConns == 0' => unlimited connections if maxConns != 0 && conns >= maxConns { -<<<<<<< HEAD - s.logger.Warn("too many connections", zap.Uint64("max connections", maxConns), zap.String("client_addr", conn.RemoteAddr().Network()), zap.Error(conn.Close())) -======= metrics.RejectConnCounter.WithLabelValues("max_connections").Inc() s.logger.Warn("too many connections", zap.Uint64("max connections", maxConns), zap.Stringer("client_addr", conn.RemoteAddr()), zap.Error(conn.Close())) ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) return false, nil, 0, nil } @@ -242,8 +234,6 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { clientConn.Run(ctx) } -<<<<<<< HEAD -======= func (s *SQLServer) rejectConnByMemory(conn net.Conn) bool { if s.memUsage == nil { return false @@ -264,32 +254,6 @@ func (s *SQLServer) rejectConnByMemory(conn net.Conn) bool { return true } -func (s *SQLServer) fromPublicEndpoint(addr net.Addr) bool { - if addr == nil || reflect.ValueOf(addr).IsNil() { - return false - } - s.mu.RLock() - publicEndpoints := s.mu.publicEndpoints - s.mu.RUnlock() - ip, err := netutil.NetAddr2IP(addr) - if err != nil { - s.logger.Warn("failed to check public endpoint", zap.Any("addr", addr), zap.Error(err)) - return false - } - contains, err := netutil.CIDRContainsIP(publicEndpoints, ip) - if err != nil { - s.logger.Warn("failed to check public endpoint", zap.Any("ip", ip), zap.Error(err)) - return false - } - if contains { - return true - } - // The public NLB may enable preserveIP, and the incoming address is the client address, which may be a public address. - // Even if the private NLB enables preserveIP, the client address is still a private address. - return !netutil.IsPrivate(ip) -} - ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) func (s *SQLServer) PreClose() { // Step 1: HTTP status returns unhealthy so that NLB takes this instance offline and then new connections won't come. s.mu.Lock() diff --git a/pkg/proxy/proxy_test.go b/pkg/proxy/proxy_test.go index b78b3c49e..2dda7957c 100644 --- a/pkg/proxy/proxy_test.go +++ b/pkg/proxy/proxy_test.go @@ -9,11 +9,7 @@ import ( "fmt" "net" "strings" -<<<<<<< HEAD -======= - "sync" "sync/atomic" ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) "testing" "time" @@ -38,11 +34,7 @@ func TestCreateConn(t *testing.T) { cfg := &config.Config{} certManager := cert.NewCertManager() require.NoError(t, certManager.Init(cfg, lg, nil)) -<<<<<<< HEAD - server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, &mockHsHandler{}) -======= - server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, nil) ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) + server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, &mockHsHandler{}, nil) require.NoError(t, err) server.Run(context.Background(), nil) defer func() { @@ -83,7 +75,7 @@ func TestRejectConnByMemory(t *testing.T) { lg, _ := logger.CreateLoggerForTest(t) certManager := cert.NewCertManager() require.NoError(t, certManager.Init(&config.Config{}, lg, nil)) - server, err := NewSQLServer(lg, &config.Config{}, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, &mockMemUsageProvider{ + server, err := NewSQLServer(lg, &config.Config{}, certManager, id.NewIDManager(), nil, &mockHsHandler{}, &mockMemUsageProvider{ reject: true, snapshot: mgrmem.UsageSnapshot{ Used: 9 * (1 << 30), @@ -143,7 +135,7 @@ func TestTrackConnBufferMemDelta(t *testing.T) { } require.NoError(t, certManager.Init(cfg, lg, nil)) memUsage := &mockMemUsageProvider{} - server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, memUsage) + server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, &mockHsHandler{}, memUsage) require.NoError(t, err) server.Run(context.Background(), nil) defer func() { @@ -176,11 +168,7 @@ func TestGracefulCloseConn(t *testing.T) { }, }, } -<<<<<<< HEAD - server, err := NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, hsHandler) -======= - server, err := NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, nil, hsHandler, nil) ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) + server, err := NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, hsHandler, nil) require.NoError(t, err) finish := make(chan struct{}) go func() { @@ -210,11 +198,7 @@ func TestGracefulCloseConn(t *testing.T) { } // Graceful shutdown will be blocked if there are alive connections. -<<<<<<< HEAD - server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, hsHandler) -======= - server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, nil, hsHandler, nil) ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) + server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, hsHandler, nil) require.NoError(t, err) clientConn := createClientConn() go func() { @@ -240,11 +224,7 @@ func TestGracefulCloseConn(t *testing.T) { // Graceful shutdown will shut down after GracefulCloseConnTimeout. cfg.Proxy.GracefulCloseConnTimeout = 1 -<<<<<<< HEAD - server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, hsHandler) -======= - server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, nil, hsHandler, nil) ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) + server, err = NewSQLServer(lg, cfg, nil, id.NewIDManager(), nil, hsHandler, nil) require.NoError(t, err) createClientConn() go func() { @@ -272,11 +252,7 @@ func TestGracefulShutDown(t *testing.T) { }, }, } -<<<<<<< HEAD - server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, &mockHsHandler{}) -======= - server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, nil) ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) + server, err := NewSQLServer(lg, cfg, certManager, id.NewIDManager(), nil, &mockHsHandler{}, nil) require.NoError(t, err) server.Run(context.Background(), nil) @@ -314,11 +290,7 @@ func TestMultiAddr(t *testing.T) { Proxy: config.ProxyServer{ Addr: "0.0.0.0:0,0.0.0.0:0", }, -<<<<<<< HEAD - }, certManager, id.NewIDManager(), nil, &mockHsHandler{}) -======= - }, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, nil) ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) + }, certManager, id.NewIDManager(), nil, &mockHsHandler{}, nil) require.NoError(t, err) server.Run(context.Background(), nil) @@ -334,96 +306,11 @@ func TestMultiAddr(t *testing.T) { certManager.Close() } -<<<<<<< HEAD -======= -func TestPortRange(t *testing.T) { - lg, _ := logger.CreateLoggerForTest(t) - certManager := cert.NewCertManager() - err := certManager.Init(&config.Config{}, lg, nil) - require.NoError(t, err) - start, end := findFreePortRange(t, 3) - server, err := NewSQLServer(lg, &config.Config{ - Proxy: config.ProxyServer{ - Addr: fmt.Sprintf("127.0.0.1:%d", start), - PortRange: []int{start, end}, - }, - }, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}, nil) - require.NoError(t, err) - server.Run(context.Background(), nil) - - require.Len(t, server.listeners, 3) - ports := make([]int, 0, len(server.listeners)) - for _, listener := range server.listeners { - tcpAddr, ok := listener.Addr().(*net.TCPAddr) - require.True(t, ok) - ports = append(ports, tcpAddr.Port) - - conn, err := net.Dial("tcp", listener.Addr().String()) - require.NoError(t, err) - require.NoError(t, conn.Close()) - } - slices.Sort(ports) - require.Equal(t, []int{start, start + 1, end}, ports) - - server.PreClose() - require.NoError(t, server.Close()) - certManager.Close() -} - -func TestConnAddrUsesActualListenerAddr(t *testing.T) { - lg, _ := logger.CreateLoggerForTest(t) - certManager := cert.NewCertManager() - require.NoError(t, certManager.Init(&config.Config{}, lg, nil)) - - var ( - addrMu sync.Mutex - connAddr string - ) - handler := &mockHsHandler{ - getRouter: func(ctx backend.ConnContext, _ *pnet.HandshakeResp) (router.Router, error) { - addrMu.Lock() - connAddr, _ = ctx.Value(backend.ConnContextKeyConnAddr).(string) - addrMu.Unlock() - return nil, errors.New("no router") - }, - } - server, err := NewSQLServer(lg, &config.Config{ - Proxy: config.ProxyServer{ - Addr: "127.0.0.1:0", - }, - }, certManager, id.NewIDManager(), nil, nil, handler, nil) - require.NoError(t, err) - server.Run(context.Background(), nil) - defer func() { - server.PreClose() - require.NoError(t, server.Close()) - certManager.Close() - }() - - _, port, err := net.SplitHostPort(server.listeners[0].Addr().String()) - require.NoError(t, err) - mdb, err := sql.Open("mysql", fmt.Sprintf("root@tcp(127.0.0.1:%s)/test", port)) - require.NoError(t, err) - defer func() { require.NoError(t, mdb.Close()) }() - - require.ErrorContains(t, mdb.Ping(), "no router") - require.Eventually(t, func() bool { - addrMu.Lock() - defer addrMu.Unlock() - return connAddr == server.listeners[0].Addr().String() - }, 3*time.Second, 10*time.Millisecond) -} - ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) func TestWatchCfg(t *testing.T) { lg, _ := logger.CreateLoggerForTest(t) hsHandler := backend.NewDefaultHandshakeHandler(nil) cfgch := make(chan *config.Config) -<<<<<<< HEAD - server, err := NewSQLServer(lg, &config.Config{}, nil, id.NewIDManager(), nil, hsHandler) -======= - server, err := NewSQLServer(lg, &config.Config{}, nil, id.NewIDManager(), nil, nil, hsHandler, nil) ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) + server, err := NewSQLServer(lg, &config.Config{}, nil, id.NewIDManager(), nil, hsHandler, nil) require.NoError(t, err) server.Run(context.Background(), cfgch) cfg := &config.Config{ @@ -486,47 +373,6 @@ func TestRecoverPanic(t *testing.T) { certManager.Close() } -<<<<<<< HEAD -======= -func TestPublicEndpoint(t *testing.T) { - tests := []struct { - publicEndpoints []string - publicIps []string - privateIps []string - }{ - { - publicIps: []string{"137.84.2.178"}, - privateIps: []string{"10.10.10.10"}, - }, - { - publicEndpoints: []string{"10.10.10.0/24"}, - publicIps: []string{"137.84.2.178", "10.10.10.10"}, - privateIps: []string{"10.10.20.10"}, - }, - { - publicEndpoints: []string{"10.10.10.0/24", "10.10.20.10"}, - publicIps: []string{"137.84.2.178", "10.10.10.10", "10.10.20.10"}, - privateIps: []string{"10.10.20.11"}, - }, - } - - server, err := NewSQLServer(zap.NewNop(), &config.Config{}, nil, id.NewIDManager(), nil, nil, backend.NewDefaultHandshakeHandler(nil), nil) - require.NoError(t, err) - for i, test := range tests { - cfg := &config.Config{} - cfg.Proxy.PublicEndpoints = test.publicEndpoints - server.reset(cfg) - for j, ip := range test.publicIps { - require.True(t, server.fromPublicEndpoint(&net.TCPAddr{IP: net.ParseIP(ip), Port: 1000}), "test %d %d", i, j) - } - for j, ip := range test.privateIps { - require.False(t, server.fromPublicEndpoint(&net.TCPAddr{IP: net.ParseIP(ip), Port: 1000}), "test %d %d", i, j) - } - require.False(t, server.fromPublicEndpoint(nil)) - } -} - ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) type mockHsHandler struct { backend.DefaultHandshakeHandler handshakeResp func(ctx backend.ConnContext, _ *pnet.HandshakeResp) error diff --git a/pkg/server/server.go b/pkg/server/server.go index fefd2ff7b..18489b636 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -171,11 +171,7 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) // setup proxy server { -<<<<<<< HEAD - srv.proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg, srv.certManager, idMgr, srv.replay.GetCapture(), hsHandler) -======= - srv.proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg, srv.certManager, idMgr, srv.replay.GetCapture(), srv.meter, hsHandler, srv.memManager) ->>>>>>> 92599f29 (memory, config: reject connections when memory usage is high (#1120)) + srv.proxy, err = proxy.NewSQLServer(lg.Named("proxy"), cfg, srv.certManager, idMgr, srv.replay.GetCapture(), hsHandler, srv.memManager) if err != nil { return } From e8f057c78c742110b2845e07a4423da011482724 Mon Sep 17 00:00:00 2001 From: djshow832 Date: Wed, 22 Apr 2026 20:09:16 +0800 Subject: [PATCH 3/3] fix conflict --- lib/config/proxy.go | 6 ++++++ pkg/manager/memory/memory.go | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/lib/config/proxy.go b/lib/config/proxy.go index 168aa9304..443fa9d3e 100644 --- a/lib/config/proxy.go +++ b/lib/config/proxy.go @@ -213,6 +213,12 @@ func (cfg *Config) Check() error { } func (ps *ProxyServer) Check() error { + if ps.HighMemoryUsageRejectThreshold < 0 || ps.HighMemoryUsageRejectThreshold > 1 { + return errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.high-memory-usage-reject-threshold") + } + if ps.HighMemoryUsageRejectThreshold > 0 && ps.HighMemoryUsageRejectThreshold < 0.5 { + ps.HighMemoryUsageRejectThreshold = 0.5 + } if ps.FailoverTimeout < 0 { return errors.Wrapf(ErrInvalidConfigValue, "proxy.failover-timeout must be greater than or equal to 0") } diff --git a/pkg/manager/memory/memory.go b/pkg/manager/memory/memory.go index 5c293e8f3..bfe4b650e 100644 --- a/pkg/manager/memory/memory.go +++ b/pkg/manager/memory/memory.go @@ -14,7 +14,7 @@ import ( "github.com/pingcap/tidb/pkg/util/memory" "github.com/pingcap/tiproxy/lib/config" - "github.com/pingcap/tiproxy/pkg/util/waitgroup" + "github.com/pingcap/tiproxy/lib/util/waitgroup" "go.uber.org/zap" )