Skip to content

Commit

Permalink
feat: Multi region support (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
dskart committed May 15, 2023
1 parent 1f09f9f commit bb62e1e
Show file tree
Hide file tree
Showing 13 changed files with 395 additions and 118 deletions.
37 changes: 37 additions & 0 deletions boltrouter/aws_credentials_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package boltrouter

import (
"context"
"fmt"
"sync"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
)

// awsCredentialsMap is used to cache aws credentials for a given region.
var awsCredentialsMap = sync.Map{}

// GetAwsCredentialsFromRegion returns the aws credentials for the given region.
func getAwsCredentialsFromRegion(ctx context.Context, region string) (aws.Credentials, error) {
if awsCred, ok := awsCredentialsMap.Load(region); ok {
return awsCred.(aws.Credentials), nil
}

return newAwsCredentialsFromRegion(ctx, region)
}

// newAwsCredentialsFromRegion creates a new aws credentials from the given region.
func newAwsCredentialsFromRegion(ctx context.Context, region string) (aws.Credentials, error) {
awsConfig, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
return aws.Credentials{}, err
}

cred, err := awsConfig.Credentials.Retrieve(ctx)
if err != nil {
return aws.Credentials{}, fmt.Errorf("could not retrieve aws credentials: %w", err)
}
awsCredentialsMap.Store(awsConfig.Region, cred)
return cred, nil
}
62 changes: 19 additions & 43 deletions boltrouter/bolt_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -22,14 +21,23 @@ type BoltRequest struct {
// a new http.Request Ready to be sent to Bolt.
// This new http.Request is routed to the correct Bolt endpoint and signed correctly.
func (br *BoltRouter) NewBoltRequest(ctx context.Context, req *http.Request) (*BoltRequest, error) {
sourceBucket := extractSourceBucket(req)
failoverRequest, err := newFailoverAwsRequest(ctx, req.Clone(ctx), br.awsCred, sourceBucket, br.boltVars.Region.Get())
sourceBucket, err := extractSourceBucket(ctx, req)
if err != nil {
return nil, fmt.Errorf("could not extract source bucket: %w", err)
}

awsCred, err := getAwsCredentialsFromRegion(ctx, sourceBucket.Region)
if err != nil {
return nil, fmt.Errorf("could not get aws credentials: %w", err)
}

failoverRequest, err := newFailoverAwsRequest(ctx, req.Clone(ctx), awsCred, sourceBucket)
if err != nil {
return nil, fmt.Errorf("failed to make failover request: %w", err)
}

authPrefix := randString(4)
headReq, err := signedAwsHeadRequest(ctx, req, br.awsCred, sourceBucket.bucket, br.boltVars.Region.Get(), authPrefix)
headReq, err := signedAwsHeadRequest(ctx, req, awsCred, sourceBucket.Bucket, sourceBucket.Region, authPrefix)
if err != nil {
return nil, fmt.Errorf("could not make signed aws head request: %w", err)
}
Expand All @@ -42,8 +50,8 @@ func (br *BoltRouter) NewBoltRequest(ctx context.Context, req *http.Request) (*B
// RequestURI is the unmodified request-target of the Request-Line (RFC 7230, Section 3.1.1) as sent by the client to a server.
// It is an error to set this field in an HTTP client request.
req.RequestURI = ""
if sourceBucket.style == virtualHostedStyle {
BoltURL = BoltURL.JoinPath(sourceBucket.bucket, req.URL.EscapedPath())
if sourceBucket.Style == virtualHostedStyle {
BoltURL = BoltURL.JoinPath(sourceBucket.Bucket, req.URL.EscapedPath())

} else {
BoltURL = BoltURL.JoinPath(req.URL.Path)
Expand Down Expand Up @@ -109,14 +117,14 @@ func signedAwsHeadRequest(ctx context.Context, req *http.Request, awsCred aws.Cr
}

// newFailoverAwsRequest creates a standard aws s3 request that can be used as a failover if the Bolt request fails.
func newFailoverAwsRequest(ctx context.Context, req *http.Request, awsCred aws.Credentials, sourceBucket SourceBucket, region string) (*http.Request, error) {
func newFailoverAwsRequest(ctx context.Context, req *http.Request, awsCred aws.Credentials, sourceBucket SourceBucket) (*http.Request, error) {
var host string
switch sourceBucket.style {
switch sourceBucket.Style {
case virtualHostedStyle:
host = fmt.Sprintf("%s.s3.%s.amazonaws.com", sourceBucket.bucket, region)
host = fmt.Sprintf("%s.s3.%s.amazonaws.com", sourceBucket.Bucket, sourceBucket.Region)
// default to path style
default:
host = fmt.Sprintf("s3.%s.amazonaws.com", region)
host = fmt.Sprintf("s3.%s.amazonaws.com", sourceBucket.Region)

}

Expand All @@ -134,45 +142,13 @@ func newFailoverAwsRequest(ctx context.Context, req *http.Request, awsCred aws.C
payloadHash := req.Header.Get("X-Amz-Content-Sha256")

awsSigner := v4.NewSigner()
if err := awsSigner.SignHTTP(ctx, awsCred, req, payloadHash, "s3", region, time.Now()); err != nil {
if err := awsSigner.SignHTTP(ctx, awsCred, req, payloadHash, "s3", sourceBucket.Region, time.Now()); err != nil {
return nil, err
}

return req.Clone(ctx), nil
}

type s3RequestStyle string

const (
virtualHostedStyle s3RequestStyle = "virtual-hosted-style"
pathStyle s3RequestStyle = "path-style"
nAuthDummy s3RequestStyle = "n-auth-dummy"
)

type SourceBucket struct {
bucket string
style s3RequestStyle
}

// extractSourceBucket extracts the aws request bucket using Path-style or Virtual-hosted-style requests.
// https://docs.aws.amazon.com/AmazonS3/latest/userguide/VirtualHosting.html
// This method will "n-auth-dummy" if nothing is found
func extractSourceBucket(req *http.Request) SourceBucket {
// virtual-hosted-style
if split := strings.Split(req.Host, "."); len(split) > 1 {
bucket := split[0]
return SourceBucket{bucket: bucket, style: virtualHostedStyle}
}

// path-style request
if paths := strings.Split(req.URL.EscapedPath(), "/"); len(paths) > 1 {
bucket := paths[1]
return SourceBucket{bucket: bucket, style: pathStyle}
}

return SourceBucket{bucket: "n-auth-dummy", style: nAuthDummy}
}

// DoBoltRequest sends an HTTP Bolt request and returns an HTTP response, following policy (such as redirects, cookies, auth) as configured on the client.
// DoBoltRequest will failover to AWS if the Bolt request fails and the config.Failover is set to true.
// DoboltRequest will return a bool indicating if the request was a failover.
Expand Down
1 change: 1 addition & 0 deletions boltrouter/bolt_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func TestBoltRequest(t *testing.T) {
boltVars := br.boltVars

req, err := http.NewRequest(tt.httpMethod, "test.projectn.co", nil)
req.Header.Set("Authorization", "AWS4-HMAC-SHA256 Credential=AKIA3Y7DLM2EYWSYCN5P/20230511/us-west-2/s3/aws4_request, SignedHeaders=accept-encoding;amz-sdk-invocation-id;amz-sdk-request;host;x-amz-content-sha256;x-amz-date, Signature=6447287d46d333a010e224191d64c31b9738cc37886aadb7753a0a579a30edc6")
require.NoError(t, err)

boltReq, err := br.NewBoltRequest(ctx, req)
Expand Down
13 changes: 0 additions & 13 deletions boltrouter/bolt_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"net/http"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"go.uber.org/zap"
)

Expand All @@ -19,7 +17,6 @@ type BoltRouter struct {
boltHttpClient *http.Client
standardHttpClient *http.Client
boltVars *BoltVars
awsCred aws.Credentials
}

// NewBoltRouter creates a new BoltRouter.
Expand All @@ -43,22 +40,12 @@ func NewBoltRouter(ctx context.Context, logger *zap.Logger, cfg Config) (*BoltRo
Timeout: time.Duration(90) * time.Second,
}

awsCfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return nil, fmt.Errorf("could not load aws default config: %w", err)
}
cred, err := awsCfg.Credentials.Retrieve(ctx)
if err != nil {
return nil, fmt.Errorf("could not retrieve aws credentials: %w", err)
}

br := &BoltRouter{
config: cfg,

boltHttpClient: &boltHttpClient,
standardHttpClient: &standardHttpClient,
boltVars: boltVars,
awsCred: cred,
}

return br, nil
Expand Down
54 changes: 54 additions & 0 deletions boltrouter/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@ import (
"net/http"
"net/http/httptest"
"os"
"sync"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -53,3 +57,53 @@ func SetupQuickSilverMock(t *testing.T, ctx context.Context, logger *zap.Logger)
require.NoError(t, err)
boltVars.QuicksilverURL.Set(quicksilver.URL)
}

type TestS3Client struct {
req *http.Request
lock sync.RWMutex

S3Client *s3.Client
}

func NewTestS3Client(t *testing.T, ctx context.Context, requestStyle s3RequestStyle, region string) *TestS3Client {
ret := &TestS3Client{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ret.lock.Lock()
ret.req = r
ret.lock.Unlock()

sc := http.StatusOK
w.WriteHeader(sc)
}))

customResolver := aws.EndpointResolverWithOptionsFunc(func(service, signRegion string, options ...interface{}) (aws.Endpoint, error) {
if service == s3.ServiceID {
return aws.Endpoint{
PartitionID: "aws",
URL: server.URL,
SigningRegion: signRegion,
}, nil
}
return aws.Endpoint{}, &aws.EndpointNotFoundError{}
})
cfg, err := config.LoadDefaultConfig(ctx,
config.WithEndpointResolverWithOptions(customResolver),
config.WithRegion(region),
)
require.NoError(t, err)

ret.S3Client = s3.NewFromConfig(cfg, func(o *s3.Options) {
if requestStyle == pathStyle {
o.UsePathStyle = true
}
})

return ret
}

func (c *TestS3Client) GetRequest(t *testing.T, ctx context.Context) *http.Request {
c.lock.RLock()
defer c.lock.RUnlock()
require.NotNil(t, c.req)
return c.req.Clone(ctx)
}
84 changes: 84 additions & 0 deletions boltrouter/source_bucket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package boltrouter

import (
"context"
"fmt"
"net/http"
"regexp"
"strconv"
"strings"
)

type s3RequestStyle string

const (
virtualHostedStyle s3RequestStyle = "virtual-hosted-style"
pathStyle s3RequestStyle = "path-style"
nAuthDummy s3RequestStyle = "n-auth-dummy"
)

type SourceBucket struct {
Bucket string
Region string
Style s3RequestStyle
}

// extractSourceBucket extracts the aws request bucket using Path-style or Virtual-hosted-style requests.
// https://docs.aws.amazon.com/AmazonS3/latest/userguide/VirtualHosting.html
// This method will "n-auth-dummy" if nothing is found
func extractSourceBucket(ctx context.Context, req *http.Request) (SourceBucket, error) {
region, err := getRegionForBucket(ctx, req.Header.Get("Authorization"))
if err != nil {
return SourceBucket{}, fmt.Errorf("could not get region for bucket: %w", err)
}

ret := SourceBucket{
Region: region,
Bucket: "n-auth-dummy",
Style: nAuthDummy,
}

isVirtualHostedStyle := false
split := strings.Split(req.Host, ".")
if len(split) > 1 {
if _, err := strconv.Atoi(split[0]); err != nil {
// is not a number, so it is a bucket name
isVirtualHostedStyle = true
}
}

if isVirtualHostedStyle {
bucket := split[0]
ret.Bucket = bucket
ret.Style = virtualHostedStyle
} else if paths := strings.Split(req.URL.EscapedPath(), "/"); len(paths) > 1 {
// path-style request
bucket := paths[1]
ret.Bucket = bucket
ret.Style = pathStyle
}

return ret, nil
}

var credentialRegexp = regexp.MustCompile(`Credential=([^,]*)`)

func getRegionForBucket(ctx context.Context, authHeader string) (string, error) {
if authHeader == "" {
return "", fmt.Errorf("no auth header in request, cannot extract region")
}

matches := credentialRegexp.FindStringSubmatch(authHeader)
if len(matches) != 2 {
return "", fmt.Errorf("could not extract credential from auth header, matches: %v", matches)
}

// format: AKIA3Y7DLM2EYWSYCN5P/20230511/us-east-1/s3/aws4_request
credentialStr := matches[1]
credentialSplit := strings.Split(credentialStr, "/")
if len(credentialSplit) != 5 {
return "", fmt.Errorf("could not extract region from credential, credential: %v", credentialStr)
}
region := credentialSplit[2]
return region, nil
}
43 changes: 43 additions & 0 deletions boltrouter/source_bucket_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package boltrouter

import (
"context"
"fmt"
"testing"

"github.com/Pallinder/go-randomdata"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go/aws"
"github.com/stretchr/testify/assert"
)

func TestExtractSourceBucket(t *testing.T) {
ctx := context.Background()
testCase := []struct {
requestStyle s3RequestStyle
region string
}{
{requestStyle: pathStyle, region: "us-east-1"},
{requestStyle: pathStyle, region: "us-east-2"},
{requestStyle: pathStyle, region: "us-west-1"},
{requestStyle: pathStyle, region: "us-west-2"},
}

for _, tc := range testCase {
t.Run(fmt.Sprintf("%v_%v", tc.region, tc.requestStyle), func(t *testing.T) {
bucketName := randomdata.SillyName()

testS3Client := NewTestS3Client(t, ctx, tc.requestStyle, tc.region)
testS3Client.S3Client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{
Bucket: aws.String(bucketName),
})

req := testS3Client.GetRequest(t, ctx)
sourceBucket, err := extractSourceBucket(ctx, req)
assert.NoError(t, err)
assert.Equal(t, bucketName, sourceBucket.Bucket)
assert.Equal(t, tc.requestStyle, sourceBucket.Style)
assert.Equal(t, tc.region, sourceBucket.Region)
})
}
}
Loading

0 comments on commit bb62e1e

Please sign in to comment.