diff --git a/oauth2/accesstoken.go b/oauth2/accesstoken.go index 7a511cf..7179ddf 100644 --- a/oauth2/accesstoken.go +++ b/oauth2/accesstoken.go @@ -15,3 +15,13 @@ type AccessToken struct { func (a *AccessToken) Valid() bool { return a != nil && a.AccessToken != "" && time.Now().UTC().Unix() > a.Expired } + +func (a *AccessToken) HasScope(scopes ...string) bool { + for _, scope := range scopes { + if ok := HierarchicScope(scope, a.Scopes); !ok { + return false + } + } + + return true +} diff --git a/oauth2/storage.go b/oauth2/storage.go index eabaea6..bc3ec74 100644 --- a/oauth2/storage.go +++ b/oauth2/storage.go @@ -1,6 +1,7 @@ package oauth2 type Storage interface { + GetClient(id string) (*Client, error) GetClientWithSecret(id, secret string) (*Client, error) GetRefreshToken(refreshToken string) (*RefreshToken, error) GetAuthorizeCode(code string) (*AuthorizeCode, error) diff --git a/oauth2/token_jwt.go b/oauth2/token_jwt.go index bcb8de2..cd75cee 100644 --- a/oauth2/token_jwt.go +++ b/oauth2/token_jwt.go @@ -50,6 +50,7 @@ func (c *JWTTokenGenerator) CreateAccessToken(req *CreateAccessTokenRequest) (st "iat": now.Unix(), "token_type": "bearer", "scope": strings.Join(req.Scopes, " "), + "extra": req.Extras, }) return token.SignedString(c.privateKey) @@ -82,3 +83,59 @@ func (c *JWTTokenGenerator) CreateCode() string { func (c *JWTTokenGenerator) CreateRefreshToken() string { return uuid.NewV4().String() } + +type JWTAccessToken struct { + Audience string + ExpiresAt int64 + ID string + IssuedAt int64 + Issuer string + Subject string + Extra map[string]interface{} + Scopes []string +} + +func (a *JWTAccessToken) Valid() bool { + return a != nil && time.Now().UTC().Unix() > a.ExpiresAt +} + +func (a *JWTAccessToken) HasScope(scopes ...string) bool { + for _, scope := range scopes { + if ok := HierarchicScope(scope, a.Scopes); !ok { + return false + } + } + + return true +} + +func ClaimJWTAccessToken(publicKey *rsa.PublicKey, accesstoken string) (*JWTAccessToken, error) { + jwttoken, err := jwt.Parse(accesstoken, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"]) + } + + return publicKey, nil + }) + if err != nil || !jwttoken.Valid { + return nil, errors.New("Invalid token") + } + + claims, ok := jwttoken.Claims.(jwt.MapClaims) + if !ok { + return nil, errors.New("Invalid jwt") + } + + at := &JWTAccessToken{ + Audience: claims["aud"].(string), + ExpiresAt: claims["exp"].(int64), + ID: claims["jti"].(string), + IssuedAt: claims["iat"].(int64), + Issuer: claims["iss"].(string), + Subject: claims["sub"].(string), + Extra: claims["extra"].(map[string]interface{}), + Scopes: strings.Fields(claims["scope"].(string)), + } + + return at, nil +} diff --git a/storage/dynamodb/storage.go b/storage/dynamodb/storage.go index c155b8d..f9b6fe6 100644 --- a/storage/dynamodb/storage.go +++ b/storage/dynamodb/storage.go @@ -31,16 +31,13 @@ func New(id, secret, region string) (*DynamoDB, error) { return &DynamoDB{dynamodb.New(sess), cache}, nil } -func (s *DynamoDB) GetClientWithSecret(id, secret string) (*oauth2.Client, error) { +func (s *DynamoDB) GetClient(id string) (*oauth2.Client, error) { res, err := s.db.GetItem(&dynamodb.GetItemInput{ TableName: aws.String("oauth_client"), Key: map[string]*dynamodb.AttributeValue{ "id": { S: aws.String(id), }, - "s": { - S: aws.String(secret), - }, }, }) @@ -58,6 +55,19 @@ func (s *DynamoDB) GetClientWithSecret(id, secret string) (*oauth2.Client, error return c, err } +func (s *DynamoDB) GetClientWithSecret(id, secret string) (*oauth2.Client, error) { + client, err := s.GetClient(id) + if err != nil { + return nil, err + } + + if client.Secret != secret { + return nil, oauth2.DbNotFoundError(err) + } + + return client, nil +} + func (s *DynamoDB) GetRefreshToken(refreshToken string) (*oauth2.RefreshToken, error) { res, err := s.db.GetItem(&dynamodb.GetItemInput{ TableName: aws.String("oauth_refreshtoken"), diff --git a/storage/dynamodb/storage_test.go b/storage/dynamodb/storage_test.go index 0098371..4470529 100644 --- a/storage/dynamodb/storage_test.go +++ b/storage/dynamodb/storage_test.go @@ -42,10 +42,14 @@ func TestCRUDClient(t *testing.T) { require.NoError(t, err) require.Equal(t, expc, c) + c, err = db.GetClient(expc.ID) + require.NoError(t, err) + require.Equal(t, expc, c) + _, err = db.db.DeleteItem(&dynamodb.DeleteItemInput{ TableName: aws.String("oauth_client"), Key: map[string]*dynamodb.AttributeValue{ - "c": { + "id": { S: aws.String(c.ID), }, }, @@ -55,6 +59,7 @@ func TestCRUDClient(t *testing.T) { c, err = db.GetClientWithSecret(c.ID, c.Secret) require.Equal(t, oauth2.DbNotFoundError(nil), err) require.Nil(t, c) + } func TestCRUDAccessToken(t *testing.T) { diff --git a/storage/memory/storage.go b/storage/memory/storage.go index 1e6ab7b..275a15f 100644 --- a/storage/memory/storage.go +++ b/storage/memory/storage.go @@ -3,6 +3,8 @@ package memory import ( "errors" + "sync" + "github.com/plimble/clover/oauth2" ) @@ -16,40 +18,63 @@ var ( ) type MemoryStorage struct { - client map[string]*oauth2.Client - scope map[string]*Scope - accessToken map[string]*oauth2.AccessToken - refreshToken map[string]*oauth2.RefreshToken - authCodeToken map[string]*oauth2.AuthorizeCode + ClientMutex sync.Mutex + ScopeMutex sync.Mutex + AccessTokenMutex sync.Mutex + RefreshTokenMutex sync.Mutex + AuthorizeCodeMutex sync.Mutex + Client map[string]*oauth2.Client + Scope map[string]*Scope + AccessToken map[string]*oauth2.AccessToken + RefreshToken map[string]*oauth2.RefreshToken + AuthorizeCode map[string]*oauth2.AuthorizeCode } func NewMemoryStorage() *MemoryStorage { - return &MemoryStorage{} + return &MemoryStorage{ + Client: make(map[string]*oauth2.Client), + Scope: make(map[string]*Scope), + AccessToken: make(map[string]*oauth2.AccessToken), + RefreshToken: make(map[string]*oauth2.RefreshToken), + AuthorizeCode: make(map[string]*oauth2.AuthorizeCode), + } } func (s *MemoryStorage) Flush() { - s.client = make(map[string]*oauth2.Client) - s.client = make(map[string]*oauth2.Client) + s.Client = make(map[string]*oauth2.Client) + s.Scope = make(map[string]*Scope) + s.AccessToken = make(map[string]*oauth2.AccessToken) + s.RefreshToken = make(map[string]*oauth2.RefreshToken) + s.AuthorizeCode = make(map[string]*oauth2.AuthorizeCode) } func (s *MemoryStorage) RevokeAccessToken(token string) error { - _, ok := s.accessToken[token] + s.AccessTokenMutex.Lock() + defer s.AccessTokenMutex.Unlock() + + _, ok := s.AccessToken[token] if !ok { return errNotFound } - delete(s.accessToken, token) + delete(s.AccessToken, token) return nil } func (s *MemoryStorage) SaveAccessToken(accessToken *oauth2.AccessToken) error { - s.accessToken[accessToken.AccessToken] = accessToken + s.AccessTokenMutex.Lock() + defer s.AccessTokenMutex.Unlock() + + s.AccessToken[accessToken.AccessToken] = accessToken return nil } func (s *MemoryStorage) GetAccessToken(token string) (*oauth2.AccessToken, error) { - accessToken, ok := s.accessToken[token] + s.AccessTokenMutex.Lock() + defer s.AccessTokenMutex.Unlock() + + accessToken, ok := s.AccessToken[token] if !ok { return nil, errNotFound } @@ -58,23 +83,32 @@ func (s *MemoryStorage) GetAccessToken(token string) (*oauth2.AccessToken, error } func (s *MemoryStorage) RevokeRefreshToken(token string) error { - _, ok := s.refreshToken[token] + s.RefreshTokenMutex.Lock() + defer s.RefreshTokenMutex.Unlock() + + _, ok := s.RefreshToken[token] if !ok { return errNotFound } - delete(s.refreshToken, token) + delete(s.RefreshToken, token) return nil } func (s *MemoryStorage) SaveRefreshToken(refreshToken *oauth2.RefreshToken) error { - s.refreshToken[refreshToken.RefreshToken] = refreshToken + s.RefreshTokenMutex.Lock() + defer s.RefreshTokenMutex.Unlock() + + s.RefreshToken[refreshToken.RefreshToken] = refreshToken return nil } func (s *MemoryStorage) GetRefreshToken(token string) (*oauth2.RefreshToken, error) { - refreshToken, ok := s.refreshToken[token] + s.RefreshTokenMutex.Lock() + defer s.RefreshTokenMutex.Unlock() + + refreshToken, ok := s.RefreshToken[token] if !ok { return nil, errNotFound } @@ -82,24 +116,19 @@ func (s *MemoryStorage) GetRefreshToken(token string) (*oauth2.RefreshToken, err return refreshToken, nil } -func (s *MemoryStorage) RevokeAuthorizeCode(code string) error { - _, ok := s.authCodeToken[code] - if !ok { - return errNotFound - } - - delete(s.authCodeToken, code) - - return nil -} - func (s *MemoryStorage) SaveAuthorizeCode(authCode *oauth2.AuthorizeCode) error { - s.authCodeToken[authCode.Code] = authCode + s.AuthorizeCodeMutex.Lock() + defer s.AuthorizeCodeMutex.Unlock() + + s.AuthorizeCode[authCode.Code] = authCode return nil } func (s *MemoryStorage) GetAuthorizeCode(code string) (*oauth2.AuthorizeCode, error) { - authCode, ok := s.authCodeToken[code] + s.AuthorizeCodeMutex.Lock() + defer s.AuthorizeCodeMutex.Unlock() + + authCode, ok := s.AuthorizeCode[code] if !ok { return nil, errNotFound } @@ -107,24 +136,11 @@ func (s *MemoryStorage) GetAuthorizeCode(code string) (*oauth2.AuthorizeCode, er return authCode, nil } -func (s *MemoryStorage) DeleteClient(id string) error { - _, ok := s.client[id] - if !ok { - return errNotFound - } - - delete(s.client, id) - - return nil -} - -func (s *MemoryStorage) SaveClient(client *oauth2.Client) error { - s.client[client.ID] = client - return nil -} - func (s *MemoryStorage) GetClientWithSecret(id, secret string) (*oauth2.Client, error) { - client, ok := s.client[id] + s.ClientMutex.Lock() + defer s.ClientMutex.Unlock() + + client, ok := s.Client[id] if !ok { return nil, errNotFound } @@ -137,7 +153,10 @@ func (s *MemoryStorage) GetClientWithSecret(id, secret string) (*oauth2.Client, } func (s *MemoryStorage) GetClient(id string) (*oauth2.Client, error) { - client, ok := s.client[id] + s.ClientMutex.Lock() + defer s.ClientMutex.Unlock() + + client, ok := s.Client[id] if !ok { return nil, errNotFound } @@ -145,49 +164,12 @@ func (s *MemoryStorage) GetClient(id string) (*oauth2.Client, error) { return client, nil } -func (s *MemoryStorage) CreateScope(scope *Scope) error { - s.scope[scope.ID] = scope - return nil -} - -func (s *MemoryStorage) GetScopeByIDs(ids []string) ([]*Scope, error) { - scopes := []*Scope{} - - for _, id := range ids { - scope, ok := s.scope[id] - if ok { - scopes = append(scopes, scope) - } - } - - return scopes, nil -} - -func (s *MemoryStorage) DeleteScope(id string) error { - _, ok := s.scope[id] - if !ok { - return errNotFound - } - - delete(s.scope, id) - - return nil -} - -func (s *MemoryStorage) GetAllScope() ([]*Scope, error) { - scopes := make([]*Scope, len(s.scope)) - index := 0 - for _, scope := range s.scope { - scopes[index] = scope - index++ - } - - return scopes, nil -} - func (s *MemoryStorage) IsAvailableScope(scopes []string) (bool, error) { + s.ScopeMutex.Lock() + defer s.ScopeMutex.Unlock() + for _, scope := range scopes { - if _, ok := s.scope[scope]; !ok { + if _, ok := s.Scope[scope]; !ok { return false, nil } }