Skip to content

Commit

Permalink
Adding name to the handshake validators
Browse files Browse the repository at this point in the history
  • Loading branch information
reinaldooli committed May 22, 2023
1 parent 911b018 commit b3d8ff5
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 77 deletions.
66 changes: 54 additions & 12 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ var (
// hbd contains the heartbeat packet data
hbd []byte
// hrd contains the handshake response data
hrd []byte
hrd []byte
// herd contains the handshake error response data
herd []byte
once sync.Once
)

Expand Down Expand Up @@ -110,6 +112,7 @@ type (
Handle()
IPVersion() string
SendHandshakeResponse() error
SendHandshakeErrorResponse() error
SendRequest(ctx context.Context, serverID, route string, v interface{}) (*protos.Response, error)
AnswerWithError(ctx context.Context, mid uint, err error)
}
Expand Down Expand Up @@ -180,6 +183,7 @@ func newAgent(

once.Do(func() {
hbdEncode(heartbeatTime, packetEncoder, messageEncoder.IsCompressionEnabled(), serializerName)
herdEncode(heartbeatTime, packetEncoder, messageEncoder.IsCompressionEnabled(), serializerName)
})

a := &agentImpl{
Expand Down Expand Up @@ -475,6 +479,14 @@ func (a *agentImpl) onSessionClosed(s session.Session) {
// SendHandshakeResponse sends a handshake response
func (a *agentImpl) SendHandshakeResponse() error {
_, err := a.conn.Write(hrd)

return err
}

func (a *agentImpl) SendHandshakeErrorResponse() error {
a.SetStatus(constants.StatusClosed)
_, err := a.conn.Write(herd)

return err
}

Expand Down Expand Up @@ -543,33 +555,63 @@ func hbdEncode(heartbeatTimeout time.Duration, packetEncoder codec.PacketEncoder
"serializer": serializerName,
},
}
data, err := gojson.Marshal(hData)

data, err := encodeAndCompress(hData, dataCompression)
if err != nil {
panic(err)
}

if dataCompression {
compressedData, err := compression.DeflateData(data)
if err != nil {
panic(err)
}
hrd, err = packetEncoder.Encode(packet.Handshake, data)
if err != nil {
panic(err)
}

if len(compressedData) < len(data) {
data = compressedData
}
hbd, err = packetEncoder.Encode(packet.Heartbeat, nil)
if err != nil {
panic(err)
}
}

hrd, err = packetEncoder.Encode(packet.Handshake, data)
func herdEncode(heartbeatTimeout time.Duration, packetEncoder codec.PacketEncoder, dataCompression bool, serializerName string) {
hErrData := map[string]interface{}{
"code": 400,
"sys": map[string]interface{}{
"heartbeat": heartbeatTimeout.Seconds(),
"dict": message.GetDictionary(),
"serializer": serializerName,
},
}

errData, err := encodeAndCompress(hErrData, dataCompression)
if err != nil {
panic(err)
}

hbd, err = packetEncoder.Encode(packet.Heartbeat, nil)
herd, err = packetEncoder.Encode(packet.Handshake, errData)
if err != nil {
panic(err)
}
}

func encodeAndCompress(data interface{}, dataCompression bool) ([]byte, error) {
encData, err := gojson.Marshal(data)
if err != nil {
return nil, err
}

if dataCompression {
compressedData, err := compression.DeflateData(encData)
if err != nil {
return nil, err
}

if len(compressedData) < len(encData) {
encData = compressedData
}
}
return encData, nil
}

func (a *agentImpl) reportChannelSize() {
chSendCapacity := a.messagesBufferSize - len(a.chSend)
if chSendCapacity == 0 {
Expand Down
2 changes: 1 addition & 1 deletion agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func TestNewAgent(t *testing.T) {
func(typ packet.Type, d []byte) {
// cannot compare inside the expect because they are equivalent but not equal
assert.EqualValues(t, packet.Handshake, typ)
})
}).Times(2)
mockEncoder.EXPECT().Encode(gomock.Any(), gomock.Nil()).Do(
func(typ packet.Type, d []byte) {
assert.EqualValues(t, packet.Heartbeat, typ)
Expand Down
14 changes: 14 additions & 0 deletions agent/mocks/agent.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 21 additions & 14 deletions service/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,31 +213,38 @@ func (h *HandlerService) processPacket(a agent.Agent, p *packet.Packet) error {
switch p.Type {
case packet.Handshake:
logger.Log.Debug("Received handshake packet")
if err := a.SendHandshakeResponse(); err != nil {
logger.Log.Errorf("Error sending handshake response: %s", err.Error())
return err
}
logger.Log.Debugf("Session handshake Id=%d, Remote=%s", a.GetSession().ID(), a.RemoteAddr())

// Parse the json sent with the handshake by the client
handshakeData := &session.HandshakeData{}
err := json.Unmarshal(p.Data, handshakeData)
if err != nil {
a.SetStatus(constants.StatusClosed)
if err := json.Unmarshal(p.Data, handshakeData); err != nil {
logger.Log.Errorf("Failed to unmarshal handshake data: %s", err.Error())
if serr := a.SendHandshakeErrorResponse(); serr != nil {
logger.Log.Errorf("Error sending handshake error response: %s", err.Error())
return err
}

return fmt.Errorf("invalid handshake data. Id=%d", a.GetSession().ID())
}

for name, fun := range a.GetSession().GetHandshakeValidators() {
if err := fun(handshakeData); err != nil {
logger.Log.Error("Handshake validation failed '%s': %w", name, err)
a.SetStatus(constants.StatusClosed)
return fmt.Errorf("failed to run '%s' validator: %w. SessionId=%d", name, err, a.GetSession().ID())
if err := a.GetSession().ValidateHandshake(handshakeData); err != nil {
logger.Log.Errorf("Handshake validation failed: %s", err.Error())
if serr := a.SendHandshakeErrorResponse(); serr != nil {
logger.Log.Errorf("Error sending handshake error response: %s", err.Error())
return err
}

return fmt.Errorf("handshake validation failed: %w. SessionId=%d", err, a.GetSession().ID())
}

if err := a.SendHandshakeResponse(); err != nil {
logger.Log.Errorf("Error sending handshake response: %s", err.Error())
return err
}
logger.Log.Debugf("Session handshake Id=%d, Remote=%s", a.GetSession().ID(), a.RemoteAddr())

a.GetSession().SetHandshakeData(handshakeData)
a.SetStatus(constants.StatusHandshake)
err = a.GetSession().Set(constants.IPVersionKey, a.IPVersion())
err := a.GetSession().Set(constants.IPVersionKey, a.IPVersion())
if err != nil {
logger.Log.Warnf("failed to save ip version on session: %q\n", err)
}
Expand Down
57 changes: 19 additions & 38 deletions service/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,12 @@ func TestHandlerServiceProcessPacketHandshake(t *testing.T) {
name string
packet *packet.Packet
socketStatus int32
validator func(data *session.HandshakeData) error
errStr string
}{
{"invalid_handshake_data", &packet.Packet{Type: packet.Handshake, Data: []byte("asiodjasd")}, constants.StatusClosed, "Invalid handshake data"},
{"valid_handshake_data", &packet.Packet{Type: packet.Handshake, Data: []byte(`{"sys":{"platform":"mac"}}`)}, constants.StatusHandshake, ""},
{"invalid_handshake_data", &packet.Packet{Type: packet.Handshake, Data: []byte("asiodjasd")}, constants.StatusClosed, nil, "invalid handshake data"},
{"validator_error", &packet.Packet{Type: packet.Handshake, Data: []byte(`{"sys":{"platform":"mac"}}`)}, constants.StatusClosed, func(data *session.HandshakeData) error { return errors.New("validation failed") }, "handshake validation failed"},
{"valid_handshake_data", &packet.Packet{Type: packet.Handshake, Data: []byte(`{"sys":{"platform":"mac"}}`)}, constants.StatusHandshake, func(data *session.HandshakeData) error { return nil }, ""},
}
for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
Expand All @@ -267,22 +269,28 @@ func TestHandlerServiceProcessPacketHandshake(t *testing.T) {

mockAgent := agentmocks.NewMockAgent(ctrl)
mockAgent.EXPECT().GetSession().Return(mockSession).Times(1)
mockAgent.EXPECT().RemoteAddr().Return(&mockAddr{})
mockAgent.EXPECT().SetStatus(table.socketStatus).Times(1)
mockAgent.EXPECT().SendHandshakeResponse().Return(nil).Times(1)

if table.validator != nil {
mockAgent.EXPECT().GetSession().Return(mockSession).Times(1)
mockSession.EXPECT().ValidateHandshake(gomock.Any()).DoAndReturn(func(data *session.HandshakeData) error {
return table.validator(data)
}).Times(1)
}

if table.errStr == "" {
handshakeData := &session.HandshakeData{}
_ = encjson.Unmarshal(table.packet.Data, handshakeData)
mockAgent.EXPECT().GetSession().Return(mockSession).Times(3)
mockAgent.EXPECT().GetSession().Return(mockSession).Times(2)
mockAgent.EXPECT().IPVersion().Return(constants.IPv4).Times(1)
mockSession.EXPECT().GetHandshakeValidators().Return(map[string]func(*session.HandshakeData) error{}).Times(1)
mockAgent.EXPECT().RemoteAddr().Return(&mockAddr{})
mockAgent.EXPECT().SetStatus(table.socketStatus).Times(1)
mockAgent.EXPECT().SendHandshakeResponse().Return(nil).Times(1)
mockAgent.EXPECT().SetLastAt().Times(1)

mockSession.EXPECT().SetHandshakeData(handshakeData).Times(1)
mockSession.EXPECT().Set(constants.IPVersionKey, constants.IPv4).Times(1)
mockAgent.EXPECT().SetLastAt().Times(1)
} else {
mockAgent.EXPECT().GetSession().Return(mockSession).Times(1)
mockSession.EXPECT().ID().Return(int64(1)).Times(1)
mockAgent.EXPECT().SendHandshakeErrorResponse().Times(1)
}

handlerPool := NewHandlerPool()
Expand All @@ -298,33 +306,6 @@ func TestHandlerServiceProcessPacketHandshake(t *testing.T) {
}
}

func TestHandlerServiceProcessPacketHandshakeValidators(t *testing.T) {
tables := []struct {
name: string
packet: *packet.Packet
socketStatus: int32,
errStr: string,
}{
{"without handshake validator"},
{"with one handshake validator"},
{"with many handshake validators"},
{"one passing validator"},
{"one failing validator"}
{"many validators all pass"},
{"many validators one fail"},
}

for _, table := range tables {
t.Run(name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockSession := mocks.NewMockSession(ctrl)
mockSession.EXPECT().ID().Return(int64(1)).Times(1)
})
}
}

func TestHandlerServiceProcessPacketHandshakeAck(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
Expand Down Expand Up @@ -437,7 +418,7 @@ func TestHandlerServiceHandle(t *testing.T) {

mockSession := mocks.NewMockSession(ctrl)
mockSession.EXPECT().SetHandshakeData(gomock.Any()).Times(1)
mockSession.EXPECT().GetHandshakeValidators().Return([]func(data *session.HandshakeData) error{}).Times(1)
mockSession.EXPECT().ValidateHandshake(gomock.Any()).Times(1)
mockSession.EXPECT().UID().Return("uid").Times(1)
mockSession.EXPECT().ID().Return(int64(1)).Times(2)
mockSession.EXPECT().Set(constants.IPVersionKey, constants.IPv4)
Expand Down
26 changes: 20 additions & 6 deletions session/mocks/session.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package session
import (
"context"
"encoding/json"
"fmt"
"net"
"reflect"
"sync"
Expand Down Expand Up @@ -156,6 +157,7 @@ type Session interface {
Clear()
SetHandshakeData(data *HandshakeData)
GetHandshakeData() *HandshakeData
ValidateHandshake(data *HandshakeData) error
GetHandshakeValidators() map[string]func(data *HandshakeData) error
}

Expand Down Expand Up @@ -806,6 +808,16 @@ func (s *sessionImpl) GetHandshakeValidators() map[string]func(data *HandshakeDa
return s.handshakeValidators
}

func (s *sessionImpl) ValidateHandshake(data *HandshakeData) error {
for name, fun := range s.handshakeValidators {
if err := fun(data); err != nil {
return fmt.Errorf("failed to run '%s' validator: %w. SessionId=%d", name, err, s.ID())
}
}

return nil
}

func (s *sessionImpl) sendRequestToFront(ctx context.Context, route string, includeData bool) error {
sessionData := &protos.Session{
Id: s.frontendSessionID,
Expand Down
Loading

0 comments on commit b3d8ff5

Please sign in to comment.