Skip to content

Commit

Permalink
Proxy proto tls fix (#387)
Browse files Browse the repository at this point in the history
* Fix proxy protocol usage with TLS listeners

* remove debugging string
  • Loading branch information
rsafonseca committed Feb 22, 2024
1 parent aa18118 commit 7a573b7
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 41 deletions.
85 changes: 85 additions & 0 deletions acceptor/proxyprotowrapper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package acceptor

import (
"net"
"sync"

"github.com/mailgun/proxyproto"
"github.com/topfreegames/pitaya/v2/logger"
)

// Listener is used to wrap an underlying listener,
// whose connections may be using the HAProxy Proxy Protocol.
// If the connection is using the protocol, the RemoteAddr() will return
// the correct client address.
type ProxyProtocolListener struct {
net.Listener
proxyProtocolEnabled *bool
}

// Accept waits for and returns the next connection to the listener.
func (p *ProxyProtocolListener) Accept() (net.Conn, error) {
// Get the underlying connection
conn, err := p.Listener.Accept()
if err != nil {
return nil, err
}
connP := &Conn{Conn: conn, proxyProtocolEnabled: p.proxyProtocolEnabled}
if *p.proxyProtocolEnabled {
err = connP.checkPrefix()
if err != nil {
return connP, err
}
}
return connP, nil
}

// Conn is used to wrap and underlying connection which
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
// return the address of the client instead of the proxy address.
type Conn struct {
net.Conn
dstAddr *net.Addr
srcAddr *net.Addr
once sync.Once
proxyProtocolEnabled *bool
}

func (p *Conn) LocalAddr() net.Addr {
if p.dstAddr != nil {
return *p.dstAddr
}
return p.Conn.LocalAddr()
}

// RemoteAddr returns the address of the client if the proxy
// protocol is being used, otherwise just returns the address of
// the socket peer. If there is an error parsing the header, the
// address of the client is not returned, and the socket is closed.
// Once implication of this is that the call could block if the
// client is slow. Using a Deadline is recommended if this is called
// before Read()
func (p *Conn) RemoteAddr() net.Addr {
if p.srcAddr != nil {
return *p.srcAddr
}
return p.Conn.RemoteAddr()
}

func (p *Conn) checkPrefix() error {

h, err := proxyproto.ReadHeader(p)
if err != nil {
logger.Log.Errorf("Failed to read Proxy Protocol TCP header: %s", err.Error())
p.Close()
return err

} else if h.Source == nil {
p.Close()
} else {
p.srcAddr = &h.Source
p.dstAddr = &h.Destination
}

return nil
}
72 changes: 33 additions & 39 deletions acceptor/tcp_acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ package acceptor

import (
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"net"
"fmt"

"github.com/mailgun/proxyproto"
"github.com/topfreegames/pitaya/v2/conn/codec"
"github.com/topfreegames/pitaya/v2/constants"
"github.com/topfreegames/pitaya/v2/logger"
Expand Down Expand Up @@ -81,10 +80,10 @@ func NewTCPAcceptor(addr string, certs ...string) *TCPAcceptor {
certificates := []tls.Certificate{}
if len(certs) != 2 && len(certs) != 0 {
panic(constants.ErrIncorrectNumberOfCertificates)
} else if ( len(certs) == 2 && certs[0] != "" && certs[1] != "") {
} else if len(certs) == 2 && certs[0] != "" && certs[1] != "" {
cert, err := tls.LoadX509KeyPair(certs[0], certs[1])
if err != nil {
panic(fmt.Errorf("%w: %v",constants.ErrInvalidCertificates,err))
panic(fmt.Errorf("%w: %v", constants.ErrInvalidCertificates, err))
}
certificates = append(certificates, cert)
}
Expand Down Expand Up @@ -127,43 +126,55 @@ func (a *TCPAcceptor) hasTLSCertificates() bool {

// ListenAndServe using tcp acceptor
func (a *TCPAcceptor) ListenAndServe() {

listener := a.createBaseListener()

if a.hasTLSCertificates() {
a.listenAndServeTLS()
return
listener = a.listenAndServeTLS(listener)
}

listener, err := net.Listen("tcp", a.addr)
if err != nil {
logger.Log.Fatalf("Failed to listen: %s", err.Error())
}
a.listener = listener
a.running = true
a.serve()
}

// ListenAndServeTLS listens using tls
func (a *TCPAcceptor) ListenAndServeTLS(cert, key string) {
listener := a.createBaseListener()

crt, err := tls.LoadX509KeyPair(cert, key)
if err != nil {
logger.Log.Fatalf("Failed to listen: %s", err.Error())
}

a.certs = append(a.certs, crt)

a.listenAndServeTLS()
a.listener = a.listenAndServeTLS(listener)
a.running = true
a.serve()
}

// ListenAndServeTLS listens using tls
func (a *TCPAcceptor) listenAndServeTLS() {
tlsCfg := &tls.Config{Certificates: a.certs}

listener, err := tls.Listen("tcp", a.addr, tlsCfg)
// Create base listener
func (a *TCPAcceptor) createBaseListener() net.Listener {
// Create raw listener
baseListener, err := net.Listen("tcp", a.addr)
if err != nil {
logger.Log.Fatalf("Failed to listen: %s", err.Error())
}
a.listener = listener
a.running = true
a.serve()

// Wrap listener in ProxyProto
baseListener = &ProxyProtocolListener{Listener: baseListener, proxyProtocolEnabled: &a.proxyProtocol}

return baseListener
}

// ListenAndServeTLS listens using tls
func (a *TCPAcceptor) listenAndServeTLS(listener net.Listener) net.Listener {

tlsCfg := &tls.Config{Certificates: a.certs}
tlsListener := tls.NewListener(listener, tlsCfg)

return tlsListener
}

func (a *TCPAcceptor) EnableProxyProtocol() {
Expand All @@ -178,35 +189,18 @@ func (a *TCPAcceptor) serve() {
logger.Log.Errorf("Failed to accept TCP connection: %s", err.Error())
continue
}
var remoteAddr net.Addr
if a.proxyProtocol == true {
h, err := proxyproto.ReadHeader(conn)
if err != nil {
logger.Log.Errorf("Failed to read Proxy Protocol TCP header: %s", err.Error())
conn.Close()
continue
} else if h.Source == nil {
conn.Close()
continue
} else {
remoteAddr = h.Source
}
} else {

remoteAddr = conn.RemoteAddr()

}
a.connChan <- &tcpPlayerConn{
Conn: conn,
remoteAddr: remoteAddr,
remoteAddr: conn.RemoteAddr(),
}
}
}

func (a *TCPAcceptor) IsRunning() bool {
return a.running
return a.running
}

func (a *TCPAcceptor) GetConfiguredAddress() string {
return a.addr
return a.addr
}
4 changes: 2 additions & 2 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ func (app *App) listen() {
}
}()
if app.config.Acceptor.ProxyProtocol {
logger.Log.Info("Enabling PROXY protocol for inbond connections")
logger.Log.Info("Enabling PROXY protocol for inbound connections")
a.EnableProxyProtocol()
} else {
logger.Log.Debug("PROXY protocol is disabled for inbound connections")
Expand All @@ -391,7 +391,7 @@ func (app *App) listen() {
}()
logger.Log.Infof("Waiting for Acceptor %s to start on addr %s", reflect.TypeOf(a), a.GetConfiguredAddress())

for a.IsRunning() == false {
for !a.IsRunning() {
}

logger.Log.Infof("Acceptor %s on addr %s is now accepting connections", reflect.TypeOf(a), a.GetAddr())
Expand Down

0 comments on commit 7a573b7

Please sign in to comment.