Skip to content

Commit

Permalink
Merge pull request #606 from iamjem/develop
Browse files Browse the repository at this point in the history
CSRF Support
  • Loading branch information
brendensoares committed Oct 24, 2014
2 parents 084bb15 + 4a35e27 commit 48cc8bc
Show file tree
Hide file tree
Showing 4 changed files with 376 additions and 0 deletions.
116 changes: 116 additions & 0 deletions modules/csrf/app/csrf.go
@@ -0,0 +1,116 @@
package csrf

import (
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"html/template"
"io"
"math"
"net/url"

"github.com/revel/revel"
)

// allowMethods are HTTP methods that do NOT require a token
var allowedMethods = map[string]bool{
"GET": true,
"HEAD": true,
"OPTIONS": true,
"TRACE": true,
}

func RandomString(length int) (string, error) {
buffer := make([]byte, int(math.Ceil(float64(length)/2)))
if _, err := io.ReadFull(rand.Reader, buffer); err != nil {
return "", nil
}
str := hex.EncodeToString(buffer)
return str[:length], nil
}

func RefreshToken(c *revel.Controller) {
token, err := RandomString(64)
if err != nil {
panic(err)
}
c.Session["csrf_token"] = token
}

func CsrfFilter(c *revel.Controller, fc []revel.Filter) {
token, foundToken := c.Session["csrf_token"]

if !foundToken {
RefreshToken(c)
}

referer, refErr := url.Parse(c.Request.Header.Get("Referer"))
isSameOrigin := sameOrigin(c.Request.URL, referer)

// If the Request method isn't in the white listed methods
if !allowedMethods[c.Request.Method] && !IsExempt(c) {
// Token wasn't present at all
if !foundToken {
c.Result = c.Forbidden("REVEL CSRF: Session token missing.")
return
}

// Referer header is invalid
if refErr != nil {
c.Result = c.Forbidden("REVEL CSRF: HTTP Referer malformed.")
return
}

// Same origin
if !isSameOrigin {
c.Result = c.Forbidden("REVEL CSRF: Same origin mismatch.")
return
}

var requestToken string
// First check for token in post data
if c.Request.Method == "POST" {
requestToken = c.Request.FormValue("csrftoken")
}

// Then check for token in custom headers, as with AJAX
if requestToken == "" {
requestToken = c.Request.Header.Get("X-CSRFToken")
}

if requestToken == "" || !compareToken(requestToken, token) {
c.Result = c.Forbidden("REVEL CSRF: Invalid token.")
return
}
}

fc[0](c, fc[1:])

// Only add token to RenderArgs if the request is: not AJAX, not missing referer header, and is same origin.
if c.Request.Header.Get("X-CSRFToken") == "" && isSameOrigin {
c.RenderArgs["_csrftoken"] = token
}
}

func compareToken(requestToken, token string) bool {
// ConstantTimeCompare will panic if the []byte aren't the same length
if len(requestToken) != len(token) {
return false
}
return subtle.ConstantTimeCompare([]byte(requestToken), []byte(token)) == 1
}

// Validates same origin policy
func sameOrigin(u1, u2 *url.URL) bool {
return u1.Scheme == u2.Scheme && u1.Host == u2.Host
}

func init() {
revel.TemplateFuncs["csrftoken"] = func(renderArgs map[string]interface{}) template.HTML {
if tokenFunc, ok := renderArgs["_csrftoken"]; !ok {
panic("REVEL CSRF: _csrftoken missing from RenderArgs.")
} else {
return template.HTML(tokenFunc.(func() string)())
}
}
}
169 changes: 169 additions & 0 deletions modules/csrf/app/csrf_test.go
@@ -0,0 +1,169 @@
package csrf

import (
"bytes"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"

"github.com/revel/revel"
)

var testFilters = []revel.Filter{
CsrfFilter,
func(c *revel.Controller, fc []revel.Filter) {
c.RenderHtml("{{ csrftoken . }}")
},
}

func TestTokenInSession(t *testing.T) {
resp := httptest.NewRecorder()
getRequest, _ := http.NewRequest("GET", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(getRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)

testFilters[0](c, testFilters)

if _, ok := c.Session["csrf_token"]; !ok {
t.Fatal("token should be present in session")
}
}

func TestPostWithoutToken(t *testing.T) {
resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)

testFilters[0](c, testFilters)

if c.Response.Status != 403 {
t.Fatal("post without token should be forbidden")
}
}

func TestNoReferrer(t *testing.T) {
resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)

c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)

RefreshToken(c)
token := c.Session["csrf_token"]

// make a new request with the token
data := url.Values{}
data.Set("csrftoken", token)
formPostRequest, _ := http.NewRequest("POST", "http://www.example.com/", bytes.NewBufferString(data.Encode()))
formPostRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded")
formPostRequest.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))

// and replace the old request
c.Request = revel.NewRequest(formPostRequest)

testFilters[0](c, testFilters)

if c.Response.Status != 403 {
t.Fatal("post without referer should be forbidden")
}
}

func TestRefererHttps(t *testing.T) {
resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))

c.Session = make(revel.Session)

RefreshToken(c)
token := c.Session["csrf_token"]

// make a new request with the token
data := url.Values{}
data.Set("csrftoken", token)
formPostRequest, _ := http.NewRequest("POST", "https://www.example.com/", bytes.NewBufferString(data.Encode()))
formPostRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded")
formPostRequest.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))
formPostRequest.Header.Add("Referer", "http://www.example.com/")

// and replace the old request
c.Request = revel.NewRequest(formPostRequest)

testFilters[0](c, testFilters)

if c.Response.Status != 403 {
t.Fatal("posts to https should have an https referer")
}
}

func TestHeaderWithToken(t *testing.T) {
resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))

c.Session = make(revel.Session)

RefreshToken(c)
token := c.Session["csrf_token"]

// make a new request with the token
formPostRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
formPostRequest.Header.Add("X-CSRFToken", token)
formPostRequest.Header.Add("Referer", "http://www.example.com/")

// and replace the old request
c.Request = revel.NewRequest(formPostRequest)

testFilters[0](c, testFilters)

if c.Response.Status == 403 {
t.Fatal("post with http header token should be allowed")
}
}

func TestFormPostWithToken(t *testing.T) {
resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))

c.Session = make(revel.Session)

RefreshToken(c)
token := c.Session["csrf_token"]

// make a new request with the token
data := url.Values{}
data.Set("csrftoken", token)
formPostRequest, _ := http.NewRequest("POST", "http://www.example.com/", bytes.NewBufferString(data.Encode()))
formPostRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded")
formPostRequest.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))
formPostRequest.Header.Add("Referer", "http://www.example.com/")

// and replace the old request
c.Request = revel.NewRequest(formPostRequest)

testFilters[0](c, testFilters)

if c.Response.Status == 403 {
t.Fatal("form post with token should be allowed")
}
}

func TestNoTokenInArgsWhenCORs(t *testing.T) {
resp := httptest.NewRecorder()

getRequest, _ := http.NewRequest("GET", "http://www.example1.com/", nil)
getRequest.Header.Add("Referer", "http://www.example2.com/")

c := revel.NewController(revel.NewRequest(getRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)

testFilters[0](c, testFilters)

if _, ok := c.RenderArgs["_csrftoken"]; ok {
t.Fatal("RenderArgs should not contain token when not same origin")
}
}
36 changes: 36 additions & 0 deletions modules/csrf/app/exempt.go
@@ -0,0 +1,36 @@
package csrf

import (
"fmt"
"strings"

"github.com/revel/revel"
)

var (
exemptPath = make(map[string]bool)
exemptAction = make(map[string]bool)
)

func MarkExempt(route string) {
if strings.HasPrefix(route, "/") {
// e.g. "/controller/action"
exemptPath[strings.ToLower(route)] = true
} else if routeParts := strings.Split(route, "."); len(routeParts) == 2 {
// e.g. "ControllerName.ActionName"
exemptAction[route] = true
} else {
err := fmt.Sprintf("csrf.MarkExempt() received invalid argument \"%v\". Either provide a path prefixed with \"/\" or controller action in the form of \"ControllerName.ActionName\".", route)
panic(err)
}
}

func IsExempt(c *revel.Controller) bool {
if _, ok := exemptPath[strings.ToLower(c.Request.Request.URL.Path)]; ok {
return true
} else if _, ok := exemptAction[c.Action]; ok {
return true
}

return false
}
55 changes: 55 additions & 0 deletions modules/csrf/app/exempt_test.go
@@ -0,0 +1,55 @@
package csrf

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/revel/revel"
)

func TestExemptPath(t *testing.T) {
MarkExempt("/Controller/Action")

resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/Controller/Action", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)

testFilters[0](c, testFilters)

if c.Response.Status == 403 {
t.Fatal("post to csrf exempt action should pass")
}
}

func TestExemptPathCaseInsensitive(t *testing.T) {
MarkExempt("/Controller/Action")

resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/controller/action", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)

testFilters[0](c, testFilters)

if c.Response.Status == 403 {
t.Fatal("post to csrf exempt action should pass")
}
}

func TestExemptAction(t *testing.T) {
MarkExempt("Controller.Action")

resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/Controller/Action", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)
c.Action = "Controller.Action"

testFilters[0](c, testFilters)

if c.Response.Status == 403 {
t.Fatal("post to csrf exempt action should pass")
}
}

0 comments on commit 48cc8bc

Please sign in to comment.