-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathcsrf.go
139 lines (127 loc) · 3.98 KB
/
csrf.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
// Copyright 2015 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package util
import (
"bytes"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/http"
"time"
v23 "v.io/v23"
"v.io/v23/context"
"v.io/v23/vom"
)
const (
cookieLen = 16
)
// CSRFCop implements utilities for generating and validating tokens for
// cross-site-request-forgery prevention (also called XSRF).
type CSRFCop struct {
ctx *context.T
}
func NewCSRFCop(ctx *context.T) *CSRFCop {
return &CSRFCop{ctx: ctx}
}
// NewToken creates an anti-cross-site-request-forgery, aka CSRF aka XSRF token
// with some data bound to it that can be obtained by ValidateToken.
// It returns an error if the token could not be created.
func (c *CSRFCop) NewToken(w http.ResponseWriter, r *http.Request, cookieName string, data interface{}) (string, error) {
cookieValue, err := c.MaybeSetCookie(w, r, cookieName)
if err != nil {
return "", fmt.Errorf("bad cookie: %v", err)
}
var encData []byte
if data != nil {
if encData, err = vom.Encode(data); err != nil {
return "", err
}
}
hash := sha256.Sum256(cookieValue)
mac, err := NewMacaroon(v23.GetPrincipal(c.ctx), append(hash[:], encData...))
return string(mac), err
}
// ValidateToken checks the validity of the provided CSRF token for the
// provided request, and extracts the data encoded in the token into 'decoded'.
// If the token is invalid, return an error. This error should not be shown to end users,
// it is meant for the consumption by the server process only.
func (c *CSRFCop) ValidateToken(token string, req *http.Request, cookieName string, decoded interface{}) error {
cookie, err := req.Cookie(cookieName)
if err != nil {
return err
}
cookieValue, err := decodeCookieValue(cookie.Value)
if err != nil {
return fmt.Errorf("invalid cookie")
}
encodedInput, err := Macaroon(token).Decode(v23.GetPrincipal(c.ctx))
if err != nil {
return err
}
if len(encodedInput) < sha256.Size {
return fmt.Errorf("invalid token data: too short")
}
hash := sha256.Sum256(cookieValue)
if !bytes.Equal(hash[:], encodedInput[:sha256.Size]) {
return fmt.Errorf("invalid token data")
}
if decoded != nil {
if err := vom.Decode(encodedInput[sha256.Size:], decoded); err != nil {
return fmt.Errorf("invalid token data: %v", err)
}
}
return nil
}
func (c *CSRFCop) MaybeSetCookie(w http.ResponseWriter, req *http.Request, cookieName string) ([]byte, error) {
cookie, err := req.Cookie(cookieName)
switch err {
case nil:
if v, err := decodeCookieValue(cookie.Value); err == nil {
return v, nil
}
c.ctx.Infof("Invalid cookie: %#v, err: %v. Regenerating one.", cookie, err)
case http.ErrNoCookie:
// Intentionally blank: Cookie will be generated below.
default:
c.ctx.Infof("Error decoding cookie %q in request: %v. Regenerating one.", cookieName, err)
}
cookie, v := newCookie(c.ctx, cookieName)
if cookie == nil || v == nil {
return nil, fmt.Errorf("failed to create cookie")
}
http.SetCookie(w, cookie)
// We need to add the cookie to the request also to prevent repeatedly resetting cookies on multiple
// calls from the same request.
req.AddCookie(cookie)
return v, nil
}
func newCookie(ctx *context.T, cookieName string) (*http.Cookie, []byte) {
b := make([]byte, cookieLen)
if _, err := rand.Read(b); err != nil {
ctx.Errorf("newCookie failed: %v", err)
return nil, nil
}
return &http.Cookie{
Name: cookieName,
Value: b64encode(b),
Expires: time.Now().Add(time.Hour * 24),
HttpOnly: true,
Secure: true,
Path: "/",
}, b
}
func decodeCookieValue(v string) ([]byte, error) {
b, err := b64decode(v)
if err != nil {
return nil, err
}
if len(b) != cookieLen {
return nil, fmt.Errorf("invalid cookie length[%d]", len(b))
}
return b, nil
}
// Shorthands.
func b64encode(b []byte) string { return base64.URLEncoding.EncodeToString(b) }
func b64decode(s string) ([]byte, error) { return base64.URLEncoding.DecodeString(s) }