diff --git a/services/filemanager/filemanager.go b/services/filemanager/filemanager.go index 62b43c36ae..8276dd835f 100644 --- a/services/filemanager/filemanager.go +++ b/services/filemanager/filemanager.go @@ -64,13 +64,9 @@ func init() { func (*FileManagerFactoryT) New(settings *SettingsT) (FileManager, error) { switch settings.Provider { case "S3_DATALAKE": - return &S3Manager{ - Config: GetS3Config(settings.Config), - }, nil + return NewS3Manager(settings.Config) case "S3": - return &S3Manager{ - Config: GetS3Config(settings.Config), - }, nil + return NewS3Manager(settings.Config) case "GCS": return &GCSManager{ Config: GetGCSConfig(settings.Config), diff --git a/services/filemanager/s3manager.go b/services/filemanager/s3manager.go index ab3f27b083..79d9b0d19e 100644 --- a/services/filemanager/s3manager.go +++ b/services/filemanager/s3manager.go @@ -140,10 +140,7 @@ func (manager *S3Manager) DeleteObjects(ctx context.Context, keys []string) (err _, err := svc.DeleteObjectsWithContext(_ctx, input) if err != nil { if aerr, ok := err.(awserr.Error); ok { - switch aerr.Code() { - default: - pkgLogger.Errorf(`Error while deleting S3 objects: %v, error code: %v`, aerr.Error(), aerr.Code()) - } + pkgLogger.Errorf(`Error while deleting S3 objects: %v, error code: %v`, aerr.Error(), aerr.Code()) } else { // Print the error, cast err to awserr.Error to get the Code and // Message from an error. @@ -155,22 +152,6 @@ func (manager *S3Manager) DeleteObjects(ctx context.Context, keys []string) (err return nil } -func (manager *S3Manager) getSessionConfig() *awsutils.SessionConfig { - sessionConfig := &awsutils.SessionConfig{ - Region: *manager.Config.Region, - Endpoint: manager.Config.Endpoint, - S3ForcePathStyle: manager.Config.S3ForcePathStyle, - DisableSSL: manager.Config.DisableSSL, - AccessKeyID: manager.Config.AccessKeyID, - AccessKey: manager.Config.AccessKey, - IAMRoleARN: manager.Config.IAMRoleARN, - ExternalID: manager.Config.ExternalID, - Service: s3.ServiceName, - } - - return sessionConfig -} - func (manager *S3Manager) getSession(ctx context.Context) (*session.Session, error) { if manager.session != nil { return manager.session, nil @@ -194,10 +175,11 @@ func (manager *S3Manager) getSession(ctx context.Context) (*session.Session, err /// Failed to Get Region probably due to VPC restrictions, Will proceed to try with AccessKeyID and AccessKey } manager.Config.Region = aws.String(region) + manager.SessionConfig.Region = region } var err error - manager.session, err = awsutils.CreateSession(manager.getSessionConfig()) + manager.session, err = awsutils.CreateSession(manager.SessionConfig) if err != nil { return nil, err } @@ -258,9 +240,10 @@ func (manager *S3Manager) GetConfiguredPrefix() string { } type S3Manager struct { - Config *S3Config - session *session.Session - timeout time.Duration + Config *S3Config + SessionConfig *awsutils.SessionConfig + session *session.Session + timeout time.Duration } func (manager *S3Manager) SetTimeout(timeout time.Duration) { @@ -275,27 +258,28 @@ func (manager *S3Manager) getTimeout() time.Duration { return getBatchRouterTimeoutConfig("S3") } -func GetS3Config(config map[string]interface{}) *S3Config { +func NewS3Manager(config map[string]interface{}) (*S3Manager, error) { var s3Config S3Config if err := mapstructure.Decode(config, &s3Config); err != nil { - pkgLogger.Errorf("unable to code config into S3Config: %w", err) - s3Config = S3Config{} + return nil, err } regionHint := appConfig.GetString("AWS_S3_REGION_HINT", "us-east-1") s3Config.RegionHint = regionHint s3Config.IsTruncated = true - - return &s3Config + sessionConfig, err := awsutils.NewSimpleSessionConfig(config, s3.ServiceName) + if err != nil { + return nil, err + } + return &S3Manager{ + Config: &s3Config, + SessionConfig: sessionConfig, + }, nil } type S3Config struct { Bucket string `mapstructure:"bucketName"` Prefix string `mapstructure:"Prefix"` Region *string `mapstructure:"region"` - AccessKeyID string `mapstructure:"accessKeyID"` - AccessKey string `mapstructure:"accessKey"` - IAMRoleARN string `mapstructure:"iamRoleARN"` - ExternalID string `mapstructure:"externalID"` Endpoint *string `mapstructure:"endpoint"` S3ForcePathStyle *bool `mapstructure:"s3ForcePathStyle"` DisableSSL *bool `mapstructure:"disableSSL"` diff --git a/services/filemanager/s3manager_test.go b/services/filemanager/s3manager_test.go index 19fd787504..dd272dc9f6 100644 --- a/services/filemanager/s3manager_test.go +++ b/services/filemanager/s3manager_test.go @@ -5,16 +5,99 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws" + "github.com/rudderlabs/rudder-server/utils/awsutils" "github.com/stretchr/testify/assert" ) +func TestNewS3ManagerWithNil(t *testing.T) { + s3Manager, err := NewS3Manager(nil) + assert.EqualError(t, err, "config should not be nil") + assert.Nil(t, s3Manager) +} + +func TestNewS3ManagerWithAccessKeys(t *testing.T) { + s3Manager, err := NewS3Manager(map[string]interface{}{ + "bucketName": "someBucket", + "region": "someRegion", + "accessKeyID": "someAccessKeyId", + "accessKey": "someSecretAccessKey", + }) + assert.Nil(t, err) + assert.NotNil(t, s3Manager) + assert.Equal(t, "someBucket", s3Manager.Config.Bucket) + assert.Equal(t, aws.String("someRegion"), s3Manager.Config.Region) + assert.Equal(t, "someAccessKeyId", s3Manager.SessionConfig.AccessKeyID) + assert.Equal(t, "someSecretAccessKey", s3Manager.SessionConfig.AccessKey) + assert.Equal(t, false, s3Manager.SessionConfig.RoleBasedAuth) +} + +func TestNewS3ManagerWithRole(t *testing.T) { + s3Manager, err := NewS3Manager(map[string]interface{}{ + "bucketName": "someBucket", + "region": "someRegion", + "iamRoleARN": "someIAMRole", + "externalID": "someExternalID", + }) + assert.Nil(t, err) + assert.NotNil(t, s3Manager) + assert.Equal(t, "someBucket", s3Manager.Config.Bucket) + assert.Equal(t, aws.String("someRegion"), s3Manager.Config.Region) + assert.Equal(t, "someIAMRole", s3Manager.SessionConfig.IAMRoleARN) + assert.Equal(t, "someExternalID", s3Manager.SessionConfig.ExternalID) + assert.Equal(t, true, s3Manager.SessionConfig.RoleBasedAuth) +} + +func TestNewS3ManagerWithBothAccessKeysAndRole(t *testing.T) { + s3Manager, err := NewS3Manager(map[string]interface{}{ + "bucketName": "someBucket", + "region": "someRegion", + "iamRoleARN": "someIAMRole", + "externalID": "someExternalID", + "accessKeyID": "someAccessKeyId", + "accessKey": "someSecretAccessKey", + }) + assert.Nil(t, err) + assert.NotNil(t, s3Manager) + assert.Equal(t, "someBucket", s3Manager.Config.Bucket) + assert.Equal(t, aws.String("someRegion"), s3Manager.Config.Region) + assert.Equal(t, "someAccessKeyId", s3Manager.SessionConfig.AccessKeyID) + assert.Equal(t, "someSecretAccessKey", s3Manager.SessionConfig.AccessKey) + assert.Equal(t, "someIAMRole", s3Manager.SessionConfig.IAMRoleARN) + assert.Equal(t, "someExternalID", s3Manager.SessionConfig.ExternalID) + assert.Equal(t, true, s3Manager.SessionConfig.RoleBasedAuth) +} + +func TestNewS3ManagerWithBothAccessKeysAndRoleButRoleBasedAuthFalse(t *testing.T) { + s3Manager, err := NewS3Manager(map[string]interface{}{ + "bucketName": "someBucket", + "region": "someRegion", + "iamRoleARN": "someIAMRole", + "externalID": "someExternalID", + "accessKeyID": "someAccessKeyId", + "accessKey": "someSecretAccessKey", + "roleBasedAuth": false, + }) + assert.Nil(t, err) + assert.NotNil(t, s3Manager) + assert.Equal(t, "someBucket", s3Manager.Config.Bucket) + assert.Equal(t, aws.String("someRegion"), s3Manager.Config.Region) + assert.Equal(t, "someAccessKeyId", s3Manager.SessionConfig.AccessKeyID) + assert.Equal(t, "someSecretAccessKey", s3Manager.SessionConfig.AccessKey) + assert.Equal(t, "someIAMRole", s3Manager.SessionConfig.IAMRoleARN) + assert.Equal(t, "someExternalID", s3Manager.SessionConfig.ExternalID) + assert.Equal(t, false, s3Manager.SessionConfig.RoleBasedAuth) +} + func TestGetSessionWithAccessKeys(t *testing.T) { s3Manager := S3Manager{ Config: &S3Config{ - Bucket: "someBucket", + Bucket: "someBucket", + Region: aws.String("someRegion"), + }, + SessionConfig: &awsutils.SessionConfig{ AccessKeyID: "someAccessKeyId", AccessKey: "someSecretAccessKey", - Region: aws.String("someRegion"), + Region: "someRegion", }, } awsSession, err := s3Manager.getSession(context.TODO()) @@ -26,10 +109,13 @@ func TestGetSessionWithAccessKeys(t *testing.T) { func TestGetSessionWithIAMRole(t *testing.T) { s3Manager := S3Manager{ Config: &S3Config{ - Bucket: "someBucket", + Bucket: "someBucket", + Region: aws.String("someRegion"), + }, + SessionConfig: &awsutils.SessionConfig{ IAMRoleARN: "someIAMRole", ExternalID: "someExternalID", - Region: aws.String("someRegion"), + Region: "someRegion", }, } awsSession, err := s3Manager.getSession(context.TODO()) @@ -37,33 +123,3 @@ func TestGetSessionWithIAMRole(t *testing.T) { assert.NotNil(t, awsSession) assert.NotNil(t, s3Manager.session) } - -func TestGetSessionConfigWithAccessKeys(t *testing.T) { - s3Manager := S3Manager{ - Config: &S3Config{ - Bucket: "someBucket", - AccessKeyID: "someAccessKeyId", - AccessKey: "someSecretAccessKey", - Region: aws.String("someRegion"), - }, - } - awsSessionConfig := s3Manager.getSessionConfig() - assert.NotNil(t, awsSessionConfig) - assert.Equal(t, s3Manager.Config.AccessKey, awsSessionConfig.AccessKey) - assert.Equal(t, s3Manager.Config.AccessKeyID, awsSessionConfig.AccessKeyID) -} - -func TestGetSessionConfigWithIAMRole(t *testing.T) { - s3Manager := S3Manager{ - Config: &S3Config{ - Bucket: "someBucket", - IAMRoleARN: "someIAMRole", - ExternalID: "someExternalID", - Region: aws.String("someRegion"), - }, - } - awsSessionConfig := s3Manager.getSessionConfig() - assert.NotNil(t, awsSessionConfig) - assert.Equal(t, s3Manager.Config.IAMRoleARN, awsSessionConfig.IAMRoleARN) - assert.Equal(t, s3Manager.Config.ExternalID, awsSessionConfig.ExternalID) -} diff --git a/utils/awsutils/session.go b/utils/awsutils/session.go index af4e836120..a46cdfe06b 100644 --- a/utils/awsutils/session.go +++ b/utils/awsutils/session.go @@ -24,6 +24,7 @@ type SessionConfig struct { RoleBasedAuth bool `mapstructure:"roleBasedAuth"` IAMRoleARN string `mapstructure:"iamRoleARN"` ExternalID string `mapstructure:"externalID"` + WorkspaceID string `mapstructure:"workspaceID"` Endpoint *string `mapstructure:"endpoint"` S3ForcePathStyle *bool `mapstructure:"s3ForcePathStyle"` DisableSSL *bool `mapstructure:"disableSSL"` @@ -52,7 +53,7 @@ func createDefaultSession(config *SessionConfig) (*session.Session, error) { }) } -func createCredentailsForRole(config *SessionConfig) (*credentials.Credentials, error) { +func createCredentialsForRole(config *SessionConfig) (*credentials.Credentials, error) { if config.ExternalID == "" { return nil, errors.New("externalID is required for IAM role") } @@ -73,7 +74,7 @@ func CreateSession(config *SessionConfig) (*session.Session, error) { err error ) if config.RoleBasedAuth { - awsCredentials, err = createCredentailsForRole(config) + awsCredentials, err = createCredentialsForRole(config) } else if config.AccessKey != "" && config.AccessKeyID != "" { awsCredentials, err = credentials.NewStaticCredentials(config.AccessKeyID, config.AccessKey, ""), nil } @@ -96,12 +97,12 @@ func isRoleBasedAuthFieldExist(config map[string]interface{}) bool { return ok } -func NewSimpleSessionConfigForDestination(destination *backendconfig.DestinationT, serviceName string) (*SessionConfig, error) { - if destination == nil { - return nil, errors.New("destination should not be nil") +func NewSimpleSessionConfig(config map[string]interface{}, serviceName string) (*SessionConfig, error) { + if config == nil { + return nil, errors.New("config should not be nil") } sessionConfig := SessionConfig{} - if err := mapstructure.Decode(destination.Config, &sessionConfig); err != nil { + if err := mapstructure.Decode(config, &sessionConfig); err != nil { return nil, fmt.Errorf("unable to populate session config using destinationConfig: %w", err) } @@ -109,7 +110,7 @@ func NewSimpleSessionConfigForDestination(destination *backendconfig.Destination return nil, errors.New("incompatible role configuration") } - if !isRoleBasedAuthFieldExist(destination.Config) { + if !isRoleBasedAuthFieldExist(config) { sessionConfig.RoleBasedAuth = sessionConfig.IAMRoleARN != "" } @@ -118,7 +119,18 @@ func NewSimpleSessionConfigForDestination(destination *backendconfig.Destination sessionConfig.AccessKey = sessionConfig.SecretAccessKey } sessionConfig.Service = serviceName - if sessionConfig.IAMRoleARN != "" { + return &sessionConfig, nil +} + +func NewSimpleSessionConfigForDestination(destination *backendconfig.DestinationT, serviceName string) (*SessionConfig, error) { + if destination == nil { + return nil, errors.New("destination should not be nil") + } + sessionConfig, err := NewSimpleSessionConfig(destination.Config, serviceName) + if err != nil { + return nil, err + } + if sessionConfig.IAMRoleARN != "" && sessionConfig.ExternalID == "" { /** In order prevent confused deputy problem, we are using workspace token as external ID. @@ -126,7 +138,7 @@ func NewSimpleSessionConfigForDestination(destination *backendconfig.Destination */ sessionConfig.ExternalID = destination.WorkspaceID } - return &sessionConfig, nil + return sessionConfig, nil } func NewSessionConfigForDestination(destination *backendconfig.DestinationT, timeout time.Duration, serviceName string) (*SessionConfig, error) { diff --git a/utils/awsutils/session_test.go b/utils/awsutils/session_test.go index 7531ace4ca..e8dda71d0b 100644 --- a/utils/awsutils/session_test.go +++ b/utils/awsutils/session_test.go @@ -26,6 +26,13 @@ var ( httpTimeout time.Duration = 10 * time.Second ) +func TestNewSessionConfigWithNilDestConfig(t *testing.T) { + serviceName := "kinesis" + sessionConfig, err := NewSessionConfigForDestination(&backendconfig.DestinationT{}, httpTimeout, serviceName) + assert.EqualError(t, err, "config should not be nil") + assert.Nil(t, sessionConfig) +} + func TestNewSessionConfigWithAccessKey(t *testing.T) { serviceName := "kinesis" sessionConfig, err := NewSessionConfigForDestination(&destinationWithAccessKey, httpTimeout, serviceName) diff --git a/warehouse/utils/utils.go b/warehouse/utils/utils.go index de9d72903a..3065e5bf88 100644 --- a/warehouse/utils/utils.go +++ b/warehouse/utils/utils.go @@ -807,6 +807,16 @@ func GetTemporaryS3Cred(destination *backendconfig.DestinationT) (string, string return "", "", "", err } + // Role already provides temporary credentials + // so we shouldn't call sts.GetSessionToken again + if sessionConfig.RoleBasedAuth { + creds, err := awsSession.Config.Credentials.Get() + if err != nil { + return "", "", "", err + } + return creds.AccessKeyID, creds.SecretAccessKey, creds.SessionToken, nil + } + // Create an STS client from just a session. svc := sts.New(awsSession)