From 090601873f19f663cd5edefb1e6f33dfc7ab53cc Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Fri, 2 Dec 2022 11:42:56 -0700 Subject: [PATCH] urlutil: add time validation functions (#3776) --- internal/urlutil/query_params.go | 6 ++- internal/urlutil/time.go | 43 ++++++++++++++++++++++ internal/urlutil/time_test.go | 63 ++++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 internal/urlutil/time.go create mode 100644 internal/urlutil/time_test.go diff --git a/internal/urlutil/query_params.go b/internal/urlutil/query_params.go index 964cad605a0..9992dee8039 100644 --- a/internal/urlutil/query_params.go +++ b/internal/urlutil/query_params.go @@ -8,12 +8,16 @@ const ( QueryDeviceCredentialID = "pomerium_device_credential_id" QueryDeviceType = "pomerium_device_type" QueryEnrollmentToken = "pomerium_enrollment_token" //nolint + QueryExpiry = "pomerium_expiry" + QueryIdentityProfile = "pomerium_identity_profile" QueryIdentityProviderID = "pomerium_idp_id" QueryIsProgrammatic = "pomerium_programmatic" + QueryIssued = "pomerium_issued" QueryPomeriumJWT = "pomerium_jwt" + QueryRedirectURI = "pomerium_redirect_uri" QuerySession = "pomerium_session" QuerySessionEncrypted = "pomerium_session_encrypted" - QueryRedirectURI = "pomerium_redirect_uri" + QuerySessionState = "pomerium_session_state" ) // URL signature based query params used for verifying the authenticity of a URL. diff --git a/internal/urlutil/time.go b/internal/urlutil/time.go new file mode 100644 index 00000000000..4e0d179a892 --- /dev/null +++ b/internal/urlutil/time.go @@ -0,0 +1,43 @@ +package urlutil + +import ( + "fmt" + "net/url" + "strconv" + "time" +) + +// BuildTimeParameters adds the issued and expiry timestamps to the query parameters. +func BuildTimeParameters(params url.Values, expiry time.Duration) { + now := time.Now() + + params.Set(QueryIssued, fmt.Sprint(now.UnixMilli())) + params.Set(QueryExpiry, fmt.Sprint(now.Add(expiry).UnixMilli())) +} + +// ValidateTimeParameters validates that the issued and expiry timestamps in the query parameters are valid. +func ValidateTimeParameters(params url.Values) error { + now := time.Now() + + issuedMS, err := strconv.ParseInt(params.Get(QueryIssued), 10, 64) + if err != nil { + return fmt.Errorf("invalid issued timestamp: %w", err) + } + issued := time.UnixMilli(issuedMS) + + if now.Add(DefaultLeeway).Before(issued) { + return ErrIssuedInTheFuture + } + + expiryMS, err := strconv.ParseInt(params.Get(QueryExpiry), 10, 64) + if err != nil { + return fmt.Errorf("invalid expiry timestamp: %w", err) + } + expiry := time.UnixMilli(expiryMS) + + if now.Add(-DefaultLeeway).After(expiry) { + return ErrExpired + } + + return nil +} diff --git a/internal/urlutil/time_test.go b/internal/urlutil/time_test.go new file mode 100644 index 00000000000..e1d2f27371b --- /dev/null +++ b/internal/urlutil/time_test.go @@ -0,0 +1,63 @@ +package urlutil + +import ( + "fmt" + "net/url" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestBuildTimeParameters(t *testing.T) { + t.Parallel() + + params := make(url.Values) + BuildTimeParameters(params, time.Minute) + assert.True(t, params.Has(QueryIssued)) + assert.True(t, params.Has(QueryExpiry)) + + ms1, _ := strconv.Atoi(params.Get(QueryIssued)) + ms2, _ := strconv.Atoi(params.Get(QueryExpiry)) + assert.Equal(t, 60000, ms2-ms1) +} + +func TestValidateTimeParameters(t *testing.T) { + t.Parallel() + + msNow := time.Now().UnixMilli() + for _, tc := range []struct { + name string + params url.Values + err string + }{ + {"empty", url.Values{}, "invalid issued timestamp"}, + {"missing issued", url.Values{QueryExpiry: {fmt.Sprint(msNow + 10000)}}, "invalid issued timestamp"}, + {"missing expiry", url.Values{QueryIssued: {fmt.Sprint(msNow + 10000)}}, "invalid expiry timestamp"}, + {"invalid issued", url.Values{ + QueryIssued: {fmt.Sprint(msNow + 120000)}, + QueryExpiry: {fmt.Sprint(msNow + 240000)}, + }, "issued in the future"}, + {"invalid expiry", url.Values{ + QueryIssued: {fmt.Sprint(msNow - 120000)}, + QueryExpiry: {fmt.Sprint(msNow - 240000)}, + }, "expired"}, + {"valid", url.Values{ + QueryIssued: {fmt.Sprint(msNow)}, + QueryExpiry: {fmt.Sprint(msNow)}, + }, ""}, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + err := ValidateTimeParameters(tc.params) + if tc.err == "" { + assert.NoError(t, err) + } else { + assert.ErrorContains(t, err, tc.err) + } + }) + } +}