/
s3client.go
75 lines (64 loc) · 2.16 KB
/
s3client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
package utils
import (
"context"
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/retry"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/stretchr/testify/require"
)
// GetSidekickS3Client returns a S3 client connected to bolt through sidekick
func GetSidekickS3Client(t *testing.T, ctx context.Context, region string) *s3.Client {
customResolver := aws.EndpointResolverWithOptionsFunc(func(service, signRegion string, options ...interface{}) (aws.Endpoint, error) {
if service == s3.ServiceID {
return aws.Endpoint{
PartitionID: "aws",
URL: SidekickURL,
SigningRegion: signRegion,
}, nil
}
// returning EndpointNotFoundError will allow the service to fallback to its default resolution
return aws.Endpoint{}, &aws.EndpointNotFoundError{}
})
cfg, err := config.LoadDefaultConfig(ctx, config.WithEndpointResolverWithOptions(customResolver))
require.NoError(t, err)
// Local cli config may increase this to 5 or 6.
// However, we want to test with the SDK default of 3.
cfg.RetryMaxAttempts = retry.DefaultMaxAttempts
s3c := s3.NewFromConfig(cfg, func(o *s3.Options) {
if region == "" {
o.Region = awsRegion(t, ctx, cfg)
} else {
o.Region = region
}
o.UsePathStyle = true
})
return s3c
}
// GetAwsS3Client returns a default aws S3 client
func GetAwsS3Client(t *testing.T, ctx context.Context, region string) *s3.Client {
cfg, err := config.LoadDefaultConfig(ctx)
require.NoError(t, err)
s3c := s3.NewFromConfig(cfg, func(o *s3.Options) {
if region == "" {
o.Region = awsRegion(t, ctx, cfg)
} else {
o.Region = region
}
})
return s3c
}
func awsRegion(t *testing.T, ctx context.Context, cfg aws.Config) string {
client := imds.NewFromConfig(cfg)
output, err := client.GetRegion(ctx, &imds.GetRegionInput{})
require.NoError(t, err)
return output.Region
}
func GetRegionForBucket(t *testing.T, ctx context.Context, bucket string) string {
region, err := manager.GetBucketRegion(ctx, AwsS3c, bucket)
require.NoError(t, err)
return region
}