Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 167 additions & 16 deletions queue/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package queue

import (
"context"
"crypto/sha1"
"encoding/json"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"time"

"github.com/redis/go-redis/v9"
"go.uber.org/multierr"

"github.com/replicate/go/shuffleshard"
)
Expand All @@ -20,10 +21,15 @@ var (
ErrInvalidWriteArgs = errors.New("queue: invalid write arguments")
ErrNoMatchingMessageInStream = errors.New("queue: no matching message in stream")
ErrInvalidMetaCancelation = errors.New("queue: invalid meta cancelation")
ErrStopGC = errors.New("queue: stop garbage collection")

streamSuffixPattern = regexp.MustCompile(`\A:s(\d+)\z`)
)

const (
metaCancelationGCBatchSize = 100
)

type Client struct {
rdb redis.Cmdable
ttl time.Duration // ttl for all keys in queue
Expand Down Expand Up @@ -52,14 +58,145 @@ func (c *Client) Prepare(ctx context.Context) error {
return prepare(ctx, c.rdb)
}

// OnGCFunc is called periodically during GC *before* deleting the expired keys. The
// argument given is the "track values" as extracted from the meta cancelation key.
type OnGCFunc func(ctx context.Context, trackValues []string) error

// GC performs all garbage collection operations that cannot be automatically
// performed via key expiry.
func (c *Client) GC(ctx context.Context) error {
if _, err := gcMetaCancelation(ctx, c.rdb); err != nil {
// performed via key expiry, which is the "meta:cancelation" hash at the time of this
// writing.
func (c *Client) GC(ctx context.Context, f OnGCFunc) (uint64, uint64, error) {
now, err := c.rdb.Time(ctx).Result()
if err != nil {
return 0, 0, err
}

nowUnix := now.Unix()

nonFatalErrors := []error{}

idsToDelete := []string{}
keysToDelete := []string{}

iter := c.rdb.HScanNoValues(ctx, MetaCancelationHash, 0, "*:expiry:*", 0).Iterator()
total := uint64(0)
twiceDeleted := uint64(0)

for iter.Next(ctx) {
key := iter.Val()
total++

if len(idsToDelete) >= metaCancelationGCBatchSize {
n, err := c.gcProcessBatch(ctx, f, idsToDelete, keysToDelete)
if err != nil {
if errors.Is(err, ErrStopGC) {
return total, twiceDeleted / 2, err
}

nonFatalErrors = append(nonFatalErrors, err)
}

twiceDeleted += uint64(n)

idsToDelete = []string{}
keysToDelete = []string{}

now, err = c.rdb.Time(ctx).Result()
if err != nil {
return total, twiceDeleted / 2, err
}

nowUnix = now.Unix()
}

keyParts := strings.Split(key, ":")
if len(keyParts) != 3 {
continue
}

keyTime, err := strconv.ParseInt(keyParts[2], 0, 64)
if err != nil {
nonFatalErrors = append(nonFatalErrors, err)
continue
}

if nowUnix > keyTime {
keysToDelete = append(keysToDelete, key, keyParts[0])
idsToDelete = append(idsToDelete, keyParts[0])
}
}

n, err := c.gcProcessBatch(ctx, f, idsToDelete, keysToDelete)
if err != nil {
if errors.Is(err, ErrStopGC) {
return total, twiceDeleted / 2, err
}

nonFatalErrors = append(nonFatalErrors, err)
}

twiceDeleted += uint64(n)

if err := iter.Err(); err != nil {
return total, twiceDeleted / 2, err
}

return total, twiceDeleted / 2, multierr.Combine(nonFatalErrors...)
}

func (c *Client) gcProcessBatch(ctx context.Context, f OnGCFunc, idsToDelete, keysToDelete []string) (int64, error) {
if len(idsToDelete) == 0 || len(keysToDelete) == 0 {
return 0, nil
}

if err := c.callOnGC(ctx, f, idsToDelete); err != nil {
// NOTE: The client `OnGCFunc` may request interruption via the `ErrStopGC`
// error as a way to prevent the `HDel`.
if errors.Is(err, ErrStopGC) {
return 0, err
}
}

return c.rdb.HDel(
ctx,
MetaCancelationHash,
keysToDelete...,
).Result()
}

func (c *Client) callOnGC(ctx context.Context, f OnGCFunc, idsToDelete []string) error {
if f == nil {
return nil
}

pipe := c.rdb.Pipeline()
hValCmds := make([]*redis.StringCmd, len(idsToDelete))

for i, idToDelete := range idsToDelete {
hValCmds[i] = pipe.HGet(ctx, MetaCancelationHash, idToDelete)
}

if _, err := pipe.Exec(ctx); err != nil {
return err
}

return nil
trackValues := make([]string, len(idsToDelete))

for i, hValCmd := range hValCmds {
msgBytes, err := hValCmd.Bytes()
if err != nil {
return err
}

msg := &metaCancelation{}
if err := json.Unmarshal(msgBytes, msg); err != nil {
return err
}

trackValues[i] = msg.TrackValue
}

return f(ctx, trackValues)
}

// Len calculates the aggregate length (XLEN) of the queue. It adds up the
Expand Down Expand Up @@ -267,12 +404,27 @@ func (c *Client) write(ctx context.Context, args *WriteArgs) (string, error) {
// Capacity: 3 (for seconds, streams, n) + len(shard) + 2*len(values)
cmdArgs := make([]any, 0, 3+len(shard)+2*len(args.Values))

cmdArgs = append(cmdArgs, int(c.ttl.Seconds()))
cmdArgs = append(cmdArgs, args.Streams)
cmdArgs = append(cmdArgs, len(shard))
cmdArgs = append(
cmdArgs,
int(c.ttl.Seconds()),
args.Streams,
len(shard),
)

if c.trackField != "" {
cmdArgs = append(cmdArgs, c.trackField)
deadlineUnix := int64(0)
if !args.Deadline.IsZero() {
deadlineUnix = args.Deadline.Unix()
}

cmdArgs = append(
cmdArgs,
c.trackField,
// NOTE: Deadline is an optional field in WriteArgs, so the Unix value may be
// passed as zero so that the writeTrackingScript uses a default value of the
// server time + ttl.
deadlineUnix,
)
}

for _, s := range shard {
Expand All @@ -290,16 +442,16 @@ func (c *Client) write(ctx context.Context, args *WriteArgs) (string, error) {
}

type metaCancelation struct {
StreamID string `json:"stream_id"`
MsgID string `json:"msg_id"`
StreamID string `json:"stream_id"`
MsgID string `json:"msg_id"`
TrackValue string `json:"track_value"`
Deadline int64 `json:"deadline"`
}

// Del supports removal of a message when the given `fieldValue` matches a "meta
// cancelation" key as written when using a client with tracking support.
func (c *Client) Del(ctx context.Context, fieldValue string) error {
metaCancelationKey := fmt.Sprintf("%x", sha1.Sum([]byte(fieldValue)))

msgBytes, err := c.rdb.HGet(ctx, MetaCancelationHash, metaCancelationKey).Bytes()
msgBytes, err := c.rdb.HGet(ctx, MetaCancelationHash, fieldValue).Bytes()
if err != nil {
return err
}
Expand All @@ -324,8 +476,7 @@ func (c *Client) Del(ctx context.Context, fieldValue string) error {

if n == 0 {
return fmt.Errorf(
"key=%q field-value=%q stream=%q message-id=%q: %w",
metaCancelationKey,
"field-value=%q stream=%q message-id=%q: %w",
fieldValue,
msg.StreamID,
msg.MsgID,
Expand Down
23 changes: 17 additions & 6 deletions queue/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package queue_test
import (
"context"
crand "crypto/rand"
"crypto/sha1"
"errors"
"fmt"
"math/rand"
Expand Down Expand Up @@ -312,6 +311,7 @@ func runClientWriteIntegrationTest(ctx context.Context, t *testing.T, rdb *redis
"name": "panda",
"tracketytrack": trackID.String(),
},
Deadline: time.Now().Add(-1 * time.Hour),
})
require.NoError(t, err)
}
Expand Down Expand Up @@ -439,12 +439,10 @@ func TestClientDelIntegration(t *testing.T) {
require.Error(t, client.Del(ctx, trackIDs[0]+"oops"))
require.Error(t, client.Del(ctx, "bogustown"))

metaCancelationKey := fmt.Sprintf("%x", sha1.Sum([]byte(trackIDs[1])))

metaCancel, err := rdb.HGet(ctx, queue.MetaCancelationHash, metaCancelationKey).Result()
metaCancel, err := rdb.HGet(ctx, queue.MetaCancelationHash, trackIDs[1]).Result()
require.NoError(t, err)

rdb.HSet(ctx, queue.MetaCancelationHash, metaCancelationKey, "{{[,bogus"+metaCancel)
rdb.HSet(ctx, queue.MetaCancelationHash, trackIDs[1], "{{[,bogus"+metaCancel)

require.Error(t, client.Del(ctx, trackIDs[1]))

Expand All @@ -462,7 +460,20 @@ func TestClientGCIntegration(t *testing.T) {

runClientWriteIntegrationTest(ctx, t, rdb, client, true)

require.NoError(t, client.GC(ctx))
gcTrackedFields := []string{}

onGCFunc := func(_ context.Context, trackedFields []string) error {
gcTrackedFields = append(gcTrackedFields, trackedFields...)

return nil
}

total, nDeleted, err := client.GC(ctx, onGCFunc)
require.NoError(t, err)
require.Equal(t, uint64(15), total)
require.Equal(t, uint64(10), nDeleted)

require.Len(t, gcTrackedFields, 10)
}

// TestPickupLatencyIntegration runs a test with a mostly-empty queue -- by
Expand Down
49 changes: 0 additions & 49 deletions queue/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ package queue
import (
"context"
_ "embed" // to provide go:embed support
"strconv"
"strings"
"time"

"github.com/redis/go-redis/v9"
)
Expand Down Expand Up @@ -76,8 +74,6 @@ var (

const (
MetaCancelationHash = "meta:cancelation"

metaCancelationGCBatchSize = 100
)

func prepare(ctx context.Context, rdb redis.Cmdable) error {
Expand All @@ -101,48 +97,3 @@ func prepare(ctx context.Context, rdb redis.Cmdable) error {
}
return nil
}

func gcMetaCancelation(ctx context.Context, rdb redis.Cmdable) (int, error) {
now := time.Now().UTC().Unix()
keysToDelete := []string{}
iter := rdb.HScan(ctx, MetaCancelationHash, 0, "*:expiry:*", 0).Iterator()

for iter.Next(ctx) {
key := iter.Val()

keyParts := strings.Split(key, ":")
if len(keyParts) != 3 {
continue
}

keyTime, err := strconv.ParseInt(keyParts[2], 0, 64)
if err != nil {
continue
}

if keyTime > now {
keysToDelete = append(keysToDelete, key, keyParts[0])
}
}

if err := iter.Err(); err != nil {
return 0, err
}

for i := 0; i < len(keysToDelete); i += metaCancelationGCBatchSize {
sliceEnd := i + metaCancelationGCBatchSize
if sliceEnd > len(keysToDelete) {
sliceEnd = len(keysToDelete)
}

if err := rdb.HDel(
ctx,
MetaCancelationHash,
keysToDelete[i:sliceEnd]...,
).Err(); err != nil {
return 0, err
}
}

return len(keysToDelete), nil
}
5 changes: 3 additions & 2 deletions queue/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ func (e queueError) Error() string {
const Empty = queueError("queue: empty")

type WriteArgs struct {
Name string // queue name
Values map[string]any // message values
Name string // queue name
Values map[string]any // message values
Deadline time.Time // time after which message will be cancel (only when tracked)

Streams int // total number of streams
StreamsPerShard int // number of streams in each shard
Expand Down
Loading