Skip to content

Commit

Permalink
Merge pull request #149 from resgateio/feature/gh-141-cors-support
Browse files Browse the repository at this point in the history
Feature/gh 141 cors support
  • Loading branch information
jirenius committed Mar 18, 2020
2 parents f47df28 + 992325d commit 91e0045
Show file tree
Hide file tree
Showing 14 changed files with 489 additions and 28 deletions.
23 changes: 23 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io/ioutil"
"os"
"os/signal"
"strings"
"syscall"
"time"

Expand Down Expand Up @@ -43,6 +44,7 @@ Server Options:
--tlskey <file> Private key for HTTP server certificate
--apiencoding <type> Encoding for web resources: json, jsonflat (default: json)
--creds <file> NATS User Credentials file
--alloworigin <origin> Allowed origin(s): *, or <scheme>://<hostname>[:<port>] (default: *)
-c, --config <file> Configuration file
Logging Options:
Expand All @@ -67,6 +69,22 @@ type Config struct {
server.Config
}

// StringSlice is a slice of strings implementing the flag.Value interface.
type StringSlice []string

func (s *StringSlice) String() string {
if s == nil {
return ""
}
return strings.Join(*s, ";")
}

// Set adds a value to the slice.
func (s *StringSlice) Set(v string) error {
*s = append(*s, v)
return nil
}

// SetDefault sets the default values
func (c *Config) SetDefault() {
if c.NatsURL == "" {
Expand All @@ -90,6 +108,7 @@ func (c *Config) Init(fs *flag.FlagSet, args []string) {
addr string
natsCreds string
debugTrace bool
allowOrigin StringSlice
)

fs.BoolVar(&showHelp, "h", false, "Show this message.")
Expand All @@ -115,6 +134,7 @@ func (c *Config) Init(fs *flag.FlagSet, args []string) {
fs.IntVar(&c.RequestTimeout, "r", 0, "Timeout in milliseconds for NATS requests.")
fs.IntVar(&c.RequestTimeout, "reqtimeout", 0, "Timeout in milliseconds for NATS requests.")
fs.StringVar(&natsCreds, "creds", "", "NATS User Credentials file.")
fs.Var(&allowOrigin, "alloworigin", "Allowed origin(s) for CORS.")
fs.BoolVar(&c.Debug, "D", false, "Enable debugging output.")
fs.BoolVar(&c.Debug, "debug", false, "Enable debugging output.")
fs.BoolVar(&c.Trace, "V", false, "Enable trace logging.")
Expand Down Expand Up @@ -185,6 +205,9 @@ func (c *Config) Init(fs *flag.FlagSet, args []string) {
} else {
c.NatsCreds = &natsCreds
}
case "alloworigin":
str := allowOrigin.String()
c.AllowOrigin = &str
case "i":
fallthrough
case "addr":
Expand Down
42 changes: 42 additions & 0 deletions server/apiHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"mime"
"net/http"
"strings"

Expand All @@ -21,10 +22,49 @@ func (s *Service) initAPIHandler() error {
return fmt.Errorf("invalid apiEncoding setting (%s) - available encodings: %s", s.cfg.APIEncoding, strings.Join(keys, ", "))
}
s.enc = f(s.cfg)
mimetype, _, err := mime.ParseMediaType(s.enc.ContentType())
s.mimetype = mimetype
return err
}

// setCommonHeaders sets common headers such as Access-Control-*.
// It returns error if the origin header does not match any allowed origin.
func (s *Service) setCommonHeaders(w http.ResponseWriter, r *http.Request) error {
switch s.cfg.allowOrigin[0] {
case "*":
w.Header().Set("Access-Control-Allow-Origin", "*")

default:
// CORS validation
origin := r.Header["Origin"]
// If no Origin header is set, or the value is null, we can allow access
// as it is not coming from a CORS enabled browser.
if len(origin) > 0 && origin[0] != "null" {
if matchesOrigins(s.cfg.allowOrigin, origin[0]) {
w.Header().Set("Access-Control-Allow-Origin", origin[0])
w.Header().Set("Vary", "Origin")
} else {
// No matching origin
w.Header().Set("Access-Control-Allow-Origin", s.cfg.allowOrigin[0])
w.Header().Set("Vary", "Origin")
return reserr.ErrForbiddenOrigin
}
}
}
return nil
}

func (s *Service) apiHandler(w http.ResponseWriter, r *http.Request) {
err := s.setCommonHeaders(w, r)
if r.Method == "OPTIONS" {
w.Header().Set("Access-Control-Allow-Methods", s.cfg.allowMethods)
return
}
if err != nil {
httpError(w, err, s.enc)
return
}

path := r.URL.RawPath
if path == "" {
path = r.URL.Path
Expand Down Expand Up @@ -167,6 +207,8 @@ func httpError(w http.ResponseWriter, err error, enc APIEncoder) {
code = http.StatusInternalServerError
case reserr.CodeServiceUnavailable:
code = http.StatusServiceUnavailable
case reserr.CodeForbidden:
code = http.StatusForbidden
default:
code = http.StatusBadRequest
}
Expand Down
90 changes: 88 additions & 2 deletions server/config.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package server

import (
"errors"
"fmt"
"net"
"net/url"
"sort"
"strings"
"unicode/utf8"

"github.com/resgateio/resgate/server/codec"
)
Expand All @@ -16,6 +20,7 @@ type Config struct {
APIPath string `json:"apiPath"`
APIEncoding string `json:"apiEncoding"`
HeaderAuth *string `json:"headerAuth"`
AllowOrigin *string `json:"allowOrigin"`

TLS bool `json:"tls"`
TLSCert string `json:"certFile"`
Expand All @@ -29,6 +34,8 @@ type Config struct {
netAddr string
headerAuthRID string
headerAuthAction string
allowOrigin []string
allowMethods string
}

// SetDefault sets the default values
Expand All @@ -49,6 +56,10 @@ func (c *Config) SetDefault() {
if c.APIEncoding == "" {
c.APIEncoding = DefaultAPIEncoding
}
if c.AllowOrigin == nil {
origin := "*"
c.AllowOrigin = &origin
}
}

// prepare sets the unexported values
Expand Down Expand Up @@ -79,7 +90,7 @@ func (c *Config) prepare() error {
c.netAddr = ip.String()
}
} else {
return fmt.Errorf("invalid addr setting (%s) - must be a valid IPv4 or IPv6 address", s)
return fmt.Errorf("invalid addr setting (%s)\n\tmust be a valid IPv4 or IPv6 address", s)
}
}
} else {
Expand All @@ -94,9 +105,22 @@ func (c *Config) prepare() error {
c.headerAuthRID = s[:idx]
c.headerAuthAction = s[idx+1:]
} else {
return fmt.Errorf("invalid headerAuth setting (%s) - must be a valid resource method", s)
return fmt.Errorf("invalid headerAuth setting (%s)\n\tmust be a valid resource method", s)
}
}

if c.AllowOrigin != nil {
c.allowOrigin = strings.Split(*c.AllowOrigin, ";")
if err := validateAllowOrigin(c.allowOrigin); err != nil {
return fmt.Errorf("invalid allowOrigin setting (%s)\n\t%s\n\tvalid options are *, or a list of semi-colon separated origins", *c.AllowOrigin, err)
}
sort.Strings(c.allowOrigin)
} else {
c.allowOrigin = []string{"*"}
}

c.allowMethods = "GET, POST, OPTIONS"

if c.WSPath == "" {
c.WSPath = "/"
}
Expand All @@ -106,3 +130,65 @@ func (c *Config) prepare() error {

return nil
}

func validateAllowOrigin(s []string) error {
for i, o := range s {
o = toLowerASCII(o)
s[i] = o
if o == "*" {
if len(s) > 1 {
return fmt.Errorf("'%s' must not be used together with other origin settings", o)
}
} else {
if o == "" {
return errors.New("origin must not be empty")
}
u, err := url.Parse(o)
if err != nil || u.Scheme == "" || u.Host == "" || u.Opaque != "" || u.User != nil || u.Path != "" || len(u.Query()) > 0 || u.Fragment != "" {
return fmt.Errorf("'%s' doesn't match <scheme>://<hostname>[:<port>]", o)
}
}
}
return nil
}

// toLowerASCII converts only A-Z to lower case in a string
func toLowerASCII(s string) string {
var b strings.Builder
b.Grow(len(s))
for i := 0; i < len(s); i++ {
c := s[i]
if 'A' <= c && c <= 'Z' {
c += 'a' - 'A'
}
b.WriteByte(c)
}
return b.String()
}

func matchesOrigins(os []string, o string) bool {
origin:
for _, s := range os {
t := o
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)
s = s[size:]
tr, size := utf8.DecodeRuneInString(t)
t = t[size:]
if sr == tr {
continue
}
// Lowercase A-Z. Should already be done for origins.
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
continue origin
}
}
if s == t {
return true
}
}
return false
}
70 changes: 64 additions & 6 deletions server/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ func TestConfigPrepare(t *testing.T) {
ipv6Addr := "::1"
invalidAddr := "127.0.0"
invalidHeaderAuth := "test"
allowOriginAll := "*"
allowOriginSingle := "http://resgate.io"
allowOriginMultiple := "http://localhost;http://resgate.io"
allowOriginInvalidEmpty := ""
allowOriginInvalidEmptyOrigin := ";http://localhost"
allowOriginInvalidMultipleAll := "http://localhost;*"
allowOriginInvalidMultipleSame := "http://localhost;*"
allowOriginInvalidOrigin := "http://this.is/invalid"
defaultCfg := Config{}
defaultCfg.SetDefault()

Expand All @@ -34,14 +42,24 @@ func TestConfigPrepare(t *testing.T) {
Expected Config
PrepareError bool
}{
{defaultCfg, Config{Addr: &defaultAddr, Port: 8080, WSPath: "/", APIPath: "/api/", APIEncoding: "json", scheme: "http", netAddr: "0.0.0.0:8080"}, false},
{Config{WSPath: "/"}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "0.0.0.0:80"}, false},
{Config{WSPath: "/"}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "0.0.0.0:80"}, false},
{Config{Addr: &emptyAddr, WSPath: "/"}, Config{Addr: &emptyAddr, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: ":80"}, false},
{Config{Addr: &localAddr, WSPath: "/"}, Config{Addr: &localAddr, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "127.0.0.1:80"}, false},
{Config{Addr: &ipv6Addr, WSPath: "/"}, Config{Addr: &ipv6Addr, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "[::1]:80"}, false},
// Valid config
{defaultCfg, Config{Addr: &defaultAddr, Port: 8080, WSPath: "/", APIPath: "/api/", APIEncoding: "json", scheme: "http", netAddr: "0.0.0.0:8080", allowOrigin: []string{"*"}, allowMethods: "GET, POST, OPTIONS"}, false},
{Config{WSPath: "/"}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "0.0.0.0:80", allowOrigin: []string{"*"}, allowMethods: "GET, POST, OPTIONS"}, false},
{Config{Addr: &emptyAddr, WSPath: "/"}, Config{Addr: &emptyAddr, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: ":80", allowOrigin: []string{"*"}, allowMethods: "GET, POST, OPTIONS"}, false},
{Config{Addr: &localAddr, WSPath: "/"}, Config{Addr: &localAddr, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "127.0.0.1:80", allowOrigin: []string{"*"}, allowMethods: "GET, POST, OPTIONS"}, false},
{Config{Addr: &ipv6Addr, WSPath: "/"}, Config{Addr: &ipv6Addr, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "[::1]:80", allowOrigin: []string{"*"}, allowMethods: "GET, POST, OPTIONS"}, false},
{Config{AllowOrigin: &allowOriginAll, WSPath: "/"}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "0.0.0.0:80", allowOrigin: []string{"*"}, allowMethods: "GET, POST, OPTIONS"}, false},
{Config{AllowOrigin: &allowOriginSingle, WSPath: "/"}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "0.0.0.0:80", allowOrigin: []string{"http://resgate.io"}, allowMethods: "GET, POST, OPTIONS"}, false},
{Config{AllowOrigin: &allowOriginMultiple, WSPath: "/"}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "0.0.0.0:80", allowOrigin: []string{"http://localhost", "http://resgate.io"}, allowMethods: "GET, POST, OPTIONS"}, false},

// Invalid config
{Config{Addr: &invalidAddr, WSPath: "/"}, Config{}, true},
{Config{HeaderAuth: &invalidHeaderAuth, WSPath: "/"}, Config{}, true},
{Config{AllowOrigin: &allowOriginInvalidEmpty, WSPath: "/"}, Config{}, true},
{Config{AllowOrigin: &allowOriginInvalidEmptyOrigin, WSPath: "/"}, Config{}, true},
{Config{AllowOrigin: &allowOriginInvalidMultipleAll, WSPath: "/"}, Config{}, true},
{Config{AllowOrigin: &allowOriginInvalidMultipleSame, WSPath: "/"}, Config{}, true},
{Config{AllowOrigin: &allowOriginInvalidOrigin, WSPath: "/"}, Config{}, true},
}

for i, r := range tbl {
Expand Down Expand Up @@ -90,6 +108,19 @@ func TestConfigPrepare(t *testing.T) {
t.Fatalf("expected headerAuthRID to be:\n%s\nbut got:\n%s\nin test %d", r.Expected.headerAuthRID, cfg.headerAuthRID, i+1)
}

if len(cfg.allowOrigin) != len(r.Expected.allowOrigin) {
t.Fatalf("expected allowOrigin to be:\n%+v\nbut got:\n%+v\nin test %d", r.Expected.allowOrigin, cfg.allowOrigin, i+1)
}
for i, origin := range cfg.allowOrigin {
if origin != r.Expected.allowOrigin[i] {
t.Fatalf("expected allowOrigin to be:\n%+v\nbut got:\n%+v\nin test %d", r.Expected.allowOrigin, cfg.allowOrigin, i+1)
}
}

if cfg.allowMethods != r.Expected.allowMethods {
t.Fatalf("expected allowMethods to be:\n%s\nbut got:\n%s\nin test %d", r.Expected.allowMethods, cfg.allowMethods, i+1)
}

compareStringPtr(t, "HeaderAuth", cfg.HeaderAuth, r.Expected.HeaderAuth, i)

}
Expand Down Expand Up @@ -135,3 +166,30 @@ func TestVersionMatchesTag(t *testing.T) {
t.Fatalf("Expected version %+v, got %+v", Version, tag[1:])
}
}

func TestMatchesOrigins(t *testing.T) {
tbl := []struct {
AllowedOrigins []string
Origin string
Expected bool
}{
{[]string{"http://localhost"}, "http://localhost", true},
{[]string{"https://resgate.io"}, "https://resgate.io", true},
{[]string{"https://resgate.io"}, "https://Resgate.IO", true},
{[]string{"http://localhost", "https://resgate.io"}, "http://localhost", true},
{[]string{"http://localhost", "https://resgate.io"}, "https://resgate.io", true},
{[]string{"http://localhost", "https://resgate.io"}, "https://Resgate.IO", true},
{[]string{"http://localhost", "https://resgate.io", "http://resgate.io"}, "http://Localhost", true},
{[]string{"http://localhost", "https://resgate.io", "http://resgate.io"}, "https://Resgate.io", true},
{[]string{"http://localhost", "https://resgate.io", "http://resgate.io"}, "http://resgate.IO", true},
{[]string{"https://resgate.io"}, "http://resgate.io", false},
{[]string{"http://localhost", "https://resgate.io"}, "http://resgate.io", false},
{[]string{"http://localhost", "https://resgate.io", "http://resgate.io"}, "http://localhost/", false},
}

for i, r := range tbl {
if matchesOrigins(r.AllowedOrigins, r.Origin) != r.Expected {
t.Fatalf("expected matchesOrigins to return %#v\n\tmatchesOrigins(%#v, %#v)\n\tin test #%d", r.Expected, r.AllowedOrigins, r.Origin, i+1)
}
}
}
2 changes: 2 additions & 0 deletions server/reserr/reserr.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ const (
CodeBadRequest = "system.badRequest"
CodeMethodNotAllowed = "system.methodNotAllowed"
CodeServiceUnavailable = "system.serviceUnavailable"
CodeForbidden = "system.forbidden"
)

// Pre-defined RES errors
Expand All @@ -71,4 +72,5 @@ var (
ErrBadRequest = &Error{Code: CodeBadRequest, Message: "Bad request"}
ErrMethodNotAllowed = &Error{Code: CodeMethodNotAllowed, Message: "Method not allowed"}
ErrServiceUnavailable = &Error{Code: CodeServiceUnavailable, Message: "Service unavailable"}
ErrForbiddenOrigin = &Error{Code: CodeForbidden, Message: "Forbidden origin"}
)
5 changes: 3 additions & 2 deletions server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ type Service struct {
cache *rescache.Cache

// httpServer
h *http.Server
enc APIEncoder
h *http.Server
enc APIEncoder
mimetype string

// wsListener/wsConn
upgrader websocket.Upgrader
Expand Down

0 comments on commit 91e0045

Please sign in to comment.