Skip to content

Commit

Permalink
Merge pull request #68 from qmuntal/asyncqueued
Browse files Browse the repository at this point in the history
Fix race conditions in queued Fire
  • Loading branch information
qmuntal committed Sep 6, 2023
2 parents 23039c6 + 8af1ab7 commit aa77393
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 53 deletions.
88 changes: 88 additions & 0 deletions modes.go
@@ -0,0 +1,88 @@
package stateless

import (
"context"
"sync"
"sync/atomic"
)

type fireMode interface {
Fire(ctx context.Context, trigger Trigger, args ...any) error
Firing() bool
}

type fireModeImmediate struct {
ops atomic.Uint64
sm *StateMachine
}

func (f *fireModeImmediate) Firing() bool {
return f.ops.Load() > 0
}

func (f *fireModeImmediate) Fire(ctx context.Context, trigger Trigger, args ...any) error {
f.ops.Add(1)
defer f.ops.Add(^uint64(0))
return f.sm.internalFireOne(ctx, trigger, args...)
}

type queuedTrigger struct {
Context context.Context
Trigger Trigger
Args []any
}

type fireModeQueued struct {
firing atomic.Bool
sm *StateMachine

triggers []queuedTrigger
mu sync.Mutex // guards triggers
}

func (f *fireModeQueued) Firing() bool {
return f.firing.Load()
}

func (f *fireModeQueued) Fire(ctx context.Context, trigger Trigger, args ...any) error {
f.enqueue(ctx, trigger, args...)
for {
et, ok := f.fetch()
if !ok {
break
}
err := f.execute(et)
if err != nil {
return err
}
}
return nil
}

func (f *fireModeQueued) enqueue(ctx context.Context, trigger Trigger, args ...any) {
f.mu.Lock()
defer f.mu.Unlock()

f.triggers = append(f.triggers, queuedTrigger{Context: ctx, Trigger: trigger, Args: args})
}

func (f *fireModeQueued) fetch() (et queuedTrigger, ok bool) {
f.mu.Lock()
defer f.mu.Unlock()

if len(f.triggers) == 0 {
return queuedTrigger{}, false
}

if !f.firing.CompareAndSwap(false, true) {
return queuedTrigger{}, false
}

et, f.triggers = f.triggers[0], f.triggers[1:]
return et, true
}

func (f *fireModeQueued) execute(et queuedTrigger) error {
defer f.firing.Swap(false)
return f.sm.internalFireOne(et.Context, et.Trigger, et.Args...)
}
66 changes: 13 additions & 53 deletions statemachine.go
@@ -1,12 +1,10 @@
package stateless

import (
"container/list"
"context"
"fmt"
"reflect"
"sync"
"sync/atomic"
)

// State is used to to represent the possible machine states.
Expand Down Expand Up @@ -64,26 +62,29 @@ func callEvents(events []TransitionFunc, ctx context.Context, transition Transit
// It is safe to use the StateMachine concurrently, but non of the callbacks (state manipulation, actions, events, ...) are guarded,
// so it is up to the client to protect them against race conditions.
type StateMachine struct {
ops atomic.Uint64
stateConfig map[State]*stateRepresentation
triggerConfig map[Trigger]triggerWithParameters
stateAccessor func(context.Context) (State, error)
stateMutator func(context.Context, State) error
unhandledTriggerAction UnhandledTriggerActionFunc
onTransitioningEvents []TransitionFunc
onTransitionedEvents []TransitionFunc
eventQueue list.List
firingMode FiringMode
firingMutex sync.Mutex
stateMutex sync.RWMutex
mode fireMode
}

func newStateMachine() *StateMachine {
return &StateMachine{
func newStateMachine(firingMode FiringMode) *StateMachine {
sm := &StateMachine{
stateConfig: make(map[State]*stateRepresentation),
triggerConfig: make(map[Trigger]triggerWithParameters),
unhandledTriggerAction: UnhandledTriggerActionFunc(DefaultUnhandledTriggerAction),
}
if firingMode == FiringImmediate {
sm.mode = &fireModeImmediate{sm: sm}
} else {
sm.mode = &fireModeQueued{sm: sm}
}
return sm
}

// NewStateMachine returns a queued state machine.
Expand All @@ -94,7 +95,7 @@ func NewStateMachine(initialState State) *StateMachine {
// NewStateMachineWithMode returns a state machine with the desired firing mode
func NewStateMachineWithMode(initialState State, firingMode FiringMode) *StateMachine {
var stateMutex sync.Mutex
sm := newStateMachine()
sm := newStateMachine(firingMode)
reference := &struct {
State State
}{State: initialState}
Expand All @@ -109,16 +110,14 @@ func NewStateMachineWithMode(initialState State, firingMode FiringMode) *StateMa
reference.State = state
return nil
}
sm.firingMode = firingMode
return sm
}

// NewStateMachineWithExternalStorage returns a state machine with external state storage.
func NewStateMachineWithExternalStorage(stateAccessor func(context.Context) (State, error), stateMutator func(context.Context, State) error, firingMode FiringMode) *StateMachine {
sm := newStateMachine()
sm := newStateMachine(firingMode)
sm.stateAccessor = stateAccessor
sm.stateMutator = stateMutator
sm.firingMode = firingMode
return sm
}

Expand Down Expand Up @@ -276,7 +275,7 @@ func (sm *StateMachine) Configure(state State) *StateConfiguration {

// Firing returns true when the state machine is processing a trigger.
func (sm *StateMachine) Firing() bool {
return sm.ops.Load() != 0
return sm.mode.Firing()
}

// String returns a human-readable representation of the state machine.
Expand Down Expand Up @@ -321,49 +320,10 @@ func (sm *StateMachine) stateRepresentation(state State) *stateRepresentation {
}

func (sm *StateMachine) internalFire(ctx context.Context, trigger Trigger, args ...any) error {
switch sm.firingMode {
case FiringImmediate:
return sm.internalFireOne(ctx, trigger, args...)
case FiringQueued:
fallthrough
default:
return sm.internalFireQueued(ctx, trigger, args...)
}
}

type queuedTrigger struct {
Context context.Context
Trigger Trigger
Args []any
}

func (sm *StateMachine) internalFireQueued(ctx context.Context, trigger Trigger, args ...any) error {
sm.firingMutex.Lock()
sm.eventQueue.PushBack(queuedTrigger{Context: ctx, Trigger: trigger, Args: args})
sm.firingMutex.Unlock()
if sm.Firing() {
return nil
}

for {
sm.firingMutex.Lock()
e := sm.eventQueue.Front()
if e == nil {
sm.firingMutex.Unlock()
break
}
et := sm.eventQueue.Remove(e).(queuedTrigger)
sm.firingMutex.Unlock()
if err := sm.internalFireOne(et.Context, et.Trigger, et.Args...); err != nil {
return err
}
}
return nil
return sm.mode.Fire(ctx, trigger, args...)
}

func (sm *StateMachine) internalFireOne(ctx context.Context, trigger Trigger, args ...any) error {
sm.ops.Add(1)
defer sm.ops.Add(^uint64(0))
var (
config triggerWithParameters
ok bool
Expand Down

0 comments on commit aa77393

Please sign in to comment.