-
Notifications
You must be signed in to change notification settings - Fork 111
/
s3.go
434 lines (372 loc) · 13.8 KB
/
s3.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
package s3
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/bmatcuk/doublestar/v4"
"github.com/c2h5oh/datasize"
"github.com/mitchellh/mapstructure"
"github.com/rilldata/rill/runtime/drivers"
rillblob "github.com/rilldata/rill/runtime/drivers/blob"
"github.com/rilldata/rill/runtime/pkg/activity"
"github.com/rilldata/rill/runtime/pkg/globutil"
"github.com/rilldata/rill/runtime/pkg/observability"
"go.uber.org/zap"
"gocloud.dev/blob"
"gocloud.dev/blob/s3blob"
)
var spec = drivers.Spec{
DisplayName: "Amazon S3",
Description: "Connect to AWS S3 Storage.",
ServiceAccountDocs: "https://docs.rilldata.com/deploy/credentials/s3",
SourceProperties: []drivers.PropertySchema{
{
Key: "path",
DisplayName: "S3 URI",
Description: "Path to file on the disk.",
Placeholder: "s3://bucket-name/path/to/file.csv",
Type: drivers.StringPropertyType,
Required: true,
Hint: "Glob patterns are supported",
},
{
Key: "region",
DisplayName: "AWS region",
Description: "AWS Region for the bucket.",
Placeholder: "us-east-1",
Type: drivers.StringPropertyType,
Required: false,
Hint: "Rill will use the default region in your local AWS config, unless set here.",
},
{
Key: "aws.credentials",
DisplayName: "AWS credentials",
Description: "AWS credentials inferred from your local environment.",
Type: drivers.InformationalPropertyType,
Hint: "Set your local credentials: <code>aws configure</code> Click to learn more.",
Href: "https://docs.rilldata.com/develop/import-data#configure-credentials-for-s3",
},
},
ConfigProperties: []drivers.PropertySchema{
{
Key: "aws_access_key_id",
Secret: true,
},
{
Key: "aws_secret_access_key",
Secret: true,
},
},
}
const defaultPageSize = 20
func init() {
drivers.Register("s3", driver{})
drivers.RegisterAsConnector("s3", driver{})
}
type driver struct{}
var _ drivers.Driver = driver{}
type configProperties struct {
AccessKeyID string `mapstructure:"aws_access_key_id"`
SecretAccessKey string `mapstructure:"aws_secret_access_key"`
SessionToken string `mapstructure:"aws_access_token"`
AllowHostAccess bool `mapstructure:"allow_host_access"`
}
// Open implements drivers.Driver
func (d driver) Open(cfgMap map[string]any, shared bool, client activity.Client, logger *zap.Logger) (drivers.Handle, error) {
if shared {
return nil, fmt.Errorf("s3 driver can't be shared")
}
cfg := &configProperties{}
err := mapstructure.Decode(cfgMap, cfg)
if err != nil {
return nil, err
}
conn := &Connection{
config: cfg,
logger: logger,
}
return conn, nil
}
// Drop implements drivers.Driver
func (d driver) Drop(config map[string]any, logger *zap.Logger) error {
return drivers.ErrDropNotSupported
}
func (d driver) Spec() drivers.Spec {
return spec
}
func (d driver) HasAnonymousSourceAccess(ctx context.Context, props map[string]any, logger *zap.Logger) (bool, error) {
conf, err := parseSourceProperties(props)
if err != nil {
return false, fmt.Errorf("failed to parse config: %w", err)
}
c, err := d.Open(map[string]any{}, false, activity.NewNoopClient(), logger)
if err != nil {
return false, err
}
conn := c.(*Connection)
bucketObj, err := conn.openBucket(ctx, conf, conf.url.Host, credentials.AnonymousCredentials)
if err != nil {
return false, fmt.Errorf("failed to open bucket %q, %w", conf.url.Host, err)
}
defer bucketObj.Close()
return bucketObj.IsAccessible(ctx)
}
func (d driver) TertiarySourceConnectors(ctx context.Context, src map[string]any, logger *zap.Logger) ([]string, error) {
return nil, nil
}
type Connection struct {
// config is input configs passed to driver.Open
config *configProperties
logger *zap.Logger
}
var _ drivers.Handle = &Connection{}
// Driver implements drivers.Connection.
func (c *Connection) Driver() string {
return "s3"
}
// Config implements drivers.Connection.
func (c *Connection) Config() map[string]any {
m := make(map[string]any, 0)
_ = mapstructure.Decode(c.config, &m)
return m
}
// Close implements drivers.Connection.
func (c *Connection) Close() error {
return nil
}
// Registry implements drivers.Connection.
func (c *Connection) AsRegistry() (drivers.RegistryStore, bool) {
return nil, false
}
// Catalog implements drivers.Connection.
func (c *Connection) AsCatalogStore(instanceID string) (drivers.CatalogStore, bool) {
return nil, false
}
// Repo implements drivers.Connection.
func (c *Connection) AsRepoStore(instanceID string) (drivers.RepoStore, bool) {
return nil, false
}
// OLAP implements drivers.Connection.
func (c *Connection) AsOLAP(instanceID string) (drivers.OLAPStore, bool) {
return nil, false
}
// Migrate implements drivers.Connection.
func (c *Connection) Migrate(ctx context.Context) (err error) {
return nil
}
// MigrationStatus implements drivers.Connection.
func (c *Connection) MigrationStatus(ctx context.Context) (current, desired int, err error) {
return 0, 0, nil
}
// AsObjectStore implements drivers.Connection.
func (c *Connection) AsObjectStore() (drivers.ObjectStore, bool) {
return c, true
}
// AsTransporter implements drivers.Connection.
func (c *Connection) AsTransporter(from, to drivers.Handle) (drivers.Transporter, bool) {
return nil, false
}
// AsFileStore implements drivers.Connection.
func (c *Connection) AsFileStore() (drivers.FileStore, bool) {
return nil, false
}
// AsSQLStore implements drivers.Connection.
func (c *Connection) AsSQLStore() (drivers.SQLStore, bool) {
return nil, false
}
type sourceProperties struct {
Path string `mapstructure:"path"`
URI string `mapstructure:"uri"`
AWSRegion string `mapstructure:"region"`
GlobMaxTotalSize int64 `mapstructure:"glob.max_total_size"`
GlobMaxObjectsMatched int `mapstructure:"glob.max_objects_matched"`
GlobMaxObjectsListed int64 `mapstructure:"glob.max_objects_listed"`
GlobPageSize int `mapstructure:"glob.page_size"`
S3Endpoint string `mapstructure:"endpoint"`
Extract map[string]any `mapstructure:"extract"`
BatchSize string `mapstructure:"batch_size"`
url *globutil.URL
extractPolicy *rillblob.ExtractPolicy
}
func parseSourceProperties(props map[string]any) (*sourceProperties, error) {
conf := &sourceProperties{}
err := mapstructure.WeakDecode(props, conf)
if err != nil {
return nil, err
}
// Backwards compatibility for "uri" renamed to "path"
if conf.URI != "" {
conf.Path = conf.URI
}
if !doublestar.ValidatePattern(conf.Path) {
return nil, fmt.Errorf("glob pattern %s is invalid", conf.Path)
}
url, err := globutil.ParseBucketURL(conf.Path)
if err != nil {
return nil, fmt.Errorf("failed to parse path %q, %w", conf.Path, err)
}
conf.url = url
if url.Scheme != "s3" {
return nil, fmt.Errorf("invalid s3 path %q, should start with s3://", conf.Path)
}
conf.extractPolicy, err = rillblob.ParseExtractPolicy(conf.Extract)
if err != nil {
return nil, fmt.Errorf("failed to parse extract config: %w", err)
}
return conf, nil
}
// DownloadFiles returns a file iterator over objects stored in s3.
//
// The credentials are read from following configs
// - aws_access_key_id
// - aws_secret_access_key
// - aws_session_token
//
// Additionally in case allow_host_credentials is true it looks for credentials stored on host machine as well
func (c *Connection) DownloadFiles(ctx context.Context, src map[string]any) (drivers.FileIterator, error) {
conf, err := parseSourceProperties(src)
if err != nil {
return nil, fmt.Errorf("failed to parse config: %w", err)
}
creds, err := c.getCredentials()
if err != nil {
return nil, err
}
bucketObj, err := c.openBucket(ctx, conf, conf.url.Host, creds)
if err != nil {
return nil, fmt.Errorf("failed to open bucket %q, %w", conf.url.Host, err)
}
batchSize, err := datasize.ParseString(conf.BatchSize)
if err != nil {
return nil, err
}
// prepare fetch configs
opts := rillblob.Options{
GlobMaxTotalSize: conf.GlobMaxTotalSize,
GlobMaxObjectsMatched: conf.GlobMaxObjectsMatched,
GlobMaxObjectsListed: conf.GlobMaxObjectsListed,
GlobPageSize: conf.GlobPageSize,
GlobPattern: conf.url.Path,
ExtractPolicy: conf.extractPolicy,
BatchSizeBytes: int64(batchSize.Bytes()),
}
it, err := rillblob.NewIterator(ctx, bucketObj, opts, c.logger)
if err != nil {
// TODO :: fix this for single file access. for single file first call only happens during download
var failureErr awserr.RequestFailure
if !errors.As(err, &failureErr) {
return nil, err
}
// aws returns StatusForbidden in cases like no creds passed, wrong creds passed and incorrect bucket
// r2 returns StatusBadRequest in all cases above
// we try again with anonymous credentials in case bucket is public
if (failureErr.StatusCode() == http.StatusForbidden || failureErr.StatusCode() == http.StatusBadRequest) && creds != credentials.AnonymousCredentials {
c.logger.Info("s3 list objects failed, re-trying with anonymous credential", zap.Error(err), observability.ZapCtx(ctx))
creds = credentials.AnonymousCredentials
bucketObj, bucketErr := c.openBucket(ctx, conf, conf.url.Host, creds)
if bucketErr != nil {
return nil, fmt.Errorf("failed to open bucket %q, %w", conf.url.Host, bucketErr)
}
it, err = rillblob.NewIterator(ctx, bucketObj, opts, c.logger)
}
// check again
if errors.As(err, &failureErr) && (failureErr.StatusCode() == http.StatusForbidden || failureErr.StatusCode() == http.StatusBadRequest) {
return nil, drivers.NewPermissionDeniedError(fmt.Sprintf("can't access remote err: %v", failureErr))
}
}
return it, err
}
func (c *Connection) openBucket(ctx context.Context, conf *sourceProperties, bucket string, creds *credentials.Credentials) (*blob.Bucket, error) {
sess, err := c.getAwsSessionConfig(ctx, conf, bucket, creds)
if err != nil {
return nil, fmt.Errorf("failed to start session: %w", err)
}
return s3blob.OpenBucket(ctx, sess, bucket, nil)
}
func (c *Connection) getAwsSessionConfig(ctx context.Context, conf *sourceProperties, bucket string, creds *credentials.Credentials) (*session.Session, error) {
// If S3Endpoint is set, we assume we're targeting an S3 compatible API (but not AWS)
if len(conf.S3Endpoint) > 0 {
region := conf.AWSRegion
if region == "" {
// Set the default region for bwd compatibility reasons
// cloudflare and minio ignore if us-east-1 is set, not tested for others
region = "us-east-1"
}
return session.NewSession(&aws.Config{
Region: aws.String(region),
Endpoint: &conf.S3Endpoint,
S3ForcePathStyle: aws.Bool(true),
Credentials: creds,
})
}
// The logic below is AWS-specific, so we ignore it when conf.S3Endpoint is set
// The complexity below relates to AWS being pretty strict about regions (probably to avoid unexpected cross-region traffic).
// If the user explicitly set a region, we use that
if conf.AWSRegion != "" {
return session.NewSession(&aws.Config{
Region: aws.String(conf.AWSRegion),
Credentials: creds,
})
}
sharedConfigState := session.SharedConfigDisable
if c.config.AllowHostAccess {
sharedConfigState = session.SharedConfigEnable // Tells to look for default region set with `aws configure`
}
// Create a session that tries to infer the region from the environment
sess, err := session.NewSessionWithOptions(session.Options{
SharedConfigState: sharedConfigState,
Config: aws.Config{
Credentials: creds,
},
})
if err != nil {
return nil, err
}
// If no region was found, we default to us-east-1 (which will be used to resolve the lookup in the next step)
if sess.Config.Region == nil || *sess.Config.Region == "" {
sess = sess.Copy(&aws.Config{Region: aws.String("us-east-1")})
}
// Bucket names are globally unique, but requests will fail if their region doesn't match the one configured in the session.
// So we do a lookup for the bucket's region and configure the session to use that.
reg, err := s3manager.GetBucketRegion(ctx, sess, bucket, "")
if err != nil {
return nil, err
}
if reg != "" {
sess = sess.Copy(&aws.Config{Region: aws.String(reg)})
}
return sess, nil
}
func (c *Connection) getCredentials() (*credentials.Credentials, error) {
providers := make([]credentials.Provider, 0)
staticProvider := &credentials.StaticProvider{}
staticProvider.AccessKeyID = c.config.AccessKeyID
staticProvider.SecretAccessKey = c.config.SecretAccessKey
staticProvider.SessionToken = c.config.SessionToken
staticProvider.ProviderName = credentials.StaticProviderName
// in case user doesn't set access key id and secret access key the credentials retreival will fail
// the credential lookup will proceed to next provider in chain
providers = append(providers, staticProvider)
if c.config.AllowHostAccess {
// allowed to access host credentials so we add them in chain
// The chain used here is a duplicate of defaults.CredProviders(), but without the remote credentials lookup (since they resolve too slowly).
providers = append(providers, &credentials.EnvProvider{}, &credentials.SharedCredentialsProvider{Filename: "", Profile: ""})
}
// Find credentials to use.
creds := credentials.NewChainCredentials(providers)
if _, err := creds.Get(); err != nil {
if !errors.Is(err, credentials.ErrNoValidProvidersFoundInChain) {
return nil, err
}
// If no local credentials are found, you must explicitly set AnonymousCredentials to fetch public objects.
// AnonymousCredentials can't be chained, so we try to resolve local creds, and use anon if none were found.
creds = credentials.AnonymousCredentials
}
return creds, nil
}