Skip to content

Commit

Permalink
feat: crunch lock (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
dskart committed May 28, 2024
1 parent 6a2ed90 commit 1f00e44
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 48 deletions.
23 changes: 12 additions & 11 deletions app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import (
"net/http"
"time"

"github.com/project-n-oss/sidekick/app/aws"
sidekickAws "github.com/project-n-oss/sidekick/app/aws"

"go.uber.org/zap"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
Expand All @@ -24,22 +25,28 @@ func New(ctx context.Context, logger *zap.Logger, cfg Config) (*App, error) {
return nil, err
}

var gcpHttpClient http.Client
standardHttpClient := http.Client{
Timeout: time.Duration(90) * time.Second,
}

ret := &App{
cfg: cfg,
logger: logger,
standardHttpClient: &standardHttpClient,
}

switch cfg.CloudPlatform {
case AwsCloudPlatform.String():
aws.RefreshCredentialsPeriodically(ctx, logger)
sidekickAws.RefreshCredentialsPeriodically(ctx, logger)
sidekickAws.RefreshS3ClientPeriodically(ctx, logger)

case GcpCloudPlatform.String():
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/devstorage.read_write")
if err != nil {
return nil, err
}
ts := oauth2.TokenSource(creds.TokenSource)
gcpHttpClient = http.Client{
ret.gcpHttpClient = &http.Client{
Timeout: time.Duration(90) * time.Second,
Transport: &oauth2.Transport{
Base: http.DefaultTransport,
Expand All @@ -49,13 +56,7 @@ func New(ctx context.Context, logger *zap.Logger, cfg Config) (*App, error) {
}

logger.Sugar().Infof("Cloud Platform: %s, CrunchErr: %v", cfg.CloudPlatform, !cfg.NoCrunchErr)
return &App{
cfg: cfg,
logger: logger,

standardHttpClient: &standardHttpClient,
gcpHttpClient: &gcpHttpClient,
}, nil
return ret, nil
}

func (a *App) Close(ctx context.Context) error {
Expand Down
63 changes: 63 additions & 0 deletions app/aws/s3_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package aws

import (
"context"
"fmt"
"sync"
"time"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/s3"
"go.uber.org/zap"
)

var s3ClientMap = sync.Map{}

func GetS3ClientFromRegion(ctx context.Context, region string) (*s3.Client, error) {
if client, ok := s3ClientMap.Load(region); ok {
return client.(*s3.Client), nil
}
return newS3ClientFromRegion(ctx, region)
}

func newS3ClientFromRegion(ctx context.Context, region string) (*s3.Client, error) {
awsConfig, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
return nil, err
}

client := s3.NewFromConfig(awsConfig)
return client, nil
}

func refreshS3Client(ctx context.Context, logger *zap.Logger) {
s3ClientMap.Range(func(key, value interface{}) bool {
region := key.(string)
awsConfig, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
logger.Error(fmt.Sprintf("failed to load aws config for region %s", region), zap.Error(err))
return true
}

client := s3.NewFromConfig(awsConfig)
s3ClientMap.Store(region, client)
return true
})
}

func RefreshS3ClientPeriodically(ctx context.Context, logger *zap.Logger) {
refreshS3Client(ctx, logger)
ticker := time.NewTicker(30 * time.Minute)
go func() {
for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
refreshS3Client(ctx, logger)
}
}
}()

}
60 changes: 28 additions & 32 deletions app/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@ import (
"net/http"
"strings"

"github.com/project-n-oss/sidekick/app/aws"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
sidekickAws "github.com/project-n-oss/sidekick/app/aws"
)

func statusCodeIs2xx(statusCode int) bool {
return statusCode >= 200 && statusCode < 300
}

// DoRequest makes a request to the cloud platform
// Does a request to the source bucket and if it returns 404, tries the crunched bucket
// Returns the response and a boolean indicating if the response is from the crunched bucket
Expand All @@ -24,18 +22,20 @@ func (sess *Session) DoRequest(req *http.Request) (*http.Response, bool, error)
}
}

const crunchFileFoundErrStatus = "500 Src file not found, but crunched file found"
const crunchFileFoundErrStatus = "409 Src file not found, but crunched file found"
const crunchFileFoundStatusCode = 409

// DoAwsRequest makes a request to AWS
// Does a request to the source bucket and if it returns 404, tries the crunched bucket
// Returns the response and a boolean indicating if the response is from the crunched bucket
// If a crunched version of the source file exists, returns a 500 response
// Returns the response and a boolean indicating if a crunched file was found
// You can disable this behavior by setting NoCrunchErr to true in the config
func (sess *Session) DoAwsRequest(req *http.Request) (*http.Response, bool, error) {
sourceBucket, err := aws.ExtractSourceBucket(req)
sourceBucket, err := sidekickAws.ExtractSourceBucket(req)
if err != nil {
return nil, false, fmt.Errorf("failed to extract source bucket from request: %w", err)
}

cloudRequest, err := aws.NewRequest(sess.Context(), sess.Logger(), req, sourceBucket)
cloudRequest, err := sidekickAws.NewRequest(sess.Context(), sess.Logger(), req, sourceBucket)
if err != nil {
return nil, false, fmt.Errorf("failed to make aws request: %w", err)
}
Expand All @@ -45,45 +45,41 @@ func (sess *Session) DoAwsRequest(req *http.Request) (*http.Response, bool, erro
return nil, false, fmt.Errorf("failed to do aws request: %w", err)
}

statusCode := -1
if resp != nil {
statusCode = resp.StatusCode
}
// if the source file is not already a crunched file, check if the crunched file exists
if !sess.app.cfg.NoCrunchErr && !isCrunchedFile(cloudRequest.URL.Path) {
objectKey := makeCrunchFilePath(sourceBucket.Bucket, cloudRequest.URL.Path)

if statusCode == 404 && !isCrunchedFile(req.URL.Path) && !sess.app.cfg.NoCrunchErr {
crunchedFilePath := makeCrunchFilePath(req.URL.Path)
crunchedRequest, err := aws.NewRequest(sess.Context(), sess.Logger(), req, sourceBucket, aws.WithPath(crunchedFilePath))
s3Client, err := sidekickAws.GetS3ClientFromRegion(sess.Context(), sourceBucket.Region)
if err != nil {
return nil, false, fmt.Errorf("failed to make aws request: %w", err)
return nil, false, fmt.Errorf("failed to get s3 client for region '%s': %w", sourceBucket.Region, err)
}

resp, err := http.DefaultClient.Do(crunchedRequest)
if err != nil {
return nil, false, fmt.Errorf("failed to do crunched aws request: %w", err)
}
crunchedStatusCode := -1
if resp != nil {
crunchedStatusCode = resp.StatusCode
}
// ignore errors, we only want to check if the object exists
headResp, _ := s3Client.HeadObject(sess.Context(), &s3.HeadObjectInput{
Bucket: aws.String(sourceBucket.Bucket),
Key: aws.String(objectKey),
})

// return 500 to client if there is a crunch version of the file
if statusCodeIs2xx(crunchedStatusCode) {
resp.StatusCode = 500
// found crunched file, return 500 to client
if headResp != nil && headResp.ETag != nil {
resp.StatusCode = crunchFileFoundStatusCode
resp.Status = crunchFileFoundErrStatus
}

return resp, true, err
return resp, true, nil
}

return resp, false, err
}

func makeCrunchFilePath(filePath string) string {
func makeCrunchFilePath(bucketName, filePath string) string {
splitS := strings.SplitAfterN(filePath, ".", 2)
ret := strings.TrimSuffix(splitS[0], ".") + ".gr"
if len(splitS) > 1 {
ret += "." + splitS[1]
}
ret = strings.TrimPrefix(ret, "/")
ret = strings.TrimPrefix(ret, bucketName)
ret = strings.TrimPrefix(ret, "/")
return ret
}

Expand Down
11 changes: 6 additions & 5 deletions app/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,18 @@ func TestApp_RequestIsCrunchedFile(t *testing.T) {
func TestApp_RequestMakeCrunchFilePath(t *testing.T) {
testCases := []struct {
path string
bucket string
expected string
}{
{path: "/foo/bar", expected: "/foo/bar.gr"},
{path: "/foo/bar/", expected: "/foo/bar/.gr"},
{path: "/foo/bar/myfile.parquet", expected: "/foo/bar/myfile.gr.parquet"},
{path: "/foo/bar/myfile.c00.zstd.parquet", expected: "/foo/bar/myfile.gr.c00.zstd.parquet"},
{path: "/foo/bar", bucket: "foo", expected: "bar.gr"},
{path: "/foo/bar/", bucket: "foo", expected: "bar/.gr"},
{path: "/foo/bar/myfile.parquet", bucket: "foo", expected: "bar/myfile.gr.parquet"},
{path: "/foo/bar/myfile.c00.zstd.parquet", bucket: "foo", expected: "bar/myfile.gr.c00.zstd.parquet"},
}

for _, tc := range testCases {
t.Run(tc.path, func(t *testing.T) {
assert.Equal(t, tc.expected, makeCrunchFilePath(tc.path))
assert.Equal(t, tc.expected, makeCrunchFilePath(tc.bucket, tc.path))
})
}
}

0 comments on commit 1f00e44

Please sign in to comment.