Skip to content

Commit

Permalink
Change Lambdas to create and update in parallel (#3976)
Browse files Browse the repository at this point in the history
The root cause behind why the serialization of Lambda creation/update
was introduced upstream is excessive memory usage (see
hashicorp/terraform#9364).
After investigation we found that this is caused by the HTTP request
logging middleware. It logs the lambda archive as a base64 encoded
string. In order to do so, multiple copies of the body are created in
memory, which leads to memory bloating.
This change fixes that by redacting the body in the logs for the
Create/Update Lambda calls.

The PR introduces two patches. One removes the Lambda serialization and
the other fixes the HTTP request logging middleware for the Lambda
`CreateFunction` and `UpdateFunctionCode` operations.
After this, Lambdas are created/updated in parallel and don't suffer
from excessive memory usage. Users can still limit the parallelism with
the CLI flag `--parallel` if they wish so.

Relates to #2206
  • Loading branch information
flostadler committed Jun 7, 2024
1 parent a5d58e9 commit 42cf551
Show file tree
Hide file tree
Showing 9 changed files with 394 additions and 15 deletions.
47 changes: 47 additions & 0 deletions patches/0060-Parallelize-Lambda-Function-resource-operations.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Florian Stadler <florian@pulumi.com>
Date: Wed, 22 May 2024 17:01:32 +0200
Subject: [PATCH] Parallelize Lambda Function resource operations

Upstream introduced serialization of Lambda Function resource
operations to fight high memory usage when managing a lot of
Lambda functions.
We think this was an optimization for a special edge case that
drastically worsens the UX for the majority of users.

diff --git a/internal/service/lambda/function.go b/internal/service/lambda/function.go
index fb1d412f18..8e3529cc26 100644
--- a/internal/service/lambda/function.go
+++ b/internal/service/lambda/function.go
@@ -36,7 +36,6 @@ import (

const (
FunctionVersionLatest = "$LATEST"
- mutexKey = `aws_lambda_function`
listVersionsMaxItems = 10000
)

@@ -482,11 +481,6 @@ func resourceFunctionCreate(ctx context.Context, d *schema.ResourceData, meta in
}

if v, ok := d.GetOk("filename"); ok {
- // Grab an exclusive lock so that we're only reading one function into memory at a time.
- // See https://github.com/hashicorp/terraform/issues/9364.
- conns.GlobalMutexKV.Lock(mutexKey)
- defer conns.GlobalMutexKV.Unlock(mutexKey)
-
zipFile, err := readFileContents(v.(string))

if err != nil {
@@ -944,11 +938,6 @@ func resourceFunctionUpdate(ctx context.Context, d *schema.ResourceData, meta in
}

if v, ok := d.GetOk("filename"); ok {
- // Grab an exclusive lock so that we're only reading one function into memory at a time.
- // See https://github.com/hashicorp/terraform/issues/9364
- conns.GlobalMutexKV.Lock(mutexKey)
- defer conns.GlobalMutexKV.Unlock(mutexKey)
-
zipFile, err := readFileContents(v.(string))

if err != nil {
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Florian Stadler <florian@pulumi.com>
Date: Fri, 31 May 2024 12:29:36 +0200
Subject: [PATCH] Create Logging Middleware for Lambda service that does not
log the lambda code archive

When creating lambda functions and directly uploading the code, then the whole archive
is being logged as a base64 encoded string as part of the HTTP request logger.
In order to do so, multiple copies of the body are created in memory, which leads
to memory bloating.
This change fixes that by redacting the body in the logs for the Create/Update Lambda
calls.

diff --git a/internal/service/lambda/request_response_logger.go b/internal/service/lambda/request_response_logger.go
new file mode 100644
index 0000000000..737faef4a7
--- /dev/null
+++ b/internal/service/lambda/request_response_logger.go
@@ -0,0 +1,109 @@
+package lambda
+
+import (
+ "context"
+ "fmt"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware"
+ "github.com/aws/smithy-go/middleware"
+ smithyhttp "github.com/aws/smithy-go/transport/http"
+ "github.com/hashicorp/aws-sdk-go-base/v2/logging"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+ _ "unsafe"
+)
+
+const (
+ lambdaCreateOperation = "CreateFunction"
+ lambdaUpdateFunctionCodeOperation = "UpdateFunctionCode"
+)
+
+// Replaces the upstream logging middleware from https://github.com/hashicorp/aws-sdk-go-base/blob/main/logger.go#L107
+// We do not want to log the Lambda Archive that is part of the request body because this leads to bloating memory
+type wrappedRequestResponseLogger struct {
+ wrapped middleware.DeserializeMiddleware
+}
+
+// ID is the middleware identifier.
+func (r *wrappedRequestResponseLogger) ID() string {
+ return "PULUMI_AWS_RequestResponseLogger"
+}
+
+func NewWrappedRequestResponseLogger(wrapped middleware.DeserializeMiddleware) middleware.DeserializeMiddleware {
+ return &wrappedRequestResponseLogger{wrapped: wrapped}
+}
+
+//go:linkname decomposeHTTPResponse github.com/hashicorp/aws-sdk-go-base/v2.decomposeHTTPResponse
+func decomposeHTTPResponse(ctx context.Context, resp *http.Response, elapsed time.Duration) (map[string]any, error)
+
+func (r *wrappedRequestResponseLogger) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler,
+) (
+ out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
+) {
+ if awsmiddleware.GetServiceID(ctx) == "Lambda" {
+ if op := awsmiddleware.GetOperationName(ctx); op != lambdaCreateOperation && op != lambdaUpdateFunctionCodeOperation {
+ // pass through to the wrapped response logger for all other lambda operations that do not send the code as part of the request body
+ return r.wrapped.HandleDeserialize(ctx, in, next)
+ }
+ }
+
+ // Inlined the logging middleware from https://github.com/hashicorp/aws-sdk-go-base/blob/main/logger.go and patching
+ // out the request body logging
+ logger := logging.RetrieveLogger(ctx)
+ region := awsmiddleware.GetRegion(ctx)
+
+ if signingRegion := awsmiddleware.GetSigningRegion(ctx); signingRegion != region { //nolint:staticcheck // Not retrievable elsewhere
+ ctx = logger.SetField(ctx, string(logging.SigningRegionKey), signingRegion)
+ }
+ if awsmiddleware.GetEndpointSource(ctx) == aws.EndpointSourceCustom {
+ ctx = logger.SetField(ctx, string(logging.CustomEndpointKey), true)
+ }
+
+ req, ok := in.Request.(*smithyhttp.Request)
+ if !ok {
+ return out, metadata, fmt.Errorf("unexpected request middleware type %T", in.Request)
+ }
+
+ rc := req.Build(ctx)
+
+ originalBody := rc.Body
+ // remove the body from the logging output. This is the main change compared to the upstream logging middleware
+ redactedBody := strings.NewReader("[Redacted]")
+ rc.Body = io.NopCloser(redactedBody)
+ rc.ContentLength = redactedBody.Size()
+
+ requestFields, err := logging.DecomposeHTTPRequest(ctx, rc)
+ if err != nil {
+ return out, metadata, fmt.Errorf("decomposing request: %w", err)
+ }
+ logger.Debug(ctx, "HTTP Request Sent", requestFields)
+
+ // reconstruct the original request
+ req, err = req.SetStream(originalBody)
+ if err != nil {
+ return out, metadata, err
+ }
+ in.Request = req
+
+ start := time.Now()
+ out, metadata, err = next.HandleDeserialize(ctx, in)
+ duration := time.Since(start)
+
+ if err != nil {
+ return out, metadata, err
+ }
+
+ if res, ok := out.RawResponse.(*smithyhttp.Response); !ok {
+ return out, metadata, fmt.Errorf("unknown response type: %T", out.RawResponse)
+ } else {
+ responseFields, err := decomposeHTTPResponse(ctx, res.Response, duration)
+ if err != nil {
+ return out, metadata, fmt.Errorf("decomposing response: %w", err)
+ }
+ logger.Debug(ctx, "HTTP Response Received", responseFields)
+ }
+
+ return out, metadata, err
+}
diff --git a/internal/service/lambda/service_package_extra.go b/internal/service/lambda/service_package_extra.go
index 54f6aac15a..1f2440d3e3 100644
--- a/internal/service/lambda/service_package_extra.go
+++ b/internal/service/lambda/service_package_extra.go
@@ -6,6 +6,7 @@ import (
aws_sdkv2 "github.com/aws/aws-sdk-go-v2/aws"
retry_sdkv2 "github.com/aws/aws-sdk-go-v2/aws/retry"
lambda_sdkv2 "github.com/aws/aws-sdk-go-v2/service/lambda"
+ "github.com/aws/smithy-go/middleware"
tfawserr_sdkv2 "github.com/hashicorp/aws-sdk-go-base/v2/tfawserr"
"github.com/hashicorp/terraform-provider-aws/internal/conns"
"github.com/hashicorp/terraform-provider-aws/names"
@@ -34,6 +35,19 @@ func (p *servicePackage) NewClient(ctx context.Context, config map[string]any) (
if endpoint := config[names.AttrEndpoint].(string); endpoint != "" {
o.BaseEndpoint = aws_sdkv2.String(endpoint)
}
+
+ // Switch out the terraform http logging middleware with a custom logging middleware that does not log the
+ // lambda code. Logging the lambda code leads to memory bloating because it allocates a lot of copies of the
+ // body
+ o.APIOptions = append(o.APIOptions, func(stack *middleware.Stack) error {
+ loggingMiddleware, err := stack.Deserialize.Remove("TF_AWS_RequestResponseLogger")
+ if err != nil {
+ return err
+ }
+
+ err = stack.Deserialize.Add(NewWrappedRequestResponseLogger(loggingMiddleware), middleware.After)
+ return err
+ })
o.Retryer = conns.AddIsErrorRetryables(cfg.Retryer().(aws_sdkv2.RetryerV2), retry)
}), nil
}
88 changes: 88 additions & 0 deletions provider/provider_nodejs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
package provider

import (
"archive/zip"
"crypto/rand"
"io"
"os"
"path/filepath"
"testing"
"time"

"github.com/pulumi/providertest/pulumitest"
"github.com/pulumi/providertest/pulumitest/opttest"
"github.com/pulumi/pulumi/pkg/v3/testing/integration"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -128,3 +133,86 @@ func TestRegressAttributeMustBeWholeNumber(t *testing.T) {
result := test.Preview()
t.Logf("#%v", result.ChangeSummary)
}

func TestParallelLambdaCreation(t *testing.T) {
if testing.Short() {
t.Skipf("Skipping test in -short mode because it needs cloud credentials")
return
}

tempFile, err := createLambdaArchive(25 * 1024 * 1024)
require.NoError(t, err)
defer os.Remove(tempFile)

maxDuration(5*time.Minute, t, func(t *testing.T) {
test := getJSBaseOptions(t).
With(integration.ProgramTestOptions{
Dir: filepath.Join("test-programs", "parallel-lambdas"),
Config: map[string]string{
"lambda:archivePath": tempFile,
},
// Lambdas have diffs on every update (source code hash)
AllowEmptyPreviewChanges: true,
SkipRefresh: true,
})

integration.ProgramTest(t, &test)
})
}

func getJSBaseOptions(t *testing.T) integration.ProgramTestOptions {
envRegion := getEnvRegion(t)
baseJS := integration.ProgramTestOptions{
Config: map[string]string{
"aws:region": "INVALID_REGION",
"aws:envRegion": envRegion,
},
Dependencies: []string{
"@pulumi/aws",
},
}

return baseJS
}

func createLambdaArchive(size int64) (string, error) {
// Create a temporary file to save the zip archive
tempFile, err := os.CreateTemp("", "archive-*.zip")
if err != nil {
return "", err
}
defer tempFile.Close()

// Create a new zip archive
zipWriter := zip.NewWriter(tempFile)
defer zipWriter.Close()

randomDataReader := io.LimitReader(rand.Reader, size)

// Create the index.js file for the lambda
indexWriter, err := zipWriter.Create("index.js")
if err != nil {
return "", err
}
_, err = indexWriter.Write([]byte("const { version } = require(\"@aws-sdk/client-s3/package.json\");\n\nexports.handler = async () => ({ version });\n"))
if err != nil {
return "", err
}

randomDataWriter, err := zipWriter.Create("random.txt")
if err != nil {
return "", err
}
_, err = io.Copy(randomDataWriter, randomDataReader)
if err != nil {
return "", err
}

// Get the path of the temporary file
archivePath, err := filepath.Abs(tempFile.Name())
if err != nil {
return "", err
}

return archivePath, nil
}
15 changes: 0 additions & 15 deletions provider/provider_python_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,18 +105,3 @@ func getPythonBaseOptions(t *testing.T) integration.ProgramTestOptions {

return pythonBase
}

func maxDuration(dur time.Duration, t *testing.T, test func(t *testing.T)) {
t.Helper()
timeout := time.After(dur)
done := make(chan bool)
go func() {
test(t)
done <- true
}()
select {
case <-timeout:
t.Fatalf("Test timed out after %v", dur)
case <-done:
}
}
16 changes: 16 additions & 0 deletions provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"os"
"path/filepath"
"testing"
"time"

"github.com/pulumi/providertest"
"github.com/pulumi/providertest/optproviderupgrade"
Expand Down Expand Up @@ -94,3 +95,18 @@ func pulumiTest(t *testing.T, dir string, opts ...opttest.Option) *pulumitest.Pu
ptest := pulumitest.NewPulumiTest(t, dir, opts...)
return ptest
}

func maxDuration(dur time.Duration, t *testing.T, test func(t *testing.T)) {
t.Helper()
timeout := time.After(dur)
done := make(chan bool)
go func() {
test(t)
done <- true
}()
select {
case <-timeout:
t.Fatalf("Test timed out after %v", dur)
case <-done:
}
}
3 changes: 3 additions & 0 deletions provider/test-programs/parallel-lambdas/Pulumi.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
name: parallel-lambdas
runtime: nodejs
description: Parallel Lambdas example
Loading

0 comments on commit 42cf551

Please sign in to comment.