diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 540815919fa9..8685eb43f2da 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -246,6 +246,12 @@ type TxnCtxNoNeedToRestore struct { // fair locking mode, and it takes effect (which is determined according to whether lock-with-conflict // has occurred during execution of any statement). FairLockingEffective bool + + // CurrentStmtPessimisticLockCache is the cache for pessimistic locked keys in the current statement. + // It is merged into `pessimisticLockCache` after a statement finishes. + // Read results cannot be directly written into pessimisticLockCache because failed statement need to rollback + // its pessimistic locks. + CurrentStmtPessimisticLockCache map[string][]byte } // SavepointRecord indicates a transaction's savepoint record. @@ -317,22 +323,32 @@ func (tc *TransactionContext) UpdateDeltaForTable(physicalTableID int64, delta i // GetKeyInPessimisticLockCache gets a key in pessimistic lock cache. func (tc *TransactionContext) GetKeyInPessimisticLockCache(key kv.Key) (val []byte, ok bool) { - if tc.pessimisticLockCache == nil { + if tc.pessimisticLockCache == nil && tc.CurrentStmtPessimisticLockCache == nil { return nil, false } - val, ok = tc.pessimisticLockCache[string(key)] - if ok { - tc.PessimisticCacheHit++ + if tc.CurrentStmtPessimisticLockCache != nil { + val, ok = tc.CurrentStmtPessimisticLockCache[string(key)] + if ok { + tc.PessimisticCacheHit++ + return + } + } + if tc.pessimisticLockCache != nil { + val, ok = tc.pessimisticLockCache[string(key)] + if ok { + tc.PessimisticCacheHit++ + } } return } -// SetPessimisticLockCache sets a key value pair into pessimistic lock cache. +// SetPessimisticLockCache sets a key value pair in pessimistic lock cache. +// The value is buffered in the statement cache until the current statement finishes. func (tc *TransactionContext) SetPessimisticLockCache(key kv.Key, val []byte) { - if tc.pessimisticLockCache == nil { - tc.pessimisticLockCache = map[string][]byte{} + if tc.CurrentStmtPessimisticLockCache == nil { + tc.CurrentStmtPessimisticLockCache = make(map[string][]byte) } - tc.pessimisticLockCache[string(key)] = val + tc.CurrentStmtPessimisticLockCache[string(key)] = val } // Cleanup clears up transaction info that no longer use. @@ -345,6 +361,7 @@ func (tc *TransactionContext) Cleanup() { tc.relatedTableForMDL = nil tc.tdmLock.Unlock() tc.pessimisticLockCache = nil + tc.CurrentStmtPessimisticLockCache = nil tc.IsStaleness = false tc.Savepoints = nil tc.EnableMDL = false @@ -380,6 +397,8 @@ func (tc *TransactionContext) GetCurrentSavepoint() TxnCtxNeedToRestore { } pessimisticLockCache := make(map[string][]byte, len(tc.pessimisticLockCache)) maps.Copy(pessimisticLockCache, tc.pessimisticLockCache) + CurrentStmtPessimisticLockCache := make(map[string][]byte, len(tc.CurrentStmtPessimisticLockCache)) + maps.Copy(CurrentStmtPessimisticLockCache, tc.CurrentStmtPessimisticLockCache) cachedTables := make(map[int64]interface{}, len(tc.CachedTables)) maps.Copy(cachedTables, tc.CachedTables) return TxnCtxNeedToRestore{ @@ -448,6 +467,21 @@ func (tc *TransactionContext) RollbackToSavepoint(name string) *SavepointRecord return nil } +// FlushStmtPessimisticLockCache merges the current statement pessimistic lock cache into transaction pessimistic lock +// cache. The caller may need to clear the stmt cache itself. +func (tc *TransactionContext) FlushStmtPessimisticLockCache() { + if tc.CurrentStmtPessimisticLockCache == nil { + return + } + if tc.pessimisticLockCache == nil { + tc.pessimisticLockCache = make(map[string][]byte) + } + for key, val := range tc.CurrentStmtPessimisticLockCache { + tc.pessimisticLockCache[key] = val + } + tc.CurrentStmtPessimisticLockCache = nil +} + // WriteStmtBufs can be used by insert/replace/delete/update statement. // TODO: use a common memory pool to replace this. type WriteStmtBufs struct { diff --git a/sessionctx/variable/session_test.go b/sessionctx/variable/session_test.go index d43bfd274ded..39d647b942bf 100644 --- a/sessionctx/variable/session_test.go +++ b/sessionctx/variable/session_test.go @@ -353,6 +353,7 @@ func TestTransactionContextSavepoint(t *testing.T) { }, } tc.SetPessimisticLockCache([]byte{'a'}, []byte{'a'}) + tc.FlushStmtPessimisticLockCache() tc.AddSavepoint("S1", nil) require.Equal(t, 1, len(tc.Savepoints)) @@ -372,6 +373,7 @@ func TestTransactionContextSavepoint(t *testing.T) { TableID: 9, } tc.SetPessimisticLockCache([]byte{'b'}, []byte{'b'}) + tc.FlushStmtPessimisticLockCache() tc.AddSavepoint("S2", nil) require.Equal(t, 2, len(tc.Savepoints)) @@ -389,6 +391,7 @@ func TestTransactionContextSavepoint(t *testing.T) { TableID: 13, } tc.SetPessimisticLockCache([]byte{'c'}, []byte{'c'}) + tc.FlushStmtPessimisticLockCache() tc.AddSavepoint("s2", nil) require.Equal(t, 2, len(tc.Savepoints)) diff --git a/sessiontxn/isolation/BUILD.bazel b/sessiontxn/isolation/BUILD.bazel index dc4f2673578c..491318a67d72 100644 --- a/sessiontxn/isolation/BUILD.bazel +++ b/sessiontxn/isolation/BUILD.bazel @@ -47,7 +47,7 @@ go_test( "serializable_test.go", ], flaky = True, - shard_count = 27, + shard_count = 29, deps = [ ":isolation", "//config", diff --git a/sessiontxn/isolation/base.go b/sessiontxn/isolation/base.go index 975b3e7266be..0c0c1be98aca 100644 --- a/sessiontxn/isolation/base.go +++ b/sessiontxn/isolation/base.go @@ -222,6 +222,7 @@ func (p *baseTxnContextProvider) OnPessimisticStmtEnd(_ context.Context, _ bool) // OnStmtRetry is the hook that should be called when a statement is retried internally. func (p *baseTxnContextProvider) OnStmtRetry(ctx context.Context) error { p.ctx = ctx + p.sctx.GetSessionVars().TxnCtx.CurrentStmtPessimisticLockCache = nil return nil } @@ -550,6 +551,12 @@ func (p *basePessimisticTxnContextProvider) OnPessimisticStmtEnd(ctx context.Con } } } + + if isSuccessful { + p.sctx.GetSessionVars().TxnCtx.FlushStmtPessimisticLockCache() + } else { + p.sctx.GetSessionVars().TxnCtx.CurrentStmtPessimisticLockCache = nil + } return nil } diff --git a/sessiontxn/isolation/readcommitted_test.go b/sessiontxn/isolation/readcommitted_test.go index 0c727bef3f25..eb70c94a58a9 100644 --- a/sessiontxn/isolation/readcommitted_test.go +++ b/sessiontxn/isolation/readcommitted_test.go @@ -17,6 +17,7 @@ package isolation_test import ( "context" "fmt" + "sync" "testing" "time" @@ -562,3 +563,60 @@ func initializePessimisticRCProvider(t testing.TB, tk *testkit.TestKit) *isolati tk.MustExec("begin pessimistic") return assert.CheckAndGetProvider(t) } + +func TestFailedDMLConsistency1(t *testing.T) { + store := testkit.CreateMockStore(t) + + tk1 := testkit.NewTestKit(t, store) + tk1.MustExec("use test") + tk1.MustExec("CREATE TABLE t(id INT primary key, v int not null);") + tk1.MustExec("insert into t values (1, 1)") + tk1.MustExec("set @@tidb_txn_assertion_level = \"strict\";") + tk1.MustExec("set transaction isolation level read committed;") + tk1.MustExec("begin pessimistic") + tk1.MustExec("insert into t values (0, 0)") + tk2 := testkit.NewTestKit(t, store) + tk2.MustExec("use test") + tk2.MustExec("begin pessimistic") + tk1.Exec("update t set v = null where id in (1);") + tk2.MustExec("delete from t where id = 1;") + var wg sync.WaitGroup + wg.Add(1) + go func() { + println("@@ -- begin delete") + tk1.MustExec("delete from t where id in (1);") + println("@@ -- end delete") + wg.Done() + }() + time.Sleep(100 * time.Millisecond) + tk2.MustExec("commit") + wg.Wait() + tk1.MustExec("commit") +} + +func TestFailedDMLConsistency2(t *testing.T) { + store := testkit.CreateMockStore(t) + tk1 := testkit.NewTestKit(t, store) + tk1.MustExec("set @@tidb_txn_assertion_level=strict") + tk1.MustExec("use test") + tk1.MustExec("CREATE TABLE t(id INT primary key, v int not null, v2 int, index (id), unique index (v2));") + tk1.MustExec("INSERT INTO t VALUES (1, 1, 1);") + tk1.MustExec("set transaction isolation level read committed;") + tk1.MustExec("begin pessimistic") + tk1.MustExec("insert into t values (0, 0, 0)") + tk2 := testkit.NewTestKit(t, store) + tk2.MustExec("use test") + tk2.MustExec("begin pessimistic") + tk1.Exec("update t set v = null where id in (1);") + tk2.MustExec("update t set id = 10 where id = 1;") + var wg sync.WaitGroup + wg.Add(1) + go func() { + tk1.MustExec("delete from t where id in (1, 2);") + wg.Done() + }() + tk2.MustExec("commit") + wg.Wait() + tk1.MustExec("commit") + tk1.MustExec("admin check table t") +}