Skip to content

Commit

Permalink
Improve error handling (#61)
Browse files Browse the repository at this point in the history
* Add 401, 403 error types

* Migrate HTTP error serving to apierrors package

For better separation of concerns between the router and the API error
handling/HTTP error generation.

* Add TODO

* Use abstraction for auth middleware error handling

Prevents code duplication and improves separation of concerns. Also
paves the way for easier future changes to error messages when they
become more structured.

* Update users, stories controller error handling

Ensures consistency and improves correctness.

* Fix error message typo
  • Loading branch information
RichDom2185 authored Jul 22, 2023
1 parent d918637 commit c45af08
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 20 deletions.
9 changes: 6 additions & 3 deletions controller/stories/stories.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ func HandleRead(w http.ResponseWriter, r *http.Request) error {
storyIDStr := chi.URLParam(r, "storyID")
storyID, err := strconv.Atoi(storyIDStr)
if err != nil {
http.Error(w, "Invalid storyID", http.StatusBadRequest)
return err
return apierrors.ClientBadRequestError{
Message: fmt.Sprintf("Invalid storyID: %v", err),
}
}

// Get DB instance
Expand All @@ -66,7 +67,9 @@ func HandleCreate(w http.ResponseWriter, r *http.Request) error {
e, ok := err.(*json.UnmarshalTypeError)
if !ok {
logrus.Error(err)
return err
return apierrors.ClientBadRequestError{
Message: fmt.Sprintf("Bad JSON parsing: %v", err),
}
}

// TODO: Investigate if we should use errors.Wrap instead
Expand Down
16 changes: 11 additions & 5 deletions controller/users/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ func HandleRead(w http.ResponseWriter, r *http.Request) error {
userIDStr := chi.URLParam(r, "userID")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
http.Error(w, "Invalid userID", http.StatusBadRequest)
return err
return apierrors.ClientBadRequestError{
Message: fmt.Sprintf("Invalid userID: %v", err),
}
}

// Get DB instance
Expand All @@ -66,18 +67,23 @@ func HandleCreate(w http.ResponseWriter, r *http.Request) error {
e, ok := err.(*json.UnmarshalTypeError)
if !ok {
logrus.Error(err)
return err
return apierrors.ClientBadRequestError{
Message: fmt.Sprintf("Bad JSON parsing: %v", err),
}
}

// TODO: Investigate if we should use errors.Wrap instead
return apierrors.ClientBadRequestError{
return apierrors.ClientUnprocessableEntityError{
Message: fmt.Sprintf("Invalid JSON format: %s should be a %s.", e.Field, e.Type),
}
}

err := params.Validate()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
logrus.Error(err)
return apierrors.ClientUnprocessableEntityError{
Message: fmt.Sprintf("JSON validation failed: %v", err),
}
}

userModel := *params.ToModel()
Expand Down
10 changes: 8 additions & 2 deletions internal/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/source-academy/stories-backend/internal/config"
apierrors "github.com/source-academy/stories-backend/internal/errors"
envutils "github.com/source-academy/stories-backend/internal/utils/env"
)

Expand All @@ -22,6 +23,7 @@ func MakeMiddlewareFrom(conf *config.Config) func(http.Handler) http.Handler {
key, ok := keySet.Key(0)
if !ok {
// Block all access if JWKS source is down, since we can't verify JWTs
// TODO: Investigate if 500 is appropriate
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
Expand All @@ -34,15 +36,19 @@ func MakeMiddlewareFrom(conf *config.Config) func(http.Handler) http.Handler {
// Get JWT from request
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
apierrors.ServeHTTP(w, r, apierrors.ClientUnauthorizedError{
Message: "Missing Authorization header",
})
return
}

// Verify JWT
toParse := authHeader[len("Bearer "):]
token, err := jwt.ParseString(toParse, jwt.WithKey(jwa.RS256, key))
if err != nil {
fmt.Printf("Failed to verify JWS: %s\n", err)
apierrors.ServeHTTP(w, r, apierrors.ClientForbiddenError{
Message: fmt.Sprintf("Failed to verify JWT: %s\n", err),
})
return
}

Expand Down
17 changes: 17 additions & 0 deletions internal/errors/401.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package apierrors

import (
"net/http"
)

type ClientUnauthorizedError struct {
Message string
}

func (e ClientUnauthorizedError) Error() string {
return e.Message
}

func (e ClientUnauthorizedError) HTTPStatusCode() int {
return http.StatusUnauthorized
}
17 changes: 17 additions & 0 deletions internal/errors/403.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package apierrors

import (
"net/http"
)

type ClientForbiddenError struct {
Message string
}

func (e ClientForbiddenError) Error() string {
return e.Message
}

func (e ClientForbiddenError) HTTPStatusCode() int {
return http.StatusForbidden
}
17 changes: 17 additions & 0 deletions internal/errors/errors.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
package apierrors

import (
"errors"
"net/http"
)

// ClientError is an interface for errors that should be returned to the client.
// They generally start with a 4xx HTTP status code.
type ClientError interface {
error
HTTPStatusCode() int
}

func ServeHTTP(w http.ResponseWriter, r *http.Request, err error) {
var clientError ClientError
if errors.As(err, &clientError) {
// Client error (status 4xx), write error message and status code
http.Error(w, clientError.Error(), clientError.HTTPStatusCode())
return
}

// 500 Internal Server Error as a catch-all
http.Error(w, err.Error(), http.StatusInternalServerError)
}
12 changes: 2 additions & 10 deletions internal/router/errors.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package router

import (
"errors"
"net/http"

apierrors "github.com/source-academy/stories-backend/internal/errors"
Expand All @@ -15,14 +14,7 @@ func handleAPIError(handler func(w http.ResponseWriter, r *http.Request) error)
return
}

var clientError apierrors.ClientError
if errors.As(err, &clientError) {
// Client error (status 4xx), write error message and status code
http.Error(w, clientError.Error(), clientError.HTTPStatusCode())
return
}

// 500 Internal Server Error as a catch-all
http.Error(w, err.Error(), http.StatusInternalServerError)
// Error, write error message and status code
apierrors.ServeHTTP(w, r, err)
}
}

0 comments on commit c45af08

Please sign in to comment.