From e0b126bdfb530244c578e83a3cc3e8480d9dc108 Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Mon, 15 Nov 2021 14:21:11 +0100 Subject: [PATCH] fix(dot/network): fix memory allocations with `sizedBufferPool` (#1963) --- dot/network/notifications.go | 2 +- dot/network/pool.go | 25 ++++---- dot/network/pool_test.go | 111 +++++++++++++++++++++++++++++++++++ dot/network/service.go | 15 +++-- 4 files changed, 131 insertions(+), 22 deletions(-) create mode 100644 dot/network/pool_test.go diff --git a/dot/network/notifications.go b/dot/network/notifications.go index 6022703d82..3bf9615ccc 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -389,7 +389,7 @@ func (s *Service) readHandshake(stream libp2pnetwork.Stream, decoder HandshakeDe go func() { msgBytes := s.bufPool.get() defer func() { - s.bufPool.put(&msgBytes) + s.bufPool.put(msgBytes) close(hsC) }() diff --git a/dot/network/pool.go b/dot/network/pool.go index 6f864795ce..c517acf88d 100644 --- a/dot/network/pool.go +++ b/dot/network/pool.go @@ -5,15 +5,15 @@ package network // sizedBufferPool is a pool of buffers used for reading from streams type sizedBufferPool struct { - c chan *[maxMessageSize]byte + c chan []byte } -func newSizedBufferPool(min, max int) (bp *sizedBufferPool) { - bufferCh := make(chan *[maxMessageSize]byte, max) +func newSizedBufferPool(preAllocate, size int) (bp *sizedBufferPool) { + bufferCh := make(chan []byte, size) - for i := 0; i < min; i++ { - buf := [maxMessageSize]byte{} - bufferCh <- &buf + for i := 0; i < preAllocate; i++ { + buf := make([]byte, maxMessageSize) + bufferCh <- buf } return &sizedBufferPool{ @@ -23,20 +23,19 @@ func newSizedBufferPool(min, max int) (bp *sizedBufferPool) { // get gets a buffer from the sizedBufferPool, or creates a new one if none are // available in the pool. Buffers have a pre-allocated capacity. -func (bp *sizedBufferPool) get() [maxMessageSize]byte { - var buff *[maxMessageSize]byte +func (bp *sizedBufferPool) get() (b []byte) { select { - case buff = <-bp.c: - // reuse existing buffer + case b = <-bp.c: + // reuse existing buffer + return b default: // create new buffer - buff = &[maxMessageSize]byte{} + return make([]byte, maxMessageSize) } - return *buff } // put returns the given buffer to the sizedBufferPool. -func (bp *sizedBufferPool) put(b *[maxMessageSize]byte) { +func (bp *sizedBufferPool) put(b []byte) { select { case bp.c <- b: default: // Discard the buffer if the pool is full. diff --git a/dot/network/pool_test.go b/dot/network/pool_test.go new file mode 100644 index 0000000000..4b9927274e --- /dev/null +++ b/dot/network/pool_test.go @@ -0,0 +1,111 @@ +package network + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Benchmark_sizedBufferPool(b *testing.B) { + const preAllocate = 100 + const poolSize = 200 + sbp := newSizedBufferPool(preAllocate, poolSize) + + b.RunParallel(func(p *testing.PB) { + for p.Next() { + buffer := sbp.get() + buffer[0] = 1 + buffer[len(buffer)-1] = 1 + sbp.put(buffer) + } + }) +} + +// Before: 104853 11119 ns/op 65598 B/op 1 allocs/op +// Array ptr: 2742781 438.3 ns/op 2 B/op 0 allocs/op +// Slices: 2560960 463.8 ns/op 2 B/op 0 allocs/op +// Slice pointer: 2683528 460.8 ns/op 2 B/op 0 allocs/op + +func Test_sizedBufferPool(t *testing.T) { + t.Parallel() + + const preAlloc = 1 + const poolSize = 2 + const maxIndex = maxMessageSize - 1 + + pool := newSizedBufferPool(preAlloc, poolSize) + + first := pool.get() // pre-allocated one + first[maxIndex] = 1 + + second := pool.get() // new one + second[maxIndex] = 2 + + third := pool.get() // new one + third[maxIndex] = 3 + + fourth := pool.get() // new one + fourth[maxIndex] = 4 + + pool.put(fourth) + pool.put(third) + pool.put(second) // discarded + pool.put(first) // discarded + + b := pool.get() // fourth + assert.Equal(t, byte(4), b[maxIndex]) + + b = pool.get() // third + assert.Equal(t, byte(3), b[maxIndex]) +} + +func Test_sizedBufferPool_race(t *testing.T) { + t.Parallel() + + const preAlloc = 1 + const poolSize = 2 + + pool := newSizedBufferPool(preAlloc, poolSize) + + const parallelism = 4 + + readyWait := new(sync.WaitGroup) + readyWait.Add(parallelism) + + doneWait := new(sync.WaitGroup) + doneWait.Add(parallelism) + + // run for 50ms + ctxTimerStarted := make(chan struct{}) + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + go func() { + const timeout = 50 * time.Millisecond + readyWait.Wait() + ctx, cancel = context.WithTimeout(ctx, timeout) + close(ctxTimerStarted) + }() + defer cancel() + + for i := 0; i < parallelism; i++ { + go func() { + defer doneWait.Done() + readyWait.Done() + readyWait.Wait() + <-ctxTimerStarted + + for ctx.Err() != nil { + // test relies on the -race detector + // to detect concurrent writes to the buffer. + b := pool.get() + b[0] = 1 + pool.put(b) + } + }() + } + + doneWait.Wait() +} diff --git a/dot/network/service.go b/dot/network/service.go index d538ccfb84..0f51a2f049 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -135,14 +135,13 @@ func NewService(cfg *Config) (*Service, error) { // pre-allocate pool of buffers used to read from streams. // initially allocate as many buffers as liekly necessary which is the number inbound streams we will have, // which should equal average number of peers times the number of notifications protocols, which is currently 3. - var bufPool *sizedBufferPool - if cfg.noPreAllocate { - bufPool = &sizedBufferPool{ - c: make(chan *[maxMessageSize]byte, cfg.MinPeers*3), - } - } else { - bufPool = newSizedBufferPool(cfg.MinPeers*3, cfg.MaxPeers*3) + preAllocateInPool := cfg.MinPeers * 3 + poolSize := cfg.MaxPeers * 3 + if cfg.noPreAllocate { // testing + preAllocateInPool = 0 + poolSize = cfg.MinPeers * 3 } + bufPool := newSizedBufferPool(preAllocateInPool, poolSize) network := &Service{ ctx: ctx, @@ -550,7 +549,7 @@ func (s *Service) readStream(stream libp2pnetwork.Stream, decoder messageDecoder peer := stream.Conn().RemotePeer() msgBytes := s.bufPool.get() - defer s.bufPool.put(&msgBytes) + defer s.bufPool.put(msgBytes) for { tot, err := readStream(stream, msgBytes[:])