Skip to content

Commit

Permalink
Optimise payload queue for write path
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Apr 1, 2024
1 parent c04adbf commit 278dfbb
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 196 deletions.
2 changes: 1 addition & 1 deletion association.go
Original file line number Diff line number Diff line change
Expand Up @@ -2184,7 +2184,7 @@ func (a *Association) movePendingDataChunkToInflightQueue(c *chunkPayloadData) {
a.log.Tracef("[%s] sending ppi=%d tsn=%d ssn=%d sent=%d len=%d (%v,%v)",
a.name, c.payloadType, c.tsn, c.streamSequenceNumber, c.nSent, len(c.userData), c.beginningFragment, c.endingFragment)

a.inflightQueue.pushNoCheck(c)
a.inflightQueue.push(c)
}

// popPendingDataChunksToSend pops chunks from the pending queues as many as
Expand Down
8 changes: 4 additions & 4 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1191,7 +1191,7 @@ func TestCreateForwardTSN(t *testing.T) {

a.cumulativeTSNAckPoint = 9
a.advancedPeerTSNAckPoint = 10
a.inflightQueue.pushNoCheck(&chunkPayloadData{
a.inflightQueue.push(&chunkPayloadData{
beginningFragment: true,
endingFragment: true,
tsn: 10,
Expand All @@ -1218,7 +1218,7 @@ func TestCreateForwardTSN(t *testing.T) {

a.cumulativeTSNAckPoint = 9
a.advancedPeerTSNAckPoint = 12
a.inflightQueue.pushNoCheck(&chunkPayloadData{
a.inflightQueue.push(&chunkPayloadData{
beginningFragment: true,
endingFragment: true,
tsn: 10,
Expand All @@ -1228,7 +1228,7 @@ func TestCreateForwardTSN(t *testing.T) {
nSent: 1,
_abandoned: true,
})
a.inflightQueue.pushNoCheck(&chunkPayloadData{
a.inflightQueue.push(&chunkPayloadData{
beginningFragment: true,
endingFragment: true,
tsn: 11,
Expand All @@ -1238,7 +1238,7 @@ func TestCreateForwardTSN(t *testing.T) {
nSent: 1,
_abandoned: true,
})
a.inflightQueue.pushNoCheck(&chunkPayloadData{
a.inflightQueue.push(&chunkPayloadData{
beginningFragment: true,
endingFragment: true,
tsn: 12,
Expand Down
140 changes: 20 additions & 120 deletions payload_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,83 +3,45 @@

package sctp

import (
"fmt"
"sort"
)

type payloadQueue struct {
chunkMap map[uint32]*chunkPayloadData
sorted []uint32
dupTSN []uint32
tsns []uint32
nBytes int
}

func newPayloadQueue() *payloadQueue {
return &payloadQueue{chunkMap: map[uint32]*chunkPayloadData{}}
}

func (q *payloadQueue) updateSortedKeys() {
if q.sorted != nil {
func (q *payloadQueue) push(p *chunkPayloadData) {
if _, ok := q.chunkMap[p.tsn]; ok {
return
}

q.sorted = make([]uint32, len(q.chunkMap))
i := 0
for k := range q.chunkMap {
q.sorted[i] = k
i++
}

sort.Slice(q.sorted, func(i, j int) bool {
return sna32LT(q.sorted[i], q.sorted[j])
})
}

func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32) bool {
_, ok := q.chunkMap[p.tsn]
if ok || sna32LTE(p.tsn, cumulativeTSN) {
return false
}
return true
}

func (q *payloadQueue) pushNoCheck(p *chunkPayloadData) {
q.chunkMap[p.tsn] = p
q.nBytes += len(p.userData)
q.sorted = nil
}

// push pushes a payload data. If the payload data is already in our queue or
// older than our cumulativeTSN marker, it will be recored as duplications,
// which can later be retrieved using popDuplicates.
func (q *payloadQueue) push(p *chunkPayloadData, cumulativeTSN uint32) bool {
_, ok := q.chunkMap[p.tsn]
if ok || sna32LTE(p.tsn, cumulativeTSN) {
// Found the packet, log in dups
q.dupTSN = append(q.dupTSN, p.tsn)
return false
var pos int
for pos = len(q.tsns) - 1; pos >= 0; pos-- {
if q.tsns[pos] < p.tsn {
break
}
}

q.chunkMap[p.tsn] = p
q.nBytes += len(p.userData)
q.sorted = nil
return true
pos++
q.tsns = append(q.tsns, 0)
copy(q.tsns[pos+1:], q.tsns[pos:])
q.tsns[pos] = p.tsn
}

// pop pops only if the oldest chunk's TSN matches the given TSN.
func (q *payloadQueue) pop(tsn uint32) (*chunkPayloadData, bool) {
q.updateSortedKeys()

if len(q.chunkMap) > 0 && tsn == q.sorted[0] {
q.sorted = q.sorted[1:]
if c, ok := q.chunkMap[tsn]; ok {
delete(q.chunkMap, tsn)
q.nBytes -= len(c.userData)
return c, true
}
if len(q.tsns) == 0 || q.tsns[0] != tsn {
return nil, false
}

Check warning on line 38 in payload_queue.go

View check run for this annotation

Codecov / codecov/patch

payload_queue.go#L37-L38

Added lines #L37 - L38 were not covered by tests
q.tsns = q.tsns[1:]
if c, ok := q.chunkMap[tsn]; ok {
delete(q.chunkMap, tsn)
q.nBytes -= len(c.userData)
return c, true
}

return nil, false
}

Expand All @@ -89,58 +51,6 @@ func (q *payloadQueue) get(tsn uint32) (*chunkPayloadData, bool) {
return c, ok
}

// popDuplicates returns an array of TSN values that were found duplicate.
func (q *payloadQueue) popDuplicates() []uint32 {
dups := q.dupTSN
q.dupTSN = []uint32{}
return dups
}

func (q *payloadQueue) getGapAckBlocks(cumulativeTSN uint32) (gapAckBlocks []gapAckBlock) {
var b gapAckBlock

if len(q.chunkMap) == 0 {
return []gapAckBlock{}
}

q.updateSortedKeys()

for i, tsn := range q.sorted {
if i == 0 {
b.start = uint16(tsn - cumulativeTSN)
b.end = b.start
continue
}
diff := uint16(tsn - cumulativeTSN)
if b.end+1 == diff {
b.end++
} else {
gapAckBlocks = append(gapAckBlocks, gapAckBlock{
start: b.start,
end: b.end,
})
b.start = diff
b.end = diff
}
}

gapAckBlocks = append(gapAckBlocks, gapAckBlock{
start: b.start,
end: b.end,
})

return gapAckBlocks
}

func (q *payloadQueue) getGapAckBlocksString(cumulativeTSN uint32) string {
gapAckBlocks := q.getGapAckBlocks(cumulativeTSN)
str := fmt.Sprintf("cumTSN=%d", cumulativeTSN)
for _, b := range gapAckBlocks {
str += fmt.Sprintf(",%d-%d", b.start, b.end)
}
return str
}

func (q *payloadQueue) markAsAcked(tsn uint32) int {
var nBytesAcked int
if c, ok := q.chunkMap[tsn]; ok {
Expand All @@ -154,16 +64,6 @@ func (q *payloadQueue) markAsAcked(tsn uint32) int {
return nBytesAcked
}

func (q *payloadQueue) getLastTSNReceived() (uint32, bool) {
q.updateSortedKeys()

qlen := len(q.sorted)
if qlen == 0 {
return 0, false
}
return q.sorted[qlen-1], true
}

func (q *payloadQueue) markAllToRetrasmit() {
for _, c := range q.chunkMap {
if c.acked || c.abandoned() {
Expand Down
83 changes: 12 additions & 71 deletions payload_queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ func makePayload(tsn uint32, nBytes int) *chunkPayloadData {
}

func TestPayloadQueue(t *testing.T) {
t.Run("pushNoCheck", func(t *testing.T) {
t.Run("push", func(t *testing.T) {
pq := newPayloadQueue()
pq.pushNoCheck(makePayload(0, 10))
pq.push(makePayload(0, 10))
assert.Equal(t, 10, pq.getNumBytes(), "total bytes mismatch")
assert.Equal(t, 1, pq.size(), "item count mismatch")
pq.pushNoCheck(makePayload(1, 11))
pq.push(makePayload(1, 11))
assert.Equal(t, 21, pq.getNumBytes(), "total bytes mismatch")
assert.Equal(t, 2, pq.size(), "item count mismatch")
pq.pushNoCheck(makePayload(2, 12))
pq.push(makePayload(2, 12))
assert.Equal(t, 33, pq.getNumBytes(), "total bytes mismatch")
assert.Equal(t, 3, pq.size(), "item count mismatch")

Expand All @@ -31,96 +31,37 @@ func TestPayloadQueue(t *testing.T) {
assert.True(t, ok, "pop should succeed")
if ok {
assert.Equal(t, i, c.tsn, "TSN should match")
assert.NotNil(t, pq.sorted, "should not be nil")
assert.NotNil(t, pq.tsns, "should not be nil")
}
}

assert.Equal(t, 0, pq.getNumBytes(), "total bytes mismatch")
assert.Equal(t, 0, pq.size(), "item count mismatch")

pq.pushNoCheck(makePayload(3, 13))
assert.Nil(t, pq.sorted, "should be nil")
pq.push(makePayload(3, 13))
assert.Len(t, pq.tsns, 1)
assert.Equal(t, 13, pq.getNumBytes(), "total bytes mismatch")
pq.pushNoCheck(makePayload(4, 14))
assert.Nil(t, pq.sorted, "should be nil")
pq.push(makePayload(4, 14))
assert.Len(t, pq.tsns, 2)
assert.Equal(t, 27, pq.getNumBytes(), "total bytes mismatch")

for i := uint32(3); i < 5; i++ {
c, ok := pq.pop(i)
assert.True(t, ok, "pop should succeed")
if ok {
assert.Equal(t, i, c.tsn, "TSN should match")
assert.NotNil(t, pq.sorted, "should not be nil")
assert.NotNil(t, pq.tsns, "should not be nil")
}
}

assert.Equal(t, 0, pq.getNumBytes(), "total bytes mismatch")
assert.Equal(t, 0, pq.size(), "item count mismatch")
})

t.Run("getGapAckBlocks", func(t *testing.T) {
pq := newPayloadQueue()
pq.push(makePayload(1, 0), 0)
pq.push(makePayload(2, 0), 0)
pq.push(makePayload(3, 0), 0)
pq.push(makePayload(4, 0), 0)
pq.push(makePayload(5, 0), 0)
pq.push(makePayload(6, 0), 0)

gab1 := []*gapAckBlock{{start: 1, end: 6}}
gab2 := pq.getGapAckBlocks(0)
assert.NotNil(t, gab2)
assert.Len(t, gab2, 1)

assert.Equal(t, gab1[0].start, gab2[0].start)
assert.Equal(t, gab1[0].end, gab2[0].end)

pq.push(makePayload(8, 0), 0)
pq.push(makePayload(9, 0), 0)

gab1 = []*gapAckBlock{{start: 1, end: 6}, {start: 8, end: 9}}
gab2 = pq.getGapAckBlocks(0)
assert.NotNil(t, gab2)
assert.Len(t, gab2, 2)

assert.Equal(t, gab1[0].start, gab2[0].start)
assert.Equal(t, gab1[0].end, gab2[0].end)
assert.Equal(t, gab1[1].start, gab2[1].start)
assert.Equal(t, gab1[1].end, gab2[1].end)
})

t.Run("getLastTSNReceived", func(t *testing.T) {
pq := newPayloadQueue()

// empty queie should return false
_, ok := pq.getLastTSNReceived()
assert.False(t, ok, "should be false")

ok = pq.push(makePayload(20, 0), 0)
assert.True(t, ok, "should be true")
tsn, ok := pq.getLastTSNReceived()
assert.True(t, ok, "should be false")
assert.Equal(t, uint32(20), tsn, "should match")

// append should work
ok = pq.push(makePayload(21, 0), 0)
assert.True(t, ok, "should be true")
tsn, ok = pq.getLastTSNReceived()
assert.True(t, ok, "should be false")
assert.Equal(t, uint32(21), tsn, "should match")

// check if sorting applied
ok = pq.push(makePayload(19, 0), 0)
assert.True(t, ok, "should be true")
tsn, ok = pq.getLastTSNReceived()
assert.True(t, ok, "should be false")
assert.Equal(t, uint32(21), tsn, "should match")
})

t.Run("markAllToRetrasmit", func(t *testing.T) {
pq := newPayloadQueue()
for i := 0; i < 3; i++ {
pq.push(makePayload(uint32(i+1), 10), 0)
pq.push(makePayload(uint32(i+1), 10))
}
pq.markAsAcked(2)
pq.markAllToRetrasmit()
Expand All @@ -139,7 +80,7 @@ func TestPayloadQueue(t *testing.T) {
t.Run("reset retransmit flag on ack", func(t *testing.T) {
pq := newPayloadQueue()
for i := 0; i < 4; i++ {
pq.push(makePayload(uint32(i+1), 10), 0)
pq.push(makePayload(uint32(i+1), 10))
}

pq.markAllToRetrasmit()
Expand Down

0 comments on commit 278dfbb

Please sign in to comment.