diff --git a/cmd/sso.go b/cmd/sso.go index a9b4db7d3..cb3d0b83b 100644 --- a/cmd/sso.go +++ b/cmd/sso.go @@ -12,6 +12,7 @@ import ( "github.com/supabase/cli/internal/sso/update" "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/internal/utils/flags" + "github.com/supabase/cli/pkg/api" ) var ( @@ -25,6 +26,16 @@ var ( Allowed: []string{"saml"}, // intentionally no default value so users have to specify --type saml explicitly } + + ssoNameIDFormat = utils.EnumFlag{ + Allowed: []string{ + string(api.CreateProviderBodyNameIdFormatUrnOasisNamesTcSAML11NameidFormatEmailAddress), + string(api.CreateProviderBodyNameIdFormatUrnOasisNamesTcSAML11NameidFormatUnspecified), + string(api.CreateProviderBodyNameIdFormatUrnOasisNamesTcSAML20NameidFormatPersistent), + string(api.CreateProviderBodyNameIdFormatUrnOasisNamesTcSAML20NameidFormatTransient), + }, + } + ssoMetadataFile string ssoMetadataURL string ssoSkipURLValidation bool @@ -48,6 +59,7 @@ var ( MetadataURL: ssoMetadataURL, SkipURLValidation: ssoSkipURLValidation, AttributeMapping: ssoAttributeMappingFile, + NameIDFormat: ssoNameIDFormat.String(), Domains: ssoDomains, }) }, @@ -88,6 +100,7 @@ var ( MetadataURL: ssoMetadataURL, SkipURLValidation: ssoSkipURLValidation, AttributeMapping: ssoAttributeMappingFile, + NameIDFormat: ssoNameIDFormat.String(), Domains: ssoDomains, AddDomains: ssoAddDomains, RemoveDomains: ssoRemoveDomains, @@ -146,6 +159,7 @@ func init() { ssoAddFlags.StringVar(&ssoMetadataURL, "metadata-url", "", "URL pointing to a SAML 2.0 Metadata XML document describing the identity provider.") ssoAddFlags.BoolVar(&ssoSkipURLValidation, "skip-url-validation", false, "Whether local validation of the SAML 2.0 Metadata URL should not be performed.") ssoAddFlags.StringVar(&ssoAttributeMappingFile, "attribute-mapping-file", "", "File containing a JSON mapping between SAML attributes to custom JWT claims.") + ssoAddFlags.Var(&ssoNameIDFormat, "name-id-format", "URI reference representing the classification of string-based identifier information.") ssoAddCmd.MarkFlagsMutuallyExclusive("metadata-file", "metadata-url") cobra.CheckErr(ssoAddCmd.MarkFlagRequired("type")) cobra.CheckErr(ssoAddCmd.MarkFlagFilename("metadata-file", "xml")) @@ -159,6 +173,7 @@ func init() { ssoUpdateFlags.StringVar(&ssoMetadataURL, "metadata-url", "", "URL pointing to a SAML 2.0 Metadata XML document describing the identity provider.") ssoUpdateFlags.BoolVar(&ssoSkipURLValidation, "skip-url-validation", false, "Whether local validation of the SAML 2.0 Metadata URL should not be performed.") ssoUpdateFlags.StringVar(&ssoAttributeMappingFile, "attribute-mapping-file", "", "File containing a JSON mapping between SAML attributes to custom JWT claims.") + ssoUpdateFlags.Var(&ssoNameIDFormat, "name-id-format", "URI reference representing the classification of string-based identifier information.") ssoUpdateCmd.MarkFlagsMutuallyExclusive("metadata-file", "metadata-url") ssoUpdateCmd.MarkFlagsMutuallyExclusive("domains", "add-domains") ssoUpdateCmd.MarkFlagsMutuallyExclusive("domains", "remove-domains") diff --git a/internal/sso/create/create.go b/internal/sso/create/create.go index 15dea6cc7..68babc742 100644 --- a/internal/sso/create/create.go +++ b/internal/sso/create/create.go @@ -11,6 +11,7 @@ import ( "github.com/supabase/cli/internal/sso/internal/saml" "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/pkg/api" + "github.com/supabase/cli/pkg/cast" ) var Fs = afero.NewOsFs() @@ -25,6 +26,7 @@ type RunParams struct { MetadataURL string SkipURLValidation bool AttributeMapping string + NameIDFormat string } func Run(ctx context.Context, params RunParams) error { @@ -66,6 +68,10 @@ func Run(ctx context.Context, params RunParams) error { body.Domains = ¶ms.Domains } + if params.NameIDFormat != "" { + body.NameIdFormat = cast.Ptr(api.CreateProviderBodyNameIdFormat(params.NameIDFormat)) + } + resp, err := utils.GetSupabase().V1CreateASsoProviderWithResponse(ctx, params.ProjectRef, body) if err != nil { return errors.Errorf("failed to create sso provider: %w", err) diff --git a/internal/sso/internal/render/render.go b/internal/sso/internal/render/render.go index 64e73a2bf..3939672ec 100644 --- a/internal/sso/internal/render/render.go +++ b/internal/sso/internal/render/render.go @@ -65,6 +65,15 @@ func formatEntityID(provider api.GetProviderResponse) string { return entityID } +func formatNameIDFormat(provider api.GetProviderResponse) string { + nameIDFormat := "-" + if provider.Saml != nil && provider.Saml.NameIdFormat != nil && *provider.Saml.NameIdFormat != "" { + nameIDFormat = *provider.Saml.NameIdFormat + } + + return nameIDFormat +} + func ListMarkdown(providers api.ListProvidersResponse) error { markdownTable := []string{ "|TYPE|IDENTITY PROVIDER ID|DOMAINS|SAML 2.0 `EntityID`|CREATED AT (UTC)|UPDATED AT (UTC)|\n|-|-|-|-|-|-|\n", @@ -116,6 +125,11 @@ func SingleMarkdown(provider api.GetProviderResponse) error { "|SAML 2.0 `EntityID`|`%s`|", formatEntityID(provider), )) + + markdownTable = append(markdownTable, fmt.Sprintf( + "|NAMEID FORMAT|`%s`|", + formatNameIDFormat(provider), + )) } markdownTable = append(markdownTable, fmt.Sprintf( diff --git a/internal/sso/update/update.go b/internal/sso/update/update.go index 316a1360b..941b528bc 100644 --- a/internal/sso/update/update.go +++ b/internal/sso/update/update.go @@ -12,6 +12,7 @@ import ( "github.com/supabase/cli/internal/sso/internal/saml" "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/pkg/api" + "github.com/supabase/cli/pkg/cast" ) var Fs = afero.NewOsFs() @@ -25,6 +26,7 @@ type RunParams struct { MetadataURL string SkipURLValidation bool AttributeMapping string + NameIDFormat string Domains []string AddDomains []string @@ -111,6 +113,10 @@ func Run(ctx context.Context, params RunParams) error { body.Domains = &domains } + if params.NameIDFormat != "" { + body.NameIdFormat = cast.Ptr(api.UpdateProviderBodyNameIdFormat(params.NameIDFormat)) + } + putResp, err := utils.GetSupabase().V1UpdateASsoProviderWithResponse(ctx, params.ProjectRef, parsed, body) if err != nil { return errors.Errorf("failed to update sso provider: %w", err)