Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 56 additions & 10 deletions internal/impl/aws/output_sqs.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package aws

import (
"bytes"
"context"
"errors"
"fmt"
Expand All @@ -24,21 +25,45 @@ import (

const (
// SQS Output Fields
sqsoFieldURL = "url"
sqsoFieldMessageGroupID = "message_group_id"
sqsoFieldMessageDedupeID = "message_deduplication_id"
sqsoFieldDelaySeconds = "delay_seconds"
sqsoFieldMetadata = "metadata"
sqsoFieldBatching = "batching"
sqsoFieldURL = "url"
sqsoFieldMessageGroupID = "message_group_id"
sqsoFieldMessageDedupeID = "message_deduplication_id"
sqsoFieldDelaySeconds = "delay_seconds"
sqsoFieldMetadata = "metadata"
sqsoFieldBatching = "batching"
sqsoFieldRemoveInvalidCodePoints = "remove_invalid_codepoints"

sqsMaxRecordsCount = 10
)

// Mapping to ensure that the messages doesn't contain any unicode message outside
// of the set: #x9 | #xA | #xD | #x20 to #xD7FF | #xE000 to #xFFFD | #x10000 to #x10FFFF
// as per SendMessage docs: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_SendMessage.html
var sQSCodepointMap = func(r rune) rune {
switch {
case r == 0x0009:
return r
case r == 0x000A:
return r
case r == 0x000D:
return r
case r >= 0x0020 && r <= 0xD7FF:
return r
case r >= 0xE000 && r <= 0xFFFD:
return r
case r >= 0x10000 && r <= 0x10FFFF:
return r
default:
return -1
}
}

type sqsoConfig struct {
URL string
MessageGroupID *service.InterpolatedString
MessageDeduplicationID *service.InterpolatedString
DelaySeconds *service.InterpolatedString
URL string
MessageGroupID *service.InterpolatedString
MessageDeduplicationID *service.InterpolatedString
DelaySeconds *service.InterpolatedString
RemoveInvalidCodepoints bool

Metadata *service.MetadataExcludeFilter
aconf aws.Config
Expand All @@ -64,6 +89,9 @@ func sqsoConfigFromParsed(pConf *service.ParsedConfig) (conf sqsoConfig, err err
return
}
}
if conf.RemoveInvalidCodepoints, err = pConf.FieldBool(sqsoFieldRemoveInvalidCodePoints); err != nil {
return
}
if conf.Metadata, err = pConf.FieldMetadataExcludeFilter(sqsoFieldMetadata); err != nil {
return
}
Expand Down Expand Up @@ -101,6 +129,13 @@ By default Bento will use a shared credentials file when connecting to AWS servi
service.NewInterpolatedStringField(sqsoFieldDelaySeconds).
Description("An optional delay time in seconds for message. Value between 0 and 900").
Optional(),
service.NewBoolField(sqsoFieldRemoveInvalidCodePoints).
Description(`:::caution
AWS SQS rejects any message containing unicode characters outside of the set: #x9 | #xA | #xD | #x20 to #xD7FF | #xE000 to #xFFFD | #x10000 to #x10FFFF
:::

Setting this field to true will remove any unicode characters outside of the allowed set from both the messsage and metadata values before attempting to send to SQS`).
Default(false),
service.NewOutputMaxInFlightField().
Description("The maximum number of parallel message batches to have in flight at any given time."),
service.NewMetadataExcludeFilterField(snsoFieldMetadata).
Expand Down Expand Up @@ -239,6 +274,17 @@ func (a *sqsWriter) getSQSAttributes(batch service.MessageBatch, i int) (sqsAttr
return sqsAttributes{}, err
}

if a.conf.RemoveInvalidCodepoints {
msgBytes = bytes.Map(sQSCodepointMap, msgBytes)

for k, v := range values {
if v.StringValue != nil {
v.StringValue = aws.String(strings.Map(sQSCodepointMap, *v.StringValue))
values[k] = v
}
}
}

return sqsAttributes{
attrMap: values,
groupID: groupID,
Expand Down
75 changes: 75 additions & 0 deletions internal/impl/aws/output_sqs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"testing"
"unsafe"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
Expand All @@ -15,6 +16,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/warpstreamlabs/bento/internal/metadata"
"github.com/warpstreamlabs/bento/public/service"
)

Expand Down Expand Up @@ -267,3 +269,76 @@ func TestSQSSendLimit(t *testing.T) {
},
}, in)
}
func TestSQSRemoveInvalidCodepoints(t *testing.T) {
conf, err := config.LoadDefaultConfig(context.Background(),
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider("xxxxx", "xxxxx", "xxxxx")),
)
require.NoError(t, err)

tests := map[string]struct {
removeInvalidCodepoints bool
msgBytes []byte
expectedContent string
metadataKey string
metadataValue string
expectedMetadataValue string
}{
"remove_invalid_codepoints_enabled": {
removeInvalidCodepoints: true,
msgBytes: []byte("{\"hello\":\"world\"\uFFFE}"),
expectedContent: "{\"hello\":\"world\"}",
metadataKey: "mykey",
metadataValue: "valid-prefix\uFFFEvalid-suffix",
expectedMetadataValue: "valid-prefixvalid-suffix",
},
"remove_invalid_codepoints_disabled": {
removeInvalidCodepoints: false,
msgBytes: []byte("{\"hello\":\"world\"\uFFFE}"),
expectedContent: "{\"hello\":\"world\"\uFFFE}",
metadataKey: "mykey",
metadataValue: "valid-prefix\uFFFEvalid-suffix",
expectedMetadataValue: "valid-prefix\uFFFEvalid-suffix",
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
t.Parallel()
w, err := newSQSWriter(sqsoConfig{
URL: "http://foo.example.com",
RemoveInvalidCodepoints: test.removeInvalidCodepoints,
aconf: conf,
Metadata: newPassAllMetadataExcludeFilter(t),
}, service.MockResources())
require.NoError(t, err)

msg := service.NewMessage(test.msgBytes)
msg.MetaSet(test.metadataKey, test.metadataValue)
batch := service.MessageBatch{msg}

attr, err := w.getSQSAttributes(batch, 0)
require.NoError(t, err)

assert.Equal(t, test.expectedContent, *attr.content)

if test.metadataKey != "" {
metaAttr, ok := attr.attrMap[test.metadataKey]
require.True(t, ok, "expected metadata key %q to be present in attrMap", test.metadataKey)
require.NotNil(t, metaAttr.StringValue)
assert.Equal(t, test.expectedMetadataValue, *metaAttr.StringValue)
}
})
}
}

func newPassAllMetadataExcludeFilter(t *testing.T) *service.MetadataExcludeFilter {
t.Helper()
conf := metadata.NewExcludeFilterConfig()
filter, err := conf.Filter()
require.NoError(t, err)

mef := &service.MetadataExcludeFilter{}
p := (*struct{ f *metadata.ExcludeFilter })(unsafe.Pointer(mef))
p.f = filter
return mef
}
14 changes: 14 additions & 0 deletions website/docs/components/outputs/aws_sqs.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ output:
message_group_id: "" # No default (optional)
message_deduplication_id: "" # No default (optional)
delay_seconds: "" # No default (optional)
remove_invalid_codepoints: false
max_in_flight: 64
metadata:
exclude_prefixes: []
Expand All @@ -59,6 +60,7 @@ output:
message_group_id: "" # No default (optional)
message_deduplication_id: "" # No default (optional)
delay_seconds: "" # No default (optional)
remove_invalid_codepoints: false
max_in_flight: 64
metadata:
exclude_prefixes: []
Expand Down Expand Up @@ -137,6 +139,18 @@ This field supports [interpolation functions](/docs/configuration/interpolation#

Type: `string`

### `remove_invalid_codepoints`

:::caution
AWS SQS rejects any message containing unicode characters outside of the set: #x9 | #xA | #xD | #x20 to #xD7FF | #xE000 to #xFFFD | #x10000 to #x10FFFF
:::

Setting this field to true will remove any unicode characters outside of the allowed set from both the messsage and metadata values before attempting to send to SQS


Type: `bool`
Default: `false`

### `max_in_flight`

The maximum number of parallel message batches to have in flight at any given time.
Expand Down