Skip to content

Commit

Permalink
Device token api endpoint (dexidp#1)
Browse files Browse the repository at this point in the history
* Added /device/token handler with associated business logic and storage tests.

* Use crypto rand for user code
  • Loading branch information
justin-slowik authored and wolfeidau committed Mar 14, 2020
1 parent 5d306df commit 547942f
Show file tree
Hide file tree
Showing 10 changed files with 161 additions and 47 deletions.
82 changes: 57 additions & 25 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"path"
Expand Down Expand Up @@ -1434,9 +1433,8 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
message := "Could not parse Device Request body"
s.logger.Errorf("%s : %v", message, err)
respondWithError(w, message, err)
s.logger.Errorf("Could not parse Device Request body: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
return
}

Expand All @@ -1450,7 +1448,11 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
deviceCode := storage.NewDeviceCode()

//make user code
userCode := storage.NewUserCode()
userCode, err := storage.NewUserCode()
if err != nil {
s.logger.Errorf("Error generating user code: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
}

//make a pkce verification code
pkceCode := storage.NewID()
Expand All @@ -1469,24 +1471,21 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
}

if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
message := fmt.Sprintf("Failed to store device request %v", err)
s.logger.Errorf(message)
respondWithError(w, message, err)
s.logger.Errorf("Failed to store device request; %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}

//Store the device token
deviceToken := storage.DeviceToken{
DeviceCode: deviceCode,
Status: "pending",
Token: "",
Status: deviceTokenPending,
Expiry: expireTime,
}

if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
message := fmt.Sprintf("Failed to store device token %v", err)
s.logger.Errorf(message)
respondWithError(w, message, err)
s.logger.Errorf("Failed to store device token %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}

Expand All @@ -1503,20 +1502,53 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
enc.Encode(code)

default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
}
}

func respondWithError(w io.Writer, errorMessage string, err error) {
resp := struct {
Error string `json:"error"`
ErrorMessage string `json:"message"`
}{
Error: err.Error(),
ErrorMessage: errorMessage,
}
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
message := "Could not parse Device Token Request body"
s.logger.Warnf("%s : %v", message, err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
return
}

deviceCode := r.Form.Get("device_code")
if deviceCode == "" {
message := "No device code received"
s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest)
return
}

grantType := r.PostFormValue("grant_type")
if grantType != grantTypeDeviceCode {
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
return
}

//Grab the device token from the db
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
if err != nil || s.now().After(deviceToken.Expiry) {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get device code: %v", err)
}
s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest)
return
}

enc := json.NewEncoder(w)
enc.SetIndent("", " ")
enc.Encode(resp)
switch deviceToken.Status {
case deviceTokenPending:
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
case deviceTokenComplete:
w.Write([]byte(deviceToken.Token))
}
default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
}
}
6 changes: 6 additions & 0 deletions server/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ const (
grantTypeAuthorizationCode = "authorization_code"
grantTypeRefreshToken = "refresh_token"
grantTypePassword = "password"
grantTypeDeviceCode = "device_code"
)

const (
Expand All @@ -130,6 +131,11 @@ const (
responseTypeIDToken = "id_token" // ID Token in url fragment
)

const (
deviceTokenPending = "authorization_pending"
deviceTokenComplete = "complete"
)

func parseScopes(scopes []string) connector.Scopes {
var s connector.Scopes
for _, scope := range scopes {
Expand Down
1 change: 1 addition & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin)
handleFunc("/device/code", s.handleDeviceCode)
handleFunc("/device/token", s.handleDeviceToken)
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
// Strip the X-Remote-* headers to prevent security issues on
// misconfigured authproxy connector setups.
Expand Down
32 changes: 20 additions & 12 deletions storage/conformance/conformance.go
Original file line number Diff line number Diff line change
Expand Up @@ -833,8 +833,13 @@ func testGC(t *testing.T, s storage.Storage) {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}

userCode, err := storage.NewUserCode()
if err != nil {
t.Errorf("Unexpected Error: %v", err)
}

d := storage.DeviceRequest{
UserCode: storage.NewUserCode(),
UserCode: userCode,
DeviceCode: storage.NewID(),
ClientID: "client1",
Scopes: []string{"openid", "email"},
Expand Down Expand Up @@ -892,22 +897,21 @@ func testGC(t *testing.T, s storage.Storage) {
t.Errorf("expected no device token garbage collection results, got %#v", result)
}
}
//if _, err := s.GetDeviceRequest(d.UserCode); err != nil {
// t.Errorf("expected to be able to get auth request after GC: %v", err)
//}
if _, err := s.GetDeviceToken(dt.DeviceCode); err != nil {
t.Errorf("expected to be able to get device token after GC: %v", err)
}
}
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.DeviceTokens != 1 {
t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens)
}

//TODO add this code back once Getters are written for device tokens
//if _, err := s.GetDeviceRequest(d.UserCode); err == nil {
// t.Errorf("expected device request to be GC'd")
//} else if err != storage.ErrNotFound {
// t.Errorf("expected storage.ErrNotFound, got %v", err)
//}
if _, err := s.GetDeviceToken(dt.DeviceCode); err == nil {
t.Errorf("expected device token to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}
}

// testTimezones tests that backends either fully support timezones or
Expand Down Expand Up @@ -957,8 +961,12 @@ func testTimezones(t *testing.T, s storage.Storage) {
}

func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
userCode, err := storage.NewUserCode()
if err != nil {
panic(err)
}
d1 := storage.DeviceRequest{
UserCode: storage.NewUserCode(),
UserCode: userCode,
DeviceCode: storage.NewID(),
ClientID: "client1",
Scopes: []string{"openid", "email"},
Expand All @@ -971,7 +979,7 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
}

// Attempt to create same DeviceRequest twice.
err := s.CreateDeviceRequest(d1)
err = s.CreateDeviceRequest(d1)
mustBeErrAlreadyExists(t, "device request", err)

//No manual deletes for device requests, will be handled by garbage collection routines
Expand Down
7 changes: 7 additions & 0 deletions storage/etcd/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,13 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
return c.txnCreate(ctx, keyID(deviceRequestPrefix, t.DeviceCode), fromStorageDeviceToken(t))
}

func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &t)
return t, err
}

func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) {
res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix())
if err != nil {
Expand Down
8 changes: 8 additions & 0 deletions storage/kubernetes/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -641,3 +641,11 @@ func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error {
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
}

func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
var token DeviceToken
if err := cli.get(resourceDeviceToken, deviceCode, &token); err != nil {
return storage.DeviceToken{}, err
}
return toStorageDeviceToken(token), nil
}
9 changes: 9 additions & 0 deletions storage/kubernetes/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -741,3 +741,12 @@ func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
}
return req
}

func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
return storage.DeviceToken{
DeviceCode: t.ObjectMeta.Name,
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
}
}
11 changes: 11 additions & 0 deletions storage/memory/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,3 +503,14 @@ func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
})
return
}

func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
s.tx(func() {
var ok bool
if t, ok = s.deviceTokens[deviceCode]; !ok {
err = storage.ErrNotFound
return
}
})
return
}
22 changes: 22 additions & 0 deletions storage/sql/crud.go
Original file line number Diff line number Diff line change
Expand Up @@ -922,3 +922,25 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
}
return nil
}

func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
return getDeviceToken(c, deviceCode)
}

func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
err = q.QueryRow(`
select
status, token, expiry
from device_token where device_code = $1;
`, deviceCode).Scan(
&a.Status, &a.Token, &a.Expiry,
)
if err != nil {
if err == sql.ErrNoRows {
return a, storage.ErrNotFound
}
return a, fmt.Errorf("select device token: %v", err)
}
a.DeviceCode = deviceCode
return a, nil
}
30 changes: 20 additions & 10 deletions storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"encoding/base32"
"errors"
"io"
mrand "math/rand"
"math/big"
"strings"
"time"

Expand All @@ -25,6 +25,9 @@ var (
// TODO(ericchiang): refactor ID creation onto the storage.
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")

//Valid characters for user codes
const validUserCharacters = "BCDFGHJKLMNPQRSTVWXZ"

// NewDeviceCode returns a 32 char alphanumeric cryptographically secure string
func NewDeviceCode() string {
return newSecureID(32)
Expand Down Expand Up @@ -79,6 +82,7 @@ type Storage interface {
GetPassword(email string) (Password, error)
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
GetConnector(id string) (Connector, error)
GetDeviceToken(deviceCode string) (DeviceToken, error)

ListClients() ([]Client, error)
ListRefreshTokens() ([]RefreshToken, error)
Expand Down Expand Up @@ -357,18 +361,24 @@ type Keys struct {
NextRotation time.Time
}

func NewUserCode() string {
mrand.Seed(time.Now().UnixNano())
return randomString(4) + "-" + randomString(4)
// NewUserCode returns a randomized 8 character user code for the device flow.
// No vowels are included to prevent accidental generation of words
func NewUserCode() (string, error) {
code, err := randomString(8)
if err != nil {
return "", err
}
return code[:4] + "-" + code[4:], nil
}

func randomString(n int) string {
var letter = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
b := make([]rune, n)
for i := range b {
b[i] = letter[mrand.Intn(len(letter))]
func randomString(n int) (string, error) {
v := big.NewInt(int64(len(validUserCharacters)))
bytes := make([]byte, n)
for i := 0; i < n; i++ {
c, _ := rand.Int(rand.Reader, v)
bytes[i] = validUserCharacters[c.Int64()]
}
return string(b)
return string(bytes), nil
}

//DeviceRequest represents an OIDC device authorization request. It holds the state of a device request until the user
Expand Down

0 comments on commit 547942f

Please sign in to comment.