From c45af08cdcb3962fe94efc0b5156eb3801ff6bb4 Mon Sep 17 00:00:00 2001 From: Richard Dominick <34370238+RichDom2185@users.noreply.github.com> Date: Sat, 22 Jul 2023 12:05:09 +0800 Subject: [PATCH] Improve error handling (#61) * Add 401, 403 error types * Migrate HTTP error serving to apierrors package For better separation of concerns between the router and the API error handling/HTTP error generation. * Add TODO * Use abstraction for auth middleware error handling Prevents code duplication and improves separation of concerns. Also paves the way for easier future changes to error messages when they become more structured. * Update users, stories controller error handling Ensures consistency and improves correctness. * Fix error message typo --- controller/stories/stories.go | 9 ++++++--- controller/users/users.go | 16 +++++++++++----- internal/auth/middleware.go | 10 ++++++++-- internal/errors/401.go | 17 +++++++++++++++++ internal/errors/403.go | 17 +++++++++++++++++ internal/errors/errors.go | 17 +++++++++++++++++ internal/router/errors.go | 12 ++---------- 7 files changed, 78 insertions(+), 20 deletions(-) create mode 100644 internal/errors/401.go create mode 100644 internal/errors/403.go diff --git a/controller/stories/stories.go b/controller/stories/stories.go index cc3cd8b..214c66a 100644 --- a/controller/stories/stories.go +++ b/controller/stories/stories.go @@ -40,8 +40,9 @@ func HandleRead(w http.ResponseWriter, r *http.Request) error { storyIDStr := chi.URLParam(r, "storyID") storyID, err := strconv.Atoi(storyIDStr) if err != nil { - http.Error(w, "Invalid storyID", http.StatusBadRequest) - return err + return apierrors.ClientBadRequestError{ + Message: fmt.Sprintf("Invalid storyID: %v", err), + } } // Get DB instance @@ -66,7 +67,9 @@ func HandleCreate(w http.ResponseWriter, r *http.Request) error { e, ok := err.(*json.UnmarshalTypeError) if !ok { logrus.Error(err) - return err + return apierrors.ClientBadRequestError{ + Message: fmt.Sprintf("Bad JSON parsing: %v", err), + } } // TODO: Investigate if we should use errors.Wrap instead diff --git a/controller/users/users.go b/controller/users/users.go index e20cd77..46806c1 100644 --- a/controller/users/users.go +++ b/controller/users/users.go @@ -40,8 +40,9 @@ func HandleRead(w http.ResponseWriter, r *http.Request) error { userIDStr := chi.URLParam(r, "userID") userID, err := strconv.Atoi(userIDStr) if err != nil { - http.Error(w, "Invalid userID", http.StatusBadRequest) - return err + return apierrors.ClientBadRequestError{ + Message: fmt.Sprintf("Invalid userID: %v", err), + } } // Get DB instance @@ -66,18 +67,23 @@ func HandleCreate(w http.ResponseWriter, r *http.Request) error { e, ok := err.(*json.UnmarshalTypeError) if !ok { logrus.Error(err) - return err + return apierrors.ClientBadRequestError{ + Message: fmt.Sprintf("Bad JSON parsing: %v", err), + } } // TODO: Investigate if we should use errors.Wrap instead - return apierrors.ClientBadRequestError{ + return apierrors.ClientUnprocessableEntityError{ Message: fmt.Sprintf("Invalid JSON format: %s should be a %s.", e.Field, e.Type), } } err := params.Validate() if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + logrus.Error(err) + return apierrors.ClientUnprocessableEntityError{ + Message: fmt.Sprintf("JSON validation failed: %v", err), + } } userModel := *params.ToModel() diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index 8ce5d54..7fb75b8 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -7,6 +7,7 @@ import ( "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/source-academy/stories-backend/internal/config" + apierrors "github.com/source-academy/stories-backend/internal/errors" envutils "github.com/source-academy/stories-backend/internal/utils/env" ) @@ -22,6 +23,7 @@ func MakeMiddlewareFrom(conf *config.Config) func(http.Handler) http.Handler { key, ok := keySet.Key(0) if !ok { // Block all access if JWKS source is down, since we can't verify JWTs + // TODO: Investigate if 500 is appropriate return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -34,7 +36,9 @@ func MakeMiddlewareFrom(conf *config.Config) func(http.Handler) http.Handler { // Get JWT from request authHeader := r.Header.Get("Authorization") if authHeader == "" { - http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + apierrors.ServeHTTP(w, r, apierrors.ClientUnauthorizedError{ + Message: "Missing Authorization header", + }) return } @@ -42,7 +46,9 @@ func MakeMiddlewareFrom(conf *config.Config) func(http.Handler) http.Handler { toParse := authHeader[len("Bearer "):] token, err := jwt.ParseString(toParse, jwt.WithKey(jwa.RS256, key)) if err != nil { - fmt.Printf("Failed to verify JWS: %s\n", err) + apierrors.ServeHTTP(w, r, apierrors.ClientForbiddenError{ + Message: fmt.Sprintf("Failed to verify JWT: %s\n", err), + }) return } diff --git a/internal/errors/401.go b/internal/errors/401.go new file mode 100644 index 0000000..5c59322 --- /dev/null +++ b/internal/errors/401.go @@ -0,0 +1,17 @@ +package apierrors + +import ( + "net/http" +) + +type ClientUnauthorizedError struct { + Message string +} + +func (e ClientUnauthorizedError) Error() string { + return e.Message +} + +func (e ClientUnauthorizedError) HTTPStatusCode() int { + return http.StatusUnauthorized +} diff --git a/internal/errors/403.go b/internal/errors/403.go new file mode 100644 index 0000000..d986f5c --- /dev/null +++ b/internal/errors/403.go @@ -0,0 +1,17 @@ +package apierrors + +import ( + "net/http" +) + +type ClientForbiddenError struct { + Message string +} + +func (e ClientForbiddenError) Error() string { + return e.Message +} + +func (e ClientForbiddenError) HTTPStatusCode() int { + return http.StatusForbidden +} diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 1566ea1..0e948fb 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -1,8 +1,25 @@ package apierrors +import ( + "errors" + "net/http" +) + // ClientError is an interface for errors that should be returned to the client. // They generally start with a 4xx HTTP status code. type ClientError interface { error HTTPStatusCode() int } + +func ServeHTTP(w http.ResponseWriter, r *http.Request, err error) { + var clientError ClientError + if errors.As(err, &clientError) { + // Client error (status 4xx), write error message and status code + http.Error(w, clientError.Error(), clientError.HTTPStatusCode()) + return + } + + // 500 Internal Server Error as a catch-all + http.Error(w, err.Error(), http.StatusInternalServerError) +} diff --git a/internal/router/errors.go b/internal/router/errors.go index 562d14e..01ad118 100644 --- a/internal/router/errors.go +++ b/internal/router/errors.go @@ -1,7 +1,6 @@ package router import ( - "errors" "net/http" apierrors "github.com/source-academy/stories-backend/internal/errors" @@ -15,14 +14,7 @@ func handleAPIError(handler func(w http.ResponseWriter, r *http.Request) error) return } - var clientError apierrors.ClientError - if errors.As(err, &clientError) { - // Client error (status 4xx), write error message and status code - http.Error(w, clientError.Error(), clientError.HTTPStatusCode()) - return - } - - // 500 Internal Server Error as a catch-all - http.Error(w, err.Error(), http.StatusInternalServerError) + // Error, write error message and status code + apierrors.ServeHTTP(w, r, err) } }