From 218d0f3746c902f3be64585835b092282a8ec58c Mon Sep 17 00:00:00 2001 From: Patrick Hemmer Date: Thu, 6 Apr 2023 22:31:11 -0400 Subject: [PATCH] fix(inputs.socket_listener): fix tracking of unix sockets This fixes an issue where the code would lose track of unix sockets. The remote end of a unix socket does not have a unique address representation, thus multiple entries may overwrite each other in the map. This changes the map to key off the net.Conn object itself, basically using the map as a set. Fixes #13058 --- .../inputs/socket_listener/socket_listener_test.go | 12 ++++++++++++ plugins/inputs/socket_listener/stream_listener.go | 10 +++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/plugins/inputs/socket_listener/socket_listener_test.go b/plugins/inputs/socket_listener/socket_listener_test.go index 032c86dde2ffa..11b7b80ac09d4 100644 --- a/plugins/inputs/socket_listener/socket_listener_test.go +++ b/plugins/inputs/socket_listener/socket_listener_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net" "os" "path/filepath" @@ -15,6 +16,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/influxdata/telegraf" @@ -181,6 +183,16 @@ func TestSocketListener(t *testing.T) { }, time.Second, 100*time.Millisecond, "did not receive metrics (%d)", acc.NMetrics()) actual := acc.GetTelegrafMetrics() testutil.RequireMetricsEqual(t, expected, actual, testutil.SortMetrics()) + + plugin.Stop() + + if _, ok := plugin.listener.(*streamListener); ok { + // Verify that plugin.Stop() closed the client's connection + _ = client.SetReadDeadline(time.Now().Add(time.Second)) + buf := []byte{1} + _, err = client.Read(buf) + assert.Equal(t, err, io.EOF) + } }) } } diff --git a/plugins/inputs/socket_listener/stream_listener.go b/plugins/inputs/socket_listener/stream_listener.go index 0c765afdae81f..6afec834b85e2 100644 --- a/plugins/inputs/socket_listener/stream_listener.go +++ b/plugins/inputs/socket_listener/stream_listener.go @@ -32,7 +32,7 @@ type streamListener struct { Log telegraf.Logger listener net.Listener - connections map[string]net.Conn + connections map[net.Conn]struct{} path string wg sync.WaitGroup @@ -123,7 +123,7 @@ func (l *streamListener) setupConnection(conn net.Conn) error { // Store the connection mapped to its address l.Lock() - l.connections[addr] = conn + l.connections[conn] = struct{}{} l.Unlock() return nil @@ -134,7 +134,7 @@ func (l *streamListener) closeConnection(conn net.Conn) { if err := conn.Close(); err != nil { l.Log.Errorf("Cannot close connection to %q: %v", addr, err) } - delete(l.connections, addr) + delete(l.connections, conn) } func (l *streamListener) addr() net.Addr { @@ -147,7 +147,7 @@ func (l *streamListener) close() error { } l.Lock() - for _, conn := range l.connections { + for conn := range l.connections { l.closeConnection(conn) } l.Unlock() @@ -164,7 +164,7 @@ func (l *streamListener) close() error { } func (l *streamListener) listen(acc telegraf.Accumulator) { - l.connections = make(map[string]net.Conn) + l.connections = make(map[net.Conn]struct{}) l.wg.Add(1) defer l.wg.Done()