Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #606 from iamjem/develop
CSRF Support
- Loading branch information
Showing
4 changed files
with
376 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
package csrf | ||
|
||
import ( | ||
"crypto/rand" | ||
"crypto/subtle" | ||
"encoding/hex" | ||
"html/template" | ||
"io" | ||
"math" | ||
"net/url" | ||
|
||
"github.com/revel/revel" | ||
) | ||
|
||
// allowMethods are HTTP methods that do NOT require a token | ||
var allowedMethods = map[string]bool{ | ||
"GET": true, | ||
"HEAD": true, | ||
"OPTIONS": true, | ||
"TRACE": true, | ||
} | ||
|
||
func RandomString(length int) (string, error) { | ||
buffer := make([]byte, int(math.Ceil(float64(length)/2))) | ||
if _, err := io.ReadFull(rand.Reader, buffer); err != nil { | ||
return "", nil | ||
} | ||
str := hex.EncodeToString(buffer) | ||
return str[:length], nil | ||
} | ||
|
||
func RefreshToken(c *revel.Controller) { | ||
token, err := RandomString(64) | ||
if err != nil { | ||
panic(err) | ||
} | ||
c.Session["csrf_token"] = token | ||
} | ||
|
||
func CsrfFilter(c *revel.Controller, fc []revel.Filter) { | ||
token, foundToken := c.Session["csrf_token"] | ||
|
||
if !foundToken { | ||
RefreshToken(c) | ||
} | ||
|
||
referer, refErr := url.Parse(c.Request.Header.Get("Referer")) | ||
isSameOrigin := sameOrigin(c.Request.URL, referer) | ||
|
||
// If the Request method isn't in the white listed methods | ||
if !allowedMethods[c.Request.Method] && !IsExempt(c) { | ||
// Token wasn't present at all | ||
if !foundToken { | ||
c.Result = c.Forbidden("REVEL CSRF: Session token missing.") | ||
return | ||
} | ||
|
||
// Referer header is invalid | ||
if refErr != nil { | ||
c.Result = c.Forbidden("REVEL CSRF: HTTP Referer malformed.") | ||
return | ||
} | ||
|
||
// Same origin | ||
if !isSameOrigin { | ||
c.Result = c.Forbidden("REVEL CSRF: Same origin mismatch.") | ||
return | ||
} | ||
|
||
var requestToken string | ||
// First check for token in post data | ||
if c.Request.Method == "POST" { | ||
requestToken = c.Request.FormValue("csrftoken") | ||
} | ||
|
||
// Then check for token in custom headers, as with AJAX | ||
if requestToken == "" { | ||
requestToken = c.Request.Header.Get("X-CSRFToken") | ||
} | ||
|
||
if requestToken == "" || !compareToken(requestToken, token) { | ||
c.Result = c.Forbidden("REVEL CSRF: Invalid token.") | ||
return | ||
} | ||
} | ||
|
||
fc[0](c, fc[1:]) | ||
|
||
// Only add token to RenderArgs if the request is: not AJAX, not missing referer header, and is same origin. | ||
if c.Request.Header.Get("X-CSRFToken") == "" && isSameOrigin { | ||
c.RenderArgs["_csrftoken"] = token | ||
} | ||
} | ||
|
||
func compareToken(requestToken, token string) bool { | ||
// ConstantTimeCompare will panic if the []byte aren't the same length | ||
if len(requestToken) != len(token) { | ||
return false | ||
} | ||
return subtle.ConstantTimeCompare([]byte(requestToken), []byte(token)) == 1 | ||
} | ||
|
||
// Validates same origin policy | ||
func sameOrigin(u1, u2 *url.URL) bool { | ||
return u1.Scheme == u2.Scheme && u1.Host == u2.Host | ||
} | ||
|
||
func init() { | ||
revel.TemplateFuncs["csrftoken"] = func(renderArgs map[string]interface{}) template.HTML { | ||
if tokenFunc, ok := renderArgs["_csrftoken"]; !ok { | ||
panic("REVEL CSRF: _csrftoken missing from RenderArgs.") | ||
} else { | ||
return template.HTML(tokenFunc.(func() string)()) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
package csrf | ||
|
||
import ( | ||
"bytes" | ||
"net/http" | ||
"net/http/httptest" | ||
"net/url" | ||
"strconv" | ||
"testing" | ||
|
||
"github.com/revel/revel" | ||
) | ||
|
||
var testFilters = []revel.Filter{ | ||
CsrfFilter, | ||
func(c *revel.Controller, fc []revel.Filter) { | ||
c.RenderHtml("{{ csrftoken . }}") | ||
}, | ||
} | ||
|
||
func TestTokenInSession(t *testing.T) { | ||
resp := httptest.NewRecorder() | ||
getRequest, _ := http.NewRequest("GET", "http://www.example.com/", nil) | ||
c := revel.NewController(revel.NewRequest(getRequest), revel.NewResponse(resp)) | ||
c.Session = make(revel.Session) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if _, ok := c.Session["csrf_token"]; !ok { | ||
t.Fatal("token should be present in session") | ||
} | ||
} | ||
|
||
func TestPostWithoutToken(t *testing.T) { | ||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil) | ||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
c.Session = make(revel.Session) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status != 403 { | ||
t.Fatal("post without token should be forbidden") | ||
} | ||
} | ||
|
||
func TestNoReferrer(t *testing.T) { | ||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil) | ||
|
||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
c.Session = make(revel.Session) | ||
|
||
RefreshToken(c) | ||
token := c.Session["csrf_token"] | ||
|
||
// make a new request with the token | ||
data := url.Values{} | ||
data.Set("csrftoken", token) | ||
formPostRequest, _ := http.NewRequest("POST", "http://www.example.com/", bytes.NewBufferString(data.Encode())) | ||
formPostRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded") | ||
formPostRequest.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) | ||
|
||
// and replace the old request | ||
c.Request = revel.NewRequest(formPostRequest) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status != 403 { | ||
t.Fatal("post without referer should be forbidden") | ||
} | ||
} | ||
|
||
func TestRefererHttps(t *testing.T) { | ||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil) | ||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
|
||
c.Session = make(revel.Session) | ||
|
||
RefreshToken(c) | ||
token := c.Session["csrf_token"] | ||
|
||
// make a new request with the token | ||
data := url.Values{} | ||
data.Set("csrftoken", token) | ||
formPostRequest, _ := http.NewRequest("POST", "https://www.example.com/", bytes.NewBufferString(data.Encode())) | ||
formPostRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded") | ||
formPostRequest.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) | ||
formPostRequest.Header.Add("Referer", "http://www.example.com/") | ||
|
||
// and replace the old request | ||
c.Request = revel.NewRequest(formPostRequest) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status != 403 { | ||
t.Fatal("posts to https should have an https referer") | ||
} | ||
} | ||
|
||
func TestHeaderWithToken(t *testing.T) { | ||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil) | ||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
|
||
c.Session = make(revel.Session) | ||
|
||
RefreshToken(c) | ||
token := c.Session["csrf_token"] | ||
|
||
// make a new request with the token | ||
formPostRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil) | ||
formPostRequest.Header.Add("X-CSRFToken", token) | ||
formPostRequest.Header.Add("Referer", "http://www.example.com/") | ||
|
||
// and replace the old request | ||
c.Request = revel.NewRequest(formPostRequest) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status == 403 { | ||
t.Fatal("post with http header token should be allowed") | ||
} | ||
} | ||
|
||
func TestFormPostWithToken(t *testing.T) { | ||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil) | ||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
|
||
c.Session = make(revel.Session) | ||
|
||
RefreshToken(c) | ||
token := c.Session["csrf_token"] | ||
|
||
// make a new request with the token | ||
data := url.Values{} | ||
data.Set("csrftoken", token) | ||
formPostRequest, _ := http.NewRequest("POST", "http://www.example.com/", bytes.NewBufferString(data.Encode())) | ||
formPostRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded") | ||
formPostRequest.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) | ||
formPostRequest.Header.Add("Referer", "http://www.example.com/") | ||
|
||
// and replace the old request | ||
c.Request = revel.NewRequest(formPostRequest) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status == 403 { | ||
t.Fatal("form post with token should be allowed") | ||
} | ||
} | ||
|
||
func TestNoTokenInArgsWhenCORs(t *testing.T) { | ||
resp := httptest.NewRecorder() | ||
|
||
getRequest, _ := http.NewRequest("GET", "http://www.example1.com/", nil) | ||
getRequest.Header.Add("Referer", "http://www.example2.com/") | ||
|
||
c := revel.NewController(revel.NewRequest(getRequest), revel.NewResponse(resp)) | ||
c.Session = make(revel.Session) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if _, ok := c.RenderArgs["_csrftoken"]; ok { | ||
t.Fatal("RenderArgs should not contain token when not same origin") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package csrf | ||
|
||
import ( | ||
"fmt" | ||
"strings" | ||
|
||
"github.com/revel/revel" | ||
) | ||
|
||
var ( | ||
exemptPath = make(map[string]bool) | ||
exemptAction = make(map[string]bool) | ||
) | ||
|
||
func MarkExempt(route string) { | ||
if strings.HasPrefix(route, "/") { | ||
// e.g. "/controller/action" | ||
exemptPath[strings.ToLower(route)] = true | ||
} else if routeParts := strings.Split(route, "."); len(routeParts) == 2 { | ||
// e.g. "ControllerName.ActionName" | ||
exemptAction[route] = true | ||
} else { | ||
err := fmt.Sprintf("csrf.MarkExempt() received invalid argument \"%v\". Either provide a path prefixed with \"/\" or controller action in the form of \"ControllerName.ActionName\".", route) | ||
panic(err) | ||
} | ||
} | ||
|
||
func IsExempt(c *revel.Controller) bool { | ||
if _, ok := exemptPath[strings.ToLower(c.Request.Request.URL.Path)]; ok { | ||
return true | ||
} else if _, ok := exemptAction[c.Action]; ok { | ||
return true | ||
} | ||
|
||
return false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
package csrf | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/revel/revel" | ||
) | ||
|
||
func TestExemptPath(t *testing.T) { | ||
MarkExempt("/Controller/Action") | ||
|
||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/Controller/Action", nil) | ||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
c.Session = make(revel.Session) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status == 403 { | ||
t.Fatal("post to csrf exempt action should pass") | ||
} | ||
} | ||
|
||
func TestExemptPathCaseInsensitive(t *testing.T) { | ||
MarkExempt("/Controller/Action") | ||
|
||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/controller/action", nil) | ||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
c.Session = make(revel.Session) | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status == 403 { | ||
t.Fatal("post to csrf exempt action should pass") | ||
} | ||
} | ||
|
||
func TestExemptAction(t *testing.T) { | ||
MarkExempt("Controller.Action") | ||
|
||
resp := httptest.NewRecorder() | ||
postRequest, _ := http.NewRequest("POST", "http://www.example.com/Controller/Action", nil) | ||
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp)) | ||
c.Session = make(revel.Session) | ||
c.Action = "Controller.Action" | ||
|
||
testFilters[0](c, testFilters) | ||
|
||
if c.Response.Status == 403 { | ||
t.Fatal("post to csrf exempt action should pass") | ||
} | ||
} |