Skip to content

Commit

Permalink
Conn: synchronize Conn.Close() and Conn.Listen() workers
Browse files Browse the repository at this point in the history
With the 'new' runtime poller introduced in Go 1.12, closing a Conn now
unblocks any blocked calls to netlink.Conn.Receive().

This patch handles closed Receive()s and introduces a WaitGroup to conntrack.Conn
that allows Conn.Close() to wait for all workers to exit.

See mdlayher/netlink#119.
  • Loading branch information
ti-mo committed Dec 19, 2022
1 parent ba3b291 commit 02ccf83
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 38 deletions.
38 changes: 32 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package conntrack

import (
"fmt"
"sync"

"github.com/mdlayher/netlink"
"github.com/pkg/errors"
Expand All @@ -12,6 +13,8 @@ import (
// subsystem and implements all Conntrack actions.
type Conn struct {
conn *netfilter.Conn

workers sync.WaitGroup
}

// DumpOptions is passed as an option to `Dump`-related methods to modify their behaviour.
Expand All @@ -28,12 +31,21 @@ func Dial(config *netlink.Config) (*Conn, error) {
return nil, err
}

return &Conn{c}, nil
return &Conn{conn: c}, nil
}

// Close closes a Conn.
//
// If any workers were started using [Conn.Listen], blocks until all have
// terminated.
func (c *Conn) Close() error {
return c.conn.Close()
if err := c.conn.Close(); err != nil {
return err
}

c.workers.Wait()

return nil
}

// SetOption enables or disables a netlink socket option for the Conn.
Expand Down Expand Up @@ -72,8 +84,9 @@ func (c *Conn) SetWriteBuffer(bytes int) error {
// evChan consumers need to be able to keep up with the Event producers. When the channel is full,
// messages will pile up in the Netlink socket's buffer, putting the socket at risk of being closed
// by the kernel when it eventually fills up.
//
// Closing the Conn makes all workers terminate silently.
func (c *Conn) Listen(evChan chan<- Event, numWorkers uint8, groups []netfilter.NetlinkGroup) (chan error, error) {

if numWorkers == 0 {
return nil, errors.Errorf(errWorkerCount, numWorkers)
}
Expand Down Expand Up @@ -101,16 +114,29 @@ func (c *Conn) Listen(evChan chan<- Event, numWorkers uint8, groups []netfilter.

// eventWorker is a worker function that decodes Netlink messages into Events.
func (c *Conn) eventWorker(workerID uint8, evChan chan<- Event, errChan chan<- error) {

var err error
var recv []netlink.Message
var ev Event

c.workers.Add(1)
defer c.workers.Done()

for {
// Receive data from the Netlink socket
// Receive data from the Netlink socket.
recv, err = c.conn.Receive()

// If the Conn gets closed while blocked in Receive(), Go's runtime poller
// will return an src/internal/poll.ErrFileClosing. Since we cannot match
// the underlying error using errors.Is(), retrieve it from the netlink.OpErr.
var opErr *netlink.OpError
if errors.As(err, &opErr) {
if opErr.Err.Error() == "use of closed file" {
return
}
}

if err != nil {
errChan <- errors.Wrap(err, fmt.Sprintf(errWorkerReceive, workerID))
errChan <- fmt.Errorf("Receive() netlink event, closing worker %d: %w", workerID, err)
return
}

Expand Down
1 change: 0 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,5 @@ var (
const (
errUnknownEventType = "unknown event type %d"
errWorkerCount = "invalid worker count %d"
errWorkerReceive = "netlink.Receive error in listenWorker %d, exiting"
errAttributeChild = "unknown attribute child Type '%d'"
)
47 changes: 16 additions & 31 deletions event_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"net"
"testing"

"github.com/pkg/errors"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand All @@ -16,40 +14,26 @@ import (
)

func TestConnListen(t *testing.T) {

// Dial a send connection to Netlink in a new namespace
// Dial a send connection to Netlink in a new namespace.
sc, nsid, err := makeNSConn()
require.NoError(t, err)

// Create a listener connection in the same namespace
// Create a listener connection in the same namespace.
lc, err := Dial(&netlink.Config{NetNS: nsid})
require.NoError(t, err)

// This needs to be an unbuffered channel with a single producer worker. Multicast connections
// currently cannot be terminated gracefully when stuck in Receive(), so we have to inject an event
// ourselves, while making sure the worker exits before re-entering Receive().
// Subscribe to new/update conntrack events using a single worker.
ev := make(chan Event)
errChan, err := lc.Listen(ev, 1, []netfilter.NetlinkGroup{netfilter.GroupCTNew, netfilter.GroupCTUpdate})
require.NoError(t, err)

// Watch for listen channel errors in the background
go func() {
err, ok := <-errChan
if ok {
opErr, ok := errors.Cause(err).(*netlink.OpError)
require.True(t, ok)
require.EqualError(t, opErr.Err, "recvmsg: bad file descriptor")
}
}()

numFlows := 100

var f Flow
var warn bool

for i := 1; i <= numFlows; i++ {

// Create the Flow
// Create the Flow.
f = NewFlow(
17, 0,
net.ParseIP("2a00:1450:400e:804::200e"),
Expand All @@ -59,7 +43,7 @@ func TestConnListen(t *testing.T) {
err = sc.Create(f)
require.NoError(t, err, "creating IPv6 flow", i)

// Read a new event from the channel
// Read a new event from the channel.
re := <-ev

// Validate new event attributes
Expand All @@ -75,31 +59,32 @@ func TestConnListen(t *testing.T) {
}
assert.Equal(t, f.TupleOrig.Proto.DestinationPort, re.Flow.TupleOrig.Proto.DestinationPort)

// Update the flow
// Update the Flow.
f.Timeout = 240
err = sc.Update(f)
require.NoError(t, err)

// Read an update event from the channel
// Read an update event from the channel.
re = <-ev

// Validate update event attributes
// Validate update event attributes.
assert.Equal(t, EventUpdate, re.Type)
assert.Equal(t, f.TupleOrig.Proto.DestinationPort, re.Flow.TupleOrig.Proto.DestinationPort)

// Compare the timeout on the connection, but within a 2-second window.
assert.GreaterOrEqual(t, re.Flow.Timeout, f.Timeout-2, "timeout")
}

// Generate an event to unblock the listen worker goroutine
go func() {
f.Timeout = 1
_ = sc.Update(f)
}()

// Close the sockets
// Close the sockets, interrupting any blocked listeners.
assert.NoError(t, lc.Close())
assert.NoError(t, sc.Close())

// Non-blocking read on errChan. No messages should appear when workers die.
select {
case err := <-errChan:
assert.NoError(t, err)
default:
}
}

func TestConnListenError(t *testing.T) {
Expand Down

0 comments on commit 02ccf83

Please sign in to comment.