Skip to content

Commit

Permalink
Fix memo/search attribute size validation when upserting (#3353)
Browse files Browse the repository at this point in the history
* Fix memo size validation for UpsertMemo

* Fix search attributes size validation
  • Loading branch information
rodrigozhou committed Sep 9, 2022
1 parent 2925d4f commit 33263db
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 45 deletions.
43 changes: 43 additions & 0 deletions common/payload/payload.go
Expand Up @@ -25,12 +25,17 @@
package payload

import (
"github.com/gogo/protobuf/proto"
commonpb "go.temporal.io/api/common/v1"
"go.temporal.io/sdk/converter"
"golang.org/x/exp/maps"
)

var (
defaultDataConverter = converter.GetDefaultDataConverter()

nilPayload, _ = Encode(nil)
emptySlicePayload, _ = Encode([]string{})
)

func EncodeString(str string) *commonpb.Payload {
Expand All @@ -56,3 +61,41 @@ func Decode(p *commonpb.Payload, valuePtr interface{}) error {
func ToString(p *commonpb.Payload) string {
return defaultDataConverter.ToString(p)
}

// MergeMapOfPayload returns a new map resulting from merging map m2 into m1.
// If a key in m2 already exists in m1, then the value in m2 replaces the value in m1.
// If the new payload have nil data or an empty slice data, then it deletes the key.
// For example:
//
// m1 := map[string]*commonpb.Payload{
// "key1": EncodeString("value1"),
// "key2": EncodeString("value2"),
// }
// m2 := map[string]*commonpb.Payload{
// "key1": EncodeString("newValue1"),
// "key2": nilPayload,
// }
// m3 := MergeMapOfPayload(m1, m2)
//
// The resulting map `m3` is:
//
// m1 := map[string]*commonpb.Payload{
// "key1": EncodeString("newValue1"),
// }
func MergeMapOfPayload(
m1 map[string]*commonpb.Payload,
m2 map[string]*commonpb.Payload,
) map[string]*commonpb.Payload {
if len(m1) == 0 {
return maps.Clone(m2)
}
ret := maps.Clone(m1)
for k, v := range m2 {
if proto.Equal(v, nilPayload) || proto.Equal(v, emptySlicePayload) {
delete(ret, k)
} else {
ret[k] = v
}
}
return ret
}
31 changes: 31 additions & 0 deletions common/payload/payload_test.go
Expand Up @@ -28,6 +28,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
commonpb "go.temporal.io/api/common/v1"
)

type testStruct struct {
Expand Down Expand Up @@ -71,3 +72,33 @@ func TestToString(t *testing.T) {
result = ToString(nil)
assert.Equal("", result)
}

func TestMergeMapOfPayload(t *testing.T) {
assert := assert.New(t)

var currentMap map[string]*commonpb.Payload
var newMap map[string]*commonpb.Payload
resultMap := MergeMapOfPayload(currentMap, newMap)
assert.Equal(make(map[string]*commonpb.Payload), resultMap)

newMap = map[string]*commonpb.Payload{"key": EncodeString("val")}
resultMap = MergeMapOfPayload(currentMap, newMap)
assert.Equal(newMap, resultMap)

currentMap = map[string]*commonpb.Payload{"number": EncodeString("1")}
resultMap = MergeMapOfPayload(currentMap, newMap)
assert.Equal(
map[string]*commonpb.Payload{"number": EncodeString("1"), "key": EncodeString("val")},
resultMap,
)

newValue, _ := Encode(nil)
newMap = map[string]*commonpb.Payload{"number": newValue}
resultMap = MergeMapOfPayload(currentMap, newMap)
assert.Equal(0, len(resultMap))

newValue, _ = Encode([]int{})
newMap = map[string]*commonpb.Payload{"number": newValue}
resultMap = MergeMapOfPayload(currentMap, newMap)
assert.Equal(0, len(resultMap))
}
4 changes: 2 additions & 2 deletions service/history/commandChecker.go
Expand Up @@ -173,15 +173,15 @@ func (c *workflowSizeChecker) failWorkflowIfPayloadSizeExceedsLimit(
}

func (c *workflowSizeChecker) failWorkflowIfMemoSizeExceedsLimit(
memo *commonpb.Memo,
commandTypeTag metrics.Tag,
memoSize int,
message string,
) (bool, error) {

executionInfo := c.mutableState.GetExecutionInfo()
executionState := c.mutableState.GetExecutionState()
err := common.CheckEventBlobSizeLimit(
memoSize,
memo.Size(),
c.memoSizeLimitWarn,
c.memoSizeLimitError,
executionInfo.NamespaceId,
Expand Down
26 changes: 3 additions & 23 deletions service/history/workflow/mutable_state_impl.go
Expand Up @@ -59,6 +59,7 @@ import (
"go.temporal.io/server/common/log/tag"
"go.temporal.io/server/common/metrics"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/payload"
"go.temporal.io/server/common/persistence"
"go.temporal.io/server/common/persistence/versionhistory"
"go.temporal.io/server/common/persistence/visibility"
Expand Down Expand Up @@ -2789,11 +2790,8 @@ func (e *MutableStateImpl) AddUpsertWorkflowSearchAttributesEvent(
func (e *MutableStateImpl) ReplicateUpsertWorkflowSearchAttributesEvent(
event *historypb.HistoryEvent,
) {

upsertSearchAttr := event.GetUpsertWorkflowSearchAttributesEventAttributes().GetSearchAttributes().GetIndexedFields()
currentSearchAttr := e.GetExecutionInfo().SearchAttributes

e.executionInfo.SearchAttributes = mergeMapOfPayload(currentSearchAttr, upsertSearchAttr)
e.executionInfo.SearchAttributes = payload.MergeMapOfPayload(e.executionInfo.SearchAttributes, upsertSearchAttr)
}

func (e *MutableStateImpl) AddWorkflowPropertiesModifiedEvent(
Expand Down Expand Up @@ -2822,26 +2820,8 @@ func (e *MutableStateImpl) ReplicateWorkflowPropertiesModifiedEvent(
attr := event.GetWorkflowPropertiesModifiedEventAttributes()
if attr.UpsertedMemo != nil {
upsertMemo := attr.GetUpsertedMemo().GetFields()
currentMemo := e.GetExecutionInfo().Memo
e.executionInfo.Memo = mergeMapOfPayload(currentMemo, upsertMemo)
}
}

func mergeMapOfPayload(
current map[string]*commonpb.Payload,
upsert map[string]*commonpb.Payload,
) map[string]*commonpb.Payload {
if current == nil {
current = make(map[string]*commonpb.Payload)
}
for k, v := range upsert {
if v.Data == nil {
delete(current, k)
} else {
current[k] = v
}
e.executionInfo.Memo = payload.MergeMapOfPayload(e.executionInfo.Memo, upsertMemo)
}
return current
}

func (e *MutableStateImpl) AddExternalWorkflowExecutionSignaled(
Expand Down
16 changes: 0 additions & 16 deletions service/history/workflow/mutable_state_impl_test.go
Expand Up @@ -47,7 +47,6 @@ import (
"go.temporal.io/server/common/failure"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/payload"
"go.temporal.io/server/common/payloads"
"go.temporal.io/server/common/persistence"
"go.temporal.io/server/common/persistence/versionhistory"
Expand Down Expand Up @@ -321,21 +320,6 @@ func (s *mutableStateSuite) TestChecksumShouldInvalidate() {
s.False(s.mutableState.shouldInvalidateCheckum())
}

func (s *mutableStateSuite) TestMergeMapOfPayload() {
var currentMap map[string]*commonpb.Payload
var newMap map[string]*commonpb.Payload
resultMap := mergeMapOfPayload(currentMap, newMap)
s.Equal(make(map[string]*commonpb.Payload), resultMap)

newMap = map[string]*commonpb.Payload{"key": payload.EncodeString("val")}
resultMap = mergeMapOfPayload(currentMap, newMap)
s.Equal(newMap, resultMap)

currentMap = map[string]*commonpb.Payload{"number": payload.EncodeString("1")}
resultMap = mergeMapOfPayload(currentMap, newMap)
s.Equal(2, len(resultMap))
}

func (s *mutableStateSuite) TestEventReapplied() {
runID := uuid.New()
eventID := int64(1)
Expand Down
18 changes: 14 additions & 4 deletions service/history/workflowTaskHandler.go
Expand Up @@ -46,6 +46,7 @@ import (
"go.temporal.io/server/common/log/tag"
"go.temporal.io/server/common/metrics"
"go.temporal.io/server/common/namespace"
"go.temporal.io/server/common/payload"
"go.temporal.io/server/common/payloads"
"go.temporal.io/server/common/primitives/timestamp"
"go.temporal.io/server/common/searchattribute"
Expand Down Expand Up @@ -794,8 +795,8 @@ func (handler *workflowTaskHandlerImpl) handleCommandContinueAsNewWorkflow(
}

failWorkflow, err = handler.sizeLimitChecker.failWorkflowIfMemoSizeExceedsLimit(
attr.GetMemo(),
metrics.CommandTypeTag(enumspb.COMMAND_TYPE_CONTINUE_AS_NEW_WORKFLOW_EXECUTION.String()),
attr.GetMemo().Size(),
"ContinueAsNewWorkflowExecutionCommandAttributes. Memo exceeds size limit.",
)
if err != nil || failWorkflow {
Expand Down Expand Up @@ -915,8 +916,8 @@ func (handler *workflowTaskHandlerImpl) handleCommandStartChildWorkflow(
}

failWorkflow, err = handler.sizeLimitChecker.failWorkflowIfMemoSizeExceedsLimit(
attr.GetMemo(),
metrics.CommandTypeTag(enumspb.COMMAND_TYPE_START_CHILD_WORKFLOW_EXECUTION.String()),
attr.GetMemo().Size(),
"StartChildWorkflowExecutionCommandAttributes. Memo exceeds size limit.",
)
if err != nil || failWorkflow {
Expand Down Expand Up @@ -1058,8 +1059,14 @@ func (handler *workflowTaskHandlerImpl) handleCommandUpsertWorkflowSearchAttribu
return err
}

// new search attributes size limit check
failWorkflow, err = handler.sizeLimitChecker.failWorkflowIfSearchAttributesSizeExceedsLimit(
attr.GetSearchAttributes(),
&commonpb.SearchAttributes{
IndexedFields: payload.MergeMapOfPayload(
executionInfo.SearchAttributes,
attr.GetSearchAttributes().GetIndexedFields(),
),
},
namespace,
metrics.CommandTypeTag(enumspb.COMMAND_TYPE_UPSERT_WORKFLOW_SEARCH_ATTRIBUTES.String()),
)
Expand Down Expand Up @@ -1124,9 +1131,12 @@ func (handler *workflowTaskHandlerImpl) handleCommandModifyWorkflowProperties(
return err
}

// new memo size limit check
failWorkflow, err = handler.sizeLimitChecker.failWorkflowIfMemoSizeExceedsLimit(
&commonpb.Memo{
Fields: payload.MergeMapOfPayload(executionInfo.Memo, attr.GetUpsertedMemo().GetFields()),
},
metrics.CommandTypeTag(enumspb.COMMAND_TYPE_MODIFY_WORKFLOW_PROPERTIES.String()),
attr.GetUpsertedMemo().Size(),
"ModifyWorkflowPropertiesCommandAttributes. Memo exceeds size limit.",
)
if err != nil || failWorkflow {
Expand Down

0 comments on commit 33263db

Please sign in to comment.