diff --git a/acceptor/acceptor.go b/acceptor/acceptor.go index 7d1d7dff..d4d1f142 100644 --- a/acceptor/acceptor.go +++ b/acceptor/acceptor.go @@ -22,10 +22,16 @@ package acceptor import "net" +// PlayerConn iface +type PlayerConn interface { + GetNextMessage() (b []byte, err error) + net.Conn +} + // Acceptor type interface type Acceptor interface { ListenAndServe() Stop() GetAddr() string - GetConnChan() chan net.Conn + GetConnChan() chan PlayerConn } diff --git a/acceptor/tcp_acceptor.go b/acceptor/tcp_acceptor.go index 66b576f8..4f54ccd0 100644 --- a/acceptor/tcp_acceptor.go +++ b/acceptor/tcp_acceptor.go @@ -21,9 +21,16 @@ package acceptor import ( + "bytes" "crypto/tls" + "errors" + "fmt" + "io" + "io/ioutil" "net" + "github.com/topfreegames/pitaya/conn/codec" + "github.com/topfreegames/pitaya/conn/packet" "github.com/topfreegames/pitaya/constants" "github.com/topfreegames/pitaya/logger" ) @@ -31,13 +38,44 @@ import ( // TCPAcceptor struct type TCPAcceptor struct { addr string - connChan chan net.Conn + connChan chan PlayerConn listener net.Listener running bool certFile string keyFile string } +type tcpPlayerConn struct { + net.Conn +} + +// GetNextMessage reads the next message available in the stream +func (t *tcpPlayerConn) GetNextMessage() (b []byte, err error) { + msgBuffer := bytes.NewBuffer(nil) + header, err := ioutil.ReadAll(io.LimitReader(t.Conn, codec.HeadLength)) + if err != nil { + return nil, err + } + if len(header) == 0 { + return nil, errors.New("EOF") + } + typ := header[0] + if typ < packet.Handshake || typ > packet.Kick { + return nil, fmt.Errorf("Invalid packet type received: %x (maybe the header is corrupted?)", header[0]) + } + msgBuffer.Write(header) + remainingSize := codec.BytesToInt(header[1:codec.HeadLength]) + msgData, err := ioutil.ReadAll(io.LimitReader(t.Conn, int64(remainingSize))) + if err != nil { + return nil, err + } + if len(msgData) < remainingSize { + return nil, errors.New("Received less data than expected, EOF?") + } + msgBuffer.Write(msgData) + return msgBuffer.Bytes(), nil +} + // NewTCPAcceptor creates a new instance of tcp acceptor func NewTCPAcceptor(addr string, certs ...string) *TCPAcceptor { keyFile := "" @@ -51,7 +89,7 @@ func NewTCPAcceptor(addr string, certs ...string) *TCPAcceptor { return &TCPAcceptor{ addr: addr, - connChan: make(chan net.Conn), + connChan: make(chan PlayerConn), running: false, certFile: certFile, keyFile: keyFile, @@ -67,7 +105,7 @@ func (a *TCPAcceptor) GetAddr() string { } // GetConnChan gets a connection channel -func (a *TCPAcceptor) GetConnChan() chan net.Conn { +func (a *TCPAcceptor) GetConnChan() chan PlayerConn { return a.connChan } @@ -121,6 +159,8 @@ func (a *TCPAcceptor) serve() { continue } - a.connChan <- conn + a.connChan <- &tcpPlayerConn{ + Conn: conn, + } } } diff --git a/acceptor/tcp_acceptor_test.go b/acceptor/tcp_acceptor_test.go index 4619256c..134b8126 100644 --- a/acceptor/tcp_acceptor_test.go +++ b/acceptor/tcp_acceptor_test.go @@ -21,6 +21,7 @@ package acceptor import ( + "errors" "net" "testing" "time" @@ -145,3 +146,132 @@ func TestStop(t *testing.T) { }) } } + +func TestGetNextMessage(t *testing.T) { + tables := []struct { + name string + data []byte + err error + }{ + {"invalid_header", []byte{0x00, 0x00, 0x00, 0x00}, errors.New("Invalid packet type received: 0 (maybe the header is corrupted?)")}, + {"valid_message", []byte{0x02, 0x00, 0x00, 0x01, 0x00}, nil}, + } + + for _, table := range tables { + t.Run(table.name, func(t *testing.T) { + a := NewTCPAcceptor("0.0.0.0:0") + go a.ListenAndServe() + defer a.Stop() + c := a.GetConnChan() + // should be able to connect within 100 milliseconds + var conn net.Conn + var err error + helpers.ShouldEventuallyReturn(t, func() error { + conn, err = net.Dial("tcp", a.GetAddr()) + return err + }, nil, 10*time.Millisecond, 100*time.Millisecond) + + defer conn.Close() + playerConn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(PlayerConn) + _, err = conn.Write(table.data) + assert.NoError(t, err) + + msg, err := playerConn.GetNextMessage() + if table.err != nil { + assert.EqualError(t, err, table.err.Error()) + } else { + assert.Equal(t, table.data, msg) + assert.NoError(t, err) + } + }) + } +} + +func TestGetNextMessageTwoMessagesInBuffer(t *testing.T) { + a := NewTCPAcceptor("0.0.0.0:0") + go a.ListenAndServe() + defer a.Stop() + c := a.GetConnChan() + // should be able to connect within 100 milliseconds + var conn net.Conn + var err error + helpers.ShouldEventuallyReturn(t, func() error { + conn, err = net.Dial("tcp", a.GetAddr()) + return err + }, nil, 10*time.Millisecond, 100*time.Millisecond) + defer conn.Close() + + playerConn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(PlayerConn) + msg1 := []byte{0x01, 0x00, 0x00, 0x01, 0x02} + msg2 := []byte{0x02, 0x00, 0x00, 0x02, 0x01, 0x01} + buffer := append(msg1, msg2...) + _, err = conn.Write(buffer) + assert.NoError(t, err) + + msg, err := playerConn.GetNextMessage() + assert.NoError(t, err) + assert.Equal(t, msg1, msg) + + msg, err = playerConn.GetNextMessage() + assert.NoError(t, err) + assert.Equal(t, msg2, msg) +} + +func TestGetNextMessageEOF(t *testing.T) { + a := NewTCPAcceptor("0.0.0.0:0") + go a.ListenAndServe() + defer a.Stop() + c := a.GetConnChan() + // should be able to connect within 100 milliseconds + var conn net.Conn + var err error + helpers.ShouldEventuallyReturn(t, func() error { + conn, err = net.Dial("tcp", a.GetAddr()) + return err + }, nil, 10*time.Millisecond, 100*time.Millisecond) + + playerConn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(PlayerConn) + buffer := []byte{0x02, 0x00, 0x00, 0x02, 0x01} + _, err = conn.Write(buffer) + assert.NoError(t, err) + + go func() { + time.Sleep(100 * time.Millisecond) + conn.Close() + }() + + _, err = playerConn.GetNextMessage() + assert.EqualError(t, err, "Received less data than expected, EOF?") + +} + +func TestGetNextMessageInParts(t *testing.T) { + a := NewTCPAcceptor("0.0.0.0:0") + go a.ListenAndServe() + defer a.Stop() + c := a.GetConnChan() + // should be able to connect within 100 milliseconds + var conn net.Conn + var err error + helpers.ShouldEventuallyReturn(t, func() error { + conn, err = net.Dial("tcp", a.GetAddr()) + return err + }, nil, 10*time.Millisecond, 100*time.Millisecond) + + defer conn.Close() + playerConn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(PlayerConn) + part1 := []byte{0x02, 0x00, 0x00, 0x03, 0x01} + part2 := []byte{0x01, 0x02} + _, err = conn.Write(part1) + assert.NoError(t, err) + + go func() { + time.Sleep(200 * time.Millisecond) + _, err = conn.Write(part2) + }() + + msg, err := playerConn.GetNextMessage() + assert.NoError(t, err) + assert.Equal(t, msg, append(part1, part2...)) + +} diff --git a/acceptor/ws_acceptor.go b/acceptor/ws_acceptor.go index 753a9b91..870af775 100644 --- a/acceptor/ws_acceptor.go +++ b/acceptor/ws_acceptor.go @@ -22,12 +22,16 @@ package acceptor import ( "crypto/tls" + "errors" + "fmt" "io" "net" "net/http" "time" "github.com/gorilla/websocket" + "github.com/topfreegames/pitaya/conn/codec" + "github.com/topfreegames/pitaya/conn/packet" "github.com/topfreegames/pitaya/constants" "github.com/topfreegames/pitaya/logger" ) @@ -35,7 +39,7 @@ import ( // WSAcceptor struct type WSAcceptor struct { addr string - connChan chan net.Conn + connChan chan PlayerConn listener net.Listener certFile string keyFile string @@ -54,7 +58,7 @@ func NewWSAcceptor(addr string, certs ...string) *WSAcceptor { w := &WSAcceptor{ addr: addr, - connChan: make(chan net.Conn), + connChan: make(chan PlayerConn), certFile: certFile, keyFile: keyFile, } @@ -70,13 +74,13 @@ func (w *WSAcceptor) GetAddr() string { } // GetConnChan gets a connection channel -func (w *WSAcceptor) GetConnChan() chan net.Conn { +func (w *WSAcceptor) GetConnChan() chan PlayerConn { return w.connChan } type connHandler struct { upgrader *websocket.Upgrader - connChan chan net.Conn + connChan chan PlayerConn } func (h *connHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { @@ -86,7 +90,7 @@ func (h *connHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { return } - c, err := newWSConn(conn) + c, err := NewWSConn(conn) if err != nil { logger.Log.Errorf("Failed to create new ws connection: %s", err.Error()) return @@ -160,33 +164,54 @@ func (w *WSAcceptor) Stop() { } } -// wsConn is an adapter to t.Conn, which implements all t.Conn +// WSConn is an adapter to t.Conn, which implements all t.Conn // interface base on *websocket.Conn -type wsConn struct { +type WSConn struct { conn *websocket.Conn typ int // message type reader io.Reader } -// newWSConn return an initialized *wsConn -func newWSConn(conn *websocket.Conn) (*wsConn, error) { - c := &wsConn{conn: conn} - - t, r, err := conn.NextReader() - if err != nil { - return nil, err - } - - c.typ = t - c.reader = r +// NewWSConn return an initialized *WSConn +func NewWSConn(conn *websocket.Conn) (*WSConn, error) { + c := &WSConn{conn: conn} return c, nil } +// GetNextMessage reads the next message available in the stream +func (c *WSConn) GetNextMessage() (b []byte, err error) { + _, msgBytes, err := c.conn.ReadMessage() + header := msgBytes[:codec.HeadLength] + typ := header[0] + if typ < packet.Handshake || typ > packet.Kick { + return nil, fmt.Errorf("Invalid packet type received: %x (maybe the header is corrupted?)", header[0]) + } + if len(msgBytes) < codec.HeadLength { + return nil, errors.New("Received invalid message, len(header) < 4") + } + msgSize := codec.BytesToInt(header[1:codec.HeadLength]) + dataLen := len(msgBytes[codec.HeadLength:]) + if dataLen < msgSize { + return nil, errors.New("Received less data than expected, EOF?") + } else if dataLen > msgSize { + return nil, errors.New("Received more data than expected, seems like an error in ws logic") + } + return msgBytes, err +} + // Read reads data from the connection. // Read can be made to time out and return an Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetReadDeadline. -func (c *wsConn) Read(b []byte) (int, error) { +func (c *WSConn) Read(b []byte) (int, error) { + if c.reader == nil { + t, r, err := c.conn.NextReader() + if err != nil { + return 0, err + } + c.typ = t + c.reader = r + } n, err := c.reader.Read(b) if err != nil && err != io.EOF { return n, err @@ -204,7 +229,7 @@ func (c *wsConn) Read(b []byte) (int, error) { // Write writes data to the connection. // Write can be made to time out and return an Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetWriteDeadline. -func (c *wsConn) Write(b []byte) (int, error) { +func (c *WSConn) Write(b []byte) (int, error) { err := c.conn.WriteMessage(websocket.BinaryMessage, b) if err != nil { return 0, err @@ -215,17 +240,17 @@ func (c *wsConn) Write(b []byte) (int, error) { // Close closes the connection. // Any blocked Read or Write operations will be unblocked and return errors. -func (c *wsConn) Close() error { +func (c *WSConn) Close() error { return c.conn.Close() } // LocalAddr returns the local network address. -func (c *wsConn) LocalAddr() net.Addr { +func (c *WSConn) LocalAddr() net.Addr { return c.conn.LocalAddr() } // RemoteAddr returns the remote network address. -func (c *wsConn) RemoteAddr() net.Addr { +func (c *WSConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } @@ -244,7 +269,7 @@ func (c *wsConn) RemoteAddr() net.Addr { // the deadline after successful Read or Write calls. // // A zero value for t means I/O operations will not time out. -func (c *wsConn) SetDeadline(t time.Time) error { +func (c *WSConn) SetDeadline(t time.Time) error { if err := c.SetReadDeadline(t); err != nil { return err } @@ -255,7 +280,7 @@ func (c *wsConn) SetDeadline(t time.Time) error { // SetReadDeadline sets the deadline for future Read calls // and any currently-blocked Read call. // A zero value for t means Read will not time out. -func (c *wsConn) SetReadDeadline(t time.Time) error { +func (c *WSConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } @@ -264,6 +289,6 @@ func (c *wsConn) SetReadDeadline(t time.Time) error { // Even if write times out, it may return n > 0, indicating that // some of the data was successfully written. // A zero value for t means Write will not time out. -func (c *wsConn) SetWriteDeadline(t time.Time) error { +func (c *WSConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } diff --git a/acceptor/ws_acceptor_test.go b/acceptor/ws_acceptor_test.go index e7bf5b5e..dcca2d6a 100644 --- a/acceptor/ws_acceptor_test.go +++ b/acceptor/ws_acceptor_test.go @@ -2,6 +2,7 @@ package acceptor import ( "crypto/tls" + "errors" "fmt" "testing" "time" @@ -96,7 +97,7 @@ func TestWSAcceptorListenAndServe(t *testing.T) { defer w.Stop() go w.ListenAndServe() mustConnectToWS(t, table.write, w, "ws") - conn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*wsConn) + conn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*WSConn) defer conn.Close() assert.NotNil(t, conn) }) @@ -111,7 +112,7 @@ func TestWSAcceptorListenAndServeTLS(t *testing.T) { defer w.Stop() go w.ListenAndServeTLS("./fixtures/server.crt", "./fixtures/server.key") mustConnectToWS(t, table.write, w, "wss") - conn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*wsConn) + conn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*WSConn) defer conn.Close() assert.NotNil(t, conn) }) @@ -140,7 +141,7 @@ func TestWSConnRead(t *testing.T) { defer w.Stop() go w.ListenAndServe() mustConnectToWS(t, table.write, w, "ws") - conn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*wsConn) + conn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*WSConn) defer conn.Close() b := make([]byte, len(table.write)) n, err := conn.Read(b) @@ -159,7 +160,7 @@ func TestWSConnWrite(t *testing.T) { defer w.Stop() go w.ListenAndServe() mustConnectToWS(t, table.write, w, "ws") - conn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*wsConn) + conn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*WSConn) defer conn.Close() b := make([]byte, len(table.write)) n, err := conn.Write(b) @@ -177,7 +178,7 @@ func TestWSConnLocalAddr(t *testing.T) { defer w.Stop() go w.ListenAndServe() mustConnectToWS(t, table.write, w, "ws") - conn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*wsConn) + conn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*WSConn) defer conn.Close() a := conn.LocalAddr().String() assert.NotEmpty(t, a) @@ -193,7 +194,7 @@ func TestWSConnRemoteAddr(t *testing.T) { defer w.Stop() go w.ListenAndServe() mustConnectToWS(t, table.write, w, "ws") - conn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*wsConn) + conn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*WSConn) defer conn.Close() a := conn.RemoteAddr().String() assert.NotEmpty(t, a) @@ -209,7 +210,7 @@ func TestWSConnSetDeadline(t *testing.T) { defer w.Stop() go w.ListenAndServe() mustConnectToWS(t, table.write, w, "ws") - conn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*wsConn) + conn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*WSConn) defer conn.Close() conn.SetDeadline(time.Now().Add(5 * time.Millisecond)) time.Sleep(10 * time.Millisecond) @@ -218,3 +219,77 @@ func TestWSConnSetDeadline(t *testing.T) { }) } } + +func TestWSGetNextMessage(t *testing.T) { + tables := []struct { + name string + data []byte + err error + }{ + {"invalid_header", []byte{0x00, 0x00, 0x00, 0x00}, errors.New("Invalid packet type received: 0 (maybe the header is corrupted?)")}, + {"valid_message", []byte{0x02, 0x00, 0x00, 0x01, 0x00}, nil}, + {"invalid_message", []byte{0x02, 0x00, 0x00, 0x02, 0x00}, errors.New("Received less data than expected, EOF?")}, + {"invalid_header", []byte{0x02, 0x00}, errors.New("Received invalid message, len(header) < 4")}, + } + + for _, table := range tables { + t.Run(table.name, func(t *testing.T) { + w := NewWSAcceptor("0.0.0.0:0") + c := w.GetConnChan() + defer w.Stop() + go w.ListenAndServe() + + var conn *websocket.Conn + var err error + helpers.ShouldEventuallyReturn(t, func() error { + addr := fmt.Sprintf("%s://%s", "ws", w.GetAddr()) + dialer := websocket.DefaultDialer + conn, _, err = dialer.Dial(addr, nil) + return err + }, nil, 10*time.Millisecond, 100*time.Millisecond) + + playerConn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*WSConn) + defer playerConn.Close() + err = conn.WriteMessage(websocket.BinaryMessage, table.data) + assert.NoError(t, err) + msg, err := playerConn.GetNextMessage() + if table.err != nil { + assert.EqualError(t, err, table.err.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, table.data, msg) + } + }) + } +} + +func TestWSGetNextMessageSequentially(t *testing.T) { + w := NewWSAcceptor("0.0.0.0:0") + c := w.GetConnChan() + defer w.Stop() + go w.ListenAndServe() + + var conn *websocket.Conn + var err error + helpers.ShouldEventuallyReturn(t, func() error { + addr := fmt.Sprintf("%s://%s", "ws", w.GetAddr()) + dialer := websocket.DefaultDialer + conn, _, err = dialer.Dial(addr, nil) + return err + }, nil, 10*time.Millisecond, 100*time.Millisecond) + + playerConn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(*WSConn) + defer playerConn.Close() + msg1 := []byte{0x01, 0x00, 0x00, 0x02, 0x01, 0x01} + msg2 := []byte{0x02, 0x00, 0x00, 0x02, 0x05, 0x04} + err = conn.WriteMessage(websocket.BinaryMessage, msg1) + assert.NoError(t, err) + err = conn.WriteMessage(websocket.BinaryMessage, msg2) + assert.NoError(t, err) + msg, err := playerConn.GetNextMessage() + assert.NoError(t, err) + assert.Equal(t, msg1, msg) + msg, err = playerConn.GetNextMessage() + assert.NoError(t, err) + assert.Equal(t, msg2, msg) +} diff --git a/acceptorwrapper/base.go b/acceptorwrapper/base.go index bc622fe3..55345b08 100644 --- a/acceptorwrapper/base.go +++ b/acceptorwrapper/base.go @@ -21,8 +21,6 @@ package acceptorwrapper import ( - "net" - "github.com/topfreegames/pitaya/acceptor" ) @@ -32,14 +30,14 @@ import ( // Any new wrapper can inherit from BaseWrapper and just implement wrapConn. type BaseWrapper struct { acceptor.Acceptor - connChan chan net.Conn - wrapConn func(net.Conn) net.Conn + connChan chan acceptor.PlayerConn + wrapConn func(acceptor.PlayerConn) acceptor.PlayerConn } // NewBaseWrapper returns an instance of BaseWrapper. -func NewBaseWrapper(wrapConn func(net.Conn) net.Conn) BaseWrapper { +func NewBaseWrapper(wrapConn func(acceptor.PlayerConn) acceptor.PlayerConn) BaseWrapper { return BaseWrapper{ - connChan: make(chan net.Conn), + connChan: make(chan acceptor.PlayerConn), wrapConn: wrapConn, } } @@ -52,7 +50,7 @@ func (b *BaseWrapper) ListenAndServe() { } // GetConnChan returns the wrapper conn chan -func (b *BaseWrapper) GetConnChan() chan net.Conn { +func (b *BaseWrapper) GetConnChan() chan acceptor.PlayerConn { return b.connChan } diff --git a/acceptorwrapper/base_test.go b/acceptorwrapper/base_test.go index a0e6306c..6ccc0d07 100644 --- a/acceptorwrapper/base_test.go +++ b/acceptorwrapper/base_test.go @@ -21,9 +21,10 @@ package acceptorwrapper import ( - "net" "testing" + "github.com/topfreegames/pitaya/acceptor" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/topfreegames/pitaya/mocks" @@ -36,9 +37,9 @@ func TestListenAndServe(t *testing.T) { defer ctrl.Finish() mockAcceptor := mocks.NewMockAcceptor(ctrl) - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) - conns := make(chan net.Conn) + conns := make(chan acceptor.PlayerConn) exit := make(chan struct{}) reads := 3 go func() { @@ -51,8 +52,8 @@ func TestListenAndServe(t *testing.T) { mockAcceptor.EXPECT().GetConnChan().Return(conns) wrapper := &BaseWrapper{ Acceptor: mockAcceptor, - connChan: make(chan net.Conn), - wrapConn: func(c net.Conn) net.Conn { + connChan: make(chan acceptor.PlayerConn), + wrapConn: func(c acceptor.PlayerConn) acceptor.PlayerConn { _, err := c.Read([]byte{}) assert.NoError(t, err) return c diff --git a/acceptorwrapper/rate_limiter.go b/acceptorwrapper/rate_limiter.go index ba4f451e..4e0f20af 100644 --- a/acceptorwrapper/rate_limiter.go +++ b/acceptorwrapper/rate_limiter.go @@ -22,10 +22,10 @@ package acceptorwrapper import ( "container/list" - "net" "time" "github.com/topfreegames/pitaya" + "github.com/topfreegames/pitaya/acceptor" "github.com/topfreegames/pitaya/constants" "github.com/topfreegames/pitaya/logger" "github.com/topfreegames/pitaya/metrics" @@ -41,7 +41,7 @@ import ( // On the client side, this will yield a timeout error and the client must // be prepared to handle it. type RateLimiter struct { - net.Conn + acceptor.PlayerConn limit int interval time.Duration times list.List @@ -50,13 +50,13 @@ type RateLimiter struct { // NewRateLimiter returns an initialized *RateLimiting func NewRateLimiter( - conn net.Conn, + conn acceptor.PlayerConn, limit int, interval time.Duration, forceDisable bool, ) *RateLimiter { r := &RateLimiter{ - Conn: conn, + PlayerConn: conn, limit: limit, interval: interval, forceDisable: forceDisable, @@ -67,25 +67,26 @@ func NewRateLimiter( return r } -func (r *RateLimiter) Read(b []byte) (n int, err error) { +// GetNextMessage gets the next message in the connection +func (r *RateLimiter) GetNextMessage() (msg []byte, err error) { if r.forceDisable { - return r.Conn.Read(b) + return r.PlayerConn.GetNextMessage() } for { - n, err = r.Conn.Read(b) + msg, err := r.PlayerConn.GetNextMessage() if err != nil { - return n, err + return nil, err } now := time.Now() if r.shouldRateLimit(now) { - logger.Log.Errorf("Data=%s, Error=%s", b, constants.ErrRateLimitExceeded) + logger.Log.Errorf("Data=%s, Error=%s", msg, constants.ErrRateLimitExceeded) metrics.ReportExceededRateLimiting(pitaya.GetMetricsReporters()) continue } - return n, err + return msg, err } } diff --git a/acceptorwrapper/rate_limiter_test.go b/acceptorwrapper/rate_limiter_test.go index 09dfa223..83c7fe35 100644 --- a/acceptorwrapper/rate_limiter_test.go +++ b/acceptorwrapper/rate_limiter_test.go @@ -30,40 +30,40 @@ import ( "github.com/topfreegames/pitaya/mocks" ) -func TestRateLimiterRead(t *testing.T) { +func TestRateLimiterGetNextMessage(t *testing.T) { t.Parallel() var ( limit = 3 interval = time.Second - buf = []byte{} + ret = []byte{0x01, 0x00, 0x00, 0x01, 0x01} errTest = errors.New("error") - mockConn *mocks.MockConn + mockConn *mocks.MockPlayerConn r *RateLimiter ) tables := map[string]struct { forceDisable bool mock func() - expected int + expected []byte err error }{ "test_can_read_on_first_time": { forceDisable: false, mock: func() { - mockConn.EXPECT().Read(buf).Return(10, nil) + mockConn.EXPECT().GetNextMessage().Return(ret, nil) }, - expected: 10, + expected: ret, err: nil, }, "test_read_return_error": { forceDisable: false, mock: func() { - mockConn.EXPECT().Read(buf).Return(0, errTest) + mockConn.EXPECT().GetNextMessage().Return(nil, errTest) }, - expected: 0, + expected: nil, err: errTest, }, @@ -71,17 +71,17 @@ func TestRateLimiterRead(t *testing.T) { forceDisable: false, mock: func() { for i := 0; i < limit; i++ { - mockConn.EXPECT().Read(buf).Return(10, nil) - _, err := r.Read(buf) + mockConn.EXPECT().GetNextMessage().Return(ret, nil) + _, err := r.GetNextMessage() assert.NoError(t, err) } // exceed after this call - mockConn.EXPECT().Read(buf).Return(10, nil) + mockConn.EXPECT().GetNextMessage().Return(ret, nil) // back to for begin, return error to leave for loop - mockConn.EXPECT().Read(buf).Return(0, errTest) + mockConn.EXPECT().GetNextMessage().Return(ret, errTest) }, - expected: 0, + expected: nil, err: errTest, }, @@ -89,14 +89,14 @@ func TestRateLimiterRead(t *testing.T) { forceDisable: true, mock: func() { for i := 0; i < limit; i++ { - mockConn.EXPECT().Read(buf).Return(10, nil) - _, err := r.Read(buf) + mockConn.EXPECT().GetNextMessage().Return(ret, nil) + _, err := r.GetNextMessage() assert.NoError(t, err) } - mockConn.EXPECT().Read(buf).Return(10, nil) + mockConn.EXPECT().GetNextMessage().Return(ret, nil) }, - expected: 10, // exceed but ignored, so return the value of read + expected: ret, // exceed but ignored, so return the value of read err: nil, }, } @@ -105,14 +105,14 @@ func TestRateLimiterRead(t *testing.T) { t.Run(name, func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockConn = mocks.NewMockConn(ctrl) + mockConn = mocks.NewMockPlayerConn(ctrl) r = NewRateLimiter(mockConn, limit, interval, table.forceDisable) table.mock() - n, err := r.Read(buf) + buf, err := r.GetNextMessage() assert.Equal(t, table.err, err) - assert.Equal(t, table.expected, n) + assert.Equal(t, table.expected, buf) }) } } diff --git a/acceptorwrapper/rate_limiting_wrapper.go b/acceptorwrapper/rate_limiting_wrapper.go index 3171e4ce..eefbbc94 100644 --- a/acceptorwrapper/rate_limiting_wrapper.go +++ b/acceptorwrapper/rate_limiting_wrapper.go @@ -21,8 +21,6 @@ package acceptorwrapper import ( - "net" - "github.com/topfreegames/pitaya/acceptor" "github.com/topfreegames/pitaya/config" ) @@ -37,7 +35,7 @@ type RateLimitingWrapper struct { func NewRateLimitingWrapper(c *config.Config) *RateLimitingWrapper { r := &RateLimitingWrapper{} - r.BaseWrapper = NewBaseWrapper(func(conn net.Conn) net.Conn { + r.BaseWrapper = NewBaseWrapper(func(conn acceptor.PlayerConn) acceptor.PlayerConn { var ( limit = c.GetInt("pitaya.conn.ratelimiting.limit") interval = c.GetDuration("pitaya.conn.ratelimiting.interval") diff --git a/agent/agent.go b/agent/agent.go index 2f4e4d36..3b6b5403 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -103,8 +103,14 @@ func NewAgent( metricsReporters []metrics.Reporter, ) *Agent { // initialize heartbeat and handshake data on first user connection + serializerName := "" + + if serializer != nil { //should never be true, only during testing + serializerName = serializer.GetName() + } + once.Do(func() { - hbdEncode(heartbeatTime, packetEncoder, messageEncoder.IsCompressionEnabled(), serializer.GetName()) + hbdEncode(heartbeatTime, packetEncoder, messageEncoder.IsCompressionEnabled(), serializerName) }) a := &Agent{ diff --git a/agent/agent_test.go b/agent/agent_test.go index 24778216..14c7e176 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -77,7 +77,7 @@ func TestNewAgent(t *testing.T) { dieChan := make(chan bool) hbTime := time.Second - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockEncoder.EXPECT().Encode(gomock.Any(), gomock.Not(gomock.Nil())).Do( func(typ packet.Type, d []byte) { @@ -128,7 +128,7 @@ func TestKick(t *testing.T) { dieChan := make(chan bool) hbTime := time.Second - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockEncoder.EXPECT().Encode(gomock.Any(), gomock.Nil()).Do( func(typ packet.Type, d []byte) { assert.EqualValues(t, packet.Kick, typ) @@ -136,6 +136,7 @@ func TestKick(t *testing.T) { mockConn.EXPECT().Write(gomock.Any()).Return(0, nil) messageEncoder := message.NewMessagesEncoder(false) + mockSerializer.EXPECT().GetName() ag := NewAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 10, dieChan, messageEncoder, nil) c := context.Background() err := ag.Kick(c) @@ -164,7 +165,8 @@ func TestAgentSend(t *testing.T) { hbTime := time.Second messageEncoder := message.NewMessagesEncoder(false) - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) + mockSerializer.EXPECT().GetName() ag := NewAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 10, dieChan, messageEncoder, nil) assert.NotNil(t, ag) @@ -222,9 +224,10 @@ func TestAgentPush(t *testing.T) { hbTime := time.Second messageEncoder := message.NewMessagesEncoder(false) mockMetricsReporter := metricsmocks.NewMockReporter(ctrl) - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockMetricsReporters := []metrics.Reporter{mockMetricsReporter} mockMetricsReporter.EXPECT().ReportGauge(metrics.ConnectedClients, gomock.Any(), gomock.Any()) + mockSerializer.EXPECT().GetName() ag := NewAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 10, dieChan, messageEncoder, mockMetricsReporters) assert.NotNil(t, ag) @@ -262,9 +265,10 @@ func TestAgentPushFullChannel(t *testing.T) { hbTime := time.Second messageEncoder := message.NewMessagesEncoder(false) mockMetricsReporter := metricsmocks.NewMockReporter(ctrl) - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockMetricsReporters := []metrics.Reporter{mockMetricsReporter} mockMetricsReporter.EXPECT().ReportGauge(metrics.ConnectedClients, gomock.Any(), gomock.Any()) + mockSerializer.EXPECT().GetName() ag := NewAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 0, dieChan, messageEncoder, mockMetricsReporters) assert.NotNil(t, ag) @@ -327,9 +331,10 @@ func TestAgentResponseMID(t *testing.T) { hbTime := time.Second messageEncoder := message.NewMessagesEncoder(false) - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockMetricsReporters := []metrics.Reporter{mockMetricsReporter} mockMetricsReporter.EXPECT().ReportGauge(metrics.ConnectedClients, gomock.Any(), gomock.Any()) + mockSerializer.EXPECT().GetName() ag := NewAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 10, dieChan, messageEncoder, mockMetricsReporters) assert.NotNil(t, ag) @@ -381,9 +386,10 @@ func TestAgentResponseMIDFullChannel(t *testing.T) { hbTime := time.Second messageEncoder := message.NewMessagesEncoder(false) mockMetricsReporter := metricsmocks.NewMockReporter(ctrl) - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockMetricsReporters := []metrics.Reporter{mockMetricsReporter} mockMetricsReporter.EXPECT().ReportGauge(metrics.ConnectedClients, gomock.Any(), gomock.Any()) + mockSerializer.EXPECT().GetName() ag := NewAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 0, dieChan, messageEncoder, mockMetricsReporters) assert.NotNil(t, ag) mockMetricsReporters[0].(*metricsmocks.MockReporter).EXPECT().ReportGauge(metrics.ChannelCapacity, gomock.Any(), float64(0)) @@ -413,7 +419,7 @@ func TestAgentClose(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockEncoder := codecmocks.NewMockPacketEncoder(ctrl) heartbeatAndHandshakeMocks(mockEncoder) mockMessageEncoder := messagemocks.NewMockEncoder(ctrl) @@ -457,7 +463,7 @@ func TestAgentRemoteAddr(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockEncoder := codecmocks.NewMockPacketEncoder(ctrl) heartbeatAndHandshakeMocks(mockEncoder) mockMessageEncoder := messagemocks.NewMockEncoder(ctrl) @@ -474,7 +480,7 @@ func TestAgentString(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockEncoder := codecmocks.NewMockPacketEncoder(ctrl) heartbeatAndHandshakeMocks(mockEncoder) mockMessageEncoder := messagemocks.NewMockEncoder(ctrl) @@ -501,7 +507,7 @@ func TestAgentGetStatus(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockEncoder := codecmocks.NewMockPacketEncoder(ctrl) heartbeatAndHandshakeMocks(mockEncoder) mockMessageEncoder := messagemocks.NewMockEncoder(ctrl) @@ -598,7 +604,7 @@ func TestAgentSendHandshakeResponse(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockEncoder := codecmocks.NewMockPacketEncoder(ctrl) heartbeatAndHandshakeMocks(mockEncoder) mockMessageEncoder := messagemocks.NewMockEncoder(ctrl) @@ -633,6 +639,7 @@ func TestAnswerWithError(t *testing.T) { mockEncoder := codecmocks.NewMockPacketEncoder(ctrl) heartbeatAndHandshakeMocks(mockEncoder) messageEncoder := message.NewMessagesEncoder(false) + mockSerializer.EXPECT().GetName() ag := NewAgent(nil, nil, mockEncoder, mockSerializer, time.Second, 1, nil, messageEncoder, nil) assert.NotNil(t, ag) @@ -652,8 +659,9 @@ func TestAgentHeartbeat(t *testing.T) { mockSerializer := serializemocks.NewMockSerializer(ctrl) mockEncoder := codecmocks.NewMockPacketEncoder(ctrl) heartbeatAndHandshakeMocks(mockEncoder) - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockMessageEncoder := messagemocks.NewMockEncoder(ctrl) + mockSerializer.EXPECT().GetName() ag := NewAgent(mockConn, nil, mockEncoder, mockSerializer, 1*time.Second, 1, nil, mockMessageEncoder, nil) assert.NotNil(t, ag) @@ -683,8 +691,9 @@ func TestAgentHeartbeatExitsIfConnError(t *testing.T) { mockSerializer := serializemocks.NewMockSerializer(ctrl) mockEncoder := codecmocks.NewMockPacketEncoder(ctrl) heartbeatAndHandshakeMocks(mockEncoder) - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockMessageEncoder := messagemocks.NewMockEncoder(ctrl) + mockSerializer.EXPECT().GetName() ag := NewAgent(mockConn, nil, mockEncoder, mockSerializer, 1*time.Second, 1, nil, mockMessageEncoder, nil) assert.NotNil(t, ag) @@ -713,12 +722,13 @@ func TestAgentHeartbeatExitsOnStopHeartbeat(t *testing.T) { mockSerializer := serializemocks.NewMockSerializer(ctrl) mockEncoder := codecmocks.NewMockPacketEncoder(ctrl) heartbeatAndHandshakeMocks(mockEncoder) - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) messageEncoder := message.NewMessagesEncoder(false) mockConn.EXPECT().RemoteAddr().MaxTimes(1) mockConn.EXPECT().Close().MaxTimes(1) + mockSerializer.EXPECT().GetName() ag := NewAgent(mockConn, nil, mockEncoder, mockSerializer, 1*time.Second, 1, nil, messageEncoder, nil) assert.NotNil(t, ag) @@ -736,7 +746,7 @@ func TestAgentWriteChSend(t *testing.T) { mockSerializer := serializemocks.NewMockSerializer(ctrl) mockEncoder := codecmocks.NewMockPacketEncoder(ctrl) - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) messageEncoder := message.NewMessagesEncoder(false) mockMetricsReporter := metricsmocks.NewMockReporter(ctrl) mockMetricsReporters := []metrics.Reporter{mockMetricsReporter} @@ -788,7 +798,7 @@ func TestAgentWriteChSendSerializeErr(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockSerializer := serializemocks.NewMockSerializer(ctrl) mockEncoder := codecmocks.NewMockPacketEncoder(ctrl) messageEncoder := message.NewMessagesEncoder(false) @@ -854,8 +864,9 @@ func TestAgentHandle(t *testing.T) { mockSerializer := serializemocks.NewMockSerializer(ctrl) mockEncoder := codecmocks.NewMockPacketEncoder(ctrl) heartbeatAndHandshakeMocks(mockEncoder) - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) messageEncoder := message.NewMessagesEncoder(false) + mockSerializer.EXPECT().GetName() ag := NewAgent(mockConn, nil, mockEncoder, mockSerializer, 1*time.Second, 1, nil, messageEncoder, nil) assert.NotNil(t, ag) @@ -909,8 +920,9 @@ func TestNatsRPCServerReportMetrics(t *testing.T) { messageEncoder := message.NewMessagesEncoder(false) mockMetricsReporter := metricsmocks.NewMockReporter(ctrl) mockMetricsReporters := []metrics.Reporter{mockMetricsReporter} - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockMetricsReporter.EXPECT().ReportGauge(metrics.ConnectedClients, gomock.Any(), gomock.Any()) + mockSerializer.EXPECT().GetName() ag := NewAgent(mockConn, mockDecoder, mockEncoder, mockSerializer, hbTime, 10, dieChan, messageEncoder, mockMetricsReporters) assert.NotNil(t, ag) @@ -942,7 +954,7 @@ func TestIPVersion(t *testing.T) { t.Run("test_"+table.addr, func(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) mockAddr := &customMockAddr{str: table.addr} mockConn.EXPECT().RemoteAddr().Return(mockAddr) diff --git a/client/client.go b/client/client.go index 7c89bce4..25272f6d 100644 --- a/client/client.go +++ b/client/client.go @@ -27,10 +27,14 @@ import ( "errors" "fmt" "net" + "net/url" "sync" "sync/atomic" "time" + "github.com/topfreegames/pitaya/acceptor" + + "github.com/gorilla/websocket" "github.com/sirupsen/logrus" "github.com/topfreegames/pitaya" "github.com/topfreegames/pitaya/conn/codec" @@ -125,6 +129,7 @@ func New(logLevel logrus.Level, requestTimeout ...time.Duration) *Client { } } +// SetClientHandshakeData sets the data to send inside handshake func (c *Client) SetClientHandshakeData(data *session.HandshakeData) { c.clientHandshakeData = data } @@ -139,6 +144,7 @@ func (c *Client) sendHandshakeRequest() error { if err != nil { return err } + _, err = c.conn.Write(p) return err } @@ -334,12 +340,16 @@ func (c *Client) Disconnect() { } } -// ConnectToTLS connects to the server at addr using TLS, for now the only supported protocol is tcp -// this methods blocks as it also handles the messages from the server -func (c *Client) ConnectToTLS(addr string, skipVerify bool) error { - conn, err := tls.Dial("tcp", addr, &tls.Config{ - InsecureSkipVerify: skipVerify, - }) +// ConnectTo connects to the server at addr, for now the only supported protocol is tcp +// if tlsConfig is sent, it connects using TLS +func (c *Client) ConnectTo(addr string, tlsConfig ...*tls.Config) error { + var conn net.Conn + var err error + if len(tlsConfig) > 0 { + conn, err = tls.Dial("tcp", addr, tlsConfig[0]) + } else { + conn, err = net.Dial("tcp", addr) + } if err != nil { return err } @@ -351,17 +361,30 @@ func (c *Client) ConnectToTLS(addr string, skipVerify bool) error { } c.closeChan = make(chan struct{}) + return nil } -// ConnectTo connects to the server at addr, for now the only supported protocol is tcp -// this methods blocks as it also handles the messages from the server -func (c *Client) ConnectTo(addr string) error { - conn, err := net.Dial("tcp", addr) +// ConnectToWS connects using webshocket protocol +func (c *Client) ConnectToWS(addr string, path string, tlsConfig ...*tls.Config) error { + u := url.URL{Scheme: "ws", Host: addr, Path: path} + dialer := websocket.DefaultDialer + + if len(tlsConfig) > 0 { + dialer.TLSClientConfig = tlsConfig[0] + u.Scheme = "wss" + } + + conn, _, err := dialer.Dial(u.String(), nil) if err != nil { return err } - c.conn = conn + + c.conn, err = acceptor.NewWSConn(conn) + if err != nil { + return err + } + c.IncomingMsgChan = make(chan *message.Message, 10) if err = c.handleHandshake(); err != nil { diff --git a/client/client_test.go b/client/client_test.go index ec5779af..2f66dfe6 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -17,7 +17,7 @@ func TestSendRequestShouldTimeout(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockConn := mocks.NewMockConn(ctrl) + mockConn := mocks.NewMockPlayerConn(ctrl) c.conn = mockConn go c.pendingRequestsReaper() diff --git a/client/pitayaclient.go b/client/pitayaclient.go index b97d9440..6d942abd 100644 --- a/client/pitayaclient.go +++ b/client/pitayaclient.go @@ -21,13 +21,16 @@ package client import ( + "crypto/tls" + "github.com/topfreegames/pitaya/conn/message" "github.com/topfreegames/pitaya/session" ) +// PitayaClient iface type PitayaClient interface { - ConnectTo(addr string) error - ConnectToTLS(addr string, skipVerify bool) error + ConnectTo(addr string, tlsConfig ...*tls.Config) error + ConnectToWS(addr string, path string, tlsConfig ...*tls.Config) error ConnectedStatus() bool Disconnect() MsgChannel() chan *message.Message diff --git a/client/protoclient.go b/client/protoclient.go index 4a04c454..83fe732e 100644 --- a/client/protoclient.go +++ b/client/protoclient.go @@ -23,6 +23,7 @@ package client import ( "bytes" "compress/gzip" + "crypto/tls" "encoding/json" "errors" "io/ioutil" @@ -317,13 +318,17 @@ func NewWithDescriptor(descriptorsRoute string, docsRoute string, docslogLevel l func (pc *ProtoClient) LoadServerInfo(addr string) error { pc.ready = false - if err := pc.Client.ConnectToTLS(addr, true); err != nil { - if err.Error() == "EOF" { - if err := pc.Client.ConnectTo(addr); err != nil { - return err + if err := pc.Client.ConnectToWS(addr, "", &tls.Config{ + InsecureSkipVerify: true, + }); err != nil { + if err := pc.Client.ConnectToWS(addr, ""); err != nil { + if err := pc.Client.ConnectTo(addr, &tls.Config{ + InsecureSkipVerify: true, + }); err != nil { + if err := pc.Client.ConnectTo(addr); err != nil { + return err + } } - } else { - return err } } @@ -415,31 +420,10 @@ func (pc *ProtoClient) waitForData() { } } -// ConnectToTLS connects to the server at addr using TLS, for now the only supported protocol is tcp -// this methods blocks as it also handles the messages from the server -func (pc *ProtoClient) ConnectToTLS(addr string, skipVerify bool) error { - err := pc.Client.ConnectToTLS(addr, skipVerify) - if err != nil { - return err - } - - if !pc.ready { - err = pc.LoadServerInfo(addr) - if err != nil { - return err - } - } - - if pc.ready { - go pc.waitForData() - } - return nil -} - // ConnectTo connects to the server at addr, for now the only supported protocol is tcp // this methods blocks as it also handles the messages from the server -func (pc *ProtoClient) ConnectTo(addr string) error { - err := pc.Client.ConnectTo(addr) +func (pc *ProtoClient) ConnectTo(addr string, tlsConfig ...*tls.Config) error { + err := pc.Client.ConnectTo(addr, tlsConfig...) if err != nil { return err } diff --git a/conn/codec/pomelo_packet_decoder.go b/conn/codec/pomelo_packet_decoder.go index 5f4f4141..b7e3e79d 100644 --- a/conn/codec/pomelo_packet_decoder.go +++ b/conn/codec/pomelo_packet_decoder.go @@ -40,7 +40,7 @@ func (c *PomeloPacketDecoder) forward(buf *bytes.Buffer) (int, packet.Type, erro if typ < packet.Handshake || typ > packet.Kick { return 0, 0x00, packet.ErrWrongPomeloPacketType } - size := bytesToInt(header[1:]) + size := BytesToInt(header[1:]) // packet length limitation if size > MaxPacketSize { @@ -73,7 +73,7 @@ func (c *PomeloPacketDecoder) Decode(data []byte) ([]*packet.Packet, error) { p := &packet.Packet{Type: typ, Length: size, Data: buf.Next(size)} packets = append(packets, p) - // more packet + // if no more packets, break if buf.Len() < HeadLength { break } @@ -86,12 +86,3 @@ func (c *PomeloPacketDecoder) Decode(data []byte) ([]*packet.Packet, error) { return packets, nil } - -// Decode packet data length byte to int(Big end) -func bytesToInt(b []byte) int { - result := 0 - for _, v := range b { - result = result<<8 + int(v) - } - return result -} diff --git a/conn/codec/pomelo_packet_encoder.go b/conn/codec/pomelo_packet_encoder.go index 2b1b5af2..90f1ed07 100644 --- a/conn/codec/pomelo_packet_encoder.go +++ b/conn/codec/pomelo_packet_encoder.go @@ -52,17 +52,8 @@ func (e *PomeloPacketEncoder) Encode(typ packet.Type, data []byte) ([]byte, erro buf := make([]byte, p.Length+HeadLength) buf[0] = byte(p.Type) - copy(buf[1:HeadLength], intToBytes(p.Length)) + copy(buf[1:HeadLength], IntToBytes(p.Length)) copy(buf[HeadLength:], data) return buf, nil } - -// Encode packet data length to bytes(Big end) -func intToBytes(n int) []byte { - buf := make([]byte, 3) - buf[0] = byte((n >> 16) & 0xFF) - buf[1] = byte((n >> 8) & 0xFF) - buf[2] = byte(n & 0xFF) - return buf -} diff --git a/conn/codec/utils.go b/conn/codec/utils.go new file mode 100644 index 00000000..75f0a1c7 --- /dev/null +++ b/conn/codec/utils.go @@ -0,0 +1,19 @@ +package codec + +// BytesToInt decode packet data length byte to int(Big end) +func BytesToInt(b []byte) int { + result := 0 + for _, v := range b { + result = result<<8 + int(v) + } + return result +} + +// IntToBytes encode packet data length to bytes(Big end) +func IntToBytes(n int) []byte { + buf := make([]byte, 3) + buf[0] = byte((n >> 16) & 0xFF) + buf[1] = byte((n >> 8) & 0xFF) + buf[2] = byte(n & 0xFF) + return buf +} diff --git a/examples/demo/chat/main.go b/examples/demo/chat/main.go index 6c4caf5c..8b400ad2 100644 --- a/examples/demo/chat/main.go +++ b/examples/demo/chat/main.go @@ -129,7 +129,9 @@ func main() { go http.ListenAndServe(":3251", nil) t := acceptor.NewTCPAcceptor(":3250") + w := acceptor.NewWSAcceptor(":3252") pitaya.AddAcceptor(t) + pitaya.AddAcceptor(w) pitaya.Configure(true, "chat", pitaya.Cluster, map[string]string{}, conf) pitaya.Start() diff --git a/go.mod b/go.mod index 9dbb7b09..e0d165d9 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/coreos/go-systemd v0.0.0-20180202092358-40e2722dffea // indirect github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf // indirect github.com/customerio/gospec v0.0.0-20130710230057-a5cc0e48aa39 // indirect - github.com/davecgh/go-spew v1.1.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect github.com/garyburd/redigo v1.6.0 // indirect github.com/ghodss/yaml v1.0.0 // indirect @@ -23,7 +23,7 @@ require ( github.com/go-playground/universal-translator v0.16.0 // indirect github.com/gogo/protobuf v1.3.0 // indirect github.com/golang/groupcache v0.0.0-20180203143532-66deaeb636df // indirect - github.com/golang/mock v1.1.1 + github.com/golang/mock v1.3.1 github.com/golang/protobuf v1.3.1 github.com/google/btree v0.0.0-20180124185431-e89373fe6b4a // indirect github.com/google/uuid v1.0.0 @@ -48,7 +48,6 @@ require ( github.com/orfjackal/nanospec.go v0.0.0-20120727230329-de4694c1d701 // indirect github.com/pelletier/go-toml v1.2.0 // indirect github.com/pkg/errors v0.8.0 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v0.8.0 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 // indirect github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e // indirect @@ -60,7 +59,7 @@ require ( github.com/spf13/jwalterweatherman v0.0.0-20180109140146-7c0cea34c8ec // indirect github.com/spf13/pflag v1.0.1 // indirect github.com/spf13/viper v1.0.2 - github.com/stretchr/testify v1.2.1 + github.com/stretchr/testify v1.3.0 github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 // indirect github.com/topfreegames/go-workers v1.0.0 github.com/uber-go/atomic v1.3.2 // indirect diff --git a/go.sum b/go.sum index 0bc79ecb..7f697107 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/customerio/gospec v0.0.0-20130710230057-a5cc0e48aa39 h1:O0YTztXI3XeJX github.com/customerio/gospec v0.0.0-20130710230057-a5cc0e48aa39/go.mod h1:OzYUFhPuL2JbjwFwrv6CZs23uBawekc6OZs+g19F0mY= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= @@ -48,6 +50,8 @@ github.com/golang/groupcache v0.0.0-20180203143532-66deaeb636df h1:Sf/EWTqecLGj5 github.com/golang/groupcache v0.0.0-20180203143532-66deaeb636df/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1 h1:G5FRp8JnTd7RQH5kemVNlMeyXQAztQ3mOWV95KxsXH8= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.3.1 h1:qGJ6qTW+x6xX/my+8YUVl4WNpX9B7+/l2tRsHGZ7f2s= +github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -138,8 +142,10 @@ github.com/spf13/pflag v1.0.1 h1:aCvUg6QPl3ibpQUxyLkrEkCHtPqYJL4x9AuhqVqFis4= github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/viper v1.0.2 h1:Ncr3ZIuJn322w2k1qmzXDnkLAdQMlJqBa9kfAH+irso= github.com/spf13/viper v1.0.2/go.mod h1:A8kyI5cUJhb8N+3pkfONlcEcZbueH6nhAm0Fq7SrnBM= -github.com/stretchr/testify v1.2.1 h1:52QO5WkIUcHGIR7EnGagH88x1bUzqGXTC5/1bDTUQ7U= -github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 h1:lYIiVDtZnyTWlNwiAxLj0bbpTcx1BWCFhXjfsvmPdNc= github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/topfreegames/go-workers v1.0.0 h1:R53uIT6nwlT45WBm79ZDnxG8W2ec9lJk3uJhZUmd3GI= @@ -169,6 +175,7 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3 h1:0GoQqolDA55aaLxZyTzK/Y2eP golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -181,6 +188,8 @@ golang.org/x/time v0.0.0-20180314180208-26559e0f760e h1:aUMCDtB7fbxaw60p2ngy69FC golang.org/x/time v0.0.0-20180314180208-26559e0f760e/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262 h1:qsl9y/CJx34tuA7QCPNp86JNJe4spst6Ff8MjvPUdPg= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/genproto v0.0.0-20170818010345-ee236bd376b0/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8 h1:Nw54tB0rB7hY/N0NQvRW8DG4Yk3Q6T9cu9RcFQDu1tc= diff --git a/mocks/acceptor.go b/mocks/acceptor.go index 6cc62ec6..07b618b7 100644 --- a/mocks/acceptor.go +++ b/mocks/acceptor.go @@ -1,15 +1,170 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/topfreegames/pitaya/acceptor (interfaces: Acceptor) +// Source: acceptor/acceptor.go -// Package mocks is a generated GoMock package. +// Package mock_acceptor is a generated GoMock package. package mocks import ( - gomock "github.com/golang/mock/gomock" net "net" reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + acceptor "github.com/topfreegames/pitaya/acceptor" ) +// MockPlayerConn is a mock of PlayerConn interface +type MockPlayerConn struct { + ctrl *gomock.Controller + recorder *MockPlayerConnMockRecorder +} + +// MockPlayerConnMockRecorder is the mock recorder for MockPlayerConn +type MockPlayerConnMockRecorder struct { + mock *MockPlayerConn +} + +// NewMockPlayerConn creates a new mock instance +func NewMockPlayerConn(ctrl *gomock.Controller) *MockPlayerConn { + mock := &MockPlayerConn{ctrl: ctrl} + mock.recorder = &MockPlayerConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockPlayerConn) EXPECT() *MockPlayerConnMockRecorder { + return m.recorder +} + +// GetNextMessage mocks base method +func (m *MockPlayerConn) GetNextMessage() ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNextMessage") + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetNextMessage indicates an expected call of GetNextMessage +func (mr *MockPlayerConnMockRecorder) GetNextMessage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNextMessage", reflect.TypeOf((*MockPlayerConn)(nil).GetNextMessage)) +} + +// Read mocks base method +func (m *MockPlayerConn) Read(b []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Read", b) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Read indicates an expected call of Read +func (mr *MockPlayerConnMockRecorder) Read(b interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockPlayerConn)(nil).Read), b) +} + +// Write mocks base method +func (m *MockPlayerConn) Write(b []byte) (int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Write", b) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Write indicates an expected call of Write +func (mr *MockPlayerConnMockRecorder) Write(b interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockPlayerConn)(nil).Write), b) +} + +// Close mocks base method +func (m *MockPlayerConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockPlayerConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPlayerConn)(nil).Close)) +} + +// LocalAddr mocks base method +func (m *MockPlayerConn) LocalAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LocalAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// LocalAddr indicates an expected call of LocalAddr +func (mr *MockPlayerConnMockRecorder) LocalAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockPlayerConn)(nil).LocalAddr)) +} + +// RemoteAddr mocks base method +func (m *MockPlayerConn) RemoteAddr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RemoteAddr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// RemoteAddr indicates an expected call of RemoteAddr +func (mr *MockPlayerConnMockRecorder) RemoteAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockPlayerConn)(nil).RemoteAddr)) +} + +// SetDeadline mocks base method +func (m *MockPlayerConn) SetDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetDeadline indicates an expected call of SetDeadline +func (mr *MockPlayerConnMockRecorder) SetDeadline(t interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockPlayerConn)(nil).SetDeadline), t) +} + +// SetReadDeadline mocks base method +func (m *MockPlayerConn) SetReadDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetReadDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetReadDeadline indicates an expected call of SetReadDeadline +func (mr *MockPlayerConnMockRecorder) SetReadDeadline(t interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockPlayerConn)(nil).SetReadDeadline), t) +} + +// SetWriteDeadline mocks base method +func (m *MockPlayerConn) SetWriteDeadline(t time.Time) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWriteDeadline", t) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetWriteDeadline indicates an expected call of SetWriteDeadline +func (mr *MockPlayerConnMockRecorder) SetWriteDeadline(t interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockPlayerConn)(nil).SetWriteDeadline), t) +} + // MockAcceptor is a mock of Acceptor interface type MockAcceptor struct { ctrl *gomock.Controller @@ -33,8 +188,33 @@ func (m *MockAcceptor) EXPECT() *MockAcceptorMockRecorder { return m.recorder } +// ListenAndServe mocks base method +func (m *MockAcceptor) ListenAndServe() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ListenAndServe") +} + +// ListenAndServe indicates an expected call of ListenAndServe +func (mr *MockAcceptorMockRecorder) ListenAndServe() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListenAndServe", reflect.TypeOf((*MockAcceptor)(nil).ListenAndServe)) +} + +// Stop mocks base method +func (m *MockAcceptor) Stop() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Stop") +} + +// Stop indicates an expected call of Stop +func (mr *MockAcceptorMockRecorder) Stop() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockAcceptor)(nil).Stop)) +} + // GetAddr mocks base method func (m *MockAcceptor) GetAddr() string { + m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAddr") ret0, _ := ret[0].(string) return ret0 @@ -42,37 +222,20 @@ func (m *MockAcceptor) GetAddr() string { // GetAddr indicates an expected call of GetAddr func (mr *MockAcceptorMockRecorder) GetAddr() *gomock.Call { + mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAddr", reflect.TypeOf((*MockAcceptor)(nil).GetAddr)) } // GetConnChan mocks base method -func (m *MockAcceptor) GetConnChan() chan net.Conn { +func (m *MockAcceptor) GetConnChan() chan acceptor.PlayerConn { + m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetConnChan") - ret0, _ := ret[0].(chan net.Conn) + ret0, _ := ret[0].(chan acceptor.PlayerConn) return ret0 } // GetConnChan indicates an expected call of GetConnChan func (mr *MockAcceptorMockRecorder) GetConnChan() *gomock.Call { + mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConnChan", reflect.TypeOf((*MockAcceptor)(nil).GetConnChan)) } - -// ListenAndServe mocks base method -func (m *MockAcceptor) ListenAndServe() { - m.ctrl.Call(m, "ListenAndServe") -} - -// ListenAndServe indicates an expected call of ListenAndServe -func (mr *MockAcceptorMockRecorder) ListenAndServe() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListenAndServe", reflect.TypeOf((*MockAcceptor)(nil).ListenAndServe)) -} - -// Stop mocks base method -func (m *MockAcceptor) Stop() { - m.ctrl.Call(m, "Stop") -} - -// Stop indicates an expected call of Stop -func (mr *MockAcceptorMockRecorder) Stop() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockAcceptor)(nil).Stop)) -} diff --git a/mocks/net.go b/mocks/net.go deleted file mode 100644 index cd6cfb56..00000000 --- a/mocks/net.go +++ /dev/null @@ -1,181 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: nn.go - -// Package mocks is a generated GoMock package. -package mocks - -import ( - "net" - reflect "reflect" - time "time" - - gomock "github.com/golang/mock/gomock" -) - -// MockAddr is a mock of Addr interface -type MockAddr struct { - ctrl *gomock.Controller - recorder *MockAddrMockRecorder -} - -// MockAddrMockRecorder is the mock recorder for MockAddr -type MockAddrMockRecorder struct { - mock *MockAddr -} - -// NewMockAddr creates a new mock instance -func NewMockAddr(ctrl *gomock.Controller) *MockAddr { - mock := &MockAddr{ctrl: ctrl} - mock.recorder = &MockAddrMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockAddr) EXPECT() *MockAddrMockRecorder { - return m.recorder -} - -// Network mocks base method -func (m *MockAddr) Network() string { - ret := m.ctrl.Call(m, "Network") - ret0, _ := ret[0].(string) - return ret0 -} - -// Network indicates an expected call of Network -func (mr *MockAddrMockRecorder) Network() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Network", reflect.TypeOf((*MockAddr)(nil).Network)) -} - -// String mocks base method -func (m *MockAddr) String() string { - ret := m.ctrl.Call(m, "String") - ret0, _ := ret[0].(string) - return ret0 -} - -// String indicates an expected call of String -func (mr *MockAddrMockRecorder) String() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*MockAddr)(nil).String)) -} - -// MockConn is a mock of Conn interface -type MockConn struct { - ctrl *gomock.Controller - recorder *MockConnMockRecorder -} - -// MockConnMockRecorder is the mock recorder for MockConn -type MockConnMockRecorder struct { - mock *MockConn -} - -// NewMockConn creates a new mock instance -func NewMockConn(ctrl *gomock.Controller) *MockConn { - mock := &MockConn{ctrl: ctrl} - mock.recorder = &MockConnMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockConn) EXPECT() *MockConnMockRecorder { - return m.recorder -} - -// Read mocks base method -func (m *MockConn) Read(b []byte) (int, error) { - ret := m.ctrl.Call(m, "Read", b) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Read indicates an expected call of Read -func (mr *MockConnMockRecorder) Read(b interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockConn)(nil).Read), b) -} - -// Write mocks base method -func (m *MockConn) Write(b []byte) (int, error) { - ret := m.ctrl.Call(m, "Write", b) - ret0, _ := ret[0].(int) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// Write indicates an expected call of Write -func (mr *MockConnMockRecorder) Write(b interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockConn)(nil).Write), b) -} - -// Close mocks base method -func (m *MockConn) Close() error { - ret := m.ctrl.Call(m, "Close") - ret0, _ := ret[0].(error) - return ret0 -} - -// Close indicates an expected call of Close -func (mr *MockConnMockRecorder) Close() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockConn)(nil).Close)) -} - -// LocalAddr mocks base method -func (m *MockConn) LocalAddr() net.Addr { - ret := m.ctrl.Call(m, "LocalAddr") - ret0, _ := ret[0].(net.Addr) - return ret0 -} - -// LocalAddr indicates an expected call of LocalAddr -func (mr *MockConnMockRecorder) LocalAddr() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LocalAddr", reflect.TypeOf((*MockConn)(nil).LocalAddr)) -} - -// RemoteAddr mocks base method -func (m *MockConn) RemoteAddr() net.Addr { - ret := m.ctrl.Call(m, "RemoteAddr") - ret0, _ := ret[0].(net.Addr) - return ret0 -} - -// RemoteAddr indicates an expected call of RemoteAddr -func (mr *MockConnMockRecorder) RemoteAddr() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockConn)(nil).RemoteAddr)) -} - -// SetDeadline mocks base method -func (m *MockConn) SetDeadline(t time.Time) error { - ret := m.ctrl.Call(m, "SetDeadline", t) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetDeadline indicates an expected call of SetDeadline -func (mr *MockConnMockRecorder) SetDeadline(t interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockConn)(nil).SetDeadline), t) -} - -// SetReadDeadline mocks base method -func (m *MockConn) SetReadDeadline(t time.Time) error { - ret := m.ctrl.Call(m, "SetReadDeadline", t) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetReadDeadline indicates an expected call of SetReadDeadline -func (mr *MockConnMockRecorder) SetReadDeadline(t interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockConn)(nil).SetReadDeadline), t) -} - -// SetWriteDeadline mocks base method -func (m *MockConn) SetWriteDeadline(t time.Time) error { - ret := m.ctrl.Call(m, "SetWriteDeadline", t) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetWriteDeadline indicates an expected call of SetWriteDeadline -func (mr *MockConnMockRecorder) SetWriteDeadline(t interface{}) *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockConn)(nil).SetWriteDeadline), t) -} diff --git a/service/handler.go b/service/handler.go index 88651447..89f3ca71 100644 --- a/service/handler.go +++ b/service/handler.go @@ -21,14 +21,14 @@ package service import ( - "bytes" "context" "encoding/json" "fmt" - "net" "strings" "time" + "github.com/topfreegames/pitaya/acceptor" + "github.com/google/uuid" opentracing "github.com/opentracing/opentracing-go" "github.com/topfreegames/pitaya/agent" @@ -164,7 +164,7 @@ func (h *HandlerService) Register(comp component.Component, opts []component.Opt } // Handle handles messages from a conn -func (h *HandlerService) Handle(conn net.Conn) { +func (h *HandlerService) Handle(conn acceptor.PlayerConn) { // create a client agent and startup write goroutine a := agent.NewAgent(conn, h.decoder, h.encoder, h.serializer, h.heartbeatTimeout, h.messagesBufferSize, h.appDieChan, h.messageEncoder, h.metricsReporters) @@ -179,37 +179,22 @@ func (h *HandlerService) Handle(conn net.Conn) { logger.Log.Debugf("Session read goroutine exit, SessionID=%d, UID=%d", a.Session.ID(), a.Session.UID()) }() - // read loop - data := make([]byte, constants.IOBufferBytesSize) - buf := bytes.NewBuffer(nil) for { - totalLen := 0 - n := len(data) - var err error - for n == len(data) { - n, err = conn.Read(data) - if err != nil { - logger.Log.Debugf("Read message error: '%s', session will be closed immediately", err.Error()) - return - } - buf.Write(data[:n]) - totalLen += n - if totalLen > codec.MaxPacketSize { - logger.Log.Warn("received chunk > MaxPacketSize, disconnecting client...") - return - } - } + msg, err := conn.GetNextMessage() - logger.Log.Debugf("Received data on connection with len %d", totalLen) + if err != nil { + logger.Log.Errorf("Error reading next available message: %s", err.Error()) + return + } - packets, err := h.decoder.Decode(buf.Bytes()) + packets, err := h.decoder.Decode(msg) if err != nil { logger.Log.Errorf("Failed to decode message: %s", err.Error()) return } if len(packets) < 1 { - logger.Log.Warnf("Read no packets, data: %v", buf.Bytes()) + logger.Log.Warnf("Read no packets, data: %v", msg) continue } @@ -220,7 +205,7 @@ func (h *HandlerService) Handle(conn net.Conn) { return } } - buf.Reset() + //buf.Reset() } } diff --git a/service/handler_test.go b/service/handler_test.go index 9649b210..5eb2ca8f 100644 --- a/service/handler_test.go +++ b/service/handler_test.go @@ -173,11 +173,8 @@ func TestHandlerServiceProcessMessage(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() mockSerializer := serializemocks.NewMockSerializer(ctrl) - once.Do(func() { - mockSerializer.EXPECT().GetName() - }) - mockConn := connmock.NewMockConn(ctrl) + mockConn := connmock.NewMockPlayerConn(ctrl) sv := &cluster.Server{} svc := NewHandlerService(nil, nil, nil, nil, 1*time.Second, 1, 1, 1, sv, &RemoteService{}, nil, nil) @@ -186,6 +183,7 @@ func TestHandlerServiceProcessMessage(t *testing.T) { } messageEncoder := message.NewMessagesEncoder(false) + mockSerializer.EXPECT().GetName() ag := agent.NewAgent(mockConn, nil, packetEncoder, mockSerializer, 1*time.Second, 1, nil, messageEncoder, nil) svc.processMessage(ag, table.msg) @@ -227,7 +225,7 @@ func TestHandlerServiceLocalProcess(t *testing.T) { defer ctrl.Finish() mockSerializer := serializemocks.NewMockSerializer(ctrl) - mockConn := connmock.NewMockConn(ctrl) + mockConn := connmock.NewMockPlayerConn(ctrl) packetEncoder := codec.NewPomeloPacketEncoder() messageEncoder := message.NewMessagesEncoder(false) svc := NewHandlerService(nil, nil, nil, nil, 1*time.Second, 1, 1, 1, nil, nil, nil, nil) @@ -235,6 +233,8 @@ func TestHandlerServiceLocalProcess(t *testing.T) { if table.err != nil { mockSerializer.EXPECT().Marshal(table.err) } + + mockSerializer.EXPECT().GetName() ag := agent.NewAgent(mockConn, nil, packetEncoder, mockSerializer, 1*time.Second, 1, nil, messageEncoder, nil) svc.localProcess(nil, ag, table.rt, table.msg) }) @@ -257,11 +257,8 @@ func TestHandlerServiceProcessPacketHandshake(t *testing.T) { defer ctrl.Finish() mockSerializer := serializemocks.NewMockSerializer(ctrl) - once.Do(func() { - mockSerializer.EXPECT().GetName() - }) - mockConn := connmock.NewMockConn(ctrl) + mockConn := connmock.NewMockPlayerConn(ctrl) packetEncoder := codec.NewPomeloPacketEncoder() messageEncoder := message.NewMessagesEncoder(false) svc := NewHandlerService(nil, nil, nil, nil, 1*time.Second, 1, 1, 1, nil, nil, nil, nil) @@ -274,6 +271,7 @@ func TestHandlerServiceProcessPacketHandshake(t *testing.T) { mockConn.EXPECT().RemoteAddr().Return(&mockAddr{}) } + mockSerializer.EXPECT().GetName() ag := agent.NewAgent(mockConn, nil, packetEncoder, mockSerializer, 1*time.Second, 1, nil, messageEncoder, nil) err := svc.processPacket(ag, table.packet) @@ -292,7 +290,7 @@ func TestHandlerServiceProcessPacketHandshakeAck(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockConn := connmock.NewMockConn(ctrl) + mockConn := connmock.NewMockPlayerConn(ctrl) packetEncoder := codec.NewPomeloPacketEncoder() messageEncoder := message.NewMessagesEncoder(false) svc := NewHandlerService(nil, nil, nil, nil, 1*time.Second, 1, 1, 1, nil, nil, nil, nil) @@ -308,7 +306,7 @@ func TestHandlerServiceProcessPacketHeartbeat(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mockConn := connmock.NewMockConn(ctrl) + mockConn := connmock.NewMockPlayerConn(ctrl) packetEncoder := codec.NewPomeloPacketEncoder() messageEncoder := message.NewMessagesEncoder(false) svc := NewHandlerService(nil, nil, nil, nil, 1*time.Second, 1, 1, 1, nil, nil, nil, nil) @@ -343,13 +341,14 @@ func TestHandlerServiceProcessPacketData(t *testing.T) { defer ctrl.Finish() mockSerializer := serializemocks.NewMockSerializer(ctrl) - mockConn := connmock.NewMockConn(ctrl) + mockConn := connmock.NewMockPlayerConn(ctrl) packetEncoder := codec.NewPomeloPacketEncoder() messageEncoder := message.NewMessagesEncoder(false) svc := NewHandlerService(nil, nil, nil, nil, 1*time.Second, 1, 1, 1, &cluster.Server{}, nil, nil, nil) if table.socketStatus < constants.StatusWorking { mockConn.EXPECT().RemoteAddr().Return(&mockAddr{}) } + mockSerializer.EXPECT().GetName() ag := agent.NewAgent(mockConn, nil, packetEncoder, mockSerializer, 1*time.Second, 1, nil, messageEncoder, nil) ag.SetStatus(table.socketStatus) @@ -369,11 +368,9 @@ func TestHandlerServiceHandle(t *testing.T) { defer ctrl.Finish() mockSerializer := serializemocks.NewMockSerializer(ctrl) - once.Do(func() { - mockSerializer.EXPECT().GetName() - }) + mockSerializer.EXPECT().GetName() - mockConn := connmock.NewMockConn(ctrl) + mockConn := connmock.NewMockPlayerConn(ctrl) packetEncoder := codec.NewPomeloPacketEncoder() packetDecoder := codec.NewPomeloPacketDecoder() messageEncoder := message.NewMessagesEncoder(false) @@ -384,15 +381,11 @@ func TestHandlerServiceHandle(t *testing.T) { bbb, err := packetEncoder.Encode(packet.Handshake, []byte(handshakeBuffer)) assert.NoError(t, err) - firstCall := mockConn.EXPECT().Read(gomock.Any()).Do(func(b []byte) { - for i, c := range bbb { - b[i] = c - } - + firstCall := mockConn.EXPECT().GetNextMessage().Return(bbb, nil).Do(func() { wg.Done() - }).Return(len(bbb), nil) + }) - mockConn.EXPECT().Read(gomock.Any()).Return(0, errors.New("die")).Do(func(b []byte) { + mockConn.EXPECT().GetNextMessage().Return(nil, errors.New("die")).Do(func() { wg.Done() }).After(firstCall) diff --git a/service/remote_test.go b/service/remote_test.go index 49f2e085..f34d3702 100644 --- a/service/remote_test.go +++ b/service/remote_test.go @@ -437,7 +437,8 @@ func TestRemoteServiceRemoteProcess(t *testing.T) { } encoder := codec.NewPomeloPacketEncoder() - mockConn := connmock.NewMockConn(ctrl) + mockConn := connmock.NewMockPlayerConn(ctrl) + mockSerializer.EXPECT().GetName() ag := agent.NewAgent(mockConn, nil, encoder, mockSerializer, 1*time.Second, 1, nil, messageEncoder, nil) if table.responseMIDErr {