Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve API sanity #20

Merged
merged 1 commit into from Jan 18, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
461 changes: 318 additions & 143 deletions README.md

Large diffs are not rendered by default.

64 changes: 1 addition & 63 deletions access_request.go
@@ -1,53 +1,15 @@
package fosite

import (
"github.com/ory-am/fosite/client"
"time"
)

type AccessRequester interface {
// GetGrantType returns the requests grant type.
GetGrantType() string

// GetClient returns the requests client.
GetClient() client.Client

// GetRequestedAt returns the time the request was created.
GetRequestedAt() time.Time

// GetScopes returns the request's scopes.
GetScopes() Arguments

// SetScopes sets the request's scopes.
SetScopes(Arguments)

// GetGrantScopes returns all granted scopes.
GetGrantedScopes() Arguments

// GrantScope marks a request's scope as granted.
GrantScope(string)

// SetGrantTypeHandled marks a grant type as handled indicating that the response type is supported.
SetGrantTypeHandled(string)

// DidHandleGrantType returns if the requested grant type has been handled correctly.
DidHandleGrantType() bool
}

type AccessRequest struct {
GrantType string
HandledGrantType []string
RequestedAt time.Time
Client client.Client
Scopes Arguments
GrantedScopes []string
}

func NewAccessRequest() *AccessRequest {
return &AccessRequest{
RequestedAt: time.Now(),
HandledGrantType: []string{},
}
Request
}

func (a *AccessRequest) DidHandleGrantType() bool {
Expand All @@ -61,27 +23,3 @@ func (a *AccessRequest) SetGrantTypeHandled(name string) {
func (a *AccessRequest) GetGrantType() string {
return a.GrantType
}

func (a *AccessRequest) GetRequestedAt() time.Time {
return a.RequestedAt
}

func (a *AccessRequest) GetClient() client.Client {
return a.Client
}

func (a *AccessRequest) GetScopes() Arguments {
return a.Scopes
}

func (a *AccessRequest) SetScopes(s Arguments) {
a.Scopes = s
}

func (a *AccessRequest) GetGrantedScopes() Arguments {
return Arguments(a.GrantedScopes)
}

func (a *AccessRequest) GrantScope(scope string) {
a.GrantedScopes = append(a.GrantedScopes, scope)
}
50 changes: 30 additions & 20 deletions access_request_handler.go
Expand Up @@ -5,6 +5,7 @@ import (
"golang.org/x/net/context"
"net/http"
"strings"
"time"
)

// Implements
Expand Down Expand Up @@ -33,59 +34,68 @@ import (
// client MUST authenticate with the authorization server as described
// in Section 3.2.1.
func (f *Fosite) NewAccessRequest(ctx context.Context, r *http.Request, session interface{}) (AccessRequester, error) {
ar := NewAccessRequest()
accessRequest := &AccessRequest{
Request: Request{
Scopes: Arguments{},
Session: session,
RequestedAt: time.Now(),
},
}

if r.Method != "POST" {
return ar, errors.New(ErrInvalidRequest)
return accessRequest, errors.New(ErrInvalidRequest)
}

if f.RequiredScope == "" {
f.RequiredScope = DefaultRequiredScopeName
}

if err := r.ParseForm(); err != nil {
return ar, errors.New(ErrInvalidRequest)
return accessRequest, errors.New(ErrInvalidRequest)
}

accessRequest.Form = r.PostForm

if session == nil {
return ar, errors.New("Session must not be nil")
return accessRequest, errors.New("Session must not be nil")
}

ar.Scopes = removeEmpty(strings.Split(r.Form.Get("scope"), " "))
ar.GrantType = r.Form.Get("grant_type")
if ar.GrantType == "" {
return ar, errors.New(ErrInvalidRequest)
accessRequest.Scopes = removeEmpty(strings.Split(r.Form.Get("scope"), " "))
accessRequest.GrantType = r.Form.Get("grant_type")
if accessRequest.GrantType == "" {
return accessRequest, errors.New(ErrInvalidRequest)
}

clientID, clientSecret, ok := r.BasicAuth()
if !ok {
return ar, errors.New(ErrInvalidRequest)
return accessRequest, errors.New(ErrInvalidRequest)
}

client, err := f.Store.GetClient(clientID)
if err != nil {
return ar, errors.New(ErrInvalidClient)
return accessRequest, errors.New(ErrInvalidClient)
}

// Enforce client authentication
if err := f.Hasher.Compare(client.GetHashedSecret(), []byte(clientSecret)); err != nil {
return ar, errors.New(ErrInvalidClient)
return accessRequest, errors.New(ErrInvalidClient)
}
ar.Client = client
accessRequest.Client = client

for _, loader := range f.TokenEndpointHandlers {
if err := loader.ValidateTokenEndpointRequest(ctx, r, ar, session); err != nil {
return ar, err
if err := loader.ValidateTokenEndpointRequest(ctx, r, accessRequest); err != nil {
return accessRequest, err
}
}

if !ar.DidHandleGrantType() {
return ar, errors.New(ErrUnsupportedGrantType)
if !accessRequest.DidHandleGrantType() {
return accessRequest, errors.New(ErrUnsupportedGrantType)
}

if !ar.GetScopes().Has(f.RequiredScope) {
return ar, errors.New(ErrInvalidScope)
if !accessRequest.GetScopes().Has(f.RequiredScope) {
return accessRequest, errors.New(ErrInvalidScope)
}

ar.GrantScope(f.RequiredScope)
return ar, nil
accessRequest.GrantScope(f.RequiredScope)
return accessRequest, nil
}
14 changes: 8 additions & 6 deletions access_request_handler_test.go
Expand Up @@ -123,7 +123,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(ErrServerError)
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Return(ErrServerError)
},
handlers: TokenEndpointHandlers{"a": handler},
},
Expand All @@ -140,7 +140,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
},
handlers: TokenEndpointHandlers{"a": handler},
},
Expand All @@ -157,7 +157,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, a AccessRequester, _ interface{}) {
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, a AccessRequester) {
a.SetGrantTypeHandled("bar")
}).Return(nil)
},
Expand All @@ -175,7 +175,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, a AccessRequester, _ interface{}) {
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, a AccessRequester) {
a.SetGrantTypeHandled("foo")
a.SetScopes([]string{"asdfasdf"})
}).Return(nil)
Expand All @@ -195,7 +195,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, a AccessRequester, _ interface{}) {
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, a AccessRequester) {
a.SetGrantTypeHandled("foo")
a.SetScopes([]string{DefaultRequiredScopeName})
}).Return(nil)
Expand All @@ -204,7 +204,9 @@ func TestNewAccessRequest(t *testing.T) {
expect: &AccessRequest{
GrantType: "foo",
HandledGrantType: []string{"foo"},
Client: client,
Request: Request{
Client: client,
},
},
},
} {
Expand Down
2 changes: 1 addition & 1 deletion access_request_test.go
Expand Up @@ -7,7 +7,7 @@ import (
)

func TestAccessRequest(t *testing.T) {
ar := NewAccessRequest()
ar := &AccessRequest{}
ar.GrantType = "foobar"
ar.Client = &client.SecureClient{}
ar.GrantScope("foo")
Expand Down
16 changes: 0 additions & 16 deletions access_response.go
@@ -1,21 +1,5 @@
package fosite

type AccessResponder interface {
SetExtra(key string, value interface{})

GetExtra(key string) interface{}

SetAccessToken(string)

SetTokenType(string)

GetAccessToken() string

GetTokenType() string

ToMap() map[string]interface{}
}

func NewAccessResponse() AccessResponder {
return &AccessResponse{
Extra: map[string]interface{}{},
Expand Down
4 changes: 2 additions & 2 deletions access_response_handler.go
Expand Up @@ -6,13 +6,13 @@ import (
"net/http"
)

func (f *Fosite) NewAccessResponse(ctx context.Context, req *http.Request, requester AccessRequester, session interface{}) (AccessResponder, error) {
func (f *Fosite) NewAccessResponse(ctx context.Context, req *http.Request, requester AccessRequester) (AccessResponder, error) {
var err error
var tk TokenEndpointHandler

response := NewAccessResponse()
for _, tk = range f.TokenEndpointHandlers {
if err = tk.HandleTokenEndpointRequest(ctx, req, requester, response, session); err != nil {
if err = tk.HandleTokenEndpointRequest(ctx, req, requester, response); err != nil {
return nil, errors.Wrap(err, 1)
}
}
Expand Down
10 changes: 5 additions & 5 deletions access_response_handler_test.go
Expand Up @@ -30,21 +30,21 @@ func TestNewAccessResponse(t *testing.T) {
},
{
mock: func() {
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(ErrServerError)
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(ErrServerError)
},
handlers: TokenEndpointHandlers{"a": handler},
expectErr: ErrServerError,
},
{
mock: func() {
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
},
handlers: TokenEndpointHandlers{"a": handler},
expectErr: ErrUnsupportedGrantType,
},
{
mock: func() {
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, _ AccessRequester, resp AccessResponder, _param4 interface{}) {
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, _ AccessRequester, resp AccessResponder) {
resp.SetAccessToken("foo")
}).Return(nil)
},
Expand All @@ -53,7 +53,7 @@ func TestNewAccessResponse(t *testing.T) {
},
{
mock: func() {
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, _ AccessRequester, resp AccessResponder, _param4 interface{}) {
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, _ AccessRequester, resp AccessResponder) {
resp.SetAccessToken("foo")
resp.SetTokenType("bar")
}).Return(nil)
Expand All @@ -68,7 +68,7 @@ func TestNewAccessResponse(t *testing.T) {
} {
f.TokenEndpointHandlers = c.handlers
c.mock()
ar, err := f.NewAccessResponse(nil, nil, nil, struct{}{})
ar, err := f.NewAccessResponse(nil, nil, nil)
assert.True(t, errors.Is(c.expectErr, err), "%d", k)
assert.Equal(t, ar, c.expect)
t.Logf("Passed test case %d", k)
Expand Down