diff --git a/pkg/repository/config/aws.go b/pkg/repository/config/aws.go index 6cb87f0a62..ab05d5d5f5 100644 --- a/pkg/repository/config/aws.go +++ b/pkg/repository/config/aws.go @@ -19,7 +19,12 @@ package config import ( "context" + "fmt" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" "os" + "time" "github.com/aws/aws-sdk-go-v2/aws" awsconfig "github.com/aws/aws-sdk-go-v2/config" @@ -31,15 +36,17 @@ import ( const ( // AWS specific environment variable awsProfileEnvVar = "AWS_PROFILE" - awsRoleEnvVar = "AWS_ROLE_ARN" awsKeyIDEnvVar = "AWS_ACCESS_KEY_ID" awsSecretKeyEnvVar = "AWS_SECRET_ACCESS_KEY" awsSessTokenEnvVar = "AWS_SESSION_TOKEN" awsProfileKey = "profile" awsCredentialsFileEnvVar = "AWS_SHARED_CREDENTIALS_FILE" awsConfigFileEnvVar = "AWS_CONFIG_FILE" + awsDefaultProfile = "default" ) +var emptyAWSCredentials = aws.Credentials{} + // GetS3ResticEnvVars gets the environment variables that restic // relies on (AWS_PROFILE) based on info in the provided object // storage location config map. @@ -72,10 +79,6 @@ func GetS3ResticEnvVars(config map[string]string) (map[string]string, error) { // GetS3Credentials gets the S3 credential values according to the information // of the provided config or the system's environment variables func GetS3Credentials(config map[string]string) (*aws.Credentials, error) { - if os.Getenv(awsRoleEnvVar) != "" { - return nil, nil - } - var opts []func(*awsconfig.LoadOptions) error credentialsFile := config[CredentialsFileKey] if credentialsFile == "" { @@ -86,6 +89,7 @@ func GetS3Credentials(config map[string]string) (*aws.Credentials, error) { // To support the existing use case where config file is passed // as credentials of a BSL awsconfig.WithSharedConfigFiles([]string{credentialsFile})) + } opts = append(opts, awsconfig.WithSharedConfigProfile(config[awsProfileKey])) @@ -93,6 +97,23 @@ func GetS3Credentials(config map[string]string) (*aws.Credentials, error) { if err != nil { return nil, err } + + if credentialsFile != "" && os.Getenv("AWS_WEB_IDENTITY_TOKEN_FILE") != "" && os.Getenv("AWS_ROLE_ARN") != "" { + // Reset the config to use the credentials from the credentials/config file + profile := config[awsProfileKey] + if profile == "" { + profile = awsDefaultProfile + } + sfp, err := awsconfig.LoadSharedConfigProfile(context.Background(), profile, func(o *awsconfig.LoadSharedConfigOptions) { + o.ConfigFiles = []string{credentialsFile} + o.CredentialsFiles = []string{credentialsFile} + }) + if err != nil { + return nil, fmt.Errorf("error loading config profile '%s': %v", profile, err) + } + resolveCredsFromProfile(context.Background(), &cfg, &sfp) + } + creds, err := cfg.Credentials.Retrieve(context.Background()) return &creds, err @@ -115,3 +136,44 @@ func GetAWSBucketRegion(bucket string) (string, error) { } return region, nil } + +func resolveCredsFromProfile(ctx context.Context, cfg *aws.Config, sharedConfig *awsconfig.SharedConfig) error { + var err error + switch { + case sharedConfig.Source != nil: + // Assume IAM role with credentials source from a different profile. + err = resolveCredsFromProfile(ctx, cfg, sharedConfig.Source) + case sharedConfig.Credentials.HasKeys(): + // Static Credentials from Shared Config/Credentials file. + cfg.Credentials = credentials.StaticCredentialsProvider{ + Value: sharedConfig.Credentials, + } + } + if err != nil { + return err + } + if len(sharedConfig.RoleARN) > 0 { + credsFromAssumeRole(cfg, sharedConfig) + } + return nil +} + +func credsFromAssumeRole(cfg *aws.Config, sharedCfg *awsconfig.SharedConfig) { + optFns := []func(*stscreds.AssumeRoleOptions){ + func(options *stscreds.AssumeRoleOptions) { + options.RoleSessionName = sharedCfg.RoleSessionName + if sharedCfg.RoleDurationSeconds != nil { + if *sharedCfg.RoleDurationSeconds/time.Minute > 15 { + options.Duration = *sharedCfg.RoleDurationSeconds + } + } + if len(sharedCfg.ExternalID) > 0 { + options.ExternalID = aws.String(sharedCfg.ExternalID) + } + if len(sharedCfg.MFASerial) != 0 { + options.SerialNumber = aws.String(sharedCfg.MFASerial) + } + }, + } + cfg.Credentials = stscreds.NewAssumeRoleProvider(sts.NewFromConfig(*cfg), sharedCfg.RoleARN, optFns...) +} diff --git a/pkg/repository/provider/unified_repo.go b/pkg/repository/provider/unified_repo.go index 76ae36351f..e3422e693b 100644 --- a/pkg/repository/provider/unified_repo.go +++ b/pkg/repository/provider/unified_repo.go @@ -505,7 +505,7 @@ func getStorageVariables(backupLocation *velerov1api.BackupStorageLocation, repo } s3URL = url.Host - disableTLS = (url.Scheme == "http") + disableTLS = url.Scheme == "http" } result[udmrepo.StoreOptionS3Endpoint] = strings.Trim(s3URL, "/")