Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions pkg/sqlreplay/conn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,9 @@ func (c *conn) Run(ctx context.Context) {
if command == nil {
break
}
if c.readonly {
if !c.isReadOnly(command.Value) {
c.replayStats.FilteredCmds.Add(1)
continue
}
if c.readonly && !c.isReadOnly(command.Value) {
c.replayStats.FilteredCmds.Add(1)
continue
}
// Quit the connection in the next round no matter what exception happens (like disconnection).
if command.Value.Type == pnet.ComQuit {
Expand Down Expand Up @@ -234,16 +232,14 @@ func (c *conn) Run(ctx context.Context) {

func (c *conn) isReadOnly(command *cmd.Command) bool {
switch command.Type {
case pnet.ComQuery:
case pnet.ComQuery, pnet.ComStmtPrepare:
// If the statement is not readonly, it won't be prepared.
return lex.IsReadOnly(hack.String(command.Payload[1:]))
case pnet.ComStmtExecute, pnet.ComStmtSendLongData, pnet.ComStmtReset, pnet.ComStmtFetch:
stmtID := binary.LittleEndian.Uint32(command.Payload[1:5])
ps := c.preparedStmts[stmtID]
if len(ps.text) == 0 {
// Maybe the connection is reconnected after disconnection and the prepared statements are lost.
return false
}
return lex.IsReadOnly(ps.text)
case pnet.ComStmtExecute, pnet.ComStmtSendLongData, pnet.ComStmtReset, pnet.ComStmtFetch, pnet.ComStmtClose:
// If the statement is prepared successfully, then it's readonly.
captureStmtID := binary.LittleEndian.Uint32(command.Payload[1:5])
_, ok := c.psIDMapping[captureStmtID]
return ok
case pnet.ComCreateDB, pnet.ComDropDB, pnet.ComDelayedInsert:
return false
}
Expand All @@ -254,9 +250,7 @@ func (c *conn) isReadOnly(command *cmd.Command) bool {
return true
}

// maintain prepared statement info so that we can find its info when:
// - Judge whether an EXECUTE command is readonly
// - Get the error message when an EXECUTE command fails
// Maintain prepared statement info so that we can find its info when getting the failed statement and params.
func (c *conn) updatePreparedStmts(capturedPsID uint32, request []byte, resp ExecuteResp) {
switch request[0] {
case pnet.ComStmtPrepare.Byte():
Expand Down
39 changes: 27 additions & 12 deletions pkg/sqlreplay/conn/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,27 +223,31 @@ func TestSkipReadOnly(t *testing.T) {
readonly: false,
},
{
cmd: &cmd.Command{Type: pnet.ComStmtPrepare, Payload: append([]byte{pnet.ComStmtPrepare.Byte()}, []byte("select ?")...)},
cmd: &cmd.Command{Type: pnet.ComStmtPrepare, CapturedPsID: 1, Payload: append([]byte{pnet.ComStmtPrepare.Byte()}, []byte("select ?")...)},
readonly: true,
},
{
cmd: &cmd.Command{Type: pnet.ComStmtExecute, Payload: []byte{pnet.ComStmtExecute.Byte(), 1, 0, 0, 0, 0, 0, 0, 0}},
cmd: &cmd.Command{Type: pnet.ComStmtExecute, CapturedPsID: 1, Payload: []byte{pnet.ComStmtExecute.Byte(), 1, 0, 0, 0, 0, 0, 0, 0}},
readonly: true,
},
{
cmd: &cmd.Command{Type: pnet.ComStmtFetch, Payload: []byte{pnet.ComStmtFetch.Byte(), 1, 0, 0, 0}},
cmd: &cmd.Command{Type: pnet.ComStmtFetch, CapturedPsID: 1, Payload: []byte{pnet.ComStmtFetch.Byte(), 1, 0, 0, 0}},
readonly: true,
},
{
cmd: &cmd.Command{Type: pnet.ComStmtPrepare, Payload: append([]byte{pnet.ComStmtPrepare.Byte()}, []byte("insert into t value(?)")...)},
readonly: true,
cmd: &cmd.Command{Type: pnet.ComStmtPrepare, CapturedPsID: 2, Payload: append([]byte{pnet.ComStmtPrepare.Byte()}, []byte("insert into t value(?)")...)},
readonly: false,
},
{
cmd: &cmd.Command{Type: pnet.ComStmtExecute, CapturedPsID: 2, Payload: []byte{pnet.ComStmtExecute.Byte(), 2, 0, 0, 0}},
readonly: false,
},
{
cmd: &cmd.Command{Type: pnet.ComStmtExecute, Payload: []byte{pnet.ComStmtExecute.Byte(), 2, 0, 0, 0}},
cmd: &cmd.Command{Type: pnet.ComStmtSendLongData, CapturedPsID: 2, Payload: []byte{pnet.ComStmtFetch.Byte(), 2, 0, 0, 0, 0, 0, 0, 0}},
readonly: false,
},
{
cmd: &cmd.Command{Type: pnet.ComStmtSendLongData, Payload: []byte{pnet.ComStmtFetch.Byte(), 2, 0, 0, 0, 0, 0, 0, 0}},
cmd: &cmd.Command{Type: pnet.ComStmtClose, CapturedPsID: 2, Payload: []byte{pnet.ComStmtClose.Byte(), 2, 0, 0, 0}},
readonly: false,
},
{
Expand Down Expand Up @@ -315,7 +319,7 @@ func TestReadOnly(t *testing.T) {
{
cmd: pnet.ComStmtPrepare,
stmt: "insert into t value(?)",
readOnly: true,
readOnly: false,
},
{
cmd: pnet.ComStmtExecute,
Expand All @@ -327,10 +331,15 @@ func TestReadOnly(t *testing.T) {
stmt: "insert into t value(?)",
readOnly: false,
},
{
cmd: pnet.ComStmtExecute,
stmt: "",
readOnly: false,
},
{
cmd: pnet.ComStmtClose,
stmt: "insert into t value(?)",
readOnly: true,
readOnly: false,
},
{
cmd: pnet.ComQuit,
Expand All @@ -346,13 +355,19 @@ func TestReadOnly(t *testing.T) {
backendConn := newMockBackendConn()
conn.backendConn = backendConn
for i, test := range tests {
clear(conn.psIDMapping)
var payload []byte
switch test.cmd {
case pnet.ComQuery:
case pnet.ComQuery, pnet.ComStmtPrepare:
payload = append([]byte{test.cmd.Byte()}, []byte(test.stmt)...)
default:
conn.preparedStmts[1] = preparedStmt{text: test.stmt}
case pnet.ComStmtExecute, pnet.ComStmtClose, pnet.ComStmtFetch, pnet.ComStmtReset, pnet.ComStmtSendLongData:
prepare := cmd.NewCommand(append([]byte{pnet.ComStmtPrepare.Byte()}, []byte(test.stmt)...), time.Time{}, 100)
if conn.isReadOnly(prepare) {
conn.psIDMapping[1] = 1
}
payload = []byte{test.cmd.Byte(), 1, 0, 0, 0}
default:
payload = []byte{test.cmd.Byte()}
}
command := cmd.NewCommand(payload, time.Time{}, 100)
require.Equal(t, test.readOnly, conn.isReadOnly(command), "case %d", i)
Expand Down
32 changes: 25 additions & 7 deletions pkg/sqlreplay/replay/dry_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/pingcap/tiproxy/pkg/sqlreplay/cmd"
"github.com/pingcap/tiproxy/pkg/sqlreplay/conn"
"github.com/pingcap/tiproxy/pkg/sqlreplay/report"
"github.com/pingcap/tiproxy/pkg/util/waitgroup"
)

type nopConn struct {
Expand All @@ -32,24 +33,41 @@ func (c *nopConn) Stop() {
c.closeCh <- c.connID
}

var _ report.Report = (*mockReport)(nil)
var _ report.Report = (*nopReport)(nil)

type mockReport struct {
type nopReport struct {
exceptionCh chan conn.Exception
wg waitgroup.WaitGroup
cancel context.CancelFunc
}

func newMockReport(exceptionCh chan conn.Exception) *mockReport {
return &mockReport{
func newMockReport(exceptionCh chan conn.Exception) *nopReport {
return &nopReport{
exceptionCh: exceptionCh,
}
}

func (mr *mockReport) Start(ctx context.Context, cfg report.ReportConfig) error {
func (mr *nopReport) Start(ctx context.Context, cfg report.ReportConfig) error {
childCtx, cancel := context.WithCancel(ctx)
mr.cancel = cancel
mr.wg.RunWithRecover(func() { mr.loop(childCtx) }, nil, nil)
return nil
}

func (mr *mockReport) Stop(err error) {
func (mr *nopReport) loop(ctx context.Context) {
for ctx.Err() == nil {
select {
case <-ctx.Done():
return
case <-mr.exceptionCh:
}
}
}

func (mr *mockReport) Close() {
func (mr *nopReport) Close() {
if mr.cancel != nil {
mr.cancel()
mr.cancel = nil
}
mr.wg.Wait()
}
48 changes: 15 additions & 33 deletions pkg/sqlreplay/replay/replay.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ func (r *replay) Start(cfg ReplayConfig, backendTLSConfig *tls.Config, hsHandler
}
r.report = cfg.report
if r.report == nil {
if cfg.DryRun {
r.report = &mockReport{exceptionCh: r.exceptionCh}
if cfg.DryRun || cfg.ReadOnly {
r.report = &nopReport{exceptionCh: r.exceptionCh}
} else {
backendConnCreator := func() conn.BackendConn {
return conn.NewBackendConn(r.lg.Named("be"), r.idMgr.NewID(), hsHandler, bcConfig, backendTLSConfig, r.cfg.Username, r.cfg.Password)
Expand Down Expand Up @@ -513,8 +513,9 @@ func (r *replay) readCommands(ctx context.Context) {
zap.Int("alive_conns", connCount),
zap.Time("last_cmd_start_ts", time.Unix(0, r.replayStats.CurCmdTs.Load())),
zap.Time("last_cmd_end_ts", time.Unix(0, r.replayStats.CurCmdEndTs.Load())),
zap.NamedError("ctx_err", ctx.Err()),
zap.Bool("graceful_stop", r.gracefulStop.Load()))
zap.Bool("graceful_stop", r.gracefulStop.Load()),
zap.Error(err),
zap.NamedError("ctx_err", ctx.Err()))

// Notify the connections that the commands are finished.
for _, conn := range conns {
Expand Down Expand Up @@ -816,24 +817,22 @@ func (r *replay) saveCheckpointLoop(ctx context.Context) {
}
defer file.Close()

for {
// Add an interval here to avoid printing too many logs when error occurs.
if err != nil {
time.Sleep(stateSaveRetryInterval)
}

for ctx.Err() == nil {
select {
case <-ctx.Done():
return
break
case <-ticker.C:
err = r.saveCheckpointToFile(file)
if err != nil {
r.lg.Error("save current checkpoint failed", zap.Error(err))
// Add an interval here to avoid printing too many logs when error occurs.
time.Sleep(stateSaveRetryInterval)
continue
}
}
}
if err = r.saveCheckpointToFile(file); err != nil {
r.lg.Error("save current state failed on close", zap.Error(err))
}
}

func (r *replay) saveCheckpointToFile(file *os.File) error {
Expand All @@ -855,16 +854,6 @@ func (r *replay) saveCheckpointToFile(file *os.File) error {
return nil
}

func (r *replay) saveCurrentStateToFilePath(filePath string) error {
file, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
return errors.Wrapf(err, "open state file %s", filePath)
}
defer file.Close()

return r.saveCheckpointToFile(file)
}

func (r *replay) fetchCurrentCheckpoint() replayCheckpoint {
return replayCheckpoint{
CurCmdTs: r.replayStats.CurCmdTs.Load(),
Expand Down Expand Up @@ -923,6 +912,10 @@ func (r *replay) stop(err error) {
r.cancel = nil
}
close(r.execInfoCh)
if r.report != nil {
r.report.Close()
r.report = nil
}
r.endTime = time.Now()
// decodedCmds - pendingCmds may be greater than replayedCmds because if a connection is closed unexpectedly,
// the pending commands of that connection are discarded. We calculate the progress based on decodedCmds - pendingCmds.
Expand Down Expand Up @@ -1006,17 +999,6 @@ func (r *replay) Stop(err error, graceful bool) {

func (r *replay) Close() {
r.Stop(errors.New("shutting down"), false)
if r.report != nil {
r.report.Close()
}
// at this time, the save checkpoint loop and replay loop have exited. It's safe to update the latest
// checkpoint file.
if len(r.cfg.CheckPointFilePath) > 0 {
err := r.saveCurrentStateToFilePath(r.cfg.CheckPointFilePath)
if err != nil {
r.lg.Error("save current state failed on close", zap.Error(err))
}
}
}

func getDirForInput(input string) (string, error) {
Expand Down
67 changes: 50 additions & 17 deletions pkg/util/lex/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

package lex

import (
"strings"

"github.com/pingcap/tidb/pkg/parser"
)

func startsWithKeyword(sql string, keywords [][]string) bool {
lexer := NewLexer(sql)
tokens := make([]string, 0, 2)
Expand Down Expand Up @@ -47,24 +53,51 @@ func IsSensitiveSQL(sql string) bool {
// include SELECT FOR UPDATE because it doesn't require write privilege
// include SET because SET SESSION_STATES and SET session variables should be executed
// include BEGIN / COMMIT in case the user sets autocommit to false, either in SET SESSION_STATES or SET @@autocommit
var readOnlyKeywords = [][]string{
{"SELECT"},
{"SHOW"},
{"WITH"},
{"SET"},
{"USE"},
{"DESC"},
{"DESCRIBE"},
{"TABLE"},
{"DO"},
{"BEGIN"},
{"COMMIT"},
{"ROLLBACK"},
{"START", "TRANSACTION"},
}

func IsReadOnly(sql string) bool {
return startsWithKeyword(sql, readOnlyKeywords)
lexer := NewLexer(sql)
switch lexer.NextToken() {
case "SELECT":
for {
token := lexer.NextToken()
if token == "" {
break
}
if token == "FOR" && lexer.NextToken() == "UPDATE" {
return false
}
}
return true
case "SHOW", "WITH", "USE", "DESC", "DESCRIBE", "TABLE", "DO", "BEGIN", "COMMIT", "ROLLBACK":
return true
case "START":
return lexer.NextToken() == "TRANSACTION"
case "SET":
// Filter `set global`, `set @@global.`, `set password`, and other unknown statements.
normalized := parser.Normalize(sql, "ON")
switch {
case strings.HasPrefix(normalized, "set session_states "):
return true
case strings.HasPrefix(normalized, "set session "):
return true
case strings.HasPrefix(normalized, "set names "):
return true
case strings.HasPrefix(normalized, "set char "):
return true
case strings.HasPrefix(normalized, "set charset "):
return true
case strings.HasPrefix(normalized, "set character "):
return true
case strings.HasPrefix(normalized, "set transaction "):
return true
case strings.HasPrefix(normalized, "set @@global."):
return false
case strings.HasPrefix(normalized, "set @"):
return true
}
return false

}
return false
}

var startTxnKeywords = [][]string{
Expand Down
10 changes: 9 additions & 1 deletion pkg/util/lex/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,17 @@ func TestReadOnlySQL(t *testing.T) {
{`SELECT ? FROM table_name`, true},
{`(select * from t1) union (select * from t2)`, true},
{`WITH cte AS (SELECT 1, 2) SELECT * FROM cte t1, cte t2`, true},
{`SELECT ? FROM table_name for update`, false},
{`SELECT "for update"`, true},
{`SET session_States ''`, true},
{`SET @@session_variable=true`, true},
{`set GLOBAL variable=false`, true},
{`SET @@global.variable=true`, false},
{`set GLOBAL variable=false`, false},
{`set password = 'hello'`, false},
{`set NAMES utf8`, true},
{`set character utf8`, true},
{`set transaction isolation_level = 'read committed`, true},
{`SET @variable=true`, true},
{`insert into table t value(1)`, false},
{`desc table t`, true},
{`describe select * from t`, true},
Expand Down
Loading