From 0fb0614db352cae2f9847fbcaebcbd47e5519472 Mon Sep 17 00:00:00 2001 From: arekkas Date: Sat, 7 Jul 2018 14:41:38 +0200 Subject: [PATCH] oauth2: Removes tokens when consent is revoked Closes #856 Signed-off-by: arekkas --- cmd/cli/handler_migrate.go | 2 +- cmd/server/handler_consent_factory.go | 3 +- consent/manager_memory.go | 35 ++++++++++------------- consent/manager_sql.go | 41 +++++++++++++++++++++------ consent/manager_test.go | 40 ++++++++++++++++++++------ consent/sdk_test.go | 7 +++-- consent/strategy_default_test.go | 2 +- integration/sql_schema_test.go | 6 ++-- oauth2/fosite_store_memory.go | 4 +-- oauth2/handler.go | 2 ++ oauth2/oauth2_auth_code_test.go | 4 +-- 11 files changed, 97 insertions(+), 49 deletions(-) diff --git a/cmd/cli/handler_migrate.go b/cmd/cli/handler_migrate.go index d87c96b118..071415dab7 100644 --- a/cmd/cli/handler_migrate.go +++ b/cmd/cli/handler_migrate.go @@ -117,7 +117,7 @@ func (h *MigrateHandler) runMigrateSQL(db *sqlx.DB) error { "client": &client.SQLManager{DB: db}, "oauth2": &oauth2.FositeSQLStore{DB: db}, "jwk": &jwk.SQLManager{DB: db}, - "consent": consent.NewSQLManager(db, nil), + "consent": consent.NewSQLManager(db, nil, nil), } { fmt.Printf("Applying `%s` SQL migrations...\n", k) if num, err := m.CreateSchemas(); err != nil { diff --git a/cmd/server/handler_consent_factory.go b/cmd/server/handler_consent_factory.go index c075bb9485..51bae59473 100644 --- a/cmd/server/handler_consent_factory.go +++ b/cmd/server/handler_consent_factory.go @@ -35,12 +35,13 @@ func injectConsentManager(c *config.Config, cm client.Manager) { switch con := ctx.Connection.(type) { case *config.MemoryConnection: - manager = consent.NewMemoryManager() + manager = consent.NewMemoryManager(ctx.FositeStore) break case *sqlcon.SQLConnection: manager = consent.NewSQLManager( con.GetDatabase(), cm, + ctx.FositeStore, ) break case *config.PluginConnection: diff --git a/consent/manager_memory.go b/consent/manager_memory.go index c9335c558b..3928b3bf02 100644 --- a/consent/manager_memory.go +++ b/consent/manager_memory.go @@ -36,15 +36,17 @@ type MemoryManager struct { handledAuthRequests map[string]HandledAuthenticationRequest authSessions map[string]AuthenticationSession m map[string]*sync.RWMutex + store pkg.FositeStorer } -func NewMemoryManager() *MemoryManager { +func NewMemoryManager(store pkg.FositeStorer) *MemoryManager { return &MemoryManager{ consentRequests: map[string]ConsentRequest{}, handledConsentRequests: map[string]HandledConsentRequest{}, authRequests: map[string]AuthenticationRequest{}, handledAuthRequests: map[string]HandledAuthenticationRequest{}, authSessions: map[string]AuthenticationSession{}, + store: store, m: map[string]*sync.RWMutex{ "consentRequests": new(sync.RWMutex), "handledConsentRequests": new(sync.RWMutex), @@ -56,24 +58,7 @@ func NewMemoryManager() *MemoryManager { } func (m *MemoryManager) RevokeUserConsentSession(user string) error { - m.m["handledConsentRequests"].Lock() - defer m.m["handledConsentRequests"].Unlock() - m.m["consentRequests"].Lock() - defer m.m["consentRequests"].Unlock() - - var found bool - for k, c := range m.handledConsentRequests { - if c.ConsentRequest.Subject == user { - delete(m.handledConsentRequests, k) - delete(m.consentRequests, k) - found = true - } - } - - if !found { - return errors.WithStack(pkg.ErrNotFound) - } - return nil + return m.RevokeUserClientConsentSession(user, "") } func (m *MemoryManager) RevokeUserClientConsentSession(user, client string) error { @@ -84,9 +69,19 @@ func (m *MemoryManager) RevokeUserClientConsentSession(user, client string) erro var found bool for k, c := range m.handledConsentRequests { - if c.ConsentRequest.Subject == user && c.ConsentRequest.Client.ID == client { + if c.ConsentRequest.Subject == user && (client == "" || (client != "" && c.ConsentRequest.Client.ID == client)) { delete(m.handledConsentRequests, k) delete(m.consentRequests, k) + if err := m.store.RevokeAccessToken(nil, c.Challenge); errors.Cause(err) == fosite.ErrNotFound { + // do nothing + } else if err != nil { + return err + } + if err := m.store.RevokeRefreshToken(nil, c.Challenge); errors.Cause(err) == fosite.ErrNotFound { + // do nothing + } else if err != nil { + return err + } found = true } } diff --git a/consent/manager_sql.go b/consent/manager_sql.go index 3b6d15f98d..35a4706a4a 100644 --- a/consent/manager_sql.go +++ b/consent/manager_sql.go @@ -36,14 +36,16 @@ import ( ) type SQLManager struct { - db *sqlx.DB - c client.Manager + db *sqlx.DB + c client.Manager + store pkg.FositeStorer } -func NewSQLManager(db *sqlx.DB, c client.Manager) *SQLManager { +func NewSQLManager(db *sqlx.DB, c client.Manager, store pkg.FositeStorer) *SQLManager { return &SQLManager{ - db: db, - c: c, + db: db, + c: c, + store: store, } } @@ -72,14 +74,37 @@ func (m *SQLManager) revokeConsentSession(user, client string) error { args = append(args, client) } - var queries []string + var challenges = make([]string, 0) + if err := m.db.Select(&challenges, m.db.Rebind(fmt.Sprintf( + `SELECT r.challenge FROM hydra_oauth2_consent_request_handled as h +JOIN hydra_oauth2_consent_request as r ON r.challenge = h.challenge WHERE %s`, + part, + )), args...); err != nil { + if err == sql.ErrNoRows { + return errors.WithStack(pkg.ErrNotFound) + } + return sqlcon.HandleError(err) + } + for _, challenge := range challenges { + if err := m.store.RevokeAccessToken(nil, challenge); errors.Cause(err) == fosite.ErrNotFound { + // do nothing + } else if err != nil { + return err + } + if err := m.store.RevokeRefreshToken(nil, challenge); errors.Cause(err) == fosite.ErrNotFound { + // do nothing + } else if err != nil { + return err + } + } + + var queries []string switch m.db.DriverName() { case "mysql": queries = append(queries, fmt.Sprintf(`DELETE h, r FROM hydra_oauth2_consent_request_handled as h -JOIN hydra_oauth2_consent_request as r ON -r.challenge = h.challenge +JOIN hydra_oauth2_consent_request as r ON r.challenge = h.challenge WHERE %s`, part), ) default: diff --git a/consent/manager_test.go b/consent/manager_test.go index df1a707ac5..aed1bc635c 100644 --- a/consent/manager_test.go +++ b/consent/manager_test.go @@ -18,7 +18,7 @@ * @license Apache-2.0 */ -package consent +package consent_test import ( "flag" @@ -31,12 +31,20 @@ import ( _ "github.com/lib/pq" "github.com/ory/fosite" "github.com/ory/hydra/client" + . "github.com/ory/hydra/consent" + "github.com/ory/hydra/oauth2" "github.com/ory/hydra/pkg" "github.com/ory/sqlcon/dockertest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +var clientManager = client.NewMemoryManager(&fosite.BCrypt{WorkFactor: 8}) +var fositeManager = oauth2.NewFositeMemoryStore(clientManager, time.Hour) +var managers = map[string]Manager{ + "memory": NewMemoryManager(fositeManager), +} + func mockConsentRequest(key string, remember bool, rememberFor int, hasError bool, skip bool, authAt bool) (c *ConsentRequest, h *HandledConsentRequest) { c = &ConsentRequest{ OpenIDConnectContext: &OpenIDConnectContext{ @@ -138,7 +146,7 @@ func connectToPostgres(managers map[string]Manager, c client.Manager) { return } - s := NewSQLManager(db, c) + s := NewSQLManager(db, c, fositeManager) if _, err := s.CreateSchemas(); err != nil { log.Fatalf("Could not connect to database: %v", err) return @@ -154,7 +162,7 @@ func connectToMySQL(managers map[string]Manager, c client.Manager) { return } - s := NewSQLManager(db, c) + s := NewSQLManager(db, c, fositeManager) if _, err := s.CreateSchemas(); err != nil { log.Fatalf("Could not create mysql schema: %v", err) return @@ -163,11 +171,6 @@ func connectToMySQL(managers map[string]Manager, c client.Manager) { managers["mysql"] = s } -var clientManager = client.NewMemoryManager(&fosite.BCrypt{WorkFactor: 8}) -var managers = map[string]Manager{ - "memory": NewMemoryManager(), -} - func TestMain(m *testing.M) { runner := dockertest.Register() @@ -410,7 +413,6 @@ func TestManagers(t *testing.T) { for k, m := range managers { cr1, hcr1 := mockConsentRequest("rv1", false, 0, false, false, false) cr2, hcr2 := mockConsentRequest("rv2", false, 0, false, false, false) - clientManager.CreateClient(cr1.Client) clientManager.CreateClient(cr2.Client) @@ -422,23 +424,36 @@ func TestManagers(t *testing.T) { require.NoError(t, err) t.Run("manager="+k, func(t *testing.T) { + fositeManager.CreateAccessTokenSession(nil, "trva1", &fosite.Request{ID: "challengerv1", RequestedAt: time.Now()}) + fositeManager.CreateRefreshTokenSession(nil, "rrva1", &fosite.Request{ID: "challengerv1", RequestedAt: time.Now()}) + fositeManager.CreateAccessTokenSession(nil, "trva2", &fosite.Request{ID: "challengerv2", RequestedAt: time.Now()}) + fositeManager.CreateRefreshTokenSession(nil, "rrva2", &fosite.Request{ID: "challengerv2", RequestedAt: time.Now()}) + for i, tc := range []struct { subject string client string + at string + rt string ids []string }{ { + at: "trva1", rt: "rrva1", subject: "subjectrv1", client: "", ids: []string{"challengerv1"}, }, { + at: "trva2", rt: "rrva2", subject: "subjectrv2", client: "clientrv2", ids: []string{"challengerv2"}, }, } { t.Run(fmt.Sprintf("case=%d/subject=%s", i, tc.subject), func(t *testing.T) { + _, found := fositeManager.AccessTokens[tc.at] + assert.True(t, found) + _, found = fositeManager.RefreshTokens[tc.rt] + assert.True(t, found) if tc.client == "" { require.NoError(t, m.RevokeUserConsentSession(tc.subject)) @@ -452,6 +467,13 @@ func TestManagers(t *testing.T) { assert.EqualError(t, err, pkg.ErrNotFound.Error()) }) } + + t.Logf("Got at %+v", fositeManager.AccessTokens) + t.Logf("Got rt %+v", fositeManager.RefreshTokens) + _, found = fositeManager.AccessTokens[tc.at] + assert.False(t, found) + _, found = fositeManager.RefreshTokens[tc.rt] + assert.False(t, found) }) } }) diff --git a/consent/sdk_test.go b/consent/sdk_test.go index e462a9953c..0c5dd28cce 100644 --- a/consent/sdk_test.go +++ b/consent/sdk_test.go @@ -18,15 +18,18 @@ * @license Apache-2.0 */ -package consent +package consent_test import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/julienschmidt/httprouter" "github.com/ory/herodot" + . "github.com/ory/hydra/consent" + "github.com/ory/hydra/oauth2" "github.com/ory/hydra/sdk/go/hydra" "github.com/ory/hydra/sdk/go/hydra/swagger" "github.com/sirupsen/logrus" @@ -35,7 +38,7 @@ import ( ) func TestSDK(t *testing.T) { - m := NewMemoryManager() + m := NewMemoryManager(oauth2.NewFositeMemoryStore(nil, time.Minute)) router := httprouter.New() h := NewHandler(herodot.NewJSONWriter(logrus.New()), m) diff --git a/consent/strategy_default_test.go b/consent/strategy_default_test.go index 94712288df..83113f81f0 100644 --- a/consent/strategy_default_test.go +++ b/consent/strategy_default_test.go @@ -95,7 +95,7 @@ func TestStrategy(t *testing.T) { require.NoError(t, err) writer := herodot.NewJSONWriter(nil) - manager := NewMemoryManager() + manager := NewMemoryManager(nil) handler := NewHandler(writer, manager) router := httprouter.New() handler.SetRoutes(router) diff --git a/integration/sql_schema_test.go b/integration/sql_schema_test.go index 6af630a797..b814426dcc 100644 --- a/integration/sql_schema_test.go +++ b/integration/sql_schema_test.go @@ -57,9 +57,9 @@ func TestSQLSchema(t *testing.T) { } cm := &client.SQLManager{DB: db, Hasher: &fosite.BCrypt{}} - jm := jwk.SQLManager{DB: db, Cipher: &jwk.AEAD{Key: []byte("11111111111111111111111111111111")}} - om := oauth2.FositeSQLStore{Manager: cm, DB: db, L: logrus.New()} - crm := consent.NewSQLManager(db, nil) + jm := &jwk.SQLManager{DB: db, Cipher: &jwk.AEAD{Key: []byte("11111111111111111111111111111111")}} + om := &oauth2.FositeSQLStore{Manager: cm, DB: db, L: logrus.New()} + crm := consent.NewSQLManager(db, nil, om) pm := lsql.NewSQLManager(db, nil) _, err = pm.CreateSchemas("", "hydra_policy_migration") diff --git a/oauth2/fosite_store_memory.go b/oauth2/fosite_store_memory.go index 3c283cd9c7..36a5b844df 100644 --- a/oauth2/fosite_store_memory.go +++ b/oauth2/fosite_store_memory.go @@ -193,7 +193,7 @@ func (s *FositeMemoryStore) RevokeRefreshToken(ctx context.Context, id string) e } } if !found { - return errors.New("Not found") + return errors.WithStack(fosite.ErrNotFound) } return nil } @@ -211,7 +211,7 @@ func (s *FositeMemoryStore) RevokeAccessToken(ctx context.Context, id string) er } } if !found { - return errors.New("Not found") + return errors.WithStack(fosite.ErrNotFound) } return nil } diff --git a/oauth2/handler.go b/oauth2/handler.go index 9e151df58f..5e0274241c 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -540,6 +540,8 @@ func (h *Handler) AuthHandler(w http.ResponseWriter, r *http.Request, _ httprout return } + authorizeRequest.SetID(session.Challenge) + // done response, err := h.OAuth2.NewAuthorizeResponse(ctx, authorizeRequest, &Session{ DefaultSession: &openid.DefaultSession{ diff --git a/oauth2/oauth2_auth_code_test.go b/oauth2/oauth2_auth_code_test.go index b9da12c482..b1401eeea8 100644 --- a/oauth2/oauth2_auth_code_test.go +++ b/oauth2/oauth2_auth_code_test.go @@ -119,12 +119,12 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) { var cm consent.Manager switch km { case "memory": - cm = consent.NewMemoryManager() + cm = consent.NewMemoryManager(fs) fs.(*FositeMemoryStore).Manager = hc.NewMemoryManager(hasher) case "mysql": fallthrough case "postgres": - scm := consent.NewSQLManager(databases[km], fs.(*FositeSQLStore).Manager) + scm := consent.NewSQLManager(databases[km], fs.(*FositeSQLStore).Manager, fs) _, err := scm.CreateSchemas() require.NoError(t, err)