Skip to content

Commit

Permalink
add cors headers to s3 media handler
Browse files Browse the repository at this point in the history
  • Loading branch information
or-else committed Jul 2, 2021
1 parent 0d7ecbd commit b568849
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 52 deletions.
12 changes: 8 additions & 4 deletions server/hdl_files.go
Expand Up @@ -73,8 +73,10 @@ func largeFileServe(wrt http.ResponseWriter, req *http.Request) {
return
}

for name, value := range headers {
wrt.Header().Set(name, value)
for name, values := range headers {
for _, value := range values {
wrt.Header().Add(name, value)
}
}

if statusCode != 0 {
Expand Down Expand Up @@ -172,8 +174,10 @@ func largeFileReceive(wrt http.ResponseWriter, req *http.Request) {
return
}

for name, value := range headers {
wrt.Header().Set(name, value)
for name, values := range headers {
for _, value := range values {
wrt.Header().Add(name, value)
}
}

if statusCode != 0 {
Expand Down
44 changes: 3 additions & 41 deletions server/media/fs/filesys.go
Expand Up @@ -59,34 +59,9 @@ func (fh *fshandler) Init(jsconf string) error {
}

// Headers is used for serving CORS headers.
func (fh *fshandler) Headers(req *http.Request, serve bool) (map[string]string, int, error) {
if len(fh.corsOrigins) == 0 {
// CORS not configured.
return nil, 0, nil
}

allowedOrigin := matchOrigin(fh.corsOrigins, req.Header.Get("Origin"))
if allowedOrigin == "" {
// CORS policy does not match the origin.
return nil, 0, nil
}

var statusCode int
if req.Method == http.MethodHead || req.Method == http.MethodOptions {
statusCode = http.StatusOK
}
var allowMethods string
if serve {
allowMethods = "GET,HEAD,OPTIONS"
} else {
allowMethods = "POST,PUT,HEAD,OPTIONS"
}
return map[string]string{
"Access-Control-Allow-Origin": allowedOrigin,
"Access-Control-Allow-Headers": "*",
"Access-Control-Allow-Methods": allowMethods,
"Access-Control-Max-Age": "86400",
}, statusCode, nil
func (fh *fshandler) Headers(req *http.Request, serve bool) (http.Header, int, error) {
header, status := media.CORSHandler(req, fh.corsOrigins, serve)
return header, status, nil
}

// Upload processes request for file upload. The file is given as io.Reader.
Expand Down Expand Up @@ -189,19 +164,6 @@ func (fh *fshandler) getFileRecord(fid types.Uid) (*types.FileDef, error) {
return fd, nil
}

func matchOrigin(allowed []string, origin string) string {
if allowed[0] == "*" {
return "*"
}

for _, val := range allowed {
if val == origin {
return origin
}
}

return ""
}
func init() {
store.RegisterMediaHandler(handlerName, &fshandler{})
}
82 changes: 81 additions & 1 deletion server/media/media.go
Expand Up @@ -25,7 +25,7 @@ type Handler interface {
// Headers checks if the handler wants to provide additional HTTP headers for the request.
// It could be CORS headers, redirect to serve files from another URL, cache-control headers.
// It returns headers as a map, HTTP status code to stop processing or 0 to continue, error.
Headers(req *http.Request, serve bool) (map[string]string, int, error)
Headers(req *http.Request, serve bool) (http.Header, int, error)

// Upload processes request for file upload.
Upload(fdef *types.FileDef, file io.ReadSeeker) (string, error)
Expand All @@ -50,3 +50,83 @@ func GetIdFromUrl(url, serveUrl string) types.Uid {

return types.ParseUid(strings.Split(fname, ".")[0])
}

// matchCORSOrigin compares origin from the HTTP request to a list of allowed origins.
func matchCORSOrigin(allowed []string, origin string) string {
if origin == "" {
// Request has no Origin header.
return ""
}

if len(allowed) == 0 {
// Not configured
return ""
}

if allowed[0] == "*" {
return "*"
}

for _, val := range allowed {
if val == origin {
return origin
}
}

return ""
}

func matchCORSMethod(allowMethods []string, method string) bool {
if method == "" {
// Request has no Method header.
return false
}

method = strings.ToUpper(method)
for _, mm := range allowMethods {
if mm == method {
return true
}
}

return false
}

// CORSHandler is the default preflight OPTIONS processor for use by media handlers.
func CORSHandler(req *http.Request, allowedOrigins []string, serve bool) (http.Header, int) {
if req.Method != http.MethodOptions {
// Not an OPTIONS request. No special handling for all other requests.
return nil, 0
}

headers := map[string][]string{
// Always add Vary because of possible intermediate caches.
"Vary": []string{"Origin", "Access-Control-Request-Method"},
}

allowedOrigin := matchCORSOrigin(allowedOrigins, req.Header.Get("Origin"))
if allowedOrigin == "" {
// CORS policy does not match the origin.
return headers, http.StatusOK
}

var allowMethods []string
if serve {
allowMethods = []string{http.MethodGet, http.MethodHead, http.MethodOptions}
} else {
allowMethods = []string{http.MethodPost, http.MethodPut, http.MethodHead, http.MethodOptions}
}

if !matchCORSMethod(allowMethods, req.Header.Get("Access-Control-Request-Method")) {
// CORS policy does not allow this method.
return headers, http.StatusOK
}

headers["Access-Control-Allow-Origin"] = []string{allowedOrigin}
headers["Access-Control-Allow-Headers"] = []string{"*"}
headers["Access-Control-Allow-Methods"] = []string{strings.Join(allowMethods, ",")}
headers["Access-Control-Max-Age"] = []string{"86400"}
headers["Access-Control-Allow-Credentials"] = []string{"true"}

return headers, http.StatusOK
}
16 changes: 10 additions & 6 deletions server/media/s3/s3.go
Expand Up @@ -151,11 +151,15 @@ func (ah *awshandler) Init(jsconf string) error {
}

// Headers redirects GET, HEAD requests to the AWS server.
func (ah *awshandler) Headers(req *http.Request, serve bool) (map[string]string, int, error) {
if req.Method == http.MethodPut || req.Method == http.MethodPost || req.Method == http.MethodOptions {
func (ah *awshandler) Headers(req *http.Request, serve bool) (http.Header, int, error) {
if req.Method == http.MethodPut || req.Method == http.MethodPost {
return nil, 0, nil
}

if headers, status := media.CORSHandler(req, ah.conf.CorsOrigins, serve); status != 0 {
return headers, status, nil
}

fid := ah.GetIdFromUrl(req.URL.String())
if fid.IsZero() {
return nil, 0, types.ErrNotFound
Expand Down Expand Up @@ -184,10 +188,10 @@ func (ah *awshandler) Headers(req *http.Request, serve bool) (map[string]string,
// Return presigned URL. The URL will stop working after a short period of time to prevent use of Tinode
// as a free file server.
url, err := awsReq.Presign(time.Second * presignDuration)
headers := map[string]string{
"Location": url,
"Content-Type": "application/json; charset=utf-8",
"Cache-Control": "no-cache, no-store, must-revalidate",
headers := map[string][]string{
"Location": []string{url},
"Content-Type": []string{"application/json; charset=utf-8"},
"Cache-Control": []string{"no-cache, no-store, must-revalidate"},
}
return headers, http.StatusTemporaryRedirect, err
}
Expand Down

0 comments on commit b568849

Please sign in to comment.