diff --git a/association.go b/association.go index f0f1d7be..507e3aad 100644 --- a/association.go +++ b/association.go @@ -264,6 +264,7 @@ func createClientWithContext(ctx context.Context, config Config) (*Association, select { case <-ctx.Done(): a.log.Errorf("[%s] client handshake canceled: state=%s", a.name, getAssociationStateString(a.getState())) + a.Close() // nolint:errcheck,gosec return nil, ctx.Err() case err := <-a.handshakeCompletedCh: if err != nil { diff --git a/association_test.go b/association_test.go index 25f2c895..82e590fb 100644 --- a/association_test.go +++ b/association_test.go @@ -2564,16 +2564,62 @@ func TestAssocMaxMessageSize(t *testing.T) { }) } +// udpConnWrapper wraps a *net.UDPConn and implements net.Conn interface. +type udpConnWrapper struct { + conn *net.UDPConn + remoteAddr net.Addr +} + +func newUDPConnWrapper(conn *net.UDPConn, remoteAddr net.Addr) net.Conn { + return &udpConnWrapper{ + conn: conn, + remoteAddr: remoteAddr, + } +} + +// Implement the net.Conn interface methods +func (w *udpConnWrapper) Read(b []byte) (n int, err error) { + // w.conn.ReadFrom(b) + n, _, err = w.conn.ReadFrom(b) + return n, err +} + +func (w *udpConnWrapper) Write(b []byte) (n int, err error) { + return w.conn.WriteTo(b, w.remoteAddr) +} + +func (w *udpConnWrapper) Close() error { + return w.conn.Close() +} + +func (w *udpConnWrapper) LocalAddr() net.Addr { + return w.conn.LocalAddr() +} + +func (w *udpConnWrapper) RemoteAddr() net.Addr { + return w.remoteAddr +} + +func (w *udpConnWrapper) SetDeadline(t time.Time) error { + return w.conn.SetDeadline(t) +} + +func (w *udpConnWrapper) SetReadDeadline(t time.Time) error { + return w.conn.SetReadDeadline(t) +} + +func (w *udpConnWrapper) SetWriteDeadline(t time.Time) error { + return w.conn.SetWriteDeadline(t) +} + // crateUDPConnPair creates a pair of net.UDPConn objects that are connected with each other -func createUDPConnPair(t *testing.T) (*net.UDPConn, *net.UDPConn, error) { +func createUDPConnPair(t *testing.T) (net.Conn, net.Conn, error) { udp1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}) if err != nil { return nil, nil, err } addr1, ok := udp1.LocalAddr().(*net.UDPAddr) require.True(t, ok) - err = udp1.Close() - require.NoError(t, err) udp2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}) if err != nil { @@ -2581,20 +2627,18 @@ func createUDPConnPair(t *testing.T) (*net.UDPConn, *net.UDPConn, error) { } addr2, ok := udp2.LocalAddr().(*net.UDPAddr) require.True(t, ok) - err = udp2.Close() - require.NoError(t, err) - udp1, err = net.DialUDP("udp", addr1, addr2) + conn1 := newUDPConnWrapper(udp1, addr2) if err != nil { return nil, nil, err } - udp2, err = net.DialUDP("udp", addr2, addr1) + conn2 := newUDPConnWrapper(udp2, addr1) if err != nil { return nil, nil, err } - return udp1, udp2, nil + return conn1, conn2, nil } func createAssocs(t *testing.T) (*Association, *Association, error) { @@ -2952,3 +2996,64 @@ func TestAssociation_Abort(t *testing.T) { assert.Equal(t, i, 0, "expected no data read") assert.Error(t, err, "User Initiated Abort: 1234", "expected abort reason") } + +// TestAssociation_createClientWithContext tests that the client is closed when the context is canceled. +func TestAssociation_createClientWithContext(t *testing.T) { + checkGoroutineLeaks(t) + + udp1, udp2, err := createUDPConnPair(t) + require.NoError(t, err) + + loggerFactory := logging.NewDefaultLoggerFactory() + + errCh1 := make(chan error) + errCh2 := make(chan error) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + + go func() { + _, err2 := createClientWithContext(ctx, Config{ + NetConn: udp1, + LoggerFactory: loggerFactory, + }) + if err2 != nil { + errCh1 <- err2 + } else { + errCh1 <- nil + } + }() + + go func() { + _, err2 := createClientWithContext(ctx, Config{ + NetConn: udp2, + LoggerFactory: loggerFactory, + }) + if err2 != nil { + errCh2 <- err2 + } else { + errCh2 <- nil + } + }() + + // Cancel the context immediately + cancel() + + var err1 error + var err2 error +loop: + for { + select { + case err1 = <-errCh1: + if err1 != nil && err2 != nil { + break loop + } + case err2 = <-errCh2: + if err1 != nil && err2 != nil { + break loop + } + } + } + + assert.Error(t, err1, "context canceled") + assert.Error(t, err2, "context canceled") +}