Skip to content
This repository has been archived by the owner on Dec 8, 2020. It is now read-only.

Commit

Permalink
Update: http/api.CORSBuilder: adds AllowOrigins for setting origins t…
Browse files Browse the repository at this point in the history
…hat are allowed
  • Loading branch information
kyleterry committed Jun 8, 2020
1 parent e7c2dd8 commit c10aefd
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
56 changes: 56 additions & 0 deletions httputil/api/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ type corsHandler struct {
allowedHeaders map[corsMatchable]struct{}
allowedMethods map[corsMatchable]struct{}
allowedMethodsHeader string
allowedOrigins map[corsMatchable]struct{}
defaultAllowedOrigin string
}

func (ch *corsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -92,20 +94,42 @@ func (ch *corsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if len(allowedHeaders) > 0 {
w.Header().Set("access-control-allow-headers", strings.Join(allowedHeaders, ", "))
}

if len(ch.allowedOrigins) > 0 {
origin := r.Header.Get("origin")

if corsMatch(ch.allowedOrigins, origin) {
w.Header().Set("access-control-allow-origin", origin)
} else {
w.Header().Set("access-control-allow-origin", ch.defaultAllowedOrigin)
}
}
}

// CORSBuilder builds an http.Handler that can be used as a middleware to set
// CORS access control headers for relaxing same-origin browser policies.
type CORSBuilder struct {
allowedHeaders map[string]struct{}
allowedHeaderPrefixes map[string]struct{}
allowedMethods map[string]struct{}
allowedOrigins map[string]struct{}
defaultAllowedOrigin string
}

// AllowHeaderPrefix takes a header prefix and allows all headers that
// match the given prefix for requests.
//
// example AllowHeaderPrefix("example-") will allow a request with header
// Example-XYZ.
func (cb *CORSBuilder) AllowHeaderPrefix(prefix string) *CORSBuilder {
cb.allowedHeaderPrefixes[http.CanonicalHeaderKey(prefix)] = corsValue

return cb
}

// AllowHeaders takes a variadic of header strings to allow. It is similar
// to AllowHeaderPrefix, but it matches against the entire string instead
// of matching against a partial prefix.
func (cb *CORSBuilder) AllowHeaders(headers ...string) *CORSBuilder {
for _, header := range headers {
cb.allowedHeaders[http.CanonicalHeaderKey(header)] = corsValue
Expand All @@ -114,6 +138,7 @@ func (cb *CORSBuilder) AllowHeaders(headers ...string) *CORSBuilder {
return cb
}

// AllowMethods takes a variadic of http methods to allow.
func (cb *CORSBuilder) AllowMethods(methods ...string) *CORSBuilder {
for _, method := range methods {
cb.allowedMethods[strings.ToUpper(method)] = corsValue
Expand All @@ -122,9 +147,30 @@ func (cb *CORSBuilder) AllowMethods(methods ...string) *CORSBuilder {
return cb
}

// AllowOrigins takes a variadic of http origins to allow. The match is against
// the entire origin string and no patterns are allowed. The first one in the list
// is the default origin to return in the event the origin in the request isn't.
// There is no attempt to error on an origin that isn't in the list because this is
// the client's job. We simple return an origin that _is_ allowed and let the client
// block the request from happening.
func (cb *CORSBuilder) AllowOrigins(origins ...string) *CORSBuilder {
for _, origin := range origins {
cb.allowedOrigins[origin] = corsValue

if cb.defaultAllowedOrigin == "" {
cb.defaultAllowedOrigin = origin
}
}

return cb
}

// Build returns an http.Handler that can set Access-Control-Allow-* headers
// based on requests it receives.
func (cb *CORSBuilder) Build() http.Handler {
ch := &corsHandler{
allowedHeaders: make(map[corsMatchable]struct{}),
allowedOrigins: make(map[corsMatchable]struct{}),
}

for allowedHeader := range cb.allowedHeaders {
Expand Down Expand Up @@ -152,13 +198,23 @@ func (cb *CORSBuilder) Build() http.Handler {
ch.allowedMethodsHeader = strings.Join(allowedMethods, ", ")
}

if len(cb.allowedOrigins) > 0 {
for origin := range cb.allowedOrigins {
ch.allowedOrigins[corsMatchableString(origin)] = corsValue
}

ch.defaultAllowedOrigin = cb.defaultAllowedOrigin
}

return ch
}

// NewCORSBuilder returns a new CORSBuilder.
func NewCORSBuilder() *CORSBuilder {
return &CORSBuilder{
allowedHeaders: make(map[string]struct{}),
allowedHeaderPrefixes: make(map[string]struct{}),
allowedMethods: make(map[string]struct{}),
allowedOrigins: make(map[string]struct{}),
}
}
47 changes: 47 additions & 0 deletions httputil/api/cors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package api

import (
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func TestCORSBuilder(t *testing.T) {
handler := NewCORSBuilder().
AllowOrigins("http://example.com", "http://app.example.com").
AllowHeaderPrefix("horsehead-").
AllowHeaders("X-Custom-Header").Build()

req, err := http.NewRequest(http.MethodOptions, "http://example.com", nil)
require.NoError(t, err)
req.Header.Set("access-control-request-method", "POST")
req.Header.Set("access-control-request-headers", "Horsehead-Custom-Header, X-Custom-Header")
req.Header.Set("Origin", "http://app.example.com")

resp := httptest.NewRecorder()

handler.ServeHTTP(resp, req)
result := resp.Result()

require.Equal(t, http.StatusOK, result.StatusCode)
require.Equal(t, "http://app.example.com", result.Header.Get("Access-Control-Allow-Origin"))
require.Equal(t, "Horsehead-Custom-Header, X-Custom-Header", result.Header.Get("Access-Control-Allow-Headers"))
require.Equal(t, strings.Join(corsDefaultAllowedMethods, ", "), result.Header.Get("Access-Control-Allow-Methods"))

{
// a request that should fail
req, err := http.NewRequest(http.MethodOptions, "http://example.com", nil)
require.NoError(t, err)
req.Header.Set("access-control-request-method", "PUT")

resp := httptest.NewRecorder()

handler.ServeHTTP(resp, req)
result := resp.Result()

require.Equal(t, http.StatusMethodNotAllowed, result.StatusCode)
}
}

0 comments on commit c10aefd

Please sign in to comment.