Skip to content

Commit

Permalink
Reuse time.Timer in ack and rtx timer utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
paulwe committed Apr 16, 2024
1 parent 45982cd commit 608e170
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 51 deletions.
76 changes: 34 additions & 42 deletions ack_timer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package sctp

import (
"math"
"sync"
"time"
)
Expand All @@ -17,58 +18,51 @@ type ackTimerObserver interface {
onAckTimeout()
}

type ackTimerState int

const (
ackTimerStopped ackTimerState = iota
ackTimerStarted
ackTimerClosed
)

// ackTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1
type ackTimer struct {
observer ackTimerObserver
interval time.Duration
stopFunc stopAckTimerLoop
closed bool
mutex sync.RWMutex
state ackTimerState
timer *time.Timer
}

type stopAckTimerLoop func()

// newAckTimer creates a new acknowledgement timer used to enable delayed ack.
func newAckTimer(observer ackTimerObserver) *ackTimer {
return &ackTimer{
observer: observer,
interval: ackInterval,
t := &ackTimer{observer: observer}
t.timer = time.AfterFunc(math.MaxInt64, t.timeout)
t.timer.Stop()
return t
}

func (t *ackTimer) timeout() {
t.mutex.Lock()
if t.state == ackTimerStarted {
t.state = ackTimerStopped
defer t.observer.onAckTimeout()
}
t.mutex.Unlock()
}

// start starts the timer.
func (t *ackTimer) start() bool {
t.mutex.Lock()
defer t.mutex.Unlock()

// this timer is already closed
if t.closed {
// this timer is already closed or already running
if t.state != ackTimerStopped {
return false
}

// this is a noop if the timer is already running
if t.stopFunc != nil {
return false
}

cancelCh := make(chan struct{})

go func() {
timer := time.NewTimer(t.interval)

select {
case <-timer.C:
t.stop()
t.observer.onAckTimeout()
case <-cancelCh:
timer.Stop()
}
}()

t.stopFunc = func() {
close(cancelCh)
}

t.state = ackTimerStarted
t.timer.Reset(ackInterval)
return true
}

Expand All @@ -78,9 +72,9 @@ func (t *ackTimer) stop() {
t.mutex.Lock()
defer t.mutex.Unlock()

if t.stopFunc != nil {
t.stopFunc()
t.stopFunc = nil
if t.state == ackTimerStarted {
t.timer.Stop()
t.state = ackTimerStopped
}
}

Expand All @@ -90,12 +84,10 @@ func (t *ackTimer) close() {
t.mutex.Lock()
defer t.mutex.Unlock()

if t.stopFunc != nil {
t.stopFunc()
t.stopFunc = nil
if t.state == ackTimerStarted {
t.timer.Stop()
}

t.closed = true
t.state = ackTimerClosed
}

// isRunning tests if the timer is running.
Expand All @@ -104,5 +96,5 @@ func (t *ackTimer) isRunning() bool {
t.mutex.RLock()
defer t.mutex.RUnlock()

return (t.stopFunc != nil)
return t.state == ackTimerStarted
}
18 changes: 10 additions & 8 deletions ack_timer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,16 @@ func TestAckTimer(t *testing.T) {
},
})

// should start ok
ok := rt.start()
assert.True(t, ok, "start() should succeed")
assert.True(t, rt.isRunning(), "should be running")
for i := 0; i < 2; i++ {
// should start ok
ok := rt.start()
assert.True(t, ok, "start() should succeed")
assert.True(t, rt.isRunning(), "should be running")

// stop immedidately
rt.stop()
assert.False(t, rt.isRunning(), "should not be running")
// stop immedidately
rt.stop()
assert.False(t, rt.isRunning(), "should not be running")
}

// Sleep more than 200msec of interval to test if it never times out
time.Sleep(ackInterval + 50*time.Millisecond)
Expand All @@ -86,7 +88,7 @@ func TestAckTimer(t *testing.T) {
"should not be timed out (actual: %d)", atomic.LoadUint32(&nCbs))

// can start again
ok = rt.start()
ok := rt.start()
assert.True(t, ok, "start() should succeed again")
assert.True(t, rt.isRunning(), "should be running")

Expand Down
5 changes: 4 additions & 1 deletion rtx_timer.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,12 @@ func (t *rtxTimer) start(rto float64) bool {
go func() {
canceling := false

timer := time.NewTimer(math.MaxInt64)
timer.Stop()

for !canceling {
timeout := calculateNextTimeout(rto, nRtos, t.rtoMax)
timer := time.NewTimer(time.Duration(timeout) * time.Millisecond)
timer.Reset(time.Duration(timeout) * time.Millisecond)

select {
case <-timer.C:
Expand Down

0 comments on commit 608e170

Please sign in to comment.