diff --git a/loader/loader.go b/loader/loader.go index 8e9b853169..5f108476c8 100644 --- a/loader/loader.go +++ b/loader/loader.go @@ -104,6 +104,7 @@ func NewWorker(loader *Loader, id int) (worker *Worker, err error) { // Close closes worker func (w *Worker) Close() { + // simulate the case that doesn't wait all doJob goroutine exit failpoint.Inject("workerCantClose", func(_ failpoint.Value) { w.tctx.L().Info("", zap.String("failpoint", "workerCantClose")) failpoint.Return() @@ -121,14 +122,14 @@ func (w *Worker) Close() { w.tctx.L().Info("closed !!!") } -func (w *Worker) run(ctx context.Context, fileJobQueue chan *fileJob, workerWg *sync.WaitGroup, runFatalChan chan *pb.ProcessError) { +func (w *Worker) run(ctx context.Context, fileJobQueue chan *fileJob, runFatalChan chan *pb.ProcessError) { atomic.StoreInt64(&w.closed, 0) newCtx, cancel := context.WithCancel(ctx) defer func() { cancel() + // make sure all doJob goroutines exit w.Close() - workerWg.Done() }() ctctx := w.tctx.WithContext(newCtx) @@ -363,7 +364,6 @@ type Loader struct { bwList *filter.Filter columnMapping *cm.Mapping - pool []*Worker closed sync2.AtomicBool toDB *conn.BaseDB @@ -385,7 +385,6 @@ func NewLoader(cfg *config.SubTaskConfig) *Loader { db2Tables: make(map[string]Tables2DataFiles), tableInfos: make(map[string]*tableInfo), workerWg: new(sync.WaitGroup), - pool: make([]*Worker, 0, cfg.PoolSize), tctx: tcontext.Background().WithLogger(log.With(zap.String("task", cfg.Name), zap.String("unit", "load"))), } loader.fileJobQueueClosed.Set(true) // not open yet @@ -479,7 +478,13 @@ func (l *Loader) Process(ctx context.Context, pr chan pb.ProcessResult) { err := l.Restore(newCtx) close(l.runFatalChan) // Restore returned, all potential fatal sent to l.runFatalChan - wg.Wait() // wait for receive all fatal from l.runFatalChan + + failpoint.Inject("dontWaitWorkerExit", func(_ failpoint.Value) { + l.tctx.L().Info("", zap.String("failpoint", "dontWaitWorkerExit")) + l.workerWg.Wait() + }) + + wg.Wait() // wait for receive all fatal from l.runFatalChan if err != nil { loaderExitWithErrorCounter.WithLabelValues(l.cfg.Name).Inc() @@ -566,13 +571,24 @@ func (l *Loader) Restore(ctx context.Context) error { go l.PrintStatus(ctx) - if err := l.restoreData(ctx); err != nil { - if errors.Cause(err) == context.Canceled { - return nil - } + begin := time.Now() + err = l.restoreData(ctx) + if err != nil && errors.Cause(err) != context.Canceled { return err } + failpoint.Inject("dontWaitWorkerExit", func(_ failpoint.Value) { + l.tctx.L().Info("", zap.String("failpoint", "dontWaitWorkerExit")) + failpoint.Return(nil) + }) + + // make sure all workers exit + l.closeFileJobQueue() // all data file dispatched, close it + l.workerWg.Wait() + if err == nil { + l.tctx.L().Info("all data files have been finished", zap.Duration("cost time", time.Since(begin))) + } + return nil } @@ -611,10 +627,6 @@ func (l *Loader) stopLoad() { l.closeFileJobQueue() l.workerWg.Wait() - for _, worker := range l.pool { - worker.Close() - } - l.pool = l.pool[:0] l.tctx.L().Debug("all workers have been closed") } @@ -741,9 +753,10 @@ func (l *Loader) initAndStartWorkerPool(ctx context.Context) error { } l.workerWg.Add(1) // for every worker goroutine, Add(1) - go worker.run(ctx, l.fileJobQueue, l.workerWg, l.runFatalChan) - - l.pool = append(l.pool, worker) + go func() { + defer l.workerWg.Done() + worker.run(ctx, l.fileJobQueue, l.runFatalChan) + }() } return nil } @@ -1138,17 +1151,12 @@ func (l *Loader) restoreData(ctx context.Context) error { select { case <-ctx.Done(): l.tctx.L().Warn("stop dispatch data file job", log.ShortError(ctx.Err())) - l.closeFileJobQueue() return ctx.Err() case l.fileJobQueue <- j: } } - l.closeFileJobQueue() // all data file dispatched, close it l.tctx.L().Info("all data files have been dispatched, waiting for them finished") - l.workerWg.Wait() - - l.tctx.L().Info("all data files have been finished", zap.Duration("cost time", time.Since(begin))) return nil } diff --git a/tests/import_goroutine_leak/run.sh b/tests/import_goroutine_leak/run.sh index a10be10146..384d85b46c 100644 --- a/tests/import_goroutine_leak/run.sh +++ b/tests/import_goroutine_leak/run.sh @@ -27,7 +27,8 @@ function run() { run_sql_file $WORK_DIR/db2.prepare.sql $MYSQL_HOST2 $MYSQL_PORT2 - # check workers of import unit exit + echo "dm-worker paninc, doJob of import unit workers don't exit" + # check doJobs of import unit worker exit inject_points=("github.com/pingcap/dm/loader/dispatchError=return(1)" "github.com/pingcap/dm/loader/LoadDataSlowDown=sleep(1000)" "github.com/pingcap/dm/loader/executeSQLError=return(1)" @@ -54,6 +55,26 @@ function run() { exit 2 fi + echo "dm-workers paninc again, workers of import unit don't exit" + # check workers of import unit exit + inject_points=("github.com/pingcap/dm/loader/dontWaitWorkerExit=return(1)" + "github.com/pingcap/dm/loader/LoadDataSlowDown=sleep(1000)" + "github.com/pingcap/dm/loader/executeSQLError=return(1)" + ) + export GO_FAILPOINTS="$(join_string \; ${inject_points[@]})" + run_dm_worker $WORK_DIR/worker1 $WORKER1_PORT $cur/conf/dm-worker1.toml + run_dm_worker $WORK_DIR/worker2 $WORKER2_PORT $cur/conf/dm-worker2.toml + sleep 2s + check_port_offline $WORKER1_PORT 20 + check_port_offline $WORKER2_PORT 20 + + # dm-worker1 panics + err_cnt=`grep "panic" $WORK_DIR/worker1/log/stdout.log | wc -l` + if [ $err_cnt -ne 2 ]; then + echo "dm-worker1 doesn't panic again, panic count ${err_cnt}" + exit 2 + fi + # check workers of import unit exit inject_points=("github.com/pingcap/dm/loader/dispatchError=return(1)" "github.com/pingcap/dm/loader/LoadDataSlowDown=sleep(1000)"