-
Notifications
You must be signed in to change notification settings - Fork 568
/
postgres_tracker.go
309 lines (284 loc) · 8.26 KB
/
postgres_tracker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
package track
import (
"context"
"database/sql"
"sort"
"time"
"github.com/pachyderm/pachyderm/v2/src/internal/errors"
"github.com/pachyderm/pachyderm/v2/src/internal/pacherr"
"github.com/pachyderm/pachyderm/v2/src/internal/pachsql"
)
var _ Tracker = &postgresTracker{}
type postgresTracker struct {
db *pachsql.DB
}
// NewPostgresTracker returns a
func NewPostgresTracker(db *pachsql.DB) Tracker {
return &postgresTracker{db: db}
}
func (t *postgresTracker) DB() *pachsql.DB {
return t.db
}
func (t *postgresTracker) CreateTx(tx *pachsql.Tx, id string, pointsTo []string, ttl time.Duration) error {
for _, dwn := range pointsTo {
if dwn == id {
return ErrSelfReference
}
}
pointsTo = dedupedStrings(pointsTo)
// create an object or update the ttl of an existing one
intID, created, err := t.putObject(tx, id, ttl)
if err != nil {
return err
}
if !created {
dwn, err := t.getDownstream(tx, intID)
if err != nil {
return err
}
if !stringsMatch(pointsTo, dwn) {
return ErrDifferentObjectExists
}
return nil
}
return t.addReferences(tx, intID, pointsTo)
}
// putObject creates or updates the object at id, to have the max of the current and new ttl.
// If ttl == NoTTL, then the ttl is removed.
func (t *postgresTracker) putObject(tx *pachsql.Tx, id string, ttl time.Duration) (int, bool, error) {
// About xmax https://stackoverflow.com/a/39204667
res := struct {
IntID int `db:"int_id"`
XMax int `db:"xmax"`
}{}
if ttl != NoTTL {
if err := tx.Get(&res,
`INSERT INTO storage.tracker_objects (str_id, expires_at)
VALUES ($1, CURRENT_TIMESTAMP + $2 * interval '1 microsecond')
ON CONFLICT (str_id) DO
UPDATE SET expires_at = greatest(
storage.tracker_objects.expires_at,
(CURRENT_TIMESTAMP + $2 * interval '1 microsecond')
)
WHERE storage.tracker_objects.str_id = $1
RETURNING int_id, xmax
`, id, ttl.Microseconds()); err != nil {
return 0, false, errors.EnsureStack(err)
}
} else {
if err := tx.Get(&res,
`INSERT INTO storage.tracker_objects (str_id)
VALUES ($1)
ON CONFLICT (str_id) DO
UPDATE SET expires_at = NULL
WHERE storage.tracker_objects.str_id = $1
RETURNING int_id, xmax
`, id); err != nil {
return 0, false, errors.EnsureStack(err)
}
}
inserted := res.XMax == 0
return res.IntID, inserted, nil
}
func (t *postgresTracker) addReferences(tx *pachsql.Tx, intID int, pointsTo []string) error {
if len(pointsTo) == 0 {
return nil
}
var pointsToInts []int
if err := tx.Select(&pointsToInts,
`INSERT INTO storage.tracker_refs (from_id, to_id)
SELECT $1, int_id FROM storage.tracker_objects WHERE str_id = ANY($2)
RETURNING to_id`,
intID, pointsTo); err != nil {
return errors.EnsureStack(err)
}
if len(pointsToInts) != len(pointsTo) {
return ErrDanglingRef
}
return nil
}
func (t *postgresTracker) SetTTL(ctx context.Context, id string, ttl time.Duration) (time.Time, error) {
var expiresAt time.Time
err := t.db.GetContext(ctx, &expiresAt,
`UPDATE storage.tracker_objects
SET expires_at = CURRENT_TIMESTAMP + $2 * interval '1 microsecond'
WHERE str_id = $1
RETURNING expires_at
`, id, ttl.Microseconds())
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
err = pacherr.NewNotExist("tracker", id)
}
return time.Time{}, err
}
return expiresAt, nil
}
func (t *postgresTracker) SetTTLPrefix(ctx context.Context, prefix string, ttl time.Duration) (time.Time, int, error) {
var x struct {
Count int `db:"count"`
ExpiresAt time.Time `db:"expires_at"`
}
err := t.db.GetContext(ctx, &x,
`WITH rows AS (
UPDATE storage.tracker_objects
SET expires_at = CURRENT_TIMESTAMP + $2 * interval '1 microsecond'
WHERE str_id LIKE $1 || '%'
RETURNING expires_at
)
SELECT COUNT(*) as count, COALESCE(MIN(expires_at), CURRENT_TIMESTAMP + $2 * interval '1 microsecond') as expires_at FROM rows
`, prefix, ttl.Microseconds())
if err != nil {
return time.Time{}, 0, errors.EnsureStack(err)
}
return x.ExpiresAt, x.Count, nil
}
func (t *postgresTracker) GetDownstream(ctx context.Context, id string) ([]string, error) {
var dwn []string
if err := t.db.SelectContext(ctx, &dwn, `
WITH target AS (
SELECT int_id FROM storage.tracker_objects WHERE str_id = $1
)
SELECT str_id
FROM storage.tracker_objects
WHERE int_id IN (
SELECT to_id FROM storage.tracker_refs WHERE from_id IN (SELECT int_id FROM target)
)
`, id); err != nil {
return nil, errors.EnsureStack(err)
}
return dwn, nil
}
func (t *postgresTracker) GetUpstream(ctx context.Context, id string) ([]string, error) {
ups := []string{}
if err := t.db.SelectContext(ctx, &ups,
`WITH target AS (
SELECT int_id FROM storage.tracker_objects WHERE str_id = $1
)
SELECT str_id
FROM storage.tracker_objects
WHERE int_id IN (
SELECT from_id FROM storage.tracker_refs WHERE to_id IN (SELECT int_id FROM TARGET)
)`, id); err != nil {
return nil, errors.EnsureStack(err)
}
return ups, nil
}
func (t *postgresTracker) GetExpiresAt(ctx context.Context, id string) (time.Time, error) {
var expiresAt time.Time
if err := t.db.GetContext(ctx, &expiresAt,
`SELECT expires_at FROM storage.tracker_objects WHERE str_id = $1
`, id); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return time.Time{}, pacherr.NewNotExist("tracker", id)
}
return time.Time{}, errors.EnsureStack(err)
}
return expiresAt, nil
}
func (t *postgresTracker) DeleteTx(tx *pachsql.Tx, id string) error {
var count int
if err := tx.Get(&count, `
WITH target AS (
SELECT int_id FROM storage.tracker_objects WHERE str_id = $1
)
SELECT count(distinct from_id) FROM storage.tracker_refs WHERE to_id IN (SELECT int_id FROM target)
`, id); err != nil {
return errors.EnsureStack(err)
}
if count > 0 {
return ErrDanglingRef
}
_, err := tx.Exec(`
WITH target AS (
SELECT int_id FROM storage.tracker_objects WHERE str_id = $1
)
DELETE FROM storage.tracker_refs WHERE from_id IN (SELECT int_id FROM target)
`, id)
if err != nil {
return errors.EnsureStack(err)
}
_, err = tx.Exec(`DELETE FROM storage.tracker_objects WHERE str_id = $1`, id)
return errors.EnsureStack(err)
}
func (t *postgresTracker) IterateDeletable(ctx context.Context, cb func(id string) error) (retErr error) {
var toDelete []string
// select 1 in inner query as we don't actually care about the results, just existence
// set arbitrary limit to guarantee we can iterate, doesn't matter for GC as we run this repeatedly
err := t.db.SelectContext(ctx, &toDelete,
`SELECT str_id FROM storage.tracker_objects as objs
WHERE NOT EXISTS (SELECT 1 FROM storage.tracker_refs as refs where objs.int_id = refs.to_id)
AND expires_at <= CURRENT_TIMESTAMP LIMIT 10000`)
if err != nil {
return errors.EnsureStack(err)
}
for _, id := range toDelete {
if err := cb(id); err != nil {
return err
}
}
return nil
}
func (t *postgresTracker) getDownstream(tx *pachsql.Tx, intID int) ([]string, error) {
dwn := []string{}
if err := tx.Select(&dwn, `
SELECT str_id FROM storage.tracker_objects
JOIN storage.tracker_refs ON int_id = to_id
WHERE from_id = $1
`, intID); err != nil {
return nil, errors.EnsureStack(err)
}
return dwn, nil
}
func stringsMatch(a, b []string) bool {
if len(a) != len(b) {
return false
}
sort.Strings(a)
sort.Strings(b)
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func dedupedStrings(xs []string) []string {
ys := append([]string{}, xs...)
return removeDuplicates(ys)
}
func removeDuplicates(xs []string) []string {
sort.Strings(xs)
var countDeleted int
for i := range xs {
if i > 0 && xs[i] == xs[i-1] {
countDeleted++
} else {
xs[i-countDeleted] = xs[i]
}
}
return xs[:len(xs)-countDeleted]
}
// SetupPostgresTrackerV0 sets up the table for the postgres tracker
// DO NOT MODIFY THIS FUNCTION
// IT HAS BEEN USED IN A RELEASED MIGRATION
func SetupPostgresTrackerV0(ctx context.Context, tx *pachsql.Tx) error {
_, err := tx.ExecContext(ctx, schema)
return errors.EnsureStack(err)
}
var schema = `
CREATE TABLE storage.tracker_objects (
int_id BIGSERIAL PRIMARY KEY,
str_id VARCHAR(4096) UNIQUE,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP
);
CREATE TABLE storage.tracker_refs (
from_id INT8 NOT NULL,
to_id INT8 NOT NULL,
PRIMARY KEY (from_id, to_id)
);
CREATE INDEX ON storage.tracker_refs (
to_id,
from_id
);
`