From 6e148386a01492cd1df8e7db875428c6e12875c7 Mon Sep 17 00:00:00 2001 From: zyguan Date: Fri, 22 Mar 2024 15:07:37 +0800 Subject: [PATCH] tikvrpc: avoid data race on `XxxRequest.Context` (#1242) * tikvrpc: avoid data race on `XxxRequest.Context` Signed-off-by: zyguan * fix grammar of codegen comment Signed-off-by: zyguan * address comments Signed-off-by: zyguan * check diff of go generate Signed-off-by: zyguan * fix a typo Signed-off-by: zyguan --------- Signed-off-by: zyguan --- .github/workflows/test.yml | 5 + tikvrpc/cmds_generated.go | 366 +++++++++++++++++++++++++++++++++++++ tikvrpc/gen.sh | 97 ++++++++++ tikvrpc/tikvrpc.go | 113 +++--------- tikvrpc/tikvrpc_test.go | 88 +++++++++ 5 files changed, 586 insertions(+), 83 deletions(-) create mode 100644 tikvrpc/cmds_generated.go create mode 100755 tikvrpc/gen.sh diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 50c490973..b37c111f5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,6 +44,11 @@ jobs: with: go-version: 1.21.0 + - name: Go generate and check diff + run: | + go generate ./... + git diff --exit-code + - name: Lint uses: golangci/golangci-lint-action@v3 with: diff --git a/tikvrpc/cmds_generated.go b/tikvrpc/cmds_generated.go new file mode 100644 index 000000000..a8923b5e4 --- /dev/null +++ b/tikvrpc/cmds_generated.go @@ -0,0 +1,366 @@ +// Code generated by gen.sh. DO NOT EDIT. + +package tikvrpc + +import ( + "github.com/pingcap/kvproto/pkg/kvrpcpb" +) + +func patchCmdCtx(req *Request, cmd CmdType, ctx *kvrpcpb.Context) bool { + switch cmd { + case CmdGet: + if req.rev == 0 { + req.Get().Context = ctx + } else { + cmd := *req.Get() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdScan: + if req.rev == 0 { + req.Scan().Context = ctx + } else { + cmd := *req.Scan() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdPrewrite: + if req.rev == 0 { + req.Prewrite().Context = ctx + } else { + cmd := *req.Prewrite() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdPessimisticLock: + if req.rev == 0 { + req.PessimisticLock().Context = ctx + } else { + cmd := *req.PessimisticLock() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdPessimisticRollback: + if req.rev == 0 { + req.PessimisticRollback().Context = ctx + } else { + cmd := *req.PessimisticRollback() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdCommit: + if req.rev == 0 { + req.Commit().Context = ctx + } else { + cmd := *req.Commit() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdCleanup: + if req.rev == 0 { + req.Cleanup().Context = ctx + } else { + cmd := *req.Cleanup() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdBatchGet: + if req.rev == 0 { + req.BatchGet().Context = ctx + } else { + cmd := *req.BatchGet() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdBatchRollback: + if req.rev == 0 { + req.BatchRollback().Context = ctx + } else { + cmd := *req.BatchRollback() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdScanLock: + if req.rev == 0 { + req.ScanLock().Context = ctx + } else { + cmd := *req.ScanLock() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdResolveLock: + if req.rev == 0 { + req.ResolveLock().Context = ctx + } else { + cmd := *req.ResolveLock() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdGC: + if req.rev == 0 { + req.GC().Context = ctx + } else { + cmd := *req.GC() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdDeleteRange: + if req.rev == 0 { + req.DeleteRange().Context = ctx + } else { + cmd := *req.DeleteRange() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawGet: + if req.rev == 0 { + req.RawGet().Context = ctx + } else { + cmd := *req.RawGet() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawBatchGet: + if req.rev == 0 { + req.RawBatchGet().Context = ctx + } else { + cmd := *req.RawBatchGet() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawPut: + if req.rev == 0 { + req.RawPut().Context = ctx + } else { + cmd := *req.RawPut() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawBatchPut: + if req.rev == 0 { + req.RawBatchPut().Context = ctx + } else { + cmd := *req.RawBatchPut() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawDelete: + if req.rev == 0 { + req.RawDelete().Context = ctx + } else { + cmd := *req.RawDelete() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawBatchDelete: + if req.rev == 0 { + req.RawBatchDelete().Context = ctx + } else { + cmd := *req.RawBatchDelete() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawDeleteRange: + if req.rev == 0 { + req.RawDeleteRange().Context = ctx + } else { + cmd := *req.RawDeleteRange() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawScan: + if req.rev == 0 { + req.RawScan().Context = ctx + } else { + cmd := *req.RawScan() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawGetKeyTTL: + if req.rev == 0 { + req.RawGetKeyTTL().Context = ctx + } else { + cmd := *req.RawGetKeyTTL() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawCompareAndSwap: + if req.rev == 0 { + req.RawCompareAndSwap().Context = ctx + } else { + cmd := *req.RawCompareAndSwap() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRawChecksum: + if req.rev == 0 { + req.RawChecksum().Context = ctx + } else { + cmd := *req.RawChecksum() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdUnsafeDestroyRange: + if req.rev == 0 { + req.UnsafeDestroyRange().Context = ctx + } else { + cmd := *req.UnsafeDestroyRange() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRegisterLockObserver: + if req.rev == 0 { + req.RegisterLockObserver().Context = ctx + } else { + cmd := *req.RegisterLockObserver() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdCheckLockObserver: + if req.rev == 0 { + req.CheckLockObserver().Context = ctx + } else { + cmd := *req.CheckLockObserver() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdRemoveLockObserver: + if req.rev == 0 { + req.RemoveLockObserver().Context = ctx + } else { + cmd := *req.RemoveLockObserver() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdPhysicalScanLock: + if req.rev == 0 { + req.PhysicalScanLock().Context = ctx + } else { + cmd := *req.PhysicalScanLock() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdCop: + if req.rev == 0 { + req.Cop().Context = ctx + } else { + cmd := *req.Cop() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdBatchCop: + if req.rev == 0 { + req.BatchCop().Context = ctx + } else { + cmd := *req.BatchCop() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdMvccGetByKey: + if req.rev == 0 { + req.MvccGetByKey().Context = ctx + } else { + cmd := *req.MvccGetByKey() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdMvccGetByStartTs: + if req.rev == 0 { + req.MvccGetByStartTs().Context = ctx + } else { + cmd := *req.MvccGetByStartTs() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdSplitRegion: + if req.rev == 0 { + req.SplitRegion().Context = ctx + } else { + cmd := *req.SplitRegion() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdTxnHeartBeat: + if req.rev == 0 { + req.TxnHeartBeat().Context = ctx + } else { + cmd := *req.TxnHeartBeat() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdCheckTxnStatus: + if req.rev == 0 { + req.CheckTxnStatus().Context = ctx + } else { + cmd := *req.CheckTxnStatus() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdCheckSecondaryLocks: + if req.rev == 0 { + req.CheckSecondaryLocks().Context = ctx + } else { + cmd := *req.CheckSecondaryLocks() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdFlashbackToVersion: + if req.rev == 0 { + req.FlashbackToVersion().Context = ctx + } else { + cmd := *req.FlashbackToVersion() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + case CmdPrepareFlashbackToVersion: + if req.rev == 0 { + req.PrepareFlashbackToVersion().Context = ctx + } else { + cmd := *req.PrepareFlashbackToVersion() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ + default: + return false + } + return true +} diff --git a/tikvrpc/gen.sh b/tikvrpc/gen.sh new file mode 100755 index 000000000..4f414a35b --- /dev/null +++ b/tikvrpc/gen.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +# Copyright 2024 TiKV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +output="cmds_generated.go" + +cat < $output +// Code generated by gen.sh. DO NOT EDIT. + +package tikvrpc + +import ( + "github.com/pingcap/kvproto/pkg/kvrpcpb" +) +EOF + +cmds=( + Get + Scan + Prewrite + PessimisticLock + PessimisticRollback + Commit + Cleanup + BatchGet + BatchRollback + ScanLock + ResolveLock + GC + DeleteRange + RawGet + RawBatchGet + RawPut + RawBatchPut + RawDelete + RawBatchDelete + RawDeleteRange + RawScan + RawGetKeyTTL + RawCompareAndSwap + RawChecksum + UnsafeDestroyRange + RegisterLockObserver + CheckLockObserver + RemoveLockObserver + PhysicalScanLock + Cop + BatchCop + MvccGetByKey + MvccGetByStartTs + SplitRegion + TxnHeartBeat + CheckTxnStatus + CheckSecondaryLocks + FlashbackToVersion + PrepareFlashbackToVersion +) + +cat <> $output + +func patchCmdCtx(req *Request, cmd CmdType, ctx *kvrpcpb.Context) bool { + switch cmd { +EOF + +for cmd in "${cmds[@]}"; do +cat <> $output + case Cmd${cmd}: + if req.rev == 0 { + req.${cmd}().Context = ctx + } else { + cmd := *req.${cmd}() + cmd.Context = ctx + req.Req = &cmd + } + req.rev++ +EOF +done + +cat <> $output + default: + return false + } + return true +} +EOF diff --git a/tikvrpc/tikvrpc.go b/tikvrpc/tikvrpc.go index 7b2a65de1..61512e001 100644 --- a/tikvrpc/tikvrpc.go +++ b/tikvrpc/tikvrpc.go @@ -83,7 +83,7 @@ const ( CmdRawBatchDelete CmdRawDeleteRange CmdRawScan - CmdGetKeyTTL + CmdRawGetKeyTTL CmdRawCompareAndSwap CmdRawChecksum @@ -116,6 +116,11 @@ const ( CmdEmpty CmdType = 3072 + iota ) +// CmdType aliases. +const ( + CmdGetKeyTTL = CmdRawGetKeyTTL +) + func (t CmdType) String() string { switch t { case CmdGet: @@ -162,6 +167,10 @@ func (t CmdType) String() string { return "RawScan" case CmdRawChecksum: return "RawChecksum" + case CmdRawGetKeyTTL: + return "RawGetKeyTTL" + case CmdRawCompareAndSwap: + return "RawCompareAndSwap" case CmdUnsafeDestroyRange: return "UnsafeDestroyRange" case CmdRegisterLockObserver: @@ -219,7 +228,11 @@ func (t CmdType) String() string { // Request wraps all kv/coprocessor requests. type Request struct { Type CmdType - Req interface{} + // Req is one of the request type defined in kvrpcpb. + // + // WARN: It may be read concurrently in batch-send-loop, so you should ONLY modify it via `AttachContext`, + // otherwise there could be a risk of data race. + Req interface{} kvrpcpb.Context ReadReplicaScope string // remove txnScope after tidb removed txnScope @@ -238,6 +251,9 @@ type Request struct { ReadType string // InputRequestSource is the input source of the request, if it's not empty, the final RequestSource sent to store will be attached with the retry info. InputRequestSource string + + // rev represents the revision of the request, it's increased when `Req.Context` gets patched. + rev uint32 } // NewRequest returns new kv rpc request. @@ -707,100 +723,31 @@ type MPPStreamResponse struct { Lease } +//go:generate bash gen.sh + // AttachContext sets the request context to the request, // return false if encounter unknown request type. // Parameter `rpcCtx` use `kvrpcpb.Context` instead of `*kvrpcpb.Context` to avoid concurrent modification by shallow copy. func AttachContext(req *Request, rpcCtx kvrpcpb.Context) bool { ctx := &rpcCtx + cmd := req.Type + // CmdCopStream and CmdCop share the same request type. + if cmd == CmdCopStream { + cmd = CmdCop + } + if patchCmdCtx(req, cmd, ctx) { + return true + } switch req.Type { - case CmdGet: - req.Get().Context = ctx - case CmdScan: - req.Scan().Context = ctx - case CmdPrewrite: - req.Prewrite().Context = ctx - case CmdPessimisticLock: - req.PessimisticLock().Context = ctx - case CmdPessimisticRollback: - req.PessimisticRollback().Context = ctx - case CmdCommit: - req.Commit().Context = ctx - case CmdCleanup: - req.Cleanup().Context = ctx - case CmdBatchGet: - req.BatchGet().Context = ctx - case CmdBatchRollback: - req.BatchRollback().Context = ctx - case CmdScanLock: - req.ScanLock().Context = ctx - case CmdResolveLock: - req.ResolveLock().Context = ctx - case CmdGC: - req.GC().Context = ctx - case CmdDeleteRange: - req.DeleteRange().Context = ctx - case CmdRawGet: - req.RawGet().Context = ctx - case CmdRawBatchGet: - req.RawBatchGet().Context = ctx - case CmdRawPut: - req.RawPut().Context = ctx - case CmdRawBatchPut: - req.RawBatchPut().Context = ctx - case CmdRawDelete: - req.RawDelete().Context = ctx - case CmdRawBatchDelete: - req.RawBatchDelete().Context = ctx - case CmdRawDeleteRange: - req.RawDeleteRange().Context = ctx - case CmdRawScan: - req.RawScan().Context = ctx - case CmdGetKeyTTL: - req.RawGetKeyTTL().Context = ctx - case CmdRawCompareAndSwap: - req.RawCompareAndSwap().Context = ctx - case CmdRawChecksum: - req.RawChecksum().Context = ctx - case CmdUnsafeDestroyRange: - req.UnsafeDestroyRange().Context = ctx - case CmdRegisterLockObserver: - req.RegisterLockObserver().Context = ctx - case CmdCheckLockObserver: - req.CheckLockObserver().Context = ctx - case CmdRemoveLockObserver: - req.RemoveLockObserver().Context = ctx - case CmdPhysicalScanLock: - req.PhysicalScanLock().Context = ctx - case CmdCop: - req.Cop().Context = ctx - case CmdCopStream: - req.Cop().Context = ctx - case CmdBatchCop: - req.BatchCop().Context = ctx // Dispatching MPP tasks don't need a region context, because it's a request for store but not region. case CmdMPPTask: case CmdMPPConn: case CmdMPPCancel: case CmdMPPAlive: - case CmdMvccGetByKey: - req.MvccGetByKey().Context = ctx - case CmdMvccGetByStartTs: - req.MvccGetByStartTs().Context = ctx - case CmdSplitRegion: - req.SplitRegion().Context = ctx + // Empty command doesn't need a region context. case CmdEmpty: - req.SplitRegion().Context = ctx - case CmdTxnHeartBeat: - req.TxnHeartBeat().Context = ctx - case CmdCheckTxnStatus: - req.CheckTxnStatus().Context = ctx - case CmdCheckSecondaryLocks: - req.CheckSecondaryLocks().Context = ctx - case CmdFlashbackToVersion: - req.FlashbackToVersion().Context = ctx - case CmdPrepareFlashbackToVersion: - req.PrepareFlashbackToVersion().Context = ctx + default: return false } diff --git a/tikvrpc/tikvrpc_test.go b/tikvrpc/tikvrpc_test.go index e3d5e25fb..5dbcd31dc 100644 --- a/tikvrpc/tikvrpc_test.go +++ b/tikvrpc/tikvrpc_test.go @@ -35,8 +35,14 @@ package tikvrpc import ( + "fmt" + "math/rand" + "sync" "testing" + "time" + "github.com/pingcap/kvproto/pkg/coprocessor" + "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/tikvpb" "github.com/stretchr/testify/assert" ) @@ -47,3 +53,85 @@ func TestBatchResponse(t *testing.T) { assert.Nil(t, batchResp) assert.NotNil(t, err) } + +// https://github.com/pingcap/tidb/issues/51921 +func TestTiDB51921(t *testing.T) { + for _, r := range []*Request{ + NewRequest(CmdGet, &kvrpcpb.GetRequest{}), + NewRequest(CmdScan, &kvrpcpb.ScanRequest{}), + NewRequest(CmdPrewrite, &kvrpcpb.PrewriteRequest{}), + NewRequest(CmdPessimisticLock, &kvrpcpb.PessimisticLockRequest{}), + NewRequest(CmdPessimisticRollback, &kvrpcpb.PessimisticRollbackRequest{}), + NewRequest(CmdCommit, &kvrpcpb.CommitRequest{}), + NewRequest(CmdCleanup, &kvrpcpb.CleanupRequest{}), + NewRequest(CmdBatchGet, &kvrpcpb.BatchGetRequest{}), + NewRequest(CmdBatchRollback, &kvrpcpb.BatchRollbackRequest{}), + NewRequest(CmdScanLock, &kvrpcpb.ScanLockRequest{}), + NewRequest(CmdResolveLock, &kvrpcpb.ResolveLockRequest{}), + NewRequest(CmdGC, &kvrpcpb.GCRequest{}), + NewRequest(CmdDeleteRange, &kvrpcpb.DeleteRangeRequest{}), + NewRequest(CmdRawGet, &kvrpcpb.RawGetRequest{}), + NewRequest(CmdRawBatchGet, &kvrpcpb.RawBatchGetRequest{}), + NewRequest(CmdRawPut, &kvrpcpb.RawPutRequest{}), + NewRequest(CmdRawBatchPut, &kvrpcpb.RawBatchPutRequest{}), + NewRequest(CmdRawDelete, &kvrpcpb.RawDeleteRequest{}), + NewRequest(CmdRawBatchDelete, &kvrpcpb.RawBatchDeleteRequest{}), + NewRequest(CmdRawDeleteRange, &kvrpcpb.RawDeleteRangeRequest{}), + NewRequest(CmdRawScan, &kvrpcpb.RawScanRequest{}), + NewRequest(CmdRawGetKeyTTL, &kvrpcpb.RawGetKeyTTLRequest{}), + NewRequest(CmdRawCompareAndSwap, &kvrpcpb.RawCASRequest{}), + NewRequest(CmdRawChecksum, &kvrpcpb.RawChecksumRequest{}), + NewRequest(CmdUnsafeDestroyRange, &kvrpcpb.UnsafeDestroyRangeRequest{}), + NewRequest(CmdRegisterLockObserver, &kvrpcpb.RegisterLockObserverRequest{}), + NewRequest(CmdCheckLockObserver, &kvrpcpb.CheckLockObserverRequest{}), + NewRequest(CmdRemoveLockObserver, &kvrpcpb.RemoveLockObserverRequest{}), + NewRequest(CmdPhysicalScanLock, &kvrpcpb.PhysicalScanLockRequest{}), + NewRequest(CmdCop, &coprocessor.Request{}), + NewRequest(CmdCopStream, &coprocessor.Request{}), + NewRequest(CmdBatchCop, &coprocessor.BatchRequest{}), + NewRequest(CmdMvccGetByKey, &kvrpcpb.MvccGetByKeyRequest{}), + NewRequest(CmdMvccGetByStartTs, &kvrpcpb.MvccGetByStartTsRequest{}), + NewRequest(CmdSplitRegion, &kvrpcpb.SplitRegionRequest{}), + NewRequest(CmdTxnHeartBeat, &kvrpcpb.TxnHeartBeatRequest{}), + NewRequest(CmdCheckTxnStatus, &kvrpcpb.CheckTxnStatusRequest{}), + NewRequest(CmdCheckSecondaryLocks, &kvrpcpb.CheckSecondaryLocksRequest{}), + NewRequest(CmdFlashbackToVersion, &kvrpcpb.FlashbackToVersionRequest{}), + NewRequest(CmdPrepareFlashbackToVersion, &kvrpcpb.PrepareFlashbackToVersionRequest{}), + } { + req := r + t.Run(fmt.Sprintf("%s#%d", req.Type.String(), req.Type), func(t *testing.T) { + if req.ToBatchCommandsRequest() == nil { + t.Skipf("%s doesn't support batch commands", req.Type.String()) + } + done := make(chan struct{}) + cmds := make(chan *tikvpb.BatchCommandsRequest_Request, 8) + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + for { + select { + case <-done: + close(cmds) + return + default: + // mock relocate and retry + AttachContext(req, kvrpcpb.Context{RegionId: rand.Uint64()}) + cmds <- req.ToBatchCommandsRequest() + } + } + }() + go func() { + defer wg.Done() + for cmd := range cmds { + // mock send and marshal in batch-send-loop + cmd.Marshal() + } + }() + + time.Sleep(time.Second / 4) + close(done) + wg.Wait() + }) + } +}