Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ThibHrrd committed Aug 31, 2022
1 parent c9d7693 commit ffe5337
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 24 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ require (
github.com/pkg/errors v0.9.1
github.com/pquerna/otp v1.3.0
github.com/rs/cors v1.8.2
github.com/russellhaering/goxmldsig v1.1.1
github.com/sirupsen/logrus v1.8.1
github.com/slack-go/slack v0.7.4
github.com/spf13/cobra v1.5.0
Expand Down Expand Up @@ -266,7 +267,6 @@ require (
github.com/rivo/uniseg v0.2.0 // indirect
github.com/rjeczalik/notify v0.0.0-20181126183243-629144ba06a1 // indirect
github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/russellhaering/goxmldsig v1.1.1 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/seatgeek/logrus-gelf-formatter v0.0.0-20210414080842-5b05eb8ff761 // indirect
github.com/segmentio/backo-go v0.0.0-20200129164019-23eae7c10bd3 // indirect
Expand Down
20 changes: 10 additions & 10 deletions selfservice/flow/saml/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ func (h *Handler) RegisterPublicRoutes(router *x.RouterPublic) {

// Handle /selfservice/methods/saml/metadata
func (h *Handler) serveMetadata(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
config := h.d.Config(r.Context())
config := h.d.Config()
if samlMiddleware == nil {
if err := h.instantiateMiddleware(*config); err != nil {
if err := h.instantiateMiddleware(r.Context(), *config); err != nil {
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
}
}
Expand All @@ -108,13 +108,13 @@ func (h *Handler) serveMetadata(w http.ResponseWriter, r *http.Request, ps httpr
func (h *Handler) loginWithIdp(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
// Middleware is a singleton so we have to verify that it exists
if samlMiddleware == nil {
config := h.d.Config(r.Context())
if err := h.instantiateMiddleware(*config); err != nil {
config := h.d.Config()
if err := h.instantiateMiddleware(r.Context(), *config); err != nil {
h.d.SelfServiceErrorManager().Forward(r.Context(), w, r, err)
}
}

conf := h.d.Config(r.Context())
conf := h.d.Config()

// We have to get the SessionID from the cookie to inject it into the context to ensure continuity
cookie, err := r.Cookie(continuity.CookieName)
Expand All @@ -132,7 +132,7 @@ func (h *Handler) loginWithIdp(w http.ResponseWriter, r *http.Request, ps httpro
samlMiddleware.HandleStartAuthFlow(w, r)
} else {
// A session already exist, we redirect to the main page
http.Redirect(w, r, conf.SelfServiceBrowserDefaultReturnTo().Path, http.StatusTemporaryRedirect)
http.Redirect(w, r, conf.SelfServiceBrowserDefaultReturnTo(r.Context()).Path, http.StatusTemporaryRedirect)
}
}

Expand All @@ -142,10 +142,10 @@ func DestroyMiddlewareIfExists() {
}
}

func (h *Handler) instantiateMiddleware(config config.Config) error {
func (h *Handler) instantiateMiddleware(ctx context.Context, config config.Config) error {
// Create a SAMLProvider object from the config file
var c samlstrategy.ConfigurationCollection
conf := config.SelfServiceStrategy("saml").Config
conf := config.SelfServiceStrategy(ctx, "saml").Config
if err := jsonx.
NewStrictDecoder(bytes.NewBuffer(conf)).
Decode(&c); err != nil {
Expand Down Expand Up @@ -242,7 +242,7 @@ func (h *Handler) instantiateMiddleware(config config.Config) error {
}

// The main URL
rootURL, err := url.Parse(config.SelfServiceBrowserDefaultReturnTo().String())
rootURL, err := url.Parse(config.SelfServiceBrowserDefaultReturnTo(ctx).String())
if err != nil {
return err
}
Expand Down Expand Up @@ -275,7 +275,7 @@ func (h *Handler) instantiateMiddleware(config config.Config) error {
// It's better to use SHA256 than SHA1
samlMiddleWare.ServiceProvider.SignatureMethod = dsig.RSASHA256SignatureMethod

var publicUrlString = config.SelfPublicURL().String()
var publicUrlString = config.SelfPublicURL(ctx).String()

// Sometimes there is an issue with double slash into the url so we prevent it
// Crewjam library use default route for ACS and metadat but we want to overwrite them
Expand Down
12 changes: 6 additions & 6 deletions selfservice/flow/saml/helpertest/helpertest.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package helpertest

import (
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
Expand Down Expand Up @@ -29,7 +30,6 @@ import (
"github.com/ory/kratos/internal"
"github.com/ory/kratos/internal/testhelpers"
samlhandler "github.com/ory/kratos/selfservice/flow/saml"
"github.com/ory/kratos/selfservice/strategy/saml"
samlstrategy "github.com/ory/kratos/selfservice/strategy/saml"
samlstrat "github.com/ory/kratos/selfservice/strategy/saml/strategy"
"github.com/ory/kratos/x"
Expand All @@ -56,8 +56,8 @@ func NewSAMLProvider(
}

func ViperSetProviderConfig(t *testing.T, conf *config.Config, SAMLProvider ...samlstrategy.Configuration) {
conf.MustSet(config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypeSAML)+".config", &samlstrategy.ConfigurationCollection{SAMLProviders: SAMLProvider})
conf.MustSet(config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypeSAML)+".enabled", true)
conf.MustSet(context.Background(), config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypeSAML)+".config", &samlstrategy.ConfigurationCollection{SAMLProviders: SAMLProvider})
conf.MustSet(context.Background(), config.ViperKeySelfServiceStrategyConfig+"."+string(identity.CredentialsTypeSAML)+".enabled", true)
}

func NewClient(t *testing.T, jar *cookiejar.Jar) *http.Client {
Expand Down Expand Up @@ -132,7 +132,7 @@ func InitMiddleware(t *testing.T, idpInformation map[string]string) (*samlsp.Mid
t,
conf,
NewSAMLProvider(t, ts, "samlProviderTestID", "samlProviderTestLabel"),
saml.Configuration{
samlstrategy.Configuration{
ID: "samlProviderTestID",
Label: "samlProviderTestLabel",
PublicCertPath: "file://testdata/myservice.cert",
Expand All @@ -143,9 +143,9 @@ func InitMiddleware(t *testing.T, idpInformation map[string]string) (*samlsp.Mid
},
)

conf.MustSet(config.ViperKeySelfServiceRegistrationEnabled, true)
conf.MustSet(context.Background(), config.ViperKeySelfServiceRegistrationEnabled, true)
testhelpers.SetDefaultIdentitySchema(conf, "file://testdata/registration.schema.json")
conf.MustSet(config.HookStrategyKey(config.ViperKeySelfServiceRegistrationAfter,
conf.MustSet(context.Background(), config.HookStrategyKey(config.ViperKeySelfServiceRegistrationAfter,
identity.CredentialsTypeSAML.String()), []config.SelfServiceHook{{Name: "session"}})

t.Logf("Kratos Public URL: %s", ts.URL)
Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/saml/provider_saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (d *ProviderSAML) Claims(ctx context.Context, config *config.Config, attrib

var c ConfigurationCollection

conf := config.SelfServiceStrategy("saml").Config
conf := config.SelfServiceStrategy(ctx, "saml").Config
if err := jsonx.
NewStrictDecoder(bytes.NewBuffer(conf)).
Decode(&c); err != nil {
Expand Down
8 changes: 4 additions & 4 deletions selfservice/strategy/saml/strategy/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func (s *Strategy) alreadyAuthenticated(w http.ResponseWriter, r *http.Request,
if _, ok := req.(*settings.Flow); ok {
// ignore this if it's a settings flow
} else if !isForced(req) {
http.Redirect(w, r, s.d.Config(r.Context()).SelfServiceBrowserDefaultReturnTo().String(), http.StatusSeeOther)
http.Redirect(w, r, s.d.Config().SelfServiceBrowserDefaultReturnTo(r.Context()).String(), http.StatusSeeOther)
return true
}
}
Expand Down Expand Up @@ -306,7 +306,7 @@ func (s *Strategy) handleCallback(w http.ResponseWriter, r *http.Request, ps htt
}

// We translate SAML Attributes into claims (To create an identity we need these claims)
claims, err := provider.Claims(r.Context(), s.d.Config(r.Context()), attributes)
claims, err := provider.Claims(r.Context(), s.d.Config(), attributes)
if err != nil {
s.forwardError(w, r, err)
return
Expand Down Expand Up @@ -348,7 +348,7 @@ func (s *Strategy) Provider(ctx context.Context) (samlstrategy.Provider, error)
func (s *Strategy) Config(ctx context.Context) (*samlstrategy.ConfigurationCollection, error) {
var c samlstrategy.ConfigurationCollection

conf := s.d.Config(ctx).SelfServiceStrategy(string(s.ID())).Config
conf := s.d.Config().SelfServiceStrategy(ctx, string(s.ID())).Config
if err := jsonx.
NewStrictDecoder(bytes.NewBuffer(conf)).
Decode(&c); err != nil {
Expand Down Expand Up @@ -387,7 +387,7 @@ func (s *Strategy) handleError(w http.ResponseWriter, r *http.Request, f flow.Fl
AddProvider(rf.UI, provider, text.NewInfoRegistrationContinue())

if traits != nil {
ds, err := s.d.Config(r.Context()).DefaultIdentityTraitsSchemaURL()
ds, err := s.d.Config().DefaultIdentityTraitsSchemaURL(r.Context())
if err != nil {
return err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func (s *Strategy) PopulateRegistrationMethod(r *http.Request, f *registration.F
}

func (s *Strategy) newLinkDecoder(p interface{}, r *http.Request) error {
ds, err := s.d.Config(r.Context()).DefaultIdentityTraitsSchemaURL()
ds, err := s.d.Config().DefaultIdentityTraitsSchemaURL(r.Context())
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/saml/strategy/test/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func TestGetRegistrationIdentity(t *testing.T) {
provider, _ := strategy.Provider(context.Background())
assertion, _ := helpertest.GetAndDecryptAssertion(t, "./testdata/SP_SamlResponse.xml", middleware.ServiceProvider.Key)
attributes, _ := strategy.GetAttributesFromAssertion(assertion)
claims, _ := provider.Claims(context.Background(), strategy.D().Config(context.Background()), attributes)
claims, _ := provider.Claims(context.Background(), strategy.D().Config(), attributes)

i, err := strategy.GetRegistrationIdentity(nil, context.Background(), provider, claims, false)
require.NoError(t, err)
Expand Down

0 comments on commit ffe5337

Please sign in to comment.