Skip to content

Commit

Permalink
Improved UDPConn pair and log messages
Browse files Browse the repository at this point in the history
Relates to #270
  • Loading branch information
enobufs authored and Sean-Der committed Jan 6, 2024
1 parent daee3aa commit d69aa98
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 35 deletions.
7 changes: 7 additions & 0 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,17 @@ func Server(config Config) (*Association, error) {

// Client opens a SCTP stream over a conn
func Client(config Config) (*Association, error) {
return createClientWithContext(context.Background(), config)
}

func createClientWithContext(ctx context.Context, config Config) (*Association, error) {
a := createAssociation(config)
a.init(true)

select {
case <-ctx.Done():
a.log.Errorf("[%s] client handshake canceled: state=%s", a.name, getAssociationStateString(a.getState()))
return nil, ctx.Err()
case err := <-a.handshakeCompletedCh:
if err != nil {
return nil, err
Expand Down
114 changes: 79 additions & 35 deletions association_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2570,65 +2570,106 @@ func TestAssocMaxMessageSize(t *testing.T) {
})
}

func createAssocs(t *testing.T) (a1, a2 *Association) {
addr1 := &net.UDPAddr{
IP: net.IP{127, 0, 0, 1},
Port: 1234,
// crateUDPConnPair creates a pair of net.UDPConn objects that are connected with each other
func createUDPConnPair(t *testing.T) (*net.UDPConn, *net.UDPConn, error) {
udp1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")})
if err != nil {
return nil, nil, err
}
addr1, ok := udp1.LocalAddr().(*net.UDPAddr)
require.True(t, ok)
err = udp1.Close()
require.NoError(t, err)

addr2 := &net.UDPAddr{
IP: net.IP{127, 0, 0, 1},
Port: 5678,
udp2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")})
if err != nil {
return nil, nil, err
}
addr2, ok := udp2.LocalAddr().(*net.UDPAddr)
require.True(t, ok)
err = udp2.Close()
require.NoError(t, err)

udp1, err := net.DialUDP("udp", addr1, addr2)
udp1, err = net.DialUDP("udp", addr1, addr2)
if err != nil {
panic(err)
return nil, nil, err
}

udp2, err := net.DialUDP("udp", addr2, addr1)
udp2, err = net.DialUDP("udp", addr2, addr1)
if err != nil {
panic(err)
return nil, nil, err
}

return udp1, udp2, nil
}

func createAssocs(t *testing.T) (*Association, *Association, error) {
udp1, udp2, err := createUDPConnPair(t)
if err != nil {
return nil, nil, err
}

loggerFactory := logging.NewDefaultLoggerFactory()

a1Chan := make(chan *Association)
a2Chan := make(chan *Association)
a1Chan := make(chan interface{})
a2Chan := make(chan interface{})

ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

go func() {
a, err := Client(Config{
a, err2 := createClientWithContext(ctx, Config{
NetConn: udp1,
LoggerFactory: loggerFactory,
})
require.NoError(t, err)

a1Chan <- a
if err2 != nil {
a1Chan <- err2
} else {
a1Chan <- a
}
}()

go func() {
a, err := Client(Config{
a, err2 := createClientWithContext(ctx, Config{
NetConn: udp2,
LoggerFactory: loggerFactory,
})
require.NoError(t, err)

a2Chan <- a
if err2 != nil {
a2Chan <- err2
} else {
a2Chan <- a
}
}()

select {
case a1 = <-a1Chan:
case <-time.After(time.Second):
assert.Fail(t, "timed out waiting for a1")
}
var a1 *Association
var a2 *Association

select {
case a2 = <-a2Chan:
case <-time.After(time.Second):
assert.Fail(t, "timed out waiting for a2")
loop:
for {
select {
case v1 := <-a1Chan:
switch v := v1.(type) {
case *Association:
a1 = v
if a2 != nil {
break loop
}
case error:
return nil, nil, v
}
case v2 := <-a2Chan:
switch v := v2.(type) {
case *Association:
a2 = v
if a1 != nil {
break loop
}
case error:
return nil, nil, v
}
}
}

return a1, a2
return a1, a2, nil
}

func TestAssociation_Shutdown(t *testing.T) {
Expand All @@ -2640,7 +2681,8 @@ func TestAssociation_Shutdown(t *testing.T) {
assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked")
}()

a1, a2 := createAssocs(t)
a1, a2, err := createAssocs(t)
require.NoError(t, err)

s11, err := a1.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)
Expand Down Expand Up @@ -2683,7 +2725,8 @@ func TestAssociation_ShutdownDuringWrite(t *testing.T) {
assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked")
}()

a1, a2 := createAssocs(t)
a1, a2, err := createAssocs(t)
require.NoError(t, err)

s11, err := a1.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)
Expand Down Expand Up @@ -2899,7 +2942,8 @@ func TestAssociation_Abort(t *testing.T) {
assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked")
}()

a1, a2 := createAssocs(t)
a1, a2, err := createAssocs(t)
require.NoError(t, err)

s11, err := a1.OpenStream(1, PayloadTypeWebRTCString)
require.NoError(t, err)
Expand Down

0 comments on commit d69aa98

Please sign in to comment.