diff --git a/pkg/util/fsm/fsm.go b/pkg/util/fsm/fsm.go new file mode 100644 index 0000000000..986d2cc401 --- /dev/null +++ b/pkg/util/fsm/fsm.go @@ -0,0 +1,127 @@ +// Copyright (C) 2017 ScyllaDB + +package fsm + +import ( + "context" + + "github.com/pkg/errors" +) + +// ErrEventRejected is the error returned when the state machine cannot process +// an event in the state that it is in. +var ErrEventRejected = errors.New("event rejected") + +const ( + // NoOp represents a no-op event. State machine stops when this event is emitted. + NoOp Event = "NoOp" +) + +// State represents an extensible state type in the state machine. +type State string + +// Event represents an extensible event type in the state machine. +type Event string + +// Action represents the action to be executed in a given state. +type Action func(ctx context.Context) (Event, error) + +// Events represents a mapping of events and states. +type Events map[Event]State + +// Transition binds a state with an action and a set of events it can handle. +type Transition struct { + Action Action + Events Events +} + +// Hook is called on each state machine transition. +type Hook func(ctx context.Context, currentState, nextState State, event Event) error + +// StateTransitions represents a mapping of states and their implementations. +type StateTransitions map[State]Transition + +// StateMachine represents the state machine. +type StateMachine struct { + // Current represents the current state. + current State + + // StateTransitions holds the configuration of states and events handled by the state machine. + stateTransitions StateTransitions + + // TransitionHook is called on every state transition. + transitionHook Hook +} + +// New returns initialized state machine. +func New(state State, stateTransitions StateTransitions, hook Hook) *StateMachine { + return &StateMachine{ + current: state, + stateTransitions: stateTransitions, + transitionHook: hook, + } +} + +// getNextState returns the next state for the event given the machine's current +// state, or an error if the event can't be handled in the given state. +func (s *StateMachine) getNextState(event Event) (State, error) { + if transition, ok := s.stateTransitions[s.current]; ok { + if transition.Events != nil { + if next, ok := transition.Events[event]; ok { + return next, nil + } + } + } + return s.current, ErrEventRejected +} + +// Transition triggers current state action and sends event to the state machine. +func (s *StateMachine) Transition(ctx context.Context) error { + // Pick next transition according to current state. + transition := s.stateTransitions[s.current] + event, err := transition.Action(ctx) + if err != nil { + return err + } + if event == NoOp { + return nil + } + + for { + // Determine the next state for the event given the machine's current state. + nextState, err := s.getNextState(event) + if err != nil { + return errors.Wrapf(ErrEventRejected, "rejected %s", err.Error()) + } + + // Identify the state definition for the next state. + nextTransition, ok := s.stateTransitions[nextState] + if !ok || nextTransition.Action == nil { + return errors.Wrapf(ErrEventRejected, "unknown transition %q for event %q", nextState, event) + } + + if s.transitionHook != nil { + if err := s.transitionHook(ctx, s.current, nextState, event); err != nil { + return err + } + } + // Transition over to the next state. + s.current = nextState + + // Execute the next state's action and loop over again if the event returned + // is not a no-op. + nextEvent, err := nextTransition.Action(ctx) + if err != nil { + return err + } + if nextEvent == NoOp { + return nil + } + event = nextEvent + } +} + +// Current return current state machine state. +func (s *StateMachine) Current() State { + return s.current +} diff --git a/pkg/util/fsm/fsm_test.go b/pkg/util/fsm/fsm_test.go new file mode 100644 index 0000000000..077784b5ca --- /dev/null +++ b/pkg/util/fsm/fsm_test.go @@ -0,0 +1,141 @@ +// Copyright (C) 2017 ScyllaDB + +package fsm_test + +import ( + "context" + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/pkg/errors" + "github.com/scylladb/scylla-operator/pkg/util/fsm" + "go.uber.org/atomic" + "sigs.k8s.io/controller-runtime/pkg/envtest/printer" +) + +const ( + Registration fsm.State = "registration" + DoctorAppointment fsm.State = "doctor_appointment" + ApplyMedicine fsm.State = "medicine" + RequirePayment fsm.State = "require_payment" + + ActionSuccess fsm.Event = "success" +) + +var _ = Describe("FSM tests", func() { + type patient struct { + registrationDone bool + doctorAnalyzed bool + medicineApplied bool + paymentDone bool + } + + var ( + ctx context.Context + hookCalled *atomic.Int64 + + countingHook = func(ctx context.Context, currentState, nextState fsm.State, event fsm.Event) error { + hookCalled.Inc() + return nil + } + ) + + BeforeEach(func() { + ctx = context.Background() + hookCalled = atomic.NewInt64(0) + }) + + It("Full transition", func() { + p := patient{} + + fsm := fsm.New(Registration, fsm.StateTransitions{ + Registration: fsm.Transition{ + Action: func(ctx context.Context) (fsm.Event, error) { + p.registrationDone = true + return ActionSuccess, nil + }, + Events: map[fsm.Event]fsm.State{ + ActionSuccess: DoctorAppointment, + }, + }, + DoctorAppointment: fsm.Transition{ + Action: func(ctx context.Context) (fsm.Event, error) { + p.doctorAnalyzed = true + return ActionSuccess, nil + }, + Events: map[fsm.Event]fsm.State{ + ActionSuccess: ApplyMedicine, + }, + }, + ApplyMedicine: fsm.Transition{ + Action: func(ctx context.Context) (fsm.Event, error) { + p.medicineApplied = true + return ActionSuccess, nil + }, + Events: map[fsm.Event]fsm.State{ + ActionSuccess: RequirePayment, + }, + }, + RequirePayment: fsm.Transition{ + Action: func(ctx context.Context) (fsm.Event, error) { + p.paymentDone = true + return fsm.NoOp, nil + }, + Events: map[fsm.Event]fsm.State{ + ActionSuccess: DoctorAppointment, + }, + }, + }, countingHook) + + Expect(fsm.Transition(ctx)).To(Succeed()) + + Expect(hookCalled.Load()).To(Equal(int64(3))) + + Expect(p.registrationDone).To(BeTrue()) + Expect(p.doctorAnalyzed).To(BeTrue()) + Expect(p.medicineApplied).To(BeTrue()) + Expect(p.paymentDone).To(BeTrue()) + }) + + It("action failure", func() { + fsm := fsm.New(Registration, fsm.StateTransitions{ + Registration: fsm.Transition{ + Action: func(ctx context.Context) (fsm.Event, error) { + return fsm.NoOp, errors.New("fail!") + }, + }, + }, countingHook) + + Expect(fsm.Transition(ctx)).To(HaveOccurred()) + Expect(hookCalled.Load()).To(Equal(int64(0))) + Expect(fsm.Current()).To(Equal(Registration)) + }) + + It("hook failure interrupts machine", func() { + fsm := fsm.New(Registration, fsm.StateTransitions{ + Registration: fsm.Transition{ + Action: func(ctx context.Context) (fsm.Event, error) { + return ActionSuccess, nil + }, + Events: map[fsm.Event]fsm.State{ + ActionSuccess: DoctorAppointment, + }, + }, + DoctorAppointment: fsm.Transition{}, + }, func(ctx context.Context, currentState, nextState fsm.State, event fsm.Event) error { + return errors.New("fail!") + }) + + Expect(fsm.Transition(ctx)).To(HaveOccurred()) + Expect(fsm.Current()).To(Equal(Registration)) + }) +}) + +func TestFSM(t *testing.T) { + RegisterFailHandler(Fail) + + RunSpecsWithDefaultAndCustomReporters(t, + "FSM Suite", + []Reporter{printer.NewlineReporter{}}) +}