Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proxy proto tls fix #387

Merged
merged 2 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions acceptor/proxyprotowrapper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package acceptor

import (
"net"
"sync"

"github.com/mailgun/proxyproto"
"github.com/topfreegames/pitaya/v2/logger"
)

// Listener is used to wrap an underlying listener,
// whose connections may be using the HAProxy Proxy Protocol.
// If the connection is using the protocol, the RemoteAddr() will return
// the correct client address.
type ProxyProtocolListener struct {
net.Listener
proxyProtocolEnabled *bool
}

// Accept waits for and returns the next connection to the listener.
func (p *ProxyProtocolListener) Accept() (net.Conn, error) {
// Get the underlying connection
conn, err := p.Listener.Accept()
if err != nil {
return nil, err
}
connP := &Conn{Conn: conn, proxyProtocolEnabled: p.proxyProtocolEnabled}
if *p.proxyProtocolEnabled {
err = connP.checkPrefix()
if err != nil {
return connP, err
}
}
return connP, nil
}

// Conn is used to wrap and underlying connection which
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
// return the address of the client instead of the proxy address.
type Conn struct {
net.Conn
dstAddr *net.Addr
srcAddr *net.Addr
once sync.Once
proxyProtocolEnabled *bool
}

func (p *Conn) LocalAddr() net.Addr {
if p.dstAddr != nil {
return *p.dstAddr
}
return p.Conn.LocalAddr()
}

// RemoteAddr returns the address of the client if the proxy
// protocol is being used, otherwise just returns the address of
// the socket peer. If there is an error parsing the header, the
// address of the client is not returned, and the socket is closed.
// Once implication of this is that the call could block if the
// client is slow. Using a Deadline is recommended if this is called
// before Read()
func (p *Conn) RemoteAddr() net.Addr {
if p.srcAddr != nil {
return *p.srcAddr
}
return p.Conn.RemoteAddr()
}

func (p *Conn) checkPrefix() error {

h, err := proxyproto.ReadHeader(p)
if err != nil {
logger.Log.Errorf("Failed to read Proxy Protocol TCP header: %s", err.Error())
p.Close()
return err

} else if h.Source == nil {
p.Close()
} else {
p.srcAddr = &h.Source
p.dstAddr = &h.Destination
}

return nil
}
72 changes: 33 additions & 39 deletions acceptor/tcp_acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ package acceptor

import (
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"net"
"fmt"

"github.com/mailgun/proxyproto"
"github.com/topfreegames/pitaya/v2/conn/codec"
"github.com/topfreegames/pitaya/v2/constants"
"github.com/topfreegames/pitaya/v2/logger"
Expand Down Expand Up @@ -81,10 +80,10 @@ func NewTCPAcceptor(addr string, certs ...string) *TCPAcceptor {
certificates := []tls.Certificate{}
if len(certs) != 2 && len(certs) != 0 {
panic(constants.ErrIncorrectNumberOfCertificates)
} else if ( len(certs) == 2 && certs[0] != "" && certs[1] != "") {
} else if len(certs) == 2 && certs[0] != "" && certs[1] != "" {
cert, err := tls.LoadX509KeyPair(certs[0], certs[1])
if err != nil {
panic(fmt.Errorf("%w: %v",constants.ErrInvalidCertificates,err))
panic(fmt.Errorf("%w: %v", constants.ErrInvalidCertificates, err))
}
certificates = append(certificates, cert)
}
Expand Down Expand Up @@ -127,43 +126,55 @@ func (a *TCPAcceptor) hasTLSCertificates() bool {

// ListenAndServe using tcp acceptor
func (a *TCPAcceptor) ListenAndServe() {

listener := a.createBaseListener()

if a.hasTLSCertificates() {
a.listenAndServeTLS()
return
listener = a.listenAndServeTLS(listener)
}

listener, err := net.Listen("tcp", a.addr)
if err != nil {
logger.Log.Fatalf("Failed to listen: %s", err.Error())
}
a.listener = listener
a.running = true
a.serve()
}

// ListenAndServeTLS listens using tls
func (a *TCPAcceptor) ListenAndServeTLS(cert, key string) {
listener := a.createBaseListener()

crt, err := tls.LoadX509KeyPair(cert, key)
if err != nil {
logger.Log.Fatalf("Failed to listen: %s", err.Error())
}

a.certs = append(a.certs, crt)

a.listenAndServeTLS()
a.listener = a.listenAndServeTLS(listener)
a.running = true
a.serve()
}

// ListenAndServeTLS listens using tls
func (a *TCPAcceptor) listenAndServeTLS() {
tlsCfg := &tls.Config{Certificates: a.certs}

listener, err := tls.Listen("tcp", a.addr, tlsCfg)
// Create base listener
func (a *TCPAcceptor) createBaseListener() net.Listener {
// Create raw listener
baseListener, err := net.Listen("tcp", a.addr)
if err != nil {
logger.Log.Fatalf("Failed to listen: %s", err.Error())
}
a.listener = listener
a.running = true
a.serve()

// Wrap listener in ProxyProto
baseListener = &ProxyProtocolListener{Listener: baseListener, proxyProtocolEnabled: &a.proxyProtocol}

return baseListener
}

// ListenAndServeTLS listens using tls
func (a *TCPAcceptor) listenAndServeTLS(listener net.Listener) net.Listener {

tlsCfg := &tls.Config{Certificates: a.certs}
tlsListener := tls.NewListener(listener, tlsCfg)

return tlsListener
}

func (a *TCPAcceptor) EnableProxyProtocol() {
Expand All @@ -178,35 +189,18 @@ func (a *TCPAcceptor) serve() {
logger.Log.Errorf("Failed to accept TCP connection: %s", err.Error())
continue
}
var remoteAddr net.Addr
if a.proxyProtocol == true {
h, err := proxyproto.ReadHeader(conn)
if err != nil {
logger.Log.Errorf("Failed to read Proxy Protocol TCP header: %s", err.Error())
conn.Close()
continue
} else if h.Source == nil {
conn.Close()
continue
} else {
remoteAddr = h.Source
}
} else {

remoteAddr = conn.RemoteAddr()

}
a.connChan <- &tcpPlayerConn{
Conn: conn,
remoteAddr: remoteAddr,
remoteAddr: conn.RemoteAddr(),
}
}
}

func (a *TCPAcceptor) IsRunning() bool {
return a.running
return a.running
}

func (a *TCPAcceptor) GetConfiguredAddress() string {
return a.addr
return a.addr
}
4 changes: 2 additions & 2 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ func (app *App) listen() {
}
}()
if app.config.Acceptor.ProxyProtocol {
logger.Log.Info("Enabling PROXY protocol for inbond connections")
logger.Log.Info("Enabling PROXY protocol for inbound connections")
a.EnableProxyProtocol()
} else {
logger.Log.Debug("PROXY protocol is disabled for inbound connections")
Expand All @@ -391,7 +391,7 @@ func (app *App) listen() {
}()
logger.Log.Infof("Waiting for Acceptor %s to start on addr %s", reflect.TypeOf(a), a.GetConfiguredAddress())

for a.IsRunning() == false {
for !a.IsRunning() {
}

logger.Log.Infof("Acceptor %s on addr %s is now accepting connections", reflect.TypeOf(a), a.GetAddr())
Expand Down
Loading