diff --git a/pkg/provider/identityprovider.go b/pkg/provider/identityprovider.go index 073eaa8..3a7c024 100644 --- a/pkg/provider/identityprovider.go +++ b/pkg/provider/identityprovider.go @@ -153,7 +153,7 @@ func (p *IdentityProvider) GetMetadata(ctx context.Context) (*md.IDPSSODescripto return nil, nil, err } - metadata, aaMetadata := p.conf.getMetadata(ctx, p.GetEntityID(ctx), cert) + metadata, aaMetadata := p.conf.getMetadata(ctx, p.GetEntityID(ctx), cert, p.timeFormat) return metadata, aaMetadata, nil } diff --git a/pkg/provider/logout.go b/pkg/provider/logout.go index f810a0e..1807947 100644 --- a/pkg/provider/logout.go +++ b/pkg/provider/logout.go @@ -69,6 +69,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque checkIfRequestTimeIsStillValid( func() string { return logoutRequest.IssueInstant }, func() string { return logoutRequest.NotOnOrAfter }, + p.timeFormat, ), func() { response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to validate request: %w", err).Error(), p.timeFormat)) diff --git a/pkg/provider/metadata.go b/pkg/provider/metadata.go index 4bfd969..59952ce 100644 --- a/pkg/provider/metadata.go +++ b/pkg/provider/metadata.go @@ -39,6 +39,7 @@ func (p *IdentityProviderConfig) getMetadata( ctx context.Context, entityID string, idpCertData []byte, + timeFormat string, ) (*md.IDPSSODescriptorType, *md.AttributeAuthorityDescriptorType) { endpoints := endpointConfigToEndpoints(p.Endpoints) @@ -80,7 +81,7 @@ func (p *IdentityProviderConfig) getMetadata( } validUntil := "" if p.MetadataIDPConfig.ValidUntil != 0 { - validUntil = time.Now().Add(p.MetadataIDPConfig.ValidUntil).UTC().Format(defaultTimeLayout) + validUntil = time.Now().Add(p.MetadataIDPConfig.ValidUntil).UTC().Format(timeFormat) } cacheDuration := "" if p.MetadataIDPConfig.CacheDuration != "" { diff --git a/pkg/provider/sso.go b/pkg/provider/sso.go index 5c21840..78f2c24 100644 --- a/pkg/provider/sso.go +++ b/pkg/provider/sso.go @@ -287,6 +287,7 @@ func checkRequestRequiredContent( if err := checkIfRequestTimeIsStillValid( func() string { return authNRequest.Conditions.NotBefore }, func() string { return authNRequest.Conditions.NotOnOrAfter }, + DefaultTimeFormat, )(); err != nil { return err } diff --git a/pkg/provider/time.go b/pkg/provider/time.go index 3c4bca9..40bbaf1 100644 --- a/pkg/provider/time.go +++ b/pkg/provider/time.go @@ -5,13 +5,11 @@ import ( "time" ) -const defaultTimeLayout = "2006-01-02T15:04:05.999999Z" - -func checkIfRequestTimeIsStillValid(notBefore func() string, notOnOrAfter func() string) func() error { +func checkIfRequestTimeIsStillValid(notBefore func() string, notOnOrAfter func() string, timeFormat string) func() error { return func() error { now := time.Now().UTC() if notBefore() != "" { - t, err := time.Parse(defaultTimeLayout, notBefore()) + t, err := time.Parse(timeFormat, notBefore()) if err != nil { return fmt.Errorf("failed to parse NotBefore: %w", err) } @@ -21,7 +19,7 @@ func checkIfRequestTimeIsStillValid(notBefore func() string, notOnOrAfter func() } if notOnOrAfter() != "" { - t, err := time.Parse(defaultTimeLayout, notOnOrAfter()) + t, err := time.Parse(timeFormat, notOnOrAfter()) if err != nil { return fmt.Errorf("failed to parse NotOnOrAfter: %w", err) } diff --git a/pkg/provider/time_test.go b/pkg/provider/time_test.go index c888c5a..e447057 100644 --- a/pkg/provider/time_test.go +++ b/pkg/provider/time_test.go @@ -5,6 +5,10 @@ import ( "time" ) +const ( + otherTimeFormat = "2006-01-02T15:04:05.999Z" +) + func TestTime_checkIfRequestTimeIsStillValid(t *testing.T) { type args struct { notBefore string @@ -20,40 +24,56 @@ func TestTime_checkIfRequestTimeIsStillValid(t *testing.T) { { "check ok 1", args{ - notBefore: now.Add(-1 * time.Minute).Format(defaultTimeLayout), - notOnOrAfter: now.Add(1 * time.Minute).Format(defaultTimeLayout), + notBefore: now.Add(-1 * time.Minute).Format(DefaultTimeFormat), + notOnOrAfter: now.Add(1 * time.Minute).Format(DefaultTimeFormat), }, false, }, { "check ok 2", args{ - notBefore: now.Add(-1 * time.Minute).Format(defaultTimeLayout), - notOnOrAfter: now.Add(5 * time.Minute).Format(defaultTimeLayout), + notBefore: now.Add(-1 * time.Minute).Format(DefaultTimeFormat), + notOnOrAfter: now.Add(5 * time.Minute).Format(DefaultTimeFormat), }, false, }, { "check ok 3", args{ - notBefore: now.Add(-5 * time.Minute).Format(defaultTimeLayout), - notOnOrAfter: now.Add(5 * time.Minute).Format(defaultTimeLayout), + notBefore: now.Add(-5 * time.Minute).Format(DefaultTimeFormat), + notOnOrAfter: now.Add(5 * time.Minute).Format(DefaultTimeFormat), + }, + false, + }, + { + "check ok otherformat", + args{ + notBefore: now.Add(-5 * time.Minute).Format(otherTimeFormat), + notOnOrAfter: now.Add(5 * time.Minute).Format(otherTimeFormat), }, false, }, { "check not ok 1", args{ - notBefore: now.Add(1 * time.Minute).Format(defaultTimeLayout), - notOnOrAfter: now.Add(5 * time.Minute).Format(defaultTimeLayout), + notBefore: now.Add(1 * time.Minute).Format(DefaultTimeFormat), + notOnOrAfter: now.Add(5 * time.Minute).Format(DefaultTimeFormat), }, true, }, { "check not ok 2", args{ - notBefore: now.Add(-5 * time.Minute).Format(defaultTimeLayout), - notOnOrAfter: now.Add(-1 * time.Minute).Format(defaultTimeLayout), + notBefore: now.Add(-5 * time.Minute).Format(DefaultTimeFormat), + notOnOrAfter: now.Add(-1 * time.Minute).Format(DefaultTimeFormat), + }, + true, + }, + { + "check not ok otherFormat", + args{ + notBefore: now.Add(-5 * time.Minute).Format(otherTimeFormat), + notOnOrAfter: now.Add(-1 * time.Minute).Format(otherTimeFormat), }, true, }, @@ -69,7 +89,7 @@ func TestTime_checkIfRequestTimeIsStillValid(t *testing.T) { "check ok only notOnOrAfter", args{ notBefore: "", - notOnOrAfter: now.Add(1 * time.Minute).Format(defaultTimeLayout), + notOnOrAfter: now.Add(1 * time.Minute).Format(DefaultTimeFormat), }, false, }, @@ -77,14 +97,14 @@ func TestTime_checkIfRequestTimeIsStillValid(t *testing.T) { "check not ok only notOnOrAfter", args{ notBefore: "", - notOnOrAfter: now.Add(-1 * time.Minute).Format(defaultTimeLayout), + notOnOrAfter: now.Add(-1 * time.Minute).Format(DefaultTimeFormat), }, true, }, { "check not ok only notBefore", args{ - notBefore: now.Add(1 * time.Minute).Format(defaultTimeLayout), + notBefore: now.Add(1 * time.Minute).Format(DefaultTimeFormat), notOnOrAfter: "", }, true, @@ -92,7 +112,7 @@ func TestTime_checkIfRequestTimeIsStillValid(t *testing.T) { { "check ok only notBefore", args{ - notBefore: now.Add(-1 * time.Minute).Format(defaultTimeLayout), + notBefore: now.Add(-1 * time.Minute).Format(DefaultTimeFormat), notOnOrAfter: "", }, false, @@ -124,7 +144,7 @@ func TestTime_checkIfRequestTimeIsStillValid(t *testing.T) { return tt.args.notOnOrAfter } - errF := checkIfRequestTimeIsStillValid(notBeforeF, notOnOrAfterF) + errF := checkIfRequestTimeIsStillValid(notBeforeF, notOnOrAfterF, DefaultTimeFormat) err := errF() if (err != nil) != tt.res { t.Errorf("ParseCertificates() got = %v, want %v", err != nil, tt.res)