diff --git a/selfservice/flow/registration/request.go b/selfservice/flow/registration/request.go index 89bfa171da6..09fd151c4c0 100644 --- a/selfservice/flow/registration/request.go +++ b/selfservice/flow/registration/request.go @@ -73,11 +73,9 @@ func NewRequest(exp time.Duration, csrf string, r *http.Request) *Request { source := urlx.Copy(r.URL) source.Host = r.Host - if len(source.Scheme) == 0 { - source.Scheme = "http" - if r.TLS != nil { - source.Scheme = "https" - } + source.Scheme = "http" + if r.TLS != nil { + source.Scheme = "https" } return &Request{ diff --git a/selfservice/flow/registration/request_test.go b/selfservice/flow/registration/request_test.go index ca160b186cc..70fac134e49 100644 --- a/selfservice/flow/registration/request_test.go +++ b/selfservice/flow/registration/request_test.go @@ -1,6 +1,8 @@ package registration_test import ( + "crypto/tls" + "net/http" "testing" "time" @@ -8,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/ory/x/urlx" + "github.com/ory/kratos/selfservice/flow/registration" "github.com/ory/kratos/x" ) @@ -28,6 +32,34 @@ func TestFakeRequestData(t *testing.T) { } } +func TestNewRequest(t *testing.T) { + t.Run("case=0", func(t *testing.T) { + r := registration.NewRequest(0, "csrf", &http.Request{ + URL: urlx.ParseOrPanic("/"), + Host: "ory.sh", TLS: &tls.ConnectionState{}, + }) + assert.Equal(t, r.IssuedAt, r.ExpiresAt) + // assert.Equal(t, flow.TypeBrowser, r.Type) + assert.Equal(t, "https://ory.sh/", r.RequestURL) + }) + + t.Run("case=1", func(t *testing.T) { + r := registration.NewRequest(0, "csrf", &http.Request{ + URL: urlx.ParseOrPanic("/"), + Host: "ory.sh"}) + assert.Equal(t, r.IssuedAt, r.ExpiresAt) + // assert.Equal(t, flow.TypeBrowser, r.Type) + assert.Equal(t, "http://ory.sh/", r.RequestURL) + }) + + t.Run("case=2", func(t *testing.T) { + r := registration.NewRequest(0, "csrf", &http.Request{ + URL: urlx.ParseOrPanic("https://ory.sh/"), + Host: "ory.sh"}) + assert.Equal(t, "http://ory.sh/", r.RequestURL) + }) +} + func TestRequest(t *testing.T) { r := ®istration.Request{ID: x.NewUUID()} assert.Equal(t, r.ID, r.GetID())