diff --git a/cmd/daemon/serve.go b/cmd/daemon/serve.go index 2745884677f..6a86d80126e 100644 --- a/cmd/daemon/serve.go +++ b/cmd/daemon/serve.go @@ -1,7 +1,6 @@ package daemon import ( - stdctx "context" "net/http" "strings" "sync" @@ -165,9 +164,10 @@ func sqa(cmd *cobra.Command, d driver.Driver) *metricsx.Service { func bgTasks(d driver.Driver, wg *sync.WaitGroup, cmd *cobra.Command, args []string) { defer wg.Done() - if err := d.Registry().Courier().Work(stdctx.Background()); err != nil { + if err := graceful.Graceful(d.Registry().Courier().Work, d.Registry().Courier().Shutdown); err != nil { d.Logger().WithError(err).Fatalf("Failed to run courier worker.") } + d.Logger().Println("courier worker was shutdown gracefully") } func ServeAll(d driver.Driver) func(cmd *cobra.Command, args []string) { diff --git a/courier/courier.go b/courier/courier.go index d8c5b6e501a..b909501a424 100644 --- a/courier/courier.go +++ b/courier/courier.go @@ -27,6 +27,9 @@ type ( dialer *gomail.Dialer d smtpDependencies c configuration.Provider + // graceful shutdown handling + ctx context.Context + shutdown context.CancelFunc } Provider interface { Courier() *Courier @@ -38,9 +41,12 @@ func NewSMTP(d smtpDependencies, c configuration.Provider) *Courier { sslSkipVerify, _ := strconv.ParseBool(uri.Query().Get("skip_ssl_verify")) password, _ := uri.User.Password() port, _ := strconv.ParseInt(uri.Port(), 10, 64) + ctx, cancel := context.WithCancel(context.Background()) return &Courier{ - d: d, - c: c, + d: d, + c: c, + ctx: ctx, + shutdown: cancel, dialer: &gomail.Dialer{ Host: uri.Hostname(), Port: int(port), @@ -82,20 +88,28 @@ func (m *Courier) QueueEmail(ctx context.Context, t EmailTemplate) (uuid.UUID, e return message.ID, nil } -func (m *Courier) Work(ctx context.Context) error { +func (m *Courier) Work() error { errChan := make(chan error) defer close(errChan) - go m.watchMessages(ctx, errChan) + go m.watchMessages(m.ctx, errChan) select { - case <-ctx.Done(): - return ctx.Err() + case <-m.ctx.Done(): + if m.ctx.Err() == context.Canceled { + return nil + } + return m.ctx.Err() case err := <-errChan: return err } } +func (m *Courier) Shutdown(ctx context.Context) error { + m.shutdown() + return nil +} + func (m *Courier) watchMessages(ctx context.Context, errChan chan error) { for { if err := backoff.Retry(func() error { diff --git a/courier/courier_test.go b/courier/courier_test.go index 9de5c02f255..dfd80920800 100644 --- a/courier/courier_test.go +++ b/courier/courier_test.go @@ -104,7 +104,7 @@ func TestSMTP(t *testing.T) { c := reg.Courier() go func() { - require.NoError(t, c.Work(context.Background())) + require.NoError(t, c.Work()) }() t.Run("case=queue messages", func(t *testing.T) {