Skip to content

Commit

Permalink
chore: set gw rate limits at event level (#4069)
Browse files Browse the repository at this point in the history
  • Loading branch information
cisse21 committed Nov 2, 2023
1 parent 9fd3c82 commit 8dd4ab9
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 30 deletions.
7 changes: 4 additions & 3 deletions gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ var _ = Describe("Gateway", func() {
})

It("should store messages successfully if rate limit is not reached for workspace", func() {
c.mockRateLimiter.EXPECT().CheckLimitReached(gomock.Any(), gomock.Any()).Return(false, nil).Times(1)
c.mockRateLimiter.EXPECT().CheckLimitReached(gomock.Any(), gomock.Any(), gomock.Any()).Return(false, nil).Times(1)
c.mockJobsDB.EXPECT().WithStoreSafeTx(gomock.Any(), gomock.Any()).Times(1).Do(func(ctx context.Context, f func(tx jobsdb.StoreSafeTx) error) {
_ = f(jobsdb.EmptyStoreSafeTx())
}).Return(nil)
Expand Down Expand Up @@ -884,10 +884,11 @@ var _ = Describe("Gateway", func() {
})

It("should reject messages if rate limit is reached for workspace", func() {
c.mockRateLimiter.EXPECT().CheckLimitReached(gomock.Any(), gomock.Any()).Return(true, nil).Times(1)
conf.Set("Gateway.allowReqsWithoutUserIDAndAnonymousID", true)
c.mockRateLimiter.EXPECT().CheckLimitReached(gomock.Any(), gomock.Any(), gomock.Any()).Return(true, nil).Times(1)
expectHandlerResponse(
gateway.webAliasHandler(),
authorizedRequest(WriteKeyEnabled, bytes.NewBufferString("{}")),
authorizedRequest(WriteKeyEnabled, bytes.NewBufferString(`{"data": "valid-json"}`)),
http.StatusTooManyRequests,
response.TooManyRequests+"\n",
"alias",
Expand Down
29 changes: 15 additions & 14 deletions gateway/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ func (gw *Handle) getJobDataFromRequest(req *webRequestT) (jobData *jobFromReq,
userIDHeader = req.userIDHeader
ipAddr = req.ipAddr
body = req.requestPayload

// values retrieved from first event in batch
sourcesJobRunID, sourcesTaskRunID = req.authContext.SourceJobRunID, req.authContext.SourceTaskRunID
)

fillMessageID := func(event map[string]interface{}) {
Expand Down Expand Up @@ -295,18 +298,6 @@ func (gw *Handle) getJobDataFromRequest(req *webRequestT) (jobData *jobFromReq,
eventsBatch := gjson.GetBytes(body, "batch").Array()
jobData.numEvents = len(eventsBatch)

if gw.conf.enableRateLimit.Load() {
// In case of "batch" requests, if rate-limiter returns true for LimitReached, just drop the event batch and continue.
ok, errCheck := gw.rateLimiter.CheckLimitReached(context.TODO(), workspaceId)
if errCheck != nil {
gw.stats.NewTaggedStat("gateway.rate_limiter_error", stats.CountType, stats.Tags{"workspaceId": workspaceId}).Increment()
gw.logger.Errorf("Rate limiter error: %v Allowing the request", errCheck)
}
if ok {
return jobData, errRequestDropped
}
}

type jobObject struct {
userID string
events []map[string]interface{}
Expand All @@ -317,8 +308,6 @@ func (gw *Handle) getJobDataFromRequest(req *webRequestT) (jobData *jobFromReq,
out []jobObject

marshalledParams []byte
// values retrieved from first event in batch
sourcesJobRunID, sourcesTaskRunID = req.authContext.SourceJobRunID, req.authContext.SourceTaskRunID

// facts about the batch populated as we iterate over events
containsAudienceList, suppressed bool
Expand Down Expand Up @@ -408,6 +397,18 @@ func (gw *Handle) getJobDataFromRequest(req *webRequestT) (jobData *jobFromReq,
})
}

if gw.conf.enableRateLimit.Load() && sourcesJobRunID == "" && sourcesTaskRunID == "" {
// In case of "batch" requests, if rate-limiter returns true for LimitReached, just drop the event batch and continue.
ok, errCheck := gw.rateLimiter.CheckLimitReached(context.TODO(), workspaceId, int64(len(eventsBatch)))
if errCheck != nil {
gw.stats.NewTaggedStat("gateway.rate_limiter_error", stats.CountType, stats.Tags{"workspaceId": workspaceId}).Increment()
gw.logger.Errorf("Rate limiter error: %v Allowing the request", errCheck)
}
if ok {
return jobData, errRequestDropped
}
}

if len(out) == 0 && suppressed {
err = errRequestSuppressed
return
Expand Down
10 changes: 5 additions & 5 deletions gateway/throttler/throttler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type Limiter interface {
}

type Throttler interface {
CheckLimitReached(context context.Context, workspaceId string) (bool, error)
CheckLimitReached(context context.Context, workspaceId string, eventCount int64) (bool, error)
}

type Factory struct {
Expand All @@ -45,9 +45,9 @@ func New(stats stats.Stats) (*Factory, error) {
return &f, nil
}

func (f *Factory) CheckLimitReached(context context.Context, workspaceId string) (bool, error) {
func (f *Factory) CheckLimitReached(context context.Context, workspaceId string, eventCount int64) (bool, error) {
t := f.get(workspaceId)
return t.checkLimitReached(context, workspaceId)
return t.checkLimitReached(context, workspaceId, eventCount)
}

func (f *Factory) get(workspaceId string) *throttler {
Expand Down Expand Up @@ -98,8 +98,8 @@ type throttler struct {
}

// checkLimitReached returns true if we're not allowed to process the number of event
func (t *throttler) checkLimitReached(ctx context.Context, key string) (limited bool, retErr error) {
allowed, _, err := t.limiter.Allow(ctx, 1, t.config.limit, getWindowInSecs(t.config.window), key)
func (t *throttler) checkLimitReached(ctx context.Context, key string, count int64) (limited bool, retErr error) {
allowed, _, err := t.limiter.Allow(ctx, count, t.config.limit, getWindowInSecs(t.config.window), key)
if err != nil {
return false, fmt.Errorf("could not limit: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions gateway/throttler/throttler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ func TestGateway_Throttler(t *testing.T) {
}

for i := 0; i < eventLimit; i++ {
_, err := testThrottler.checkLimitReached(context.TODO(), workspaceId)
_, err := testThrottler.checkLimitReached(context.TODO(), workspaceId, 1)
require.NoError(t, err)
}

startTime := time.Now()
var passed int
for i := 0; i < 2*eventLimit; i++ {
allowed, err := testThrottler.checkLimitReached(context.TODO(), workspaceId)
allowed, err := testThrottler.checkLimitReached(context.TODO(), workspaceId, 1)
require.NoError(t, err)
if allowed {
passed++
Expand Down Expand Up @@ -69,14 +69,14 @@ func TestGateway_Factory(t *testing.T) {
require.NotNil(t, rateLimiter)

for i := 0; i < eventLimit; i++ {
_, err := rateLimiter.CheckLimitReached(context.TODO(), workspaceId)
_, err := rateLimiter.CheckLimitReached(context.TODO(), workspaceId, 1)
require.NoError(t, err)
}

startTime := time.Now()
var passed int
for i := 0; i < 2*eventLimit; i++ {
allowed, err := rateLimiter.CheckLimitReached(context.TODO(), workspaceId)
allowed, err := rateLimiter.CheckLimitReached(context.TODO(), workspaceId, 1)
require.NoError(t, err)
if allowed {
passed++
Expand Down
8 changes: 4 additions & 4 deletions mocks/gateway/throttler.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 8dd4ab9

Please sign in to comment.