Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CSRF Support #606

Merged
merged 9 commits into from Oct 24, 2014
145 changes: 145 additions & 0 deletions modules/csrf/app/csrf.go
@@ -0,0 +1,145 @@
package csrf

import (
"crypto/rand"
"crypto/sha1"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"fmt"
"github.com/revel/revel"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move the revel import down to its own section. We're trying to separate the imports into "built-in" and "other" sections.

"html/template"
"io"
"math"
"net/url"
)

var allowedMethods = map[string]bool{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment here would be useful to explain that allowedMethods don't require a token.

"GET": true,
"HEAD": true,
"OPTIONS": true,
"TRACE": true,
}

func NewToken(c *revel.Controller) string {
token := c.Request.Header.Get("Csrf-Token")
if token == "" {
token = saltedToken(c.Session["csrfSecret"])
c.Request.Header.Set("Csrf-Token", token)
}
return token
}

func NewSecret() (string, error) {
return RandomString(64)
}

func RandomString(length int) (string, error) {
buffer := make([]byte, int(math.Ceil(float64(length)/2)))
if _, err := rand.Read(buffer); err != nil {
return "", err
}
str := hex.EncodeToString(buffer)
return str[:length], nil
}

func RefreshSecret(c *revel.Controller) {
csrfSecret, err := RandomString(64)
if err != nil {
panic(err)
}
c.Session["csrfSecret"] = csrfSecret
}

func CsrfFilter(c *revel.Controller, fc []revel.Filter) {
csrfSecret, foundSecret := c.Session["csrfSecret"]

if !foundSecret {
RefreshSecret(c)
}

// TODO: Add a hook for csrf exempt?
// Whatever the flag, it needs to be visible during the request cycle...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c.Request.Header.Get("Origin")


// If the Request method isn't in the white listed methods
if !allowedMethods[c.Request.Method] {
// Token wasn't present at all
if !foundSecret {
reject(c)
return
}
// Referrer header isn't present
referer := c.Request.Referer()
if referer == "" {
reject(c)
return
}
// Referer is invalid, or the requested
// page is HTTPS but the Referer is NOT HTTPS
refUrl, err := url.Parse(referer)
if err != nil || c.Request.URL.Scheme == "https" && refUrl.Scheme != "https" {
reject(c)
return
}

var requestCsrfToken string
// First check post data
if c.Request.Method == "POST" {
requestCsrfToken = c.Request.FormValue("csrftoken")
}

// Then check custom headers, as with AJAX
if requestCsrfToken == "" {
requestCsrfToken = c.Request.Header.Get("X-CSRFToken")
}

if requestCsrfToken == "" || !checkToken(requestCsrfToken, csrfSecret) {
reject(c)
return
}
}

fc[0](c, fc[1:])

c.RenderArgs["_csrftoken"] = func() string {
return NewToken(c)
}
}

func createToken(salt, secret string) string {
hash := sha1.New()
io.WriteString(hash, salt+secret)
return salt + base64.StdEncoding.EncodeToString(hash.Sum(nil))
}

func checkToken(requestCsrfToken, secret string) bool {
csrfToken := createToken(requestCsrfToken[0:10], secret)
// ConstantTimeCompare will panic if the []byte aren't the same length
if len(requestCsrfToken) != len(csrfToken) {
return false
}

return subtle.ConstantTimeCompare([]byte(requestCsrfToken), []byte(csrfToken)) == 1
}

func saltedToken(secret string) string {
salt, err := RandomString(10)
if err != nil {
panic(err)
}
return createToken(salt, secret)
}

func reject(c *revel.Controller) {
c.Response.Out.WriteHeader(403)
}

func init() {
revel.TemplateFuncs["csrftoken"] = func(renderArgs map[string]interface{}) template.HTML {
tokenFunc, ok := renderArgs["_csrftoken"]
if !ok {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 2 lines instead of 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, definitely some rough edges left from the refactoring that needs to be cleaned up still. When we've got some consolidated feedback I'll go through and incorporate everything at once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this again. Is your comment about line 134? The tokenFunc variable gets used outside of the if statement, which is why I didn't declare them in the if's initialization statement.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could do something like

    if tokenFunc, ok := renderArgs["_csrftoken"];!ok {
        panic("REVEL CSRF: _csrftoken missing from RenderArgs.")
    } else {
        return template.HTML(tokenFunc.(func() string)())
    }

panic("_csrftoken missing from RenderArgs.")
}
return template.HTML(fmt.Sprintf(`<input type="hidden" name="csrftoken" value="%s">`, tokenFunc.(func() string)()))
}
}
192 changes: 192 additions & 0 deletions modules/csrf/app/csrf_test.go
@@ -0,0 +1,192 @@
package csrf

import (
"bytes"
"github.com/revel/revel"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
)

type fooController struct {
*revel.Controller
}

func TestSecretInSession(t *testing.T) {
resp := httptest.NewRecorder()
getRequest, _ := http.NewRequest("GET", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(getRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)
filters := []revel.Filter{
CsrfFilter,
func(c *revel.Controller, fc []revel.Filter) {
c.RenderHtml("{{ csrftoken . }}")
},
}

filters[0](c, filters)

if _, ok := c.Session["csrfSecret"]; !ok {
t.Fatal("secret should be present in session")
}
}

func TestPostWithoutToken(t *testing.T) {
resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))
c.Session = make(revel.Session)

filters := []revel.Filter{
CsrfFilter,
func(c *revel.Controller, fc []revel.Filter) {
c.RenderHtml("{{ csrftoken . }}")
},
}

filters[0](c, filters)

if resp.Code != 403 {
t.Fatal("post without token should be forbidden")
}
}

func TestNoReferrer(t *testing.T) {
resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))

c.Session = make(revel.Session)
secret, _ := NewSecret()

c.Session["csrfSecret"] = secret
token := NewToken(c)

// make a new request with the token
data := url.Values{}
data.Set("csrftoken", token)
formPostRequest, _ := http.NewRequest("POST", "http://www.example.com/", bytes.NewBufferString(data.Encode()))
formPostRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded")
formPostRequest.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))

// and replace the old request
c.Request = revel.NewRequest(formPostRequest)

filters := []revel.Filter{
CsrfFilter,
func(c *revel.Controller, fc []revel.Filter) {
c.RenderHtml("{{ csrftoken . }}")
},
}
filters[0](c, filters)

if resp.Code != 403 {
t.Fatal("post without referer should be forbidden")
}
}

func TestRefererHttps(t *testing.T) {
resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))

c.Session = make(revel.Session)
secret, _ := NewSecret()

c.Session["csrfSecret"] = secret
token := NewToken(c)

// make a new request with the token
data := url.Values{}
data.Set("csrftoken", token)
formPostRequest, _ := http.NewRequest("POST", "https://www.example.com/", bytes.NewBufferString(data.Encode()))
formPostRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded")
formPostRequest.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))
formPostRequest.Header.Add("Referer", "http://www.example.com/")

// and replace the old request
c.Request = revel.NewRequest(formPostRequest)

filters := []revel.Filter{
CsrfFilter,
func(c *revel.Controller, fc []revel.Filter) {
c.RenderHtml("{{ csrftoken . }}")
},
}
filters[0](c, filters)

if resp.Code != 403 {
t.Fatal("posts to https should have an https referer")
}
}

func TestHeaderWithToken(t *testing.T) {
resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))

c.Session = make(revel.Session)
secret, _ := NewSecret()

c.Session["csrfSecret"] = secret
token := NewToken(c)

// make a new request with the token
formPostRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
formPostRequest.Header.Add("X-CSRFToken", token)
formPostRequest.Header.Add("Referer", "http://www.example.com/")

// and replace the old request
c.Request = revel.NewRequest(formPostRequest)

filters := []revel.Filter{
CsrfFilter,
func(c *revel.Controller, fc []revel.Filter) {
c.RenderHtml("{{ csrftoken . }}")
},
}

filters[0](c, filters)

if resp.Code == 403 {
t.Fatal("post with http header token should be allowed")
}
}

func TestFormPostWithToken(t *testing.T) {
resp := httptest.NewRecorder()
postRequest, _ := http.NewRequest("POST", "http://www.example.com/", nil)
c := revel.NewController(revel.NewRequest(postRequest), revel.NewResponse(resp))

c.Session = make(revel.Session)
secret, _ := NewSecret()

c.Session["csrfSecret"] = secret
token := NewToken(c)

// make a new request with the token
data := url.Values{}
data.Set("csrftoken", token)
formPostRequest, _ := http.NewRequest("POST", "http://www.example.com/", bytes.NewBufferString(data.Encode()))
formPostRequest.Header.Add("Content-Type", "application/x-www-form-urlencoded")
formPostRequest.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))
formPostRequest.Header.Add("Referer", "http://www.example.com/")

// and replace the old request
c.Request = revel.NewRequest(formPostRequest)

filters := []revel.Filter{
CsrfFilter,
func(c *revel.Controller, fc []revel.Filter) {
c.RenderHtml("{{ csrftoken . }}")
},
}

filters[0](c, filters)

if resp.Code == 403 {
t.Fatal("form post with token should be allowed")
}
}