-
-
Notifications
You must be signed in to change notification settings - Fork 75
/
auth.go
121 lines (103 loc) · 4.36 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
package tdameritrade
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"net/http"
"golang.org/x/oauth2"
)
var (
// ErrNoCode is returned when the TD Ameritrade request is missing a code.
ErrNoCode = fmt.Errorf("missing code in request from TD Ameritrade")
// ErrNoState is returned when TD Ameritrade request is missing state, indicating a CSRF attempt.
ErrNoState = fmt.Errorf("missing state in request from TD Ameritrade")
)
// PersistentStore is meant to persist data from TD Ameritrade that is needed between requests.
// Implementations must return the same value they set for a user in StoreState in GetState, or the login process will fail.
// It is meant to allow credentials to be stored in cookies, JWTs and anything else you can think of.
type PersistentStore interface {
StoreToken(token *oauth2.Token, w http.ResponseWriter, req *http.Request) error
GetToken(req *http.Request) (*oauth2.Token, error)
StoreState(state string, w http.ResponseWriter, req *http.Request) error
GetState(*http.Request) (string, error)
}
// Authenticator is a helper for TD Ameritrade's authentication.
// It authenticates users and validates the state returned from TD Ameritrade to protect users from CSRF attacks.
// It's recommended to use NewAuthenticator instead of creating this struct directly because TD Ameritrade requires Client IDs to be in the form clientid@AMER.OAUTHAP.
// This is not immediately obvious from the documentation.
// See https://developer.tdameritrade.com/content/authentication-faq
type Authenticator struct {
Store PersistentStore
OAuth2 oauth2.Config
AuthOpts []oauth2.AuthCodeOption
}
// NewAuthenticator will automatically append @AMER.OAUTHAP to the client ID to save callers hours of frustration.
func NewAuthenticator(store PersistentStore, oauth2 oauth2.Config, opts ...oauth2.AuthCodeOption) *Authenticator {
oauth2.ClientID = oauth2.ClientID + "@AMER.OAUTHAP"
return &Authenticator{
Store: store,
OAuth2: oauth2,
AuthOpts: opts,
}
}
// AuthenticatedClient tries to create an authenticated `Client` from a user's request
func (a *Authenticator) AuthenticatedClient(ctx context.Context, req *http.Request) (*Client, error) {
token, err := a.Store.GetToken(req)
if err != nil {
return nil, err
}
authenticatedClient := a.OAuth2.Client(ctx, token)
return NewClient(authenticatedClient)
}
// StartOAuth2Flow returns TD Ameritrade's Auth URL and stores a random state value.
// Redirect users to the returned URL to begin authentication.
func (a *Authenticator) StartOAuth2Flow(w http.ResponseWriter, req *http.Request) (string, error) {
// Do not leave state generation up to callers.
// Experience has shown that people often do not know what OAuth2 state is and leave themselves vulnerable to CSRF attacks.
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
// Instead, have callers store the state we give them and present it to us when we ask for it again.
state := base64.RawURLEncoding.EncodeToString(b)
err := a.Store.StoreState(state, w, req)
if err != nil {
return "", err
}
return a.OAuth2.AuthCodeURL(state, a.AuthOpts...), nil
}
// FinishOAuth2Flow finishes authenticating a user returning from TD Ameritrade.
// It verifies that TD Ameritrade has returned the expected state to prevent CSRF attacks and returns an authenticated `Client` on success.
func (a *Authenticator) FinishOAuth2Flow(ctx context.Context, w http.ResponseWriter, req *http.Request) (*Client, error) {
code, ok := req.URL.Query()["code"]
if !ok || len(code) == 0 || len(code[0]) == 0 {
return nil, ErrNoCode
}
state, ok := req.URL.Query()["state"]
if !ok || len(state) == 0 || len(state[0]) == 0 {
return nil, ErrNoState
}
expectedState, err := a.Store.GetState(req)
if err != nil {
return nil, err
}
// Sanity check: users should never return an empty string from GetState.
// Prevent users from making themselves vulnerable to CSRF by forcing them to set state.
if len(expectedState) == 0 {
return nil, ErrNoState
}
if state[0] != expectedState {
return nil, fmt.Errorf("invalid state. expected: '%v', got '%v'", expectedState, state[0])
}
token, err := a.OAuth2.Exchange(ctx, code[0], a.AuthOpts...)
if err != nil {
return nil, err
}
err = a.Store.StoreToken(token, w, req)
if err != nil {
return nil, err
}
authenticatedClient := a.OAuth2.Client(ctx, token)
return NewClient(authenticatedClient)
}