Skip to content

Commit

Permalink
disttask: fix subtask finished immediately and mark success when enco…
Browse files Browse the repository at this point in the history
…untering network partition (#48660)

ref #46258, close pingcap/tidb#48649
  • Loading branch information
ywqzzy committed Nov 17, 2023
1 parent 6b3df66 commit 844ba42
Show file tree
Hide file tree
Showing 46 changed files with 1,044 additions and 809 deletions.
20 changes: 9 additions & 11 deletions pkg/ddl/backfilling_dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@ import (
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/testkit"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/util"
"go.uber.org/zap"
)

func TestBackfillingDispatcherLocalMode(t *testing.T) {
Expand Down Expand Up @@ -156,7 +154,8 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
}, 1, 1, time.Second)
defer pool.Close()
ctx := context.WithValue(context.Background(), "etcd", true)
mgr := storage.NewTaskManager(util.WithInternalSourceType(ctx, "taskManager"), pool)
ctx = util.WithInternalSourceType(ctx, "handle")
mgr := storage.NewTaskManager(pool)
storage.SetTaskManager(mgr)
dspManager, err := dispatcher.NewManager(util.WithInternalSourceType(ctx, "dispatcher"), mgr, "host:port")
require.NoError(t, err)
Expand All @@ -175,7 +174,7 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
ext.(*ddl.BackfillingDispatcherExt).GlobalSort = true
dsp.Extension = ext

taskID, err := mgr.AddNewGlobalTask(task.Key, proto.Backfill, 1, task.Meta)
taskID, err := mgr.AddNewGlobalTask(ctx, task.Key, proto.Backfill, 1, task.Meta)
require.NoError(t, err)
task.ID = taskID
serverInfos, _, err := dsp.GetEligibleInstances(context.Background(), task)
Expand All @@ -192,11 +191,10 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
for _, m := range subtaskMetas {
subtasks = append(subtasks, proto.NewSubtask(task.Step, task.ID, task.Type, "", m))
}
_, err = mgr.UpdateGlobalTaskAndAddSubTasks(task, subtasks, proto.TaskStatePending)
_, err = mgr.UpdateGlobalTaskAndAddSubTasks(ctx, task, subtasks, proto.TaskStatePending)
require.NoError(t, err)
gotSubtasks, err := mgr.GetSubtasksForImportInto(taskID, ddl.StepReadIndex)
gotSubtasks, err := mgr.GetSubtasksForImportInto(ctx, taskID, ddl.StepReadIndex)
require.NoError(t, err)
logutil.BgLogger().Info("ywq test", zap.Any("len", len(gotSubtasks)))

// update meta, same as import into.
sortStepMeta := &ddl.BackfillSubTaskMeta{
Expand All @@ -216,7 +214,7 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
sortStepMetaBytes, err := json.Marshal(sortStepMeta)
require.NoError(t, err)
for _, s := range gotSubtasks {
require.NoError(t, mgr.FinishSubtask(s.ID, sortStepMetaBytes))
require.NoError(t, mgr.FinishSubtask(ctx, s.SchedulerID, s.ID, sortStepMetaBytes))
}
// 2. to merge-sort stage.
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/ddl/forceMergeSort", `return()`))
Expand All @@ -234,9 +232,9 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
for _, m := range subtaskMetas {
subtasks = append(subtasks, proto.NewSubtask(task.Step, task.ID, task.Type, "", m))
}
_, err = mgr.UpdateGlobalTaskAndAddSubTasks(task, subtasks, proto.TaskStatePending)
_, err = mgr.UpdateGlobalTaskAndAddSubTasks(ctx, task, subtasks, proto.TaskStatePending)
require.NoError(t, err)
gotSubtasks, err = mgr.GetSubtasksForImportInto(taskID, task.Step)
gotSubtasks, err = mgr.GetSubtasksForImportInto(ctx, taskID, task.Step)
require.NoError(t, err)
mergeSortStepMeta := &ddl.BackfillSubTaskMeta{
SortedKVMeta: external.SortedKVMeta{
Expand All @@ -255,7 +253,7 @@ func TestBackfillingDispatcherGlobalSortMode(t *testing.T) {
mergeSortStepMetaBytes, err := json.Marshal(mergeSortStepMeta)
require.NoError(t, err)
for _, s := range gotSubtasks {
require.NoError(t, mgr.FinishSubtask(s.ID, mergeSortStepMetaBytes))
require.NoError(t, mgr.FinishSubtask(ctx, s.SchedulerID, s.ID, mergeSortStepMetaBytes))
}
// 3. to write&ingest stage.
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/ddl/mockWriteIngest", "return(true)"))
Expand Down
14 changes: 8 additions & 6 deletions pkg/ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -2060,6 +2060,8 @@ func (w *worker) executeDistGlobalTask(reorgInfo *reorgInfo) error {
taskType := proto.Backfill
taskKey := fmt.Sprintf("ddl/%s/%d", taskType, reorgInfo.Job.ID)
g, ctx := errgroup.WithContext(context.Background())
ctx = kv.WithInternalSourceType(ctx, kv.InternalDistTask)

done := make(chan struct{})

// generate taskKey for multi schema change.
Expand All @@ -2076,7 +2078,7 @@ func (w *worker) executeDistGlobalTask(reorgInfo *reorgInfo) error {
if err != nil {
return err
}
task, err := taskManager.GetGlobalTaskByKeyWithHistory(taskKey)
task, err := taskManager.GetGlobalTaskByKeyWithHistory(w.ctx, taskKey)
if err != nil {
return err
}
Expand All @@ -2095,7 +2097,7 @@ func (w *worker) executeDistGlobalTask(reorgInfo *reorgInfo) error {
backoffer := backoff.NewExponential(dispatcher.RetrySQLInterval, 2, dispatcher.RetrySQLMaxInterval)
err := handle.RunWithRetry(ctx, dispatcher.RetrySQLTimes, backoffer, logutil.BgLogger(),
func(ctx context.Context) (bool, error) {
return true, handle.ResumeTask(taskKey)
return true, handle.ResumeTask(w.ctx, taskKey)
},
)
if err != nil {
Expand Down Expand Up @@ -2158,7 +2160,7 @@ func (w *worker) executeDistGlobalTask(reorgInfo *reorgInfo) error {
case <-checkFinishTk.C:
if err = w.isReorgRunnable(reorgInfo.Job.ID, true); err != nil {
if dbterror.ErrPausedDDLJob.Equal(err) {
if err = handle.PauseTask(taskKey); err != nil {
if err = handle.PauseTask(w.ctx, taskKey); err != nil {
logutil.BgLogger().Error("pause global task error", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
continue
}
Expand All @@ -2170,7 +2172,7 @@ func (w *worker) executeDistGlobalTask(reorgInfo *reorgInfo) error {
if !dbterror.ErrCancelledDDLJob.Equal(err) {
return errors.Trace(err)
}
if err = handle.CancelGlobalTask(taskKey); err != nil {
if err = handle.CancelGlobalTask(w.ctx, taskKey); err != nil {
logutil.BgLogger().Error("cancel global task error", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
// continue to cancel global task.
continue
Expand All @@ -2191,12 +2193,12 @@ func (w *worker) updateJobRowCount(taskKey string, jobID int64) {
logutil.BgLogger().Warn("cannot get task manager", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
return
}
gTask, err := taskMgr.GetGlobalTaskByKey(taskKey)
gTask, err := taskMgr.GetGlobalTaskByKey(w.ctx, taskKey)
if err != nil || gTask == nil {
logutil.BgLogger().Warn("cannot get global task", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
return
}
rowCount, err := taskMgr.GetSubtaskRowCount(gTask.ID, proto.StepOne)
rowCount, err := taskMgr.GetSubtaskRowCount(w.ctx, gTask.ID, proto.StepOne)
if err != nil {
logutil.BgLogger().Warn("cannot get subtask row count", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
return
Expand Down
1 change: 1 addition & 0 deletions pkg/disttask/framework/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ go_test(
"//pkg/testkit",
"@com_github_pingcap_failpoint//:failpoint",
"@com_github_stretchr_testify//require",
"@com_github_tikv_client_go_v2//util",
"@org_uber_go_mock//gomock",
],
)
42 changes: 21 additions & 21 deletions pkg/disttask/framework/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func (*BaseDispatcher) Close() {

// refreshTask fetch task state from tidb_global_task table.
func (d *BaseDispatcher) refreshTask() error {
newTask, err := d.taskMgr.GetGlobalTaskByID(d.Task.ID)
newTask, err := d.taskMgr.GetGlobalTaskByID(d.ctx, d.Task.ID)
if err != nil {
logutil.Logger(d.logCtx).Error("refresh task failed", zap.Error(err))
return err
Expand Down Expand Up @@ -166,7 +166,7 @@ func (d *BaseDispatcher) scheduleTask() {
}
failpoint.Inject("cancelTaskAfterRefreshTask", func(val failpoint.Value) {
if val.(bool) && d.Task.State == proto.TaskStateRunning {
err := d.taskMgr.CancelGlobalTask(d.Task.ID)
err := d.taskMgr.CancelGlobalTask(d.ctx, d.Task.ID)
if err != nil {
logutil.Logger(d.logCtx).Error("cancel task failed", zap.Error(err))
}
Expand All @@ -175,7 +175,7 @@ func (d *BaseDispatcher) scheduleTask() {

failpoint.Inject("pausePendingTask", func(val failpoint.Value) {
if val.(bool) && d.Task.State == proto.TaskStatePending {
_, err := d.taskMgr.PauseTask(d.Task.Key)
_, err := d.taskMgr.PauseTask(d.ctx, d.Task.Key)
if err != nil {
logutil.Logger(d.logCtx).Error("pause task failed", zap.Error(err))
}
Expand All @@ -185,7 +185,7 @@ func (d *BaseDispatcher) scheduleTask() {

failpoint.Inject("pauseTaskAfterRefreshTask", func(val failpoint.Value) {
if val.(bool) && d.Task.State == proto.TaskStateRunning {
_, err := d.taskMgr.PauseTask(d.Task.Key)
_, err := d.taskMgr.PauseTask(d.ctx, d.Task.Key)
if err != nil {
logutil.Logger(d.logCtx).Error("pause task failed", zap.Error(err))
}
Expand Down Expand Up @@ -243,7 +243,7 @@ func (d *BaseDispatcher) onCancelling() error {
// handle task in pausing state, cancel all running subtasks.
func (d *BaseDispatcher) onPausing() error {
logutil.Logger(d.logCtx).Info("on pausing state", zap.Stringer("state", d.Task.State), zap.Int64("stage", int64(d.Task.Step)))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.Task.ID, proto.TaskStateRunning, proto.TaskStatePending)
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.ctx, d.Task.ID, proto.TaskStateRunning, proto.TaskStatePending)
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
Expand Down Expand Up @@ -276,7 +276,7 @@ var TestSyncChan = make(chan struct{})
// handle task in resuming state
func (d *BaseDispatcher) onResuming() error {
logutil.Logger(d.logCtx).Info("on resuming state", zap.Stringer("state", d.Task.State), zap.Int64("stage", int64(d.Task.Step)))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.Task.ID, proto.TaskStatePaused)
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.ctx, d.Task.ID, proto.TaskStatePaused)
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
Expand All @@ -291,13 +291,13 @@ func (d *BaseDispatcher) onResuming() error {
return err
}

return d.taskMgr.ResumeSubtasks(d.Task.ID)
return d.taskMgr.ResumeSubtasks(d.ctx, d.Task.ID)
}

// handle task in reverting state, check all revert subtasks finished.
func (d *BaseDispatcher) onReverting() error {
logutil.Logger(d.logCtx).Debug("on reverting state", zap.Stringer("state", d.Task.State), zap.Int64("stage", int64(d.Task.Step)))
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.Task.ID, proto.TaskStateRevertPending, proto.TaskStateReverting)
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.ctx, d.Task.ID, proto.TaskStateRevertPending, proto.TaskStateReverting)
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
Expand All @@ -323,7 +323,7 @@ func (d *BaseDispatcher) onPending() error {
// If subtasks finished, run into the next stage.
func (d *BaseDispatcher) onRunning() error {
logutil.Logger(d.logCtx).Debug("on running state", zap.Stringer("state", d.Task.State), zap.Int64("stage", int64(d.Task.Step)))
subTaskErrs, err := d.taskMgr.CollectSubTaskError(d.Task.ID)
subTaskErrs, err := d.taskMgr.CollectSubTaskError(d.ctx, d.Task.ID)
if err != nil {
logutil.Logger(d.logCtx).Warn("collect subtask error failed", zap.Error(err))
return err
Expand All @@ -333,7 +333,7 @@ func (d *BaseDispatcher) onRunning() error {
return d.onErrHandlingStage(subTaskErrs)
}
// check current stage finished.
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.Task.ID, proto.TaskStatePending, proto.TaskStateRunning)
cnt, err := d.taskMgr.GetSubtaskInStatesCnt(d.ctx, d.Task.ID, proto.TaskStatePending, proto.TaskStateRunning)
if err != nil {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
Expand All @@ -355,13 +355,13 @@ func (d *BaseDispatcher) onRunning() error {
func (d *BaseDispatcher) onFinished() error {
metrics.UpdateMetricsForFinishTask(d.Task)
logutil.Logger(d.logCtx).Debug("schedule task, task is finished", zap.Stringer("state", d.Task.State))
return d.taskMgr.TransferSubTasks2History(d.Task.ID)
return d.taskMgr.TransferSubTasks2History(d.ctx, d.Task.ID)
}

func (d *BaseDispatcher) replaceDeadNodesIfAny() error {
if len(d.taskNodes) == 0 {
var err error
d.taskNodes, err = d.taskMgr.GetSchedulerIDsByTaskIDAndStep(d.Task.ID, d.Task.Step)
d.taskNodes, err = d.taskMgr.GetSchedulerIDsByTaskIDAndStep(d.ctx, d.Task.ID, d.Task.Step)
if err != nil {
return err
}
Expand Down Expand Up @@ -411,10 +411,10 @@ func (d *BaseDispatcher) replaceDeadNodesIfAny() error {
}
if len(replaceNodes) > 0 {
logutil.Logger(d.logCtx).Info("reschedule subtasks to other nodes", zap.Int("node-cnt", len(replaceNodes)))
if err := d.taskMgr.UpdateFailedSchedulerIDs(d.Task.ID, replaceNodes); err != nil {
if err := d.taskMgr.UpdateFailedSchedulerIDs(d.ctx, d.Task.ID, replaceNodes); err != nil {
return err
}
if err := d.taskMgr.CleanUpMeta(cleanNodes); err != nil {
if err := d.taskMgr.CleanUpMeta(d.ctx, cleanNodes); err != nil {
return err
}
// replace local cache.
Expand All @@ -441,15 +441,15 @@ func (d *BaseDispatcher) updateTask(taskState proto.TaskState, newSubTasks []*pr
}

failpoint.Inject("cancelBeforeUpdate", func() {
err := d.taskMgr.CancelGlobalTask(d.Task.ID)
err := d.taskMgr.CancelGlobalTask(d.ctx, d.Task.ID)
if err != nil {
logutil.Logger(d.logCtx).Error("cancel task failed", zap.Error(err))
}
})

var retryable bool
for i := 0; i < retryTimes; i++ {
retryable, err = d.taskMgr.UpdateGlobalTaskAndAddSubTasks(d.Task, newSubTasks, prevState)
retryable, err = d.taskMgr.UpdateGlobalTaskAndAddSubTasks(d.ctx, d.Task, newSubTasks, prevState)
if err == nil || !retryable {
break
}
Expand Down Expand Up @@ -658,13 +658,13 @@ func GenerateSchedulerNodes(ctx context.Context) (serverNodes []*infosync.Server
}

func (d *BaseDispatcher) filterByRole(infos []*infosync.ServerInfo) ([]*infosync.ServerInfo, error) {
nodes, err := d.taskMgr.GetNodesByRole("background")
nodes, err := d.taskMgr.GetNodesByRole(d.ctx, "background")
if err != nil {
return nil, err
}

if len(nodes) == 0 {
nodes, err = d.taskMgr.GetNodesByRole("")
nodes, err = d.taskMgr.GetNodesByRole(d.ctx, "")
}

if err != nil {
Expand Down Expand Up @@ -693,7 +693,7 @@ func (d *BaseDispatcher) GetAllSchedulerIDs(ctx context.Context, task *proto.Tas
return nil, nil
}

schedulerIDs, err := d.taskMgr.GetSchedulerIDsByTaskID(task.ID)
schedulerIDs, err := d.taskMgr.GetSchedulerIDsByTaskID(d.ctx, task.ID)
if err != nil {
return nil, err
}
Expand All @@ -708,7 +708,7 @@ func (d *BaseDispatcher) GetAllSchedulerIDs(ctx context.Context, task *proto.Tas

// GetPreviousSubtaskMetas get subtask metas from specific step.
func (d *BaseDispatcher) GetPreviousSubtaskMetas(taskID int64, step proto.Step) ([][]byte, error) {
previousSubtasks, err := d.taskMgr.GetSucceedSubtasksByStep(taskID, step)
previousSubtasks, err := d.taskMgr.GetSucceedSubtasksByStep(d.ctx, taskID, step)
if err != nil {
logutil.Logger(d.logCtx).Warn("get previous succeed subtask failed", zap.Int64("step", int64(step)))
return nil, err
Expand All @@ -722,7 +722,7 @@ func (d *BaseDispatcher) GetPreviousSubtaskMetas(taskID int64, step proto.Step)

// GetPreviousSchedulerIDs gets scheduler IDs that run previous step.
func (d *BaseDispatcher) GetPreviousSchedulerIDs(_ context.Context, taskID int64, step proto.Step) ([]string, error) {
return d.taskMgr.GetSchedulerIDsByTaskIDAndStep(taskID, step)
return d.taskMgr.GetSchedulerIDsByTaskIDAndStep(d.ctx, taskID, step)
}

// WithNewSession executes the function with a new session.
Expand Down
12 changes: 7 additions & 5 deletions pkg/disttask/framework/dispatcher/dispatcher_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ func (dm *Manager) dispatchTaskLoop() {

// TODO: Consider getting these tasks, in addition to the task being worked on..
tasks, err := dm.taskMgr.GetGlobalTasksInStates(
dm.ctx,
proto.TaskStatePending,
proto.TaskStateRunning,
proto.TaskStateReverting,
Expand Down Expand Up @@ -223,7 +224,7 @@ func (dm *Manager) failTask(task *proto.Task, err error) {
prevState := task.State
task.State = proto.TaskStateFailed
task.Error = err
if _, err2 := dm.taskMgr.UpdateGlobalTaskAndAddSubTasks(task, nil, prevState); err2 != nil {
if _, err2 := dm.taskMgr.UpdateGlobalTaskAndAddSubTasks(dm.ctx, task, nil, prevState); err2 != nil {
logutil.BgLogger().Warn("failed to update task state to failed",
zap.Int64("task-id", task.ID), zap.Error(err2))
}
Expand All @@ -248,7 +249,7 @@ func (dm *Manager) gcSubtaskHistoryTableLoop() {
logutil.BgLogger().Info("subtask history table gc loop exits", zap.Error(dm.ctx.Err()))
return
case <-ticker.C:
err := dm.taskMgr.GCSubtasks()
err := dm.taskMgr.GCSubtasks(dm.ctx)
if err != nil {
logutil.BgLogger().Warn("subtask history table gc failed", zap.Error(err))
} else {
Expand Down Expand Up @@ -318,6 +319,7 @@ func (dm *Manager) doCleanUpRoutine() {
logutil.BgLogger().Info("clean up nodes in framework meta since nodes shutdown", zap.Int("cnt", cnt))
}
tasks, err := dm.taskMgr.GetGlobalTasksInStates(
dm.ctx,
proto.TaskStateFailed,
proto.TaskStateReverted,
proto.TaskStateSucceed,
Expand Down Expand Up @@ -350,7 +352,7 @@ func (dm *Manager) CleanUpMeta() int {
return 0
}

oldNodes, err := dm.taskMgr.GetAllNodes()
oldNodes, err := dm.taskMgr.GetAllNodes(dm.ctx)
if err != nil {
logutil.BgLogger().Warn("get all nodes met error")
return 0
Expand All @@ -366,7 +368,7 @@ func (dm *Manager) CleanUpMeta() int {
return 0
}
logutil.BgLogger().Info("start to clean up dist_framework_meta")
err = dm.taskMgr.CleanUpMeta(cleanNodes)
err = dm.taskMgr.CleanUpMeta(dm.ctx, cleanNodes)
if err != nil {
logutil.BgLogger().Warn("clean up dist_framework_meta met error")
return 0
Expand Down Expand Up @@ -396,7 +398,7 @@ func (dm *Manager) cleanUpFinishedTasks(tasks []*proto.Task) error {
logutil.BgLogger().Warn("cleanUp routine failed", zap.Error(errors.Trace(firstErr)))
}

return dm.taskMgr.TransferTasks2History(cleanedTasks)
return dm.taskMgr.TransferTasks2History(dm.ctx, cleanedTasks)
}

// MockDispatcher mock one dispatcher for one task, only used for tests.
Expand Down

0 comments on commit 844ba42

Please sign in to comment.