/
webserver.hooks.go
272 lines (233 loc) · 8.3 KB
/
webserver.hooks.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
// Copyright 2021 The searKing Author. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package webserver
import (
"context"
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"runtime/debug"
"slices"
"github.com/searKing/golang/go/runtime"
"github.com/searKing/golang/pkg/webserver/healthz"
"golang.org/x/sync/errgroup"
)
// PostStartHookFunc is a function that is called after the server has started.
// It must properly handle cases like:
// 1. asynchronous start in multiple API server processes
// 2. conflicts between the different processes all trying to perform the same action
// 3. partially complete work (API server crashes while running your hook)
// 4. API server access **BEFORE** your hook has completed
//
// Think of it like a mini-controller that is super privileged and gets to run in-process
// If you use this feature, tag @deads2k on github who has promised to review code for anyone's PostStartHook
// until it becomes easier to use.
// ctx will be cancelled when WebServer is Closed or any other PostStartHookFunc failed.
type PostStartHookFunc func(ctx context.Context) error
// PreShutdownHookFunc is a function that can be added to the shutdown logic.
type PreShutdownHookFunc func() error
// PostStartHookProvider is an interface in addition to provide a post start hook for the api server
type PostStartHookProvider interface {
PostStartHook() (string, PostStartHookFunc, error)
}
type postStartHookEntry struct {
hook PostStartHookFunc
// originatingStack holds the stack that registered postStartHooks. This allows us to show a more helpful message
// for duplicate registration.
originatingStack string
// done will be closed when the postHook is finished
done chan struct{}
}
type preShutdownHookEntry struct {
hook PreShutdownHookFunc
}
// AddBootSequencePostStartHook allows you to add a PostStartHook in order.
func (s *WebServer) AddBootSequencePostStartHook(name string, hook PostStartHookFunc) error {
return s.addPostStartHook(name, hook, true)
}
// AddPostStartHook allows you to add a PostStartHook.
func (s *WebServer) AddPostStartHook(name string, hook PostStartHookFunc) error {
return s.addPostStartHook(name, hook, false)
}
// AddPostStartHookOrDie allows you to add a PostStartHook, but dies on failure
func (s *WebServer) AddPostStartHookOrDie(name string, hook PostStartHookFunc) {
if err := s.AddPostStartHook(name, hook); err != nil {
slog.Error(fmt.Sprintf("Error registering PostStartHook %q: %s", name, err.Error()))
os.Exit(1)
}
}
// AddBootSequencePreShutdownHook allows you to add a PreShutdownHook in reverse order.
func (s *WebServer) AddBootSequencePreShutdownHook(name string, hook PreShutdownHookFunc) error {
return s.addPreShutdownHook(name, hook, true)
}
// AddPreShutdownHook allows you to add a PreShutdownHook.
func (s *WebServer) AddPreShutdownHook(name string, hook PreShutdownHookFunc) error {
return s.addPreShutdownHook(name, hook, false)
}
// AddPreShutdownHookOrDie allows you to add a PostStartHook, but dies on failure
func (s *WebServer) AddPreShutdownHookOrDie(name string, hook PreShutdownHookFunc) {
if err := s.AddPreShutdownHook(name, hook); err != nil {
slog.Error(fmt.Sprintf("Error registering PreShutdownHook %q: %s", name, err.Error()))
os.Exit(1)
}
}
// RunPostStartHooks runs the PostStartHooks for the server
func (s *WebServer) RunPostStartHooks(ctx context.Context) error {
s.postStartHookLock.Lock()
defer s.postStartHookLock.Unlock()
s.postStartHooksCalled = true
g, gCtx := errgroup.WithContext(ctx)
var keys = s.postStartHookOrderedKeys
for k := range s.postStartHooks {
if !slices.Contains(s.postStartHookOrderedKeys, k) {
keys = append(keys, k)
}
}
for i, k := range keys {
if v, has := s.postStartHooks[k]; has {
hookName, hookEntry := k, v
if i < len(s.postStartHookOrderedKeys) {
if err := runPostStartHook(gCtx, hookName, hookEntry); err != nil {
return err
}
continue
}
g.Go(func() error {
return runPostStartHook(gCtx, hookName, hookEntry)
})
} else { // never happen
hookName := k
slog.Warn(fmt.Sprintf("unknown PostStartHook %q", hookName))
}
}
return g.Wait()
}
// RunPreShutdownHooks runs the PreShutdownHooks for the server
func (s *WebServer) RunPreShutdownHooks() error {
s.preShutdownHookLock.Lock()
defer s.preShutdownHookLock.Unlock()
s.preShutdownHooksCalled = true
var keys = s.preShutdownHookOrderedKeys
for k := range s.preShutdownHooks {
if !slices.Contains(s.preShutdownHookOrderedKeys, k) {
keys = append(keys, k)
}
}
slices.Reverse(keys)
var errs []error
for _, k := range keys {
if v, has := s.preShutdownHooks[k]; has {
hookName, hookEntry := k, v
errs = append(errs, runPreShutdownHook(hookName, hookEntry))
} else {
hookName := k
slog.Warn(fmt.Sprintf("unknown PreShutdownHook %q", hookName))
}
}
return errors.Join(errs...)
}
// isPostStartHookRegistered checks whether a given PostStartHook is registered
func (s *WebServer) isPostStartHookRegistered(name string) bool {
s.postStartHookLock.Lock()
defer s.postStartHookLock.Unlock()
_, exists := s.postStartHooks[name]
return exists
}
func (s *WebServer) addPostStartHook(name string, hook PostStartHookFunc, order bool) error {
if len(name) == 0 {
return fmt.Errorf("missing name")
}
if hook == nil {
return fmt.Errorf("hook func may not be nil: %q", name)
}
s.postStartHookLock.Lock()
defer s.postStartHookLock.Unlock()
if s.postStartHooksCalled {
return fmt.Errorf("unable to add %q because PostStartHooks have already been called", name)
}
if postStartHook, exists := s.postStartHooks[name]; exists {
// this is programmer error, but it can be hard to debug
return fmt.Errorf("unable to add %q because it was already registered by: %s", name, postStartHook.originatingStack)
}
// done is closed when the poststarthook is finished. This is used by the health check to be able to indicate
// that the poststarthook is finished
done := make(chan struct{})
if err := s.AddBootSequenceHealthChecks(postStartHookHealthz{name: "poststarthook/" + name, done: done}); err != nil {
return err
}
if order {
s.postStartHookOrderedKeys = append(s.postStartHookOrderedKeys, name)
}
s.postStartHooks[name] = postStartHookEntry{hook: hook, originatingStack: string(debug.Stack()), done: done}
return nil
}
func (s *WebServer) addPreShutdownHook(name string, hook PreShutdownHookFunc, order bool) error {
if len(name) == 0 {
return fmt.Errorf("missing name")
}
if hook == nil {
return nil
}
s.preShutdownHookLock.Lock()
defer s.preShutdownHookLock.Unlock()
if s.preShutdownHooksCalled {
return fmt.Errorf("unable to add %q because PreShutdownHooks have already been called", name)
}
if _, exists := s.preShutdownHooks[name]; exists {
return fmt.Errorf("unable to add %q because it is already registered", name)
}
if order {
s.preShutdownHookOrderedKeys = append(s.preShutdownHookOrderedKeys, name)
}
s.preShutdownHooks[name] = preShutdownHookEntry{hook: hook}
return nil
}
func runPostStartHook(ctx context.Context, name string, entry postStartHookEntry) error {
var err error
func() {
// don't let the hook *accidentally* panic and kill the server
defer runtime.NeverPanicButLog.Recover()
err = entry.hook(ctx)
}()
// if the hook intentionally wants to kill server, let it.
if err != nil {
return fmt.Errorf("PostStartHook %q failed: %w", name, err)
}
close(entry.done)
return nil
}
func runPreShutdownHook(name string, entry preShutdownHookEntry) error {
var err error
func() {
// don't let the hook *accidentally* panic and kill the server
defer runtime.NeverPanicButLog.Recover()
err = entry.hook()
}()
if err != nil {
return fmt.Errorf("PreShutdownHook %q failed: %w", name, err)
}
return nil
}
// postStartHookHealthz implements a healthz check for poststarthooks. It will return a "hookNotFinished"
// error until the poststarthook is finished.
type postStartHookHealthz struct {
name string
// done will be closed when the postStartHook is finished
done chan struct{}
}
var _ healthz.HealthChecker = postStartHookHealthz{}
func (h postStartHookHealthz) Name() string {
return h.name
}
var hookNotFinished = errors.New("not finished")
func (h postStartHookHealthz) Check(req *http.Request) error {
select {
case <-h.done:
return nil
default:
return hookNotFinished
}
}