Skip to content

Commit

Permalink
feat: refresh s3 client
Browse files Browse the repository at this point in the history
  • Loading branch information
dskart committed May 24, 2024
1 parent 501331c commit b7c4c65
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 11 deletions.
14 changes: 4 additions & 10 deletions app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ import (
"net/http"
"time"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/project-n-oss/sidekick/app/aws"
sidekickAws "github.com/project-n-oss/sidekick/app/aws"

"go.uber.org/zap"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
Expand All @@ -19,7 +18,6 @@ type App struct {

standardHttpClient *http.Client
gcpHttpClient *http.Client
s3Client *s3.Client
}

func New(ctx context.Context, logger *zap.Logger, cfg Config) (*App, error) {
Expand All @@ -39,12 +37,8 @@ func New(ctx context.Context, logger *zap.Logger, cfg Config) (*App, error) {

switch cfg.CloudPlatform {
case AwsCloudPlatform.String():
aws.RefreshCredentialsPeriodically(ctx, logger)
awsConfig, err := config.LoadDefaultConfig(ctx)
if err != nil {
return nil, err
}
ret.s3Client = s3.NewFromConfig(awsConfig)
sidekickAws.RefreshCredentialsPeriodically(ctx, logger)
sidekickAws.RefreshS3ClientPeriodically(ctx, logger)

case GcpCloudPlatform.String():
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/devstorage.read_write")
Expand Down
63 changes: 63 additions & 0 deletions app/aws/s3_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package aws

import (
"context"
"fmt"
"sync"
"time"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/s3"
"go.uber.org/zap"
)

var s3ClientMap = sync.Map{}

func GetS3ClientFromRegion(ctx context.Context, region string) (*s3.Client, error) {
if client, ok := s3ClientMap.Load(region); ok {
return client.(*s3.Client), nil
}
return newS3ClientFromRegion(ctx, region)
}

func newS3ClientFromRegion(ctx context.Context, region string) (*s3.Client, error) {
awsConfig, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
return nil, err
}

client := s3.NewFromConfig(awsConfig)
return client, nil
}

func refreshS3Client(ctx context.Context, logger *zap.Logger) {
s3ClientMap.Range(func(key, value interface{}) bool {
region := key.(string)
awsConfig, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
logger.Error(fmt.Sprintf("failed to load aws config for region %s", region), zap.Error(err))
return true
}

client := s3.NewFromConfig(awsConfig)
s3ClientMap.Store(region, client)
return true
})
}

func RefreshS3ClientPeriodically(ctx context.Context, logger *zap.Logger) {
refreshS3Client(ctx, logger)
ticker := time.NewTicker(30 * time.Minute)
go func() {
for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
refreshS3Client(ctx, logger)
}
}
}()

}
8 changes: 7 additions & 1 deletion app/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,14 @@ func (sess *Session) DoAwsRequest(req *http.Request) (*http.Response, bool, erro
// if the source file is not already a crunched file, check if the crunched file exists
if !sess.app.cfg.NoCrunchErr && !isCrunchedFile(cloudRequest.URL.Path) {
objectKey := makeCrunchFilePath(cloudRequest.URL.Path)

// ignore errors, we only want to check if the object exists
headResp, _ := sess.app.s3Client.HeadObject(sess.Context(), &s3.HeadObjectInput{
s3Client, err := sidekickAws.GetS3ClientFromRegion(sess.Context(), sourceBucket.Region)
if err != nil {
return nil, false, fmt.Errorf("failed to get s3 client for region '%s': %w", sourceBucket.Region, err)
}

headResp, _ := s3Client.HeadObject(sess.Context(), &s3.HeadObjectInput{
Bucket: aws.String(sourceBucket.Bucket),
Key: aws.String(objectKey),
})
Expand Down

0 comments on commit b7c4c65

Please sign in to comment.