Skip to content

Commit

Permalink
fix: use current sequence for refetching of events (#5772)
Browse files Browse the repository at this point in the history
* fix: use current sequence for refetching of events

* fix: use client ids
  • Loading branch information
livio-a committed Apr 28, 2023
1 parent c8c5cf3 commit 458a383
Show file tree
Hide file tree
Showing 28 changed files with 273 additions and 107 deletions.
8 changes: 4 additions & 4 deletions internal/admin/repository/eventsourcing/handler/styling.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,16 @@ func (_ *Styling) AggregateTypes() []models.AggregateType {
return []models.AggregateType{org.AggregateType, instance.AggregateType}
}

func (m *Styling) CurrentSequence(instanceID string) (uint64, error) {
sequence, err := m.view.GetLatestStylingSequence(instanceID)
func (m *Styling) CurrentSequence(ctx context.Context, instanceID string) (uint64, error) {
sequence, err := m.view.GetLatestStylingSequence(ctx, instanceID)
if err != nil {
return 0, err
}
return sequence.CurrentSequence, nil
}

func (m *Styling) EventQuery(instanceIDs []string) (*models.SearchQuery, error) {
sequences, err := m.view.GetLatestStylingSequences(instanceIDs)
func (m *Styling) EventQuery(ctx context.Context, instanceIDs []string) (*models.SearchQuery, error) {
sequences, err := m.view.GetLatestStylingSequences(ctx, instanceIDs)
if err != nil {
return nil, err
}
Expand Down
9 changes: 5 additions & 4 deletions internal/admin/repository/eventsourcing/view/sequence.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package view

import (
"context"
"time"

"github.com/zitadel/zitadel/internal/eventstore/v1/models"
Expand All @@ -15,12 +16,12 @@ func (v *View) saveCurrentSequence(viewName string, event *models.Event) error {
return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.InstanceID, event.Sequence, event.CreationDate)
}

func (v *View) latestSequence(viewName, instanceID string) (*repository.CurrentSequence, error) {
return repository.LatestSequence(v.Db, sequencesTable, viewName, instanceID)
func (v *View) latestSequence(ctx context.Context, viewName, instanceID string) (*repository.CurrentSequence, error) {
return repository.LatestSequence(v.Db, v.TimeTravel(ctx, sequencesTable), viewName, instanceID)
}

func (v *View) latestSequences(viewName string, instanceIDs []string) ([]*repository.CurrentSequence, error) {
return repository.LatestSequences(v.Db, sequencesTable, viewName, instanceIDs)
func (v *View) latestSequences(ctx context.Context, viewName string, instanceIDs []string) ([]*repository.CurrentSequence, error) {
return repository.LatestSequences(v.Db, v.TimeTravel(ctx, sequencesTable), viewName, instanceIDs)
}

func (v *View) AllCurrentSequences(db, instanceID string) ([]*repository.CurrentSequence, error) {
Expand Down
10 changes: 6 additions & 4 deletions internal/admin/repository/eventsourcing/view/styling.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package view

import (
"context"

"github.com/zitadel/zitadel/internal/eventstore/v1/models"
"github.com/zitadel/zitadel/internal/iam/repository/view"
"github.com/zitadel/zitadel/internal/iam/repository/view/model"
Expand Down Expand Up @@ -39,12 +41,12 @@ func (v *View) UpdateOrgOwnerRemovedStyling(event *models.Event) error {
return v.ProcessedStylingSequence(event)
}

func (v *View) GetLatestStylingSequence(instanceID string) (*global_view.CurrentSequence, error) {
return v.latestSequence(stylingTyble, instanceID)
func (v *View) GetLatestStylingSequence(ctx context.Context, instanceID string) (*global_view.CurrentSequence, error) {
return v.latestSequence(ctx, stylingTyble, instanceID)
}

func (v *View) GetLatestStylingSequences(instanceIDs []string) ([]*global_view.CurrentSequence, error) {
return v.latestSequences(stylingTyble, instanceIDs)
func (v *View) GetLatestStylingSequences(ctx context.Context, instanceIDs []string) ([]*global_view.CurrentSequence, error) {
return v.latestSequences(ctx, stylingTyble, instanceIDs)
}

func (v *View) ProcessedStylingSequence(event *models.Event) error {
Expand Down
14 changes: 12 additions & 2 deletions internal/admin/repository/eventsourcing/view/view.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package view

import (
"context"

"github.com/jinzhu/gorm"

"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
)

type View struct {
Db *gorm.DB
Db *gorm.DB
client *database.DB
}

func StartView(sqlClient *database.DB) (*View, error) {
Expand All @@ -15,10 +20,15 @@ func StartView(sqlClient *database.DB) (*View, error) {
return nil, err
}
return &View{
Db: gorm,
Db: gorm,
client: sqlClient,
}, nil
}

func (v *View) Health() (err error) {
return v.Db.DB().Ping()
}

func (v *View) TimeTravel(ctx context.Context, tableName string) string {
return tableName + v.client.Timetravel(call.Took(ctx))
}
29 changes: 20 additions & 9 deletions internal/auth/repository/eventsourcing/eventstore/auth_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/zitadel/zitadel/internal/telemetry/tracing"
user_model "github.com/zitadel/zitadel/internal/user/model"
user_view_model "github.com/zitadel/zitadel/internal/user/repository/view/model"
"github.com/zitadel/zitadel/internal/view/repository"
)

const unknownUserID = "UNKNOWN"
Expand Down Expand Up @@ -64,7 +65,9 @@ type privacyPolicyProvider interface {
type userSessionViewProvider interface {
UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error)
UserSessionsByAgentID(string, string) ([]*user_view_model.UserSessionView, error)
GetLatestUserSessionSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error)
}

type userViewProvider interface {
UserByID(string, string) (*user_view_model.UserView, error)
}
Expand Down Expand Up @@ -654,7 +657,7 @@ func (repo *AuthRequestRepo) checkLoginName(ctx context.Context, request *domain
preferredLoginName += "@" + request.RequestedPrimaryDomain
}
}
user, err = repo.checkLoginNameInputForResourceOwner(request, preferredLoginName)
user, err = repo.checkLoginNameInputForResourceOwner(ctx, request, preferredLoginName)
} else {
user, err = repo.checkLoginNameInput(ctx, request, preferredLoginName)
}
Expand Down Expand Up @@ -729,12 +732,12 @@ func (repo *AuthRequestRepo) checkDomainDiscovery(ctx context.Context, request *

func (repo *AuthRequestRepo) checkLoginNameInput(ctx context.Context, request *domain.AuthRequest, loginNameInput string) (*user_view_model.UserView, error) {
// always check the loginname first
user, err := repo.View.UserByLoginName(loginNameInput, request.InstanceID)
user, err := repo.View.UserByLoginName(ctx, loginNameInput, request.InstanceID)
if err == nil {
// and take the user regardless if there would be a user with that email or phone
return user, repo.checkLoginPolicyWithResourceOwner(ctx, request, user.ResourceOwner)
}
user, emailErr := repo.View.UserByEmail(loginNameInput, request.InstanceID)
user, emailErr := repo.View.UserByEmail(ctx, loginNameInput, request.InstanceID)
if emailErr == nil {
// if there was a single user with the specified email
// load and check the login policy
Expand All @@ -747,7 +750,7 @@ func (repo *AuthRequestRepo) checkLoginNameInput(ctx context.Context, request *d
return user, nil
}
}
user, phoneErr := repo.View.UserByPhone(loginNameInput, request.InstanceID)
user, phoneErr := repo.View.UserByPhone(ctx, loginNameInput, request.InstanceID)
if phoneErr == nil {
// if there was a single user with the specified phone
// load and check the login policy
Expand All @@ -765,25 +768,25 @@ func (repo *AuthRequestRepo) checkLoginNameInput(ctx context.Context, request *d
return nil, err
}

func (repo *AuthRequestRepo) checkLoginNameInputForResourceOwner(request *domain.AuthRequest, loginNameInput string) (*user_view_model.UserView, error) {
func (repo *AuthRequestRepo) checkLoginNameInputForResourceOwner(ctx context.Context, request *domain.AuthRequest, loginNameInput string) (*user_view_model.UserView, error) {
// always check the loginname first
user, err := repo.View.UserByLoginNameAndResourceOwner(loginNameInput, request.RequestedOrgID, request.InstanceID)
user, err := repo.View.UserByLoginNameAndResourceOwner(ctx, loginNameInput, request.RequestedOrgID, request.InstanceID)
if err == nil {
// and take the user regardless if there would be a user with that email or phone
return user, nil
}
if request.LoginPolicy != nil && !request.LoginPolicy.DisableLoginWithEmail {
// if login by email is allowed and there was a single user with the specified email
// take that user (and ignore possible phone number matches)
user, emailErr := repo.View.UserByEmailAndResourceOwner(loginNameInput, request.RequestedOrgID, request.InstanceID)
user, emailErr := repo.View.UserByEmailAndResourceOwner(ctx, loginNameInput, request.RequestedOrgID, request.InstanceID)
if emailErr == nil {
return user, nil
}
}
if request.LoginPolicy != nil && !request.LoginPolicy.DisableLoginWithPhone {
// if login by phone is allowed and there was a single user with the specified phone
// take that user
user, phoneErr := repo.View.UserByPhoneAndResourceOwner(loginNameInput, request.RequestedOrgID, request.InstanceID)
user, phoneErr := repo.View.UserByPhoneAndResourceOwner(ctx, loginNameInput, request.RequestedOrgID, request.InstanceID)
if phoneErr == nil {
return user, nil
}
Expand Down Expand Up @@ -1298,12 +1301,20 @@ func userSessionsByUserAgentID(provider userSessionViewProvider, agentID, instan
}

func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID string, user *user_model.UserView) (*user_model.UserSessionView, error) {
session, err := provider.UserSessionByIDs(agentID, user.ID, authz.GetInstance(ctx).InstanceID())
instanceID := authz.GetInstance(ctx).InstanceID()
session, err := provider.UserSessionByIDs(agentID, user.ID, instanceID)
if err != nil {
if !errors.IsNotFound(err) {
return nil, err
}
sequence, err := provider.GetLatestUserSessionSequence(ctx, instanceID)
logging.WithFields("instanceID", instanceID, "userID", user.ID).
OnError(err).
Errorf("could not get current sequence for userSessionByIDs")
session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: user.ID}
if sequence != nil {
session.Sequence = sequence.CurrentSequence
}
}
events, err := eventProvider.UserEventsByID(ctx, user.ID, session.Sequence)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
user_model "github.com/zitadel/zitadel/internal/user/model"
user_es_model "github.com/zitadel/zitadel/internal/user/repository/eventsourcing/model"
user_view_model "github.com/zitadel/zitadel/internal/user/repository/view/model"
"github.com/zitadel/zitadel/internal/view/repository"
)

var (
Expand All @@ -35,6 +36,10 @@ func (m *mockViewNoUserSession) UserSessionsByAgentID(string, string) ([]*user_v
return nil, nil
}

func (m *mockViewNoUserSession) GetLatestUserSessionSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error) {
return &repository.CurrentSequence{}, nil
}

type mockViewErrUserSession struct{}

func (m *mockViewErrUserSession) UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) {
Expand All @@ -45,6 +50,10 @@ func (m *mockViewErrUserSession) UserSessionsByAgentID(string, string) ([]*user_
return nil, errors.ThrowInternal(nil, "id", "internal error")
}

func (m *mockViewErrUserSession) GetLatestUserSessionSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error) {
return &repository.CurrentSequence{}, nil
}

type mockViewUserSession struct {
ExternalLoginVerification time.Time
PasswordlessVerification time.Time
Expand Down Expand Up @@ -82,6 +91,10 @@ func (m *mockViewUserSession) UserSessionsByAgentID(string, string) ([]*user_vie
return sessions, nil
}

func (m *mockViewUserSession) GetLatestUserSessionSequence(ctx context.Context, instanceID string) (*repository.CurrentSequence, error) {
return &repository.CurrentSequence{}, nil
}

type mockViewNoUser struct{}

func (m *mockViewNoUser) UserByID(string, string) (*user_view_model.UserView, error) {
Expand Down
15 changes: 12 additions & 3 deletions internal/auth/repository/eventsourcing/eventstore/refresh_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,24 @@ func (r *RefreshTokenRepo) RefreshTokenByToken(ctx context.Context, refreshToken
}

func (r *RefreshTokenRepo) RefreshTokenByID(ctx context.Context, tokenID, userID string) (*usr_model.RefreshTokenView, error) {
tokenView, viewErr := r.View.RefreshTokenByID(tokenID, authz.GetInstance(ctx).InstanceID())
instanceID := authz.GetInstance(ctx).InstanceID()
tokenView, viewErr := r.View.RefreshTokenByID(tokenID, instanceID)
if viewErr != nil && !errors.IsNotFound(viewErr) {
return nil, viewErr
}
if errors.IsNotFound(viewErr) {
sequence, err := r.View.GetLatestRefreshTokenSequence(ctx, instanceID)
logging.WithFields("instanceID", instanceID, "userID", userID, "tokenID", tokenID).
OnError(err).
Errorf("could not get current sequence for RefreshTokenByID")

tokenView = new(model.RefreshTokenView)
tokenView.ID = tokenID
tokenView.UserID = userID
tokenView.InstanceID = authz.GetInstance(ctx).InstanceID()
tokenView.InstanceID = instanceID
if sequence != nil {
tokenView.Sequence = sequence.CurrentSequence
}
}

events, esErr := r.getUserEvents(ctx, userID, tokenView.InstanceID, tokenView.Sequence)
Expand Down Expand Up @@ -80,7 +89,7 @@ func (r *RefreshTokenRepo) SearchMyRefreshTokens(ctx context.Context, userID str
if err != nil {
return nil, err
}
sequence, err := r.View.GetLatestRefreshTokenSequence(authz.GetInstance(ctx).InstanceID())
sequence, err := r.View.GetLatestRefreshTokenSequence(ctx, authz.GetInstance(ctx).InstanceID())
logging.Log("EVENT-GBdn4").OnError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Warn("could not read latest refresh token sequence")
request.Queries = append(request.Queries, &usr_model.RefreshTokenSearchQuery{Key: usr_model.RefreshTokenSearchKeyUserID, Method: domain.SearchMethodEquals, Value: userID})
tokens, count, err := r.View.SearchRefreshTokens(request)
Expand Down
14 changes: 12 additions & 2 deletions internal/auth/repository/eventsourcing/eventstore/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,25 @@ func (repo *TokenRepo) IsTokenValid(ctx context.Context, userID, tokenID string)
}

func (repo *TokenRepo) TokenByIDs(ctx context.Context, userID, tokenID string) (*usr_model.TokenView, error) {
token, viewErr := repo.View.TokenByIDs(tokenID, userID, authz.GetInstance(ctx).InstanceID())
instanceID := authz.GetInstance(ctx).InstanceID()

token, viewErr := repo.View.TokenByIDs(tokenID, userID, instanceID)
if viewErr != nil && !errors.IsNotFound(viewErr) {
return nil, viewErr
}
if errors.IsNotFound(viewErr) {
sequence, err := repo.View.GetLatestTokenSequence(ctx, instanceID)
logging.WithFields("instanceID", instanceID, "userID", userID, "tokenID", tokenID).
OnError(err).
Errorf("could not get current sequence for TokenByIDs")

token = new(model.TokenView)
token.ID = tokenID
token.UserID = userID
token.InstanceID = authz.GetInstance(ctx).InstanceID()
token.InstanceID = instanceID
if sequence != nil {
token.Sequence = sequence.CurrentSequence
}
}

events, esErr := repo.getUserEvents(ctx, userID, token.InstanceID, token.Sequence)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,16 @@ func (t *RefreshToken) AggregateTypes() []es_models.AggregateType {
return []es_models.AggregateType{user.AggregateType, project.AggregateType, instance.AggregateType}
}

func (t *RefreshToken) CurrentSequence(instanceID string) (uint64, error) {
sequence, err := t.view.GetLatestRefreshTokenSequence(instanceID)
func (t *RefreshToken) CurrentSequence(ctx context.Context, instanceID string) (uint64, error) {
sequence, err := t.view.GetLatestRefreshTokenSequence(ctx, instanceID)
if err != nil {
return 0, err
}
return sequence.CurrentSequence, nil
}

func (t *RefreshToken) EventQuery(instanceIDs []string) (*es_models.SearchQuery, error) {
sequences, err := t.view.GetLatestRefreshTokenSequences(instanceIDs)
func (t *RefreshToken) EventQuery(ctx context.Context, instanceIDs []string) (*es_models.SearchQuery, error) {
sequences, err := t.view.GetLatestRefreshTokenSequences(ctx, instanceIDs)
if err != nil {
return nil, err
}
Expand Down
20 changes: 11 additions & 9 deletions internal/auth/repository/eventsourcing/handler/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,16 @@ func (_ *Token) AggregateTypes() []es_models.AggregateType {
return []es_models.AggregateType{user.AggregateType, project.AggregateType, instance.AggregateType}
}

func (t *Token) CurrentSequence(instanceID string) (uint64, error) {
sequence, err := t.view.GetLatestTokenSequence(instanceID)
func (t *Token) CurrentSequence(ctx context.Context, instanceID string) (uint64, error) {
sequence, err := t.view.GetLatestTokenSequence(ctx, instanceID)
if err != nil {
return 0, err
}
return sequence.CurrentSequence, nil
}

func (t *Token) EventQuery(instanceIDs []string) (*es_models.SearchQuery, error) {
sequences, err := t.view.GetLatestTokenSequences(instanceIDs)
func (t *Token) EventQuery(ctx context.Context, instanceIDs []string) (*es_models.SearchQuery, error) {
sequences, err := t.view.GetLatestTokenSequences(ctx, instanceIDs)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -145,11 +145,13 @@ func (t *Token) Reduce(event *es_models.Event) (err error) {
if err != nil {
return err
}
applicationsIDs := make([]string, 0, len(project.Applications))
clientIDs := make([]string, 0, len(project.Applications))
for _, app := range project.Applications {
applicationsIDs = append(applicationsIDs, app.AppID)
if app.OIDCConfig != nil {
clientIDs = append(clientIDs, app.OIDCConfig.ClientID)
}
}
return t.view.DeleteApplicationTokens(event, applicationsIDs...)
return t.view.DeleteApplicationTokens(event, clientIDs...)
case instance.InstanceRemovedEventType:
return t.view.DeleteInstanceTokens(event)
case org.OrgRemovedEventType:
Expand Down Expand Up @@ -208,7 +210,7 @@ func (t *Token) OnSuccess(instanceIDs []string) error {
}

func (t *Token) getProjectByID(ctx context.Context, projID, instanceID string) (*proj_model.Project, error) {
query, err := proj_view.ProjectByIDQuery(projID, instanceID, 0)
projectQuery, err := proj_view.ProjectByIDQuery(projID, instanceID, 0)
if err != nil {
return nil, err
}
Expand All @@ -217,7 +219,7 @@ func (t *Token) getProjectByID(ctx context.Context, projID, instanceID string) (
AggregateID: projID,
},
}
err = es_sdk.Filter(ctx, t.Eventstore().FilterEvents, esProject.AppendEvents, query)
err = es_sdk.Filter(ctx, t.Eventstore().FilterEvents, esProject.AppendEvents, projectQuery)
if err != nil && !caos_errs.IsNotFound(err) {
return nil, err
}
Expand Down

1 comment on commit 458a383

@vercel
Copy link

@vercel vercel bot commented on 458a383 Apr 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

docs – ./

docs-git-main-zitadel.vercel.app
zitadel-docs.vercel.app
docs-zitadel.vercel.app

Please sign in to comment.