diff --git a/cmd/replayer/main.go b/cmd/replayer/main.go index 0bf1da1ad..75deda7fa 100644 --- a/cmd/replayer/main.go +++ b/cmd/replayer/main.go @@ -61,6 +61,7 @@ func main() { logLevel := rootCmd.PersistentFlags().String("log-level", "info", "the log level: debug, info, warn, error, dpanic, panic, fatal") startTime := rootCmd.PersistentFlags().Time("start-time", time.Now(), []string{time.RFC3339, time.RFC3339Nano}, "the time to start the replay. Format is RFC3339. Default is the current time.") filterCommandWithRetry := rootCmd.PersistentFlags().Bool("filter-command-with-retry", false, "filter out commands that are retries according to the audit log.") + userAllowlist := rootCmd.PersistentFlags().StringSlice("user-allowlist", nil, "for audit log format only: USER values to replay (case-insensitive, normalized to lowercase); if set, lines whose [USER=...] is not listed are skipped (repeat flag or use comma-separated values).") waitOnEOF := rootCmd.PersistentFlags().Bool("wait-on-eof", false, "wait for the next file when all the files are read.") rootCmd.RunE = func(cmd *cobra.Command, _ []string) error { @@ -150,6 +151,7 @@ func main() { ReplayerIndex: *replayerIndex, OutputPath: *outputPath, FilterCommandWithRetry: *filterCommandWithRetry, + UserAllowlist: *userAllowlist, WaitOnEOF: *waitOnEOF, } if err := r.StartReplay(replayCfg); err != nil { diff --git a/pkg/server/api/traffic.go b/pkg/server/api/traffic.go index ed5d6279d..a5250ad03 100644 --- a/pkg/server/api/traffic.go +++ b/pkg/server/api/traffic.go @@ -157,6 +157,13 @@ func (h *Server) TrafficReplay(c *gin.Context) { cfg.Addr = c.PostForm("addr") cfg.DryRun = strings.EqualFold(c.PostForm("dryrun"), "true") cfg.FilterCommandWithRetry = strings.EqualFold(c.PostForm("filtercommandwithretry"), "true") + if wl := strings.TrimSpace(c.PostForm("user-allowlist")); wl != "" { + for _, part := range strings.Split(wl, ",") { + if u := strings.TrimSpace(part); u != "" { + cfg.UserAllowlist = append(cfg.UserAllowlist, strings.ToLower(u)) + } + } + } cfg.WaitOnEOF = strings.EqualFold(c.PostForm("wait-on-eof"), "true") h.lg.Info("request: traffic replay", zap.Any("cfg", cfg)) diff --git a/pkg/server/api/traffic_test.go b/pkg/server/api/traffic_test.go index c86c8f82b..266f9c002 100644 --- a/pkg/server/api/traffic_test.go +++ b/pkg/server/api/traffic_test.go @@ -101,7 +101,7 @@ func TestTraffic(t *testing.T) { }) // replay succeeds doHTTP(t, http.MethodPost, "/api/traffic/replay", httpOpts{ - reader: cli.GetFormReader(map[string]string{"input": "/tmp", "speed": "2.0", "username": "u1", "password": "p1"}), + reader: cli.GetFormReader(map[string]string{"input": "/tmp", "speed": "2.0", "username": "u1", "password": "p1", "user-allowlist": " Root, APP "}), header: map[string]string{"Content-Type": "application/x-www-form-urlencoded"}, }, func(t *testing.T, r *http.Response) { require.Equal(t, http.StatusOK, r.StatusCode) @@ -111,7 +111,7 @@ func TestTraffic(t *testing.T) { require.Equal(t, "replay", mgr.curJob) startTime := mgr.replayCfg.StartTime require.False(t, startTime.IsZero()) - require.Equal(t, replay.ReplayConfig{Input: "/tmp", Username: "u1", Password: "p1", Speed: 2.0, StartTime: startTime, PSCloseStrategy: cmd.PSCloseStrategyDirected}, mgr.replayCfg) + require.Equal(t, replay.ReplayConfig{Input: "/tmp", Username: "u1", Password: "p1", Speed: 2.0, StartTime: startTime, PSCloseStrategy: cmd.PSCloseStrategyDirected, UserAllowlist: []string{"root", "app"}}, mgr.replayCfg) }) // show succeeds doHTTP(t, http.MethodGet, "/api/traffic/show", httpOpts{}, func(t *testing.T, r *http.Response) { diff --git a/pkg/sqlreplay/cmd/audit_log_plugin.go b/pkg/sqlreplay/cmd/audit_log_plugin.go index db9a2f675..a3ec8e7ac 100644 --- a/pkg/sqlreplay/cmd/audit_log_plugin.go +++ b/pkg/sqlreplay/cmd/audit_log_plugin.go @@ -30,6 +30,7 @@ const ( auditPluginKeyCostTime = "COST_TIME" auditPluginKeyPreparedStmtID = "PREPARED_STMT_ID" auditPluginKeyRetry = "RETRY" + auditPluginKeyUser = "USER" auditPluginClassGeneral = "GENERAL" auditPluginClassTableAccess = "TABLE_ACCESS" @@ -105,9 +106,11 @@ type AuditLogPluginDecoder struct { pendingCmds []*Command psCloseStrategy PSCloseStrategy filterCommandWithRetry bool - idAllocator *ConnIDAllocator - dedup *DeDup - lg *zap.Logger + // userAllowlist, when non-empty, causes Decode to skip lines whose [USER=...] is not in the set. + userAllowlist map[string]struct{} + idAllocator *ConnIDAllocator + dedup *DeDup + lg *zap.Logger } // ConnIDAllocator allocates connection IDs for new connections. @@ -171,6 +174,12 @@ func (decoder *AuditLogPluginDecoder) Decode(reader LineReader) (*Command, error // Ignore the commands before CommandEndTime. continue } + if decoder.userAllowlist != nil { + user := strings.ToLower(strings.TrimSpace(kvs[auditPluginKeyUser])) + if _, ok := decoder.userAllowlist[user]; !ok { + continue + } + } var connID uint64 if connCtx, ok := decoder.connInfo[upstreamConnID]; ok { @@ -657,3 +666,26 @@ func (decoder *AuditLogPluginDecoder) isDuplicatedWrite(lastCmd *Command, kvs ma func (decoder *AuditLogPluginDecoder) EnableFilterCommandWithRetry() { decoder.filterCommandWithRetry = true } + +// SetUserAllowlist restricts decoding to audit log lines whose USER field is in the list. +// Matching is case-insensitive; names are stored in lowercase. Empty or all-blank entries are ignored. +// When users is empty after trimming, filtering is disabled. +func (decoder *AuditLogPluginDecoder) SetUserAllowlist(users []string) { + if len(users) == 0 { + decoder.userAllowlist = nil + return + } + m := make(map[string]struct{}) + for _, u := range users { + u = strings.ToLower(strings.TrimSpace(u)) + if u == "" { + continue + } + m[u] = struct{}{} + } + if len(m) == 0 { + decoder.userAllowlist = nil + } else { + decoder.userAllowlist = m + } +} diff --git a/pkg/sqlreplay/cmd/audit_log_plugin_test.go b/pkg/sqlreplay/cmd/audit_log_plugin_test.go index e642a0004..9e71fad60 100644 --- a/pkg/sqlreplay/cmd/audit_log_plugin_test.go +++ b/pkg/sqlreplay/cmd/audit_log_plugin_test.go @@ -4,6 +4,7 @@ package cmd import ( + "fmt" "io" "testing" "time" @@ -676,6 +677,57 @@ func TestDecodeSingleLine(t *testing.T) { } } +func TestDecodeUserAllowlist(t *testing.T) { + mkLine := func(user string, connID int, sel string) string { + return fmt.Sprintf(`[2025/09/08 21:16:29.585 +08:00] [INFO] [logger.go:77] [ID=17573373891] [TIMESTAMP=2025/09/06 16:16:29.585 +08:10] [EVENT_CLASS=GENERAL] [EVENT_SUBCLASS=] [STATUS_CODE=0] [COST_TIME=1057.834] [HOST=127.0.0.1] [CLIENT_IP=127.0.0.1] [USER=%s] [DATABASES="[]"] [TABLES="[]"] [SQL_TEXT="select %s"] [ROWS=0] [CONNECTION_ID=%d] [CLIENT_PORT=52611] [PID=89967] [COMMAND=Query] [SQL_STATEMENTS=Set] [EXECUTE_PARAMS="[]"] [CURRENT_DB=] [EVENT=COMPLETED]`, user, sel, connID) + } + lines := mkLine("admin", 1, "1") + "\n" + mkLine("root", 2, "2") + "\n" + + t.Run("no filter returns both", func(t *testing.T) { + decoder := NewAuditLogPluginDecoder(NewDeDup(), zap.NewNop()) + decoder.SetPSCloseStrategy(PSCloseStrategyAlways) + mr := mockReader{data: []byte(lines), filename: "f"} + cmds, err := decodeCmds(decoder, &mr) + require.ErrorIs(t, err, io.EOF) + require.Len(t, cmds, 2) + require.Contains(t, string(cmds[1].Payload), "select 2") + }) + + t.Run("allowlist match is case insensitive", func(t *testing.T) { + decoder := NewAuditLogPluginDecoder(NewDeDup(), zap.NewNop()) + decoder.SetPSCloseStrategy(PSCloseStrategyAlways) + decoder.SetUserAllowlist([]string{"ROOT"}) + one := mkLine("root", 9, "9") + "\n" + mr := mockReader{data: []byte(one), filename: "f"} + cmds, err := decodeCmds(decoder, &mr) + require.ErrorIs(t, err, io.EOF) + require.Len(t, cmds, 1) + require.Contains(t, string(cmds[0].Payload), "select 9") + }) + + t.Run("allowlist root skips admin", func(t *testing.T) { + decoder := NewAuditLogPluginDecoder(NewDeDup(), zap.NewNop()) + decoder.SetPSCloseStrategy(PSCloseStrategyAlways) + decoder.SetUserAllowlist([]string{"root"}) + mr := mockReader{data: []byte(lines), filename: "f"} + cmds, err := decodeCmds(decoder, &mr) + require.ErrorIs(t, err, io.EOF) + require.Len(t, cmds, 1) + require.Contains(t, string(cmds[0].Payload), "select 2") + require.Equal(t, uint64(2), cmds[0].UpstreamConnID) + }) + + t.Run("blank-only allowlist disables filter", func(t *testing.T) { + decoder := NewAuditLogPluginDecoder(NewDeDup(), zap.NewNop()) + decoder.SetPSCloseStrategy(PSCloseStrategyAlways) + decoder.SetUserAllowlist([]string{"", " ", "\t"}) + mr := mockReader{data: []byte(lines), filename: "f"} + cmds, err := decodeCmds(decoder, &mr) + require.ErrorIs(t, err, io.EOF) + require.Len(t, cmds, 2) + }) +} + func TestDecodeMultiLines(t *testing.T) { tests := []struct { lines string diff --git a/pkg/sqlreplay/replay/replay.go b/pkg/sqlreplay/replay/replay.go index 697b6ae4d..45ec91548 100644 --- a/pkg/sqlreplay/replay/replay.go +++ b/pkg/sqlreplay/replay/replay.go @@ -124,6 +124,10 @@ type ReplayConfig struct { Addr string // FilterCommandWithRetry indicates whether to filter out commands that are retries according to the audit log. FilterCommandWithRetry bool + // UserAllowlist is only used for audit log plugin format. When non-empty, lines whose [USER=...] value + // is not in this list are ignored (comma-separated in HTTP form; repeated or comma-separated CLI flag). + // Matching is case-insensitive; HTTP form values are stored lowercased, and the audit decoder lowercases for lookup. + UserAllowlist []string // WaitOnEOF indicates whether the replayer waits for the next file when no more files. WaitOnEOF bool // the following fields are for testing @@ -620,6 +624,9 @@ func (r *replay) constructDecoderForReader(ctx context.Context, reader cmd.LineR if r.cfg.FilterCommandWithRetry { auditLogDecoder.EnableFilterCommandWithRetry() } + if len(r.cfg.UserAllowlist) > 0 { + auditLogDecoder.SetUserAllowlist(r.cfg.UserAllowlist) + } } var decoder decoder