diff --git a/queue/client.go b/queue/client.go index bf6cf6a..c22d26a 100644 --- a/queue/client.go +++ b/queue/client.go @@ -2,15 +2,16 @@ package queue import ( "context" - "crypto/sha1" "encoding/json" "errors" "fmt" "regexp" + "strconv" "strings" "time" "github.com/redis/go-redis/v9" + "go.uber.org/multierr" "github.com/replicate/go/shuffleshard" ) @@ -20,10 +21,15 @@ var ( ErrInvalidWriteArgs = errors.New("queue: invalid write arguments") ErrNoMatchingMessageInStream = errors.New("queue: no matching message in stream") ErrInvalidMetaCancelation = errors.New("queue: invalid meta cancelation") + ErrStopGC = errors.New("queue: stop garbage collection") streamSuffixPattern = regexp.MustCompile(`\A:s(\d+)\z`) ) +const ( + metaCancelationGCBatchSize = 100 +) + type Client struct { rdb redis.Cmdable ttl time.Duration // ttl for all keys in queue @@ -52,14 +58,145 @@ func (c *Client) Prepare(ctx context.Context) error { return prepare(ctx, c.rdb) } +// OnGCFunc is called periodically during GC *before* deleting the expired keys. The +// argument given is the "track values" as extracted from the meta cancelation key. +type OnGCFunc func(ctx context.Context, trackValues []string) error + // GC performs all garbage collection operations that cannot be automatically -// performed via key expiry. -func (c *Client) GC(ctx context.Context) error { - if _, err := gcMetaCancelation(ctx, c.rdb); err != nil { +// performed via key expiry, which is the "meta:cancelation" hash at the time of this +// writing. +func (c *Client) GC(ctx context.Context, f OnGCFunc) (uint64, uint64, error) { + now, err := c.rdb.Time(ctx).Result() + if err != nil { + return 0, 0, err + } + + nowUnix := now.Unix() + + nonFatalErrors := []error{} + + idsToDelete := []string{} + keysToDelete := []string{} + + iter := c.rdb.HScanNoValues(ctx, MetaCancelationHash, 0, "*:expiry:*", 0).Iterator() + total := uint64(0) + twiceDeleted := uint64(0) + + for iter.Next(ctx) { + key := iter.Val() + total++ + + if len(idsToDelete) >= metaCancelationGCBatchSize { + n, err := c.gcProcessBatch(ctx, f, idsToDelete, keysToDelete) + if err != nil { + if errors.Is(err, ErrStopGC) { + return total, twiceDeleted / 2, err + } + + nonFatalErrors = append(nonFatalErrors, err) + } + + twiceDeleted += uint64(n) + + idsToDelete = []string{} + keysToDelete = []string{} + + now, err = c.rdb.Time(ctx).Result() + if err != nil { + return total, twiceDeleted / 2, err + } + + nowUnix = now.Unix() + } + + keyParts := strings.Split(key, ":") + if len(keyParts) != 3 { + continue + } + + keyTime, err := strconv.ParseInt(keyParts[2], 0, 64) + if err != nil { + nonFatalErrors = append(nonFatalErrors, err) + continue + } + + if nowUnix > keyTime { + keysToDelete = append(keysToDelete, key, keyParts[0]) + idsToDelete = append(idsToDelete, keyParts[0]) + } + } + + n, err := c.gcProcessBatch(ctx, f, idsToDelete, keysToDelete) + if err != nil { + if errors.Is(err, ErrStopGC) { + return total, twiceDeleted / 2, err + } + + nonFatalErrors = append(nonFatalErrors, err) + } + + twiceDeleted += uint64(n) + + if err := iter.Err(); err != nil { + return total, twiceDeleted / 2, err + } + + return total, twiceDeleted / 2, multierr.Combine(nonFatalErrors...) +} + +func (c *Client) gcProcessBatch(ctx context.Context, f OnGCFunc, idsToDelete, keysToDelete []string) (int64, error) { + if len(idsToDelete) == 0 || len(keysToDelete) == 0 { + return 0, nil + } + + if err := c.callOnGC(ctx, f, idsToDelete); err != nil { + // NOTE: The client `OnGCFunc` may request interruption via the `ErrStopGC` + // error as a way to prevent the `HDel`. + if errors.Is(err, ErrStopGC) { + return 0, err + } + } + + return c.rdb.HDel( + ctx, + MetaCancelationHash, + keysToDelete..., + ).Result() +} + +func (c *Client) callOnGC(ctx context.Context, f OnGCFunc, idsToDelete []string) error { + if f == nil { + return nil + } + + pipe := c.rdb.Pipeline() + hValCmds := make([]*redis.StringCmd, len(idsToDelete)) + + for i, idToDelete := range idsToDelete { + hValCmds[i] = pipe.HGet(ctx, MetaCancelationHash, idToDelete) + } + + if _, err := pipe.Exec(ctx); err != nil { return err } - return nil + trackValues := make([]string, len(idsToDelete)) + + for i, hValCmd := range hValCmds { + msgBytes, err := hValCmd.Bytes() + if err != nil { + return err + } + + msg := &metaCancelation{} + if err := json.Unmarshal(msgBytes, msg); err != nil { + return err + } + + trackValues[i] = msg.TrackValue + } + + return f(ctx, trackValues) } // Len calculates the aggregate length (XLEN) of the queue. It adds up the @@ -267,12 +404,27 @@ func (c *Client) write(ctx context.Context, args *WriteArgs) (string, error) { // Capacity: 3 (for seconds, streams, n) + len(shard) + 2*len(values) cmdArgs := make([]any, 0, 3+len(shard)+2*len(args.Values)) - cmdArgs = append(cmdArgs, int(c.ttl.Seconds())) - cmdArgs = append(cmdArgs, args.Streams) - cmdArgs = append(cmdArgs, len(shard)) + cmdArgs = append( + cmdArgs, + int(c.ttl.Seconds()), + args.Streams, + len(shard), + ) if c.trackField != "" { - cmdArgs = append(cmdArgs, c.trackField) + deadlineUnix := int64(0) + if !args.Deadline.IsZero() { + deadlineUnix = args.Deadline.Unix() + } + + cmdArgs = append( + cmdArgs, + c.trackField, + // NOTE: Deadline is an optional field in WriteArgs, so the Unix value may be + // passed as zero so that the writeTrackingScript uses a default value of the + // server time + ttl. + deadlineUnix, + ) } for _, s := range shard { @@ -290,16 +442,16 @@ func (c *Client) write(ctx context.Context, args *WriteArgs) (string, error) { } type metaCancelation struct { - StreamID string `json:"stream_id"` - MsgID string `json:"msg_id"` + StreamID string `json:"stream_id"` + MsgID string `json:"msg_id"` + TrackValue string `json:"track_value"` + Deadline int64 `json:"deadline"` } // Del supports removal of a message when the given `fieldValue` matches a "meta // cancelation" key as written when using a client with tracking support. func (c *Client) Del(ctx context.Context, fieldValue string) error { - metaCancelationKey := fmt.Sprintf("%x", sha1.Sum([]byte(fieldValue))) - - msgBytes, err := c.rdb.HGet(ctx, MetaCancelationHash, metaCancelationKey).Bytes() + msgBytes, err := c.rdb.HGet(ctx, MetaCancelationHash, fieldValue).Bytes() if err != nil { return err } @@ -324,8 +476,7 @@ func (c *Client) Del(ctx context.Context, fieldValue string) error { if n == 0 { return fmt.Errorf( - "key=%q field-value=%q stream=%q message-id=%q: %w", - metaCancelationKey, + "field-value=%q stream=%q message-id=%q: %w", fieldValue, msg.StreamID, msg.MsgID, diff --git a/queue/client_test.go b/queue/client_test.go index 84f2678..01df4fc 100644 --- a/queue/client_test.go +++ b/queue/client_test.go @@ -3,7 +3,6 @@ package queue_test import ( "context" crand "crypto/rand" - "crypto/sha1" "errors" "fmt" "math/rand" @@ -312,6 +311,7 @@ func runClientWriteIntegrationTest(ctx context.Context, t *testing.T, rdb *redis "name": "panda", "tracketytrack": trackID.String(), }, + Deadline: time.Now().Add(-1 * time.Hour), }) require.NoError(t, err) } @@ -439,12 +439,10 @@ func TestClientDelIntegration(t *testing.T) { require.Error(t, client.Del(ctx, trackIDs[0]+"oops")) require.Error(t, client.Del(ctx, "bogustown")) - metaCancelationKey := fmt.Sprintf("%x", sha1.Sum([]byte(trackIDs[1]))) - - metaCancel, err := rdb.HGet(ctx, queue.MetaCancelationHash, metaCancelationKey).Result() + metaCancel, err := rdb.HGet(ctx, queue.MetaCancelationHash, trackIDs[1]).Result() require.NoError(t, err) - rdb.HSet(ctx, queue.MetaCancelationHash, metaCancelationKey, "{{[,bogus"+metaCancel) + rdb.HSet(ctx, queue.MetaCancelationHash, trackIDs[1], "{{[,bogus"+metaCancel) require.Error(t, client.Del(ctx, trackIDs[1])) @@ -462,7 +460,20 @@ func TestClientGCIntegration(t *testing.T) { runClientWriteIntegrationTest(ctx, t, rdb, client, true) - require.NoError(t, client.GC(ctx)) + gcTrackedFields := []string{} + + onGCFunc := func(_ context.Context, trackedFields []string) error { + gcTrackedFields = append(gcTrackedFields, trackedFields...) + + return nil + } + + total, nDeleted, err := client.GC(ctx, onGCFunc) + require.NoError(t, err) + require.Equal(t, uint64(15), total) + require.Equal(t, uint64(10), nDeleted) + + require.Len(t, gcTrackedFields, 10) } // TestPickupLatencyIntegration runs a test with a mostly-empty queue -- by diff --git a/queue/queue.go b/queue/queue.go index 0ede775..c69aafe 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -35,9 +35,7 @@ package queue import ( "context" _ "embed" // to provide go:embed support - "strconv" "strings" - "time" "github.com/redis/go-redis/v9" ) @@ -76,8 +74,6 @@ var ( const ( MetaCancelationHash = "meta:cancelation" - - metaCancelationGCBatchSize = 100 ) func prepare(ctx context.Context, rdb redis.Cmdable) error { @@ -101,48 +97,3 @@ func prepare(ctx context.Context, rdb redis.Cmdable) error { } return nil } - -func gcMetaCancelation(ctx context.Context, rdb redis.Cmdable) (int, error) { - now := time.Now().UTC().Unix() - keysToDelete := []string{} - iter := rdb.HScan(ctx, MetaCancelationHash, 0, "*:expiry:*", 0).Iterator() - - for iter.Next(ctx) { - key := iter.Val() - - keyParts := strings.Split(key, ":") - if len(keyParts) != 3 { - continue - } - - keyTime, err := strconv.ParseInt(keyParts[2], 0, 64) - if err != nil { - continue - } - - if keyTime > now { - keysToDelete = append(keysToDelete, key, keyParts[0]) - } - } - - if err := iter.Err(); err != nil { - return 0, err - } - - for i := 0; i < len(keysToDelete); i += metaCancelationGCBatchSize { - sliceEnd := i + metaCancelationGCBatchSize - if sliceEnd > len(keysToDelete) { - sliceEnd = len(keysToDelete) - } - - if err := rdb.HDel( - ctx, - MetaCancelationHash, - keysToDelete[i:sliceEnd]..., - ).Err(); err != nil { - return 0, err - } - } - - return len(keysToDelete), nil -} diff --git a/queue/types.go b/queue/types.go index a5c0105..960a75f 100644 --- a/queue/types.go +++ b/queue/types.go @@ -15,8 +15,9 @@ func (e queueError) Error() string { const Empty = queueError("queue: empty") type WriteArgs struct { - Name string // queue name - Values map[string]any // message values + Name string // queue name + Values map[string]any // message values + Deadline time.Time // time after which message will be cancel (only when tracked) Streams int // total number of streams StreamsPerShard int // number of streams in each shard diff --git a/queue/writetracking.lua b/queue/writetracking.lua index ad13689..340a03f 100644 --- a/queue/writetracking.lua +++ b/queue/writetracking.lua @@ -1,6 +1,6 @@ -- Write commands take the form -- --- EVALSHA sha 1 key seconds streams n track_field sid [sid ...] field value [field value ...] +-- EVALSHA sha 1 key seconds streams n track_field deadline sid [sid ...] field value [field value ...] -- -- - `key` is the base key for the queue, e.g. "prediction:input:abcd1234" -- - `seconds` determines the expiry timeout for all keys that make up the @@ -12,6 +12,7 @@ -- or equal to `streams`. -- - `track_field` is the name of the key in `fields` used for tracking the stream -- message ID for cancelation. +-- - `deadline` is the unix timestamp used in the cancelation key. -- - `sid` are the stream IDs to consider writing to. They must be in the range -- [0, `streams`). The message will be written to the shortest of the selected -- streams. @@ -27,8 +28,9 @@ local ttl = tonumber(ARGV[1], 10) local writestreams = tonumber(ARGV[2], 10) local n = tonumber(ARGV[3], 10) local track_field = ARGV[4] -local sids = { unpack(ARGV, 5, 5 + n - 1) } -local fields = { unpack(ARGV, 5 + n, #ARGV) } +local deadline = tonumber(ARGV[5], 10) +local sids = { unpack(ARGV, 6, 6 + n - 1) } +local fields = { unpack(ARGV, 6 + n, #ARGV) } local key_meta = base .. ':meta' local key_notifications = base .. ':notifications' @@ -116,19 +118,22 @@ local id = redis.call('XADD', key_stream, '*', unpack(fields)) redis.call('XADD', key_notifications, 'MAXLEN', '1', '*', 's', selected_sid) if track_value ~= '' then - local cancelation_key = redis.sha1hex(track_value) - local server_time = redis.call('TIME') - local expiry_unixtime = tonumber(server_time[1]) + 90000 -- 25 hours - local cancelation_expiry_key = cancelation_key .. ':expiry:' .. tostring(expiry_unixtime) + if deadline == 0 then + local server_time = redis.call('TIME') + deadline = tonumber(server_time[1], 10) + ttl + end + + local cancelation_expiry_key = track_value .. ':expiry:' .. tostring(deadline) redis.call( 'HSET', '__META_CANCELATION_HASH__', - cancelation_key, + track_value, cjson.encode({ ['stream_id'] = key_stream, ['track_value'] = track_value, ['msg_id'] = id, + ['deadline'] = deadline, }), cancelation_expiry_key, '1'