Skip to content

Commit

Permalink
Merge b6d06fc into bd4f950
Browse files Browse the repository at this point in the history
  • Loading branch information
leohahn committed Jun 28, 2018
2 parents bd4f950 + b6d06fc commit 1bfcaf0
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 41 deletions.
24 changes: 18 additions & 6 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ package agent

import (
"context"
"encoding/json"
gojson "encoding/json"
"errors"
"fmt"
"net"
Expand All @@ -38,6 +38,8 @@ import (
"github.com/topfreegames/pitaya/metrics"
"github.com/topfreegames/pitaya/protos"
"github.com/topfreegames/pitaya/serialize"
"github.com/topfreegames/pitaya/serialize/json"
"github.com/topfreegames/pitaya/serialize/protobuf"
"github.com/topfreegames/pitaya/session"
"github.com/topfreegames/pitaya/tracing"
"github.com/topfreegames/pitaya/util"
Expand Down Expand Up @@ -99,7 +101,16 @@ func NewAgent(
) *Agent {
// initialize heartbeat and handshake data on first player connection
once.Do(func() {
hbdEncode(heartbeatTime, packetEncoder, messageEncoder.IsCompressionEnabled())
var serializerName string
switch serializer.(type) {
case *json.Serializer:
serializerName = "json"
case *protobuf.Serializer:
serializerName = "protobuf"
default:
serializerName = "unknown"
}
hbdEncode(heartbeatTime, packetEncoder, messageEncoder.IsCompressionEnabled(), serializerName)
})

a := &Agent{
Expand Down Expand Up @@ -407,15 +418,16 @@ func (a *Agent) AnswerWithError(ctx context.Context, mid uint, err error) {
}
}

func hbdEncode(heartbeatTimeout time.Duration, packetEncoder codec.PacketEncoder, dataCompression bool) {
func hbdEncode(heartbeatTimeout time.Duration, packetEncoder codec.PacketEncoder, dataCompression bool, serializerName string) {
hData := map[string]interface{}{
"code": 200,
"sys": map[string]interface{}{
"heartbeat": heartbeatTimeout.Seconds(),
"dict": message.GetDictionary(),
"heartbeat": heartbeatTimeout.Seconds(),
"dict": message.GetDictionary(),
"serializer": serializerName,
},
}
data, err := json.Marshal(hData)
data, err := gojson.Marshal(hData)
if err != nil {
panic(err)
}
Expand Down
23 changes: 13 additions & 10 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,24 @@ import (
var (
handshakeBuffer = `
{
"sys": {
"type": "golang-tcp",
"version": "0.0.1",
"rsa": {}
},
"user": {
}
};
"sys": {
"platform": "mac",
"lib_version": "0.3.5-release",
"client_build_number":"20",
"client_version":"2.1"
},
"user": {
"age": 30
}
}
`
)

// HandshakeSys struct
type HandshakeSys struct {
Dict map[string]uint16 `json:"dict"`
Heartbeat int `json:"heartbeat"`
Dict map[string]uint16 `json:"dict"`
Heartbeat int `json:"heartbeat"`
Serializer string `json:"serializer"`
}

// HandshakeData struct
Expand Down
17 changes: 16 additions & 1 deletion service/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package service

import (
"context"
"encoding/json"
"fmt"
"net"
"strings"
Expand All @@ -41,6 +42,7 @@ import (
"github.com/topfreegames/pitaya/metrics"
"github.com/topfreegames/pitaya/route"
"github.com/topfreegames/pitaya/serialize"
"github.com/topfreegames/pitaya/session"
"github.com/topfreegames/pitaya/timer"
"github.com/topfreegames/pitaya/tracing"
)
Expand Down Expand Up @@ -207,14 +209,27 @@ func (h *HandlerService) Handle(conn net.Conn) {
}

func (h *HandlerService) processPacket(a *agent.Agent, p *packet.Packet) error {
fmt.Println("PROCESSING PAKCET MAMN")
switch p.Type {
case packet.Handshake:
if err := a.SendHandshakeResponse(); err != nil {
return err
}
a.SetStatus(constants.StatusHandshake)
logger.Log.Debugf("Session handshake Id=%d, Remote=%s", a.Session.ID(), a.RemoteAddr())

// Parse the json sent with the handshake by the client
handshakeData := &session.HandshakeData{}
err := json.Unmarshal(p.Data, handshakeData)

if err != nil {
a.SetStatus(constants.StatusClosed)
println("INVALID HANDSHAKE DATA BOI")
return fmt.Errorf("Invalid handshake data. Id=%d", a.Session.ID())
}

a.Session.SetHandshakeData(handshakeData)
a.SetStatus(constants.StatusHandshake)

case packet.HandshakeAck:
a.SetStatus(constants.StatusWorking)
logger.Log.Debugf("Receive handshake ACK Id=%d, Remote=%s", a.Session.ID(), a.RemoteAddr())
Expand Down
62 changes: 38 additions & 24 deletions service/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,23 +234,43 @@ func TestHandlerServiceLocalProcess(t *testing.T) {
}

func TestHandlerServiceProcessPacketHandshake(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
tables := []struct {
name string
packet *packet.Packet
socketStatus int32
errStr string
}{
{"invalid_handshake_data", &packet.Packet{Type: packet.Handshake, Data: []byte("asiodjasd")}, constants.StatusClosed, "Invalid handshake data"},
{"valid_handshake_data", &packet.Packet{Type: packet.Handshake, Data: []byte(`{"sys":{"platform":"mac"}}`)}, constants.StatusHandshake, ""},
}
for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockSerializer := serializemocks.NewMockSerializer(ctrl)
mockConn := connmock.NewMockConn(ctrl)
packetEncoder := codec.NewPomeloPacketEncoder()
messageEncoder := message.NewMessagesEncoder(false)
svc := NewHandlerService(nil, nil, nil, nil, 1*time.Second, 1, 1, 1, nil, nil, nil, nil)
mockSerializer := serializemocks.NewMockSerializer(ctrl)
mockConn := connmock.NewMockConn(ctrl)
packetEncoder := codec.NewPomeloPacketEncoder()
messageEncoder := message.NewMessagesEncoder(false)
svc := NewHandlerService(nil, nil, nil, nil, 1*time.Second, 1, 1, 1, nil, nil, nil, nil)

mockConn.EXPECT().RemoteAddr().Return(&mockAddr{})
mockConn.EXPECT().Write(gomock.Any()).Do(func(d []byte) {
assert.Contains(t, string(d), "heartbeat")
})
ag := agent.NewAgent(mockConn, nil, packetEncoder, mockSerializer, 1*time.Second, 1, nil, messageEncoder, nil)
err := svc.processPacket(ag, &packet.Packet{Type: packet.Handshake})
assert.NoError(t, err)
assert.Equal(t, constants.StatusHandshake, ag.GetStatus())
mockConn.EXPECT().RemoteAddr().Return(&mockAddr{})
mockConn.EXPECT().Write(gomock.Any()).Do(func(d []byte) {
assert.Contains(t, string(d), "heartbeat")
})

ag := agent.NewAgent(mockConn, nil, packetEncoder, mockSerializer, 1*time.Second, 1, nil, messageEncoder, nil)

err := svc.processPacket(ag, table.packet)
if table.errStr == "" {
assert.Nil(t, err)
} else {
assert.NotNil(t, err)
assert.Contains(t, err.Error(), table.errStr)
}
assert.Equal(t, table.socketStatus, ag.GetStatus())
})
}
}

func TestHandlerServiceProcessPacketHandshakeAck(t *testing.T) {
Expand Down Expand Up @@ -341,21 +361,15 @@ func TestHandlerServiceHandle(t *testing.T) {
svc := NewHandlerService(nil, packetDecoder, packetEncoder, mockSerializer, 1*time.Second, 1, 1, 1, nil, nil, messageEncoder, nil)
var wg sync.WaitGroup
firstCall := mockConn.EXPECT().Read(gomock.Any()).Do(func(b []byte) {
handshakeBuffer := `{
"sys": {
"type": "golang-tcp",
"version": "0.0.1",
"rsa": {}
},
"user": {}
};`
handshakeBuffer := `{"sys":{"platform":"mac","lib_version":"0.3.5-release","client_build_number":"20","client_version":"2.1"},"user":{"age":30}}`
bbb, err := packetEncoder.Encode(packet.Handshake, []byte(handshakeBuffer))
for i, c := range bbb {
b[i] = c
}

assert.NoError(t, err)
wg.Done()
}).Return(101, nil)
}).Return(128, nil)

mockConn.EXPECT().Read(gomock.Any()).Return(0, errors.New("die")).Do(func(b []byte) {
wg.Done()
Expand Down
31 changes: 31 additions & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ var (
SessionCount int64
)

// HandshakeClientData represents information about the client sent on the handshake.
type HandshakeClientData struct {
Platform string `json:"platform"`
LibVersion string `json:"lib_version"`
BuildNumber string `json:"client_build_number"`
Version string `json:"client_version"`
}

// HandshakeData represents information about the handshake sent by the client.
// `sys` corresponds to information independent from the app and `user` information
// that depends on the app and is customized by the user.
type HandshakeData struct {
Sys HandshakeClientData `json:"sys"`
User map[string]interface{} `json:"user,omitempty"`
}

// Session represents a client session which could storage temp data during low-level
// keep connected, all data will be released when the low-level connection was broken.
// Session instance related to the client will be passed to Handler method as the first
Expand All @@ -69,6 +85,7 @@ type Session struct {
lastTime int64 // last heartbeat time
entity NetworkEntity // low-level network entity
data map[string]interface{} // session data store
handshakeData *HandshakeData // handshake data received by the client
encodedData []byte // session data encoded as a byte array
OnCloseCallbacks []func() //onClose callbacks
IsFrontend bool // if session is a frontend session
Expand Down Expand Up @@ -99,6 +116,7 @@ func New(entity NetworkEntity, frontend bool, UID ...string) *Session {
id: sessionIDSvc.sessionID(),
entity: entity,
data: make(map[string]interface{}),
handshakeData: nil,
lastTime: time.Now().Unix(),
OnCloseCallbacks: []func(){},
IsFrontend: frontend,
Expand Down Expand Up @@ -616,6 +634,19 @@ func (s *Session) Clear() {
s.updateEncodedData()
}

// SetHandshakeData sets the handshake data received by the client.
func (s *Session) SetHandshakeData(data *HandshakeData) {
s.Lock()
defer s.Unlock()

s.handshakeData = data
}

// GetHandshakeData gets the handshake data received by the client.
func (s *Session) GetHandshakeData() *HandshakeData {
return s.handshakeData
}

func (s *Session) sendRequestToFront(ctx context.Context, route string, includeData bool) error {
sessionData := &protos.Session{
Id: s.frontendSessionID,
Expand Down
86 changes: 86 additions & 0 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1261,3 +1261,89 @@ func TestSessionClear(t *testing.T) {
expectedEncoded := getEncodedEmptyMap()
assert.Equal(t, expectedEncoded, ss.encodedData)
}

func TestSessionGetHandshakeData(t *testing.T) {
t.Parallel()

data1 := &HandshakeData{
Sys: HandshakeClientData{
Platform: "macos",
LibVersion: "2.3.2",
BuildNumber: "20",
Version: "14.0.2",
},
User: make(map[string]interface{}),
}
data2 := &HandshakeData{
Sys: HandshakeClientData{
Platform: "windows",
LibVersion: "2.3.10",
BuildNumber: "",
Version: "ahaha",
},
User: map[string]interface{}{
"ababa": make(map[string]interface{}),
"pepe": 1,
},
}
tables := []struct {
name string
data *HandshakeData
}{
{"test_1", data1},
{"test_2", data2},
}

for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
ss := New(nil, false)

assert.Nil(t, ss.GetHandshakeData())

ss.handshakeData = table.data

assert.Equal(t, ss.GetHandshakeData(), table.data)
})
}
}

func TestSessionSetHandshakeData(t *testing.T) {
t.Parallel()

data1 := &HandshakeData{
Sys: HandshakeClientData{
Platform: "macos",
LibVersion: "2.3.2",
BuildNumber: "20",
Version: "14.0.2",
},
User: make(map[string]interface{}),
}
data2 := &HandshakeData{
Sys: HandshakeClientData{
Platform: "windows",
LibVersion: "2.3.10",
BuildNumber: "",
Version: "ahaha",
},
User: map[string]interface{}{
"ababa": make(map[string]interface{}),
"pepe": 1,
},
}
tables := []struct {
name string
data *HandshakeData
}{
{"testSessionSetData_1", data1},
{"testSessionSetData_2", data2},
}

for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
ss := New(nil, false)
ss.SetHandshakeData(table.data)
assert.Equal(t, table.data, ss.handshakeData)
})
}
}

0 comments on commit 1bfcaf0

Please sign in to comment.