Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/replayer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func main() {
logLevel := rootCmd.PersistentFlags().String("log-level", "info", "the log level: debug, info, warn, error, dpanic, panic, fatal")
startTime := rootCmd.PersistentFlags().Time("start-time", time.Now(), []string{time.RFC3339, time.RFC3339Nano}, "the time to start the replay. Format is RFC3339. Default is the current time.")
filterCommandWithRetry := rootCmd.PersistentFlags().Bool("filter-command-with-retry", false, "filter out commands that are retries according to the audit log.")
userAllowlist := rootCmd.PersistentFlags().StringSlice("user-allowlist", nil, "for audit log format only: USER values to replay (case-insensitive, normalized to lowercase); if set, lines whose [USER=...] is not listed are skipped (repeat flag or use comma-separated values).")
waitOnEOF := rootCmd.PersistentFlags().Bool("wait-on-eof", false, "wait for the next file when all the files are read.")

rootCmd.RunE = func(cmd *cobra.Command, _ []string) error {
Expand Down Expand Up @@ -150,6 +151,7 @@ func main() {
ReplayerIndex: *replayerIndex,
OutputPath: *outputPath,
FilterCommandWithRetry: *filterCommandWithRetry,
UserAllowlist: *userAllowlist,
WaitOnEOF: *waitOnEOF,
}
if err := r.StartReplay(replayCfg); err != nil {
Expand Down
7 changes: 7 additions & 0 deletions pkg/server/api/traffic.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ func (h *Server) TrafficReplay(c *gin.Context) {
cfg.Addr = c.PostForm("addr")
cfg.DryRun = strings.EqualFold(c.PostForm("dryrun"), "true")
cfg.FilterCommandWithRetry = strings.EqualFold(c.PostForm("filtercommandwithretry"), "true")
if wl := strings.TrimSpace(c.PostForm("user-allowlist")); wl != "" {
for _, part := range strings.Split(wl, ",") {
if u := strings.TrimSpace(part); u != "" {
cfg.UserAllowlist = append(cfg.UserAllowlist, strings.ToLower(u))
}
}
}
cfg.WaitOnEOF = strings.EqualFold(c.PostForm("wait-on-eof"), "true")
h.lg.Info("request: traffic replay", zap.Any("cfg", cfg))

Expand Down
4 changes: 2 additions & 2 deletions pkg/server/api/traffic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func TestTraffic(t *testing.T) {
})
// replay succeeds
doHTTP(t, http.MethodPost, "/api/traffic/replay", httpOpts{
reader: cli.GetFormReader(map[string]string{"input": "/tmp", "speed": "2.0", "username": "u1", "password": "p1"}),
reader: cli.GetFormReader(map[string]string{"input": "/tmp", "speed": "2.0", "username": "u1", "password": "p1", "user-allowlist": " Root, APP "}),
header: map[string]string{"Content-Type": "application/x-www-form-urlencoded"},
}, func(t *testing.T, r *http.Response) {
require.Equal(t, http.StatusOK, r.StatusCode)
Expand All @@ -111,7 +111,7 @@ func TestTraffic(t *testing.T) {
require.Equal(t, "replay", mgr.curJob)
startTime := mgr.replayCfg.StartTime
require.False(t, startTime.IsZero())
require.Equal(t, replay.ReplayConfig{Input: "/tmp", Username: "u1", Password: "p1", Speed: 2.0, StartTime: startTime, PSCloseStrategy: cmd.PSCloseStrategyDirected}, mgr.replayCfg)
require.Equal(t, replay.ReplayConfig{Input: "/tmp", Username: "u1", Password: "p1", Speed: 2.0, StartTime: startTime, PSCloseStrategy: cmd.PSCloseStrategyDirected, UserAllowlist: []string{"root", "app"}}, mgr.replayCfg)
})
// show succeeds
doHTTP(t, http.MethodGet, "/api/traffic/show", httpOpts{}, func(t *testing.T, r *http.Response) {
Expand Down
38 changes: 35 additions & 3 deletions pkg/sqlreplay/cmd/audit_log_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const (
auditPluginKeyCostTime = "COST_TIME"
auditPluginKeyPreparedStmtID = "PREPARED_STMT_ID"
auditPluginKeyRetry = "RETRY"
auditPluginKeyUser = "USER"

auditPluginClassGeneral = "GENERAL"
auditPluginClassTableAccess = "TABLE_ACCESS"
Expand Down Expand Up @@ -105,9 +106,11 @@ type AuditLogPluginDecoder struct {
pendingCmds []*Command
psCloseStrategy PSCloseStrategy
filterCommandWithRetry bool
idAllocator *ConnIDAllocator
dedup *DeDup
lg *zap.Logger
// userAllowlist, when non-empty, causes Decode to skip lines whose [USER=...] is not in the set.
userAllowlist map[string]struct{}
idAllocator *ConnIDAllocator
dedup *DeDup
lg *zap.Logger
}

// ConnIDAllocator allocates connection IDs for new connections.
Expand Down Expand Up @@ -171,6 +174,12 @@ func (decoder *AuditLogPluginDecoder) Decode(reader LineReader) (*Command, error
// Ignore the commands before CommandEndTime.
continue
}
if decoder.userAllowlist != nil {
user := strings.ToLower(strings.TrimSpace(kvs[auditPluginKeyUser]))
if _, ok := decoder.userAllowlist[user]; !ok {
continue
}
}

var connID uint64
if connCtx, ok := decoder.connInfo[upstreamConnID]; ok {
Expand Down Expand Up @@ -657,3 +666,26 @@ func (decoder *AuditLogPluginDecoder) isDuplicatedWrite(lastCmd *Command, kvs ma
func (decoder *AuditLogPluginDecoder) EnableFilterCommandWithRetry() {
decoder.filterCommandWithRetry = true
}

// SetUserAllowlist restricts decoding to audit log lines whose USER field is in the list.
// Matching is case-insensitive; names are stored in lowercase. Empty or all-blank entries are ignored.
// When users is empty after trimming, filtering is disabled.
func (decoder *AuditLogPluginDecoder) SetUserAllowlist(users []string) {
if len(users) == 0 {
decoder.userAllowlist = nil
return
}
m := make(map[string]struct{})
for _, u := range users {
u = strings.ToLower(strings.TrimSpace(u))
if u == "" {
continue
}
m[u] = struct{}{}
}
if len(m) == 0 {
decoder.userAllowlist = nil
} else {
decoder.userAllowlist = m
}
}
52 changes: 52 additions & 0 deletions pkg/sqlreplay/cmd/audit_log_plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package cmd

import (
"fmt"
"io"
"testing"
"time"
Expand Down Expand Up @@ -676,6 +677,57 @@ func TestDecodeSingleLine(t *testing.T) {
}
}

func TestDecodeUserAllowlist(t *testing.T) {
mkLine := func(user string, connID int, sel string) string {
return fmt.Sprintf(`[2025/09/08 21:16:29.585 +08:00] [INFO] [logger.go:77] [ID=17573373891] [TIMESTAMP=2025/09/06 16:16:29.585 +08:10] [EVENT_CLASS=GENERAL] [EVENT_SUBCLASS=] [STATUS_CODE=0] [COST_TIME=1057.834] [HOST=127.0.0.1] [CLIENT_IP=127.0.0.1] [USER=%s] [DATABASES="[]"] [TABLES="[]"] [SQL_TEXT="select %s"] [ROWS=0] [CONNECTION_ID=%d] [CLIENT_PORT=52611] [PID=89967] [COMMAND=Query] [SQL_STATEMENTS=Set] [EXECUTE_PARAMS="[]"] [CURRENT_DB=] [EVENT=COMPLETED]`, user, sel, connID)
}
lines := mkLine("admin", 1, "1") + "\n" + mkLine("root", 2, "2") + "\n"

t.Run("no filter returns both", func(t *testing.T) {
decoder := NewAuditLogPluginDecoder(NewDeDup(), zap.NewNop())
decoder.SetPSCloseStrategy(PSCloseStrategyAlways)
mr := mockReader{data: []byte(lines), filename: "f"}
cmds, err := decodeCmds(decoder, &mr)
require.ErrorIs(t, err, io.EOF)
require.Len(t, cmds, 2)
require.Contains(t, string(cmds[1].Payload), "select 2")
})

t.Run("allowlist match is case insensitive", func(t *testing.T) {
decoder := NewAuditLogPluginDecoder(NewDeDup(), zap.NewNop())
decoder.SetPSCloseStrategy(PSCloseStrategyAlways)
decoder.SetUserAllowlist([]string{"ROOT"})
one := mkLine("root", 9, "9") + "\n"
mr := mockReader{data: []byte(one), filename: "f"}
cmds, err := decodeCmds(decoder, &mr)
require.ErrorIs(t, err, io.EOF)
require.Len(t, cmds, 1)
require.Contains(t, string(cmds[0].Payload), "select 9")
})

t.Run("allowlist root skips admin", func(t *testing.T) {
decoder := NewAuditLogPluginDecoder(NewDeDup(), zap.NewNop())
decoder.SetPSCloseStrategy(PSCloseStrategyAlways)
decoder.SetUserAllowlist([]string{"root"})
mr := mockReader{data: []byte(lines), filename: "f"}
cmds, err := decodeCmds(decoder, &mr)
require.ErrorIs(t, err, io.EOF)
require.Len(t, cmds, 1)
require.Contains(t, string(cmds[0].Payload), "select 2")
require.Equal(t, uint64(2), cmds[0].UpstreamConnID)
})

t.Run("blank-only allowlist disables filter", func(t *testing.T) {
decoder := NewAuditLogPluginDecoder(NewDeDup(), zap.NewNop())
decoder.SetPSCloseStrategy(PSCloseStrategyAlways)
decoder.SetUserAllowlist([]string{"", " ", "\t"})
mr := mockReader{data: []byte(lines), filename: "f"}
cmds, err := decodeCmds(decoder, &mr)
require.ErrorIs(t, err, io.EOF)
require.Len(t, cmds, 2)
})
}

func TestDecodeMultiLines(t *testing.T) {
tests := []struct {
lines string
Expand Down
7 changes: 7 additions & 0 deletions pkg/sqlreplay/replay/replay.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ type ReplayConfig struct {
Addr string
// FilterCommandWithRetry indicates whether to filter out commands that are retries according to the audit log.
FilterCommandWithRetry bool
// UserAllowlist is only used for audit log plugin format. When non-empty, lines whose [USER=...] value
// is not in this list are ignored (comma-separated in HTTP form; repeated or comma-separated CLI flag).
// Matching is case-insensitive; HTTP form values are stored lowercased, and the audit decoder lowercases for lookup.
UserAllowlist []string
// WaitOnEOF indicates whether the replayer waits for the next file when no more files.
WaitOnEOF bool
// the following fields are for testing
Expand Down Expand Up @@ -620,6 +624,9 @@ func (r *replay) constructDecoderForReader(ctx context.Context, reader cmd.LineR
if r.cfg.FilterCommandWithRetry {
auditLogDecoder.EnableFilterCommandWithRetry()
}
if len(r.cfg.UserAllowlist) > 0 {
auditLogDecoder.SetUserAllowlist(r.cfg.UserAllowlist)
}
}

var decoder decoder
Expand Down