diff --git a/pkg/http/handler.go b/pkg/http/handler.go index 26bbf868..278f0545 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -4,7 +4,6 @@ import ( "encoding/json" "log/slog" baseHttp "net/http" - "strconv" "github.com/getsentry/sentry-go" ) @@ -37,25 +36,18 @@ func captureApiError(r *baseHttp.Request, apiErr *ApiError) { return } - level := sentry.LevelWarning - if apiErr.Status >= baseHttp.StatusInternalServerError { - level = sentry.LevelError + errToCapture := error(apiErr) + if apiErr.Err != nil { + errToCapture = apiErr.Err } notify := func(hub *sentry.Hub) { hub.WithScope(func(scope *sentry.Scope) { - scope.SetLevel(level) - scope.SetTag("http.method", r.Method) - scope.SetTag("http.status_code", strconv.Itoa(apiErr.Status)) - scope.SetTag("http.route", r.URL.Path) - scope.SetRequest(r) - scope.SetExtra("api_error_status_text", baseHttp.StatusText(apiErr.Status)) + scopeApiError := NewScopeApiError(scope, r, apiErr) - if apiErr.Data != nil { - scope.SetExtra("api_error_data", apiErr.Data) - } + scopeApiError.Enrich() - hub.CaptureException(apiErr) + hub.CaptureException(errToCapture) }) } diff --git a/pkg/http/handler_test.go b/pkg/http/handler_test.go index 87e33c80..cbc8d603 100644 --- a/pkg/http/handler_test.go +++ b/pkg/http/handler_test.go @@ -1,10 +1,16 @@ package http import ( + "context" "encoding/json" + "errors" + "fmt" "net/http" "net/http/httptest" "testing" + + "github.com/getsentry/sentry-go" + "github.com/oullin/pkg/portal" ) func TestMakeApiHandler(t *testing.T) { @@ -13,6 +19,7 @@ func TestMakeApiHandler(t *testing.T) { return &ApiError{ Message: "bad", Status: http.StatusBadRequest, + Err: errors.New("bad"), } }) @@ -33,3 +40,104 @@ func TestMakeApiHandler(t *testing.T) { t.Fatalf("invalid response") } } + +func TestScopeApiErrorRequestID(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(portal.RequestIDHeader, "header-id") + + scopeApiError := &ScopeApiError{request: req} + + if got := scopeApiError.RequestID(); got != "header-id" { + t.Fatalf("expected header request id, got %s", got) + } + + ctxReq := req.WithContext(context.WithValue(req.Context(), portal.RequestIDKey, "context-id")) + + scopeApiError.request = ctxReq + + if got := scopeApiError.RequestID(); got != "context-id" { + t.Fatalf("expected context request id, got %s", got) + } +} + +func TestScopeApiErrorAccountName(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set(portal.UsernameHeader, "header-user") + + scopeApiError := &ScopeApiError{request: req} + + if got := scopeApiError.accountName(); got != "header-user" { + t.Fatalf("expected header user, got %s", got) + } + + ctxReq := req.WithContext(context.WithValue(req.Context(), portal.AuthAccountNameKey, "context-user")) + + scopeApiError.request = ctxReq + + if got := scopeApiError.accountName(); got != "context-user" { + t.Fatalf("expected context user, got %s", got) + } +} + +func TestScopeApiErrorBuildErrorChain(t *testing.T) { + root := errors.New("root") + wrapped := fmt.Errorf("layer: %w", root) + + chain := (&ScopeApiError{}).buildErrorChain(wrapped) + + if len(chain) != 2 { + t.Fatalf("expected 2 errors in chain, got %d", len(chain)) + } + + if chain[0] != wrapped.Error() || chain[1] != root.Error() { + t.Fatalf("unexpected error chain: %#v", chain) + } +} + +func TestScopeApiErrorEnrichSetsLevelAndTags(t *testing.T) { + scope := sentry.NewScope() + req := httptest.NewRequest("POST", "/resource", nil) + + apiErr := &ApiError{Status: http.StatusInternalServerError, Err: errors.New("boom")} + + NewScopeApiError(scope, req, apiErr).Enrich() + + event := scope.ApplyToEvent(sentry.NewEvent(), nil, nil) + if event == nil { + t.Fatalf("expected event after scope enrichment") + } + + if event.Level != sentry.LevelError { + t.Fatalf("expected error level, got %s", event.Level) + } + + if got := event.Tags["http.method"]; got != "POST" { + t.Fatalf("expected POST method tag, got %s", got) + } + + if got := event.Tags["http.status_code"]; got != "500" { + t.Fatalf("expected 500 status code tag, got %s", got) + } + + if got := event.Tags["http.route"]; got != "/resource" { + t.Fatalf("expected /resource route tag, got %s", got) + } +} + +func TestScopeApiErrorEnrichSetsWarningLevelForClientErrors(t *testing.T) { + scope := sentry.NewScope() + req := httptest.NewRequest("GET", "/client", nil) + + apiErr := &ApiError{Status: http.StatusBadRequest, Err: errors.New("bad request")} + + NewScopeApiError(scope, req, apiErr).Enrich() + + event := scope.ApplyToEvent(sentry.NewEvent(), nil, nil) + if event == nil { + t.Fatalf("expected event after scope enrichment") + } + + if event.Level != sentry.LevelWarning { + t.Fatalf("expected warning level, got %s", event.Level) + } +} diff --git a/pkg/http/response.go b/pkg/http/response.go index 188d41f8..ffeb7548 100644 --- a/pkg/http/response.go +++ b/pkg/http/response.go @@ -2,6 +2,7 @@ package http import ( "encoding/json" + "errors" "fmt" "log/slog" baseHttp "net/http" @@ -86,9 +87,12 @@ func (r *Response) RespondWithNotModified() { } func InternalError(msg string) *ApiError { + message := fmt.Sprintf("Internal server error: %s", msg) + return &ApiError{ - Message: fmt.Sprintf("Internal server error: %s", msg), + Message: message, Status: baseHttp.StatusInternalServerError, + Err: errors.New(message), } } @@ -98,13 +102,17 @@ func LogInternalError(msg string, err error) *ApiError { return &ApiError{ Message: fmt.Sprintf("Internal server error: %s", msg), Status: baseHttp.StatusInternalServerError, + Err: err, } } func BadRequestError(msg string) *ApiError { + message := fmt.Sprintf("Bad request error: %s", msg) + return &ApiError{ - Message: fmt.Sprintf("Bad request error: %s", msg), + Message: message, Status: baseHttp.StatusBadRequest, + Err: errors.New(message), } } @@ -114,6 +122,7 @@ func LogBadRequestError(msg string, err error) *ApiError { return &ApiError{ Message: fmt.Sprintf("Bad request error: %s", msg), Status: baseHttp.StatusBadRequest, + Err: err, } } @@ -123,20 +132,27 @@ func LogUnauthorisedError(msg string, err error) *ApiError { return &ApiError{ Message: fmt.Sprintf("Unauthorised request: %s", msg), Status: baseHttp.StatusUnauthorized, + Err: err, } } -func UnprocessableEntity(msg string, errors map[string]any) *ApiError { +func UnprocessableEntity(msg string, errs map[string]any) *ApiError { + message := fmt.Sprintf("Unprocessable entity: %s", msg) + return &ApiError{ - Message: fmt.Sprintf("Unprocessable entity: %s", msg), + Message: message, Status: baseHttp.StatusUnprocessableEntity, - Data: errors, + Data: errs, + Err: errors.New(message), } } func NotFound(msg string) *ApiError { + message := fmt.Sprintf("Not found error: %s", msg) + return &ApiError{ - Message: fmt.Sprintf("Not found error: %s", msg), + Message: message, Status: baseHttp.StatusNotFound, + Err: errors.New(message), } } diff --git a/pkg/http/schema.go b/pkg/http/schema.go index 08c879ef..5492fd36 100644 --- a/pkg/http/schema.go +++ b/pkg/http/schema.go @@ -12,6 +12,7 @@ type ApiError struct { Message string `json:"message"` Status int `json:"status"` Data map[string]any `json:"data"` + Err error `json:"-"` } func (e *ApiError) Error() string { @@ -22,6 +23,14 @@ func (e *ApiError) Error() string { return e.Message } +func (e *ApiError) Unwrap() error { + if e == nil { + return nil + } + + return e.Err +} + type ApiHandler func(baseHttp.ResponseWriter, *baseHttp.Request) *ApiError type Middleware func(ApiHandler) ApiHandler diff --git a/pkg/http/schema_test.go b/pkg/http/schema_test.go index a68ae0f7..9a2d9bd8 100644 --- a/pkg/http/schema_test.go +++ b/pkg/http/schema_test.go @@ -1,11 +1,15 @@ package http -import "testing" +import ( + "errors" + "testing" +) func TestApiErrorError(t *testing.T) { e := &ApiError{ Message: "boom", Status: 500, + Err: errors.New("boom"), } if e.Error() != "boom" { @@ -18,3 +22,25 @@ func TestApiErrorError(t *testing.T) { t.Fatalf("nil error wrong") } } + +func TestApiErrorUnwrap(t *testing.T) { + cause := errors.New("root cause") + e := &ApiError{ + Message: "boom", + Status: 500, + Err: cause, + } + + if !errors.Is(e, cause) { + t.Fatalf("expected errors.Is to match the wrapped cause") + } + + if got := e.Unwrap(); got != cause { + t.Fatalf("expected unwrap to return the cause") + } + + var nilErr *ApiError + if nilErr.Unwrap() != nil { + t.Fatalf("expected nil unwrap to be nil") + } +} diff --git a/pkg/http/scope_api_error.go b/pkg/http/scope_api_error.go new file mode 100644 index 00000000..34b10a6c --- /dev/null +++ b/pkg/http/scope_api_error.go @@ -0,0 +1,132 @@ +package http + +import ( + "errors" + "fmt" + baseHttp "net/http" + "strconv" + "strings" + + "github.com/getsentry/sentry-go" + "github.com/oullin/pkg/portal" +) + +type ScopeApiError struct { + scope *sentry.Scope + request *baseHttp.Request + apiErr *ApiError +} + +func NewScopeApiError(scope *sentry.Scope, r *baseHttp.Request, apiErr *ApiError) *ScopeApiError { + return &ScopeApiError{scope: scope, request: r, apiErr: apiErr} +} + +func (s *ScopeApiError) RequestID() string { + if s == nil || s.request == nil { + return "" + } + + if v, ok := s.request.Context().Value(portal.RequestIDKey).(string); ok { + if id := strings.TrimSpace(v); id != "" { + return id + } + } + + return s.headerValue(portal.RequestIDHeader) +} + +func (s *ScopeApiError) Enrich() { + if s == nil || s.scope == nil || s.request == nil || s.apiErr == nil { + return + } + + level := sentry.LevelWarning + if s.apiErr.Status >= baseHttp.StatusInternalServerError { + level = sentry.LevelError + } + + s.scope.SetLevel(level) + s.scope.SetTag("http.method", s.request.Method) + s.scope.SetTag("http.status_code", strconv.Itoa(s.apiErr.Status)) + s.scope.SetTag("http.route", s.request.URL.Path) + + s.scope.SetRequest(s.request) + s.scope.SetExtra("api_error_status_text", baseHttp.StatusText(s.apiErr.Status)) + s.scope.SetExtra("api_error_message", s.apiErr.Message) + + if requestID := s.RequestID(); requestID != "" { + s.scope.SetTag("http.request_id", requestID) + s.scope.SetExtra("http_request_id", requestID) + } + + if s.apiErr.Data != nil { + s.scope.SetExtra("api_error_data", s.apiErr.Data) + } + + if s.apiErr.Err != nil { + s.scope.SetExtra("api_error_cause", s.apiErr.Err.Error()) + s.scope.SetTag("api.error.cause_type", fmt.Sprintf("%T", s.apiErr.Err)) + + s.scope.SetExtra("api_error_cause_chain", s.buildErrorChain(s.apiErr.Err)) + } + + if accountName := s.accountName(); accountName != "" { + s.scope.SetExtra("api_account_name", accountName) + } + + if username := s.headerValue(portal.UsernameHeader); username != "" { + s.scope.SetExtra("api_username_header", username) + } + + if origin := s.headerValue(portal.IntendedOriginHeader); origin != "" { + s.scope.SetExtra("api_intended_origin", origin) + } + + if ts := s.headerValue(portal.TimestampHeader); ts != "" { + s.scope.SetExtra("api_request_timestamp", ts) + } + + if nonce := s.headerValue(portal.NonceHeader); nonce != "" { + s.scope.SetExtra("api_request_nonce", nonce) + } + + if publicKey := s.headerValue(portal.TokenHeader); publicKey != "" { + s.scope.SetExtra("api_public_key", publicKey) + } + + if clientIP := strings.TrimSpace(portal.ParseClientIP(s.request)); clientIP != "" { + s.scope.SetExtra("http_client_ip", clientIP) + } +} + +func (s *ScopeApiError) accountName() string { + if s == nil || s.request == nil { + return "" + } + + if v, ok := s.request.Context().Value(portal.AuthAccountNameKey).(string); ok { + if name := strings.TrimSpace(v); name != "" { + return name + } + } + + return s.headerValue(portal.UsernameHeader) +} + +func (s *ScopeApiError) headerValue(key string) string { + if s == nil || s.request == nil { + return "" + } + + return strings.TrimSpace(s.request.Header.Get(key)) +} + +func (s *ScopeApiError) buildErrorChain(err error) []string { + chain := make([]string, 0, 4) + + for current := err; current != nil; current = errors.Unwrap(current) { + chain = append(chain, current.Error()) + } + + return chain +}