From 48ddf44012b836e8193a97aa498218b0a72deff6 Mon Sep 17 00:00:00 2001 From: djshow832 Date: Fri, 28 Nov 2025 19:42:26 +0800 Subject: [PATCH 1/3] readonly --- pkg/sqlreplay/conn/conn.go | 24 +++++------- pkg/sqlreplay/conn/conn_test.go | 11 ++++-- pkg/sqlreplay/replay/dry_run.go | 32 ++++++++++++---- pkg/sqlreplay/replay/replay.go | 38 ++++++++----------- pkg/util/lex/filter.go | 67 ++++++++++++++++++++++++--------- pkg/util/lex/filter_test.go | 10 ++++- pkg/util/lex/lex_test.go | 8 ++++ 7 files changed, 125 insertions(+), 65 deletions(-) diff --git a/pkg/sqlreplay/conn/conn.go b/pkg/sqlreplay/conn/conn.go index 69c5a6b2a..072cc9e06 100644 --- a/pkg/sqlreplay/conn/conn.go +++ b/pkg/sqlreplay/conn/conn.go @@ -164,11 +164,9 @@ func (c *conn) Run(ctx context.Context) { if command == nil { break } - if c.readonly { - if !c.isReadOnly(command.Value) { - c.replayStats.FilteredCmds.Add(1) - continue - } + if c.readonly && !c.isReadOnly(command.Value) { + c.replayStats.FilteredCmds.Add(1) + continue } // Quit the connection in the next round no matter what exception happens (like disconnection). if command.Value.Type == pnet.ComQuit { @@ -234,16 +232,14 @@ func (c *conn) Run(ctx context.Context) { func (c *conn) isReadOnly(command *cmd.Command) bool { switch command.Type { - case pnet.ComQuery: + case pnet.ComQuery, pnet.ComStmtPrepare: + // If the statement is not readonly, it won't be prepared. return lex.IsReadOnly(hack.String(command.Payload[1:])) - case pnet.ComStmtExecute, pnet.ComStmtSendLongData, pnet.ComStmtReset, pnet.ComStmtFetch: - stmtID := binary.LittleEndian.Uint32(command.Payload[1:5]) - ps := c.preparedStmts[stmtID] - if len(ps.text) == 0 { - // Maybe the connection is reconnected after disconnection and the prepared statements are lost. - return false - } - return lex.IsReadOnly(ps.text) + case pnet.ComStmtExecute, pnet.ComStmtSendLongData, pnet.ComStmtReset, pnet.ComStmtFetch, pnet.ComStmtClose: + // If the statement is prepared successfully, then it's readonly. + captureStmtID := binary.LittleEndian.Uint32(command.Payload[1:5]) + _, ok := c.psIDMapping[captureStmtID] + return ok case pnet.ComCreateDB, pnet.ComDropDB, pnet.ComDelayedInsert: return false } diff --git a/pkg/sqlreplay/conn/conn_test.go b/pkg/sqlreplay/conn/conn_test.go index d2f516102..2841d73dd 100644 --- a/pkg/sqlreplay/conn/conn_test.go +++ b/pkg/sqlreplay/conn/conn_test.go @@ -315,7 +315,7 @@ func TestReadOnly(t *testing.T) { { cmd: pnet.ComStmtPrepare, stmt: "insert into t value(?)", - readOnly: true, + readOnly: false, }, { cmd: pnet.ComStmtExecute, @@ -327,10 +327,15 @@ func TestReadOnly(t *testing.T) { stmt: "insert into t value(?)", readOnly: false, }, + { + cmd: pnet.ComStmtExecute, + stmt: "", + readOnly: false, + }, { cmd: pnet.ComStmtClose, stmt: "insert into t value(?)", - readOnly: true, + readOnly: false, }, { cmd: pnet.ComQuit, @@ -348,7 +353,7 @@ func TestReadOnly(t *testing.T) { for i, test := range tests { var payload []byte switch test.cmd { - case pnet.ComQuery: + case pnet.ComQuery, pnet.ComStmtPrepare: payload = append([]byte{test.cmd.Byte()}, []byte(test.stmt)...) default: conn.preparedStmts[1] = preparedStmt{text: test.stmt} diff --git a/pkg/sqlreplay/replay/dry_run.go b/pkg/sqlreplay/replay/dry_run.go index c525b5263..b786ed62e 100644 --- a/pkg/sqlreplay/replay/dry_run.go +++ b/pkg/sqlreplay/replay/dry_run.go @@ -9,6 +9,7 @@ import ( "github.com/pingcap/tiproxy/pkg/sqlreplay/cmd" "github.com/pingcap/tiproxy/pkg/sqlreplay/conn" "github.com/pingcap/tiproxy/pkg/sqlreplay/report" + "github.com/pingcap/tiproxy/pkg/util/waitgroup" ) type nopConn struct { @@ -32,24 +33,41 @@ func (c *nopConn) Stop() { c.closeCh <- c.connID } -var _ report.Report = (*mockReport)(nil) +var _ report.Report = (*nopReport)(nil) -type mockReport struct { +type nopReport struct { exceptionCh chan conn.Exception + wg waitgroup.WaitGroup + cancel context.CancelFunc } -func newMockReport(exceptionCh chan conn.Exception) *mockReport { - return &mockReport{ +func newMockReport(exceptionCh chan conn.Exception) *nopReport { + return &nopReport{ exceptionCh: exceptionCh, } } -func (mr *mockReport) Start(ctx context.Context, cfg report.ReportConfig) error { +func (mr *nopReport) Start(ctx context.Context, cfg report.ReportConfig) error { + childCtx, cancel := context.WithCancel(ctx) + mr.cancel = cancel + mr.wg.RunWithRecover(func() { mr.loop(childCtx) }, nil, nil) return nil } -func (mr *mockReport) Stop(err error) { +func (mr *nopReport) loop(ctx context.Context) { + for ctx.Err() == nil { + select { + case <-ctx.Done(): + return + case <-mr.exceptionCh: + } + } } -func (mr *mockReport) Close() { +func (mr *nopReport) Close() { + if mr.cancel != nil { + mr.cancel() + mr.cancel = nil + } + mr.wg.Wait() } diff --git a/pkg/sqlreplay/replay/replay.go b/pkg/sqlreplay/replay/replay.go index 1c1eed917..d7a284eca 100644 --- a/pkg/sqlreplay/replay/replay.go +++ b/pkg/sqlreplay/replay/replay.go @@ -360,8 +360,8 @@ func (r *replay) Start(cfg ReplayConfig, backendTLSConfig *tls.Config, hsHandler } r.report = cfg.report if r.report == nil { - if cfg.DryRun { - r.report = &mockReport{exceptionCh: r.exceptionCh} + if cfg.DryRun || cfg.ReadOnly { + r.report = &nopReport{exceptionCh: r.exceptionCh} } else { backendConnCreator := func() conn.BackendConn { return conn.NewBackendConn(r.lg.Named("be"), r.idMgr.NewID(), hsHandler, bcConfig, backendTLSConfig, r.cfg.Username, r.cfg.Password) @@ -513,8 +513,9 @@ func (r *replay) readCommands(ctx context.Context) { zap.Int("alive_conns", connCount), zap.Time("last_cmd_start_ts", time.Unix(0, r.replayStats.CurCmdTs.Load())), zap.Time("last_cmd_end_ts", time.Unix(0, r.replayStats.CurCmdEndTs.Load())), - zap.NamedError("ctx_err", ctx.Err()), - zap.Bool("graceful_stop", r.gracefulStop.Load())) + zap.Bool("graceful_stop", r.gracefulStop.Load()), + zap.Error(err), + zap.NamedError("ctx_err", ctx.Err())) // Notify the connections that the commands are finished. for _, conn := range conns { @@ -816,24 +817,22 @@ func (r *replay) saveCheckpointLoop(ctx context.Context) { } defer file.Close() - for { - // Add an interval here to avoid printing too many logs when error occurs. - if err != nil { - time.Sleep(stateSaveRetryInterval) - } - + for ctx.Err() == nil { select { case <-ctx.Done(): - return + break case <-ticker.C: err = r.saveCheckpointToFile(file) if err != nil { r.lg.Error("save current checkpoint failed", zap.Error(err)) + // Add an interval here to avoid printing too many logs when error occurs. time.Sleep(stateSaveRetryInterval) - continue } } } + if err = r.saveCheckpointToFile(file); err != nil { + r.lg.Error("save current state failed on close", zap.Error(err)) + } } func (r *replay) saveCheckpointToFile(file *os.File) error { @@ -923,6 +922,10 @@ func (r *replay) stop(err error) { r.cancel = nil } close(r.execInfoCh) + if r.report != nil { + r.report.Close() + r.report = nil + } r.endTime = time.Now() // decodedCmds - pendingCmds may be greater than replayedCmds because if a connection is closed unexpectedly, // the pending commands of that connection are discarded. We calculate the progress based on decodedCmds - pendingCmds. @@ -1006,17 +1009,6 @@ func (r *replay) Stop(err error, graceful bool) { func (r *replay) Close() { r.Stop(errors.New("shutting down"), false) - if r.report != nil { - r.report.Close() - } - // at this time, the save checkpoint loop and replay loop have exited. It's safe to update the latest - // checkpoint file. - if len(r.cfg.CheckPointFilePath) > 0 { - err := r.saveCurrentStateToFilePath(r.cfg.CheckPointFilePath) - if err != nil { - r.lg.Error("save current state failed on close", zap.Error(err)) - } - } } func getDirForInput(input string) (string, error) { diff --git a/pkg/util/lex/filter.go b/pkg/util/lex/filter.go index 658a2ac48..38a2194d0 100644 --- a/pkg/util/lex/filter.go +++ b/pkg/util/lex/filter.go @@ -3,6 +3,12 @@ package lex +import ( + "strings" + + "github.com/pingcap/tidb/pkg/parser" +) + func startsWithKeyword(sql string, keywords [][]string) bool { lexer := NewLexer(sql) tokens := make([]string, 0, 2) @@ -47,24 +53,51 @@ func IsSensitiveSQL(sql string) bool { // include SELECT FOR UPDATE because it doesn't require write privilege // include SET because SET SESSION_STATES and SET session variables should be executed // include BEGIN / COMMIT in case the user sets autocommit to false, either in SET SESSION_STATES or SET @@autocommit -var readOnlyKeywords = [][]string{ - {"SELECT"}, - {"SHOW"}, - {"WITH"}, - {"SET"}, - {"USE"}, - {"DESC"}, - {"DESCRIBE"}, - {"TABLE"}, - {"DO"}, - {"BEGIN"}, - {"COMMIT"}, - {"ROLLBACK"}, - {"START", "TRANSACTION"}, -} - func IsReadOnly(sql string) bool { - return startsWithKeyword(sql, readOnlyKeywords) + lexer := NewLexer(sql) + switch lexer.NextToken() { + case "SELECT": + for { + token := lexer.NextToken() + if token == "" { + break + } + if token == "FOR" && lexer.NextToken() == "UPDATE" { + return false + } + } + return true + case "SHOW", "WITH", "USE", "DESC", "DESCRIBE", "TABLE", "DO", "BEGIN", "COMMIT", "ROLLBACK": + return true + case "START": + return lexer.NextToken() == "TRANSACTION" + case "SET": + // Filter `set global`, `set @@global.`, `set password`, and other unknown statements. + normalized := parser.Normalize(sql, "ON") + switch { + case strings.HasPrefix(normalized, "set session_states "): + return true + case strings.HasPrefix(normalized, "set session "): + return true + case strings.HasPrefix(normalized, "set names "): + return true + case strings.HasPrefix(normalized, "set char "): + return true + case strings.HasPrefix(normalized, "set charset "): + return true + case strings.HasPrefix(normalized, "set character "): + return true + case strings.HasPrefix(normalized, "set transaction "): + return true + case strings.HasPrefix(normalized, "set @@global."): + return false + case strings.HasPrefix(normalized, "set @"): + return true + } + return false + + } + return false } var startTxnKeywords = [][]string{ diff --git a/pkg/util/lex/filter_test.go b/pkg/util/lex/filter_test.go index 809f58870..a8ceba00f 100644 --- a/pkg/util/lex/filter_test.go +++ b/pkg/util/lex/filter_test.go @@ -37,9 +37,17 @@ func TestReadOnlySQL(t *testing.T) { {`SELECT ? FROM table_name`, true}, {`(select * from t1) union (select * from t2)`, true}, {`WITH cte AS (SELECT 1, 2) SELECT * FROM cte t1, cte t2`, true}, + {`SELECT ? FROM table_name for update`, false}, + {`SELECT "for update"`, true}, {`SET session_States ''`, true}, {`SET @@session_variable=true`, true}, - {`set GLOBAL variable=false`, true}, + {`SET @@global.variable=true`, false}, + {`set GLOBAL variable=false`, false}, + {`set password = 'hello'`, false}, + {`set NAMES utf8`, true}, + {`set character utf8`, true}, + {`set transaction isolation_level = 'read committed`, true}, + {`SET @variable=true`, true}, {`insert into table t value(1)`, false}, {`desc table t`, true}, {`describe select * from t`, true}, diff --git a/pkg/util/lex/lex_test.go b/pkg/util/lex/lex_test.go index 33ced4f84..12343b53e 100644 --- a/pkg/util/lex/lex_test.go +++ b/pkg/util/lex/lex_test.go @@ -42,6 +42,14 @@ func TestNextToken(t *testing.T) { sql: `sEleCt ** from; t5ble_name`, tokens: []string{"SELECT", "FROM", "T", "BLE_NAME"}, }, + { + sql: `set @@session.autocommit = 0`, + tokens: []string{"SET", "SESSION", "AUTOCOMMIT"}, + }, + { + sql: `select "for update"`, + tokens: []string{"SELECT"}, + }, } for i, test := range tests { From 3c5b122a2c031d59a973b15bbb472f6cb45b3815 Mon Sep 17 00:00:00 2001 From: djshow832 Date: Fri, 28 Nov 2025 21:13:09 +0800 Subject: [PATCH 2/3] fix tests --- pkg/sqlreplay/conn/conn.go | 4 +--- pkg/sqlreplay/conn/conn_test.go | 28 +++++++++++++++++++--------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/pkg/sqlreplay/conn/conn.go b/pkg/sqlreplay/conn/conn.go index 072cc9e06..5dafb36ad 100644 --- a/pkg/sqlreplay/conn/conn.go +++ b/pkg/sqlreplay/conn/conn.go @@ -250,9 +250,7 @@ func (c *conn) isReadOnly(command *cmd.Command) bool { return true } -// maintain prepared statement info so that we can find its info when: -// - Judge whether an EXECUTE command is readonly -// - Get the error message when an EXECUTE command fails +// Maintain prepared statement info so that we can find its info when getting the failed statement and params. func (c *conn) updatePreparedStmts(capturedPsID uint32, request []byte, resp ExecuteResp) { switch request[0] { case pnet.ComStmtPrepare.Byte(): diff --git a/pkg/sqlreplay/conn/conn_test.go b/pkg/sqlreplay/conn/conn_test.go index 2841d73dd..9b969b8b7 100644 --- a/pkg/sqlreplay/conn/conn_test.go +++ b/pkg/sqlreplay/conn/conn_test.go @@ -223,27 +223,31 @@ func TestSkipReadOnly(t *testing.T) { readonly: false, }, { - cmd: &cmd.Command{Type: pnet.ComStmtPrepare, Payload: append([]byte{pnet.ComStmtPrepare.Byte()}, []byte("select ?")...)}, + cmd: &cmd.Command{Type: pnet.ComStmtPrepare, CapturedPsID: 1, Payload: append([]byte{pnet.ComStmtPrepare.Byte()}, []byte("select ?")...)}, readonly: true, }, { - cmd: &cmd.Command{Type: pnet.ComStmtExecute, Payload: []byte{pnet.ComStmtExecute.Byte(), 1, 0, 0, 0, 0, 0, 0, 0}}, + cmd: &cmd.Command{Type: pnet.ComStmtExecute, CapturedPsID: 1, Payload: []byte{pnet.ComStmtExecute.Byte(), 1, 0, 0, 0, 0, 0, 0, 0}}, readonly: true, }, { - cmd: &cmd.Command{Type: pnet.ComStmtFetch, Payload: []byte{pnet.ComStmtFetch.Byte(), 1, 0, 0, 0}}, + cmd: &cmd.Command{Type: pnet.ComStmtFetch, CapturedPsID: 1, Payload: []byte{pnet.ComStmtFetch.Byte(), 1, 0, 0, 0}}, readonly: true, }, { - cmd: &cmd.Command{Type: pnet.ComStmtPrepare, Payload: append([]byte{pnet.ComStmtPrepare.Byte()}, []byte("insert into t value(?)")...)}, - readonly: true, + cmd: &cmd.Command{Type: pnet.ComStmtPrepare, CapturedPsID: 2, Payload: append([]byte{pnet.ComStmtPrepare.Byte()}, []byte("insert into t value(?)")...)}, + readonly: false, + }, + { + cmd: &cmd.Command{Type: pnet.ComStmtExecute, CapturedPsID: 2, Payload: []byte{pnet.ComStmtExecute.Byte(), 2, 0, 0, 0}}, + readonly: false, }, { - cmd: &cmd.Command{Type: pnet.ComStmtExecute, Payload: []byte{pnet.ComStmtExecute.Byte(), 2, 0, 0, 0}}, + cmd: &cmd.Command{Type: pnet.ComStmtSendLongData, CapturedPsID: 2, Payload: []byte{pnet.ComStmtFetch.Byte(), 2, 0, 0, 0, 0, 0, 0, 0}}, readonly: false, }, { - cmd: &cmd.Command{Type: pnet.ComStmtSendLongData, Payload: []byte{pnet.ComStmtFetch.Byte(), 2, 0, 0, 0, 0, 0, 0, 0}}, + cmd: &cmd.Command{Type: pnet.ComStmtClose, CapturedPsID: 2, Payload: []byte{pnet.ComStmtClose.Byte(), 2, 0, 0, 0}}, readonly: false, }, { @@ -351,13 +355,19 @@ func TestReadOnly(t *testing.T) { backendConn := newMockBackendConn() conn.backendConn = backendConn for i, test := range tests { + clear(conn.psIDMapping) var payload []byte switch test.cmd { case pnet.ComQuery, pnet.ComStmtPrepare: payload = append([]byte{test.cmd.Byte()}, []byte(test.stmt)...) - default: - conn.preparedStmts[1] = preparedStmt{text: test.stmt} + case pnet.ComStmtExecute, pnet.ComStmtClose, pnet.ComStmtFetch, pnet.ComStmtReset, pnet.ComStmtSendLongData: + prepare := cmd.NewCommand(append([]byte{pnet.ComStmtPrepare.Byte()}, []byte(test.stmt)...), time.Time{}, 100) + if conn.isReadOnly(prepare) { + conn.psIDMapping[1] = 1 + } payload = []byte{test.cmd.Byte(), 1, 0, 0, 0} + default: + payload = []byte{test.cmd.Byte()} } command := cmd.NewCommand(payload, time.Time{}, 100) require.Equal(t, test.readOnly, conn.isReadOnly(command), "case %d", i) From 1ffe027e9832cb9765cb83b06d18c76e62315e51 Mon Sep 17 00:00:00 2001 From: djshow832 Date: Fri, 28 Nov 2025 21:59:10 +0800 Subject: [PATCH 3/3] remove save --- pkg/sqlreplay/replay/replay.go | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/pkg/sqlreplay/replay/replay.go b/pkg/sqlreplay/replay/replay.go index d7a284eca..6058a1b57 100644 --- a/pkg/sqlreplay/replay/replay.go +++ b/pkg/sqlreplay/replay/replay.go @@ -854,16 +854,6 @@ func (r *replay) saveCheckpointToFile(file *os.File) error { return nil } -func (r *replay) saveCurrentStateToFilePath(filePath string) error { - file, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE, 0644) - if err != nil { - return errors.Wrapf(err, "open state file %s", filePath) - } - defer file.Close() - - return r.saveCheckpointToFile(file) -} - func (r *replay) fetchCurrentCheckpoint() replayCheckpoint { return replayCheckpoint{ CurCmdTs: r.replayStats.CurCmdTs.Load(),