diff --git a/.golangci.yml b/.golangci.yml index 595fdc3..7f69ce9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -36,3 +36,4 @@ issues: - errcheck - unparam - prealloc + - funlen diff --git a/README.md b/README.md index 4fce730..0c86771 100644 --- a/README.md +++ b/README.md @@ -13,8 +13,10 @@ Inspired by the [Quartz](https://github.com/quartz-scheduler/quartz) Java schedu Scheduler interface ```go type Scheduler interface { - // Start starts the scheduler. - Start() + // Start starts the scheduler. The scheduler will run until + // the Stop method is called or the context is canceled. Use + // the Wait method to block until all running jobs have completed. + Start(context.Context) // IsStarted determines whether the scheduler has been started. IsStarted() bool // ScheduleJob schedules a job using a specified trigger. @@ -29,6 +31,11 @@ type Scheduler interface { Clear() // Stop shutdowns the scheduler. Stop() + // Wait blocks until the scheduler stops running and all jobs + // have returned. Wait will return when the context passed to + // it has expired. Until the context passed to start is + // cancelled or Stop is called directly. + Wait(context.Context) } ``` Implemented Schedulers @@ -52,7 +59,7 @@ Job interface. Any type that implements it can be scheduled. ```go type Job interface { // Execute is called by a Scheduler when the Trigger associated with this job fires. - Execute() + Execute(context.Context) // Description returns the description of the Job. Description() string // Key returns the unique key for the Job. @@ -77,16 +84,18 @@ Implemented Jobs ## Examples ```go +ctx := context.Background() sched := quartz.NewStdScheduler() -sched.Start() +sched.Start(ctx) cronTrigger, _ := quartz.NewCronTrigger("1/5 * * * * *") shellJob := quartz.NewShellJob("ls -la") curlJob, _ := quartz.NewCurlJob(http.MethodGet, "http://worldclockapi.com/api/json/est/now", "", nil) -functionJob := quartz.NewFunctionJob(func() (int, error) { return 42, nil }) +functionJob := quartz.NewFunctionJob(func(_ context.Context) (int, error) { return 42, nil }) sched.ScheduleJob(shellJob, cronTrigger) sched.ScheduleJob(curlJob, quartz.NewSimpleTrigger(time.Second*7)) sched.ScheduleJob(functionJob, quartz.NewSimpleTrigger(time.Second*5)) sched.Stop() +sched.Wait(ctx) ``` More code samples can be found in the examples directory. diff --git a/examples/main.go b/examples/main.go index 278de4f..ee9c15e 100644 --- a/examples/main.go +++ b/examples/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "net/http" "sync" @@ -10,16 +11,20 @@ import ( ) func main() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() wg := new(sync.WaitGroup) wg.Add(2) - go sampleJobs(wg) - go sampleScheduler(wg) + go sampleJobs(ctx, wg) + go sampleScheduler(ctx, wg) wg.Wait() } -func sampleScheduler(wg *sync.WaitGroup) { +func sampleScheduler(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + sched := quartz.NewStdScheduler() cronTrigger, err := quartz.NewCronTrigger("1/3 * * * * *") if err != nil { @@ -28,7 +33,8 @@ func sampleScheduler(wg *sync.WaitGroup) { } cronJob := PrintJob{"Cron job"} - sched.Start() + sched.Start(ctx) + sched.ScheduleJob(&PrintJob{"Ad hoc Job"}, quartz.NewRunOnceTrigger(time.Second*5)) sched.ScheduleJob(&PrintJob{"First job"}, quartz.NewSimpleTrigger(time.Second*12)) sched.ScheduleJob(&PrintJob{"Second job"}, quartz.NewSimpleTrigger(time.Second*6)) @@ -50,12 +56,13 @@ func sampleScheduler(wg *sync.WaitGroup) { time.Sleep(time.Second * 2) sched.Stop() - wg.Done() + sched.Wait(ctx) } -func sampleJobs(wg *sync.WaitGroup) { +func sampleJobs(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() sched := quartz.NewStdScheduler() - sched.Start() + sched.Start(ctx) cronTrigger, err := quartz.NewCronTrigger("1/5 * * * * *") if err != nil { @@ -69,7 +76,7 @@ func sampleJobs(wg *sync.WaitGroup) { fmt.Println(err) return } - functionJob := quartz.NewFunctionJobWithDesc("42", func() (int, error) { return 42, nil }) + functionJob := quartz.NewFunctionJobWithDesc("42", func(_ context.Context) (int, error) { return 42, nil }) sched.ScheduleJob(shellJob, cronTrigger) sched.ScheduleJob(curlJob, quartz.NewSimpleTrigger(time.Second*7)) @@ -84,5 +91,5 @@ func sampleJobs(wg *sync.WaitGroup) { time.Sleep(time.Second * 2) sched.Stop() - wg.Done() + sched.Wait(ctx) } diff --git a/examples/print_job.go b/examples/print_job.go index 72acf9e..654e0f8 100644 --- a/examples/print_job.go +++ b/examples/print_job.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "github.com/reugn/go-quartz/quartz" @@ -22,6 +23,6 @@ func (pj *PrintJob) Key() int { } // Execute is called by a Scheduler when the Trigger associated with this job fires. -func (pj *PrintJob) Execute() { +func (pj *PrintJob) Execute(_ context.Context) { fmt.Println("Executing " + pj.Description()) } diff --git a/quartz/function_job.go b/quartz/function_job.go index 43a6fa1..f369261 100644 --- a/quartz/function_job.go +++ b/quartz/function_job.go @@ -1,11 +1,12 @@ package quartz import ( + "context" "fmt" ) // Function represents an argument-less function which returns a generic type R and a possible error. -type Function[R any] func() (R, error) +type Function[R any] func(context.Context) (R, error) // FunctionJob represents a Job that invokes the passed Function, implements the quartz.Job interface. type FunctionJob[R any] struct { @@ -50,8 +51,8 @@ func (f *FunctionJob[R]) Key() int { // Execute is called by a Scheduler when the Trigger associated with this job fires. // It invokes the held function, setting the results in Result and Error members. -func (f *FunctionJob[R]) Execute() { - result, err := (*f.function)() +func (f *FunctionJob[R]) Execute(ctx context.Context) { + result, err := (*f.function)(ctx) if err != nil { f.JobStatus = FAILURE f.Result = nil diff --git a/quartz/function_job_test.go b/quartz/function_job_test.go index 59c1853..86e397c 100644 --- a/quartz/function_job_test.go +++ b/quartz/function_job_test.go @@ -1,6 +1,8 @@ package quartz_test import ( + "context" + "errors" "testing" "time" @@ -8,19 +10,22 @@ import ( ) func TestFunctionJob(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var n = 2 - funcJob1 := quartz.NewFunctionJob(func() (string, error) { + funcJob1 := quartz.NewFunctionJob(func(_ context.Context) (string, error) { n += 2 return "fired1", nil }) - funcJob2 := quartz.NewFunctionJob(func() (int, error) { + funcJob2 := quartz.NewFunctionJob(func(_ context.Context) (int, error) { n += 2 return 42, nil }) sched := quartz.NewStdScheduler() - sched.Start() + sched.Start(ctx) sched.ScheduleJob(funcJob1, quartz.NewRunOnceTrigger(time.Millisecond*300)) sched.ScheduleJob(funcJob2, quartz.NewRunOnceTrigger(time.Millisecond*800)) time.Sleep(time.Second) @@ -37,3 +42,40 @@ func TestFunctionJob(t *testing.T) { assertEqual(t, n, 6) } + +func TestFunctionJobRespectsContext(t *testing.T) { + var n int + funcJob2 := quartz.NewFunctionJob(func(ctx context.Context) (bool, error) { + timer := time.NewTimer(time.Hour) + defer timer.Stop() + select { + case <-ctx.Done(): + n-- + return false, ctx.Err() + case <-timer.C: + n++ + return true, nil + } + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sig := make(chan struct{}) + go func() { defer close(sig); funcJob2.Execute(ctx) }() + + if n != 0 { + t.Fatal("job should not have run yet") + } + cancel() + <-sig + + if n != -1 { + t.Fatal("job side effect should have reflected cancelation:", n) + } + if !errors.Is(funcJob2.Error, context.Canceled) { + t.Fatal("unexpected error function", funcJob2.Error) + } + if funcJob2.Result != nil { + t.Fatal("errored jobs should not return values") + } +} diff --git a/quartz/job.go b/quartz/job.go index 6c6105d..ee82c12 100644 --- a/quartz/job.go +++ b/quartz/job.go @@ -2,6 +2,7 @@ package quartz import ( "bytes" + "context" "fmt" "io" "net/http" @@ -12,7 +13,7 @@ import ( // to be performed. type Job interface { // Execute is called by a Scheduler when the Trigger associated with this job fires. - Execute() + Execute(context.Context) // Description returns the description of the Job. Description() string @@ -63,8 +64,8 @@ func (sh *ShellJob) Key() int { } // Execute is called by a Scheduler when the Trigger associated with this job fires. -func (sh *ShellJob) Execute() { - out, err := exec.Command("sh", "-c", sh.Cmd).Output() +func (sh *ShellJob) Execute(ctx context.Context) { + out, err := exec.CommandContext(ctx, "sh", "-c", sh.Cmd).Output() if err != nil { sh.JobStatus = FAILURE sh.Result = err.Error() @@ -128,8 +129,9 @@ func (cu *CurlJob) Key() int { } // Execute is called by a Scheduler when the Trigger associated with this job fires. -func (cu *CurlJob) Execute() { +func (cu *CurlJob) Execute(ctx context.Context) { client := &http.Client{} + cu.request = cu.request.WithContext(ctx) resp, err := client.Do(cu.request) if err != nil { cu.JobStatus = FAILURE diff --git a/quartz/scheduler.go b/quartz/scheduler.go index daf05a9..cc1dddc 100644 --- a/quartz/scheduler.go +++ b/quartz/scheduler.go @@ -2,6 +2,7 @@ package quartz import ( "container/heap" + "context" "errors" "log" "sync" @@ -19,8 +20,10 @@ type ScheduledJob struct { // Schedulers are responsible for executing Jobs when their associated // Triggers fire (when their scheduled time arrives). type Scheduler interface { - // Start starts the scheduler. - Start() + // Start starts the scheduler. The scheduler will run until + // the Stop method is called or the context is canceled. Use + // the Wait method to block until all running jobs have completed. + Start(context.Context) // IsStarted determines whether the scheduler has been started. IsStarted() bool @@ -40,18 +43,25 @@ type Scheduler interface { // Clear removes all of the scheduled jobs. Clear() + // Wait blocks until the scheduler stops running and all jobs + // have returned. Wait will return when the context passed to + // it has expired. Until the context passed to start is + // cancelled or Stop is called directly. + Wait(context.Context) + // Stop shutdowns the scheduler. Stop() } // StdScheduler implements the quartz.Scheduler interface. type StdScheduler struct { - sync.Mutex + mtx sync.Mutex queue *priorityQueue interrupt chan struct{} - exit chan struct{} + signal chan struct{} feeder chan *item started bool + cancel context.CancelFunc } // Verify StdScheduler satisfies the Scheduler interface. @@ -62,8 +72,9 @@ func NewStdScheduler() *StdScheduler { return &StdScheduler{ queue: &priorityQueue{}, interrupt: make(chan struct{}, 1), - exit: nil, + cancel: func() {}, feeder: make(chan *item), + signal: make(chan struct{}), } } @@ -74,35 +85,46 @@ func (sched *StdScheduler) ScheduleJob(job Job, trigger Trigger) error { return err } - sched.feeder <- &item{ + select { + case sched.feeder <- &item{ Job: job, Trigger: trigger, priority: nextRunTime, index: 0, + }: + return nil + case <-sched.signal: + return context.Canceled } - - return nil } // Start starts the StdScheduler execution loop. -func (sched *StdScheduler) Start() { - sched.Lock() - defer sched.Unlock() +func (sched *StdScheduler) Start(ctx context.Context) { + sched.mtx.Lock() + defer sched.mtx.Unlock() if sched.started { return } - // reset the exit channel - sched.exit = make(chan struct{}) - + ctx, sched.cancel = context.WithCancel(ctx) + go func() { <-ctx.Done(); sched.Stop() }() // start the feed reader - go sched.startFeedReader() + go sched.startFeedReader(ctx) // start scheduler execution loop - go sched.startExecutionLoop() + go sched.startExecutionLoop(ctx) sched.started = true + sched.signal = make(chan struct{}) +} + +// Wait blocks until the scheduler shuts down. +func (sched *StdScheduler) Wait(ctx context.Context) { + select { + case <-ctx.Done(): + case <-sched.signal: + } } // IsStarted determines whether the scheduler has been started. @@ -112,8 +134,8 @@ func (sched *StdScheduler) IsStarted() bool { // GetJobKeys returns the keys of all of the scheduled jobs. func (sched *StdScheduler) GetJobKeys() []int { - sched.Lock() - defer sched.Unlock() + sched.mtx.Lock() + defer sched.mtx.Unlock() keys := make([]int, 0, sched.queue.Len()) for _, item := range *sched.queue { @@ -125,8 +147,8 @@ func (sched *StdScheduler) GetJobKeys() []int { // GetScheduledJob returns the ScheduledJob with the specified key. func (sched *StdScheduler) GetScheduledJob(key int) (*ScheduledJob, error) { - sched.Lock() - defer sched.Unlock() + sched.mtx.Lock() + defer sched.mtx.Unlock() for _, item := range *sched.queue { if item.Job.Key() == key { @@ -143,8 +165,8 @@ func (sched *StdScheduler) GetScheduledJob(key int) (*ScheduledJob, error) { // DeleteJob removes the Job with the specified key if present. func (sched *StdScheduler) DeleteJob(key int) error { - sched.Lock() - defer sched.Unlock() + sched.mtx.Lock() + defer sched.mtx.Unlock() for i, item := range *sched.queue { if item.Job.Key() == key { @@ -158,8 +180,8 @@ func (sched *StdScheduler) DeleteJob(key int) error { // Clear removes all of the scheduled jobs. func (sched *StdScheduler) Clear() { - sched.Lock() - defer sched.Unlock() + sched.mtx.Lock() + defer sched.mtx.Unlock() // reset the job queue sched.queue = &priorityQueue{} @@ -167,25 +189,26 @@ func (sched *StdScheduler) Clear() { // Stop exits the StdScheduler execution loop. func (sched *StdScheduler) Stop() { - sched.Lock() - defer sched.Unlock() + sched.mtx.Lock() + defer sched.mtx.Unlock() if !sched.started { return } log.Printf("Closing the StdScheduler.") - close(sched.exit) - + sched.cancel() sched.started = false + close(sched.signal) } -func (sched *StdScheduler) startExecutionLoop() { +func (sched *StdScheduler) startExecutionLoop(ctx context.Context) { + for { if sched.queueLen() == 0 { select { case <-sched.interrupt: - case <-sched.exit: + case <-ctx.Done(): log.Printf("Exit the empty execution loop.") return } @@ -193,12 +216,12 @@ func (sched *StdScheduler) startExecutionLoop() { t := time.NewTimer(sched.calculateNextTick()) select { case <-t.C: - sched.executeAndReschedule() + sched.executeAndReschedule(ctx) case <-sched.interrupt: t.Stop() - case <-sched.exit: + case <-ctx.Done(): log.Printf("Exit the execution loop.") t.Stop() return @@ -208,69 +231,79 @@ func (sched *StdScheduler) startExecutionLoop() { } func (sched *StdScheduler) queueLen() int { - sched.Lock() - defer sched.Unlock() + sched.mtx.Lock() + defer sched.mtx.Unlock() return sched.queue.Len() } func (sched *StdScheduler) calculateNextTick() time.Duration { - sched.Lock() var interval int64 + + sched.mtx.Lock() + defer sched.mtx.Unlock() if sched.queue.Len() > 0 { interval = parkTime(sched.queue.Head().priority) } - sched.Unlock() return time.Duration(interval) } -func (sched *StdScheduler) executeAndReschedule() { +func (sched *StdScheduler) executeAndReschedule(ctx context.Context) { // return if the job queue is empty if sched.queueLen() == 0 { return } // fetch an item - sched.Lock() - item := heap.Pop(sched.queue).(*item) - sched.Unlock() + var it *item + func() { + sched.mtx.Lock() + defer sched.mtx.Unlock() + it = heap.Pop(sched.queue).(*item) + }() // execute the Job - if !isOutdated(item.priority) { - go item.Job.Execute() + if !isOutdated(it.priority) { + go it.Job.Execute(ctx) } // reschedule the Job - nextRunTime, err := item.Trigger.NextFireTime(item.priority) + nextRunTime, err := it.Trigger.NextFireTime(it.priority) if err != nil { - log.Printf("The Job '%s' got out the execution loop.", item.Job.Description()) + log.Printf("The Job '%s' got out the execution loop: %q", it.Job.Description(), err.Error()) return } - item.priority = nextRunTime - sched.feeder <- item + it.priority = nextRunTime + select { + case <-ctx.Done(): + case sched.feeder <- it: + } } -func (sched *StdScheduler) startFeedReader() { +func (sched *StdScheduler) startFeedReader(ctx context.Context) { for { select { case item := <-sched.feeder: - sched.Lock() - heap.Push(sched.queue, item) - sched.reset() - sched.Unlock() - - case <-sched.exit: + func() { + sched.mtx.Lock() + defer sched.mtx.Unlock() + + heap.Push(sched.queue, item) + sched.reset(ctx) + }() + case <-ctx.Done(): log.Printf("Exit the feed reader.") return } } } -func (sched *StdScheduler) reset() { +func (sched *StdScheduler) reset(ctx context.Context) { select { case sched.interrupt <- struct{}{}: + case <-ctx.Done(): default: } } diff --git a/quartz/scheduler_test.go b/quartz/scheduler_test.go index d610c42..ccbf264 100644 --- a/quartz/scheduler_test.go +++ b/quartz/scheduler_test.go @@ -1,7 +1,9 @@ package quartz_test import ( + "context" "net/http" + "runtime" "testing" "time" @@ -9,6 +11,9 @@ import ( ) func TestScheduler(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sched := quartz.NewStdScheduler() var jobKeys [4]int @@ -28,7 +33,7 @@ func TestScheduler(t *testing.T) { assertEqual(t, err, nil) jobKeys[3] = errCurlJob.Key() - sched.Start() + sched.Start(ctx) sched.ScheduleJob(shellJob, quartz.NewSimpleTrigger(time.Millisecond*800)) sched.ScheduleJob(curlJob, quartz.NewRunOnceTrigger(time.Millisecond)) sched.ScheduleJob(errShellJob, quartz.NewRunOnceTrigger(time.Millisecond)) @@ -59,3 +64,93 @@ func TestScheduler(t *testing.T) { assertEqual(t, errShellJob.JobStatus, quartz.FAILURE) assertEqual(t, errCurlJob.JobStatus, quartz.FAILURE) } + +func TestSchedulerCancel(t *testing.T) { + hourJob := func(ctx context.Context) (bool, error) { + timer := time.NewTimer(time.Hour) + defer timer.Stop() + select { + case <-ctx.Done(): + return false, ctx.Err() + case <-timer.C: + return true, nil + } + } + for _, tt := range []string{"context", "stop"} { + // give the go runtime to exit many threads + // before the second case. + time.Sleep(time.Millisecond) + t.Run("CloseMethod_"+tt, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + waitCtx, waitCancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer waitCancel() + + startingRoutines := runtime.NumGoroutine() + + sched := quartz.NewStdScheduler() + + sched.Start(ctx) + noopRoutines := runtime.NumGoroutine() + if startingRoutines > noopRoutines { + t.Error("should have started more threads", + startingRoutines, + noopRoutines, + ) + } + + for i := 0; i < 100; i++ { + if err := sched.ScheduleJob( + quartz.NewFunctionJob(hourJob), + quartz.NewSimpleTrigger(100*time.Millisecond), + ); err != nil { + t.Errorf("could not add job %d, %s", i, err.Error()) + } + } + + runningRoutines := runtime.NumGoroutine() + if runningRoutines < noopRoutines { + t.Error("number of running routines should not decrease", + noopRoutines, + runningRoutines, + ) + } + switch tt { + case "context": + cancel() + case "stop": + sched.Stop() + time.Sleep(time.Millisecond) // trigger context switch + default: + t.Fatal("unknown test", tt) + } + + // should not have timed out before we get to this point + if err := waitCtx.Err(); err != nil { + t.Fatal("test took too long") + } + + sched.Wait(waitCtx) + if err := waitCtx.Err(); err != nil { + t.Fatal("waiting timed out before resources were released") + } + + endingRoutines := runtime.NumGoroutine() + if endingRoutines >= runningRoutines { + t.Error("number of routines should decrease after wait", + runningRoutines, + endingRoutines, + ) + } + + if t.Failed() { + t.Log("starting", startingRoutines, + "noop", noopRoutines, + "running", runningRoutines, + "ending", endingRoutines, + ) + } + }) + } +}