/
access_tracker.go
171 lines (148 loc) · 4.66 KB
/
access_tracker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
package authorize
import (
"context"
"fmt"
"sync/atomic"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sets"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
)
const (
accessTrackerMaxSize = 1_000
accessTrackerDebouncePeriod = 10 * time.Second
accessTrackerUpdateTimeout = 3 * time.Second
)
// A AccessTrackerProvider provides the databroker service client for tracking session access.
type AccessTrackerProvider interface {
GetDataBrokerServiceClient() databroker.DataBrokerServiceClient
}
// A AccessTracker tracks accesses to sessions
type AccessTracker struct {
provider AccessTrackerProvider
sessionAccesses chan string
serviceAccountAccesses chan string
maxSize int
debouncePeriod time.Duration
droppedAccesses int64
}
// NewAccessTracker creates a new SessionAccessTracker.
func NewAccessTracker(
provider AccessTrackerProvider,
maxSize int,
debouncePeriod time.Duration,
) *AccessTracker {
return &AccessTracker{
provider: provider,
sessionAccesses: make(chan string, maxSize),
serviceAccountAccesses: make(chan string, maxSize),
maxSize: maxSize,
debouncePeriod: debouncePeriod,
}
}
// Run runs the access tracker.
func (tracker *AccessTracker) Run(ctx context.Context) {
ticker := time.NewTicker(tracker.debouncePeriod)
defer ticker.Stop()
sessionAccesses := sets.NewSizeLimited[string](tracker.maxSize)
serviceAccountAccesses := sets.NewSizeLimited[string](tracker.maxSize)
runTrackSessionAccess := func(sessionID string) {
sessionAccesses.Add(sessionID)
}
runTrackServiceAccountAccess := func(serviceAccountID string) {
serviceAccountAccesses.Add(serviceAccountID)
}
runSubmit := func() {
if dropped := atomic.SwapInt64(&tracker.droppedAccesses, 0); dropped > 0 {
log.Error(ctx).
Int64("dropped", dropped).
Msg("authorize: failed to track all session accesses")
}
client := tracker.provider.GetDataBrokerServiceClient()
var err error
sessionAccesses.ForEach(func(sessionID string) bool {
err = tracker.updateSession(ctx, client, sessionID)
return err == nil
})
if err != nil {
log.Error(ctx).Err(err).Msg("authorize: error updating session last access timestamp")
return
}
serviceAccountAccesses.ForEach(func(serviceAccountID string) bool {
err = tracker.updateServiceAccount(ctx, client, serviceAccountID)
return err == nil
})
if err != nil {
log.Error(ctx).Err(err).Msg("authorize: error updating service account last access timestamp")
return
}
sessionAccesses = sets.NewSizeLimited[string](tracker.maxSize)
serviceAccountAccesses = sets.NewSizeLimited[string](tracker.maxSize)
}
for {
select {
case <-ctx.Done():
return
case id := <-tracker.sessionAccesses:
runTrackSessionAccess(id)
case id := <-tracker.serviceAccountAccesses:
runTrackServiceAccountAccess(id)
case <-ticker.C:
runSubmit()
}
}
}
// TrackServiceAccountAccess tracks a service account access.
func (tracker *AccessTracker) TrackServiceAccountAccess(serviceAccountID string) {
select {
case tracker.serviceAccountAccesses <- serviceAccountID:
default:
atomic.AddInt64(&tracker.droppedAccesses, 1)
}
}
// TrackSessionAccess tracks a session access.
func (tracker *AccessTracker) TrackSessionAccess(sessionID string) {
select {
case tracker.sessionAccesses <- sessionID:
default:
atomic.AddInt64(&tracker.droppedAccesses, 1)
}
}
func (tracker *AccessTracker) updateServiceAccount(
ctx context.Context,
client databroker.DataBrokerServiceClient,
serviceAccountID string,
) error {
ctx, clearTimeout := context.WithTimeout(ctx, accessTrackerUpdateTimeout)
defer clearTimeout()
sa, err := user.GetServiceAccount(ctx, client, serviceAccountID)
if status.Code(err) == codes.NotFound {
return nil
} else if err != nil {
return err
}
sa.AccessedAt = timestamppb.Now()
_, err = user.PutServiceAccount(ctx, client, sa)
return err
}
func (tracker *AccessTracker) updateSession(
ctx context.Context,
client databroker.DataBrokerServiceClient,
sessionID string,
) error {
ctx, clearTimeout := context.WithTimeout(ctx, accessTrackerUpdateTimeout)
defer clearTimeout()
s := &session.Session{Id: sessionID, AccessedAt: timestamppb.Now()}
m, err := fieldmaskpb.New(s, "accessed_at")
if err != nil {
return fmt.Errorf("internal error: %w", err)
}
_, err = session.Patch(ctx, client, s, m)
return err
}