diff --git a/agent/agent.go b/agent/agent.go index a6a89585..83451ab7 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -52,7 +52,9 @@ var ( // hbd contains the heartbeat packet data hbd []byte // hrd contains the handshake response data - hrd []byte + hrd []byte + // herd contains the handshake error response data + herd []byte once sync.Once ) @@ -110,6 +112,7 @@ type ( Handle() IPVersion() string SendHandshakeResponse() error + SendHandshakeErrorResponse() error SendRequest(ctx context.Context, serverID, route string, v interface{}) (*protos.Response, error) AnswerWithError(ctx context.Context, mid uint, err error) } @@ -180,6 +183,7 @@ func newAgent( once.Do(func() { hbdEncode(heartbeatTime, packetEncoder, messageEncoder.IsCompressionEnabled(), serializerName) + herdEncode(heartbeatTime, packetEncoder, messageEncoder.IsCompressionEnabled(), serializerName) }) a := &agentImpl{ @@ -475,6 +479,14 @@ func (a *agentImpl) onSessionClosed(s session.Session) { // SendHandshakeResponse sends a handshake response func (a *agentImpl) SendHandshakeResponse() error { _, err := a.conn.Write(hrd) + + return err +} + +func (a *agentImpl) SendHandshakeErrorResponse() error { + a.SetStatus(constants.StatusClosed) + _, err := a.conn.Write(herd) + return err } @@ -543,33 +555,63 @@ func hbdEncode(heartbeatTimeout time.Duration, packetEncoder codec.PacketEncoder "serializer": serializerName, }, } - data, err := gojson.Marshal(hData) + + data, err := encodeAndCompress(hData, dataCompression) if err != nil { panic(err) } - if dataCompression { - compressedData, err := compression.DeflateData(data) - if err != nil { - panic(err) - } + hrd, err = packetEncoder.Encode(packet.Handshake, data) + if err != nil { + panic(err) + } - if len(compressedData) < len(data) { - data = compressedData - } + hbd, err = packetEncoder.Encode(packet.Heartbeat, nil) + if err != nil { + panic(err) } +} - hrd, err = packetEncoder.Encode(packet.Handshake, data) +func herdEncode(heartbeatTimeout time.Duration, packetEncoder codec.PacketEncoder, dataCompression bool, serializerName string) { + hErrData := map[string]interface{}{ + "code": 400, + "sys": map[string]interface{}{ + "heartbeat": heartbeatTimeout.Seconds(), + "dict": message.GetDictionary(), + "serializer": serializerName, + }, + } + + errData, err := encodeAndCompress(hErrData, dataCompression) if err != nil { panic(err) } - hbd, err = packetEncoder.Encode(packet.Heartbeat, nil) + herd, err = packetEncoder.Encode(packet.Handshake, errData) if err != nil { panic(err) } } +func encodeAndCompress(data interface{}, dataCompression bool) ([]byte, error) { + encData, err := gojson.Marshal(data) + if err != nil { + return nil, err + } + + if dataCompression { + compressedData, err := compression.DeflateData(encData) + if err != nil { + return nil, err + } + + if len(compressedData) < len(encData) { + encData = compressedData + } + } + return encData, nil +} + func (a *agentImpl) reportChannelSize() { chSendCapacity := a.messagesBufferSize - len(a.chSend) if chSendCapacity == 0 { diff --git a/agent/agent_test.go b/agent/agent_test.go index 3f76037a..e3318906 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -84,7 +84,7 @@ func TestNewAgent(t *testing.T) { func(typ packet.Type, d []byte) { // cannot compare inside the expect because they are equivalent but not equal assert.EqualValues(t, packet.Handshake, typ) - }) + }).Times(2) mockEncoder.EXPECT().Encode(gomock.Any(), gomock.Nil()).Do( func(typ packet.Type, d []byte) { assert.EqualValues(t, packet.Heartbeat, typ) diff --git a/agent/mocks/agent.go b/agent/mocks/agent.go index de567a4c..256f7fd2 100644 --- a/agent/mocks/agent.go +++ b/agent/mocks/agent.go @@ -6,50 +6,51 @@ package mocks import ( context "context" + net "net" + reflect "reflect" + gomock "github.com/golang/mock/gomock" agent "github.com/topfreegames/pitaya/v2/agent" protos "github.com/topfreegames/pitaya/v2/protos" session "github.com/topfreegames/pitaya/v2/session" - net "net" - reflect "reflect" ) -// MockAgent is a mock of Agent interface +// MockAgent is a mock of Agent interface. type MockAgent struct { ctrl *gomock.Controller recorder *MockAgentMockRecorder } -// MockAgentMockRecorder is the mock recorder for MockAgent +// MockAgentMockRecorder is the mock recorder for MockAgent. type MockAgentMockRecorder struct { mock *MockAgent } -// NewMockAgent creates a new mock instance +// NewMockAgent creates a new mock instance. func NewMockAgent(ctrl *gomock.Controller) *MockAgent { mock := &MockAgent{ctrl: ctrl} mock.recorder = &MockAgentMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockAgent) EXPECT() *MockAgentMockRecorder { return m.recorder } -// AnswerWithError mocks base method +// AnswerWithError mocks base method. func (m *MockAgent) AnswerWithError(arg0 context.Context, arg1 uint, arg2 error) { m.ctrl.T.Helper() m.ctrl.Call(m, "AnswerWithError", arg0, arg1, arg2) } -// AnswerWithError indicates an expected call of AnswerWithError +// AnswerWithError indicates an expected call of AnswerWithError. func (mr *MockAgentMockRecorder) AnswerWithError(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AnswerWithError", reflect.TypeOf((*MockAgent)(nil).AnswerWithError), arg0, arg1, arg2) } -// Close mocks base method +// Close mocks base method. func (m *MockAgent) Close() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Close") @@ -57,13 +58,13 @@ func (m *MockAgent) Close() error { return ret0 } -// Close indicates an expected call of Close +// Close indicates an expected call of Close. func (mr *MockAgentMockRecorder) Close() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAgent)(nil).Close)) } -// GetSession mocks base method +// GetSession mocks base method. func (m *MockAgent) GetSession() session.Session { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetSession") @@ -71,13 +72,13 @@ func (m *MockAgent) GetSession() session.Session { return ret0 } -// GetSession indicates an expected call of GetSession +// GetSession indicates an expected call of GetSession. func (mr *MockAgentMockRecorder) GetSession() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSession", reflect.TypeOf((*MockAgent)(nil).GetSession)) } -// GetStatus mocks base method +// GetStatus mocks base method. func (m *MockAgent) GetStatus() int32 { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetStatus") @@ -85,25 +86,25 @@ func (m *MockAgent) GetStatus() int32 { return ret0 } -// GetStatus indicates an expected call of GetStatus +// GetStatus indicates an expected call of GetStatus. func (mr *MockAgentMockRecorder) GetStatus() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStatus", reflect.TypeOf((*MockAgent)(nil).GetStatus)) } -// Handle mocks base method +// Handle mocks base method. func (m *MockAgent) Handle() { m.ctrl.T.Helper() m.ctrl.Call(m, "Handle") } -// Handle indicates an expected call of Handle +// Handle indicates an expected call of Handle. func (mr *MockAgentMockRecorder) Handle() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Handle", reflect.TypeOf((*MockAgent)(nil).Handle)) } -// IPVersion mocks base method +// IPVersion mocks base method. func (m *MockAgent) IPVersion() string { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IPVersion") @@ -111,13 +112,13 @@ func (m *MockAgent) IPVersion() string { return ret0 } -// IPVersion indicates an expected call of IPVersion +// IPVersion indicates an expected call of IPVersion. func (mr *MockAgentMockRecorder) IPVersion() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IPVersion", reflect.TypeOf((*MockAgent)(nil).IPVersion)) } -// Kick mocks base method +// Kick mocks base method. func (m *MockAgent) Kick(arg0 context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Kick", arg0) @@ -125,13 +126,13 @@ func (m *MockAgent) Kick(arg0 context.Context) error { return ret0 } -// Kick indicates an expected call of Kick +// Kick indicates an expected call of Kick. func (mr *MockAgentMockRecorder) Kick(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Kick", reflect.TypeOf((*MockAgent)(nil).Kick), arg0) } -// Push mocks base method +// Push mocks base method. func (m *MockAgent) Push(arg0 string, arg1 interface{}) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Push", arg0, arg1) @@ -139,13 +140,13 @@ func (m *MockAgent) Push(arg0 string, arg1 interface{}) error { return ret0 } -// Push indicates an expected call of Push +// Push indicates an expected call of Push. func (mr *MockAgentMockRecorder) Push(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Push", reflect.TypeOf((*MockAgent)(nil).Push), arg0, arg1) } -// RemoteAddr mocks base method +// RemoteAddr mocks base method. func (m *MockAgent) RemoteAddr() net.Addr { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RemoteAddr") @@ -153,13 +154,13 @@ func (m *MockAgent) RemoteAddr() net.Addr { return ret0 } -// RemoteAddr indicates an expected call of RemoteAddr +// RemoteAddr indicates an expected call of RemoteAddr. func (mr *MockAgentMockRecorder) RemoteAddr() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockAgent)(nil).RemoteAddr)) } -// ResponseMID mocks base method +// ResponseMID mocks base method. func (m *MockAgent) ResponseMID(arg0 context.Context, arg1 uint, arg2 interface{}, arg3 ...bool) error { m.ctrl.T.Helper() varargs := []interface{}{arg0, arg1, arg2} @@ -171,14 +172,28 @@ func (m *MockAgent) ResponseMID(arg0 context.Context, arg1 uint, arg2 interface{ return ret0 } -// ResponseMID indicates an expected call of ResponseMID +// ResponseMID indicates an expected call of ResponseMID. func (mr *MockAgentMockRecorder) ResponseMID(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResponseMID", reflect.TypeOf((*MockAgent)(nil).ResponseMID), varargs...) } -// SendHandshakeResponse mocks base method +// SendHandshakeErrorResponse mocks base method. +func (m *MockAgent) SendHandshakeErrorResponse() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendHandshakeErrorResponse") + ret0, _ := ret[0].(error) + return ret0 +} + +// SendHandshakeErrorResponse indicates an expected call of SendHandshakeErrorResponse. +func (mr *MockAgentMockRecorder) SendHandshakeErrorResponse() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendHandshakeErrorResponse", reflect.TypeOf((*MockAgent)(nil).SendHandshakeErrorResponse)) +} + +// SendHandshakeResponse mocks base method. func (m *MockAgent) SendHandshakeResponse() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendHandshakeResponse") @@ -186,13 +201,13 @@ func (m *MockAgent) SendHandshakeResponse() error { return ret0 } -// SendHandshakeResponse indicates an expected call of SendHandshakeResponse +// SendHandshakeResponse indicates an expected call of SendHandshakeResponse. func (mr *MockAgentMockRecorder) SendHandshakeResponse() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendHandshakeResponse", reflect.TypeOf((*MockAgent)(nil).SendHandshakeResponse)) } -// SendRequest mocks base method +// SendRequest mocks base method. func (m *MockAgent) SendRequest(arg0 context.Context, arg1, arg2 string, arg3 interface{}) (*protos.Response, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendRequest", arg0, arg1, arg2, arg3) @@ -201,37 +216,37 @@ func (m *MockAgent) SendRequest(arg0 context.Context, arg1, arg2 string, arg3 in return ret0, ret1 } -// SendRequest indicates an expected call of SendRequest +// SendRequest indicates an expected call of SendRequest. func (mr *MockAgentMockRecorder) SendRequest(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendRequest", reflect.TypeOf((*MockAgent)(nil).SendRequest), arg0, arg1, arg2, arg3) } -// SetLastAt mocks base method +// SetLastAt mocks base method. func (m *MockAgent) SetLastAt() { m.ctrl.T.Helper() m.ctrl.Call(m, "SetLastAt") } -// SetLastAt indicates an expected call of SetLastAt +// SetLastAt indicates an expected call of SetLastAt. func (mr *MockAgentMockRecorder) SetLastAt() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLastAt", reflect.TypeOf((*MockAgent)(nil).SetLastAt)) } -// SetStatus mocks base method +// SetStatus mocks base method. func (m *MockAgent) SetStatus(arg0 int32) { m.ctrl.T.Helper() m.ctrl.Call(m, "SetStatus", arg0) } -// SetStatus indicates an expected call of SetStatus +// SetStatus indicates an expected call of SetStatus. func (mr *MockAgentMockRecorder) SetStatus(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockAgent)(nil).SetStatus), arg0) } -// String mocks base method +// String mocks base method. func (m *MockAgent) String() string { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "String") @@ -239,36 +254,36 @@ func (m *MockAgent) String() string { return ret0 } -// String indicates an expected call of String +// String indicates an expected call of String. func (mr *MockAgentMockRecorder) String() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*MockAgent)(nil).String)) } -// MockAgentFactory is a mock of AgentFactory interface +// MockAgentFactory is a mock of AgentFactory interface. type MockAgentFactory struct { ctrl *gomock.Controller recorder *MockAgentFactoryMockRecorder } -// MockAgentFactoryMockRecorder is the mock recorder for MockAgentFactory +// MockAgentFactoryMockRecorder is the mock recorder for MockAgentFactory. type MockAgentFactoryMockRecorder struct { mock *MockAgentFactory } -// NewMockAgentFactory creates a new mock instance +// NewMockAgentFactory creates a new mock instance. func NewMockAgentFactory(ctrl *gomock.Controller) *MockAgentFactory { mock := &MockAgentFactory{ctrl: ctrl} mock.recorder = &MockAgentFactoryMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockAgentFactory) EXPECT() *MockAgentFactoryMockRecorder { return m.recorder } -// CreateAgent mocks base method +// CreateAgent mocks base method. func (m *MockAgentFactory) CreateAgent(arg0 net.Conn) agent.Agent { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateAgent", arg0) @@ -276,7 +291,7 @@ func (m *MockAgentFactory) CreateAgent(arg0 net.Conn) agent.Agent { return ret0 } -// CreateAgent indicates an expected call of CreateAgent +// CreateAgent indicates an expected call of CreateAgent. func (mr *MockAgentFactoryMockRecorder) CreateAgent(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateAgent", reflect.TypeOf((*MockAgentFactory)(nil).CreateAgent), arg0) diff --git a/builder.go b/builder.go index 96906673..3d96ff6e 100644 --- a/builder.go +++ b/builder.go @@ -25,6 +25,7 @@ import ( // Builder holds dependency instances for a pitaya App type Builder struct { acceptors []acceptor.Acceptor + postBuildHooks []func(app Pitaya) Config config.BuilderConfig DieChan chan bool PacketDecoder codec.PacketDecoder @@ -46,6 +47,8 @@ type Builder struct { // PitayaBuilder Builder interface type PitayaBuilder interface { + // AddPostBuildHook adds a post-build hook to the builder, a function receiving a Pitaya instance as parameter. + AddPostBuildHook(hook func(app Pitaya)) Build() Pitaya } @@ -185,6 +188,7 @@ func NewBuilder(isFrontend bool, return &Builder{ acceptors: []acceptor.Acceptor{}, + postBuildHooks: make([]func(app Pitaya), 0), Config: config, DieChan: dieChan, PacketDecoder: codec.NewPomeloPacketDecoder(), @@ -214,6 +218,11 @@ func (builder *Builder) AddAcceptor(ac acceptor.Acceptor) { builder.acceptors = append(builder.acceptors, ac) } +// AddPostBuildHook adds a post-build hook to the builder, a function receiving a Pitaya instance as parameter. +func (builder *Builder) AddPostBuildHook(hook func(app Pitaya)) { + builder.postBuildHooks = append(builder.postBuildHooks, hook) +} + // Build returns a valid App instance func (builder *Builder) Build() Pitaya { handlerPool := service.NewHandlerPool() @@ -270,7 +279,7 @@ func (builder *Builder) Build() Pitaya { handlerPool, ) - return NewApp( + app := NewApp( builder.ServerMode, builder.Serializer, builder.acceptors, @@ -288,6 +297,12 @@ func (builder *Builder) Build() Pitaya { builder.MetricsReporters, builder.Config.Pitaya, ) + + for _, postBuildHook := range builder.postBuildHooks { + postBuildHook(app) + } + + return app } // NewDefaultApp returns a default pitaya app instance diff --git a/builder_test.go b/builder_test.go new file mode 100644 index 00000000..da07c097 --- /dev/null +++ b/builder_test.go @@ -0,0 +1,58 @@ +// Copyright (c) nano Author and TFG Co. All Rights Reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package pitaya + +import ( + "github.com/stretchr/testify/assert" + "github.com/topfreegames/pitaya/v2/acceptor" + "github.com/topfreegames/pitaya/v2/config" + "testing" +) + +func TestPostBuildHooks(t *testing.T) { + acc := acceptor.NewTCPAcceptor("0.0.0.0:0") + for _, table := range tables { + builderConfig := config.NewDefaultBuilderConfig() + + t.Run("with_post_build_hooks", func(t *testing.T) { + called := false + builder := NewDefaultBuilder(table.isFrontend, table.serverType, table.serverMode, table.serverMetadata, *builderConfig) + builder.AddAcceptor(acc) + builder.AddPostBuildHook(func(app Pitaya) { + called = true + }) + app := builder.Build() + + assert.True(t, called) + assert.NotNil(t, app) + }) + + t.Run("without_post_build_hooks", func(t *testing.T) { + called := false + builder := NewDefaultBuilder(table.isFrontend, table.serverType, table.serverMode, table.serverMetadata, *builderConfig) + builder.AddAcceptor(acc) + app := builder.Build() + + assert.False(t, called) + assert.NotNil(t, app) + }) + } +} diff --git a/docs/builder.md b/docs/builder.md new file mode 100644 index 00000000..a3de1cbb --- /dev/null +++ b/docs/builder.md @@ -0,0 +1,41 @@ +Builder +=== + +Pitaya offers a [`Builder`](../builder.go) object which can be utilized to define a sort of pitaya properties. + +### PostBuildHooks + +Post-build hooks can be used to execute additional actions automatically after the build process. It also allows you to interact with the built pitaya app. + +A common use case is where it becomes necessary to perform configuration steps in both the pitaya builder and the pitaya app being built. In such cases, an effective approach is to internalize these configurations, enabling you to handle them collectively in a single operation or process. It simplifies the overall configuration process, reducing the need for separate and potentially repetitive steps. + +```go +// main.go +cfg := config.NewDefaultBuilderConfig() +builder := pitaya.NewDefaultBuilder(isFrontEnd, "my-server-type", pitaya.Cluster, map[string]string{}, *cfg) + +customModule := NewCustomModule(builder) +customModule.ConfigurePitaya(builder) + +app := builder.Build() + +// custom_object.go +type CustomObject struct { + builder *pitaya.Builder +} + +func NewCustomObject(builder *pitaya.Builder) *CustomObject { + return &CustomObject{ + builder: builder, + } +} + +func (object *CustomObject) ConfigurePitaya() { + object.builder.AddAcceptor(...) + object.builder.AddPostBuildHook(func (app pitaya.App) { + app.Register(...) + }) +} +``` + +In the above example the `ConfigurePitaya` method of the `CustomObject` is adding an `Acceptor` to the pitaya app being built, and also adding a post-build function which will register a handler `Component` that will expose endpoints to receive calls. diff --git a/docs/communication.md b/docs/communication.md index 3ce191d9..3a34ed34 100644 --- a/docs/communication.md +++ b/docs/communication.md @@ -47,7 +47,11 @@ The application can define a dictionary of compressed routes before starting, th ### Handshake -The first operation that happens when a client connects is the handshake. The handshake is initiated by the client, who sends informations about the client, such as platform, version of the client library, and others, and can also send user data in this step. This data is stored in the client's session and can be accessed later. The server replies with heartbeat interval, name of the serializer and the dictionary of compressed routes. +The first operation that happens when a client connects is the handshake. The handshake is initiated by the client, who sends information about the client, such as platform, version of the client library, and others, and can also send user data in this step. This data is stored in the client's session and can be accessed later. The server replies with heartbeat interval, name of the serializer and the dictionary of compressed routes. + +In order to enforce specific requirements, validations can be performed on the data submitted by the client. These validations server as a means to verify that the client is adherent to predefined server rules. By that if the client does not comply with the specified criteria, access to the server capabilities can be restricted. + +You can find more about the handshake validation [here](./handshake-validators.md). ### Remote service diff --git a/docs/handshake-validators.md b/docs/handshake-validators.md new file mode 100644 index 00000000..884f448b --- /dev/null +++ b/docs/handshake-validators.md @@ -0,0 +1,29 @@ +Handshake Validators +===== + +Pitaya allows to defined Handshake Validators.
+ +The primary purpose of these validators is to perform validation checks on the data transmitted by the client. The validators play a crucial role in verifying the integrity and reliability of the client's input before establishing a connection. + +In addition to data validation, handshake validators can also execute other custom logic to assess the client's compliance with the server-defined requirements. This additional logic may involve evaluating factors such as authenticating credentials, permissions, or any other criteria necessary to determine the client's eligibility to access the server. + +### Adding handshake validators + +To ensure the effective utilization of these validators, they should be added to the `SessionPool` component. As a result, each newly created session within the `SessionPool` will automatically incorporate the designated validators. + +Once the handshake process is initiated, the validators will be invoked to execute their validation routines. + +```go +cfg := config.NewDefaultBuilderConfig() +builder := pitaya.NewDefaultBuilder(isFrontEnd, "my-server-type", pitaya.Cluster, map[string]string{}, *cfg) +builder.SessionPool.AddHandshakeValidator("MyCustomValidator", func (data *session.HandshakeData) error { + if data.Sys.Version != "1.0.0" { + return errors.New("Unknown client version") + } + + return nil +}) +``` + +As a result of the validation process, if an error is encountered, the server will transmit a message to client within the code 400. This code emulates the widely recognized HTTP Bad Request status code, indicating that the client's request could not be fulfilled due to invalid data. Otherwise, if the validation process succeeds, the server will dispatch a message to client containing a code 200, mirroring the HTTP Ok status code.
+**Is important to mention that, when there are many validator functions, the validation will stop as soon it encounters the first error.** diff --git a/metrics/mocks/reporter.go b/metrics/mocks/reporter.go index 81e186ac..1b7ee61f 100644 --- a/metrics/mocks/reporter.go +++ b/metrics/mocks/reporter.go @@ -5,34 +5,35 @@ package mocks import ( - gomock "github.com/golang/mock/gomock" reflect "reflect" + + gomock "github.com/golang/mock/gomock" ) -// MockReporter is a mock of Reporter interface +// MockReporter is a mock of Reporter interface. type MockReporter struct { ctrl *gomock.Controller recorder *MockReporterMockRecorder } -// MockReporterMockRecorder is the mock recorder for MockReporter +// MockReporterMockRecorder is the mock recorder for MockReporter. type MockReporterMockRecorder struct { mock *MockReporter } -// NewMockReporter creates a new mock instance +// NewMockReporter creates a new mock instance. func NewMockReporter(ctrl *gomock.Controller) *MockReporter { mock := &MockReporter{ctrl: ctrl} mock.recorder = &MockReporterMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockReporter) EXPECT() *MockReporterMockRecorder { return m.recorder } -// ReportCount mocks base method +// ReportCount mocks base method. func (m *MockReporter) ReportCount(arg0 string, arg1 map[string]string, arg2 float64) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReportCount", arg0, arg1, arg2) @@ -40,13 +41,13 @@ func (m *MockReporter) ReportCount(arg0 string, arg1 map[string]string, arg2 flo return ret0 } -// ReportCount indicates an expected call of ReportCount +// ReportCount indicates an expected call of ReportCount. func (mr *MockReporterMockRecorder) ReportCount(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportCount", reflect.TypeOf((*MockReporter)(nil).ReportCount), arg0, arg1, arg2) } -// ReportGauge mocks base method +// ReportGauge mocks base method. func (m *MockReporter) ReportGauge(arg0 string, arg1 map[string]string, arg2 float64) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReportGauge", arg0, arg1, arg2) @@ -54,13 +55,13 @@ func (m *MockReporter) ReportGauge(arg0 string, arg1 map[string]string, arg2 flo return ret0 } -// ReportGauge indicates an expected call of ReportGauge +// ReportGauge indicates an expected call of ReportGauge. func (mr *MockReporterMockRecorder) ReportGauge(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportGauge", reflect.TypeOf((*MockReporter)(nil).ReportGauge), arg0, arg1, arg2) } -// ReportHistogram mocks base method +// ReportHistogram mocks base method. func (m *MockReporter) ReportHistogram(arg0 string, arg1 map[string]string, arg2 float64) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReportHistogram", arg0, arg1, arg2) @@ -68,13 +69,13 @@ func (m *MockReporter) ReportHistogram(arg0 string, arg1 map[string]string, arg2 return ret0 } -// ReportHistogram indicates an expected call of ReportHistogram +// ReportHistogram indicates an expected call of ReportHistogram. func (mr *MockReporterMockRecorder) ReportHistogram(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportHistogram", reflect.TypeOf((*MockReporter)(nil).ReportHistogram), arg0, arg1, arg2) } -// ReportSummary mocks base method +// ReportSummary mocks base method. func (m *MockReporter) ReportSummary(arg0 string, arg1 map[string]string, arg2 float64) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReportSummary", arg0, arg1, arg2) @@ -82,7 +83,7 @@ func (m *MockReporter) ReportSummary(arg0 string, arg1 map[string]string, arg2 f return ret0 } -// ReportSummary indicates an expected call of ReportSummary +// ReportSummary indicates an expected call of ReportSummary. func (mr *MockReporterMockRecorder) ReportSummary(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportSummary", reflect.TypeOf((*MockReporter)(nil).ReportSummary), arg0, arg1, arg2) diff --git a/metrics/mocks/statsd_reporter.go b/metrics/mocks/statsd_reporter.go index 2303005f..94c16d4d 100644 --- a/metrics/mocks/statsd_reporter.go +++ b/metrics/mocks/statsd_reporter.go @@ -5,34 +5,35 @@ package mocks import ( - gomock "github.com/golang/mock/gomock" reflect "reflect" + + gomock "github.com/golang/mock/gomock" ) -// MockClient is a mock of Client interface +// MockClient is a mock of Client interface. type MockClient struct { ctrl *gomock.Controller recorder *MockClientMockRecorder } -// MockClientMockRecorder is the mock recorder for MockClient +// MockClientMockRecorder is the mock recorder for MockClient. type MockClientMockRecorder struct { mock *MockClient } -// NewMockClient creates a new mock instance +// NewMockClient creates a new mock instance. func NewMockClient(ctrl *gomock.Controller) *MockClient { mock := &MockClient{ctrl: ctrl} mock.recorder = &MockClientMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockClient) EXPECT() *MockClientMockRecorder { return m.recorder } -// Count mocks base method +// Count mocks base method. func (m *MockClient) Count(arg0 string, arg1 int64, arg2 []string, arg3 float64) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Count", arg0, arg1, arg2, arg3) @@ -40,13 +41,13 @@ func (m *MockClient) Count(arg0 string, arg1 int64, arg2 []string, arg3 float64) return ret0 } -// Count indicates an expected call of Count +// Count indicates an expected call of Count. func (mr *MockClientMockRecorder) Count(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Count", reflect.TypeOf((*MockClient)(nil).Count), arg0, arg1, arg2, arg3) } -// Gauge mocks base method +// Gauge mocks base method. func (m *MockClient) Gauge(arg0 string, arg1 float64, arg2 []string, arg3 float64) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Gauge", arg0, arg1, arg2, arg3) @@ -54,13 +55,13 @@ func (m *MockClient) Gauge(arg0 string, arg1 float64, arg2 []string, arg3 float6 return ret0 } -// Gauge indicates an expected call of Gauge +// Gauge indicates an expected call of Gauge. func (mr *MockClientMockRecorder) Gauge(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Gauge", reflect.TypeOf((*MockClient)(nil).Gauge), arg0, arg1, arg2, arg3) } -// TimeInMilliseconds mocks base method +// TimeInMilliseconds mocks base method. func (m *MockClient) TimeInMilliseconds(arg0 string, arg1 float64, arg2 []string, arg3 float64) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "TimeInMilliseconds", arg0, arg1, arg2, arg3) @@ -68,7 +69,7 @@ func (m *MockClient) TimeInMilliseconds(arg0 string, arg1 float64, arg2 []string return ret0 } -// TimeInMilliseconds indicates an expected call of TimeInMilliseconds +// TimeInMilliseconds indicates an expected call of TimeInMilliseconds. func (mr *MockClientMockRecorder) TimeInMilliseconds(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TimeInMilliseconds", reflect.TypeOf((*MockClient)(nil).TimeInMilliseconds), arg0, arg1, arg2, arg3) diff --git a/mocks/acceptor.go b/mocks/acceptor.go index 1e8d5c35..3d77703a 100644 --- a/mocks/acceptor.go +++ b/mocks/acceptor.go @@ -1,7 +1,7 @@ // Code generated by MockGen. DO NOT EDIT. // Source: github.com/topfreegames/pitaya/v2/acceptor (interfaces: PlayerConn,Acceptor) -// Package mock_acceptor is a generated GoMock package. +// Package mocks is a generated GoMock package. package mocks import ( @@ -80,18 +80,18 @@ func (mr *MockPlayerConnMockRecorder) LocalAddr() *gomock.Call { } // Read mocks base method. -func (m *MockPlayerConn) Read(b []byte) (int, error) { +func (m *MockPlayerConn) Read(arg0 []byte) (int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Read", b) + ret := m.ctrl.Call(m, "Read", arg0) ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } // Read indicates an expected call of Read. -func (mr *MockPlayerConnMockRecorder) Read(b interface{}) *gomock.Call { +func (mr *MockPlayerConnMockRecorder) Read(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockPlayerConn)(nil).Read), b) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Read", reflect.TypeOf((*MockPlayerConn)(nil).Read), arg0) } // RemoteAddr mocks base method. @@ -109,60 +109,60 @@ func (mr *MockPlayerConnMockRecorder) RemoteAddr() *gomock.Call { } // SetDeadline mocks base method. -func (m *MockPlayerConn) SetDeadline(t time.Time) error { +func (m *MockPlayerConn) SetDeadline(arg0 time.Time) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetDeadline", t) + ret := m.ctrl.Call(m, "SetDeadline", arg0) ret0, _ := ret[0].(error) return ret0 } // SetDeadline indicates an expected call of SetDeadline. -func (mr *MockPlayerConnMockRecorder) SetDeadline(t interface{}) *gomock.Call { +func (mr *MockPlayerConnMockRecorder) SetDeadline(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockPlayerConn)(nil).SetDeadline), t) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDeadline", reflect.TypeOf((*MockPlayerConn)(nil).SetDeadline), arg0) } // SetReadDeadline mocks base method. -func (m *MockPlayerConn) SetReadDeadline(t time.Time) error { +func (m *MockPlayerConn) SetReadDeadline(arg0 time.Time) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetReadDeadline", t) + ret := m.ctrl.Call(m, "SetReadDeadline", arg0) ret0, _ := ret[0].(error) return ret0 } // SetReadDeadline indicates an expected call of SetReadDeadline. -func (mr *MockPlayerConnMockRecorder) SetReadDeadline(t interface{}) *gomock.Call { +func (mr *MockPlayerConnMockRecorder) SetReadDeadline(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockPlayerConn)(nil).SetReadDeadline), t) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetReadDeadline", reflect.TypeOf((*MockPlayerConn)(nil).SetReadDeadline), arg0) } // SetWriteDeadline mocks base method. -func (m *MockPlayerConn) SetWriteDeadline(t time.Time) error { +func (m *MockPlayerConn) SetWriteDeadline(arg0 time.Time) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetWriteDeadline", t) + ret := m.ctrl.Call(m, "SetWriteDeadline", arg0) ret0, _ := ret[0].(error) return ret0 } // SetWriteDeadline indicates an expected call of SetWriteDeadline. -func (mr *MockPlayerConnMockRecorder) SetWriteDeadline(t interface{}) *gomock.Call { +func (mr *MockPlayerConnMockRecorder) SetWriteDeadline(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockPlayerConn)(nil).SetWriteDeadline), t) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWriteDeadline", reflect.TypeOf((*MockPlayerConn)(nil).SetWriteDeadline), arg0) } // Write mocks base method. -func (m *MockPlayerConn) Write(b []byte) (int, error) { +func (m *MockPlayerConn) Write(arg0 []byte) (int, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Write", b) + ret := m.ctrl.Call(m, "Write", arg0) ret0, _ := ret[0].(int) ret1, _ := ret[1].(error) return ret0, ret1 } // Write indicates an expected call of Write. -func (mr *MockPlayerConnMockRecorder) Write(b interface{}) *gomock.Call { +func (mr *MockPlayerConnMockRecorder) Write(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockPlayerConn)(nil).Write), b) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Write", reflect.TypeOf((*MockPlayerConn)(nil).Write), arg0) } // MockAcceptor is a mock of Acceptor interface. diff --git a/mocks/app.go b/mocks/app.go index b7fad59c..ed98eafc 100644 --- a/mocks/app.go +++ b/mocks/app.go @@ -6,6 +6,9 @@ package mocks import ( context "context" + reflect "reflect" + time "time" + gomock "github.com/golang/mock/gomock" cluster "github.com/topfreegames/pitaya/v2/cluster" component "github.com/topfreegames/pitaya/v2/component" @@ -16,34 +19,32 @@ import ( session "github.com/topfreegames/pitaya/v2/session" worker "github.com/topfreegames/pitaya/v2/worker" protoiface "google.golang.org/protobuf/runtime/protoiface" - reflect "reflect" - time "time" ) -// MockPitaya is a mock of Pitaya interface +// MockPitaya is a mock of Pitaya interface. type MockPitaya struct { ctrl *gomock.Controller recorder *MockPitayaMockRecorder } -// MockPitayaMockRecorder is the mock recorder for MockPitaya +// MockPitayaMockRecorder is the mock recorder for MockPitaya. type MockPitayaMockRecorder struct { mock *MockPitaya } -// NewMockPitaya creates a new mock instance +// NewMockPitaya creates a new mock instance. func NewMockPitaya(ctrl *gomock.Controller) *MockPitaya { mock := &MockPitaya{ctrl: ctrl} mock.recorder = &MockPitayaMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockPitaya) EXPECT() *MockPitayaMockRecorder { return m.recorder } -// AddRoute mocks base method +// AddRoute mocks base method. func (m *MockPitaya) AddRoute(arg0 string, arg1 router.RoutingFunc) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddRoute", arg0, arg1) @@ -51,13 +52,13 @@ func (m *MockPitaya) AddRoute(arg0 string, arg1 router.RoutingFunc) error { return ret0 } -// AddRoute indicates an expected call of AddRoute +// AddRoute indicates an expected call of AddRoute. func (mr *MockPitayaMockRecorder) AddRoute(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRoute", reflect.TypeOf((*MockPitaya)(nil).AddRoute), arg0, arg1) } -// Documentation mocks base method +// Documentation mocks base method. func (m *MockPitaya) Documentation(arg0 bool) (map[string]interface{}, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Documentation", arg0) @@ -66,13 +67,13 @@ func (m *MockPitaya) Documentation(arg0 bool) (map[string]interface{}, error) { return ret0, ret1 } -// Documentation indicates an expected call of Documentation +// Documentation indicates an expected call of Documentation. func (mr *MockPitayaMockRecorder) Documentation(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Documentation", reflect.TypeOf((*MockPitaya)(nil).Documentation), arg0) } -// GetDieChan mocks base method +// GetDieChan mocks base method. func (m *MockPitaya) GetDieChan() chan bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetDieChan") @@ -80,13 +81,13 @@ func (m *MockPitaya) GetDieChan() chan bool { return ret0 } -// GetDieChan indicates an expected call of GetDieChan +// GetDieChan indicates an expected call of GetDieChan. func (mr *MockPitayaMockRecorder) GetDieChan() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDieChan", reflect.TypeOf((*MockPitaya)(nil).GetDieChan)) } -// GetMetricsReporters mocks base method +// GetMetricsReporters mocks base method. func (m *MockPitaya) GetMetricsReporters() []metrics.Reporter { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetMetricsReporters") @@ -94,13 +95,13 @@ func (m *MockPitaya) GetMetricsReporters() []metrics.Reporter { return ret0 } -// GetMetricsReporters indicates an expected call of GetMetricsReporters +// GetMetricsReporters indicates an expected call of GetMetricsReporters. func (mr *MockPitayaMockRecorder) GetMetricsReporters() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMetricsReporters", reflect.TypeOf((*MockPitaya)(nil).GetMetricsReporters)) } -// GetModule mocks base method +// GetModule mocks base method. func (m *MockPitaya) GetModule(arg0 string) (interfaces.Module, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetModule", arg0) @@ -109,13 +110,13 @@ func (m *MockPitaya) GetModule(arg0 string) (interfaces.Module, error) { return ret0, ret1 } -// GetModule indicates an expected call of GetModule +// GetModule indicates an expected call of GetModule. func (mr *MockPitayaMockRecorder) GetModule(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetModule", reflect.TypeOf((*MockPitaya)(nil).GetModule), arg0) } -// GetServer mocks base method +// GetServer mocks base method. func (m *MockPitaya) GetServer() *cluster.Server { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetServer") @@ -123,13 +124,13 @@ func (m *MockPitaya) GetServer() *cluster.Server { return ret0 } -// GetServer indicates an expected call of GetServer +// GetServer indicates an expected call of GetServer. func (mr *MockPitayaMockRecorder) GetServer() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServer", reflect.TypeOf((*MockPitaya)(nil).GetServer)) } -// GetServerByID mocks base method +// GetServerByID mocks base method. func (m *MockPitaya) GetServerByID(arg0 string) (*cluster.Server, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetServerByID", arg0) @@ -138,13 +139,13 @@ func (m *MockPitaya) GetServerByID(arg0 string) (*cluster.Server, error) { return ret0, ret1 } -// GetServerByID indicates an expected call of GetServerByID +// GetServerByID indicates an expected call of GetServerByID. func (mr *MockPitayaMockRecorder) GetServerByID(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServerByID", reflect.TypeOf((*MockPitaya)(nil).GetServerByID), arg0) } -// GetServerID mocks base method +// GetServerID mocks base method. func (m *MockPitaya) GetServerID() string { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetServerID") @@ -152,13 +153,13 @@ func (m *MockPitaya) GetServerID() string { return ret0 } -// GetServerID indicates an expected call of GetServerID +// GetServerID indicates an expected call of GetServerID. func (mr *MockPitayaMockRecorder) GetServerID() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServerID", reflect.TypeOf((*MockPitaya)(nil).GetServerID)) } -// GetServers mocks base method +// GetServers mocks base method. func (m *MockPitaya) GetServers() []*cluster.Server { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetServers") @@ -166,13 +167,13 @@ func (m *MockPitaya) GetServers() []*cluster.Server { return ret0 } -// GetServers indicates an expected call of GetServers +// GetServers indicates an expected call of GetServers. func (mr *MockPitayaMockRecorder) GetServers() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServers", reflect.TypeOf((*MockPitaya)(nil).GetServers)) } -// GetServersByType mocks base method +// GetServersByType mocks base method. func (m *MockPitaya) GetServersByType(arg0 string) (map[string]*cluster.Server, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetServersByType", arg0) @@ -181,13 +182,13 @@ func (m *MockPitaya) GetServersByType(arg0 string) (map[string]*cluster.Server, return ret0, ret1 } -// GetServersByType indicates an expected call of GetServersByType +// GetServersByType indicates an expected call of GetServersByType. func (mr *MockPitayaMockRecorder) GetServersByType(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServersByType", reflect.TypeOf((*MockPitaya)(nil).GetServersByType), arg0) } -// GetSessionFromCtx mocks base method +// GetSessionFromCtx mocks base method. func (m *MockPitaya) GetSessionFromCtx(arg0 context.Context) session.Session { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetSessionFromCtx", arg0) @@ -195,13 +196,13 @@ func (m *MockPitaya) GetSessionFromCtx(arg0 context.Context) session.Session { return ret0 } -// GetSessionFromCtx indicates an expected call of GetSessionFromCtx +// GetSessionFromCtx indicates an expected call of GetSessionFromCtx. func (mr *MockPitayaMockRecorder) GetSessionFromCtx(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionFromCtx", reflect.TypeOf((*MockPitaya)(nil).GetSessionFromCtx), arg0) } -// GroupAddMember mocks base method +// GroupAddMember mocks base method. func (m *MockPitaya) GroupAddMember(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GroupAddMember", arg0, arg1, arg2) @@ -209,13 +210,13 @@ func (m *MockPitaya) GroupAddMember(arg0 context.Context, arg1, arg2 string) err return ret0 } -// GroupAddMember indicates an expected call of GroupAddMember +// GroupAddMember indicates an expected call of GroupAddMember. func (mr *MockPitayaMockRecorder) GroupAddMember(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupAddMember", reflect.TypeOf((*MockPitaya)(nil).GroupAddMember), arg0, arg1, arg2) } -// GroupBroadcast mocks base method +// GroupBroadcast mocks base method. func (m *MockPitaya) GroupBroadcast(arg0 context.Context, arg1, arg2, arg3 string, arg4 interface{}) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GroupBroadcast", arg0, arg1, arg2, arg3, arg4) @@ -223,13 +224,13 @@ func (m *MockPitaya) GroupBroadcast(arg0 context.Context, arg1, arg2, arg3 strin return ret0 } -// GroupBroadcast indicates an expected call of GroupBroadcast +// GroupBroadcast indicates an expected call of GroupBroadcast. func (mr *MockPitayaMockRecorder) GroupBroadcast(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupBroadcast", reflect.TypeOf((*MockPitaya)(nil).GroupBroadcast), arg0, arg1, arg2, arg3, arg4) } -// GroupContainsMember mocks base method +// GroupContainsMember mocks base method. func (m *MockPitaya) GroupContainsMember(arg0 context.Context, arg1, arg2 string) (bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GroupContainsMember", arg0, arg1, arg2) @@ -238,13 +239,13 @@ func (m *MockPitaya) GroupContainsMember(arg0 context.Context, arg1, arg2 string return ret0, ret1 } -// GroupContainsMember indicates an expected call of GroupContainsMember +// GroupContainsMember indicates an expected call of GroupContainsMember. func (mr *MockPitayaMockRecorder) GroupContainsMember(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupContainsMember", reflect.TypeOf((*MockPitaya)(nil).GroupContainsMember), arg0, arg1, arg2) } -// GroupCountMembers mocks base method +// GroupCountMembers mocks base method. func (m *MockPitaya) GroupCountMembers(arg0 context.Context, arg1 string) (int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GroupCountMembers", arg0, arg1) @@ -253,13 +254,13 @@ func (m *MockPitaya) GroupCountMembers(arg0 context.Context, arg1 string) (int, return ret0, ret1 } -// GroupCountMembers indicates an expected call of GroupCountMembers +// GroupCountMembers indicates an expected call of GroupCountMembers. func (mr *MockPitayaMockRecorder) GroupCountMembers(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupCountMembers", reflect.TypeOf((*MockPitaya)(nil).GroupCountMembers), arg0, arg1) } -// GroupCreate mocks base method +// GroupCreate mocks base method. func (m *MockPitaya) GroupCreate(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GroupCreate", arg0, arg1) @@ -267,13 +268,13 @@ func (m *MockPitaya) GroupCreate(arg0 context.Context, arg1 string) error { return ret0 } -// GroupCreate indicates an expected call of GroupCreate +// GroupCreate indicates an expected call of GroupCreate. func (mr *MockPitayaMockRecorder) GroupCreate(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupCreate", reflect.TypeOf((*MockPitaya)(nil).GroupCreate), arg0, arg1) } -// GroupCreateWithTTL mocks base method +// GroupCreateWithTTL mocks base method. func (m *MockPitaya) GroupCreateWithTTL(arg0 context.Context, arg1 string, arg2 time.Duration) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GroupCreateWithTTL", arg0, arg1, arg2) @@ -281,13 +282,13 @@ func (m *MockPitaya) GroupCreateWithTTL(arg0 context.Context, arg1 string, arg2 return ret0 } -// GroupCreateWithTTL indicates an expected call of GroupCreateWithTTL +// GroupCreateWithTTL indicates an expected call of GroupCreateWithTTL. func (mr *MockPitayaMockRecorder) GroupCreateWithTTL(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupCreateWithTTL", reflect.TypeOf((*MockPitaya)(nil).GroupCreateWithTTL), arg0, arg1, arg2) } -// GroupDelete mocks base method +// GroupDelete mocks base method. func (m *MockPitaya) GroupDelete(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GroupDelete", arg0, arg1) @@ -295,13 +296,13 @@ func (m *MockPitaya) GroupDelete(arg0 context.Context, arg1 string) error { return ret0 } -// GroupDelete indicates an expected call of GroupDelete +// GroupDelete indicates an expected call of GroupDelete. func (mr *MockPitayaMockRecorder) GroupDelete(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupDelete", reflect.TypeOf((*MockPitaya)(nil).GroupDelete), arg0, arg1) } -// GroupMembers mocks base method +// GroupMembers mocks base method. func (m *MockPitaya) GroupMembers(arg0 context.Context, arg1 string) ([]string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GroupMembers", arg0, arg1) @@ -310,13 +311,13 @@ func (m *MockPitaya) GroupMembers(arg0 context.Context, arg1 string) ([]string, return ret0, ret1 } -// GroupMembers indicates an expected call of GroupMembers +// GroupMembers indicates an expected call of GroupMembers. func (mr *MockPitayaMockRecorder) GroupMembers(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupMembers", reflect.TypeOf((*MockPitaya)(nil).GroupMembers), arg0, arg1) } -// GroupRemoveAll mocks base method +// GroupRemoveAll mocks base method. func (m *MockPitaya) GroupRemoveAll(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GroupRemoveAll", arg0, arg1) @@ -324,13 +325,13 @@ func (m *MockPitaya) GroupRemoveAll(arg0 context.Context, arg1 string) error { return ret0 } -// GroupRemoveAll indicates an expected call of GroupRemoveAll +// GroupRemoveAll indicates an expected call of GroupRemoveAll. func (mr *MockPitayaMockRecorder) GroupRemoveAll(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupRemoveAll", reflect.TypeOf((*MockPitaya)(nil).GroupRemoveAll), arg0, arg1) } -// GroupRemoveMember mocks base method +// GroupRemoveMember mocks base method. func (m *MockPitaya) GroupRemoveMember(arg0 context.Context, arg1, arg2 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GroupRemoveMember", arg0, arg1, arg2) @@ -338,13 +339,13 @@ func (m *MockPitaya) GroupRemoveMember(arg0 context.Context, arg1, arg2 string) return ret0 } -// GroupRemoveMember indicates an expected call of GroupRemoveMember +// GroupRemoveMember indicates an expected call of GroupRemoveMember. func (mr *MockPitayaMockRecorder) GroupRemoveMember(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupRemoveMember", reflect.TypeOf((*MockPitaya)(nil).GroupRemoveMember), arg0, arg1, arg2) } -// GroupRenewTTL mocks base method +// GroupRenewTTL mocks base method. func (m *MockPitaya) GroupRenewTTL(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GroupRenewTTL", arg0, arg1) @@ -352,13 +353,13 @@ func (m *MockPitaya) GroupRenewTTL(arg0 context.Context, arg1 string) error { return ret0 } -// GroupRenewTTL indicates an expected call of GroupRenewTTL +// GroupRenewTTL indicates an expected call of GroupRenewTTL. func (mr *MockPitayaMockRecorder) GroupRenewTTL(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GroupRenewTTL", reflect.TypeOf((*MockPitaya)(nil).GroupRenewTTL), arg0, arg1) } -// IsRunning mocks base method +// IsRunning mocks base method. func (m *MockPitaya) IsRunning() bool { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IsRunning") @@ -366,13 +367,13 @@ func (m *MockPitaya) IsRunning() bool { return ret0 } -// IsRunning indicates an expected call of IsRunning +// IsRunning indicates an expected call of IsRunning. func (mr *MockPitayaMockRecorder) IsRunning() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsRunning", reflect.TypeOf((*MockPitaya)(nil).IsRunning)) } -// RPC mocks base method +// RPC mocks base method. func (m *MockPitaya) RPC(arg0 context.Context, arg1 string, arg2, arg3 protoiface.MessageV1) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RPC", arg0, arg1, arg2, arg3) @@ -380,13 +381,13 @@ func (m *MockPitaya) RPC(arg0 context.Context, arg1 string, arg2, arg3 protoifac return ret0 } -// RPC indicates an expected call of RPC +// RPC indicates an expected call of RPC. func (mr *MockPitayaMockRecorder) RPC(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RPC", reflect.TypeOf((*MockPitaya)(nil).RPC), arg0, arg1, arg2, arg3) } -// RPCTo mocks base method +// RPCTo mocks base method. func (m *MockPitaya) RPCTo(arg0 context.Context, arg1, arg2 string, arg3, arg4 protoiface.MessageV1) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RPCTo", arg0, arg1, arg2, arg3, arg4) @@ -394,13 +395,13 @@ func (m *MockPitaya) RPCTo(arg0 context.Context, arg1, arg2 string, arg3, arg4 p return ret0 } -// RPCTo indicates an expected call of RPCTo +// RPCTo indicates an expected call of RPCTo. func (mr *MockPitayaMockRecorder) RPCTo(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RPCTo", reflect.TypeOf((*MockPitaya)(nil).RPCTo), arg0, arg1, arg2, arg3, arg4) } -// Register mocks base method +// Register mocks base method. func (m *MockPitaya) Register(arg0 component.Component, arg1 ...component.Option) { m.ctrl.T.Helper() varargs := []interface{}{arg0} @@ -410,14 +411,14 @@ func (m *MockPitaya) Register(arg0 component.Component, arg1 ...component.Option m.ctrl.Call(m, "Register", varargs...) } -// Register indicates an expected call of Register +// Register indicates an expected call of Register. func (mr *MockPitayaMockRecorder) Register(arg0 interface{}, arg1 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]interface{}{arg0}, arg1...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Register", reflect.TypeOf((*MockPitaya)(nil).Register), varargs...) } -// RegisterModule mocks base method +// RegisterModule mocks base method. func (m *MockPitaya) RegisterModule(arg0 interfaces.Module, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RegisterModule", arg0, arg1) @@ -425,13 +426,13 @@ func (m *MockPitaya) RegisterModule(arg0 interfaces.Module, arg1 string) error { return ret0 } -// RegisterModule indicates an expected call of RegisterModule +// RegisterModule indicates an expected call of RegisterModule. func (mr *MockPitayaMockRecorder) RegisterModule(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterModule", reflect.TypeOf((*MockPitaya)(nil).RegisterModule), arg0, arg1) } -// RegisterModuleAfter mocks base method +// RegisterModuleAfter mocks base method. func (m *MockPitaya) RegisterModuleAfter(arg0 interfaces.Module, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RegisterModuleAfter", arg0, arg1) @@ -439,13 +440,13 @@ func (m *MockPitaya) RegisterModuleAfter(arg0 interfaces.Module, arg1 string) er return ret0 } -// RegisterModuleAfter indicates an expected call of RegisterModuleAfter +// RegisterModuleAfter indicates an expected call of RegisterModuleAfter. func (mr *MockPitayaMockRecorder) RegisterModuleAfter(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterModuleAfter", reflect.TypeOf((*MockPitaya)(nil).RegisterModuleAfter), arg0, arg1) } -// RegisterModuleBefore mocks base method +// RegisterModuleBefore mocks base method. func (m *MockPitaya) RegisterModuleBefore(arg0 interfaces.Module, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RegisterModuleBefore", arg0, arg1) @@ -453,13 +454,13 @@ func (m *MockPitaya) RegisterModuleBefore(arg0 interfaces.Module, arg1 string) e return ret0 } -// RegisterModuleBefore indicates an expected call of RegisterModuleBefore +// RegisterModuleBefore indicates an expected call of RegisterModuleBefore. func (mr *MockPitayaMockRecorder) RegisterModuleBefore(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterModuleBefore", reflect.TypeOf((*MockPitaya)(nil).RegisterModuleBefore), arg0, arg1) } -// RegisterRPCJob mocks base method +// RegisterRPCJob mocks base method. func (m *MockPitaya) RegisterRPCJob(arg0 worker.RPCJob) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RegisterRPCJob", arg0) @@ -467,13 +468,13 @@ func (m *MockPitaya) RegisterRPCJob(arg0 worker.RPCJob) error { return ret0 } -// RegisterRPCJob indicates an expected call of RegisterRPCJob +// RegisterRPCJob indicates an expected call of RegisterRPCJob. func (mr *MockPitayaMockRecorder) RegisterRPCJob(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterRPCJob", reflect.TypeOf((*MockPitaya)(nil).RegisterRPCJob), arg0) } -// RegisterRemote mocks base method +// RegisterRemote mocks base method. func (m *MockPitaya) RegisterRemote(arg0 component.Component, arg1 ...component.Option) { m.ctrl.T.Helper() varargs := []interface{}{arg0} @@ -483,14 +484,14 @@ func (m *MockPitaya) RegisterRemote(arg0 component.Component, arg1 ...component. m.ctrl.Call(m, "RegisterRemote", varargs...) } -// RegisterRemote indicates an expected call of RegisterRemote +// RegisterRemote indicates an expected call of RegisterRemote. func (mr *MockPitayaMockRecorder) RegisterRemote(arg0 interface{}, arg1 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]interface{}{arg0}, arg1...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterRemote", reflect.TypeOf((*MockPitaya)(nil).RegisterRemote), varargs...) } -// ReliableRPC mocks base method +// ReliableRPC mocks base method. func (m *MockPitaya) ReliableRPC(arg0 string, arg1 map[string]interface{}, arg2, arg3 protoiface.MessageV1) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReliableRPC", arg0, arg1, arg2, arg3) @@ -499,13 +500,13 @@ func (m *MockPitaya) ReliableRPC(arg0 string, arg1 map[string]interface{}, arg2, return ret0, ret1 } -// ReliableRPC indicates an expected call of ReliableRPC +// ReliableRPC indicates an expected call of ReliableRPC. func (mr *MockPitayaMockRecorder) ReliableRPC(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReliableRPC", reflect.TypeOf((*MockPitaya)(nil).ReliableRPC), arg0, arg1, arg2, arg3) } -// ReliableRPCWithOptions mocks base method +// ReliableRPCWithOptions mocks base method. func (m *MockPitaya) ReliableRPCWithOptions(arg0 string, arg1 map[string]interface{}, arg2, arg3 protoiface.MessageV1, arg4 *config.EnqueueOpts) (string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ReliableRPCWithOptions", arg0, arg1, arg2, arg3, arg4) @@ -514,13 +515,13 @@ func (m *MockPitaya) ReliableRPCWithOptions(arg0 string, arg1 map[string]interfa return ret0, ret1 } -// ReliableRPCWithOptions indicates an expected call of ReliableRPCWithOptions +// ReliableRPCWithOptions indicates an expected call of ReliableRPCWithOptions. func (mr *MockPitayaMockRecorder) ReliableRPCWithOptions(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReliableRPCWithOptions", reflect.TypeOf((*MockPitaya)(nil).ReliableRPCWithOptions), arg0, arg1, arg2, arg3, arg4) } -// SendKickToUsers mocks base method +// SendKickToUsers mocks base method. func (m *MockPitaya) SendKickToUsers(arg0 []string, arg1 string) ([]string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendKickToUsers", arg0, arg1) @@ -529,13 +530,13 @@ func (m *MockPitaya) SendKickToUsers(arg0 []string, arg1 string) ([]string, erro return ret0, ret1 } -// SendKickToUsers indicates an expected call of SendKickToUsers +// SendKickToUsers indicates an expected call of SendKickToUsers. func (mr *MockPitayaMockRecorder) SendKickToUsers(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendKickToUsers", reflect.TypeOf((*MockPitaya)(nil).SendKickToUsers), arg0, arg1) } -// SendPushToUsers mocks base method +// SendPushToUsers mocks base method. func (m *MockPitaya) SendPushToUsers(arg0 string, arg1 interface{}, arg2 []string, arg3 string) ([]string, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendPushToUsers", arg0, arg1, arg2, arg3) @@ -544,25 +545,25 @@ func (m *MockPitaya) SendPushToUsers(arg0 string, arg1 interface{}, arg2 []strin return ret0, ret1 } -// SendPushToUsers indicates an expected call of SendPushToUsers +// SendPushToUsers indicates an expected call of SendPushToUsers. func (mr *MockPitayaMockRecorder) SendPushToUsers(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendPushToUsers", reflect.TypeOf((*MockPitaya)(nil).SendPushToUsers), arg0, arg1, arg2, arg3) } -// SetDebug mocks base method +// SetDebug mocks base method. func (m *MockPitaya) SetDebug(arg0 bool) { m.ctrl.T.Helper() m.ctrl.Call(m, "SetDebug", arg0) } -// SetDebug indicates an expected call of SetDebug +// SetDebug indicates an expected call of SetDebug. func (mr *MockPitayaMockRecorder) SetDebug(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDebug", reflect.TypeOf((*MockPitaya)(nil).SetDebug), arg0) } -// SetDictionary mocks base method +// SetDictionary mocks base method. func (m *MockPitaya) SetDictionary(arg0 map[string]uint16) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetDictionary", arg0) @@ -570,55 +571,55 @@ func (m *MockPitaya) SetDictionary(arg0 map[string]uint16) error { return ret0 } -// SetDictionary indicates an expected call of SetDictionary +// SetDictionary indicates an expected call of SetDictionary. func (mr *MockPitayaMockRecorder) SetDictionary(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDictionary", reflect.TypeOf((*MockPitaya)(nil).SetDictionary), arg0) } -// SetHeartbeatTime mocks base method +// SetHeartbeatTime mocks base method. func (m *MockPitaya) SetHeartbeatTime(arg0 time.Duration) { m.ctrl.T.Helper() m.ctrl.Call(m, "SetHeartbeatTime", arg0) } -// SetHeartbeatTime indicates an expected call of SetHeartbeatTime +// SetHeartbeatTime indicates an expected call of SetHeartbeatTime. func (mr *MockPitayaMockRecorder) SetHeartbeatTime(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHeartbeatTime", reflect.TypeOf((*MockPitaya)(nil).SetHeartbeatTime), arg0) } -// Shutdown mocks base method +// Shutdown mocks base method. func (m *MockPitaya) Shutdown() { m.ctrl.T.Helper() m.ctrl.Call(m, "Shutdown") } -// Shutdown indicates an expected call of Shutdown +// Shutdown indicates an expected call of Shutdown. func (mr *MockPitayaMockRecorder) Shutdown() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Shutdown", reflect.TypeOf((*MockPitaya)(nil).Shutdown)) } -// Start mocks base method +// Start mocks base method. func (m *MockPitaya) Start() { m.ctrl.T.Helper() m.ctrl.Call(m, "Start") } -// Start indicates an expected call of Start +// Start indicates an expected call of Start. func (mr *MockPitayaMockRecorder) Start() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockPitaya)(nil).Start)) } -// StartWorker mocks base method +// StartWorker mocks base method. func (m *MockPitaya) StartWorker() { m.ctrl.T.Helper() m.ctrl.Call(m, "StartWorker") } -// StartWorker indicates an expected call of StartWorker +// StartWorker indicates an expected call of StartWorker. func (mr *MockPitayaMockRecorder) StartWorker() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartWorker", reflect.TypeOf((*MockPitaya)(nil).StartWorker)) diff --git a/networkentity/mocks/networkentity.go b/networkentity/mocks/networkentity.go index 5405436a..f30c28d9 100644 --- a/networkentity/mocks/networkentity.go +++ b/networkentity/mocks/networkentity.go @@ -6,36 +6,37 @@ package mocks import ( context "context" - gomock "github.com/golang/mock/gomock" - protos "github.com/topfreegames/pitaya/v2/protos" net "net" reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protos "github.com/topfreegames/pitaya/v2/protos" ) -// MockNetworkEntity is a mock of NetworkEntity interface +// MockNetworkEntity is a mock of NetworkEntity interface. type MockNetworkEntity struct { ctrl *gomock.Controller recorder *MockNetworkEntityMockRecorder } -// MockNetworkEntityMockRecorder is the mock recorder for MockNetworkEntity +// MockNetworkEntityMockRecorder is the mock recorder for MockNetworkEntity. type MockNetworkEntityMockRecorder struct { mock *MockNetworkEntity } -// NewMockNetworkEntity creates a new mock instance +// NewMockNetworkEntity creates a new mock instance. func NewMockNetworkEntity(ctrl *gomock.Controller) *MockNetworkEntity { mock := &MockNetworkEntity{ctrl: ctrl} mock.recorder = &MockNetworkEntityMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockNetworkEntity) EXPECT() *MockNetworkEntityMockRecorder { return m.recorder } -// Close mocks base method +// Close mocks base method. func (m *MockNetworkEntity) Close() error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Close") @@ -43,13 +44,13 @@ func (m *MockNetworkEntity) Close() error { return ret0 } -// Close indicates an expected call of Close +// Close indicates an expected call of Close. func (mr *MockNetworkEntityMockRecorder) Close() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockNetworkEntity)(nil).Close)) } -// Kick mocks base method +// Kick mocks base method. func (m *MockNetworkEntity) Kick(arg0 context.Context) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Kick", arg0) @@ -57,13 +58,13 @@ func (m *MockNetworkEntity) Kick(arg0 context.Context) error { return ret0 } -// Kick indicates an expected call of Kick +// Kick indicates an expected call of Kick. func (mr *MockNetworkEntityMockRecorder) Kick(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Kick", reflect.TypeOf((*MockNetworkEntity)(nil).Kick), arg0) } -// Push mocks base method +// Push mocks base method. func (m *MockNetworkEntity) Push(arg0 string, arg1 interface{}) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Push", arg0, arg1) @@ -71,13 +72,13 @@ func (m *MockNetworkEntity) Push(arg0 string, arg1 interface{}) error { return ret0 } -// Push indicates an expected call of Push +// Push indicates an expected call of Push. func (mr *MockNetworkEntityMockRecorder) Push(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Push", reflect.TypeOf((*MockNetworkEntity)(nil).Push), arg0, arg1) } -// RemoteAddr mocks base method +// RemoteAddr mocks base method. func (m *MockNetworkEntity) RemoteAddr() net.Addr { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RemoteAddr") @@ -85,13 +86,13 @@ func (m *MockNetworkEntity) RemoteAddr() net.Addr { return ret0 } -// RemoteAddr indicates an expected call of RemoteAddr +// RemoteAddr indicates an expected call of RemoteAddr. func (mr *MockNetworkEntityMockRecorder) RemoteAddr() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoteAddr", reflect.TypeOf((*MockNetworkEntity)(nil).RemoteAddr)) } -// ResponseMID mocks base method +// ResponseMID mocks base method. func (m *MockNetworkEntity) ResponseMID(arg0 context.Context, arg1 uint, arg2 interface{}, arg3 ...bool) error { m.ctrl.T.Helper() varargs := []interface{}{arg0, arg1, arg2} @@ -103,14 +104,14 @@ func (m *MockNetworkEntity) ResponseMID(arg0 context.Context, arg1 uint, arg2 in return ret0 } -// ResponseMID indicates an expected call of ResponseMID +// ResponseMID indicates an expected call of ResponseMID. func (mr *MockNetworkEntityMockRecorder) ResponseMID(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResponseMID", reflect.TypeOf((*MockNetworkEntity)(nil).ResponseMID), varargs...) } -// SendRequest mocks base method +// SendRequest mocks base method. func (m *MockNetworkEntity) SendRequest(arg0 context.Context, arg1, arg2 string, arg3 interface{}) (*protos.Response, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SendRequest", arg0, arg1, arg2, arg3) @@ -119,7 +120,7 @@ func (m *MockNetworkEntity) SendRequest(arg0 context.Context, arg1, arg2 string, return ret0, ret1 } -// SendRequest indicates an expected call of SendRequest +// SendRequest indicates an expected call of SendRequest. func (mr *MockNetworkEntityMockRecorder) SendRequest(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendRequest", reflect.TypeOf((*MockNetworkEntity)(nil).SendRequest), arg0, arg1, arg2, arg3) diff --git a/serialize/mocks/serializer.go b/serialize/mocks/serializer.go index 9ebd4b4f..b1c69b75 100644 --- a/serialize/mocks/serializer.go +++ b/serialize/mocks/serializer.go @@ -5,34 +5,35 @@ package mocks import ( - gomock "github.com/golang/mock/gomock" reflect "reflect" + + gomock "github.com/golang/mock/gomock" ) -// MockSerializer is a mock of Serializer interface +// MockSerializer is a mock of Serializer interface. type MockSerializer struct { ctrl *gomock.Controller recorder *MockSerializerMockRecorder } -// MockSerializerMockRecorder is the mock recorder for MockSerializer +// MockSerializerMockRecorder is the mock recorder for MockSerializer. type MockSerializerMockRecorder struct { mock *MockSerializer } -// NewMockSerializer creates a new mock instance +// NewMockSerializer creates a new mock instance. func NewMockSerializer(ctrl *gomock.Controller) *MockSerializer { mock := &MockSerializer{ctrl: ctrl} mock.recorder = &MockSerializerMockRecorder{mock} return mock } -// EXPECT returns an object that allows the caller to indicate expected use +// EXPECT returns an object that allows the caller to indicate expected use. func (m *MockSerializer) EXPECT() *MockSerializerMockRecorder { return m.recorder } -// GetName mocks base method +// GetName mocks base method. func (m *MockSerializer) GetName() string { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetName") @@ -40,13 +41,13 @@ func (m *MockSerializer) GetName() string { return ret0 } -// GetName indicates an expected call of GetName +// GetName indicates an expected call of GetName. func (mr *MockSerializerMockRecorder) GetName() *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetName", reflect.TypeOf((*MockSerializer)(nil).GetName)) } -// Marshal mocks base method +// Marshal mocks base method. func (m *MockSerializer) Marshal(arg0 interface{}) ([]byte, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Marshal", arg0) @@ -55,13 +56,13 @@ func (m *MockSerializer) Marshal(arg0 interface{}) ([]byte, error) { return ret0, ret1 } -// Marshal indicates an expected call of Marshal +// Marshal indicates an expected call of Marshal. func (mr *MockSerializerMockRecorder) Marshal(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Marshal", reflect.TypeOf((*MockSerializer)(nil).Marshal), arg0) } -// Unmarshal mocks base method +// Unmarshal mocks base method. func (m *MockSerializer) Unmarshal(arg0 []byte, arg1 interface{}) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Unmarshal", arg0, arg1) @@ -69,7 +70,7 @@ func (m *MockSerializer) Unmarshal(arg0 []byte, arg1 interface{}) error { return ret0 } -// Unmarshal indicates an expected call of Unmarshal +// Unmarshal indicates an expected call of Unmarshal. func (mr *MockSerializerMockRecorder) Unmarshal(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unmarshal", reflect.TypeOf((*MockSerializer)(nil).Unmarshal), arg0, arg1) diff --git a/service/handler.go b/service/handler.go index c96d7ef5..108db657 100644 --- a/service/handler.go +++ b/service/handler.go @@ -213,23 +213,38 @@ func (h *HandlerService) processPacket(a agent.Agent, p *packet.Packet) error { switch p.Type { case packet.Handshake: logger.Log.Debug("Received handshake packet") + + // Parse the json sent with the handshake by the client + handshakeData := &session.HandshakeData{} + if err := json.Unmarshal(p.Data, handshakeData); err != nil { + logger.Log.Errorf("Failed to unmarshal handshake data: %s", err.Error()) + if serr := a.SendHandshakeErrorResponse(); serr != nil { + logger.Log.Errorf("Error sending handshake error response: %s", err.Error()) + return err + } + + return fmt.Errorf("invalid handshake data. Id=%d", a.GetSession().ID()) + } + + if err := a.GetSession().ValidateHandshake(handshakeData); err != nil { + logger.Log.Errorf("Handshake validation failed: %s", err.Error()) + if serr := a.SendHandshakeErrorResponse(); serr != nil { + logger.Log.Errorf("Error sending handshake error response: %s", err.Error()) + return err + } + + return fmt.Errorf("handshake validation failed: %w. SessionId=%d", err, a.GetSession().ID()) + } + if err := a.SendHandshakeResponse(); err != nil { logger.Log.Errorf("Error sending handshake response: %s", err.Error()) return err } logger.Log.Debugf("Session handshake Id=%d, Remote=%s", a.GetSession().ID(), a.RemoteAddr()) - // Parse the json sent with the handshake by the client - handshakeData := &session.HandshakeData{} - err := json.Unmarshal(p.Data, handshakeData) - if err != nil { - a.SetStatus(constants.StatusClosed) - return fmt.Errorf("Invalid handshake data. Id=%d", a.GetSession().ID()) - } - a.GetSession().SetHandshakeData(handshakeData) a.SetStatus(constants.StatusHandshake) - err = a.GetSession().Set(constants.IPVersionKey, a.IPVersion()) + err := a.GetSession().Set(constants.IPVersionKey, a.IPVersion()) if err != nil { logger.Log.Warnf("failed to save ip version on session: %q\n", err) } diff --git a/service/handler_test.go b/service/handler_test.go index ee050545..ba2d5a9d 100644 --- a/service/handler_test.go +++ b/service/handler_test.go @@ -252,10 +252,12 @@ func TestHandlerServiceProcessPacketHandshake(t *testing.T) { name string packet *packet.Packet socketStatus int32 + validator func(data *session.HandshakeData) error errStr string }{ - {"invalid_handshake_data", &packet.Packet{Type: packet.Handshake, Data: []byte("asiodjasd")}, constants.StatusClosed, "Invalid handshake data"}, - {"valid_handshake_data", &packet.Packet{Type: packet.Handshake, Data: []byte(`{"sys":{"platform":"mac"}}`)}, constants.StatusHandshake, ""}, + {"invalid_handshake_data", &packet.Packet{Type: packet.Handshake, Data: []byte("asiodjasd")}, constants.StatusClosed, nil, "invalid handshake data"}, + {"validator_error", &packet.Packet{Type: packet.Handshake, Data: []byte(`{"sys":{"platform":"mac"}}`)}, constants.StatusClosed, func(data *session.HandshakeData) error { return errors.New("validation failed") }, "handshake validation failed"}, + {"valid_handshake_data", &packet.Packet{Type: packet.Handshake, Data: []byte(`{"sys":{"platform":"mac"}}`)}, constants.StatusHandshake, func(data *session.HandshakeData) error { return nil }, ""}, } for _, table := range tables { t.Run(table.name, func(t *testing.T) { @@ -267,21 +269,28 @@ func TestHandlerServiceProcessPacketHandshake(t *testing.T) { mockAgent := agentmocks.NewMockAgent(ctrl) mockAgent.EXPECT().GetSession().Return(mockSession).Times(1) - mockAgent.EXPECT().RemoteAddr().Return(&mockAddr{}) - mockAgent.EXPECT().SetStatus(table.socketStatus).Times(1) - mockAgent.EXPECT().SendHandshakeResponse().Return(nil).Times(1) + + if table.validator != nil { + mockAgent.EXPECT().GetSession().Return(mockSession).Times(1) + mockSession.EXPECT().ValidateHandshake(gomock.Any()).DoAndReturn(func(data *session.HandshakeData) error { + return table.validator(data) + }).Times(1) + } if table.errStr == "" { handshakeData := &session.HandshakeData{} _ = encjson.Unmarshal(table.packet.Data, handshakeData) mockAgent.EXPECT().GetSession().Return(mockSession).Times(2) mockAgent.EXPECT().IPVersion().Return(constants.IPv4).Times(1) + mockAgent.EXPECT().RemoteAddr().Return(&mockAddr{}) + mockAgent.EXPECT().SetStatus(table.socketStatus).Times(1) + mockAgent.EXPECT().SendHandshakeResponse().Return(nil).Times(1) + mockAgent.EXPECT().SetLastAt().Times(1) + mockSession.EXPECT().SetHandshakeData(handshakeData).Times(1) mockSession.EXPECT().Set(constants.IPVersionKey, constants.IPv4).Times(1) - mockAgent.EXPECT().SetLastAt().Times(1) } else { - mockAgent.EXPECT().GetSession().Return(mockSession).Times(1) - mockSession.EXPECT().ID().Return(int64(1)).Times(1) + mockAgent.EXPECT().SendHandshakeErrorResponse().Times(1) } handlerPool := NewHandlerPool() @@ -409,6 +418,7 @@ func TestHandlerServiceHandle(t *testing.T) { mockSession := mocks.NewMockSession(ctrl) mockSession.EXPECT().SetHandshakeData(gomock.Any()).Times(1) + mockSession.EXPECT().ValidateHandshake(gomock.Any()).Times(1) mockSession.EXPECT().UID().Return("uid").Times(1) mockSession.EXPECT().ID().Return(int64(1)).Times(2) mockSession.EXPECT().Set(constants.IPVersionKey, constants.IPv4) @@ -416,7 +426,7 @@ func TestHandlerServiceHandle(t *testing.T) { mockAgent.EXPECT().String().Return("") mockAgent.EXPECT().SetStatus(constants.StatusHandshake) - mockAgent.EXPECT().GetSession().Return(mockSession).Times(6) + mockAgent.EXPECT().GetSession().Return(mockSession).Times(7) mockAgent.EXPECT().IPVersion().Return(constants.IPv4) mockAgent.EXPECT().RemoteAddr().Return(&mockAddr{}).AnyTimes() mockAgent.EXPECT().SetLastAt().Do(func() { diff --git a/session/mocks/session.go b/session/mocks/session.go index 4a35593a..a5b06b81 100644 --- a/session/mocks/session.go +++ b/session/mocks/session.go @@ -15,152 +15,6 @@ import ( session "github.com/topfreegames/pitaya/v2/session" ) -// MockSessionPool is a mock of SessionPool interface. -type MockSessionPool struct { - ctrl *gomock.Controller - recorder *MockSessionPoolMockRecorder -} - -// MockSessionPoolMockRecorder is the mock recorder for MockSessionPool. -type MockSessionPoolMockRecorder struct { - mock *MockSessionPool -} - -// NewMockSessionPool creates a new mock instance. -func NewMockSessionPool(ctrl *gomock.Controller) *MockSessionPool { - mock := &MockSessionPool{ctrl: ctrl} - mock.recorder = &MockSessionPoolMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSessionPool) EXPECT() *MockSessionPoolMockRecorder { - return m.recorder -} - -// CloseAll mocks base method. -func (m *MockSessionPool) CloseAll() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "CloseAll") -} - -// CloseAll indicates an expected call of CloseAll. -func (mr *MockSessionPoolMockRecorder) CloseAll() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseAll", reflect.TypeOf((*MockSessionPool)(nil).CloseAll)) -} - -// GetSessionByID mocks base method. -func (m *MockSessionPool) GetSessionByID(id int64) session.Session { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSessionByID", id) - ret0, _ := ret[0].(session.Session) - return ret0 -} - -// GetSessionByID indicates an expected call of GetSessionByID. -func (mr *MockSessionPoolMockRecorder) GetSessionByID(id interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionByID", reflect.TypeOf((*MockSessionPool)(nil).GetSessionByID), id) -} - -// GetSessionByUID mocks base method. -func (m *MockSessionPool) GetSessionByUID(uid string) session.Session { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSessionByUID", uid) - ret0, _ := ret[0].(session.Session) - return ret0 -} - -// GetSessionByUID indicates an expected call of GetSessionByUID. -func (mr *MockSessionPoolMockRecorder) GetSessionByUID(uid interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionByUID", reflect.TypeOf((*MockSessionPool)(nil).GetSessionByUID), uid) -} - -// GetSessionCloseCallbacks mocks base method. -func (m *MockSessionPool) GetSessionCloseCallbacks() []func(session.Session) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSessionCloseCallbacks") - ret0, _ := ret[0].([]func(session.Session)) - return ret0 -} - -// GetSessionCloseCallbacks indicates an expected call of GetSessionCloseCallbacks. -func (mr *MockSessionPoolMockRecorder) GetSessionCloseCallbacks() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionCloseCallbacks", reflect.TypeOf((*MockSessionPool)(nil).GetSessionCloseCallbacks)) -} - -// GetSessionCount mocks base method. -func (m *MockSessionPool) GetSessionCount() int64 { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSessionCount") - ret0, _ := ret[0].(int64) - return ret0 -} - -// GetSessionCount indicates an expected call of GetSessionCount. -func (mr *MockSessionPoolMockRecorder) GetSessionCount() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionCount", reflect.TypeOf((*MockSessionPool)(nil).GetSessionCount)) -} - -// NewSession mocks base method. -func (m *MockSessionPool) NewSession(entity networkentity.NetworkEntity, frontend bool, UID ...string) session.Session { - m.ctrl.T.Helper() - varargs := []interface{}{entity, frontend} - for _, a := range UID { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "NewSession", varargs...) - ret0, _ := ret[0].(session.Session) - return ret0 -} - -// NewSession indicates an expected call of NewSession. -func (mr *MockSessionPoolMockRecorder) NewSession(entity, frontend interface{}, UID ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{entity, frontend}, UID...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSession", reflect.TypeOf((*MockSessionPool)(nil).NewSession), varargs...) -} - -// OnAfterSessionBind mocks base method. -func (m *MockSessionPool) OnAfterSessionBind(f func(context.Context, session.Session) error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnAfterSessionBind", f) -} - -// OnAfterSessionBind indicates an expected call of OnAfterSessionBind. -func (mr *MockSessionPoolMockRecorder) OnAfterSessionBind(f interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAfterSessionBind", reflect.TypeOf((*MockSessionPool)(nil).OnAfterSessionBind), f) -} - -// OnSessionBind mocks base method. -func (m *MockSessionPool) OnSessionBind(f func(context.Context, session.Session) error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnSessionBind", f) -} - -// OnSessionBind indicates an expected call of OnSessionBind. -func (mr *MockSessionPoolMockRecorder) OnSessionBind(f interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnSessionBind", reflect.TypeOf((*MockSessionPool)(nil).OnSessionBind), f) -} - -// OnSessionClose mocks base method. -func (m *MockSessionPool) OnSessionClose(f func(session.Session)) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnSessionClose", f) -} - -// OnSessionClose indicates an expected call of OnSessionClose. -func (mr *MockSessionPoolMockRecorder) OnSessionClose(f interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnSessionClose", reflect.TypeOf((*MockSessionPool)(nil).OnSessionClose), f) -} - // MockSession is a mock of Session interface. type MockSession struct { ctrl *gomock.Controller @@ -185,17 +39,17 @@ func (m *MockSession) EXPECT() *MockSessionMockRecorder { } // Bind mocks base method. -func (m *MockSession) Bind(ctx context.Context, uid string) error { +func (m *MockSession) Bind(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Bind", ctx, uid) + ret := m.ctrl.Call(m, "Bind", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // Bind indicates an expected call of Bind. -func (mr *MockSessionMockRecorder) Bind(ctx, uid interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Bind(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bind", reflect.TypeOf((*MockSession)(nil).Bind), ctx, uid) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Bind", reflect.TypeOf((*MockSession)(nil).Bind), arg0, arg1) } // Clear mocks base method. @@ -223,45 +77,45 @@ func (mr *MockSessionMockRecorder) Close() *gomock.Call { } // Float32 mocks base method. -func (m *MockSession) Float32(key string) float32 { +func (m *MockSession) Float32(arg0 string) float32 { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Float32", key) + ret := m.ctrl.Call(m, "Float32", arg0) ret0, _ := ret[0].(float32) return ret0 } // Float32 indicates an expected call of Float32. -func (mr *MockSessionMockRecorder) Float32(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Float32(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Float32", reflect.TypeOf((*MockSession)(nil).Float32), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Float32", reflect.TypeOf((*MockSession)(nil).Float32), arg0) } // Float64 mocks base method. -func (m *MockSession) Float64(key string) float64 { +func (m *MockSession) Float64(arg0 string) float64 { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Float64", key) + ret := m.ctrl.Call(m, "Float64", arg0) ret0, _ := ret[0].(float64) return ret0 } // Float64 indicates an expected call of Float64. -func (mr *MockSessionMockRecorder) Float64(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Float64(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Float64", reflect.TypeOf((*MockSession)(nil).Float64), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Float64", reflect.TypeOf((*MockSession)(nil).Float64), arg0) } // Get mocks base method. -func (m *MockSession) Get(key string) interface{} { +func (m *MockSession) Get(arg0 string) interface{} { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Get", key) + ret := m.ctrl.Call(m, "Get", arg0) ret0, _ := ret[0].(interface{}) return ret0 } // Get indicates an expected call of Get. -func (mr *MockSessionMockRecorder) Get(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Get(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSession)(nil).Get), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSession)(nil).Get), arg0) } // GetData mocks base method. @@ -306,6 +160,20 @@ func (mr *MockSessionMockRecorder) GetHandshakeData() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeData", reflect.TypeOf((*MockSession)(nil).GetHandshakeData)) } +// GetHandshakeValidators mocks base method. +func (m *MockSession) GetHandshakeValidators() map[string]func(*session.HandshakeData) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetHandshakeValidators") + ret0, _ := ret[0].(map[string]func(*session.HandshakeData) error) + return ret0 +} + +// GetHandshakeValidators indicates an expected call of GetHandshakeValidators. +func (mr *MockSessionMockRecorder) GetHandshakeValidators() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHandshakeValidators", reflect.TypeOf((*MockSession)(nil).GetHandshakeValidators)) +} + // GetIsFrontend mocks base method. func (m *MockSession) GetIsFrontend() bool { m.ctrl.T.Helper() @@ -363,17 +231,17 @@ func (mr *MockSessionMockRecorder) GetSubscriptions() *gomock.Call { } // HasKey mocks base method. -func (m *MockSession) HasKey(key string) bool { +func (m *MockSession) HasKey(arg0 string) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HasKey", key) + ret := m.ctrl.Call(m, "HasKey", arg0) ret0, _ := ret[0].(bool) return ret0 } // HasKey indicates an expected call of HasKey. -func (mr *MockSessionMockRecorder) HasKey(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) HasKey(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasKey", reflect.TypeOf((*MockSession)(nil).HasKey), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasKey", reflect.TypeOf((*MockSession)(nil).HasKey), arg0) } // HasRequestsInFlight mocks base method. @@ -405,129 +273,129 @@ func (mr *MockSessionMockRecorder) ID() *gomock.Call { } // Int mocks base method. -func (m *MockSession) Int(key string) int { +func (m *MockSession) Int(arg0 string) int { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Int", key) + ret := m.ctrl.Call(m, "Int", arg0) ret0, _ := ret[0].(int) return ret0 } // Int indicates an expected call of Int. -func (mr *MockSessionMockRecorder) Int(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Int(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Int", reflect.TypeOf((*MockSession)(nil).Int), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Int", reflect.TypeOf((*MockSession)(nil).Int), arg0) } // Int16 mocks base method. -func (m *MockSession) Int16(key string) int16 { +func (m *MockSession) Int16(arg0 string) int16 { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Int16", key) + ret := m.ctrl.Call(m, "Int16", arg0) ret0, _ := ret[0].(int16) return ret0 } // Int16 indicates an expected call of Int16. -func (mr *MockSessionMockRecorder) Int16(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Int16(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Int16", reflect.TypeOf((*MockSession)(nil).Int16), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Int16", reflect.TypeOf((*MockSession)(nil).Int16), arg0) } // Int32 mocks base method. -func (m *MockSession) Int32(key string) int32 { +func (m *MockSession) Int32(arg0 string) int32 { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Int32", key) + ret := m.ctrl.Call(m, "Int32", arg0) ret0, _ := ret[0].(int32) return ret0 } // Int32 indicates an expected call of Int32. -func (mr *MockSessionMockRecorder) Int32(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Int32(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Int32", reflect.TypeOf((*MockSession)(nil).Int32), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Int32", reflect.TypeOf((*MockSession)(nil).Int32), arg0) } // Int64 mocks base method. -func (m *MockSession) Int64(key string) int64 { +func (m *MockSession) Int64(arg0 string) int64 { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Int64", key) + ret := m.ctrl.Call(m, "Int64", arg0) ret0, _ := ret[0].(int64) return ret0 } // Int64 indicates an expected call of Int64. -func (mr *MockSessionMockRecorder) Int64(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Int64(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Int64", reflect.TypeOf((*MockSession)(nil).Int64), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Int64", reflect.TypeOf((*MockSession)(nil).Int64), arg0) } // Int8 mocks base method. -func (m *MockSession) Int8(key string) int8 { +func (m *MockSession) Int8(arg0 string) int8 { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Int8", key) + ret := m.ctrl.Call(m, "Int8", arg0) ret0, _ := ret[0].(int8) return ret0 } // Int8 indicates an expected call of Int8. -func (mr *MockSessionMockRecorder) Int8(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Int8(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Int8", reflect.TypeOf((*MockSession)(nil).Int8), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Int8", reflect.TypeOf((*MockSession)(nil).Int8), arg0) } // Kick mocks base method. -func (m *MockSession) Kick(ctx context.Context) error { +func (m *MockSession) Kick(arg0 context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Kick", ctx) + ret := m.ctrl.Call(m, "Kick", arg0) ret0, _ := ret[0].(error) return ret0 } // Kick indicates an expected call of Kick. -func (mr *MockSessionMockRecorder) Kick(ctx interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Kick(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Kick", reflect.TypeOf((*MockSession)(nil).Kick), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Kick", reflect.TypeOf((*MockSession)(nil).Kick), arg0) } // OnClose mocks base method. -func (m *MockSession) OnClose(c func()) error { +func (m *MockSession) OnClose(arg0 func()) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OnClose", c) + ret := m.ctrl.Call(m, "OnClose", arg0) ret0, _ := ret[0].(error) return ret0 } // OnClose indicates an expected call of OnClose. -func (mr *MockSessionMockRecorder) OnClose(c interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) OnClose(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnClose", reflect.TypeOf((*MockSession)(nil).OnClose), c) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnClose", reflect.TypeOf((*MockSession)(nil).OnClose), arg0) } // Push mocks base method. -func (m *MockSession) Push(route string, v interface{}) error { +func (m *MockSession) Push(arg0 string, arg1 interface{}) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Push", route, v) + ret := m.ctrl.Call(m, "Push", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // Push indicates an expected call of Push. -func (mr *MockSessionMockRecorder) Push(route, v interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Push(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Push", reflect.TypeOf((*MockSession)(nil).Push), route, v) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Push", reflect.TypeOf((*MockSession)(nil).Push), arg0, arg1) } // PushToFront mocks base method. -func (m *MockSession) PushToFront(ctx context.Context) error { +func (m *MockSession) PushToFront(arg0 context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "PushToFront", ctx) + ret := m.ctrl.Call(m, "PushToFront", arg0) ret0, _ := ret[0].(error) return ret0 } // PushToFront indicates an expected call of PushToFront. -func (mr *MockSessionMockRecorder) PushToFront(ctx interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) PushToFront(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PushToFront", reflect.TypeOf((*MockSession)(nil).PushToFront), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PushToFront", reflect.TypeOf((*MockSession)(nil).PushToFront), arg0) } // RemoteAddr mocks base method. @@ -545,24 +413,24 @@ func (mr *MockSessionMockRecorder) RemoteAddr() *gomock.Call { } // Remove mocks base method. -func (m *MockSession) Remove(key string) error { +func (m *MockSession) Remove(arg0 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Remove", key) + ret := m.ctrl.Call(m, "Remove", arg0) ret0, _ := ret[0].(error) return ret0 } // Remove indicates an expected call of Remove. -func (mr *MockSessionMockRecorder) Remove(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Remove(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSession)(nil).Remove), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSession)(nil).Remove), arg0) } // ResponseMID mocks base method. -func (m *MockSession) ResponseMID(ctx context.Context, mid uint, v interface{}, err ...bool) error { +func (m *MockSession) ResponseMID(arg0 context.Context, arg1 uint, arg2 interface{}, arg3 ...bool) error { m.ctrl.T.Helper() - varargs := []interface{}{ctx, mid, v} - for _, a := range err { + varargs := []interface{}{arg0, arg1, arg2} + for _, a := range arg3 { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "ResponseMID", varargs...) @@ -571,138 +439,138 @@ func (m *MockSession) ResponseMID(ctx context.Context, mid uint, v interface{}, } // ResponseMID indicates an expected call of ResponseMID. -func (mr *MockSessionMockRecorder) ResponseMID(ctx, mid, v interface{}, err ...interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) ResponseMID(arg0, arg1, arg2 interface{}, arg3 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, mid, v}, err...) + varargs := append([]interface{}{arg0, arg1, arg2}, arg3...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResponseMID", reflect.TypeOf((*MockSession)(nil).ResponseMID), varargs...) } // Set mocks base method. -func (m *MockSession) Set(key string, value interface{}) error { +func (m *MockSession) Set(arg0 string, arg1 interface{}) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Set", key, value) + ret := m.ctrl.Call(m, "Set", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // Set indicates an expected call of Set. -func (mr *MockSessionMockRecorder) Set(key, value interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Set(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockSession)(nil).Set), key, value) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockSession)(nil).Set), arg0, arg1) } // SetData mocks base method. -func (m *MockSession) SetData(data map[string]interface{}) error { +func (m *MockSession) SetData(arg0 map[string]interface{}) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetData", data) + ret := m.ctrl.Call(m, "SetData", arg0) ret0, _ := ret[0].(error) return ret0 } // SetData indicates an expected call of SetData. -func (mr *MockSessionMockRecorder) SetData(data interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) SetData(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetData", reflect.TypeOf((*MockSession)(nil).SetData), data) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetData", reflect.TypeOf((*MockSession)(nil).SetData), arg0) } // SetDataEncoded mocks base method. -func (m *MockSession) SetDataEncoded(encodedData []byte) error { +func (m *MockSession) SetDataEncoded(arg0 []byte) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetDataEncoded", encodedData) + ret := m.ctrl.Call(m, "SetDataEncoded", arg0) ret0, _ := ret[0].(error) return ret0 } // SetDataEncoded indicates an expected call of SetDataEncoded. -func (mr *MockSessionMockRecorder) SetDataEncoded(encodedData interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) SetDataEncoded(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDataEncoded", reflect.TypeOf((*MockSession)(nil).SetDataEncoded), encodedData) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDataEncoded", reflect.TypeOf((*MockSession)(nil).SetDataEncoded), arg0) } // SetFrontendData mocks base method. -func (m *MockSession) SetFrontendData(frontendID string, frontendSessionID int64) { +func (m *MockSession) SetFrontendData(arg0 string, arg1 int64) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetFrontendData", frontendID, frontendSessionID) + m.ctrl.Call(m, "SetFrontendData", arg0, arg1) } // SetFrontendData indicates an expected call of SetFrontendData. -func (mr *MockSessionMockRecorder) SetFrontendData(frontendID, frontendSessionID interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) SetFrontendData(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFrontendData", reflect.TypeOf((*MockSession)(nil).SetFrontendData), frontendID, frontendSessionID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFrontendData", reflect.TypeOf((*MockSession)(nil).SetFrontendData), arg0, arg1) } // SetHandshakeData mocks base method. -func (m *MockSession) SetHandshakeData(data *session.HandshakeData) { +func (m *MockSession) SetHandshakeData(arg0 *session.HandshakeData) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetHandshakeData", data) + m.ctrl.Call(m, "SetHandshakeData", arg0) } // SetHandshakeData indicates an expected call of SetHandshakeData. -func (mr *MockSessionMockRecorder) SetHandshakeData(data interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) SetHandshakeData(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeData", reflect.TypeOf((*MockSession)(nil).SetHandshakeData), data) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHandshakeData", reflect.TypeOf((*MockSession)(nil).SetHandshakeData), arg0) } // SetIsFrontend mocks base method. -func (m *MockSession) SetIsFrontend(isFrontend bool) { +func (m *MockSession) SetIsFrontend(arg0 bool) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetIsFrontend", isFrontend) + m.ctrl.Call(m, "SetIsFrontend", arg0) } // SetIsFrontend indicates an expected call of SetIsFrontend. -func (mr *MockSessionMockRecorder) SetIsFrontend(isFrontend interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) SetIsFrontend(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetIsFrontend", reflect.TypeOf((*MockSession)(nil).SetIsFrontend), isFrontend) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetIsFrontend", reflect.TypeOf((*MockSession)(nil).SetIsFrontend), arg0) } // SetOnCloseCallbacks mocks base method. -func (m *MockSession) SetOnCloseCallbacks(callbacks []func()) { +func (m *MockSession) SetOnCloseCallbacks(arg0 []func()) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetOnCloseCallbacks", callbacks) + m.ctrl.Call(m, "SetOnCloseCallbacks", arg0) } // SetOnCloseCallbacks indicates an expected call of SetOnCloseCallbacks. -func (mr *MockSessionMockRecorder) SetOnCloseCallbacks(callbacks interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) SetOnCloseCallbacks(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetOnCloseCallbacks", reflect.TypeOf((*MockSession)(nil).SetOnCloseCallbacks), callbacks) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetOnCloseCallbacks", reflect.TypeOf((*MockSession)(nil).SetOnCloseCallbacks), arg0) } // SetRequestInFlight mocks base method. -func (m *MockSession) SetRequestInFlight(reqID, reqData string, inFlight bool) { +func (m *MockSession) SetRequestInFlight(arg0, arg1 string, arg2 bool) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetRequestInFlight", reqID, reqData, inFlight) + m.ctrl.Call(m, "SetRequestInFlight", arg0, arg1, arg2) } // SetRequestInFlight indicates an expected call of SetRequestInFlight. -func (mr *MockSessionMockRecorder) SetRequestInFlight(reqID, reqData, inFlight interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) SetRequestInFlight(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetRequestInFlight", reflect.TypeOf((*MockSession)(nil).SetRequestInFlight), reqID, reqData, inFlight) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetRequestInFlight", reflect.TypeOf((*MockSession)(nil).SetRequestInFlight), arg0, arg1, arg2) } // SetSubscriptions mocks base method. -func (m *MockSession) SetSubscriptions(subscriptions []*nats.Subscription) { +func (m *MockSession) SetSubscriptions(arg0 []*nats.Subscription) { m.ctrl.T.Helper() - m.ctrl.Call(m, "SetSubscriptions", subscriptions) + m.ctrl.Call(m, "SetSubscriptions", arg0) } // SetSubscriptions indicates an expected call of SetSubscriptions. -func (mr *MockSessionMockRecorder) SetSubscriptions(subscriptions interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) SetSubscriptions(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSubscriptions", reflect.TypeOf((*MockSession)(nil).SetSubscriptions), subscriptions) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSubscriptions", reflect.TypeOf((*MockSession)(nil).SetSubscriptions), arg0) } // String mocks base method. -func (m *MockSession) String(key string) string { +func (m *MockSession) String(arg0 string) string { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "String", key) + ret := m.ctrl.Call(m, "String", arg0) ret0, _ := ret[0].(string) return ret0 } // String indicates an expected call of String. -func (mr *MockSessionMockRecorder) String(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) String(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*MockSession)(nil).String), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*MockSession)(nil).String), arg0) } // UID mocks base method. @@ -720,85 +588,257 @@ func (mr *MockSessionMockRecorder) UID() *gomock.Call { } // Uint mocks base method. -func (m *MockSession) Uint(key string) uint { +func (m *MockSession) Uint(arg0 string) uint { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Uint", key) + ret := m.ctrl.Call(m, "Uint", arg0) ret0, _ := ret[0].(uint) return ret0 } // Uint indicates an expected call of Uint. -func (mr *MockSessionMockRecorder) Uint(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Uint(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Uint", reflect.TypeOf((*MockSession)(nil).Uint), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Uint", reflect.TypeOf((*MockSession)(nil).Uint), arg0) } // Uint16 mocks base method. -func (m *MockSession) Uint16(key string) uint16 { +func (m *MockSession) Uint16(arg0 string) uint16 { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Uint16", key) + ret := m.ctrl.Call(m, "Uint16", arg0) ret0, _ := ret[0].(uint16) return ret0 } // Uint16 indicates an expected call of Uint16. -func (mr *MockSessionMockRecorder) Uint16(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Uint16(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Uint16", reflect.TypeOf((*MockSession)(nil).Uint16), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Uint16", reflect.TypeOf((*MockSession)(nil).Uint16), arg0) } // Uint32 mocks base method. -func (m *MockSession) Uint32(key string) uint32 { +func (m *MockSession) Uint32(arg0 string) uint32 { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Uint32", key) + ret := m.ctrl.Call(m, "Uint32", arg0) ret0, _ := ret[0].(uint32) return ret0 } // Uint32 indicates an expected call of Uint32. -func (mr *MockSessionMockRecorder) Uint32(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Uint32(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Uint32", reflect.TypeOf((*MockSession)(nil).Uint32), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Uint32", reflect.TypeOf((*MockSession)(nil).Uint32), arg0) } // Uint64 mocks base method. -func (m *MockSession) Uint64(key string) uint64 { +func (m *MockSession) Uint64(arg0 string) uint64 { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Uint64", key) + ret := m.ctrl.Call(m, "Uint64", arg0) ret0, _ := ret[0].(uint64) return ret0 } // Uint64 indicates an expected call of Uint64. -func (mr *MockSessionMockRecorder) Uint64(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Uint64(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Uint64", reflect.TypeOf((*MockSession)(nil).Uint64), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Uint64", reflect.TypeOf((*MockSession)(nil).Uint64), arg0) } // Uint8 mocks base method. -func (m *MockSession) Uint8(key string) uint8 { +func (m *MockSession) Uint8(arg0 string) byte { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Uint8", key) - ret0, _ := ret[0].(uint8) + ret := m.ctrl.Call(m, "Uint8", arg0) + ret0, _ := ret[0].(byte) return ret0 } // Uint8 indicates an expected call of Uint8. -func (mr *MockSessionMockRecorder) Uint8(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Uint8(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Uint8", reflect.TypeOf((*MockSession)(nil).Uint8), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Uint8", reflect.TypeOf((*MockSession)(nil).Uint8), arg0) +} + +// ValidateHandshake mocks base method. +func (m *MockSession) ValidateHandshake(arg0 *session.HandshakeData) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateHandshake", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateHandshake indicates an expected call of ValidateHandshake. +func (mr *MockSessionMockRecorder) ValidateHandshake(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateHandshake", reflect.TypeOf((*MockSession)(nil).ValidateHandshake), arg0) } // Value mocks base method. -func (m *MockSession) Value(key string) interface{} { +func (m *MockSession) Value(arg0 string) interface{} { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Value", key) + ret := m.ctrl.Call(m, "Value", arg0) ret0, _ := ret[0].(interface{}) return ret0 } // Value indicates an expected call of Value. -func (mr *MockSessionMockRecorder) Value(key interface{}) *gomock.Call { +func (mr *MockSessionMockRecorder) Value(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Value", reflect.TypeOf((*MockSession)(nil).Value), arg0) +} + +// MockSessionPool is a mock of SessionPool interface. +type MockSessionPool struct { + ctrl *gomock.Controller + recorder *MockSessionPoolMockRecorder +} + +// MockSessionPoolMockRecorder is the mock recorder for MockSessionPool. +type MockSessionPoolMockRecorder struct { + mock *MockSessionPool +} + +// NewMockSessionPool creates a new mock instance. +func NewMockSessionPool(ctrl *gomock.Controller) *MockSessionPool { + mock := &MockSessionPool{ctrl: ctrl} + mock.recorder = &MockSessionPoolMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSessionPool) EXPECT() *MockSessionPoolMockRecorder { + return m.recorder +} + +// AddHandshakeValidator mocks base method. +func (m *MockSessionPool) AddHandshakeValidator(arg0 string, arg1 func(*session.HandshakeData) error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddHandshakeValidator", arg0, arg1) +} + +// AddHandshakeValidator indicates an expected call of AddHandshakeValidator. +func (mr *MockSessionPoolMockRecorder) AddHandshakeValidator(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddHandshakeValidator", reflect.TypeOf((*MockSessionPool)(nil).AddHandshakeValidator), arg0, arg1) +} + +// CloseAll mocks base method. +func (m *MockSessionPool) CloseAll() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "CloseAll") +} + +// CloseAll indicates an expected call of CloseAll. +func (mr *MockSessionPoolMockRecorder) CloseAll() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseAll", reflect.TypeOf((*MockSessionPool)(nil).CloseAll)) +} + +// GetSessionByID mocks base method. +func (m *MockSessionPool) GetSessionByID(arg0 int64) session.Session { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSessionByID", arg0) + ret0, _ := ret[0].(session.Session) + return ret0 +} + +// GetSessionByID indicates an expected call of GetSessionByID. +func (mr *MockSessionPoolMockRecorder) GetSessionByID(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionByID", reflect.TypeOf((*MockSessionPool)(nil).GetSessionByID), arg0) +} + +// GetSessionByUID mocks base method. +func (m *MockSessionPool) GetSessionByUID(arg0 string) session.Session { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSessionByUID", arg0) + ret0, _ := ret[0].(session.Session) + return ret0 +} + +// GetSessionByUID indicates an expected call of GetSessionByUID. +func (mr *MockSessionPoolMockRecorder) GetSessionByUID(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionByUID", reflect.TypeOf((*MockSessionPool)(nil).GetSessionByUID), arg0) +} + +// GetSessionCloseCallbacks mocks base method. +func (m *MockSessionPool) GetSessionCloseCallbacks() []func(session.Session) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSessionCloseCallbacks") + ret0, _ := ret[0].([]func(session.Session)) + return ret0 +} + +// GetSessionCloseCallbacks indicates an expected call of GetSessionCloseCallbacks. +func (mr *MockSessionPoolMockRecorder) GetSessionCloseCallbacks() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionCloseCallbacks", reflect.TypeOf((*MockSessionPool)(nil).GetSessionCloseCallbacks)) +} + +// GetSessionCount mocks base method. +func (m *MockSessionPool) GetSessionCount() int64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSessionCount") + ret0, _ := ret[0].(int64) + return ret0 +} + +// GetSessionCount indicates an expected call of GetSessionCount. +func (mr *MockSessionPoolMockRecorder) GetSessionCount() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSessionCount", reflect.TypeOf((*MockSessionPool)(nil).GetSessionCount)) +} + +// NewSession mocks base method. +func (m *MockSessionPool) NewSession(arg0 networkentity.NetworkEntity, arg1 bool, arg2 ...string) session.Session { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "NewSession", varargs...) + ret0, _ := ret[0].(session.Session) + return ret0 +} + +// NewSession indicates an expected call of NewSession. +func (mr *MockSessionPoolMockRecorder) NewSession(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSession", reflect.TypeOf((*MockSessionPool)(nil).NewSession), varargs...) +} + +// OnAfterSessionBind mocks base method. +func (m *MockSessionPool) OnAfterSessionBind(arg0 func(context.Context, session.Session) error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnAfterSessionBind", arg0) +} + +// OnAfterSessionBind indicates an expected call of OnAfterSessionBind. +func (mr *MockSessionPoolMockRecorder) OnAfterSessionBind(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnAfterSessionBind", reflect.TypeOf((*MockSessionPool)(nil).OnAfterSessionBind), arg0) +} + +// OnSessionBind mocks base method. +func (m *MockSessionPool) OnSessionBind(arg0 func(context.Context, session.Session) error) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnSessionBind", arg0) +} + +// OnSessionBind indicates an expected call of OnSessionBind. +func (mr *MockSessionPoolMockRecorder) OnSessionBind(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnSessionBind", reflect.TypeOf((*MockSessionPool)(nil).OnSessionBind), arg0) +} + +// OnSessionClose mocks base method. +func (m *MockSessionPool) OnSessionClose(arg0 func(session.Session)) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnSessionClose", arg0) +} + +// OnSessionClose indicates an expected call of OnSessionClose. +func (mr *MockSessionPoolMockRecorder) OnSessionClose(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Value", reflect.TypeOf((*MockSession)(nil).Value), key) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnSessionClose", reflect.TypeOf((*MockSessionPool)(nil).OnSessionClose), arg0) } diff --git a/session/session.go b/session/session.go index cd567dc0..60307cf9 100644 --- a/session/session.go +++ b/session/session.go @@ -23,6 +23,7 @@ package session import ( "context" "encoding/json" + "fmt" "net" "reflect" "sync" @@ -40,6 +41,8 @@ import ( type sessionPoolImpl struct { sessionBindCallbacks []func(ctx context.Context, s Session) error afterBindCallbacks []func(ctx context.Context, s Session) error + handshakeValidators map[string]func(data *HandshakeData) error + // SessionCloseCallbacks contains global session close callbacks SessionCloseCallbacks []func(s Session) sessionsByUID sync.Map @@ -60,6 +63,7 @@ type SessionPool interface { OnAfterSessionBind(f func(ctx context.Context, s Session) error) OnSessionClose(f func(s Session)) CloseAll() + AddHandshakeValidator(name string, f func(data *HandshakeData) error) } // HandshakeClientData represents information about the client sent on the handshake. @@ -79,25 +83,26 @@ type HandshakeData struct { } type sessionImpl struct { - sync.RWMutex // protect data - id int64 // session global unique id - uid string // binding user id - lastTime int64 // last heartbeat time - entity networkentity.NetworkEntity // low-level network entity - data map[string]interface{} // session data store - handshakeData *HandshakeData // handshake data received by the client - encodedData []byte // session data encoded as a byte array - OnCloseCallbacks []func() //onClose callbacks - IsFrontend bool // if session is a frontend session - frontendID string // the id of the frontend that owns the session - frontendSessionID int64 // the id of the session on the frontend server - Subscriptions []*nats.Subscription // subscription created on bind when using nats rpc server - requestsInFlight ReqInFlight // whether the session is waiting from a response from a remote - pool *sessionPoolImpl + sync.RWMutex // protect data + id int64 // session global unique id + uid string // binding user id + lastTime int64 // last heartbeat time + entity networkentity.NetworkEntity // low-level network entity + data map[string]interface{} // session data store + handshakeData *HandshakeData // handshake data received by the client + handshakeValidators map[string]func(*HandshakeData) error // validations to run on handshake + encodedData []byte // session data encoded as a byte array + OnCloseCallbacks []func() //onClose callbacks + IsFrontend bool // if session is a frontend session + frontendID string // the id of the frontend that owns the session + frontendSessionID int64 // the id of the session on the frontend server + Subscriptions []*nats.Subscription // subscription created on bind when using nats rpc server + requestsInFlight ReqInFlight // whether the session is waiting from a response from a remote + pool *sessionPoolImpl } type ReqInFlight struct { - m map[string]string + m map[string]string mu sync.RWMutex } @@ -152,6 +157,8 @@ type Session interface { Clear() SetHandshakeData(data *HandshakeData) GetHandshakeData() *HandshakeData + ValidateHandshake(data *HandshakeData) error + GetHandshakeValidators() map[string]func(data *HandshakeData) error } type sessionIDService struct { @@ -173,15 +180,16 @@ func (c *sessionIDService) sessionID() int64 { // a networkentity.NetworkEntity is a low-level network instance func (pool *sessionPoolImpl) NewSession(entity networkentity.NetworkEntity, frontend bool, UID ...string) Session { s := &sessionImpl{ - id: pool.sessionIDSvc.sessionID(), - entity: entity, - data: make(map[string]interface{}), - handshakeData: nil, - lastTime: time.Now().Unix(), - OnCloseCallbacks: []func(){}, - IsFrontend: frontend, - pool: pool, - requestsInFlight: ReqInFlight{m: make(map[string]string)}, + id: pool.sessionIDSvc.sessionID(), + entity: entity, + data: make(map[string]interface{}), + handshakeData: nil, + handshakeValidators: pool.handshakeValidators, + lastTime: time.Now().Unix(), + OnCloseCallbacks: []func(){}, + IsFrontend: frontend, + pool: pool, + requestsInFlight: ReqInFlight{m: make(map[string]string)}, } if frontend { pool.sessionsByID.Store(s.id, s) @@ -198,6 +206,7 @@ func NewSessionPool() SessionPool { return &sessionPoolImpl{ sessionBindCallbacks: make([]func(ctx context.Context, s Session) error, 0), afterBindCallbacks: make([]func(ctx context.Context, s Session) error, 0), + handshakeValidators: make(map[string]func(data *HandshakeData) error, 0), SessionCloseCallbacks: make([]func(s Session), 0), sessionIDSvc: newSessionIDService(), } @@ -277,7 +286,7 @@ func (pool *sessionPoolImpl) CloseAll() { if s.HasRequestsInFlight() { reqsInFlight := s.GetRequestsInFlight() reqsInFlight.mu.RLock() - for _,route := range reqsInFlight.m { + for _, route := range reqsInFlight.m { logger.Log.Debugf("Session for user %s is waiting on a response for route %s from a remote server. Delaying session close.", s.UID(), route) } reqsInFlight.mu.RUnlock() @@ -295,6 +304,12 @@ func (pool *sessionPoolImpl) CloseAll() { logger.Log.Info("finished closing sessions") } +// AddHandshakeValidator allows adds validation functions that will run when +// handshake packets are processed. Errors will be raised with the given name. +func (pool *sessionPoolImpl) AddHandshakeValidator(name string, f func(data *HandshakeData) error) { + pool.handshakeValidators[name] = f +} + func (s *sessionImpl) updateEncodedData() error { var b []byte b, err := json.Marshal(s.data) @@ -788,6 +803,21 @@ func (s *sessionImpl) GetHandshakeData() *HandshakeData { return s.handshakeData } +// GetHandshakeValidators return the handshake validators associated with the session. +func (s *sessionImpl) GetHandshakeValidators() map[string]func(data *HandshakeData) error { + return s.handshakeValidators +} + +func (s *sessionImpl) ValidateHandshake(data *HandshakeData) error { + for name, fun := range s.handshakeValidators { + if err := fun(data); err != nil { + return fmt.Errorf("failed to run '%s' validator: %w. SessionId=%d", name, err, s.ID()) + } + } + + return nil +} + func (s *sessionImpl) sendRequestToFront(ctx context.Context, route string, includeData bool) error { sessionData := &protos.Session{ Id: s.frontendSessionID, diff --git a/session/session_test.go b/session/session_test.go index b9ef4f50..6d288317 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -1448,3 +1448,62 @@ func TestSessionSetHandshakeData(t *testing.T) { }) } } + +func TestSessionPoolAddHandshakeValidator(t *testing.T) { + fa := func(data *HandshakeData) error { return nil } + fb := func(data *HandshakeData) error { return errors.New("error") } + + tables := []struct { + name string + validators map[string]func(data *HandshakeData) error + result int + }{ + {"add validator", map[string]func(data *HandshakeData) error{"fa": fa}, 1}, + {"add many validators", map[string]func(data *HandshakeData) error{"fa": fa, "fb": fb}, 2}, + } + for _, table := range tables { + t.Run(table.name, func(t *testing.T) { + sessionPool := NewSessionPool() + for name, fun := range table.validators { + sessionPool.AddHandshakeValidator(name, fun) + } + session := sessionPool.NewSession(nil, false).(*sessionImpl) + validators := session.GetHandshakeValidators() + assert.Equal(t, len(validators), table.result) + }) + } +} + +func TestSessionValidateHandshake(t *testing.T) { + fa := func(data *HandshakeData) error { return nil } + fb := func(data *HandshakeData) error { return errors.New("error") } + + tables := []struct { + name string + validators map[string]func(data *HandshakeData) error + errStr string + }{ + {"one passing validator", map[string]func(data *HandshakeData) error{"fa": fa}, ""}, + {"one failing validator", map[string]func(data *HandshakeData) error{"fb": fb}, "failed to run 'fb'"}, + {"many validators all passing", map[string]func(data *HandshakeData) error{"fa": fa, "fb": fa}, ""}, + {"many validators one failing", map[string]func(data *HandshakeData) error{"fa": fa, "fb": fb}, "failed to run 'fb'"}, + } + + for _, table := range tables { + t.Run(table.name, func(t *testing.T) { + sessionPool := NewSessionPool() + for name, fun := range table.validators { + sessionPool.AddHandshakeValidator(name, fun) + } + session := sessionPool.NewSession(nil, false).(*sessionImpl) + err := session.ValidateHandshake(nil) + + if table.errStr != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), table.errStr) + } else { + assert.NoError(t, err) + } + }) + } +}