Skip to content

Commit

Permalink
fix: aws session creation is failing for s3 manager when roles are us…
Browse files Browse the repository at this point in the history
…ed (#2799)
  • Loading branch information
koladilip committed Dec 13, 2022
1 parent ba9832c commit 1534d64
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 82 deletions.
8 changes: 2 additions & 6 deletions services/filemanager/filemanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,9 @@ func init() {
func (*FileManagerFactoryT) New(settings *SettingsT) (FileManager, error) {
switch settings.Provider {
case "S3_DATALAKE":
return &S3Manager{
Config: GetS3Config(settings.Config),
}, nil
return NewS3Manager(settings.Config)
case "S3":
return &S3Manager{
Config: GetS3Config(settings.Config),
}, nil
return NewS3Manager(settings.Config)
case "GCS":
return &GCSManager{
Config: GetGCSConfig(settings.Config),
Expand Down
50 changes: 17 additions & 33 deletions services/filemanager/s3manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,7 @@ func (manager *S3Manager) DeleteObjects(ctx context.Context, keys []string) (err
_, err := svc.DeleteObjectsWithContext(_ctx, input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
default:
pkgLogger.Errorf(`Error while deleting S3 objects: %v, error code: %v`, aerr.Error(), aerr.Code())
}
pkgLogger.Errorf(`Error while deleting S3 objects: %v, error code: %v`, aerr.Error(), aerr.Code())
} else {
// Print the error, cast err to awserr.Error to get the Code and
// Message from an error.
Expand All @@ -155,22 +152,6 @@ func (manager *S3Manager) DeleteObjects(ctx context.Context, keys []string) (err
return nil
}

func (manager *S3Manager) getSessionConfig() *awsutils.SessionConfig {
sessionConfig := &awsutils.SessionConfig{
Region: *manager.Config.Region,
Endpoint: manager.Config.Endpoint,
S3ForcePathStyle: manager.Config.S3ForcePathStyle,
DisableSSL: manager.Config.DisableSSL,
AccessKeyID: manager.Config.AccessKeyID,
AccessKey: manager.Config.AccessKey,
IAMRoleARN: manager.Config.IAMRoleARN,
ExternalID: manager.Config.ExternalID,
Service: s3.ServiceName,
}

return sessionConfig
}

func (manager *S3Manager) getSession(ctx context.Context) (*session.Session, error) {
if manager.session != nil {
return manager.session, nil
Expand All @@ -194,10 +175,11 @@ func (manager *S3Manager) getSession(ctx context.Context) (*session.Session, err
/// Failed to Get Region probably due to VPC restrictions, Will proceed to try with AccessKeyID and AccessKey
}
manager.Config.Region = aws.String(region)
manager.SessionConfig.Region = region
}

var err error
manager.session, err = awsutils.CreateSession(manager.getSessionConfig())
manager.session, err = awsutils.CreateSession(manager.SessionConfig)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -258,9 +240,10 @@ func (manager *S3Manager) GetConfiguredPrefix() string {
}

type S3Manager struct {
Config *S3Config
session *session.Session
timeout time.Duration
Config *S3Config
SessionConfig *awsutils.SessionConfig
session *session.Session
timeout time.Duration
}

func (manager *S3Manager) SetTimeout(timeout time.Duration) {
Expand All @@ -275,27 +258,28 @@ func (manager *S3Manager) getTimeout() time.Duration {
return getBatchRouterTimeoutConfig("S3")
}

func GetS3Config(config map[string]interface{}) *S3Config {
func NewS3Manager(config map[string]interface{}) (*S3Manager, error) {
var s3Config S3Config
if err := mapstructure.Decode(config, &s3Config); err != nil {
pkgLogger.Errorf("unable to code config into S3Config: %w", err)
s3Config = S3Config{}
return nil, err
}
regionHint := appConfig.GetString("AWS_S3_REGION_HINT", "us-east-1")
s3Config.RegionHint = regionHint
s3Config.IsTruncated = true

return &s3Config
sessionConfig, err := awsutils.NewSimpleSessionConfig(config, s3.ServiceName)
if err != nil {
return nil, err
}
return &S3Manager{
Config: &s3Config,
SessionConfig: sessionConfig,
}, nil
}

type S3Config struct {
Bucket string `mapstructure:"bucketName"`
Prefix string `mapstructure:"Prefix"`
Region *string `mapstructure:"region"`
AccessKeyID string `mapstructure:"accessKeyID"`
AccessKey string `mapstructure:"accessKey"`
IAMRoleARN string `mapstructure:"iamRoleARN"`
ExternalID string `mapstructure:"externalID"`
Endpoint *string `mapstructure:"endpoint"`
S3ForcePathStyle *bool `mapstructure:"s3ForcePathStyle"`
DisableSSL *bool `mapstructure:"disableSSL"`
Expand Down
124 changes: 90 additions & 34 deletions services/filemanager/s3manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,99 @@ import (
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/rudderlabs/rudder-server/utils/awsutils"
"github.com/stretchr/testify/assert"
)

func TestNewS3ManagerWithNil(t *testing.T) {
s3Manager, err := NewS3Manager(nil)
assert.EqualError(t, err, "config should not be nil")
assert.Nil(t, s3Manager)
}

func TestNewS3ManagerWithAccessKeys(t *testing.T) {
s3Manager, err := NewS3Manager(map[string]interface{}{
"bucketName": "someBucket",
"region": "someRegion",
"accessKeyID": "someAccessKeyId",
"accessKey": "someSecretAccessKey",
})
assert.Nil(t, err)
assert.NotNil(t, s3Manager)
assert.Equal(t, "someBucket", s3Manager.Config.Bucket)
assert.Equal(t, aws.String("someRegion"), s3Manager.Config.Region)
assert.Equal(t, "someAccessKeyId", s3Manager.SessionConfig.AccessKeyID)
assert.Equal(t, "someSecretAccessKey", s3Manager.SessionConfig.AccessKey)
assert.Equal(t, false, s3Manager.SessionConfig.RoleBasedAuth)
}

func TestNewS3ManagerWithRole(t *testing.T) {
s3Manager, err := NewS3Manager(map[string]interface{}{
"bucketName": "someBucket",
"region": "someRegion",
"iamRoleARN": "someIAMRole",
"externalID": "someExternalID",
})
assert.Nil(t, err)
assert.NotNil(t, s3Manager)
assert.Equal(t, "someBucket", s3Manager.Config.Bucket)
assert.Equal(t, aws.String("someRegion"), s3Manager.Config.Region)
assert.Equal(t, "someIAMRole", s3Manager.SessionConfig.IAMRoleARN)
assert.Equal(t, "someExternalID", s3Manager.SessionConfig.ExternalID)
assert.Equal(t, true, s3Manager.SessionConfig.RoleBasedAuth)
}

func TestNewS3ManagerWithBothAccessKeysAndRole(t *testing.T) {
s3Manager, err := NewS3Manager(map[string]interface{}{
"bucketName": "someBucket",
"region": "someRegion",
"iamRoleARN": "someIAMRole",
"externalID": "someExternalID",
"accessKeyID": "someAccessKeyId",
"accessKey": "someSecretAccessKey",
})
assert.Nil(t, err)
assert.NotNil(t, s3Manager)
assert.Equal(t, "someBucket", s3Manager.Config.Bucket)
assert.Equal(t, aws.String("someRegion"), s3Manager.Config.Region)
assert.Equal(t, "someAccessKeyId", s3Manager.SessionConfig.AccessKeyID)
assert.Equal(t, "someSecretAccessKey", s3Manager.SessionConfig.AccessKey)
assert.Equal(t, "someIAMRole", s3Manager.SessionConfig.IAMRoleARN)
assert.Equal(t, "someExternalID", s3Manager.SessionConfig.ExternalID)
assert.Equal(t, true, s3Manager.SessionConfig.RoleBasedAuth)
}

func TestNewS3ManagerWithBothAccessKeysAndRoleButRoleBasedAuthFalse(t *testing.T) {
s3Manager, err := NewS3Manager(map[string]interface{}{
"bucketName": "someBucket",
"region": "someRegion",
"iamRoleARN": "someIAMRole",
"externalID": "someExternalID",
"accessKeyID": "someAccessKeyId",
"accessKey": "someSecretAccessKey",
"roleBasedAuth": false,
})
assert.Nil(t, err)
assert.NotNil(t, s3Manager)
assert.Equal(t, "someBucket", s3Manager.Config.Bucket)
assert.Equal(t, aws.String("someRegion"), s3Manager.Config.Region)
assert.Equal(t, "someAccessKeyId", s3Manager.SessionConfig.AccessKeyID)
assert.Equal(t, "someSecretAccessKey", s3Manager.SessionConfig.AccessKey)
assert.Equal(t, "someIAMRole", s3Manager.SessionConfig.IAMRoleARN)
assert.Equal(t, "someExternalID", s3Manager.SessionConfig.ExternalID)
assert.Equal(t, false, s3Manager.SessionConfig.RoleBasedAuth)
}

func TestGetSessionWithAccessKeys(t *testing.T) {
s3Manager := S3Manager{
Config: &S3Config{
Bucket: "someBucket",
Bucket: "someBucket",
Region: aws.String("someRegion"),
},
SessionConfig: &awsutils.SessionConfig{
AccessKeyID: "someAccessKeyId",
AccessKey: "someSecretAccessKey",
Region: aws.String("someRegion"),
Region: "someRegion",
},
}
awsSession, err := s3Manager.getSession(context.TODO())
Expand All @@ -26,44 +109,17 @@ func TestGetSessionWithAccessKeys(t *testing.T) {
func TestGetSessionWithIAMRole(t *testing.T) {
s3Manager := S3Manager{
Config: &S3Config{
Bucket: "someBucket",
Bucket: "someBucket",
Region: aws.String("someRegion"),
},
SessionConfig: &awsutils.SessionConfig{
IAMRoleARN: "someIAMRole",
ExternalID: "someExternalID",
Region: aws.String("someRegion"),
Region: "someRegion",
},
}
awsSession, err := s3Manager.getSession(context.TODO())
assert.Nil(t, err)
assert.NotNil(t, awsSession)
assert.NotNil(t, s3Manager.session)
}

func TestGetSessionConfigWithAccessKeys(t *testing.T) {
s3Manager := S3Manager{
Config: &S3Config{
Bucket: "someBucket",
AccessKeyID: "someAccessKeyId",
AccessKey: "someSecretAccessKey",
Region: aws.String("someRegion"),
},
}
awsSessionConfig := s3Manager.getSessionConfig()
assert.NotNil(t, awsSessionConfig)
assert.Equal(t, s3Manager.Config.AccessKey, awsSessionConfig.AccessKey)
assert.Equal(t, s3Manager.Config.AccessKeyID, awsSessionConfig.AccessKeyID)
}

func TestGetSessionConfigWithIAMRole(t *testing.T) {
s3Manager := S3Manager{
Config: &S3Config{
Bucket: "someBucket",
IAMRoleARN: "someIAMRole",
ExternalID: "someExternalID",
Region: aws.String("someRegion"),
},
}
awsSessionConfig := s3Manager.getSessionConfig()
assert.NotNil(t, awsSessionConfig)
assert.Equal(t, s3Manager.Config.IAMRoleARN, awsSessionConfig.IAMRoleARN)
assert.Equal(t, s3Manager.Config.ExternalID, awsSessionConfig.ExternalID)
}
30 changes: 21 additions & 9 deletions utils/awsutils/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type SessionConfig struct {
RoleBasedAuth bool `mapstructure:"roleBasedAuth"`
IAMRoleARN string `mapstructure:"iamRoleARN"`
ExternalID string `mapstructure:"externalID"`
WorkspaceID string `mapstructure:"workspaceID"`
Endpoint *string `mapstructure:"endpoint"`
S3ForcePathStyle *bool `mapstructure:"s3ForcePathStyle"`
DisableSSL *bool `mapstructure:"disableSSL"`
Expand Down Expand Up @@ -52,7 +53,7 @@ func createDefaultSession(config *SessionConfig) (*session.Session, error) {
})
}

func createCredentailsForRole(config *SessionConfig) (*credentials.Credentials, error) {
func createCredentialsForRole(config *SessionConfig) (*credentials.Credentials, error) {
if config.ExternalID == "" {
return nil, errors.New("externalID is required for IAM role")
}
Expand All @@ -73,7 +74,7 @@ func CreateSession(config *SessionConfig) (*session.Session, error) {
err error
)
if config.RoleBasedAuth {
awsCredentials, err = createCredentailsForRole(config)
awsCredentials, err = createCredentialsForRole(config)
} else if config.AccessKey != "" && config.AccessKeyID != "" {
awsCredentials, err = credentials.NewStaticCredentials(config.AccessKeyID, config.AccessKey, ""), nil
}
Expand All @@ -96,20 +97,20 @@ func isRoleBasedAuthFieldExist(config map[string]interface{}) bool {
return ok
}

func NewSimpleSessionConfigForDestination(destination *backendconfig.DestinationT, serviceName string) (*SessionConfig, error) {
if destination == nil {
return nil, errors.New("destination should not be nil")
func NewSimpleSessionConfig(config map[string]interface{}, serviceName string) (*SessionConfig, error) {
if config == nil {
return nil, errors.New("config should not be nil")
}
sessionConfig := SessionConfig{}
if err := mapstructure.Decode(destination.Config, &sessionConfig); err != nil {
if err := mapstructure.Decode(config, &sessionConfig); err != nil {
return nil, fmt.Errorf("unable to populate session config using destinationConfig: %w", err)
}

if sessionConfig.RoleBasedAuth && sessionConfig.IAMRoleARN == "" {
return nil, errors.New("incompatible role configuration")
}

if !isRoleBasedAuthFieldExist(destination.Config) {
if !isRoleBasedAuthFieldExist(config) {
sessionConfig.RoleBasedAuth = sessionConfig.IAMRoleARN != ""
}

Expand All @@ -118,15 +119,26 @@ func NewSimpleSessionConfigForDestination(destination *backendconfig.Destination
sessionConfig.AccessKey = sessionConfig.SecretAccessKey
}
sessionConfig.Service = serviceName
if sessionConfig.IAMRoleARN != "" {
return &sessionConfig, nil
}

func NewSimpleSessionConfigForDestination(destination *backendconfig.DestinationT, serviceName string) (*SessionConfig, error) {
if destination == nil {
return nil, errors.New("destination should not be nil")
}
sessionConfig, err := NewSimpleSessionConfig(destination.Config, serviceName)
if err != nil {
return nil, err
}
if sessionConfig.IAMRoleARN != "" && sessionConfig.ExternalID == "" {
/**
In order prevent confused deputy problem, we are using
workspace token as external ID.
Ref: https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html
*/
sessionConfig.ExternalID = destination.WorkspaceID
}
return &sessionConfig, nil
return sessionConfig, nil
}

func NewSessionConfigForDestination(destination *backendconfig.DestinationT, timeout time.Duration, serviceName string) (*SessionConfig, error) {
Expand Down
7 changes: 7 additions & 0 deletions utils/awsutils/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ var (
httpTimeout time.Duration = 10 * time.Second
)

func TestNewSessionConfigWithNilDestConfig(t *testing.T) {
serviceName := "kinesis"
sessionConfig, err := NewSessionConfigForDestination(&backendconfig.DestinationT{}, httpTimeout, serviceName)
assert.EqualError(t, err, "config should not be nil")
assert.Nil(t, sessionConfig)
}

func TestNewSessionConfigWithAccessKey(t *testing.T) {
serviceName := "kinesis"
sessionConfig, err := NewSessionConfigForDestination(&destinationWithAccessKey, httpTimeout, serviceName)
Expand Down
10 changes: 10 additions & 0 deletions warehouse/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,16 @@ func GetTemporaryS3Cred(destination *backendconfig.DestinationT) (string, string
return "", "", "", err
}

// Role already provides temporary credentials
// so we shouldn't call sts.GetSessionToken again
if sessionConfig.RoleBasedAuth {
creds, err := awsSession.Config.Credentials.Get()
if err != nil {
return "", "", "", err
}
return creds.AccessKeyID, creds.SecretAccessKey, creds.SessionToken, nil
}

// Create an STS client from just a session.
svc := sts.New(awsSession)

Expand Down

0 comments on commit 1534d64

Please sign in to comment.