/
session.go
255 lines (231 loc) · 6.91 KB
/
session.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
package lock
import (
"context"
"fmt"
"math/rand"
"strconv"
"sync"
"time"
"github.com/promoboxx/go-metric-client/metrics"
otext "github.com/opentracing/opentracing-go/ext"
)
// Tasker can do the work associated with the tasks passed to it.
// It should return any completed tasks so they can by flaged as "finished"
type Tasker func(ctx context.Context, tasks []Task) ([]Task, error)
// Runner will loop and run tasks assigned to it
type Runner struct {
stop chan bool
stopGroup *sync.WaitGroup
sessionMutex sync.RWMutex
sessionID int64
tasksPerSession int64
dbFinder DBFinder
client metrics.Client
scanTask ScanTask
loopTick time.Duration
logger Logger
tasker Tasker
name string
}
// NewRunner will create a new Runner to handle a type of task
// dbFinder can get an instance of the Database interface on demand
// scanTask can read from a sql.row into a Task
// tasker can complete Tasks
// looptick defines how often to check for tasks to complete
// client is a go-metrics-client that will also start spans for us
// logger is optional and will log errors if provided
func NewRunner(dbFinder DBFinder, scanTask ScanTask, tasker Tasker, loopTick time.Duration, tasksPerSession int64, logger Logger, name string, client metrics.Client) *Runner {
if client == nil {
return nil
}
if logger == nil {
logger = &noopLogger{}
}
var sg sync.WaitGroup
return &Runner{
dbFinder: dbFinder,
client: client,
scanTask: scanTask,
loopTick: loopTick,
tasksPerSession: tasksPerSession,
logger: logger,
tasker: tasker,
name: name,
stopGroup: &sg,
}
}
// Run will start looping and processing tasks
// dont call this more than once.
func (r *Runner) Run() error {
db, err := r.dbFinder()
if err != nil {
return err
}
ctx := context.Background()
r.sessionMutex.Lock()
r.sessionID, err = r.startSession(ctx, db)
r.sessionMutex.Unlock()
if err != nil {
return err
}
r.stop = make(chan bool)
go func() {
// sleep up to 10 seconds to break up services that start at the same time
time.Sleep(time.Duration(rand.Int63n(10)) * time.Second)
// setup a ticker to get and do work
tick := time.Tick(r.loopTick)
for {
select {
case <-r.stop: // if Stop() was called, exit
err := r.endSession(context.Background())
if err != nil {
r.logger.Printf("Error ending session: %v", err)
}
return
default:
// noop
}
select {
case <-tick:
// doWork until no tasks remain
for {
// use wait group to block while doing work.
r.stopGroup.Add(1)
tasks, err := r.doWork(context.Background())
if err != nil {
r.logger.Printf("Error doing work: %v", err)
r.stopGroup.Done()
break
}
if tasks == nil || len(tasks) == 0 {
r.stopGroup.Done()
break
}
r.stopGroup.Done()
}
}
}
}()
go func() {
// setup a ticker bump the session every 30 seconds
// This will keep the session active even when working on tasks for a long time.
// When the service shuts down bump will stop being called, sessions will eventually expire,
// and other services will pick up new work.
tick := time.Tick(time.Second * 30)
for {
select {
case <-tick:
r.sessionMutex.RLock()
err := db.BumpSession(context.Background(), r.sessionID)
r.sessionMutex.RUnlock()
if err != nil {
r.logger.Printf("Error bumping session: %v", err)
}
}
}
}()
return nil
}
func (r *Runner) startSession(ctx context.Context, db Database) (sessionID int64, err error) {
span, spanCtx := r.client.StartSpanWithContext(ctx, "runner start session")
defer func() {
if err != nil {
otext.Error.Set(span, true)
span.SetTag("inner-error", err)
}
span.Finish()
}()
sessionID, err = db.StartSession(spanCtx)
span.SetTag("session_id", sessionID)
return sessionID, err
}
func (r *Runner) endSession(ctx context.Context) (err error) {
span, spanCtx := r.client.StartSpanWithContext(ctx, "runner end session")
defer func() {
if err != nil {
otext.Error.Set(span, true)
span.SetTag("inner-error", err)
}
span.Finish()
}()
db, err := r.dbFinder()
if err != nil {
return err
}
r.sessionMutex.Lock()
err = db.EndSession(spanCtx, r.sessionID)
r.sessionMutex.Unlock()
if err != nil {
return fmt.Errorf("Error ending session: %v", err)
}
return
}
func (r *Runner) doWork(ctx context.Context) (tasks []Task, err error) {
span, spanCtx := r.client.StartSpanWithContext(ctx, "doing work")
start := time.Now()
name := r.name
sessionID := strconv.FormatInt(r.sessionID, 10)
params := make(map[string]string)
r.client.BackgroundRate(sessionID, name, params, 1)
defer func() {
if err != nil {
otext.Error.Set(span, true)
span.SetTag("inner-error", err)
}
span.Finish()
}()
// get work and process
db, err := r.dbFinder()
if err != nil {
r.handleError(start, sessionID, name, "Failed to find DB", err.Error(), params)
return tasks, fmt.Errorf("Error finding DB: %v", err)
}
r.sessionMutex.RLock()
tasks, dbErr := db.GetWork(spanCtx, r.sessionID, r.tasksPerSession, r.scanTask)
r.sessionMutex.RUnlock()
if dbErr != nil {
switch dbErr.Code() {
case SQLErrorSessionNotFound:
r.logger.Printf("Session expired. Getting new one")
r.sessionMutex.Lock()
r.sessionID, err = db.StartSession(spanCtx)
r.sessionMutex.Unlock()
if err != nil {
r.handleError(start, sessionID, name, "Failed to start session", err.Error()+" with dbError: "+dbErr.Error(), params)
return tasks, fmt.Errorf("Error starting new session: %v", dbErr)
}
default:
r.handleError(start, sessionID, name, "Failed getting work from db", err.Error()+" with dbError: "+dbErr.Error(), params)
return tasks, fmt.Errorf("Error getting work from db: %v", dbErr)
}
}
completedTasks, err := r.tasker(spanCtx, tasks)
if err != nil {
r.handleError(start, sessionID, name, "Error running tasks", err.Error(), params)
return tasks, fmt.Errorf("Error running tasks: %v", err)
}
taskIDs := make([]string, len(completedTasks))
for i, t := range completedTasks {
taskIDs[i] = t.GetID()
}
dbErr = db.FinishTasks(spanCtx, taskIDs)
if dbErr != nil {
r.handleError(start, sessionID, name, "Error finishing tasks", dbErr.Error(), params)
return tasks, fmt.Errorf("Error finishing tasks: %v", dbErr)
}
end := time.Since(start)
r.client.BackgroundDuration(sessionID, name, params, end)
return tasks, nil
}
// Does common error stuff
func (r *Runner) handleError(start time.Time, sessionID, name, code, message string, params map[string]string) {
end := time.Since(start)
r.client.BackgroundDuration(sessionID, name, params, end)
r.client.BackgroundError(sessionID, name, params, code, message, 1)
}
// Stop stops the runner from looping
// Stop returns a WaitGroup which you can wait on to ensure all work is finished
func (r *Runner) Stop() *sync.WaitGroup {
close(r.stop)
return r.stopGroup
}