Skip to content

Commit

Permalink
Improve message sequence validator (#4037)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexshtin committed Mar 14, 2023
1 parent 122bb36 commit b520a5b
Showing 1 changed file with 26 additions and 16 deletions.
42 changes: 26 additions & 16 deletions service/history/commandChecker.go
Original file line number Diff line number Diff line change
Expand Up @@ -1001,33 +1001,43 @@ func (v *commandAttrValidator) commandTypes(
return result
}

// TODO (alex-update): move to messageValidator.
// TODO (alex-update): move to messageValidator and dedicated package.
func (v *commandAttrValidator) validateMessages(
messages []*protocolpb.Message,
) error {
// Validates messages:
// 1. Sequence: Response (i.e. complete) must go after Acceptance and Rejection.
// 2. Only update.Acceptance, update.Response, and update.Rejection messages are allowed.

seenCompleteMessage := false
// 1. Sequence: Response message for the same protocol_instance_id must go after Acceptance.
// 2. Rejection can't be paired with Response or Acceptance.
// 3. Only Acceptance, Response,and Rejection messages are allowed.

seenAcceptance := make(map[string]struct{})
seenRejection := make(map[string]struct{})
seenResponse := make(map[string]struct{})
const messageSequence = "invalid message sequence for %s: %s message must be before %s message"
const messagePairing = "invalid message pairing for %s: %s message is not allowed because %s message was already received"
for _, message := range messages {
//nolint:revive // early-return
if types.Is(message.GetBody(), (*updatepb.Acceptance)(nil)) {
if seenCompleteMessage {
return serviceerror.NewInvalidArgument(
fmt.Sprintf(
"invalid message sequence: %s message must be before %s message",
proto.MessageName((*updatepb.Acceptance)(nil)), proto.MessageName((*updatepb.Response)(nil))))
if _, ok := seenResponse[message.GetProtocolInstanceId()]; ok {
return serviceerror.NewInvalidArgument(fmt.Sprintf(messageSequence, message.GetProtocolInstanceId(), proto.MessageName((*updatepb.Acceptance)(nil)), proto.MessageName((*updatepb.Response)(nil))))
}
if _, ok := seenRejection[message.GetProtocolInstanceId()]; ok {
return serviceerror.NewInvalidArgument(fmt.Sprintf(messagePairing, message.GetProtocolInstanceId(), proto.MessageName((*updatepb.Acceptance)(nil)), proto.MessageName((*updatepb.Rejection)(nil))))
}
seenAcceptance[message.GetProtocolInstanceId()] = struct{}{}
} else if types.Is(message.GetBody(), (*updatepb.Response)(nil)) {
seenCompleteMessage = true
if _, ok := seenRejection[message.GetProtocolInstanceId()]; ok {
return serviceerror.NewInvalidArgument(fmt.Sprintf(messagePairing, message.GetProtocolInstanceId(), proto.MessageName((*updatepb.Response)(nil)), proto.MessageName((*updatepb.Rejection)(nil))))
}
seenResponse[message.GetProtocolInstanceId()] = struct{}{}
} else if types.Is(message.GetBody(), (*updatepb.Rejection)(nil)) {
if seenCompleteMessage {
return serviceerror.NewInvalidArgument(
fmt.Sprintf(
"invalid message sequence: %s message must be before %s message",
proto.MessageName((*updatepb.Rejection)(nil)), proto.MessageName((*updatepb.Response)(nil))))
if _, ok := seenAcceptance[message.GetProtocolInstanceId()]; ok {
return serviceerror.NewInvalidArgument(fmt.Sprintf(messagePairing, message.GetProtocolInstanceId(), proto.MessageName((*updatepb.Rejection)(nil)), proto.MessageName((*updatepb.Acceptance)(nil))))
}
if _, ok := seenResponse[message.GetProtocolInstanceId()]; ok {
return serviceerror.NewInvalidArgument(fmt.Sprintf(messagePairing, message.GetProtocolInstanceId(), proto.MessageName((*updatepb.Rejection)(nil)), proto.MessageName((*updatepb.Response)(nil))))
}
seenRejection[message.GetProtocolInstanceId()] = struct{}{}
} else {
return serviceerror.NewInvalidArgument(fmt.Sprintf("unknown message type: %v", message.GetBody().GetTypeUrl()))
}
Expand Down

0 comments on commit b520a5b

Please sign in to comment.