Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
268 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
// Copyright (C) 2017 ScyllaDB | ||
|
||
package fsm | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/pkg/errors" | ||
) | ||
|
||
// ErrEventRejected is the error returned when the state machine cannot process | ||
// an event in the state that it is in. | ||
var ErrEventRejected = errors.New("event rejected") | ||
|
||
const ( | ||
// NoOp represents a no-op event. State machine stops when this event is emitted. | ||
NoOp Event = "NoOp" | ||
) | ||
|
||
// State represents an extensible state type in the state machine. | ||
type State string | ||
|
||
// Event represents an extensible event type in the state machine. | ||
type Event string | ||
|
||
// Action represents the action to be executed in a given state. | ||
type Action func(ctx context.Context) (Event, error) | ||
|
||
// Events represents a mapping of events and states. | ||
type Events map[Event]State | ||
|
||
// Transition binds a state with an action and a set of events it can handle. | ||
type Transition struct { | ||
Action Action | ||
Events Events | ||
} | ||
|
||
// Hook is called on each state machine transition. | ||
type Hook func(ctx context.Context, currentState, nextState State, event Event) error | ||
|
||
// StateTransitions represents a mapping of states and their implementations. | ||
type StateTransitions map[State]Transition | ||
|
||
// StateMachine represents the state machine. | ||
type StateMachine struct { | ||
// Current represents the current state. | ||
current State | ||
|
||
// StateTransitions holds the configuration of states and events handled by the state machine. | ||
stateTransitions StateTransitions | ||
|
||
// TransitionHook is called on every state transition. | ||
transitionHook Hook | ||
} | ||
|
||
// New returns initialized state machine. | ||
func New(state State, stateTransitions StateTransitions, hook Hook) *StateMachine { | ||
return &StateMachine{ | ||
current: state, | ||
stateTransitions: stateTransitions, | ||
transitionHook: hook, | ||
} | ||
} | ||
|
||
// getNextState returns the next state for the event given the machine's current | ||
// state, or an error if the event can't be handled in the given state. | ||
func (s *StateMachine) getNextState(event Event) (State, error) { | ||
if transition, ok := s.stateTransitions[s.current]; ok { | ||
if transition.Events != nil { | ||
if next, ok := transition.Events[event]; ok { | ||
return next, nil | ||
} | ||
} | ||
} | ||
return s.current, ErrEventRejected | ||
} | ||
|
||
// Transition triggers current state action and sends event to the state machine. | ||
func (s *StateMachine) Transition(ctx context.Context) error { | ||
// Pick next transition according to current state. | ||
transition := s.stateTransitions[s.current] | ||
event, err := transition.Action(ctx) | ||
if err != nil { | ||
return err | ||
} | ||
if event == NoOp { | ||
return nil | ||
} | ||
|
||
for { | ||
// Determine the next state for the event given the machine's current state. | ||
nextState, err := s.getNextState(event) | ||
if err != nil { | ||
return errors.Wrapf(ErrEventRejected, "rejected %s", err.Error()) | ||
} | ||
|
||
// Identify the state definition for the next state. | ||
nextTransition, ok := s.stateTransitions[nextState] | ||
if !ok || nextTransition.Action == nil { | ||
return errors.Wrapf(ErrEventRejected, "unknown transition %q for event %q", nextState, event) | ||
} | ||
|
||
if s.transitionHook != nil { | ||
if err := s.transitionHook(ctx, s.current, nextState, event); err != nil { | ||
return err | ||
} | ||
} | ||
// Transition over to the next state. | ||
s.current = nextState | ||
|
||
// Execute the next state's action and loop over again if the event returned | ||
// is not a no-op. | ||
nextEvent, err := nextTransition.Action(ctx) | ||
if err != nil { | ||
return err | ||
} | ||
if nextEvent == NoOp { | ||
return nil | ||
} | ||
event = nextEvent | ||
} | ||
} | ||
|
||
// Current return current state machine state. | ||
func (s *StateMachine) Current() State { | ||
return s.current | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
// Copyright (C) 2017 ScyllaDB | ||
|
||
package fsm_test | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
. "github.com/onsi/ginkgo" | ||
. "github.com/onsi/gomega" | ||
"github.com/pkg/errors" | ||
"github.com/scylladb/scylla-operator/pkg/util/fsm" | ||
"go.uber.org/atomic" | ||
"sigs.k8s.io/controller-runtime/pkg/envtest/printer" | ||
) | ||
|
||
const ( | ||
Registration fsm.State = "registration" | ||
DoctorAppointment fsm.State = "doctor_appointment" | ||
ApplyMedicine fsm.State = "medicine" | ||
RequirePayment fsm.State = "require_payment" | ||
|
||
ActionSuccess fsm.Event = "success" | ||
) | ||
|
||
var _ = Describe("FSM tests", func() { | ||
type patient struct { | ||
registrationDone bool | ||
doctorAnalyzed bool | ||
medicineApplied bool | ||
paymentDone bool | ||
} | ||
|
||
var ( | ||
ctx context.Context | ||
hookCalled *atomic.Int64 | ||
|
||
countingHook = func(ctx context.Context, currentState, nextState fsm.State, event fsm.Event) error { | ||
hookCalled.Inc() | ||
return nil | ||
} | ||
) | ||
|
||
BeforeEach(func() { | ||
ctx = context.Background() | ||
hookCalled = atomic.NewInt64(0) | ||
}) | ||
|
||
It("Full transition", func() { | ||
p := patient{} | ||
|
||
fsm := fsm.New(Registration, fsm.StateTransitions{ | ||
Registration: fsm.Transition{ | ||
Action: func(ctx context.Context) (fsm.Event, error) { | ||
p.registrationDone = true | ||
return ActionSuccess, nil | ||
}, | ||
Events: map[fsm.Event]fsm.State{ | ||
ActionSuccess: DoctorAppointment, | ||
}, | ||
}, | ||
DoctorAppointment: fsm.Transition{ | ||
Action: func(ctx context.Context) (fsm.Event, error) { | ||
p.doctorAnalyzed = true | ||
return ActionSuccess, nil | ||
}, | ||
Events: map[fsm.Event]fsm.State{ | ||
ActionSuccess: ApplyMedicine, | ||
}, | ||
}, | ||
ApplyMedicine: fsm.Transition{ | ||
Action: func(ctx context.Context) (fsm.Event, error) { | ||
p.medicineApplied = true | ||
return ActionSuccess, nil | ||
}, | ||
Events: map[fsm.Event]fsm.State{ | ||
ActionSuccess: RequirePayment, | ||
}, | ||
}, | ||
RequirePayment: fsm.Transition{ | ||
Action: func(ctx context.Context) (fsm.Event, error) { | ||
p.paymentDone = true | ||
return fsm.NoOp, nil | ||
}, | ||
Events: map[fsm.Event]fsm.State{ | ||
ActionSuccess: DoctorAppointment, | ||
}, | ||
}, | ||
}, countingHook) | ||
|
||
Expect(fsm.Transition(ctx)).To(Succeed()) | ||
|
||
Expect(hookCalled.Load()).To(Equal(int64(3))) | ||
|
||
Expect(p.registrationDone).To(BeTrue()) | ||
Expect(p.doctorAnalyzed).To(BeTrue()) | ||
Expect(p.medicineApplied).To(BeTrue()) | ||
Expect(p.paymentDone).To(BeTrue()) | ||
}) | ||
|
||
It("action failure", func() { | ||
fsm := fsm.New(Registration, fsm.StateTransitions{ | ||
Registration: fsm.Transition{ | ||
Action: func(ctx context.Context) (fsm.Event, error) { | ||
return fsm.NoOp, errors.New("fail!") | ||
}, | ||
}, | ||
}, countingHook) | ||
|
||
Expect(fsm.Transition(ctx)).To(HaveOccurred()) | ||
Expect(hookCalled.Load()).To(Equal(int64(0))) | ||
Expect(fsm.Current()).To(Equal(Registration)) | ||
}) | ||
|
||
It("hook failure interrupts machine", func() { | ||
fsm := fsm.New(Registration, fsm.StateTransitions{ | ||
Registration: fsm.Transition{ | ||
Action: func(ctx context.Context) (fsm.Event, error) { | ||
return ActionSuccess, nil | ||
}, | ||
Events: map[fsm.Event]fsm.State{ | ||
ActionSuccess: DoctorAppointment, | ||
}, | ||
}, | ||
DoctorAppointment: fsm.Transition{}, | ||
}, func(ctx context.Context, currentState, nextState fsm.State, event fsm.Event) error { | ||
return errors.New("fail!") | ||
}) | ||
|
||
Expect(fsm.Transition(ctx)).To(HaveOccurred()) | ||
Expect(fsm.Current()).To(Equal(Registration)) | ||
}) | ||
}) | ||
|
||
func TestFSM(t *testing.T) { | ||
RegisterFailHandler(Fail) | ||
|
||
RunSpecsWithDefaultAndCustomReporters(t, | ||
"FSM Suite", | ||
[]Reporter{printer.NewlineReporter{}}) | ||
} |