-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
395 additions
and
118 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
package boltrouter | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"sync" | ||
|
||
"github.com/aws/aws-sdk-go-v2/aws" | ||
"github.com/aws/aws-sdk-go-v2/config" | ||
) | ||
|
||
// awsCredentialsMap is used to cache aws credentials for a given region. | ||
var awsCredentialsMap = sync.Map{} | ||
|
||
// GetAwsCredentialsFromRegion returns the aws credentials for the given region. | ||
func getAwsCredentialsFromRegion(ctx context.Context, region string) (aws.Credentials, error) { | ||
if awsCred, ok := awsCredentialsMap.Load(region); ok { | ||
return awsCred.(aws.Credentials), nil | ||
} | ||
|
||
return newAwsCredentialsFromRegion(ctx, region) | ||
} | ||
|
||
// newAwsCredentialsFromRegion creates a new aws credentials from the given region. | ||
func newAwsCredentialsFromRegion(ctx context.Context, region string) (aws.Credentials, error) { | ||
awsConfig, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) | ||
if err != nil { | ||
return aws.Credentials{}, err | ||
} | ||
|
||
cred, err := awsConfig.Credentials.Retrieve(ctx) | ||
if err != nil { | ||
return aws.Credentials{}, fmt.Errorf("could not retrieve aws credentials: %w", err) | ||
} | ||
awsCredentialsMap.Store(awsConfig.Region, cred) | ||
return cred, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
package boltrouter | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net/http" | ||
"regexp" | ||
"strconv" | ||
"strings" | ||
) | ||
|
||
type s3RequestStyle string | ||
|
||
const ( | ||
virtualHostedStyle s3RequestStyle = "virtual-hosted-style" | ||
pathStyle s3RequestStyle = "path-style" | ||
nAuthDummy s3RequestStyle = "n-auth-dummy" | ||
) | ||
|
||
type SourceBucket struct { | ||
Bucket string | ||
Region string | ||
Style s3RequestStyle | ||
} | ||
|
||
// extractSourceBucket extracts the aws request bucket using Path-style or Virtual-hosted-style requests. | ||
// https://docs.aws.amazon.com/AmazonS3/latest/userguide/VirtualHosting.html | ||
// This method will "n-auth-dummy" if nothing is found | ||
func extractSourceBucket(ctx context.Context, req *http.Request) (SourceBucket, error) { | ||
region, err := getRegionForBucket(ctx, req.Header.Get("Authorization")) | ||
if err != nil { | ||
return SourceBucket{}, fmt.Errorf("could not get region for bucket: %w", err) | ||
} | ||
|
||
ret := SourceBucket{ | ||
Region: region, | ||
Bucket: "n-auth-dummy", | ||
Style: nAuthDummy, | ||
} | ||
|
||
isVirtualHostedStyle := false | ||
split := strings.Split(req.Host, ".") | ||
if len(split) > 1 { | ||
if _, err := strconv.Atoi(split[0]); err != nil { | ||
// is not a number, so it is a bucket name | ||
isVirtualHostedStyle = true | ||
} | ||
} | ||
|
||
if isVirtualHostedStyle { | ||
bucket := split[0] | ||
ret.Bucket = bucket | ||
ret.Style = virtualHostedStyle | ||
} else if paths := strings.Split(req.URL.EscapedPath(), "/"); len(paths) > 1 { | ||
// path-style request | ||
bucket := paths[1] | ||
ret.Bucket = bucket | ||
ret.Style = pathStyle | ||
} | ||
|
||
return ret, nil | ||
} | ||
|
||
var credentialRegexp = regexp.MustCompile(`Credential=([^,]*)`) | ||
|
||
func getRegionForBucket(ctx context.Context, authHeader string) (string, error) { | ||
if authHeader == "" { | ||
return "", fmt.Errorf("no auth header in request, cannot extract region") | ||
} | ||
|
||
matches := credentialRegexp.FindStringSubmatch(authHeader) | ||
if len(matches) != 2 { | ||
return "", fmt.Errorf("could not extract credential from auth header, matches: %v", matches) | ||
} | ||
|
||
// format: AKIA3Y7DLM2EYWSYCN5P/20230511/us-east-1/s3/aws4_request | ||
credentialStr := matches[1] | ||
credentialSplit := strings.Split(credentialStr, "/") | ||
if len(credentialSplit) != 5 { | ||
return "", fmt.Errorf("could not extract region from credential, credential: %v", credentialStr) | ||
} | ||
region := credentialSplit[2] | ||
return region, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package boltrouter | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"testing" | ||
|
||
"github.com/Pallinder/go-randomdata" | ||
"github.com/aws/aws-sdk-go-v2/service/s3" | ||
"github.com/aws/aws-sdk-go/aws" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestExtractSourceBucket(t *testing.T) { | ||
ctx := context.Background() | ||
testCase := []struct { | ||
requestStyle s3RequestStyle | ||
region string | ||
}{ | ||
{requestStyle: pathStyle, region: "us-east-1"}, | ||
{requestStyle: pathStyle, region: "us-east-2"}, | ||
{requestStyle: pathStyle, region: "us-west-1"}, | ||
{requestStyle: pathStyle, region: "us-west-2"}, | ||
} | ||
|
||
for _, tc := range testCase { | ||
t.Run(fmt.Sprintf("%v_%v", tc.region, tc.requestStyle), func(t *testing.T) { | ||
bucketName := randomdata.SillyName() | ||
|
||
testS3Client := NewTestS3Client(t, ctx, tc.requestStyle, tc.region) | ||
testS3Client.S3Client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ | ||
Bucket: aws.String(bucketName), | ||
}) | ||
|
||
req := testS3Client.GetRequest(t, ctx) | ||
sourceBucket, err := extractSourceBucket(ctx, req) | ||
assert.NoError(t, err) | ||
assert.Equal(t, bucketName, sourceBucket.Bucket) | ||
assert.Equal(t, tc.requestStyle, sourceBucket.Style) | ||
assert.Equal(t, tc.region, sourceBucket.Region) | ||
}) | ||
} | ||
} |
Oops, something went wrong.