-
Notifications
You must be signed in to change notification settings - Fork 351
/
client_cache.go
100 lines (85 loc) · 3.08 KB
/
client_cache.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package s3
import (
"context"
"fmt"
"sync"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/go-openapi/swag"
"github.com/treeverse/lakefs/pkg/logging"
"github.com/treeverse/lakefs/pkg/stats"
)
type clientFactory func(awsSession *session.Session, cfgs ...*aws.Config) s3iface.S3API
type s3RegionGetter func(ctx context.Context, sess *session.Session, bucket string) (string, error)
type ClientCache struct {
regionToS3Client sync.Map
bucketToRegion sync.Map
awsSession *session.Session
clientFactory clientFactory
s3RegionGetter s3RegionGetter
collector stats.Collector
}
func getBucketRegionFromS3(ctx context.Context, sess *session.Session, bucket string) (string, error) {
return s3manager.GetBucketRegion(ctx, sess, bucket, "")
}
func getBucketRegionFromSession(ctx context.Context, sess *session.Session, bucket string) (string, error) {
region := aws.StringValue(sess.Config.Region)
return region, nil
}
func newS3Client(sess *session.Session, cfgs ...*aws.Config) s3iface.S3API {
return s3.New(sess, cfgs...)
}
func NewClientCache(awsSession *session.Session) *ClientCache {
return &ClientCache{
awsSession: awsSession,
clientFactory: newS3Client,
s3RegionGetter: getBucketRegionFromS3,
}
}
func (c *ClientCache) SetClientFactory(clientFactory clientFactory) {
c.clientFactory = clientFactory
}
func (c *ClientCache) SetS3RegionGetter(s3RegionGetter s3RegionGetter) {
c.s3RegionGetter = s3RegionGetter
}
func (c *ClientCache) SetStatsCollector(statsCollector stats.Collector) {
c.collector = statsCollector
}
func (c *ClientCache) getBucketRegion(ctx context.Context, bucket string) string {
if region, hasRegion := c.bucketToRegion.Load(bucket); hasRegion {
return region.(string)
}
logging.FromContext(ctx).WithField("bucket", bucket).Debug("requesting region for bucket")
region, err := c.s3RegionGetter(ctx, c.awsSession, bucket)
if err != nil {
logging.FromContext(ctx).WithError(err).Error("failed to get region for bucket, falling back to default region")
region = *c.awsSession.Config.Region
}
c.bucketToRegion.Store(bucket, region)
return region
}
// Get returns an AWS client configured to the region of the given bucket.
func (c *ClientCache) Get(ctx context.Context, bucket string) s3iface.S3API {
region := c.getBucketRegion(ctx, bucket)
svc, hasClient := c.regionToS3Client.Load(region)
if !hasClient {
logging.FromContext(ctx).WithField("bucket", bucket).WithField("region", region).Debug("creating client for region")
svc := c.clientFactory(c.awsSession, &aws.Config{Region: swag.String(region)})
c.regionToS3Client.Store(region, svc)
if c.collector != nil {
c.collector.CollectEvent("s3_block_adapter", fmt.Sprintf("created_aws_client_%s", region))
}
return svc
}
return svc.(s3iface.S3API)
}
func (c *ClientCache) DiscoverBucketRegion(b bool) {
if b {
c.s3RegionGetter = getBucketRegionFromS3
} else {
c.s3RegionGetter = getBucketRegionFromSession
}
}