diff --git a/selfservice/hook/web_hook.go b/selfservice/hook/web_hook.go index 98bc24fc27c..d3a4107e05f 100644 --- a/selfservice/hook/web_hook.go +++ b/selfservice/hook/web_hook.go @@ -8,10 +8,13 @@ import ( "encoding/json" "fmt" "net/http" + "time" "github.com/pkg/errors" "github.com/tidwall/gjson" + "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" + semconv "go.opentelemetry.io/otel/semconv/v1.11.0" "go.opentelemetry.io/otel/trace" "github.com/ory/kratos/ui/node" @@ -29,7 +32,6 @@ import ( "github.com/ory/kratos/session" "github.com/ory/kratos/text" "github.com/ory/kratos/x" - "github.com/ory/x/otelx" ) var ( @@ -253,22 +255,6 @@ func (e *WebHook) ExecuteSettingsPrePersistHook(_ http.ResponseWriter, req *http } func (e *WebHook) execute(ctx context.Context, data *templateContext) error { - span := trace.SpanFromContext(ctx) - attrs := map[string]string{ - "webhook.http.method": data.RequestMethod, - "webhook.http.url": data.RequestURL, - "webhook.http.headers": fmt.Sprintf("%#v", data.RequestHeaders), - } - - if data.Identity != nil { - attrs["webhook.identity.id"] = data.Identity.ID.String() - } else { - attrs["webhook.identity.id"] = "" - } - - span.SetAttributes(otelx.StringAttrs(attrs)...) - defer span.End() - builder, err := request.NewBuilder(e.conf, e.deps) if err != nil { return err @@ -281,35 +267,78 @@ func (e *WebHook) execute(ctx context.Context, data *templateContext) error { return err } - errChan := make(chan error, 1) + attrs := semconv.HTTPClientAttributesFromHTTPRequest(req.Request) + if data.Identity != nil { + attrs = append(attrs, + attribute.String("webhook.identity.id", data.Identity.ID.String()), + attribute.String("webhook.identity.nid", data.Identity.NID.String()), + ) + } + + var ( + httpClient = e.deps.HTTPClient(ctx) + ignoreResponse = gjson.GetBytes(e.conf, "response.ignore").Bool() + canInterrupt = gjson.GetBytes(e.conf, "can_interrupt").Bool() + tracer = trace.SpanFromContext(ctx).TracerProvider().Tracer("kratos-webhooks") + spanOpts = []trace.SpanStartOption{trace.WithAttributes(attrs...)} + errChan = make(chan error, 1) + ) + + ctx, span := tracer.Start(ctx, "Webhook", spanOpts...) + e.deps.Logger().WithRequest(req.Request).Info("Dispatching webhook") + + req = req.WithContext(ctx) + if ignoreResponse { + // This is one of the few places where spawning a context.Background() is ok. We need to do this + // because the function runs asynchronously and we don't want to cancel the request if the + // incoming request context is cancelled. + // + // The webhook will still cancel after 30 seconds as that is the configured timeout for the HTTP client. + req = req.WithContext(context.Background()) + // spanOpts = append(spanOpts, trace.WithNewRoot()) + } + + startTime := time.Now() go func() { defer close(errChan) + defer span.End() - resp, err := e.deps.HTTPClient(ctx).Do(req.WithContext(ctx)) + resp, err := httpClient.Do(req) if err != nil { + span.SetStatus(codes.Error, err.Error()) errChan <- errors.WithStack(err) return } defer resp.Body.Close() + span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(resp.StatusCode)...) if resp.StatusCode >= http.StatusBadRequest { - if gjson.GetBytes(e.conf, "can_interrupt").Bool() { + span.SetStatus(codes.Error, "HTTP status code >= 400") + if canInterrupt { if err := parseWebhookResponse(resp); err != nil { + span.SetStatus(codes.Error, err.Error()) errChan <- err } } - errChan <- fmt.Errorf("web hook failed with status code %v", resp.StatusCode) - span.SetStatus(codes.Error, fmt.Sprintf("web hook failed with status code %v", resp.StatusCode)) + errChan <- fmt.Errorf("webhook failed with status code %v", resp.StatusCode) return } errChan <- nil }() - if gjson.GetBytes(e.conf, "response.ignore").Bool() { + if ignoreResponse { + traceID, spanID := span.SpanContext().TraceID(), span.SpanContext().SpanID() + logger := e.deps.Logger().WithField("otel", map[string]string{ + "trace_id": traceID.String(), + "span_id": spanID.String(), + }) go func() { - err := <-errChan - e.deps.Logger().WithError(err).Warning("A web hook request failed but the error was ignored because the configuration indicated that the upstream response should be ignored.") + if err := <-errChan; err != nil { + logger.WithField("duration", time.Since(startTime)).WithError(err).Warning("Webhook request failed but the error was ignored because the configuration indicated that the upstream response should be ignored.") + } else { + logger.WithField("duration", time.Since(startTime)).Info("Webhook request succeeded") + } }() return nil } @@ -323,7 +352,7 @@ func parseWebhookResponse(resp *http.Response) (err error) { } var hookResponse rawHookResponse if err := json.NewDecoder(resp.Body).Decode(&hookResponse); err != nil { - return errors.Wrap(err, "hook response could not be unmarshalled properly from JSON") + return errors.Wrap(err, "webhook response could not be unmarshalled properly from JSON") } var validationErrs []*schema.ValidationError @@ -343,11 +372,11 @@ func parseWebhookResponse(resp *http.Response) (err error) { Context: detail.Context, }) } - validationErrs = append(validationErrs, schema.NewHookValidationError(msg.InstancePtr, "a web-hook target returned an error", messages)) + validationErrs = append(validationErrs, schema.NewHookValidationError(msg.InstancePtr, "a webhook target returned an error", messages)) } if len(validationErrs) == 0 { - return errors.New("error while parsing hook response: got no validation errors") + return errors.New("error while parsing webhook response: got no validation errors") } return schema.NewValidationListError(validationErrs) diff --git a/selfservice/hook/web_hook_integration_test.go b/selfservice/hook/web_hook_integration_test.go index 1c7ddb21df0..87bec5b2400 100644 --- a/selfservice/hook/web_hook_integration_test.go +++ b/selfservice/hook/web_hook_integration_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" "github.com/ory/kratos/schema" @@ -365,7 +366,7 @@ func TestWebHooks(t *testing.T) { }`, ) - webhookError := schema.NewValidationListError([]*schema.ValidationError{schema.NewHookValidationError("#/traits/username", "a web-hook target returned an error", text.Messages{{ID: 1234, Type: "info", Text: "error message"}})}) + webhookError := schema.NewValidationListError([]*schema.ValidationError{schema.NewHookValidationError("#/traits/username", "a webhook target returned an error", text.Messages{{ID: 1234, Type: "info", Text: "error message"}})}) for _, tc := range []struct { uc string callWebHook func(wh *hook.WebHook, req *http.Request, f flow.Flow, s *session.Session) error @@ -839,3 +840,84 @@ func TestDisallowPrivateIPRanges(t *testing.T) { require.Contains(t, err.Error(), "192.168.178.0 is not a public IP address") }) } + +func TestAsyncWebhook(t *testing.T) { + _, reg := internal.NewFastRegistryWithMocks(t) + logger := logrusx.New("kratos", "test") + logHook := new(test.Hook) + logger.Logger.Hooks.Add(logHook) + whDeps := struct { + x.SimpleLoggerWithClient + *jsonnetsecure.TestProvider + }{ + x.SimpleLoggerWithClient{L: logger, C: reg.HTTPClient(context.Background()), T: otelx.NewNoop(logger, &otelx.Config{ServiceName: "kratos"})}, + jsonnetsecure.NewTestProvider(t), + } + + req := &http.Request{ + Header: map[string][]string{"Some-Header": {"Some-Value"}}, + Host: "www.ory.sh", + TLS: new(tls.ConnectionState), + URL: &url.URL{Path: "/some_end_point"}, + Method: http.MethodPost, + } + + incomingCtx, incomingCancel := context.WithCancel(context.Background()) + if deadline, ok := t.Deadline(); ok { + // cancel this context one second before test timeout for clean shutdown + var cleanup context.CancelFunc + incomingCtx, cleanup = context.WithDeadline(incomingCtx, deadline.Add(-time.Second)) + defer cleanup() + } + + req = req.WithContext(incomingCtx) + s := &session.Session{ID: x.NewUUID(), Identity: &identity.Identity{ID: x.NewUUID()}} + f := &login.Flow{ID: x.NewUUID()} + + handlerEntered, blockHandlerOnExit := make(chan struct{}), make(chan struct{}) + webhookReceiver := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(handlerEntered) + <-blockHandlerOnExit + w.Write([]byte("ok")) + })) + t.Cleanup(webhookReceiver.Close) + + wh := hook.NewWebHook(&whDeps, json.RawMessage(fmt.Sprintf(` + { + "url": %q, + "method": "GET", + "body": "file://stub/test_body.jsonnet", + "response": { + "ignore": true + } + }`, webhookReceiver.URL))) + err := wh.ExecuteLoginPostHook(nil, req, node.DefaultGroup, f, s) + require.NoError(t, err) // execution returns immediately for async webhook + select { + case <-time.After(200 * time.Millisecond): + t.Fatal("timed out waiting for webhook request to reach test handler") + case <-handlerEntered: + // ok + } + // at this point, a goroutine is in the middle of the call to our test handler and waiting for a response + incomingCancel() // simulate the incoming Kratos request having finished + close(blockHandlerOnExit) + timeout := time.After(200 * time.Millisecond) + var found bool + for !found { + for _, entry := range logHook.AllEntries() { + if entry.Message == "Webhook request succeeded" { + found = true + break + } + } + + select { + case <-timeout: + t.Fatal("timed out waiting for successful webhook completion") + case <-time.After(50 * time.Millisecond): + // continue loop + } + } + require.True(t, found) +}