/
oauth.go
164 lines (136 loc) · 4.53 KB
/
oauth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
package oauth2
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"net/url"
"strings"
"time"
)
func init() {
rand.Seed(time.Now().UTC().UnixNano())
}
// Config describes a typical 3-legged OAuth2 flow, with both the
// client application information and the server's endpoint URLs.
type Config struct {
// ClientID is the application's ID.
ClientID string
// ClientSecret is the application's secret.
ClientSecret string
// The oauth authentication url to redirect to
AuthURL string
// The url for token exchange
TokenURL string
// RedirectURL is the URL to redirect users going through
// the OAuth flow, after the resource owner's URLs.
RedirectURI string
// Scope specifies optional requested permissions, this is a *space* separated list.
Scope string
// The oauth provider
Provider Provider
}
// TokenInfo represents the credentials used to authorize
// the requests to access protected resources on the OAuth 2.0
// provider's backend.
type TokenInfo struct {
// AccessToken is the token that authorizes and authenticates
// the requests.
AccessToken string `json:"access_token"`
// TokenType is the type of token.
TokenType string `json:"token_type,omitempty"`
// The scopes for this tolen
Scope string `json:"scope,omitempty"`
}
// JSONError represents an oauth error response in json form.
type JSONError struct {
Error string `json:"error"`
}
const stateCookieName = "oauthState"
const defaultTimeout = 5 * time.Second
// StartFlow by redirecting the user to the login provider.
// A state parameter to protect against cross-site request forgery attacks is randomly generated and stored in a cookie
func StartFlow(cfg Config, w http.ResponseWriter) {
values := make(url.Values)
values.Set("client_id", cfg.ClientID)
values.Set("scope", cfg.Scope)
values.Set("redirect_uri", cfg.RedirectURI)
values.Set("response_type", "code")
// set and store the state param
values.Set("state", randStringBytes(15))
http.SetCookie(w, &http.Cookie{
Name: stateCookieName,
MaxAge: 60 * 10, // 10 minutes
Value: values.Get("state"),
HttpOnly: true,
})
targetURL := cfg.AuthURL + "?" + values.Encode()
w.Header().Set("Location", targetURL)
w.WriteHeader(http.StatusFound)
}
// Authenticate after coming back from the oauth flow.
// Verify the state parameter againt the state cookie from the request.
func Authenticate(cfg Config, r *http.Request) (TokenInfo, error) {
if r.FormValue("error") != "" {
return TokenInfo{}, fmt.Errorf("error: %v", r.FormValue("error"))
}
state := r.FormValue("state")
stateCookie, err := r.Cookie(stateCookieName)
if err != nil || stateCookie.Value != state {
return TokenInfo{}, fmt.Errorf("error: oauth state param could not be verified")
}
code := r.FormValue("code")
if code == "" {
return TokenInfo{}, fmt.Errorf("error: no auth code provided")
}
return getAccessToken(cfg, state, code)
}
func getAccessToken(cfg Config, state, code string) (TokenInfo, error) {
values := url.Values{}
values.Set("client_id", cfg.ClientID)
values.Set("client_secret", cfg.ClientSecret)
values.Set("code", code)
values.Set("redirect_uri", cfg.RedirectURI)
values.Set("grant_type", "authorization_code")
r, _ := http.NewRequest("POST", cfg.TokenURL, strings.NewReader(values.Encode()))
cntx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
r.WithContext(cntx)
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
r.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(r)
if err != nil {
return TokenInfo{}, err
}
if resp.StatusCode != 200 {
return TokenInfo{}, fmt.Errorf("error: expected http status 200 on token exchange, but got %v", resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return TokenInfo{}, fmt.Errorf("error reading token exchange response: %q", err)
}
jsonError := JSONError{}
json.Unmarshal(body, &jsonError)
if jsonError.Error != "" {
return TokenInfo{}, fmt.Errorf("error: got %q on token exchange", jsonError.Error)
}
tokenInfo := TokenInfo{}
err = json.Unmarshal(body, &tokenInfo)
if err != nil {
return TokenInfo{}, fmt.Errorf("error on parsing oauth token: %v", err)
}
if tokenInfo.AccessToken == "" {
return TokenInfo{}, fmt.Errorf("error: no access_token on token exchange")
}
return tokenInfo, nil
}
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
func randStringBytes(n int) string {
b := make([]byte, n)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return string(b)
}