diff --git a/agent/agent.go b/agent/agent.go index 7fc1b5f9..7bb71440 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -22,7 +22,7 @@ package agent import ( "context" - "encoding/json" + gojson "encoding/json" "errors" "fmt" "net" @@ -38,6 +38,8 @@ import ( "github.com/topfreegames/pitaya/metrics" "github.com/topfreegames/pitaya/protos" "github.com/topfreegames/pitaya/serialize" + "github.com/topfreegames/pitaya/serialize/json" + "github.com/topfreegames/pitaya/serialize/protobuf" "github.com/topfreegames/pitaya/session" "github.com/topfreegames/pitaya/tracing" "github.com/topfreegames/pitaya/util" @@ -99,7 +101,16 @@ func NewAgent( ) *Agent { // initialize heartbeat and handshake data on first player connection once.Do(func() { - hbdEncode(heartbeatTime, packetEncoder, messageEncoder.IsCompressionEnabled()) + var serializerName string + switch serializer.(type) { + case *json.Serializer: + serializerName = "json" + case *protobuf.Serializer: + serializerName = "protobuf" + default: + serializerName = "unknown" + } + hbdEncode(heartbeatTime, packetEncoder, messageEncoder.IsCompressionEnabled(), serializerName) }) a := &Agent{ @@ -407,15 +418,16 @@ func (a *Agent) AnswerWithError(ctx context.Context, mid uint, err error) { } } -func hbdEncode(heartbeatTimeout time.Duration, packetEncoder codec.PacketEncoder, dataCompression bool) { +func hbdEncode(heartbeatTimeout time.Duration, packetEncoder codec.PacketEncoder, dataCompression bool, serializerName string) { hData := map[string]interface{}{ "code": 200, "sys": map[string]interface{}{ - "heartbeat": heartbeatTimeout.Seconds(), - "dict": message.GetDictionary(), + "heartbeat": heartbeatTimeout.Seconds(), + "dict": message.GetDictionary(), + "serializer": serializerName, }, } - data, err := json.Marshal(hData) + data, err := gojson.Marshal(hData) if err != nil { panic(err) } diff --git a/client/client.go b/client/client.go index 341a491a..bd8bf419 100644 --- a/client/client.go +++ b/client/client.go @@ -41,21 +41,24 @@ import ( var ( handshakeBuffer = ` { - "sys": { - "type": "golang-tcp", - "version": "0.0.1", - "rsa": {} - }, - "user": { - } - }; + "sys": { + "platform": "mac", + "lib_version": "0.3.5-release", + "client_build_number":"20", + "client_version":"2.1" + }, + "user": { + "age": 30 + } +} ` ) // HandshakeSys struct type HandshakeSys struct { - Dict map[string]uint16 `json:"dict"` - Heartbeat int `json:"heartbeat"` + Dict map[string]uint16 `json:"dict"` + Heartbeat int `json:"heartbeat"` + Serializer string `json:"serializer"` } // HandshakeData struct diff --git a/service/handler.go b/service/handler.go index 888f8238..1604e91c 100644 --- a/service/handler.go +++ b/service/handler.go @@ -22,6 +22,7 @@ package service import ( "context" + "encoding/json" "fmt" "net" "strings" @@ -41,6 +42,7 @@ import ( "github.com/topfreegames/pitaya/metrics" "github.com/topfreegames/pitaya/route" "github.com/topfreegames/pitaya/serialize" + "github.com/topfreegames/pitaya/session" "github.com/topfreegames/pitaya/timer" "github.com/topfreegames/pitaya/tracing" ) @@ -207,14 +209,27 @@ func (h *HandlerService) Handle(conn net.Conn) { } func (h *HandlerService) processPacket(a *agent.Agent, p *packet.Packet) error { + fmt.Println("PROCESSING PAKCET MAMN") switch p.Type { case packet.Handshake: if err := a.SendHandshakeResponse(); err != nil { return err } - a.SetStatus(constants.StatusHandshake) logger.Log.Debugf("Session handshake Id=%d, Remote=%s", a.Session.ID(), a.RemoteAddr()) + // Parse the json sent with the handshake by the client + handshakeData := &session.HandshakeData{} + err := json.Unmarshal(p.Data, handshakeData) + + if err != nil { + a.SetStatus(constants.StatusClosed) + println("INVALID HANDSHAKE DATA BOI") + return fmt.Errorf("Invalid handshake data. Id=%d", a.Session.ID()) + } + + a.Session.SetHandshakeData(handshakeData) + a.SetStatus(constants.StatusHandshake) + case packet.HandshakeAck: a.SetStatus(constants.StatusWorking) logger.Log.Debugf("Receive handshake ACK Id=%d, Remote=%s", a.Session.ID(), a.RemoteAddr()) diff --git a/service/handler_test.go b/service/handler_test.go index d3039fda..304c6952 100644 --- a/service/handler_test.go +++ b/service/handler_test.go @@ -234,23 +234,43 @@ func TestHandlerServiceLocalProcess(t *testing.T) { } func TestHandlerServiceProcessPacketHandshake(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() + tables := []struct { + name string + packet *packet.Packet + socketStatus int32 + errStr string + }{ + {"invalid_handshake_data", &packet.Packet{Type: packet.Handshake, Data: []byte("asiodjasd")}, constants.StatusClosed, "Invalid handshake data"}, + {"valid_handshake_data", &packet.Packet{Type: packet.Handshake, Data: []byte(`{"sys":{"platform":"mac"}}`)}, constants.StatusHandshake, ""}, + } + for _, table := range tables { + t.Run(table.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() - mockSerializer := serializemocks.NewMockSerializer(ctrl) - mockConn := connmock.NewMockConn(ctrl) - packetEncoder := codec.NewPomeloPacketEncoder() - messageEncoder := message.NewMessagesEncoder(false) - svc := NewHandlerService(nil, nil, nil, nil, 1*time.Second, 1, 1, 1, nil, nil, nil, nil) + mockSerializer := serializemocks.NewMockSerializer(ctrl) + mockConn := connmock.NewMockConn(ctrl) + packetEncoder := codec.NewPomeloPacketEncoder() + messageEncoder := message.NewMessagesEncoder(false) + svc := NewHandlerService(nil, nil, nil, nil, 1*time.Second, 1, 1, 1, nil, nil, nil, nil) - mockConn.EXPECT().RemoteAddr().Return(&mockAddr{}) - mockConn.EXPECT().Write(gomock.Any()).Do(func(d []byte) { - assert.Contains(t, string(d), "heartbeat") - }) - ag := agent.NewAgent(mockConn, nil, packetEncoder, mockSerializer, 1*time.Second, 1, nil, messageEncoder, nil) - err := svc.processPacket(ag, &packet.Packet{Type: packet.Handshake}) - assert.NoError(t, err) - assert.Equal(t, constants.StatusHandshake, ag.GetStatus()) + mockConn.EXPECT().RemoteAddr().Return(&mockAddr{}) + mockConn.EXPECT().Write(gomock.Any()).Do(func(d []byte) { + assert.Contains(t, string(d), "heartbeat") + }) + + ag := agent.NewAgent(mockConn, nil, packetEncoder, mockSerializer, 1*time.Second, 1, nil, messageEncoder, nil) + + err := svc.processPacket(ag, table.packet) + if table.errStr == "" { + assert.Nil(t, err) + } else { + assert.NotNil(t, err) + assert.Contains(t, err.Error(), table.errStr) + } + assert.Equal(t, table.socketStatus, ag.GetStatus()) + }) + } } func TestHandlerServiceProcessPacketHandshakeAck(t *testing.T) { @@ -341,21 +361,15 @@ func TestHandlerServiceHandle(t *testing.T) { svc := NewHandlerService(nil, packetDecoder, packetEncoder, mockSerializer, 1*time.Second, 1, 1, 1, nil, nil, messageEncoder, nil) var wg sync.WaitGroup firstCall := mockConn.EXPECT().Read(gomock.Any()).Do(func(b []byte) { - handshakeBuffer := `{ - "sys": { - "type": "golang-tcp", - "version": "0.0.1", - "rsa": {} - }, - "user": {} -};` + handshakeBuffer := `{"sys":{"platform":"mac","lib_version":"0.3.5-release","client_build_number":"20","client_version":"2.1"},"user":{"age":30}}` bbb, err := packetEncoder.Encode(packet.Handshake, []byte(handshakeBuffer)) for i, c := range bbb { b[i] = c } + assert.NoError(t, err) wg.Done() - }).Return(101, nil) + }).Return(128, nil) mockConn.EXPECT().Read(gomock.Any()).Return(0, errors.New("die")).Do(func(b []byte) { wg.Done() diff --git a/session/session.go b/session/session.go index 17fbc7f2..fac1f773 100644 --- a/session/session.go +++ b/session/session.go @@ -58,6 +58,22 @@ var ( SessionCount int64 ) +// HandshakeClientData represents information about the client sent on the handshake. +type HandshakeClientData struct { + Platform string `json:"platform"` + LibVersion string `json:"lib_version"` + BuildNumber string `json:"client_build_number"` + Version string `json:"client_version"` +} + +// HandshakeData represents information about the handshake sent by the client. +// `sys` corresponds to information independent from the app and `user` information +// that depends on the app and is customized by the user. +type HandshakeData struct { + Sys HandshakeClientData `json:"sys"` + User map[string]interface{} `json:"user,omitempty"` +} + // Session represents a client session which could storage temp data during low-level // keep connected, all data will be released when the low-level connection was broken. // Session instance related to the client will be passed to Handler method as the first @@ -69,6 +85,7 @@ type Session struct { lastTime int64 // last heartbeat time entity NetworkEntity // low-level network entity data map[string]interface{} // session data store + handshakeData *HandshakeData // handshake data received by the client encodedData []byte // session data encoded as a byte array OnCloseCallbacks []func() //onClose callbacks IsFrontend bool // if session is a frontend session @@ -99,6 +116,7 @@ func New(entity NetworkEntity, frontend bool, UID ...string) *Session { id: sessionIDSvc.sessionID(), entity: entity, data: make(map[string]interface{}), + handshakeData: nil, lastTime: time.Now().Unix(), OnCloseCallbacks: []func(){}, IsFrontend: frontend, @@ -616,6 +634,19 @@ func (s *Session) Clear() { s.updateEncodedData() } +// SetHandshakeData sets the handshake data received by the client. +func (s *Session) SetHandshakeData(data *HandshakeData) { + s.Lock() + defer s.Unlock() + + s.handshakeData = data +} + +// GetHandshakeData gets the handshake data received by the client. +func (s *Session) GetHandshakeData() *HandshakeData { + return s.handshakeData +} + func (s *Session) sendRequestToFront(ctx context.Context, route string, includeData bool) error { sessionData := &protos.Session{ Id: s.frontendSessionID, diff --git a/session/session_test.go b/session/session_test.go index be1220c1..09069144 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -1261,3 +1261,89 @@ func TestSessionClear(t *testing.T) { expectedEncoded := getEncodedEmptyMap() assert.Equal(t, expectedEncoded, ss.encodedData) } + +func TestSessionGetHandshakeData(t *testing.T) { + t.Parallel() + + data1 := &HandshakeData{ + Sys: HandshakeClientData{ + Platform: "macos", + LibVersion: "2.3.2", + BuildNumber: "20", + Version: "14.0.2", + }, + User: make(map[string]interface{}), + } + data2 := &HandshakeData{ + Sys: HandshakeClientData{ + Platform: "windows", + LibVersion: "2.3.10", + BuildNumber: "", + Version: "ahaha", + }, + User: map[string]interface{}{ + "ababa": make(map[string]interface{}), + "pepe": 1, + }, + } + tables := []struct { + name string + data *HandshakeData + }{ + {"test_1", data1}, + {"test_2", data2}, + } + + for _, table := range tables { + t.Run(table.name, func(t *testing.T) { + ss := New(nil, false) + + assert.Nil(t, ss.GetHandshakeData()) + + ss.handshakeData = table.data + + assert.Equal(t, ss.GetHandshakeData(), table.data) + }) + } +} + +func TestSessionSetHandshakeData(t *testing.T) { + t.Parallel() + + data1 := &HandshakeData{ + Sys: HandshakeClientData{ + Platform: "macos", + LibVersion: "2.3.2", + BuildNumber: "20", + Version: "14.0.2", + }, + User: make(map[string]interface{}), + } + data2 := &HandshakeData{ + Sys: HandshakeClientData{ + Platform: "windows", + LibVersion: "2.3.10", + BuildNumber: "", + Version: "ahaha", + }, + User: map[string]interface{}{ + "ababa": make(map[string]interface{}), + "pepe": 1, + }, + } + tables := []struct { + name string + data *HandshakeData + }{ + {"testSessionSetData_1", data1}, + {"testSessionSetData_2", data2}, + } + + for _, table := range tables { + t.Run(table.name, func(t *testing.T) { + ss := New(nil, false) + ss.SetHandshakeData(table.data) + assert.Equal(t, table.data, ss.handshakeData) + }) + } +}