diff --git a/services/streammanager/lambda/lambdamanager.go b/services/streammanager/lambda/lambdamanager.go index 1b02e7196e..d05ef566a5 100644 --- a/services/streammanager/lambda/lambdamanager.go +++ b/services/streammanager/lambda/lambdamanager.go @@ -7,6 +7,7 @@ import ( "github.com/aws/aws-sdk-go/service/lambda" jsoniter "github.com/json-iterator/go" + "github.com/mitchellh/mapstructure" backendconfig "github.com/rudderlabs/rudder-server/config/backend-config" "github.com/rudderlabs/rudder-server/services/streammanager/common" "github.com/rudderlabs/rudder-server/utils/awsutils" @@ -15,14 +16,13 @@ import ( // Config is the config that is required to send data to Lambda type destinationConfig struct { - InvocationType string - ClientContext string - Lambda string + InvocationType string `json:"invocationType"` + ClientContext string `json:"clientContext"` + Lambda string `json:"lambda"` } -type inputConfig struct { - Payload string `json:"payload"` - DestinationConfig *destinationConfig `json:"destConfig"` +type inputData struct { + Payload string `json:"payload"` } type LambdaProducer struct { @@ -56,13 +56,13 @@ func NewProducer(destination *backendconfig.DestinationT, o common.Opts) (*Lambd } // Produce creates a producer and send data to Lambda. -func (producer *LambdaProducer) Produce(jsonData json.RawMessage, _ interface{}) (int, string, string) { +func (producer *LambdaProducer) Produce(jsonData json.RawMessage, destConfig interface{}) (int, string, string) { client := producer.client if client == nil { return 400, "Failure", "[Lambda] error :: Could not create client" } - var input inputConfig + var input inputData err := jsonfast.Unmarshal(jsonData, &input) if err != nil { returnMessage := "[Lambda] error while unmarshalling jsonData :: " + err.Error() @@ -71,9 +71,14 @@ func (producer *LambdaProducer) Produce(jsonData json.RawMessage, _ interface{}) if input.Payload == "" { return 400, "Failure", "[Lambda] error :: Invalid payload" } - config := input.DestinationConfig - if config == nil { - return 400, "Failure", "[Lambda] error :: Invalid destination config" + var config destinationConfig + err = mapstructure.Decode(destConfig, &config) + if err != nil { + returnMessage := "[Lambda] error while unmarshalling destConfig :: " + err.Error() + return 400, "Failure", returnMessage + } + if config.InvocationType == "" { + config.InvocationType = "Event" } var invokeInput lambda.InvokeInput diff --git a/services/streammanager/lambda/lambdamanager_test.go b/services/streammanager/lambda/lambdamanager_test.go index 9401d179d0..cce97d25a9 100644 --- a/services/streammanager/lambda/lambdamanager_test.go +++ b/services/streammanager/lambda/lambdamanager_test.go @@ -17,9 +17,10 @@ import ( ) var ( - sampleMessage = "sample payload" - sampleFunction = "sample function" - invocationType = "Event" + sampleMessage = "sample payload" + sampleFunction = "sample function" + sampleClientContext = "sample client context" + invocationType = "Event" ) func TestNewProducer(t *testing.T) { @@ -81,26 +82,6 @@ func TestProduceWithInvalidData(t *testing.T) { assert.Equal(t, 400, statusCode) assert.Equal(t, "Failure", statusMsg) assert.Contains(t, respMsg, "[Lambda] error :: Invalid payload") - - // Destination Config not present - sampleEventJson, _ = json.Marshal(map[string]interface{}{ - "payload": sampleMessage, - }) - statusCode, statusMsg, respMsg = producer.Produce(sampleEventJson, map[string]string{}) - assert.Equal(t, 400, statusCode) - assert.Equal(t, "Failure", statusMsg) - assert.Contains(t, respMsg, "[Lambda] error :: Invalid destination config") - - // Invalid Destination Config - sampleDestConfig := map[string]interface{}{} - sampleEventJson, _ = json.Marshal(map[string]interface{}{ - "payload": sampleMessage, - "destConfig": "invalid dest config", - }) - statusCode, statusMsg, respMsg = producer.Produce(sampleEventJson, sampleDestConfig) - assert.Equal(t, 400, statusCode) - assert.Equal(t, "Failure", statusMsg) - assert.Contains(t, respMsg, "[Lambda] error while unmarshalling jsonData") } func TestProduceWithServiceResponse(t *testing.T) { @@ -110,26 +91,26 @@ func TestProduceWithServiceResponse(t *testing.T) { mockLogger := mock_logger.NewMockLogger(ctrl) pkgLogger = mockLogger - sampleDestConfig := map[string]interface{}{ - "Lambda": sampleFunction, - "InvocationType": invocationType, - } - sampleEventJson, _ := json.Marshal(map[string]interface{}{ - "payload": sampleMessage, - "destConfig": sampleDestConfig, + "payload": sampleMessage, }) + destConfig := map[string]string{ + "lambda": sampleFunction, + "clientContext": sampleClientContext, + } + var sampleInput lambda.InvokeInput sampleInput.SetFunctionName(sampleFunction) sampleInput.SetPayload([]byte(sampleMessage)) sampleInput.SetInvocationType(invocationType) + sampleInput.SetClientContext(sampleClientContext) mockClient. EXPECT(). Invoke(&sampleInput). Return(&lambda.InvokeOutput{}, nil) - statusCode, statusMsg, respMsg := producer.Produce(sampleEventJson, map[string]string{}) + statusCode, statusMsg, respMsg := producer.Produce(sampleEventJson, destConfig) assert.Equal(t, 200, statusCode) assert.Equal(t, "Success", statusMsg) assert.NotEmpty(t, respMsg) @@ -141,7 +122,7 @@ func TestProduceWithServiceResponse(t *testing.T) { Invoke(&sampleInput). Return(nil, errors.New(errorCode)) mockLogger.EXPECT().Errorf(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1) - statusCode, statusMsg, respMsg = producer.Produce(sampleEventJson, map[string]string{}) + statusCode, statusMsg, respMsg = producer.Produce(sampleEventJson, destConfig) assert.Equal(t, 500, statusCode) assert.Equal(t, "Failure", statusMsg) assert.NotEmpty(t, respMsg) @@ -154,7 +135,7 @@ func TestProduceWithServiceResponse(t *testing.T) { awserr.New(errorCode, errorCode, errors.New(errorCode)), 400, "request-id", )) mockLogger.EXPECT().Errorf(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(1) - statusCode, statusMsg, respMsg = producer.Produce(sampleEventJson, map[string]string{}) + statusCode, statusMsg, respMsg = producer.Produce(sampleEventJson, destConfig) assert.Equal(t, 400, statusCode) assert.Equal(t, errorCode, statusMsg) assert.NotEmpty(t, respMsg)