Skip to content

Commit

Permalink
protocol: fix header timeout when the user supplied deadline(s)
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniomika authored and pires committed Sep 8, 2021
1 parent 094c0b6 commit 0c5719a
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 62 deletions.
54 changes: 45 additions & 9 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@ import (
"time"
)

var DEFAULT_TIMEOUT = 200 * time.Millisecond

// Listener is used to wrap an underlying listener,
// whose connections may be using the HAProxy Proxy Protocol.
// If the connection is using the protocol, the RemoteAddr() will return
// the correct client address.
// the correct client address. ReadHeaderTimeout will be applied to all
// connections in order to prevent blocking operations. If no ReadHeaderTimeout
// is set, a default of 200ms will be used. This can be disabled by setting the
// timeout to < 0.
type Listener struct {
Listener net.Listener
Policy PolicyFunc
Expand All @@ -21,7 +26,8 @@ type Listener struct {

// Conn is used to wrap and underlying connection which
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
// return the address of the client instead of the proxy address.
// return the address of the client instead of the proxy address. Each connection
// will have its own readHeaderTimeout and readDeadline set by the Accept() call.
type Conn struct {
bufReader *bufio.Reader
conn net.Conn
Expand All @@ -30,6 +36,8 @@ type Conn struct {
ProxyHeaderPolicy Policy
Validate Validator
readErr error
readHeaderTimeout time.Duration
readDeadline time.Time
}

// Validator receives a header and decides whether it is a valid one
Expand All @@ -53,12 +61,6 @@ func (p *Listener) Accept() (net.Conn, error) {
return nil, err
}

if d := p.ReadHeaderTimeout; d != 0 {
// The deadline will be reset after parsing the header.
// Otherwise, future p.conn.Read() will timeout.
conn.SetReadDeadline(time.Now().Add(d))
}

proxyHeaderPolicy := USE
if p.Policy != nil {
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
Expand All @@ -74,6 +76,15 @@ func (p *Listener) Accept() (net.Conn, error) {
WithPolicy(proxyHeaderPolicy),
ValidateHeader(p.ValidateHeader),
)

// If the ReadHeaderTimeout for the listener is 0, set a default of 200ms
if p.ReadHeaderTimeout == 0 {
p.ReadHeaderTimeout = DEFAULT_TIMEOUT
}

// Set the readHeaderTimeout of the new conn to the value of the listener
newConn.readHeaderTimeout = p.ReadHeaderTimeout

return newConn, nil
}

Expand Down Expand Up @@ -108,7 +119,6 @@ func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
func (p *Conn) Read(b []byte) (int, error) {
p.once.Do(func() {
p.readErr = p.readHeader()
p.conn.SetReadDeadline(time.Time{})
})
if p.readErr != nil {
return 0, p.readErr
Expand Down Expand Up @@ -201,11 +211,16 @@ func (p *Conn) UDPConn() (conn *net.UDPConn, ok bool) {

// SetDeadline wraps original conn.SetDeadline
func (p *Conn) SetDeadline(t time.Time) error {
p.readDeadline = t
return p.conn.SetDeadline(t)
}

// SetReadDeadline wraps original conn.SetReadDeadline
func (p *Conn) SetReadDeadline(t time.Time) error {
// Set a local var that tells us the desired deadline. This is
// needed in order to reset the read deadline to the one that is
// desired by the user, rather than an empty deadline.
p.readDeadline = t
return p.conn.SetReadDeadline(t)
}

Expand All @@ -215,7 +230,28 @@ func (p *Conn) SetWriteDeadline(t time.Time) error {
}

func (p *Conn) readHeader() error {
// If the connection's readHeaderTimeout is more than 0,
// push our deadline back to now plus the timeout. This should only
// run on the connection, as we don't want to override the previous
// read deadline the user may have used.
if p.readHeaderTimeout > 0 {
p.conn.SetReadDeadline(time.Now().Add(p.readHeaderTimeout))
}

header, err := Read(p.bufReader)

// If the connection's readHeaderTimeout is more than 0, undo the change to the
// deadline that we made above. Because we retain the readDeadline as part of our
// SetReadDeadline override, we know the user's desired deadline so we use that.
// Therefore, we check whether the error is a net.Timeout and if it is, we decide
// the proxy proto does not exist and set the error accordingly.
if p.readHeaderTimeout > 0 {
p.conn.SetReadDeadline(p.readDeadline)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
err = ErrNoProxyProtocol
}
}

// For the purpose of this wrapper shamefully stolen from armon/go-proxyproto
// let's act as if there was no error when PROXY protocol is not present.
if err == ErrNoProxyProtocol {
Expand Down

0 comments on commit 0c5719a

Please sign in to comment.