New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CSRF Support #606
CSRF Support #606
Changes from all commits
0383b41
0794e7f
a7ac1dd
7ac9116
1ed0355
1e04845
d2726db
0a80ada
4a35e27
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
package csrf | ||
|
||
import ( | ||
"crypto/rand" | ||
"crypto/subtle" | ||
"encoding/hex" | ||
"html/template" | ||
"io" | ||
"math" | ||
"net/url" | ||
|
||
"github.com/revel/revel" | ||
) | ||
|
||
// allowMethods are HTTP methods that do NOT require a token | ||
var allowedMethods = map[string]bool{ | ||
"GET": true, | ||
"HEAD": true, | ||
"OPTIONS": true, | ||
"TRACE": true, | ||
} | ||
|
||
func RandomString(length int) (string, error) { | ||
buffer := make([]byte, int(math.Ceil(float64(length)/2))) | ||
if _, err := io.ReadFull(rand.Reader, buffer); err != nil { | ||
return "", nil | ||
} | ||
str := hex.EncodeToString(buffer) | ||
return str[:length], nil | ||
} | ||
|
||
func RefreshToken(c *revel.Controller) { | ||
token, err := RandomString(64) | ||
if err != nil { | ||
panic(err) | ||
} | ||
c.Session["csrf_token"] = token | ||
} | ||
|
||
func CsrfFilter(c *revel.Controller, fc []revel.Filter) { | ||
token, foundToken := c.Session["csrf_token"] | ||
|
||
if !foundToken { | ||
RefreshToken(c) | ||
} | ||
|
||
referer, refErr := url.Parse(c.Request.Header.Get("Referer")) | ||
isSameOrigin := sameOrigin(c.Request.URL, referer) | ||
|
||
// If the Request method isn't in the white listed methods | ||
if !allowedMethods[c.Request.Method] && !IsExempt(c) { | ||
// Token wasn't present at all | ||
if !foundToken { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, can There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @AnonX yes, if the token wasn't sent in the form (right?). |
||
c.Result = c.Forbidden("REVEL CSRF: Session token missing.") | ||
return | ||
} | ||
|
||
// Referer header is invalid | ||
if refErr != nil { | ||
c.Result = c.Forbidden("REVEL CSRF: HTTP Referer malformed.") | ||
return | ||
} | ||
|
||
// Same origin | ||
if !isSameOrigin { | ||
c.Result = c.Forbidden("REVEL CSRF: Same origin mismatch.") | ||
return | ||
} | ||
|
||
var requestToken string | ||
// First check for token in post data | ||
if c.Request.Method == "POST" { | ||
requestToken = c.Request.FormValue("csrftoken") | ||
} | ||
|
||
// Then check for token in custom headers, as with AJAX | ||
if requestToken == "" { | ||
requestToken = c.Request.Header.Get("X-CSRFToken") | ||
} | ||
|
||
if requestToken == "" || !compareToken(requestToken, token) { | ||
c.Result = c.Forbidden("REVEL CSRF: Invalid token.") | ||
return | ||
} | ||
} | ||
|
||
fc[0](c, fc[1:]) | ||
|
||
// Only add token to RenderArgs if the request is: not AJAX, not missing referer header, and is same origin. | ||
if c.Request.Header.Get("X-CSRFToken") == "" && isSameOrigin { | ||
c.RenderArgs["_csrftoken"] = token | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At this point the return from function sameOrigin must be true because you checked for it being false on line 69, you should be able to remove the criteria from the if statement There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thats only if its not a white listed method. If someone does an AJAX request cross-domain using a safe method (like GET) we need to check it there too. I could possibly pull the check up outside the if statement though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Look at the if statement on line 69 again There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but line 69 happens only when the condition on line 55 is true (only if the HTTP method is not safe, like a POST). The check on line 94 always gets called when its either a safe HTTP method (like a GET request), OR the unsafe request passed the previous tests inside the line 55 if statement. The reason the sameOrigin() call is on line 94 is to ensure that tokens aren't leaked in "safe" (like an HTTP GET) cross origin requests, for example if someone built an API that works cross origin. Like I mentioned though, since the condition is always needed at some point in the filter, it could be move above line 55. I'll do that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quite right, harder to visual the code on a screen without bracket highlights.. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, I was about to pull open my editor because these diff views are a bit lacking in legibility. |
||
} | ||
} | ||
|
||
func compareToken(requestToken, token string) bool { | ||
// ConstantTimeCompare will panic if the []byte aren't the same length | ||
if len(requestToken) != len(token) { | ||
return false | ||
} | ||
return subtle.ConstantTimeCompare([]byte(requestToken), []byte(token)) == 1 | ||
} | ||
|
||
// Validates same origin policy | ||
func sameOrigin(u1, u2 *url.URL) bool { | ||
return u1.Scheme == u2.Scheme && u1.Host == u2.Host | ||
} | ||
|
||
func init() { | ||
revel.TemplateFuncs["csrftoken"] = func(renderArgs map[string]interface{}) template.HTML { | ||
if tokenFunc, ok := renderArgs["_csrftoken"]; !ok { | ||
panic("REVEL CSRF: _csrftoken missing from RenderArgs.") | ||
} else { | ||
return template.HTML(tokenFunc.(func() string)()) | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
package csrf | ||
|
||
import ( | ||
"bytes" | ||
"net/http" | ||
"net/http/httptest" | ||
"net/url" | ||
"strconv" | ||
"testing" | ||
|
||
"github.com/revel/revel" | ||
) | ||
|
||
var testFilters = []revel.Filter{ | ||
CsrfFilter, | ||
func(c *revel.Controller, fc []revel.Filter) { | ||
c.RenderHtml("{{ csrftoken . }}") | ||
}, | ||
} | ||
|
||
func TestTokenInSession(t *testing.T) { | ||
resp := httptest.NewRecorder() | ||
getRequest, _ := http.NewRequest("GET", "http://www.example.com/", nil) | ||
c := revel.NewController(revel.NewRequest(getRequest), revel.NewResponse(resp)) | ||
c.Session = make(revel.Session) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if _, ok := c.Session["csrf_token"]; !ok { | ||
t.Fatal("token should be present in session") | ||
} | ||
} | ||
|
||
func TestPostWithoutToken(t *testing.T) { | ||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil) | ||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
c.Session = make(revel.Session) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status != 403 { | ||
t.Fatal("post without token should be forbidden") | ||
} | ||
} | ||
|
||
func TestNoReferrer(t *testing.T) { | ||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil) | ||
|
||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
c.Session = make(revel.Session) | ||
|
||
RefreshToken(c) | ||
token := c.Session["csrf_token"] | ||
|
||
// make a new request with the token | ||
data := url.Values{} | ||
data.Set("csrftoken", token) | ||
formPostRequest, _ := http.NewRequest("POST", "http://www.example.com/", bytes.NewBufferString(data.Encode())) | ||
formPostRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded") | ||
formPostRequest.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) | ||
|
||
// and replace the old request | ||
c.Request = revel.NewRequest(formPostRequest) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status != 403 { | ||
t.Fatal("post without referer should be forbidden") | ||
} | ||
} | ||
|
||
func TestRefererHttps(t *testing.T) { | ||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil) | ||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
|
||
c.Session = make(revel.Session) | ||
|
||
RefreshToken(c) | ||
token := c.Session["csrf_token"] | ||
|
||
// make a new request with the token | ||
data := url.Values{} | ||
data.Set("csrftoken", token) | ||
formPostRequest, _ := http.NewRequest("POST", "https://www.example.com/", bytes.NewBufferString(data.Encode())) | ||
formPostRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded") | ||
formPostRequest.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) | ||
formPostRequest.Header.Add("Referer", "http://www.example.com/") | ||
|
||
// and replace the old request | ||
c.Request = revel.NewRequest(formPostRequest) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status != 403 { | ||
t.Fatal("posts to https should have an https referer") | ||
} | ||
} | ||
|
||
func TestHeaderWithToken(t *testing.T) { | ||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil) | ||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
|
||
c.Session = make(revel.Session) | ||
|
||
RefreshToken(c) | ||
token := c.Session["csrf_token"] | ||
|
||
// make a new request with the token | ||
formPostRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil) | ||
formPostRequest.Header.Add("X-CSRFToken", token) | ||
formPostRequest.Header.Add("Referer", "http://www.example.com/") | ||
|
||
// and replace the old request | ||
c.Request = revel.NewRequest(formPostRequest) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status == 403 { | ||
t.Fatal("post with http header token should be allowed") | ||
} | ||
} | ||
|
||
func TestFormPostWithToken(t *testing.T) { | ||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil) | ||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
|
||
c.Session = make(revel.Session) | ||
|
||
RefreshToken(c) | ||
token := c.Session["csrf_token"] | ||
|
||
// make a new request with the token | ||
data := url.Values{} | ||
data.Set("csrftoken", token) | ||
formPostRequest, _ := http.NewRequest("POST", "http://www.example.com/", bytes.NewBufferString(data.Encode())) | ||
formPostRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded") | ||
formPostRequest.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) | ||
formPostRequest.Header.Add("Referer", "http://www.example.com/") | ||
|
||
// and replace the old request | ||
c.Request = revel.NewRequest(formPostRequest) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status == 403 { | ||
t.Fatal("form post with token should be allowed") | ||
} | ||
} | ||
|
||
func TestNoTokenInArgsWhenCORs(t *testing.T) { | ||
resp := httptest.NewRecorder() | ||
|
||
getRequest, _ := http.NewRequest("GET", "http://www.example1.com/", nil) | ||
getRequest.Header.Add("Referer", "http://www.example2.com/") | ||
|
||
c := revel.NewController(revel.NewRequest(getRequest), revel.NewResponse(resp)) | ||
c.Session = make(revel.Session) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if _, ok := c.RenderArgs["_csrftoken"]; ok { | ||
t.Fatal("RenderArgs should not contain token when not same origin") | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package csrf | ||
|
||
import ( | ||
"fmt" | ||
"strings" | ||
|
||
"github.com/revel/revel" | ||
) | ||
|
||
var ( | ||
exemptPath = make(map[string]bool) | ||
exemptAction = make(map[string]bool) | ||
) | ||
|
||
func MarkExempt(route string) { | ||
if strings.HasPrefix(route, "/") { | ||
// e.g. "/controller/action" | ||
exemptPath[strings.ToLower(route)] = true | ||
} else if routeParts := strings.Split(route, "."); len(routeParts) == 2 { | ||
// e.g. "ControllerName.ActionName" | ||
exemptAction[route] = true | ||
} else { | ||
err := fmt.Sprintf("csrf.MarkExempt() received invalid argument \"%v\". Either provide a path prefixed with \"/\" or controller action in the form of \"ControllerName.ActionName\".", route) | ||
panic(err) | ||
} | ||
} | ||
|
||
func IsExempt(c *revel.Controller) bool { | ||
if _, ok := exemptPath[strings.ToLower(c.Request.Request.URL.Path)]; ok { | ||
return true | ||
} else if _, ok := exemptAction[c.Action]; ok { | ||
return true | ||
} | ||
|
||
return false | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
package csrf | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/revel/revel" | ||
) | ||
|
||
func TestExemptPath(t *testing.T) { | ||
MarkExempt("/Controller/Action") | ||
|
||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/Controller/Action", nil) | ||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
c.Session = make(revel.Session) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status == 403 { | ||
t.Fatal("post to csrf exempt action should pass") | ||
} | ||
} | ||
|
||
func TestExemptPathCaseInsensitive(t *testing.T) { | ||
MarkExempt("/Controller/Action") | ||
|
||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/controller/action", nil) | ||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
c.Session = make(revel.Session) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status == 403 { | ||
t.Fatal("post to csrf exempt action should pass") | ||
} | ||
} | ||
|
||
func TestExemptAction(t *testing.T) { | ||
MarkExempt("Controller.Action") | ||
|
||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/Controller/Action", nil) | ||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
c.Session = make(revel.Session) | ||
c.Action = "Controller.Action" | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status == 403 { | ||
t.Fatal("post to csrf exempt action should pass") | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment here would be useful to explain that
allowedMethods
don't require a token.