forked from plgd-dev/go-coap
/
tlslistener.go
82 lines (74 loc) · 1.89 KB
/
tlslistener.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
package net
import (
"context"
"crypto/tls"
"fmt"
"net"
"sync/atomic"
"time"
)
// TLSListener is a TLS listener that provides accept with context.
type TLSListener struct {
tcp *net.TCPListener
listener net.Listener
heartBeat time.Duration
closed uint32
}
// NewTLSListener creates tcp listener.
// Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only).
func NewTLSListener(network string, addr string, cfg *tls.Config, heartBeat time.Duration) (*TLSListener, error) {
tcp, err := newNetTCPListen(network, addr)
if err != nil {
return nil, fmt.Errorf("cannot create new tls listener: %v", err)
}
tls := tls.NewListener(tcp, cfg)
return &TLSListener{
tcp: tcp,
listener: tls,
heartBeat: heartBeat,
}, nil
}
// AcceptContext waits with context for a generic Conn.
func (l *TLSListener) AcceptWithContext(ctx context.Context) (net.Conn, error) {
for {
if atomic.LoadUint32(&l.closed) == 1 {
return nil, ErrServerClosed
}
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
err := l.SetDeadline(time.Now().Add(l.heartBeat))
if err != nil {
return nil, fmt.Errorf("cannot accept connections: %v", err)
}
rw, err := l.listener.Accept()
if err != nil {
if isTemporary(err) {
continue
}
return nil, fmt.Errorf("cannot accept connections: %v", err)
}
return rw, nil
}
}
// SetDeadline sets deadline for accept operation.
func (l *TLSListener) SetDeadline(t time.Time) error {
return l.tcp.SetDeadline(t)
}
// Accept waits for a generic Conn.
func (l *TLSListener) Accept() (net.Conn, error) {
return l.AcceptWithContext(context.Background())
}
// Close closes the connection.
func (l *TLSListener) Close() error {
if !atomic.CompareAndSwapUint32(&l.closed, 0, 1) {
return nil
}
return l.listener.Close()
}
// Addr represents a network end point address.
func (l *TLSListener) Addr() net.Addr {
return l.listener.Addr()
}