Skip to content

Commit

Permalink
Merge 2bfac9e into 662e15d
Browse files Browse the repository at this point in the history
  • Loading branch information
szabado committed Apr 8, 2019
2 parents 662e15d + 2bfac9e commit d810a49
Show file tree
Hide file tree
Showing 15 changed files with 171 additions and 77 deletions.
2 changes: 1 addition & 1 deletion bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type Bucket interface {
// the number of tokens becomes available. A return value of 0 would mean no waiting is
// necessary. Success is true if tokens can be obtained, false if cannot be obtained within
// the specified maximum wait time.
Take(ctx context.Context, numTokens int64, maxWaitTime time.Duration) (waitTime time.Duration, success bool)
Take(ctx context.Context, numTokens int64, maxWaitTime time.Duration) (waitTime time.Duration, success bool, err error)
Config() *pbconfig.BucketConfig
// Dynamic indicates whether a bucket is a dynamic one, or one that is statically defined in
// configuration.
Expand Down
21 changes: 16 additions & 5 deletions buckets/bucket_impl_tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ func TestTokenAcquisition(t *testing.T, bucket quotaservice.Bucket) {
// Clear any stale state
bucket.Take(context.Background(), 1, 0)

wait, s := bucket.Take(context.Background(), 1, 0)
wait, s, err := bucket.Take(context.Background(), 1, 0)
if err != nil {
t.Fatalf("expected a nil error, got %s", err)
}
if wait != 0 {
t.Fatalf("Expecting 0 wait. Was %v", wait)
}
Expand All @@ -26,8 +29,10 @@ func TestTokenAcquisition(t *testing.T, bucket quotaservice.Bucket) {
}

// Consume all tokens. This should work too.
wait, s = bucket.Take(context.Background(), 100, 0)

wait, s, err = bucket.Take(context.Background(), 100, 0)
if err != nil {
t.Fatalf("expected a nil error, got %s", err)
}
if wait != 0 {
t.Fatalf("Expecting 0 wait. Was %v", wait)
}
Expand All @@ -36,7 +41,10 @@ func TestTokenAcquisition(t *testing.T, bucket quotaservice.Bucket) {
}

// Should have no more left. Should have to wait.
wait, s = bucket.Take(context.Background(), 10, 10*time.Second)
wait, s, err = bucket.Take(context.Background(), 10, 10*time.Second)
if err != nil {
t.Fatalf("expected a nil error, got %s", err)
}
if wait < 1 {
t.Fatalf("Expecting positive wait time. Was %v", wait)
}
Expand All @@ -45,7 +53,10 @@ func TestTokenAcquisition(t *testing.T, bucket quotaservice.Bucket) {
}

// If we don't want to wait...
wait, s = bucket.Take(context.Background(), 10, 0)
wait, s, err = bucket.Take(context.Background(), 10, 0)
if err != nil {
t.Fatalf("expected a nil error, got %s", err)
}
if wait != 0 {
t.Fatalf("Expecting 0 wait time. Was %v", wait)
}
Expand Down
8 changes: 5 additions & 3 deletions buckets/memory/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ func NewBucketFactory() quotaservice.BucketFactory {
return &bucketFactory{}
}

var _ quotaservice.Bucket = (*tokenBucket)(nil)

// tokenBucket is a single-threaded implementation. A single goroutine updates the values of
// tokensNextAvailable and accumulatedTokens. When requesting tokens, Take() puts a request on
// the waitTimer channel, and listens on the response channel in the request for a result. The
Expand All @@ -75,17 +77,17 @@ type waitTimeReq struct {
response chan int64
}

func (b *tokenBucket) Take(_ context.Context, numTokens int64, maxWaitTime time.Duration) (time.Duration, bool) {
func (b *tokenBucket) Take(_ context.Context, numTokens int64, maxWaitTime time.Duration) (time.Duration, bool, error) {
rsp := make(chan int64, 1)
b.waitTimer <- &waitTimeReq{numTokens, maxWaitTime.Nanoseconds(), rsp}
waitTimeNanos := <-rsp

if waitTimeNanos < 0 {
// Timed out
return 0, false
return 0, false, nil
}

return time.Duration(waitTimeNanos) * time.Nanosecond, true
return time.Duration(waitTimeNanos) * time.Nanosecond, true, nil
}

// calcWaitTime is designed to run in a single event loop and is not thread-safe.
Expand Down
45 changes: 22 additions & 23 deletions buckets/redis/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@ import (
"gopkg.in/redis.v5"

"github.com/opentracing/opentracing-go"
"github.com/pkg/errors"
"github.com/square/quotaservice"
"github.com/square/quotaservice/logging"
pbconfig "github.com/square/quotaservice/protos/config"
)

// redisBucket is an interface that defines the two different bucket types used with Redis: static and dynamic buckets.
type redisBucket interface {
}

// configAttributes represents certain values from a pbconfig.BucketConfig, represented as strings, for easy use as
// parameters to a Redis call.
type configAttributes struct {
Expand All @@ -44,7 +41,7 @@ func (a *abstractBucket) Config() *pbconfig.BucketConfig {
return a.cfg
}

func (a *abstractBucket) Take(ctx context.Context, requested int64, maxWaitTime time.Duration) (time.Duration, bool) {
func (a *abstractBucket) Take(ctx context.Context, requested int64, maxWaitTime time.Duration) (time.Duration, bool, error) {
currentTimeNanos := strconv.FormatInt(time.Now().UnixNano(), 10)

maxIdleTimeMillis := a.maxIdleTimeMillis
Expand All @@ -56,32 +53,30 @@ func (a *abstractBucket) Take(ctx context.Context, requested int64, maxWaitTime
strconv.FormatInt(requested, 10), strconv.FormatInt(maxWaitTime.Nanoseconds(), 10),
maxIdleTimeMillis, a.maxDebtNanos}

var waitTime time.Duration
var err error

client := a.factory.Client().(*redis.Client)
res := a.takeFromRedis(ctx, client, args)
switch waitTimeNanos := res.Val().(type) {
if err := res.Err(); err != nil {
if isRedisClientClosedError(err) {
logging.Print("Failed to take token from redis because the client was closed, reconnecting")
a.factory.handleConnectionFailure(client)
}
return 0, false, errors.Wrap(err, "failed to take token from redis bucket")
}

var waitTime time.Duration
switch val := res.Val().(type) {
case int64:
waitTime = time.Nanosecond * time.Duration(waitTimeNanos)
break
waitTime = time.Nanosecond * time.Duration(val)
default:
err = res.Err()
if unknownCloseError(err) {
logging.Printf("Unknown response '%v' of type %T. Full result %+v",
waitTimeNanos, waitTimeNanos, res)
}
// Handle connection failure
a.factory.handleConnectionFailure(client)
return 0, false
return 0, false, errors.Errorf("unknown response of type %[1]T: %[1]v", val)
}

if waitTime < 0 {
// Timed out
return 0, false
return 0, false, nil
}

return waitTime, true
return waitTime, true, nil
}

func (a *abstractBucket) takeFromRedis(ctx context.Context, client *redis.Client, args []interface{}) *redis.Cmd {
Expand All @@ -90,7 +85,9 @@ func (a *abstractBucket) takeFromRedis(ctx context.Context, client *redis.Client
return client.EvalSha(a.factory.scriptSHA, a.keys, args...)
}

// staticBucket is an implementation of a redisBucket for use with static, named buckets.
var _ quotaservice.Bucket = (*staticBucket)(nil)

// staticBucket is an implementation of quotaservice.Bucket for use with static, named buckets.
type staticBucket struct {
*abstractBucket
}
Expand All @@ -99,7 +96,9 @@ func (s *staticBucket) Dynamic() bool {
return false
}

// dynamicBucket is an implementation of a redisBucket for use with dynamic buckets created from a template.
var _ quotaservice.Bucket = (*dynamicBucket)(nil)

// dynamicBucket is an implementation of quotaservice.Bucket for use with dynamic buckets created from a template.
type dynamicBucket struct {
*abstractBucket
}
Expand Down
13 changes: 7 additions & 6 deletions buckets/redis/bucket_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@ package redis

import (
"strconv"
"sync"
"time"

"gopkg.in/redis.v5"

"github.com/pkg/errors"
"github.com/square/quotaservice"
"github.com/square/quotaservice/logging"

"sync"

pbconfig "github.com/square/quotaservice/protos/config"
)

Expand Down Expand Up @@ -115,7 +114,7 @@ func (bf *bucketFactory) reconnectToRedis(oldClient *redis.Client) {
defer bf.Unlock()

// Always close connections on errors to prevent results leaking.
if err := bf.client.Close(); unknownCloseError(err) {
if err := bf.client.Close(); !isRedisClientClosedError(err) {
logging.Printf("Received error on Redis client close: %+v", err)
}

Expand Down Expand Up @@ -254,8 +253,10 @@ func toRedisKey(namespace, bucketName, suffix string, version int32) string {
return namespace + ":" + bucketName + ":" + suffix + ":" + strconv.Itoa(int(version))
}

func unknownCloseError(err error) bool {
return err != nil && err.Error() != "redis: client is closed"
const redisClientClosedError = "redis: client is closed"

func isRedisClientClosedError(err error) bool {
return err != nil && errors.Cause(err).Error() == redisClientClosedError
}

func checkScriptExists(c *redis.Client, sha string) bool {
Expand Down
45 changes: 45 additions & 0 deletions buckets/redis/bucket_factory_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package redis

import (
"fmt"
"testing"

"github.com/pkg/errors"
)

func TestIsRedisClientClosedError(t *testing.T) {
tests := []struct {
input error
isCloseError bool
}{
{
// Test exactly the error
input: fmt.Errorf(redisClientClosedError),
isCloseError: true,
},
{
// Test the error wrapped
input: errors.Wrap(fmt.Errorf(redisClientClosedError), "obfuscate"),
isCloseError: true,
},
{
// test not the error
input: errors.New("just another error"),
isCloseError: false,
},
{
// test not the error wrapped with the text of the error (this should never happen)
input: errors.Wrap(errors.New("just another error"), redisClientClosedError),
isCloseError: false,
},
}

for _, test := range tests {
t.Run(test.input.Error(), func(t *testing.T) {
result := isRedisClientClosedError(test.input)
if result != test.isCloseError {
t.Fatal("failed to detect error")
}
})
}
}
34 changes: 20 additions & 14 deletions buckets/redis/bucket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@ package redis

import (
"context"
"fmt"
"os"
"testing"
"time"

"gopkg.in/redis.v5"

"fmt"

"time"

"github.com/square/quotaservice/buckets"
"github.com/square/quotaservice/config"
"github.com/square/quotaservice/protos/config"
Expand Down Expand Up @@ -54,41 +52,49 @@ func TestScriptLoaded(t *testing.T) {
}

func TestFailingRedisConn(t *testing.T) {
w, s := bucket.Take(context.Background(), 1, 0)

w, s, err := bucket.Take(context.Background(), 1, 0)
if err != nil {
t.Fatalf("Expected nil error, got: %v", err)
}
if w < 0 {
t.Fatalf("Should have not seen negative wait time. Saw %v", w)
}
if !s {
t.Fatalf("Success should be true.")
}

err := bucket.factory.client.Close()
err = bucket.factory.client.Close()
if err != nil {
t.Fatal("Couldn't kill client.")
}

// Client should fail to Take(). This should start the reconnect handler
w, s = bucket.Take(context.Background(), 1, 0)
w, s, err = bucket.Take(context.Background(), 1, 0)
if err == nil {
t.Fatalf("Expected error, got nil")
}
if w < 0 {
t.Fatalf("Should have not seen negative wait time. Saw %v", w)
}
if s {
t.Fatalf("Success should be false.")
if !s {
t.Fatalf("Success should be true.")
}

for numTimeWaited := bucket.factory.connectionRetries; bucket.factory.getNumTimesConnResolved() == 0 && numTimeWaited > 0; numTimeWaited-- {
time.Sleep(5 * time.Second)
}

// Client should reconnect
w, s = bucket.Take(context.Background(), 1, 0)
w, s, err = bucket.Take(context.Background(), 1, 0)
if err != nil {
t.Fatalf("Expected nil error, got: %v", err)
}
if w < 0 {
t.Fatalf("Should have not seen negative wait time. Saw %v", w)
}
//if !s {
// t.Fatalf("Success should be true.")
//}
if !s {
t.Fatalf("Success should be true.")
}
}

func TestTokenAcquisition(t *testing.T) {
Expand Down
3 changes: 2 additions & 1 deletion client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/square/quotaservice"
"github.com/square/quotaservice/buckets/memory"
"github.com/square/quotaservice/config"
"github.com/square/quotaservice/events"
pb "github.com/square/quotaservice/protos"
qsgrpc "github.com/square/quotaservice/rpc/grpc"
"github.com/square/quotaservice/test/helpers"
Expand Down Expand Up @@ -42,7 +43,7 @@ func setUp() {
config.NewMemoryConfig(cfg),
quotaservice.NewReaperConfigForTests(),
0,
qsgrpc.New(target))
qsgrpc.New(target, events.NewNilProducer()))

if _, err := server.Start(); err != nil {
helpers.PanicError(err)
Expand Down

0 comments on commit d810a49

Please sign in to comment.