diff --git a/pkg/bootstrap/user/userprofile.go b/pkg/bootstrap/user/userprofile.go new file mode 100644 index 0000000..82fdf60 --- /dev/null +++ b/pkg/bootstrap/user/userprofile.go @@ -0,0 +1,59 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package user + +import ( + "encoding/json" + "fmt" + + "github.com/trustbloc/edge-core/pkg/storage" +) + +// Profile is the user's bootstrap profile. +type Profile struct { + ID string + SDSPrimaryVaultID string + KeyStoreIDs []string +} + +// ProfileStore is the user Profile CRUD API. +type ProfileStore struct { + s storage.Store +} + +// NewStore returns a new ProfileStore. +func NewStore(s storage.Store) *ProfileStore { + return &ProfileStore{s: s} +} + +// Save saves the user Profile. +func (ps *ProfileStore) Save(p *Profile) error { + bits, err := json.Marshal(p) + + if err != nil { + return fmt.Errorf("failed to marshal user profile : %w", err) + } + + err = ps.s.Put(p.ID, bits) + if err != nil { + return fmt.Errorf("failed to save user profile : %w", err) + } + + return nil +} + +// Get fetches the user Profile. +func (ps *ProfileStore) Get(id string) (*Profile, error) { + bits, err := ps.s.Get(id) + if err != nil { + return nil, fmt.Errorf("failed to fetch user profile : %w", err) + } + + p := &Profile{} + + return p, json.Unmarshal(bits, p) +} diff --git a/pkg/bootstrap/user/userprofile_test.go b/pkg/bootstrap/user/userprofile_test.go new file mode 100644 index 0000000..316054f --- /dev/null +++ b/pkg/bootstrap/user/userprofile_test.go @@ -0,0 +1,92 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package user + +import ( + "encoding/json" + "errors" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "github.com/trustbloc/edge-core/pkg/storage/mockstore" +) + +func TestNewStore(t *testing.T) { + p := NewStore(&mockstore.MockStore{}) + require.NotNil(t, p) +} + +func TestSave(t *testing.T) { + t.Run("saves profile", func(t *testing.T) { + expected := &Profile{ + ID: uuid.New().String(), + SDSPrimaryVaultID: uuid.New().String(), + KeyStoreIDs: []string{uuid.New().String()}, + } + + store := &mockstore.MockStore{ + Store: make(map[string][]byte), + } + + err := NewStore(store).Save(expected) + require.NoError(t, err) + result := &Profile{} + err = json.Unmarshal(store.Store[expected.ID], result) + require.NoError(t, err) + require.Equal(t, expected, result) + }) + + t.Run("wraps store error", func(t *testing.T) { + expected := errors.New("test") + store := &mockstore.MockStore{ + Store: make(map[string][]byte), + ErrPut: expected, + } + err := NewStore(store).Save(&Profile{ID: "test"}) + require.True(t, errors.Is(err, expected)) + }) +} + +func TestGet(t *testing.T) { + t.Run("fetches profile", func(t *testing.T) { + expected := &Profile{ + ID: uuid.New().String(), + SDSPrimaryVaultID: uuid.New().String(), + KeyStoreIDs: []string{uuid.New().String()}, + } + store := &mockstore.MockStore{ + Store: map[string][]byte{ + expected.ID: toBytes(t, expected), + }, + } + result, err := NewStore(store).Get(expected.ID) + require.NoError(t, err) + require.Equal(t, expected, result) + }) + + t.Run("wraps store error", func(t *testing.T) { + expected := errors.New("test") + store := &mockstore.MockStore{ + Store: map[string][]byte{ + "test": {}, + }, + ErrGet: expected, + } + _, err := NewStore(store).Get("test") + require.True(t, errors.Is(err, expected)) + }) +} + +func toBytes(t *testing.T, v interface{}) []byte { + t.Helper() + + bits, err := json.Marshal(v) + require.NoError(t, err) + + return bits +} diff --git a/pkg/internal/common/mockstorage/storage.go b/pkg/internal/common/mockstorage/storage.go new file mode 100644 index 0000000..713ea42 --- /dev/null +++ b/pkg/internal/common/mockstorage/storage.go @@ -0,0 +1,109 @@ +/* +Copyright SecureKey Technologies Inc. All Rights Reserved. + +SPDX-License-Identifier: Apache-2.0 +*/ + +package mockstorage + +import ( + "errors" + "fmt" + "sync" + + "github.com/trustbloc/edge-core/pkg/storage" +) + +// Provider mock store provider. +type Provider struct { + Stores map[string]storage.Store + Store *MockStore + ErrCreateStore error + ErrOpenStoreHandle error + FailNameSpace string +} + +// NewMockStoreProvider new store provider instance. +func NewMockStoreProvider() *Provider { + return &Provider{ + Stores: make(map[string]storage.Store), + Store: &MockStore{ + Store: make(map[string][]byte), + }, + } +} + +// CreateStore creates a new store with the given name. +func (p *Provider) CreateStore(name string) error { + return p.ErrCreateStore +} + +// OpenStore opens and returns a store for given name space. +func (p *Provider) OpenStore(name string) (storage.Store, error) { + if name == p.FailNameSpace { + return nil, fmt.Errorf("failed to open store for name space %s", name) + } + + if s, ok := p.Stores[name]; ok { + return s, nil + } + + return p.Store, p.ErrOpenStoreHandle +} + +// Close closes all stores created under this store provider. +func (p *Provider) Close() error { + return nil +} + +// CloseStore closes store for given name space. +func (p *Provider) CloseStore(name string) error { + return nil +} + +// MockStore represents a mock store. +type MockStore struct { + Store map[string][]byte + lock sync.RWMutex + ErrPut error + ErrGet error + ErrCreateIndex error + ErrQuery error + ResultsIteratorToReturn storage.ResultsIterator +} + +// Put stores the key-value pair. +func (s *MockStore) Put(k string, v []byte) error { + if k == "" { + return errors.New("key is mandatory") + } + + s.lock.Lock() + s.Store[k] = v + s.lock.Unlock() + + return s.ErrPut +} + +// Get fetches the value associated with the given key. +func (s *MockStore) Get(k string) ([]byte, error) { + s.lock.RLock() + defer s.lock.RUnlock() + + val, ok := s.Store[k] + if !ok { + return nil, storage.ErrValueNotFound + } + + return val, s.ErrGet +} + +// CreateIndex returns a mocked error. +func (s *MockStore) CreateIndex(createIndexRequest storage.CreateIndexRequest) error { + return s.ErrCreateIndex +} + +// Query returns a mocked error. +func (s *MockStore) Query(query string) (storage.ResultsIterator, error) { + return s.ResultsIteratorToReturn, s.ErrQuery +} diff --git a/pkg/restapi/operation/operations.go b/pkg/restapi/operation/operations.go index a980bac..5047a9d 100644 --- a/pkg/restapi/operation/operations.go +++ b/pkg/restapi/operation/operations.go @@ -20,6 +20,7 @@ import ( "github.com/trustbloc/edge-core/pkg/storage" "golang.org/x/oauth2" + "github.com/trustbloc/hub-auth/pkg/bootstrap/user" "github.com/trustbloc/hub-auth/pkg/internal/common/support" ) @@ -35,6 +36,10 @@ const ( var logger = log.New("hub-auth-restapi") +type oidcClaims struct { + Sub string `json:"sub"` +} + // Handler http handler for each controller API endpoint. type Handler interface { Path() string @@ -103,7 +108,6 @@ type oauth2Token interface { // Operation defines handlers. type Operation struct { - handlers []Handler client httpClient requestTokens map[string]string transientStore storage.Store @@ -186,8 +190,6 @@ func New(config *Config) (*Operation, error) { svc.bootstrapStore = bootstrapStore - svc.registerHandler() - return svc, nil } @@ -246,50 +248,44 @@ func (c *Operation) createOIDCRequest(w http.ResponseWriter, r *http.Request) { } } -func (c *Operation) handleOIDCCallback(w http.ResponseWriter, r *http.Request) { //nolint:funlen +func (c *Operation) handleOIDCCallback(w http.ResponseWriter, r *http.Request) { //nolint:funlen,gocyclo state := r.URL.Query().Get("state") if state == "" { - logger.Errorf("missing state") - c.hubAuthResult(w, "missing state") + handleAuthError(w, http.StatusBadRequest, "missing state") return } code := r.URL.Query().Get("code") if code == "" { - logger.Errorf("missing code") - c.hubAuthResult(w, "missing code") + handleAuthError(w, http.StatusBadRequest, "missing code") return } _, err := c.transientStore.Get(state) if errors.Is(err, storage.ErrValueNotFound) { - logger.Errorf("invalid state parameter") - c.hubAuthResult(w, "invalid state parameter") + handleAuthError(w, http.StatusBadRequest, "invalid state parameter") return } if err != nil { - logger.Errorf("failed to query transient store for state : %s", err) - c.hubAuthResult(w, fmt.Sprintf("failed to query transient store for state : %s", err)) + handleAuthError(w, http.StatusInternalServerError, fmt.Sprintf("failed to query transient store for state : %s", err)) return } oauthToken, err := c.oauth2Config().Exchange(r.Context(), code) if err != nil { - logger.Errorf("failed to exchange oauth2 code for token : %s", err) - c.hubAuthResult(w, fmt.Sprintf("failed to exchange oauth2 code for token : %s", err)) + handleAuthError(w, http.StatusBadGateway, fmt.Sprintf("failed to exchange oauth2 code for token : %s", err)) return } rawIDToken, ok := oauthToken.Extra("id_token").(string) if !ok { - logger.Errorf("missing id_token") - c.hubAuthResult(w, "missing id_token") + handleAuthError(w, http.StatusBadGateway, "missing id_token") return } @@ -298,34 +294,68 @@ func (c *Operation) handleOIDCCallback(w http.ResponseWriter, r *http.Request) { ClientID: c.oidcClientID, }).Verify(r.Context(), rawIDToken) if err != nil { - logger.Errorf("failed to verify id_token : %s", err) - c.hubAuthResult(w, fmt.Sprintf("failed to verify id_token : %s", err)) + handleAuthError(w, http.StatusForbidden, fmt.Sprintf("failed to verify id_token : %s", err)) return } - userData := make(map[string]interface{}) + claims := &oidcClaims{} - err = oidcToken.Claims(&userData) + err = oidcToken.Claims(claims) if err != nil { - logger.Errorf("failed to extract user data from id_token : %s", err) - c.hubAuthResult(w, fmt.Sprintf("failed to extract user data from id_token : %s", err)) + handleAuthError(w, http.StatusInternalServerError, fmt.Sprintf("failed to extract claims from id_token : %s", err)) return } - // todo #issue-25 handle user data - _, err = json.Marshal(userData) + userProfile, err := user.NewStore(c.bootstrapStore).Get(claims.Sub) + if errors.Is(err, storage.ErrValueNotFound) { + userProfile, err = c.onboardUser(claims.Sub) + if err != nil { + handleAuthError(w, http.StatusInternalServerError, fmt.Sprintf("failed to onboard new user : %s", err)) + + return + } + } + if err != nil { - logger.Errorf("failed to marshal user data : %s", err) - c.hubAuthResult(w, fmt.Sprintf("failed to marshal user data : %s", err)) + handleAuthError(w, http.StatusInternalServerError, fmt.Sprintf("failed to fetch user profile from store : %s", err)) return } + + handleAuthResult(w, r, userProfile) } -func (c *Operation) hubAuthResult(w http.ResponseWriter, data string) { +// TODO onboard user at key server and SDS: https://github.com/trustbloc/hub-auth/issues/38 +func (c *Operation) onboardUser(id string) (*user.Profile, error) { + userProfile := &user.Profile{ + ID: id, + } + + err := user.NewStore(c.bootstrapStore).Save(userProfile) + if err != nil { + return nil, fmt.Errorf("failed to save user profile : %w", err) + } + + return userProfile, nil +} + +// TODO redirect to the UI: https://github.com/trustbloc/hub-auth/issues/39 +func handleAuthResult(w http.ResponseWriter, r *http.Request, _ *user.Profile) { + http.Redirect(w, r, "", http.StatusFound) +} + +func handleAuthError(w http.ResponseWriter, status int, msg string) { // todo #issue-25 handle user data + logger.Errorf(msg) + + w.WriteHeader(status) + + _, err := w.Write([]byte(msg)) + if err != nil { + logger.Errorf("failed to write error response : %w", err) + } } // writeResponse writes interface value to response. @@ -339,20 +369,14 @@ func (c *Operation) writeErrorResponse(rw http.ResponseWriter, status int, msg s } } -// registerHandler register handlers to be exposed from this service as REST API endpoints. -func (c *Operation) registerHandler() { - // Add more protocol endpoints here to expose them as controller API endpoints - c.handlers = []Handler{ +// GetRESTHandlers get all controller API handler available for this service. +func (c *Operation) GetRESTHandlers() []Handler { + return []Handler{ support.NewHTTPHandler(oauth2GetRequestPath, http.MethodGet, c.createOIDCRequest), support.NewHTTPHandler(oauth2CallbackPath, http.MethodGet, c.handleOIDCCallback), } } -// GetRESTHandlers get all controller API handler available for this service. -func (c *Operation) GetRESTHandlers() []Handler { - return c.handlers -} - func (c *Operation) oauth2Config(scopes ...string) oauth2Config { return c.oauth2ConfigFunc(scopes...) } diff --git a/pkg/restapi/operation/operations_test.go b/pkg/restapi/operation/operations_test.go index 6fdd51a..4cd3987 100644 --- a/pkg/restapi/operation/operations_test.go +++ b/pkg/restapi/operation/operations_test.go @@ -18,12 +18,14 @@ import ( "testing" "github.com/coreos/go-oidc" - "github.com/google/uuid" "github.com/stretchr/testify/require" + "github.com/trustbloc/edge-core/pkg/storage" "github.com/trustbloc/edge-core/pkg/storage/memstore" "github.com/trustbloc/edge-core/pkg/storage/mockstore" "golang.org/x/oauth2" + + "github.com/trustbloc/hub-auth/pkg/internal/common/mockstorage" ) func TestNew(t *testing.T) { @@ -135,7 +137,7 @@ func TestCreateOIDCRequest(t *testing.T) { } func TestHandleOIDCCallback(t *testing.T) { - t.Run("success", func(t *testing.T) { + t.Run("onboard user", func(t *testing.T) { state := uuid.New().String() code := uuid.New().String() @@ -161,13 +163,21 @@ func TestHandleOIDCCallback(t *testing.T) { o.oidcProvider = &mockOIDCProvider{ verifier: &mockVerifier{ - verifyVal: &mockToken{}, + verifyVal: &mockToken{ + oidcClaimsFunc: func(v interface{}) error { + c, ok := v.(*oidcClaims) + require.True(t, ok) + c.Sub = uuid.New().String() + + return nil + }, + }, }, } result := httptest.NewRecorder() o.handleOIDCCallback(result, newOIDCCallback(state, code)) - require.Equal(t, http.StatusOK, result.Code) + require.Equal(t, http.StatusFound, result.Code) }) t.Run("error missing state", func(t *testing.T) { @@ -177,7 +187,7 @@ func TestHandleOIDCCallback(t *testing.T) { require.NoError(t, err) result := httptest.NewRecorder() svc.handleOIDCCallback(result, newOIDCCallback("", "code")) - require.Equal(t, http.StatusOK, result.Code) + require.Equal(t, http.StatusBadRequest, result.Code) }) t.Run("error missing code", func(t *testing.T) { @@ -187,7 +197,7 @@ func TestHandleOIDCCallback(t *testing.T) { require.NoError(t, err) result := httptest.NewRecorder() svc.handleOIDCCallback(result, newOIDCCallback("state", "")) - require.Equal(t, http.StatusOK, result.Code) + require.Equal(t, http.StatusBadRequest, result.Code) }) t.Run("error invalid state parameter", func(t *testing.T) { @@ -197,7 +207,7 @@ func TestHandleOIDCCallback(t *testing.T) { require.NoError(t, err) result := httptest.NewRecorder() svc.handleOIDCCallback(result, newOIDCCallback("state", "code")) - require.Equal(t, http.StatusOK, result.Code) + require.Equal(t, http.StatusBadRequest, result.Code) }) t.Run("generic transient store error", func(t *testing.T) { @@ -218,7 +228,108 @@ func TestHandleOIDCCallback(t *testing.T) { require.NoError(t, err) result := httptest.NewRecorder() svc.handleOIDCCallback(result, newOIDCCallback(state, "code")) - require.Equal(t, http.StatusOK, result.Code) + require.Equal(t, http.StatusInternalServerError, result.Code) + }) + + t.Run("generic bootstrap store FETCH error", func(t *testing.T) { + sub := uuid.New().String() + state := uuid.New().String() + config, cleanup := config() + defer cleanup() + + config.Provider = &mockstorage.Provider{ + Stores: map[string]storage.Store{ + transientStoreName: &mockstore.MockStore{ + Store: map[string][]byte{ + state: []byte(state), + }, + }, + bootstrapStoreName: &mockstore.MockStore{ + Store: map[string][]byte{ + sub: {}, + }, + ErrGet: errors.New("generic"), + }, + }, + } + + svc, err := New(config) + require.NoError(t, err) + + svc.oauth2ConfigFunc = func(...string) oauth2Config { + return &mockOAuth2Config{exchangeVal: &mockToken{ + oauth2Claim: uuid.New().String(), + }} + } + + svc.oidcProvider = &mockOIDCProvider{ + verifier: &mockVerifier{ + verifyVal: &mockToken{ + oidcClaimsFunc: func(v interface{}) error { + c, ok := v.(*oidcClaims) + require.True(t, ok) + c.Sub = sub + + return nil + }, + }, + }, + } + + result := httptest.NewRecorder() + svc.handleOIDCCallback(result, newOIDCCallback(state, "code")) + require.Equal(t, http.StatusInternalServerError, result.Code) + }) + + t.Run("generic bootstrap store PUT error while onboarding user", func(t *testing.T) { + sub := uuid.New().String() + state := uuid.New().String() + config, cleanup := config() + defer cleanup() + + config.Provider = &mockstorage.Provider{ + Stores: map[string]storage.Store{ + transientStoreName: &mockstore.MockStore{ + Store: map[string][]byte{ + state: []byte(state), + }, + }, + bootstrapStoreName: &mockstore.MockStore{ + Store: map[string][]byte{ + sub: []byte("{}"), + }, + ErrGet: storage.ErrValueNotFound, + ErrPut: errors.New("generic"), + }, + }, + } + + svc, err := New(config) + require.NoError(t, err) + + svc.oauth2ConfigFunc = func(...string) oauth2Config { + return &mockOAuth2Config{exchangeVal: &mockToken{ + oauth2Claim: uuid.New().String(), + }} + } + + svc.oidcProvider = &mockOIDCProvider{ + verifier: &mockVerifier{ + verifyVal: &mockToken{ + oidcClaimsFunc: func(v interface{}) error { + c, ok := v.(*oidcClaims) + require.True(t, ok) + c.Sub = sub + + return nil + }, + }, + }, + } + + result := httptest.NewRecorder() + svc.handleOIDCCallback(result, newOIDCCallback(state, "code")) + require.Equal(t, http.StatusInternalServerError, result.Code) }) t.Run("error exchanging auth code", func(t *testing.T) { @@ -239,7 +350,7 @@ func TestHandleOIDCCallback(t *testing.T) { } result := httptest.NewRecorder() svc.handleOIDCCallback(result, newOIDCCallback(state, "code")) - require.Equal(t, http.StatusOK, result.Code) + require.Equal(t, http.StatusBadGateway, result.Code) }) t.Run("error missing id_token", func(t *testing.T) { @@ -261,7 +372,7 @@ func TestHandleOIDCCallback(t *testing.T) { } result := httptest.NewRecorder() svc.handleOIDCCallback(result, newOIDCCallback(state, "code")) - require.Equal(t, http.StatusOK, result.Code) + require.Equal(t, http.StatusBadGateway, result.Code) }) t.Run("error id_token verification", func(t *testing.T) { @@ -283,7 +394,7 @@ func TestHandleOIDCCallback(t *testing.T) { } result := httptest.NewRecorder() svc.handleOIDCCallback(result, newOIDCCallback(state, "code")) - require.Equal(t, http.StatusOK, result.Code) + require.Equal(t, http.StatusForbidden, result.Code) }) t.Run("error scanning id_token claims", func(t *testing.T) { @@ -305,7 +416,7 @@ func TestHandleOIDCCallback(t *testing.T) { } result := httptest.NewRecorder() svc.handleOIDCCallback(result, newOIDCCallback(state, "code")) - require.Equal(t, http.StatusOK, result.Code) + require.Equal(t, http.StatusInternalServerError, result.Code) }) }