Skip to content

Commit

Permalink
Merge pull request #1209 from kujon/aws_auth_fixes
Browse files Browse the repository at this point in the history
fix: AWS auth fixes
  • Loading branch information
mialinx committed Feb 11, 2022
2 parents ac0ab66 + d5145a5 commit 2edca7e
Showing 1 changed file with 30 additions and 30 deletions.
60 changes: 30 additions & 30 deletions pkg/storages/s3/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,20 @@ func getFirstSettingOf(settings map[string]string, keys []string) string {
return ""
}

func getDefaultConfig(settings map[string]string, maxRetries int) *aws.Config {
func configWithSettings(config *aws.Config, bucket string, settings map[string]string) (*aws.Config, error) {
// DefaultRetryer implements basic retry logic using exponential backoff for
// most services. If you want to implement custom retry logic, you can implement the
// request.Retryer interface.
config := defaults.Get().Config.WithRegion(settings[RegionSetting])
config = request.WithRetryer(config, NewConnResetRetryer(client.DefaultRetryer{NumMaxRetries: maxRetries}))
maxRetriesCount := MaxRetriesDefault
if maxRetriesRaw, ok := settings[MaxRetriesSetting]; ok {
maxRetriesInt, err := strconv.Atoi(maxRetriesRaw)
if err != nil {
return nil, errors.Wrapf(err, "failed to parse %s", MaxRetriesSetting)
}

maxRetriesCount = maxRetriesInt
}
config = request.WithRetryer(config, NewConnResetRetryer(client.DefaultRetryer{NumMaxRetries: maxRetriesCount}))

accessKeyId := getFirstSettingOf(settings, []string{AccessKeyIdSetting, AccessKeySetting})
secretAccessKey := getFirstSettingOf(settings, []string{SecretAccessKeySetting, SecretKeySetting})
Expand Down Expand Up @@ -133,25 +141,6 @@ func getDefaultConfig(settings map[string]string, maxRetries int) *aws.Config {
if endpoint, ok := settings[EndpointSetting]; ok {
config = config.WithEndpoint(endpoint)
}
return config
}

// TODO : unit tests
func createSession(bucket string, settings map[string]string) (*session.Session, error) {
maxRetriesCount := MaxRetriesDefault
if maxRetriesRaw, ok := settings[MaxRetriesSetting]; ok {
maxRetriesInt, err := strconv.Atoi(maxRetriesRaw)
if err != nil {
return nil, errors.Wrapf(err, "failed to parse %s", MaxRetriesSetting)
}

maxRetriesCount = maxRetriesInt
}
config := getDefaultConfig(settings, maxRetriesCount)
config.MaxRetries = aws.Int(maxRetriesCount)
if _, err := config.Credentials.Get(); err != nil {
return nil, errors.Wrapf(err, "failed to get AWS credentials; please specify %s and %s", AccessKeyIdSetting, SecretAccessKeySetting)
}

if s3ForcePathStyleStr, ok := settings[ForcePathStyleSetting]; ok {
s3ForcePathStyle, err := strconv.ParseBool(s3ForcePathStyleStr)
Expand All @@ -167,34 +156,45 @@ func createSession(bucket string, settings map[string]string) (*session.Session,
}
config = config.WithRegion(region)

return config, nil
}

// TODO : unit tests
func createSession(bucket string, settings map[string]string) (*session.Session, error) {
s, err := session.NewSession()
if err != nil {
return nil, err
}

c, err := configWithSettings(s.Config, bucket, settings)
if err != nil {
return nil, err
}
s.Config = c

filePath := settings[s3CertFile]
if filePath != "" {
if file, err := os.Open(filePath); err == nil {
defer file.Close()
s, err := session.NewSessionWithOptions(session.Options{Config: *config, CustomCABundle: file})
s, err := session.NewSessionWithOptions(session.Options{Config: *s.Config, CustomCABundle: file})
return s, err
} else {
return nil, err
}
}

s, err := session.NewSession(config)

if err != nil {
return nil, err
}
if endpointSource, ok := settings[EndpointSourceSetting]; ok {
s.Handlers.Validate.PushBack(func(request *request.Request) {
src := setupReqProxy(endpointSource, getEndpointPort(settings))
if src != nil {
tracelog.DebugLogger.Printf("using endpoint %s", *src)
host := strings.TrimPrefix(*config.Endpoint, "https://")
host := strings.TrimPrefix(*s.Config.Endpoint, "https://")
request.HTTPRequest.Host = host
request.HTTPRequest.Header.Add("Host", host)
request.HTTPRequest.URL.Host = *src
request.HTTPRequest.URL.Scheme = HTTP
} else {
tracelog.DebugLogger.Printf("using endpoint %s", *config.Endpoint)
tracelog.DebugLogger.Printf("using endpoint %s", *s.Config.Endpoint)
}
})
}
Expand Down

0 comments on commit 2edca7e

Please sign in to comment.