Skip to content

Commit

Permalink
Check for account list length when marking llinked accounts as expired (
Browse files Browse the repository at this point in the history
  • Loading branch information
pjlast committed Aug 26, 2022
1 parent 0cf3ebe commit d85d391
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 44 deletions.
15 changes: 6 additions & 9 deletions enterprise/cmd/repo-updater/internal/authz/perms_syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,18 +443,15 @@ func (s *PermsSyncer) fetchUserPermsViaExternalAccounts(ctx context.Context, use
return nil, nil, errors.Wrapf(err, "list linked accounts for %d", acct.ID)
}

linkedAcctIDs := make([]int32, len(linkedAccts))
for i, linkedAcct := range linkedAccts {
linkedAcctIDs[i] = linkedAcct.ID
acctIDs := make([]int32, 0, len(linkedAccts)+1)
acctIDs = append(acctIDs, acct.ID)
for _, linkedAcct := range linkedAccts {
acctIDs = append(acctIDs, linkedAcct.ID)
}
if err = accounts.TouchExpired(ctx, linkedAcctIDs...); err != nil {
return nil, nil, errors.Wrapf(err, "set expired for external accounts %v", linkedAcctIDs)
if err = accounts.TouchExpired(ctx, acctIDs...); err != nil {
return nil, nil, errors.Wrapf(err, "set expired for external account IDs %v", acctIDs)
}

err = accounts.TouchExpired(ctx, acct.ID)
if err != nil {
return nil, nil, errors.Wrapf(err, "set expired for external account %d", acct.ID)
}
if unauthorized {
acctLogger.Warn("setExternalAccountExpired, token is revoked",
log.Bool("unauthorized", unauthorized),
Expand Down
4 changes: 4 additions & 0 deletions internal/database/external_accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
}

func (s *userExternalAccountsStore) TouchExpired(ctx context.Context, ids ...int32) error {
if len(ids) == 0 {
return nil
}

idStrings := make([]string, len(ids))
for i, id := range ids {
idStrings[i] = strconv.Itoa(int(id))
Expand Down
81 changes: 46 additions & 35 deletions internal/database/external_accounts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -677,44 +677,55 @@ func TestExternalAccounts_TouchExpiredList(t *testing.T) {
t.Skip()
}
t.Parallel()
logger := logtest.Scoped(t)
db := NewDB(logger, dbtest.NewDB(logger, t))
ctx := context.Background()

spec := extsvc.AccountSpec{
ServiceType: "xa",
ServiceID: "xb",
ClientID: "xc",
AccountID: "xd",
}

userID, err := db.UserExternalAccounts().CreateUserAndSave(ctx, NewUser{Username: "u"}, spec, extsvc.AccountData{})
spec.ServiceID = "xb2"
require.NoError(t, err)
err = db.UserExternalAccounts().Insert(ctx, userID, spec, extsvc.AccountData{})
require.NoError(t, err)
spec.ServiceID = "xb3"
err = db.UserExternalAccounts().Insert(ctx, userID, spec, extsvc.AccountData{})
require.NoError(t, err)
t.Run("non-empty list", func(t *testing.T) {
logger := logtest.Scoped(t)
db := NewDB(logger, dbtest.NewDB(logger, t))
ctx := context.Background()

accts, err := db.UserExternalAccounts().List(ctx, ExternalAccountsListOptions{UserID: 1})
require.NoError(t, err)
require.Equal(t, 3, len(accts))
spec := extsvc.AccountSpec{
ServiceType: "xa",
ServiceID: "xb",
ClientID: "xc",
AccountID: "xd",
}

acctIds := []int32{}
for _, acct := range accts {
acctIds = append(acctIds, acct.ID)
}
userID, err := db.UserExternalAccounts().CreateUserAndSave(ctx, NewUser{Username: "u"}, spec, extsvc.AccountData{})
spec.ServiceID = "xb2"
require.NoError(t, err)
err = db.UserExternalAccounts().Insert(ctx, userID, spec, extsvc.AccountData{})
require.NoError(t, err)
spec.ServiceID = "xb3"
err = db.UserExternalAccounts().Insert(ctx, userID, spec, extsvc.AccountData{})
require.NoError(t, err)

accts, err := db.UserExternalAccounts().List(ctx, ExternalAccountsListOptions{UserID: 1})
require.NoError(t, err)
require.Equal(t, 3, len(accts))

acctIds := []int32{}
for _, acct := range accts {
acctIds = append(acctIds, acct.ID)
}

err = db.UserExternalAccounts().TouchExpired(ctx, acctIds...)
require.NoError(t, err)
err = db.UserExternalAccounts().TouchExpired(ctx, acctIds...)
require.NoError(t, err)

// Confirm all accounts are expired
accts, err = db.UserExternalAccounts().List(ctx, ExternalAccountsListOptions{UserID: 1, OnlyExpired: true})
require.NoError(t, err)
require.Equal(t, 3, len(accts))
// Confirm all accounts are expired
accts, err = db.UserExternalAccounts().List(ctx, ExternalAccountsListOptions{UserID: 1, OnlyExpired: true})
require.NoError(t, err)
require.Equal(t, 3, len(accts))

accts, err = db.UserExternalAccounts().List(ctx, ExternalAccountsListOptions{UserID: 1, ExcludeExpired: true})
require.NoError(t, err)
require.Equal(t, 0, len(accts))
accts, err = db.UserExternalAccounts().List(ctx, ExternalAccountsListOptions{UserID: 1, ExcludeExpired: true})
require.NoError(t, err)
require.Equal(t, 0, len(accts))
})
t.Run("empty list", func(t *testing.T) {
logger := logtest.Scoped(t)
db := NewDB(logger, dbtest.NewDB(logger, t))
ctx := context.Background()

acctIds := []int32{}
err := db.UserExternalAccounts().TouchExpired(ctx, acctIds...)
require.NoError(t, err)
})
}

0 comments on commit d85d391

Please sign in to comment.