Skip to content

Commit

Permalink
all: api refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Jan 17, 2016
1 parent e9e9390 commit d936c91
Show file tree
Hide file tree
Showing 62 changed files with 1,540 additions and 1,117 deletions.
461 changes: 318 additions & 143 deletions README.md

Large diffs are not rendered by default.

64 changes: 1 addition & 63 deletions access_request.go
Original file line number Diff line number Diff line change
@@ -1,53 +1,15 @@
package fosite

import (
"github.com/ory-am/fosite/client"
"time"
)

type AccessRequester interface {
// GetGrantType returns the requests grant type.
GetGrantType() string

// GetClient returns the requests client.
GetClient() client.Client

// GetRequestedAt returns the time the request was created.
GetRequestedAt() time.Time

// GetScopes returns the request's scopes.
GetScopes() Arguments

// SetScopes sets the request's scopes.
SetScopes(Arguments)

// GetGrantScopes returns all granted scopes.
GetGrantedScopes() Arguments

// GrantScope marks a request's scope as granted.
GrantScope(string)

// SetGrantTypeHandled marks a grant type as handled indicating that the response type is supported.
SetGrantTypeHandled(string)

// DidHandleGrantType returns if the requested grant type has been handled correctly.
DidHandleGrantType() bool
}

type AccessRequest struct {
GrantType string
HandledGrantType []string
RequestedAt time.Time
Client client.Client
Scopes Arguments
GrantedScopes []string
}

func NewAccessRequest() *AccessRequest {
return &AccessRequest{
RequestedAt: time.Now(),
HandledGrantType: []string{},
}
Request
}

func (a *AccessRequest) DidHandleGrantType() bool {
Expand All @@ -61,27 +23,3 @@ func (a *AccessRequest) SetGrantTypeHandled(name string) {
func (a *AccessRequest) GetGrantType() string {
return a.GrantType
}

func (a *AccessRequest) GetRequestedAt() time.Time {
return a.RequestedAt
}

func (a *AccessRequest) GetClient() client.Client {
return a.Client
}

func (a *AccessRequest) GetScopes() Arguments {
return a.Scopes
}

func (a *AccessRequest) SetScopes(s Arguments) {
a.Scopes = s
}

func (a *AccessRequest) GetGrantedScopes() Arguments {
return Arguments(a.GrantedScopes)
}

func (a *AccessRequest) GrantScope(scope string) {
a.GrantedScopes = append(a.GrantedScopes, scope)
}
50 changes: 30 additions & 20 deletions access_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"golang.org/x/net/context"
"net/http"
"strings"
"time"
)

// Implements
Expand Down Expand Up @@ -33,59 +34,68 @@ import (
// client MUST authenticate with the authorization server as described
// in Section 3.2.1.
func (f *Fosite) NewAccessRequest(ctx context.Context, r *http.Request, session interface{}) (AccessRequester, error) {
ar := NewAccessRequest()
accessRequest := &AccessRequest{
Request: Request{
Scopes: Arguments{},
Session: session,
RequestedAt: time.Now(),
},
}

if r.Method != "POST" {
return ar, errors.New(ErrInvalidRequest)
return accessRequest, errors.New(ErrInvalidRequest)
}

if f.RequiredScope == "" {
f.RequiredScope = DefaultRequiredScopeName
}

if err := r.ParseForm(); err != nil {
return ar, errors.New(ErrInvalidRequest)
return accessRequest, errors.New(ErrInvalidRequest)
}

accessRequest.Form = r.PostForm

if session == nil {
return ar, errors.New("Session must not be nil")
return accessRequest, errors.New("Session must not be nil")
}

ar.Scopes = removeEmpty(strings.Split(r.Form.Get("scope"), " "))
ar.GrantType = r.Form.Get("grant_type")
if ar.GrantType == "" {
return ar, errors.New(ErrInvalidRequest)
accessRequest.Scopes = removeEmpty(strings.Split(r.Form.Get("scope"), " "))
accessRequest.GrantType = r.Form.Get("grant_type")
if accessRequest.GrantType == "" {
return accessRequest, errors.New(ErrInvalidRequest)
}

clientID, clientSecret, ok := r.BasicAuth()
if !ok {
return ar, errors.New(ErrInvalidRequest)
return accessRequest, errors.New(ErrInvalidRequest)
}

client, err := f.Store.GetClient(clientID)
if err != nil {
return ar, errors.New(ErrInvalidClient)
return accessRequest, errors.New(ErrInvalidClient)
}

// Enforce client authentication
if err := f.Hasher.Compare(client.GetHashedSecret(), []byte(clientSecret)); err != nil {
return ar, errors.New(ErrInvalidClient)
return accessRequest, errors.New(ErrInvalidClient)
}
ar.Client = client
accessRequest.Client = client

for _, loader := range f.TokenEndpointHandlers {
if err := loader.ValidateTokenEndpointRequest(ctx, r, ar, session); err != nil {
return ar, err
if err := loader.ValidateTokenEndpointRequest(ctx, r, accessRequest); err != nil {
return accessRequest, err
}
}

if !ar.DidHandleGrantType() {
return ar, errors.New(ErrUnsupportedGrantType)
if !accessRequest.DidHandleGrantType() {
return accessRequest, errors.New(ErrUnsupportedGrantType)
}

if !ar.GetScopes().Has(f.RequiredScope) {
return ar, errors.New(ErrInvalidScope)
if !accessRequest.GetScopes().Has(f.RequiredScope) {
return accessRequest, errors.New(ErrInvalidScope)
}

ar.GrantScope(f.RequiredScope)
return ar, nil
accessRequest.GrantScope(f.RequiredScope)
return accessRequest, nil
}
14 changes: 8 additions & 6 deletions access_request_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(ErrServerError)
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Return(ErrServerError)
},
handlers: TokenEndpointHandlers{"a": handler},
},
Expand All @@ -140,7 +140,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
},
handlers: TokenEndpointHandlers{"a": handler},
},
Expand All @@ -157,7 +157,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, a AccessRequester, _ interface{}) {
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, a AccessRequester) {
a.SetGrantTypeHandled("bar")
}).Return(nil)
},
Expand All @@ -175,7 +175,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, a AccessRequester, _ interface{}) {
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, a AccessRequester) {
a.SetGrantTypeHandled("foo")
a.SetScopes([]string{"asdfasdf"})
}).Return(nil)
Expand All @@ -195,7 +195,7 @@ func TestNewAccessRequest(t *testing.T) {
store.EXPECT().GetClient(gomock.Eq("foo")).Return(client, nil)
client.EXPECT().GetHashedSecret().Return([]byte("foo"))
hasher.EXPECT().Compare(gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, a AccessRequester, _ interface{}) {
handler.EXPECT().ValidateTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, a AccessRequester) {
a.SetGrantTypeHandled("foo")
a.SetScopes([]string{DefaultRequiredScopeName})
}).Return(nil)
Expand All @@ -204,7 +204,9 @@ func TestNewAccessRequest(t *testing.T) {
expect: &AccessRequest{
GrantType: "foo",
HandledGrantType: []string{"foo"},
Client: client,
Request: Request{
Client: client,
},
},
},
} {
Expand Down
2 changes: 1 addition & 1 deletion access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

func TestAccessRequest(t *testing.T) {
ar := NewAccessRequest()
ar := &AccessRequest{}
ar.GrantType = "foobar"
ar.Client = &client.SecureClient{}
ar.GrantScope("foo")
Expand Down
16 changes: 0 additions & 16 deletions access_response.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,5 @@
package fosite

type AccessResponder interface {
SetExtra(key string, value interface{})

GetExtra(key string) interface{}

SetAccessToken(string)

SetTokenType(string)

GetAccessToken() string

GetTokenType() string

ToMap() map[string]interface{}
}

func NewAccessResponse() AccessResponder {
return &AccessResponse{
Extra: map[string]interface{}{},
Expand Down
4 changes: 2 additions & 2 deletions access_response_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import (
"net/http"
)

func (f *Fosite) NewAccessResponse(ctx context.Context, req *http.Request, requester AccessRequester, session interface{}) (AccessResponder, error) {
func (f *Fosite) NewAccessResponse(ctx context.Context, req *http.Request, requester AccessRequester) (AccessResponder, error) {
var err error
var tk TokenEndpointHandler

response := NewAccessResponse()
for _, tk = range f.TokenEndpointHandlers {
if err = tk.HandleTokenEndpointRequest(ctx, req, requester, response, session); err != nil {
if err = tk.HandleTokenEndpointRequest(ctx, req, requester, response); err != nil {
return nil, errors.Wrap(err, 1)
}
}
Expand Down
10 changes: 5 additions & 5 deletions access_response_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@ func TestNewAccessResponse(t *testing.T) {
},
{
mock: func() {
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(ErrServerError)
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(ErrServerError)
},
handlers: TokenEndpointHandlers{"a": handler},
expectErr: ErrServerError,
},
{
mock: func() {
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
},
handlers: TokenEndpointHandlers{"a": handler},
expectErr: ErrUnsupportedGrantType,
},
{
mock: func() {
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, _ AccessRequester, resp AccessResponder, _param4 interface{}) {
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, _ AccessRequester, resp AccessResponder) {
resp.SetAccessToken("foo")
}).Return(nil)
},
Expand All @@ -53,7 +53,7 @@ func TestNewAccessResponse(t *testing.T) {
},
{
mock: func() {
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, _ AccessRequester, resp AccessResponder, _param4 interface{}) {
handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_ context.Context, _ *http.Request, _ AccessRequester, resp AccessResponder) {
resp.SetAccessToken("foo")
resp.SetTokenType("bar")
}).Return(nil)
Expand All @@ -68,7 +68,7 @@ func TestNewAccessResponse(t *testing.T) {
} {
f.TokenEndpointHandlers = c.handlers
c.mock()
ar, err := f.NewAccessResponse(nil, nil, nil, struct{}{})
ar, err := f.NewAccessResponse(nil, nil, nil)
assert.True(t, errors.Is(c.expectErr, err), "%d", k)
assert.Equal(t, ar, c.expect)
t.Logf("Passed test case %d", k)
Expand Down
Loading

0 comments on commit d936c91

Please sign in to comment.