Skip to content

Commit

Permalink
fix: only query access tokens by hashed signature
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr committed Aug 8, 2023
1 parent 0b56f53 commit a21e945
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 69 deletions.
12 changes: 7 additions & 5 deletions persistence/sql/persister_nid_test.go
Expand Up @@ -40,22 +40,24 @@ import (
type PersisterTestSuite struct {
suite.Suite
registries map[string]driver.Registry
clean func(*testing.T)
t1 context.Context
t2 context.Context
t1NID uuid.UUID
t2NID uuid.UUID
}

var _ PersisterTestSuite = PersisterTestSuite{}
var _ interface {
suite.SetupAllSuite
suite.TearDownTestSuite
} = (*PersisterTestSuite)(nil)

func (s *PersisterTestSuite) SetupSuite() {
s.registries = map[string]driver.Registry{
"memory": internal.NewRegistrySQLFromURL(s.T(), dbal.NewSQLiteTestDatabase(s.T()), true, &contextx.Default{}),
}

if !testing.Short() {
s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], s.clean = internal.ConnectDatabases(s.T(), true, &contextx.Default{})
s.registries["postgres"], s.registries["mysql"], s.registries["cockroach"], _ = internal.ConnectDatabases(s.T(), true, &contextx.Default{})
}

s.t1NID, s.t2NID = uuid.Must(uuid.NewV4()), uuid.Must(uuid.NewV4())
Expand Down Expand Up @@ -558,11 +560,11 @@ func (s *PersisterTestSuite) DeleteAccessTokenSession() {
require.NoError(t, r.Persister().DeleteAccessTokenSession(s.t2, sig))

actual := persistencesql.OAuth2RequestSQL{Table: "access"}
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, sig))
require.NoError(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig)))
require.Equal(t, s.t1NID, actual.NID)

require.NoError(t, r.Persister().DeleteAccessTokenSession(s.t1, sig))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, sig))
require.Error(t, r.Persister().Connection(context.Background()).Find(&actual, persistencesql.SignatureHash(sig)))
})
}
}
Expand Down
137 changes: 77 additions & 60 deletions persistence/sql/persister_oauth2.go
Expand Up @@ -67,7 +67,7 @@ func (r OAuth2RequestSQL) TableName() string {
return "hydra_oauth2_" + string(r.Table)
}

func (p *Persister) sqlSchemaFromRequest(ctx context.Context, rawSignature string, r fosite.Requester, table tableName) (*OAuth2RequestSQL, error) {
func (p *Persister) sqlSchemaFromRequest(ctx context.Context, signature string, r fosite.Requester, table tableName) (*OAuth2RequestSQL, error) {
subject := ""
if r.GetSession() == nil {
p.l.Debugf("Got an empty session in sqlSchemaFromRequest")
Expand Down Expand Up @@ -101,7 +101,7 @@ func (p *Persister) sqlSchemaFromRequest(ctx context.Context, rawSignature strin
return &OAuth2RequestSQL{
Request: r.GetID(),
ConsentChallenge: challenge,
ID: p.hashSignature(ctx, rawSignature, table),
ID: signature,
RequestedAt: r.GetRequestedAt(),
Client: r.GetClient().GetID(),
Scopes: strings.Join(r.GetRequestedScopes(), "|"),
Expand Down Expand Up @@ -160,20 +160,6 @@ func (r *OAuth2RequestSQL) toRequest(ctx context.Context, session fosite.Session
}, nil
}

// SignatureHash hashes the signature to prevent errors where the signature is
// longer than 128 characters (and thus doesn't fit into the pk).
func SignatureHash(signature string) string {
return fmt.Sprintf("%x", sha512.Sum384([]byte(signature)))
}

// hashSignature prevents errors where the signature is longer than 128 characters (and thus doesn't fit into the pk).
func (p *Persister) hashSignature(_ context.Context, signature string, table tableName) string {
if table == sqlTableAccess {
return SignatureHash(signature)
}
return signature
}

func (p *Persister) ClientAssertionJWTValid(ctx context.Context, jti string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ClientAssertionJWTValid")
defer otelx.End(span, &err)
Expand Down Expand Up @@ -228,7 +214,7 @@ func (p *Persister) SetClientAssertionJWTRaw(ctx context.Context, jti *oauth2.Bl
return sqlcon.HandleError(p.CreateWithNetwork(ctx, jti))
}

func (p *Persister) createSession(ctx context.Context, signature string, requester fosite.Requester, table tableName) (err error) {
func (p *Persister) createSession(ctx context.Context, signature string, requester fosite.Requester, table tableName) error {
req, err := p.sqlSchemaFromRequest(ctx, signature, requester, table)
if err != nil {
return err
Expand All @@ -242,28 +228,21 @@ func (p *Persister) createSession(ctx context.Context, signature string, request
return nil
}

func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature string, session fosite.Session, table tableName) (_ fosite.Requester, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.findSessionBySignature")
defer otelx.End(span, &err)

func (p *Persister) findSessionBySignature(ctx context.Context, signature string, session fosite.Session, table tableName) (fosite.Requester, error) {
r := OAuth2RequestSQL{Table: table}

// We look for the signature as well as the hash of the signature here.
// This is because we now always store the hash of the signature in the database,
// regardless of the type of the signature. In previous versions, we only stored
// the hash of the signature for JWT tokens.
//
// This code will be removed in a future version.
err = p.QueryWithNetwork(ctx).Where("signature IN (?, ?)", rawSignature, SignatureHash(rawSignature)).First(&r)
err := p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r)
if errors.Is(err, sql.ErrNoRows) {
return nil, errorsx.WithStack(fosite.ErrNotFound)
} else if err != nil {
}
if err != nil {
return nil, sqlcon.HandleError(err)
} else if !r.Active {
}
if !r.Active {
fr, err := r.toRequest(ctx, session, p)
if err != nil {
return nil, err
} else if table == sqlTableCode {
}
if table == sqlTableCode {
return fr, errorsx.WithStack(fosite.ErrInvalidatedAuthorizeCode)
}
return fr, errorsx.WithStack(fosite.ErrInactiveToken)
Expand All @@ -272,46 +251,35 @@ func (p *Persister) findSessionBySignature(ctx context.Context, rawSignature str
return r.toRequest(ctx, session, p)
}

func (p *Persister) deleteSessionBySignature(ctx context.Context, signature string, table tableName) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionBySignature")
defer otelx.End(span, &err)

signature = p.hashSignature(ctx, signature, table)

// We look for the signature as well as the hash of the signature here.
// This is because we now always store the hash of the signature in the database,
// regardless of the type of the signature. In previous versions, we only stored
// the hash of the signature for JWT tokens.
//
// This code will be removed in a future version.
err = sqlcon.HandleError(
func (p *Persister) deleteSessionBySignature(ctx context.Context, signature string, table tableName) error {
err := sqlcon.HandleError(
p.QueryWithNetwork(ctx).
Where("signature IN (?, ?)", signature, SignatureHash(signature)).
Where("signature = ?", signature).
Delete(&OAuth2RequestSQL{Table: table}))

if errors.Is(err, sqlcon.ErrNoRows) {
return errorsx.WithStack(fosite.ErrNotFound)
} else if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
}
if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
} else if err != nil {
return err
}
return nil
return err
}

func (p *Persister) deleteSessionByRequestID(ctx context.Context, id string, table tableName) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.deleteSessionByRequestID")
defer otelx.End(span, &err)

/* #nosec G201 table is static */
if err := p.QueryWithNetwork(ctx).
err = p.QueryWithNetwork(ctx).
Where("request_id=?", id).
Delete(&OAuth2RequestSQL{Table: table}); errors.Is(err, sql.ErrNoRows) {
Delete(&OAuth2RequestSQL{Table: table})
if errors.Is(err, sql.ErrNoRows) {
return errorsx.WithStack(fosite.ErrNotFound)
} else if err := sqlcon.HandleError(err); err != nil {
}
if err := sqlcon.HandleError(err); err != nil {
if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
} else if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock?
}
if strings.Contains(err.Error(), "Error 1213") { // InnoDB Deadlock?
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
}
return err
Expand Down Expand Up @@ -356,14 +324,20 @@ func (p *Persister) InvalidateAuthorizeCodeSession(ctx context.Context, signatur
return sqlcon.HandleError(
p.Connection(ctx).
RawQuery(
fmt.Sprintf("UPDATE %s SET active=false WHERE signature=? AND nid = ?", OAuth2RequestSQL{Table: sqlTableCode}.TableName()),
fmt.Sprintf("UPDATE %s SET active = false WHERE signature = ? AND nid = ?", OAuth2RequestSQL{Table: sqlTableCode}.TableName()),
signature,
p.NetworkID(ctx),
).
Exec(),
)
}

// SignatureHash hashes the signature to prevent errors where the signature is
// longer than 128 characters (and thus doesn't fit into the pk).
func SignatureHash(signature string) string {
return fmt.Sprintf("%x", sha512.Sum384([]byte(signature)))
}

func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature string, requester fosite.Requester) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateAccessTokenSession")
defer otelx.End(span, &err)
Expand All @@ -372,19 +346,62 @@ func (p *Persister) CreateAccessTokenSession(ctx context.Context, signature stri
append(toEventOptions(requester), events.WithGrantType(requester.GetRequestForm().Get("grant_type")))...,
)

return p.createSession(ctx, signature, requester, sqlTableAccess)
return p.createSession(ctx, SignatureHash(signature), requester, sqlTableAccess)
}

func (p *Persister) GetAccessTokenSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetAccessTokenSession")
defer otelx.End(span, &err)
return p.findSessionBySignature(ctx, signature, session, sqlTableAccess)

r := OAuth2RequestSQL{Table: sqlTableAccess}
err = p.QueryWithNetwork(ctx).Where("signature = ?", SignatureHash(signature)).First(&r)
if errors.Is(err, sql.ErrNoRows) {
// Backwards compatibility: we previously did not always hash the
// signature before inserting. In case there are still very old (but
// valid) access tokens in the database, this should get them.
err = p.QueryWithNetwork(ctx).Where("signature = ?", signature).First(&r)
if errors.Is(err, sql.ErrNoRows) {
return nil, errorsx.WithStack(fosite.ErrNotFound)
}
}
if err != nil {
return nil, sqlcon.HandleError(err)
}
if !r.Active {
fr, err := r.toRequest(ctx, session, p)
if err != nil {
return nil, err
}
return fr, errorsx.WithStack(fosite.ErrInactiveToken)
}

return r.toRequest(ctx, session, p)
}

func (p *Persister) DeleteAccessTokenSession(ctx context.Context, signature string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteAccessTokenSession")
defer otelx.End(span, &err)
return p.deleteSessionBySignature(ctx, signature, sqlTableAccess)

err = sqlcon.HandleError(
p.QueryWithNetwork(ctx).
Where("signature = ?", SignatureHash(signature)).
Delete(&OAuth2RequestSQL{Table: sqlTableAccess}))
if errors.Is(err, sqlcon.ErrNoRows) {
// Backwards compatibility: we previously did not always hash the
// signature before inserting. In case there are still very old (but
// valid) access tokens in the database, this should get them.
err = sqlcon.HandleError(
p.QueryWithNetwork(ctx).
Where("signature = ?", signature).
Delete(&OAuth2RequestSQL{Table: sqlTableAccess}))
if errors.Is(err, sqlcon.ErrNoRows) {
return errorsx.WithStack(fosite.ErrNotFound)
}
}
if errors.Is(err, sqlcon.ErrConcurrentUpdate) {
return errors.Wrap(fosite.ErrSerializationFailure, err.Error())
}
return err
}

func toEventOptions(requester fosite.Requester) []trace.EventOption {
Expand Down
2 changes: 0 additions & 2 deletions x/audit_test.go
Expand Up @@ -43,8 +43,6 @@ func TestLogAudit(t *testing.T) {
l.Logger.Out = buf
LogAudit(r, tc.message, l)

t.Logf("%s", buf.String())

assert.Contains(t, buf.String(), "audience=audit")
for _, expectContain := range tc.expectContains {
assert.Contains(t, buf.String(), expectContain)
Expand Down
3 changes: 1 addition & 2 deletions x/clean_sql.go
Expand Up @@ -10,7 +10,6 @@ import (
)

func DeleteHydraRows(t *testing.T, c *pop.Connection) {
t.Logf("Deleting hydra rows in database: %s", c.Dialect.Name())
for _, tb := range []string{
"hydra_oauth2_access",
"hydra_oauth2_refresh",
Expand Down Expand Up @@ -57,7 +56,7 @@ func CleanSQLPop(t *testing.T, c *pop.Connection) {
"schema_migration",
} {
if err := c.RawQuery("DROP TABLE IF EXISTS " + tb).Exec(); err != nil {
t.Logf(`Unable to clean up table "%s": %s`, tb, err)
t.Fatalf(`Unable to clean up table "%s": %s`, tb, err)
}
}
t.Logf("Successfully cleaned up database: %s", c.Dialect.Name())
Expand Down

0 comments on commit a21e945

Please sign in to comment.