Skip to content

Commit

Permalink
Refactor player connection
Browse files Browse the repository at this point in the history
Those commits decouple the logic of reading messages from the
service/handler and delegates the responsibility for the acceptor itself
by implementing a GetNextMessage method. TCP Acceptor now parses whole
messages one by one from the stream, this way we ensure not to read
partial messages anymore.
  • Loading branch information
felipejfc committed Dec 25, 2019
1 parent 046993f commit 4f14f01
Show file tree
Hide file tree
Showing 27 changed files with 704 additions and 430 deletions.
8 changes: 7 additions & 1 deletion acceptor/acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@ package acceptor

import "net"

// PlayerConn iface
type PlayerConn interface {
GetNextMessage() (b []byte, err error)
net.Conn
}

// Acceptor type interface
type Acceptor interface {
ListenAndServe()
Stop()
GetAddr() string
GetConnChan() chan net.Conn
GetConnChan() chan PlayerConn
}
48 changes: 44 additions & 4 deletions acceptor/tcp_acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,61 @@
package acceptor

import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"io/ioutil"
"net"

"github.com/topfreegames/pitaya/conn/codec"
"github.com/topfreegames/pitaya/conn/packet"
"github.com/topfreegames/pitaya/constants"
"github.com/topfreegames/pitaya/logger"
)

// TCPAcceptor struct
type TCPAcceptor struct {
addr string
connChan chan net.Conn
connChan chan PlayerConn
listener net.Listener
running bool
certFile string
keyFile string
}

type tcpPlayerConn struct {
net.Conn
}

// GetNextMessage reads the next message available in the stream
func (t *tcpPlayerConn) GetNextMessage() (b []byte, err error) {
msgBuffer := bytes.NewBuffer(nil)
header, err := ioutil.ReadAll(io.LimitReader(t.Conn, codec.HeadLength))
if err != nil {
return nil, err
}
if len(header) == 0 {
return nil, errors.New("EOF")
}
typ := header[0]
if typ < packet.Handshake || typ > packet.Kick {
return nil, fmt.Errorf("Invalid packet type received: %x (maybe the header is corrupted?)", header[0])
}
msgBuffer.Write(header)
remainingSize := codec.BytesToInt(header[1:codec.HeadLength])
msgData, err := ioutil.ReadAll(io.LimitReader(t.Conn, int64(remainingSize)))
if err != nil {
return nil, err
}
if len(msgData) < remainingSize {
return nil, errors.New("Received less data than expected, EOF?")
}
msgBuffer.Write(msgData)
return msgBuffer.Bytes(), nil
}

// NewTCPAcceptor creates a new instance of tcp acceptor
func NewTCPAcceptor(addr string, certs ...string) *TCPAcceptor {
keyFile := ""
Expand All @@ -51,7 +89,7 @@ func NewTCPAcceptor(addr string, certs ...string) *TCPAcceptor {

return &TCPAcceptor{
addr: addr,
connChan: make(chan net.Conn),
connChan: make(chan PlayerConn),
running: false,
certFile: certFile,
keyFile: keyFile,
Expand All @@ -67,7 +105,7 @@ func (a *TCPAcceptor) GetAddr() string {
}

// GetConnChan gets a connection channel
func (a *TCPAcceptor) GetConnChan() chan net.Conn {
func (a *TCPAcceptor) GetConnChan() chan PlayerConn {
return a.connChan
}

Expand Down Expand Up @@ -121,6 +159,8 @@ func (a *TCPAcceptor) serve() {
continue
}

a.connChan <- conn
a.connChan <- &tcpPlayerConn{
Conn: conn,
}
}
}
130 changes: 130 additions & 0 deletions acceptor/tcp_acceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package acceptor

import (
"errors"
"net"
"testing"
"time"
Expand Down Expand Up @@ -145,3 +146,132 @@ func TestStop(t *testing.T) {
})
}
}

func TestGetNextMessage(t *testing.T) {
tables := []struct {
name string
data []byte
err error
}{
{"invalid_header", []byte{0x00, 0x00, 0x00, 0x00}, errors.New("Invalid packet type received: 0 (maybe the header is corrupted?)")},
{"valid_message", []byte{0x02, 0x00, 0x00, 0x01, 0x00}, nil},
}

for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
a := NewTCPAcceptor("0.0.0.0:0")
go a.ListenAndServe()
defer a.Stop()
c := a.GetConnChan()
// should be able to connect within 100 milliseconds
var conn net.Conn
var err error
helpers.ShouldEventuallyReturn(t, func() error {
conn, err = net.Dial("tcp", a.GetAddr())
return err
}, nil, 10*time.Millisecond, 100*time.Millisecond)

defer conn.Close()
playerConn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(PlayerConn)
_, err = conn.Write(table.data)
assert.NoError(t, err)

msg, err := playerConn.GetNextMessage()
if table.err != nil {
assert.EqualError(t, err, table.err.Error())
} else {
assert.Equal(t, table.data, msg)
assert.NoError(t, err)
}
})
}
}

func TestGetNextMessageTwoMessagesInBuffer(t *testing.T) {
a := NewTCPAcceptor("0.0.0.0:0")
go a.ListenAndServe()
defer a.Stop()
c := a.GetConnChan()
// should be able to connect within 100 milliseconds
var conn net.Conn
var err error
helpers.ShouldEventuallyReturn(t, func() error {
conn, err = net.Dial("tcp", a.GetAddr())
return err
}, nil, 10*time.Millisecond, 100*time.Millisecond)
defer conn.Close()

playerConn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(PlayerConn)
msg1 := []byte{0x01, 0x00, 0x00, 0x01, 0x02}
msg2 := []byte{0x02, 0x00, 0x00, 0x02, 0x01, 0x01}
buffer := append(msg1, msg2...)
_, err = conn.Write(buffer)
assert.NoError(t, err)

msg, err := playerConn.GetNextMessage()
assert.NoError(t, err)
assert.Equal(t, msg1, msg)

msg, err = playerConn.GetNextMessage()
assert.NoError(t, err)
assert.Equal(t, msg2, msg)
}

func TestGetNextMessageEOF(t *testing.T) {
a := NewTCPAcceptor("0.0.0.0:0")
go a.ListenAndServe()
defer a.Stop()
c := a.GetConnChan()
// should be able to connect within 100 milliseconds
var conn net.Conn
var err error
helpers.ShouldEventuallyReturn(t, func() error {
conn, err = net.Dial("tcp", a.GetAddr())
return err
}, nil, 10*time.Millisecond, 100*time.Millisecond)

playerConn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(PlayerConn)
buffer := []byte{0x02, 0x00, 0x00, 0x02, 0x01}
_, err = conn.Write(buffer)
assert.NoError(t, err)

go func() {
time.Sleep(100 * time.Millisecond)
conn.Close()
}()

_, err = playerConn.GetNextMessage()
assert.EqualError(t, err, "Received less data than expected, EOF?")

}

func TestGetNextMessageInParts(t *testing.T) {
a := NewTCPAcceptor("0.0.0.0:0")
go a.ListenAndServe()
defer a.Stop()
c := a.GetConnChan()
// should be able to connect within 100 milliseconds
var conn net.Conn
var err error
helpers.ShouldEventuallyReturn(t, func() error {
conn, err = net.Dial("tcp", a.GetAddr())
return err
}, nil, 10*time.Millisecond, 100*time.Millisecond)

defer conn.Close()
playerConn := helpers.ShouldEventuallyReceive(t, c, 100*time.Millisecond).(PlayerConn)
part1 := []byte{0x02, 0x00, 0x00, 0x03, 0x01}
part2 := []byte{0x01, 0x02}
_, err = conn.Write(part1)
assert.NoError(t, err)

go func() {
time.Sleep(200 * time.Millisecond)
_, err = conn.Write(part2)
}()

msg, err := playerConn.GetNextMessage()
assert.NoError(t, err)
assert.Equal(t, msg, append(part1, part2...))

}

0 comments on commit 4f14f01

Please sign in to comment.