diff --git a/credentials_getter.go b/credentials_getter.go index a43679b..e325e9d 100644 --- a/credentials_getter.go +++ b/credentials_getter.go @@ -13,7 +13,9 @@ import ( ) const ( - arnPrefix = "arn:aws:iam::" + arnPrefix = "arn:aws:iam::" + roleARNSuffix = ":role" + roleSessionNameMaxSize = 64 ) // CredentialsGetter can get credentials. @@ -52,7 +54,11 @@ func (c *STSCredentialsGetter) Get(role string, sessionDuration time.Duration) ( if strings.HasPrefix(role, arnPrefix) { roleARN = role } - roleSessionName := normalizeRoleARN(roleARN) + "-session" + + roleSessionName, err := normalizeRoleARN(roleARN) + if err != nil { + return nil, err + } params := &sts.AssumeRoleInput{ RoleArn: aws.String(roleARN), @@ -95,8 +101,22 @@ func GetBaseRoleARN(sess *session.Session) (string, error) { // normalizeRoleARN normalizes a role ARN by substituting special characters // with characters allowed for a RoleSessionName according to: // https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html -func normalizeRoleARN(roleARN string) string { - roleARN = strings.Replace(roleARN, ":", "_", -1) - roleARN = strings.Replace(roleARN, "/", ".", -1) - return roleARN +func normalizeRoleARN(roleARN string) (string, error) { + parts := strings.Split(roleARN, "/") + if len(parts) != 2 { + return "", fmt.Errorf("invalid roleARN: %s", roleARN) + } + + accountID := strings.TrimPrefix(parts[0], arnPrefix) + accountID = strings.TrimSuffix(accountID, roleARNSuffix) + + roleName := strings.Replace(parts[1], ":", "_", -1) + roleName = strings.Replace(roleName, "/", ".", -1) + + roleNameMaxSize := roleSessionNameMaxSize - 1 - len(accountID) + + if len(roleName) > roleNameMaxSize { + roleName = roleName[:roleNameMaxSize] + } + return accountID + "." + roleName, nil } diff --git a/credentials_getter_test.go b/credentials_getter_test.go index 9474a50..772c740 100644 --- a/credentials_getter_test.go +++ b/credentials_getter_test.go @@ -40,7 +40,8 @@ func TestGet(t *testing.T) { }, } - creds, err := getter.Get("role", 3600*time.Second) + roleARN := "arn:aws:iam::012345678910:role/role-name" + creds, err := getter.Get(roleARN, 3600*time.Second) require.NoError(t, err) require.Equal(t, "access_key_id", creds.AccessKeyID) require.Equal(t, "secret_access_key", creds.SecretAccessKey) @@ -58,3 +59,18 @@ func TestGet(t *testing.T) { // sess := &session.Session{} // baseRole, err := GetBaseRoleARN(sess) // } + +func TestNormalizeRoleARN(t *testing.T) { + roleARN := "arn:aws:iam::012345678910:role/role-name" + expectedARN := "012345678910.role-name" + normalized, err := normalizeRoleARN(roleARN) + require.NoError(t, err) + require.Equal(t, expectedARN, normalized) + + // truncate long role names + roleARN = "arn:aws:iam::012345678910:role/role-name-very-very-very-very-very-very-very-very-long" + expectedARN = "012345678910.role-name-very-very-very-very-very-very-very-very-l" + normalized, err = normalizeRoleARN(roleARN) + require.NoError(t, err) + require.Equal(t, expectedARN, normalized) +}