Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor,distsql: refactor the base executor in tableReader #51397

Merged
merged 2 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/ddl/index_cop.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ func fetchTableScanResult(
}
err = table.FillVirtualColumnValue(
copCtx.VirtualColumnsFieldTypes, copCtx.VirtualColumnsOutputOffsets,
copCtx.ExprColumnInfos, copCtx.ColumnInfos, copCtx.SessionContext, chk)
copCtx.ExprColumnInfos, copCtx.ColumnInfos, copCtx.SessionContext.GetExprCtx(), chk)
return false, err
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/distsql/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ go_library(
deps = [
"//pkg/config",
"//pkg/ddl/placement",
"//pkg/distsql/context",
"//pkg/errctx",
"//pkg/errno",
"//pkg/expression",
Expand All @@ -21,7 +22,6 @@ go_library(
"//pkg/parser/mysql",
"//pkg/parser/terror",
"//pkg/planner/util",
"//pkg/sessionctx",
"//pkg/sessionctx/stmtctx",
"//pkg/sessionctx/variable",
"//pkg/store/copr",
Expand Down
12 changes: 12 additions & 0 deletions pkg/distsql/context/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")

go_library(
name = "context",
srcs = ["context.go"],
importpath = "github.com/pingcap/tidb/pkg/distsql/context",
visibility = ["//visibility:public"],
deps = [
"//pkg/kv",
"//pkg/sessionctx/variable",
],
)
28 changes: 28 additions & 0 deletions pkg/distsql/context/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package context

import (
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
)

// DistSQLContext gives the interface
type DistSQLContext interface {
// GetSessionVars gets the session variables.
GetSessionVars() *variable.SessionVars
// GetClient gets a kv.Client.
GetClient() kv.Client
Copy link
Member Author

@YangKeao YangKeao Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible to remove GetClient() from the session context. It's only used in executor and distsql, but I'm not sure whether it's appropriate (or whether it's helpful). Therefore, I kept it, and the distsql context doesn't need *Extend.

}
60 changes: 30 additions & 30 deletions pkg/distsql/distsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/config"
distsqlctx "github.com/pingcap/tidb/pkg/distsql/context"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/metrics"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/types"
Expand All @@ -37,14 +37,14 @@ import (
)

// GenSelectResultFromMPPResponse generates an iterator from response.
func GenSelectResultFromMPPResponse(sctx sessionctx.Context, fieldTypes []*types.FieldType, planIDs []int, rootID int, resp kv.Response) SelectResult {
func GenSelectResultFromMPPResponse(dctx distsqlctx.DistSQLContext, fieldTypes []*types.FieldType, planIDs []int, rootID int, resp kv.Response) SelectResult {
// TODO: Add metric label and set open tracing.
return &selectResult{
label: "mpp",
resp: resp,
rowLen: len(fieldTypes),
fieldTypes: fieldTypes,
ctx: sctx,
ctx: dctx,
copPlanIDs: planIDs,
rootPlanID: rootID,
storeType: kv.TiFlash,
Expand All @@ -53,7 +53,7 @@ func GenSelectResultFromMPPResponse(sctx sessionctx.Context, fieldTypes []*types

// Select sends a DAG request, returns SelectResult.
// In kvReq, KeyRanges is required, Concurrency/KeepOrder/Desc/IsolationLevel/Priority are optional.
func Select(ctx context.Context, sctx sessionctx.Context, kvReq *kv.Request, fieldTypes []*types.FieldType) (SelectResult, error) {
func Select(ctx context.Context, dctx distsqlctx.DistSQLContext, kvReq *kv.Request, fieldTypes []*types.FieldType) (SelectResult, error) {
r, ctx := tracing.StartRegionEx(ctx, "distsql.Select")
defer r.End()

Expand All @@ -62,8 +62,8 @@ func Select(ctx context.Context, sctx sessionctx.Context, kvReq *kv.Request, fie
hook.(func(*kv.Request))(kvReq)
}

enabledRateLimitAction := sctx.GetSessionVars().EnabledRateLimitAction
originalSQL := sctx.GetSessionVars().StmtCtx.OriginalSQL
enabledRateLimitAction := dctx.GetSessionVars().EnabledRateLimitAction
originalSQL := dctx.GetSessionVars().StmtCtx.OriginalSQL
eventCb := func(event trxevents.TransactionEvent) {
// Note: Do not assume this callback will be invoked within the same goroutine.
if copMeetLock := event.GetCopMeetLock(); copMeetLock != nil {
Expand All @@ -74,27 +74,27 @@ func Select(ctx context.Context, sctx sessionctx.Context, kvReq *kv.Request, fie
}
}

ctx = WithSQLKvExecCounterInterceptor(ctx, sctx.GetSessionVars().StmtCtx)
ctx = WithSQLKvExecCounterInterceptor(ctx, dctx.GetSessionVars().StmtCtx)
option := &kv.ClientSendOption{
SessionMemTracker: sctx.GetSessionVars().MemTracker,
SessionMemTracker: dctx.GetSessionVars().MemTracker,
EnabledRateLimitAction: enabledRateLimitAction,
EventCb: eventCb,
EnableCollectExecutionInfo: config.GetGlobalConfig().Instance.EnableCollectExecutionInfo.Load(),
}

if kvReq.StoreType == kv.TiFlash {
ctx = SetTiFlashConfVarsInContext(ctx, sctx)
option.TiFlashReplicaRead = sctx.GetSessionVars().TiFlashReplicaRead
option.AppendWarning = sctx.GetSessionVars().StmtCtx.AppendWarning
ctx = SetTiFlashConfVarsInContext(ctx, dctx.GetSessionVars())
option.TiFlashReplicaRead = dctx.GetSessionVars().TiFlashReplicaRead
option.AppendWarning = dctx.GetSessionVars().StmtCtx.AppendWarning
}

resp := sctx.GetClient().Send(ctx, kvReq, sctx.GetSessionVars().KVVars, option)
resp := dctx.GetClient().Send(ctx, kvReq, dctx.GetSessionVars().KVVars, option)
if resp == nil {
return nil, errors.New("client returns nil response")
}

label := metrics.LblGeneral
if sctx.GetSessionVars().InRestrictedSQL {
if dctx.GetSessionVars().InRestrictedSQL {
label = metrics.LblInternal
}

Expand All @@ -106,7 +106,7 @@ func Select(ctx context.Context, sctx sessionctx.Context, kvReq *kv.Request, fie
resp: resp,
rowLen: len(fieldTypes),
fieldTypes: fieldTypes,
ctx: sctx,
ctx: dctx,
sqlType: label,
memTracker: kvReq.MemTracker,
storeType: kvReq.StoreType,
Expand All @@ -116,34 +116,34 @@ func Select(ctx context.Context, sctx sessionctx.Context, kvReq *kv.Request, fie
}

// SetTiFlashConfVarsInContext set some TiFlash config variables in context.
func SetTiFlashConfVarsInContext(ctx context.Context, sctx sessionctx.Context) context.Context {
if sctx.GetSessionVars().TiFlashMaxThreads != -1 {
ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxTiFlashThreads, strconv.FormatInt(sctx.GetSessionVars().TiFlashMaxThreads, 10))
func SetTiFlashConfVarsInContext(ctx context.Context, vars *variable.SessionVars) context.Context {
if vars.TiFlashMaxThreads != -1 {
ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxTiFlashThreads, strconv.FormatInt(vars.TiFlashMaxThreads, 10))
}
if sctx.GetSessionVars().TiFlashMaxBytesBeforeExternalJoin != -1 {
ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxBytesBeforeTiFlashExternalJoin, strconv.FormatInt(sctx.GetSessionVars().TiFlashMaxBytesBeforeExternalJoin, 10))
if vars.TiFlashMaxBytesBeforeExternalJoin != -1 {
ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxBytesBeforeTiFlashExternalJoin, strconv.FormatInt(vars.TiFlashMaxBytesBeforeExternalJoin, 10))
}
if sctx.GetSessionVars().TiFlashMaxBytesBeforeExternalGroupBy != -1 {
ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxBytesBeforeTiFlashExternalGroupBy, strconv.FormatInt(sctx.GetSessionVars().TiFlashMaxBytesBeforeExternalGroupBy, 10))
if vars.TiFlashMaxBytesBeforeExternalGroupBy != -1 {
ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxBytesBeforeTiFlashExternalGroupBy, strconv.FormatInt(vars.TiFlashMaxBytesBeforeExternalGroupBy, 10))
}
if sctx.GetSessionVars().TiFlashMaxBytesBeforeExternalSort != -1 {
ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxBytesBeforeTiFlashExternalSort, strconv.FormatInt(sctx.GetSessionVars().TiFlashMaxBytesBeforeExternalSort, 10))
if vars.TiFlashMaxBytesBeforeExternalSort != -1 {
ctx = metadata.AppendToOutgoingContext(ctx, variable.TiDBMaxBytesBeforeTiFlashExternalSort, strconv.FormatInt(vars.TiFlashMaxBytesBeforeExternalSort, 10))
}
if sctx.GetSessionVars().TiFlashMaxQueryMemoryPerNode <= 0 {
if vars.TiFlashMaxQueryMemoryPerNode <= 0 {
ctx = metadata.AppendToOutgoingContext(ctx, variable.TiFlashMemQuotaQueryPerNode, "0")
} else {
ctx = metadata.AppendToOutgoingContext(ctx, variable.TiFlashMemQuotaQueryPerNode, strconv.FormatInt(sctx.GetSessionVars().TiFlashMaxQueryMemoryPerNode, 10))
ctx = metadata.AppendToOutgoingContext(ctx, variable.TiFlashMemQuotaQueryPerNode, strconv.FormatInt(vars.TiFlashMaxQueryMemoryPerNode, 10))
}
ctx = metadata.AppendToOutgoingContext(ctx, variable.TiFlashQuerySpillRatio, strconv.FormatFloat(sctx.GetSessionVars().TiFlashQuerySpillRatio, 'f', -1, 64))
ctx = metadata.AppendToOutgoingContext(ctx, variable.TiFlashQuerySpillRatio, strconv.FormatFloat(vars.TiFlashQuerySpillRatio, 'f', -1, 64))
return ctx
}

// SelectWithRuntimeStats sends a DAG request, returns SelectResult.
// The difference from Select is that SelectWithRuntimeStats will set copPlanIDs into selectResult,
// which can help selectResult to collect runtime stats.
func SelectWithRuntimeStats(ctx context.Context, sctx sessionctx.Context, kvReq *kv.Request,
func SelectWithRuntimeStats(ctx context.Context, dctx distsqlctx.DistSQLContext, kvReq *kv.Request,
fieldTypes []*types.FieldType, copPlanIDs []int, rootPlanID int) (SelectResult, error) {
sr, err := Select(ctx, sctx, kvReq, fieldTypes)
sr, err := Select(ctx, dctx, kvReq, fieldTypes)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -198,7 +198,7 @@ func Checksum(ctx context.Context, client kv.Client, kvReq *kv.Request, vars any
// methods are:
// 1. TypeChunk: the result is encoded using the Chunk format, refer util/chunk/chunk.go
// 2. TypeDefault: the result is encoded row by row
func SetEncodeType(ctx sessionctx.Context, dagReq *tipb.DAGRequest) {
func SetEncodeType(ctx distsqlctx.DistSQLContext, dagReq *tipb.DAGRequest) {
if canUseChunkRPC(ctx) {
dagReq.EncodeType = tipb.EncodeType_TypeChunk
setChunkMemoryLayout(dagReq)
Expand All @@ -207,7 +207,7 @@ func SetEncodeType(ctx sessionctx.Context, dagReq *tipb.DAGRequest) {
}
}

func canUseChunkRPC(ctx sessionctx.Context) bool {
func canUseChunkRPC(ctx distsqlctx.DistSQLContext) bool {
if !ctx.GetSessionVars().EnableChunkRPC {
return false
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/distsql/select_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/config"
dcontext "github.com/pingcap/tidb/pkg/distsql/context"
"github.com/pingcap/tidb/pkg/errno"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/metrics"
"github.com/pingcap/tidb/pkg/parser/terror"
"github.com/pingcap/tidb/pkg/planner/util"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/store/copr"
"github.com/pingcap/tidb/pkg/types"
Expand Down Expand Up @@ -286,7 +286,7 @@ type selectResult struct {

rowLen int
fieldTypes []*types.FieldType
ctx sessionctx.Context
ctx dcontext.DistSQLContext

selectResp *tipb.SelectResponse
selectRespSize int64 // record the selectResp.Size() when it is initialized.
Expand Down
3 changes: 3 additions & 0 deletions pkg/executor/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ go_library(
"//pkg/ddl/placement",
"//pkg/ddl/schematracker",
"//pkg/distsql",
"//pkg/distsql/context",
"//pkg/disttask/framework/handle",
"//pkg/disttask/framework/proto",
"//pkg/disttask/framework/storage",
Expand Down Expand Up @@ -131,6 +132,7 @@ go_library(
"//pkg/executor/sortexec",
"//pkg/expression",
"//pkg/expression/aggregation",
"//pkg/expression/context",
"//pkg/infoschema",
"//pkg/keyspace",
"//pkg/kv",
Expand Down Expand Up @@ -378,6 +380,7 @@ go_test(
"//pkg/ddl/placement",
"//pkg/ddl/util",
"//pkg/distsql",
"//pkg/distsql/context",
"//pkg/domain",
"//pkg/domain/infosync",
"//pkg/errctx",
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/analyze_col_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (e *AnalyzeColumnsExecV2) decodeSampleDataWithVirtualColumn(
}
}
}
err := table.FillVirtualColumnValue(fieldTps, virtualColIdx, schema.Columns, e.colsInfo, e.ctx, chk)
err := table.FillVirtualColumnValue(fieldTps, virtualColIdx, schema.Columns, e.colsInfo, e.ctx.GetExprCtx(), chk)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/batch_point_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func (e *BatchPointGetExec) Next(ctx context.Context, req *chunk.Chunk) error {
e.index++
}

err := table.FillVirtualColumnValue(e.virtualColumnRetFieldTypes, e.virtualColumnIndex, e.Schema().Columns, e.columns, e.Ctx(), req)
err := table.FillVirtualColumnValue(e.virtualColumnRetFieldTypes, e.virtualColumnIndex, e.Schema().Columns, e.columns, e.Ctx().GetExprCtx(), req)
if err != nil {
return err
}
Expand Down
Loading