Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

identity: preserve session refresh schedule #4633

Merged
merged 2 commits into from Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions internal/identity/manager/data.go
Expand Up @@ -54,6 +54,9 @@ func (u *User) UnmarshalJSON(data []byte) error {
// A Session is a session managed by the Manager.
type Session struct {
*session.Session
// lastRefresh is the time of the last refresh attempt (which may or may
// not have succeeded), or else the time the Manager first became aware of
// the session (if it has not yet attempted to refresh this session).
lastRefresh time.Time
// gracePeriod is the amount of time before expiration to attempt a refresh.
gracePeriod time.Duration
Expand Down
34 changes: 25 additions & 9 deletions internal/identity/manager/manager.go
Expand Up @@ -107,6 +107,10 @@ func (mgr *Manager) GetDataBrokerServiceClient() databroker.DataBrokerServiceCli
return mgr.cfg.Load().dataBrokerClient
}

func (mgr *Manager) now() time.Time {
return mgr.cfg.Load().now()
}

func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecordsMessage, clear <-chan struct{}) error {
// wait for initial sync
select {
Expand Down Expand Up @@ -145,7 +149,7 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
case <-timer.C:
}

now := time.Now()
now := mgr.now()
nextTime = now.Add(maxWait)

// refresh sessions
Expand Down Expand Up @@ -182,6 +186,15 @@ func (mgr *Manager) refreshLoop(ctx context.Context, update <-chan updateRecords
}
}

// refreshSession handles two distinct session lifecycle events:
//
// 1. If the session itself has expired, delete the session.
// 2. If the session's underlying OAuth2 access token is nearing expiration
// (but the session itself is still valid), refresh the access token.
//
// After a successful access token refresh, this method will also trigger a
// user info refresh. If an access token refresh or a user info refresh fails
// with a permanent error, the session will be deleted.
func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string) {
log.Info(ctx).
Str("user_id", userID).
Expand All @@ -208,7 +221,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
}

expiry := s.GetExpiresAt().AsTime()
if !expiry.After(time.Now()) {
if !expiry.After(mgr.now()) {
log.Info(ctx).
Str("user_id", userID).
Str("session_id", sessionID).
Expand Down Expand Up @@ -262,16 +275,17 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
return
}

res, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session)
if err != nil {
if _, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session); err != nil {
log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()).
Msg("failed to update session")
return
}

mgr.onUpdateSession(ctx, res.GetRecord(), s.Session)
s.lastRefresh = mgr.now()
mgr.sessions.ReplaceOrInsert(s)
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(userID, sessionID))
}

func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
Expand All @@ -291,7 +305,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
Msg("no user found for refresh")
return
}
u.lastRefresh = time.Now()
u.lastRefresh = mgr.now()
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())

for _, s := range mgr.sessions.GetSessionsForUser(userID) {
Expand Down Expand Up @@ -343,7 +357,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
log.Warn(ctx).Msgf("error unmarshaling session: %s", err)
continue
}
mgr.onUpdateSession(ctx, record, &pbSession)
mgr.onUpdateSession(record, &pbSession)
case grpcutil.GetTypeURL(new(user.User)):
var pbUser user.User
err := record.GetData().UnmarshalTo(&pbUser)
Expand All @@ -356,7 +370,7 @@ func (mgr *Manager) onUpdateRecords(ctx context.Context, msg updateRecordsMessag
}
}

func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record, session *session.Session) {
func (mgr *Manager) onUpdateSession(record *databroker.Record, session *session.Session) {
mgr.sessionScheduler.Remove(toSessionSchedulerKey(session.GetUserId(), session.GetId()))

if record.GetDeletedAt() != nil {
Expand All @@ -366,7 +380,9 @@ func (mgr *Manager) onUpdateSession(_ context.Context, record *databroker.Record

// update session
s, _ := mgr.sessions.Get(session.GetUserId(), session.GetId())
s.lastRefresh = time.Now()
if s.lastRefresh.IsZero() {
s.lastRefresh = mgr.now()
}
s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod
s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration
s.Session = session
Expand Down
184 changes: 176 additions & 8 deletions internal/identity/manager/manager_test.go
Expand Up @@ -3,6 +3,7 @@ package manager
import (
"context"
"errors"
"fmt"
"testing"
"time"

Expand All @@ -24,18 +25,23 @@ import (
"github.com/pomerium/pomerium/pkg/protoutil"
)

type mockAuthenticator struct{}
type mockAuthenticator struct {
refreshResult *oauth2.Token
refreshError error
revokeError error
updateUserInfoError error
}

func (mock mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) {
return nil, errors.New("update session")
func (mock *mockAuthenticator) Refresh(_ context.Context, _ *oauth2.Token, _ identity.State) (*oauth2.Token, error) {
return mock.refreshResult, mock.refreshError
}

func (mock mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error {
return errors.New("not implemented")
func (mock *mockAuthenticator) Revoke(_ context.Context, _ *oauth2.Token) error {
return mock.revokeError
}

func (mock mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error {
return errors.New("update user info")
func (mock *mockAuthenticator) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ any) error {
return mock.updateUserInfoError
}

func TestManager_refresh(t *testing.T) {
Expand Down Expand Up @@ -86,6 +92,9 @@ func TestManager_onUpdateRecords(t *testing.T) {
})

if _, ok := mgr.sessions.Get("user1", "session1"); assert.True(t, ok) {
tm, id := mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(10*time.Second), tm)
assert.Equal(t, "user1\037session1", id)
}
if _, ok := mgr.users.Get("user1"); assert.True(t, ok) {
tm, id := mgr.userScheduler.Next()
Expand All @@ -94,6 +103,147 @@ func TestManager_onUpdateRecords(t *testing.T) {
}
}

func TestManager_onUpdateSession(t *testing.T) {
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)

s := &session.Session{
Id: "session-id",
UserId: "user-id",
OauthToken: &session.OAuthToken{
AccessToken: "access-token",
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
},
IssuedAt: timestamppb.New(startTime),
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
}

assertNextScheduled := func(t *testing.T, mgr *Manager, expectedTime time.Time) {
t.Helper()
tm, key := mgr.sessionScheduler.Next()
assert.Equal(t, expectedTime, tm)
assert.Equal(t, "user-id\037session-id", key)
}

t.Run("initial refresh event when not expiring soon", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))

// When the Manager first becomes aware of a session it should schedule
// a refresh event for one minute before access token expiration.
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))
})
t.Run("initial refresh event when expiring soon", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))

// When the Manager first becomes aware of a session, if that session
// is expiring within the gracePeriod (1 minute), it should schedule a
// refresh event for as soon as possible, subject to the
// coolOffDuration (10 seconds).
now = now.Add(4*time.Minute + 30*time.Second) // 30 s before expiration
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, now.Add(10*time.Second))
})
t.Run("update near scheduled refresh", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))

mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))

// If a session is updated close to the time when it is scheduled to be
// refreshed, the scheduled refresh event should not be pushed back.
now = now.Add(3*time.Minute + 55*time.Second) // 5 s before refresh
mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, now.Add(5*time.Second))

// However, if an update changes the access token validity, the refresh
// event should be rescheduled accordingly. (This should be uncommon,
// as only the refresh loop itself should modify the access token.)
s2 := proto.Clone(s).(*session.Session)
s2.OauthToken.ExpiresAt = timestamppb.New(now.Add(5 * time.Minute))
mgr.onUpdateSession(mkRecord(s2), s2)
assertNextScheduled(t, mgr, now.Add(4*time.Minute))
})
t.Run("session record deleted", func(t *testing.T) {
now := startTime
mgr := New(WithNow(func() time.Time { return now }))

mgr.onUpdateSession(mkRecord(s), s)
assertNextScheduled(t, mgr, startTime.Add(4*time.Minute))

// If a session is deleted, any scheduled refresh event should be canceled.
record := mkRecord(s)
record.DeletedAt = timestamppb.New(now)
mgr.onUpdateSession(record, s)
_, key := mgr.sessionScheduler.Next()
assert.Empty(t, key)
})
}

func TestManager_refreshSession(t *testing.T) {
startTime := time.Date(2023, 10, 19, 12, 0, 0, 0, time.UTC)

var auth mockAuthenticator

ctrl := gomock.NewController(t)
client := mock_databroker.NewMockDataBrokerServiceClient(ctrl)

now := startTime
mgr := New(
WithDataBrokerClient(client),
WithNow(func() time.Time { return now }),
WithAuthenticator(&auth),
)

// Initialize the Manager with a new session.
s := &session.Session{
Id: "session-id",
UserId: "user-id",
OauthToken: &session.OAuthToken{
AccessToken: "access-token",
ExpiresAt: timestamppb.New(startTime.Add(5 * time.Minute)),
RefreshToken: "refresh-token",
},
IssuedAt: timestamppb.New(startTime),
ExpiresAt: timestamppb.New(startTime.Add(24 * time.Hour)),
}
mgr.sessions.ReplaceOrInsert(Session{
Session: s,
lastRefresh: startTime,
gracePeriod: time.Minute,
coolOffDuration: 10 * time.Second,
})

// After a success token refresh, the manager should schedule another
// refresh event.
now = now.Add(4 * time.Minute)
auth.refreshResult, auth.refreshError = &oauth2.Token{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
Expiry: now.Add(5 * time.Minute),
}, nil
expectedSession := proto.Clone(s).(*session.Session)
expectedSession.OauthToken = &session.OAuthToken{
AccessToken: "new-access-token",
ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)),
RefreshToken: "new-refresh-token",
}
client.EXPECT().Put(gomock.Any(),
objectsAreEqualMatcher{&databroker.PutRequest{Records: []*databroker.Record{{
Type: "type.googleapis.com/session.Session",
Id: "session-id",
Data: protoutil.NewAny(expectedSession),
}}}}).
Return(nil /* this result is currently unused */, nil)
mgr.refreshSession(context.Background(), "user-id", "session-id")

tm, key := mgr.sessionScheduler.Next()
assert.Equal(t, now.Add(4*time.Minute), tm)
assert.Equal(t, "user-id\037session-id", key)
}

func TestManager_reportErrors(t *testing.T) {
ctrl := gomock.NewController(t)

Expand Down Expand Up @@ -135,7 +285,10 @@ func TestManager_reportErrors(t *testing.T) {
mgr := New(
WithEventManager(evtMgr),
WithDataBrokerClient(client),
WithAuthenticator(mockAuthenticator{}),
WithAuthenticator(&mockAuthenticator{
refreshError: errors.New("update session"),
updateUserInfoError: errors.New("update user info"),
}),
)

mgr.onUpdateRecords(ctx, updateRecordsMessage{
Expand Down Expand Up @@ -172,3 +325,18 @@ type recordable interface {
proto.Message
GetId() string
}

// objectsAreEqualMatcher implements gomock.Matcher using ObjectsAreEqual. This
// is especially helpful when working with pointers, as it will compare the
// underlying values rather than the pointers themselves.
type objectsAreEqualMatcher struct {
expected interface{}
}

func (m objectsAreEqualMatcher) Matches(x interface{}) bool {
return assert.ObjectsAreEqual(m.expected, x)
}

func (m objectsAreEqualMatcher) String() string {
return fmt.Sprintf("is equal to %v (%T)", m.expected, m.expected)
}