/
saga_recovery.go
83 lines (70 loc) · 1.96 KB
/
saga_recovery.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
package saga
import "fmt"
type SagaRecoveryType int
//
// Saga Recovery Types define how to interpret SagaState in RecoveryMode.
//
// ForwardRecovery: all tasks in the saga must be executed at least once.
// tasks MUST BE idempotent
//
// RollbackRecovery: if Saga is Aborted or in unsafe state, compensating
// tasks for all started tasks need to be executed.
// compensating tasks MUST BE idempotent.
//
const (
RollbackRecovery SagaRecoveryType = iota
ForwardRecovery
)
//
// Recovers SagaState from SagaLog messages
//
func recoverState(sagaId string, saga SagaCoordinator) (*SagaState, error) {
// Get Logged Messages For this Saga from the Log.
msgs, err := saga.log.GetMessages(sagaId)
if err != nil {
return nil, err
}
if msgs == nil || len(msgs) == 0 {
return nil, nil
}
// Reconstruct Saga State from Logged Messages
startMsg := msgs[0]
if startMsg.MsgType != StartSaga {
return nil, fmt.Errorf("InvalidMessages: first message must be StartSaga")
}
state, err := makeSagaState(sagaId, startMsg.Data)
if err != nil {
return nil, err
}
for _, msg := range msgs {
// skip applying StartSaga message we already did this
// duplicate messages are just ignored since msgs are idempotent
if msg.MsgType == StartSaga {
continue
}
err = updateSagaState(state, msg)
if err != nil {
return nil, err
}
}
return state, nil
}
//
// Returns true if saga is in a safe state, i.e. execution can pick up where
// it left off. This is only used in RollbackRecovery
//
// A Saga is in a Safe State if all StartedTasks also have EndTask Messages
// A Saga is also in a Safe State if the Saga has been aborted and compensating
// actions have started to be applied.
//
func isSagaInSafeState(state *SagaState) bool {
if state.IsSagaAborted() {
return true
}
for taskId := range state.taskState {
if state.IsTaskStarted(taskId) && !state.IsTaskCompleted(taskId) {
return false
}
}
return true
}