Skip to content

Commit

Permalink
Rework the order of ResponseWriter.Header and ResponseWriter.WriteHeader
Browse files Browse the repository at this point in the history
  • Loading branch information
parkr committed Apr 12, 2024
1 parent d529a96 commit 53f9bb6
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 15 deletions.
10 changes: 8 additions & 2 deletions cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,18 @@ type corsHandler struct {
// ServeHTTP adds CORS headers.
func (c corsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
// According to the documentation for ResponseWriter.WriteHeader and
// ResponseWriter.Header, modifications to the Header must happen BEFORE
// any call to WriteHeader. That means it should be, in order:
// 1. Write all your Headers
// 2. Call WriteHeader to set your status
// 3. Write to the body
c.addCORSHeaders(w, r)
w.WriteHeader(http.StatusNoContent)
return
}
c.next.ServeHTTP(w, r)
c.addCORSHeaders(w, r)
c.next.ServeHTTP(w, r)
}

func (c corsHandler) addCORSHeaders(w http.ResponseWriter, r *http.Request) {
Expand Down
3 changes: 2 additions & 1 deletion dnt/do_not_track.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ type dntMiddleware struct {

func (d dntMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if RequestsDoNotTrack(r) {
w.WriteHeader(http.StatusNoContent)
SetDoNotTrack(w)
// Note: All w.Header() modifications must be made BEFORE this call.
w.WriteHeader(http.StatusNoContent)
return
}
d.nextHandler.ServeHTTP(w, r)
Expand Down
5 changes: 3 additions & 2 deletions jsv1/version1.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@ const returnedJavaScript = "(function(){})();"
const lengthOfJavaScript = "17"

func Write(w http.ResponseWriter, code int) {
w.WriteHeader(code)
w.Header().Set("Content-Type", "application/javascript")
w.Header().Set("Content-Length", lengthOfJavaScript)

fmt.Fprintf(w, returnedJavaScript)
}

func Error(w http.ResponseWriter, code int, err string) {
w.WriteHeader(code)
content := fmt.Sprintf(`(function(){console.error("%s")})();`, err)
w.Header().Set("Content-Type", "application/javascript")
w.Header().Set("Content-Length", strconv.Itoa(len(content)))
// Note: All w.Header() modifications must be made BEFORE this call.
w.WriteHeader(code)
fmt.Fprintf(w, content)
}
3 changes: 2 additions & 1 deletion jsv2/version2.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ function logVisit(document) {
`

func Write(w http.ResponseWriter, code int) {
w.WriteHeader(code)
w.Header().Set("Content-Type", "application/javascript")
w.Header().Set("Content-Length", strconv.Itoa(len(returnedJavaScript)))
// Note: All w.Header() modifications must be made BEFORE this call.
w.WriteHeader(code)
fmt.Fprintf(w, returnedJavaScript)
}
7 changes: 0 additions & 7 deletions ping.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package ping

import (
"encoding/json"
"errors"
"fmt"
"log"
Expand Down Expand Up @@ -36,12 +35,6 @@ func parseReferer(referer string) (*url.URL, error) {
return url.Parse(referer)
}

func jsonError(w http.ResponseWriter, statusCode int, message string) {
w.WriteHeader(statusCode)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"error": message})
}

// ping routes to pingv1 or pingv2 depending on the version code in the form.
func ping(w http.ResponseWriter, r *http.Request) {
version := r.FormValue("v")
Expand Down
10 changes: 8 additions & 2 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package ping

import (
"encoding/json"
"fmt"
"net/http"
"strings"
)
Expand All @@ -11,12 +10,19 @@ func writeJsonResponse(w http.ResponseWriter, input interface{}) {
w.Header().Set("Content-Type", "application/json")
data, err := json.Marshal(input)
if err != nil {
fmt.Fprintf(w, `{"error":"json, `+err.Error()+`"}`)
jsonError(w, http.StatusInternalServerError, err.Error())
} else {
w.Write(data)
}
}

func jsonError(w http.ResponseWriter, statusCode int, message string) {
w.Header().Set("Content-Type", "application/json")
// Note: All w.Header() modifications must be made BEFORE this call.
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(map[string]string{"error": message})
}

func sanitizeUserInput(input string) string {
escapedInput := strings.ReplaceAll(input, "\n", "")
escapedInput = strings.ReplaceAll(escapedInput, "\r", "")
Expand Down

0 comments on commit 53f9bb6

Please sign in to comment.