Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CSRF Support #606

Merged
merged 9 commits into from Oct 24, 2014
116 changes: 116 additions & 0 deletions modules/csrf/app/csrf.go
@@ -0,0 +1,116 @@
package csrf

import (
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"html/template"
"io"
"math"
"net/url"

"github.com/revel/revel"
)

// allowMethods are HTTP methods that do NOT require a token
var allowedMethods = map[string]bool{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment here would be useful to explain that allowedMethods don't require a token.

"GET": true,
"HEAD": true,
"OPTIONS": true,
"TRACE": true,
}

func RandomString(length int) (string, error) {
buffer := make([]byte, int(math.Ceil(float64(length)/2)))
if _, err := io.ReadFull(rand.Reader, buffer); err != nil {
return "", nil
}
str := hex.EncodeToString(buffer)
return str[:length], nil
}

func RefreshToken(c *revel.Controller) {
token, err := RandomString(64)
if err != nil {
panic(err)
}
c.Session["csrf_token"] = token
}

func CsrfFilter(c *revel.Controller, fc []revel.Filter) {
token, foundToken := c.Session["csrf_token"]

if !foundToken {
RefreshToken(c)
}

referer, refErr := url.Parse(c.Request.Header.Get("Referer"))
isSameOrigin := sameOrigin(c.Request.URL, referer)

// If the Request method isn't in the white listed methods
if !allowedMethods[c.Request.Method] && !IsExempt(c) {
// Token wasn't present at all
if !foundToken {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, can foundToken be ever false on line 53? Haven't we already checked it on line 43 and called RefreshToken(c)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AnonX yes, if the token wasn't sent in the form (right?).

c.Result = c.Forbidden("REVEL CSRF: Session token missing.")
return
}

// Referer header is invalid
if refErr != nil {
c.Result = c.Forbidden("REVEL CSRF: HTTP Referer malformed.")
return
}

// Same origin
if !isSameOrigin {
c.Result = c.Forbidden("REVEL CSRF: Same origin mismatch.")
return
}

var requestToken string
// First check for token in post data
if c.Request.Method == "POST" {
requestToken = c.Request.FormValue("csrftoken")
}

// Then check for token in custom headers, as with AJAX
if requestToken == "" {
requestToken = c.Request.Header.Get("X-CSRFToken")
}

if requestToken == "" || !compareToken(requestToken, token) {
c.Result = c.Forbidden("REVEL CSRF: Invalid token.")
return
}
}

fc[0](c, fc[1:])

// Only add token to RenderArgs if the request is: not AJAX, not missing referer header, and is same origin.
if c.Request.Header.Get("X-CSRFToken") == "" && isSameOrigin {
c.RenderArgs["_csrftoken"] = token
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point the return from function sameOrigin must be true because you checked for it being false on line 69, you should be able to remove the criteria from the if statement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats only if its not a white listed method. If someone does an AJAX request cross-domain using a safe method (like GET) we need to check it there too. I could possibly pull the check up outside the if statement though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at the if statement on line 69 again if !sameOrigin(c.Request.URL, referrer) {... return}, so if sameOrigin returns false then the code returns. It will never reach line 94 to call the same function sameOrigin(c.Request.URL, referrer) which must be true if it got this far

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but line 69 happens only when the condition on line 55 is true (only if the HTTP method is not safe, like a POST). The check on line 94 always gets called when its either a safe HTTP method (like a GET request), OR the unsafe request passed the previous tests inside the line 55 if statement. The reason the sameOrigin() call is on line 94 is to ensure that tokens aren't leaked in "safe" (like an HTTP GET) cross origin requests, for example if someone built an API that works cross origin.

Like I mentioned though, since the condition is always needed at some point in the filter, it could be move above line 55. I'll do that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite right, harder to visual the code on a screen without bracket highlights..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I was about to pull open my editor because these diff views are a bit lacking in legibility.

}
}

func compareToken(requestToken, token string) bool {
// ConstantTimeCompare will panic if the []byte aren't the same length
if len(requestToken) != len(token) {
return false
}
return subtle.ConstantTimeCompare([]byte(requestToken), []byte(token)) == 1
}

// Validates same origin policy
func sameOrigin(u1, u2 *url.URL) bool {
return u1.Scheme == u2.Scheme && u1.Host == u2.Host
}

func init() {
revel.TemplateFuncs["csrftoken"] = func(renderArgs map[string]interface{}) template.HTML {
if tokenFunc, ok := renderArgs["_csrftoken"]; !ok {
panic("REVEL CSRF: _csrftoken missing from RenderArgs.")
} else {
return template.HTML(tokenFunc.(func() string)())
}
}
}
169 changes: 169 additions & 0 deletions modules/csrf/app/csrf_test.go
@@ -0,0 +1,169 @@
package csrf

import (
"bytes"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"

"github.com/revel/revel"
)

var testFilters = []revel.Filter{
CsrfFilter,
func(c *revel.Controller, fc []revel.Filter) {
c.RenderHtml("{{ csrftoken . }}")
},
}

func TestTokenInSession(t *testing.T) {
resp := httptest.NewRecorder()
getRequest, _ := http.NewRequest("GET", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(getRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)

testFilters[0](c, testFilters)

if _, ok := c.Session["csrf_token"]; !ok {
t.Fatal("token should be present in session")
}
}

func TestPostWithoutToken(t *testing.T) {
resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)

testFilters[0](c, testFilters)

if c.Response.Status != 403 {
t.Fatal("post without token should be forbidden")
}
}

func TestNoReferrer(t *testing.T) {
resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)

c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)

RefreshToken(c)
token := c.Session["csrf_token"]

// make a new request with the token
data := url.Values{}
data.Set("csrftoken", token)
formPostRequest, _ := http.NewRequest("POST", "http://www.example.com/", bytes.NewBufferString(data.Encode()))
formPostRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded")
formPostRequest.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))

// and replace the old request
c.Request = revel.NewRequest(formPostRequest)

testFilters[0](c, testFilters)

if c.Response.Status != 403 {
t.Fatal("post without referer should be forbidden")
}
}

func TestRefererHttps(t *testing.T) {
resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))

c.Session = make(revel.Session)

RefreshToken(c)
token := c.Session["csrf_token"]

// make a new request with the token
data := url.Values{}
data.Set("csrftoken", token)
formPostRequest, _ := http.NewRequest("POST", "https://www.example.com/", bytes.NewBufferString(data.Encode()))
formPostRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded")
formPostRequest.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))
formPostRequest.Header.Add("Referer", "http://www.example.com/")

// and replace the old request
c.Request = revel.NewRequest(formPostRequest)

testFilters[0](c, testFilters)

if c.Response.Status != 403 {
t.Fatal("posts to https should have an https referer")
}
}

func TestHeaderWithToken(t *testing.T) {
resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))

c.Session = make(revel.Session)

RefreshToken(c)
token := c.Session["csrf_token"]

// make a new request with the token
formPostRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
formPostRequest.Header.Add("X-CSRFToken", token)
formPostRequest.Header.Add("Referer", "http://www.example.com/")

// and replace the old request
c.Request = revel.NewRequest(formPostRequest)

testFilters[0](c, testFilters)

if c.Response.Status == 403 {
t.Fatal("post with http header token should be allowed")
}
}

func TestFormPostWithToken(t *testing.T) {
resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))

c.Session = make(revel.Session)

RefreshToken(c)
token := c.Session["csrf_token"]

// make a new request with the token
data := url.Values{}
data.Set("csrftoken", token)
formPostRequest, _ := http.NewRequest("POST", "http://www.example.com/", bytes.NewBufferString(data.Encode()))
formPostRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded")
formPostRequest.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))
formPostRequest.Header.Add("Referer", "http://www.example.com/")

// and replace the old request
c.Request = revel.NewRequest(formPostRequest)

testFilters[0](c, testFilters)

if c.Response.Status == 403 {
t.Fatal("form post with token should be allowed")
}
}

func TestNoTokenInArgsWhenCORs(t *testing.T) {
resp := httptest.NewRecorder()

getRequest, _ := http.NewRequest("GET", "http://www.example1.com/", nil)
getRequest.Header.Add("Referer", "http://www.example2.com/")

c := revel.NewController(revel.NewRequest(getRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)

testFilters[0](c, testFilters)

if _, ok := c.RenderArgs["_csrftoken"]; ok {
t.Fatal("RenderArgs should not contain token when not same origin")
}
}
36 changes: 36 additions & 0 deletions modules/csrf/app/exempt.go
@@ -0,0 +1,36 @@
package csrf

import (
"fmt"
"strings"

"github.com/revel/revel"
)

var (
exemptPath = make(map[string]bool)
exemptAction = make(map[string]bool)
)

func MarkExempt(route string) {
if strings.HasPrefix(route, "/") {
// e.g. "/controller/action"
exemptPath[strings.ToLower(route)] = true
} else if routeParts := strings.Split(route, "."); len(routeParts) == 2 {
// e.g. "ControllerName.ActionName"
exemptAction[route] = true
} else {
err := fmt.Sprintf("csrf.MarkExempt() received invalid argument \"%v\". Either provide a path prefixed with \"/\" or controller action in the form of \"ControllerName.ActionName\".", route)
panic(err)
}
}

func IsExempt(c *revel.Controller) bool {
if _, ok := exemptPath[strings.ToLower(c.Request.Request.URL.Path)]; ok {
return true
} else if _, ok := exemptAction[c.Action]; ok {
return true
}

return false
}
55 changes: 55 additions & 0 deletions modules/csrf/app/exempt_test.go
@@ -0,0 +1,55 @@
package csrf

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/revel/revel"
)

func TestExemptPath(t *testing.T) {
MarkExempt("/Controller/Action")

resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/Controller/Action", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)

testFilters[0](c, testFilters)

if c.Response.Status == 403 {
t.Fatal("post to csrf exempt action should pass")
}
}

func TestExemptPathCaseInsensitive(t *testing.T) {
MarkExempt("/Controller/Action")

resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/controller/action", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)

testFilters[0](c, testFilters)

if c.Response.Status == 403 {
t.Fatal("post to csrf exempt action should pass")
}
}

func TestExemptAction(t *testing.T) {
MarkExempt("Controller.Action")

resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/Controller/Action", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)
c.Action = "Controller.Action"

testFilters[0](c, testFilters)

if c.Response.Status == 403 {
t.Fatal("post to csrf exempt action should pass")
}
}