Skip to content

Commit

Permalink
chore: use gcra rate limiter gateway (#3086)
Browse files Browse the repository at this point in the history
  • Loading branch information
cisse21 committed Mar 15, 2023
1 parent 10ccbf3 commit 61d9275
Show file tree
Hide file tree
Showing 13 changed files with 350 additions and 183 deletions.
25 changes: 13 additions & 12 deletions app/apphandlers/embeddedAppHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,21 @@ import (
"fmt"
"net/http"

"github.com/rudderlabs/rudder-server/config"
"github.com/rudderlabs/rudder-server/router/throttler"
"github.com/rudderlabs/rudder-server/utils/logger"
"github.com/rudderlabs/rudder-server/utils/payload"
"github.com/rudderlabs/rudder-server/utils/types/deployment"

"golang.org/x/sync/errgroup"

"github.com/rudderlabs/rudder-server/app"
"github.com/rudderlabs/rudder-server/app/cluster"
"github.com/rudderlabs/rudder-server/config"
backendconfig "github.com/rudderlabs/rudder-server/config/backend-config"
"github.com/rudderlabs/rudder-server/gateway"
gwThrottler "github.com/rudderlabs/rudder-server/gateway/throttler"
"github.com/rudderlabs/rudder-server/jobsdb"
"github.com/rudderlabs/rudder-server/jobsdb/prebackup"
"github.com/rudderlabs/rudder-server/processor"
ratelimiter "github.com/rudderlabs/rudder-server/rate-limiter"
"github.com/rudderlabs/rudder-server/router"
"github.com/rudderlabs/rudder-server/router/batchrouter"
routerManager "github.com/rudderlabs/rudder-server/router/manager"
rtThrottler "github.com/rudderlabs/rudder-server/router/throttler"
"github.com/rudderlabs/rudder-server/services/db"
destinationdebugger "github.com/rudderlabs/rudder-server/services/debugger/destination"
sourcedebugger "github.com/rudderlabs/rudder-server/services/debugger/source"
Expand All @@ -32,8 +28,11 @@ import (
"github.com/rudderlabs/rudder-server/services/multitenant"
"github.com/rudderlabs/rudder-server/services/stats"
"github.com/rudderlabs/rudder-server/services/transientsource"
"github.com/rudderlabs/rudder-server/utils/logger"
"github.com/rudderlabs/rudder-server/utils/misc"
"github.com/rudderlabs/rudder-server/utils/payload"
"github.com/rudderlabs/rudder-server/utils/types"
"github.com/rudderlabs/rudder-server/utils/types/deployment"
)

// embeddedApp is the type for embedded type implementation
Expand Down Expand Up @@ -213,9 +212,9 @@ func (a *embeddedApp) StartRudderCore(ctx context.Context, options *app.Options)
transformationhandle,
processor.WithAdaptiveLimit(adaptiveLimit),
)
throttlerFactory, err := throttler.New(stats.Default)
throttlerFactory, err := rtThrottler.New(stats.Default)
if err != nil {
return fmt.Errorf("failed to create throttler factory: %w", err)
return fmt.Errorf("failed to create rt throttler factory: %w", err)
}
rtFactory := &router.Factory{
Reporting: reportingI,
Expand Down Expand Up @@ -253,8 +252,10 @@ func (a *embeddedApp) StartRudderCore(ctx context.Context, options *app.Options)
MultiTenantStat: multitenantStats,
}

rateLimiter := ratelimiter.HandleT{}
rateLimiter.SetUp()
rateLimiter, err := gwThrottler.New(stats.Default)
if err != nil {
return fmt.Errorf("failed to create gw rate limiter: %w", err)
}
gw := gateway.HandleT{}
// This separate gateway db is created just to be used with gateway because in case of degraded mode,
// the earlier created gwDb (which was created to be used mainly with processor) will not be running, and it
Expand All @@ -274,7 +275,7 @@ func (a *embeddedApp) StartRudderCore(ctx context.Context, options *app.Options)
err = gw.Setup(
ctx,
a.app, backendconfig.DefaultBackendConfig, gatewayDB,
&rateLimiter, a.versionHandler, rsourcesService, sourceHandle,
rateLimiter, a.versionHandler, rsourcesService, sourceHandle,
)
if err != nil {
return fmt.Errorf("could not setup gateway: %w", err)
Expand Down
12 changes: 7 additions & 5 deletions app/apphandlers/gatewayAppHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ import (
"github.com/rudderlabs/rudder-server/config"
backendconfig "github.com/rudderlabs/rudder-server/config/backend-config"
"github.com/rudderlabs/rudder-server/gateway"
gwThrottler "github.com/rudderlabs/rudder-server/gateway/throttler"
"github.com/rudderlabs/rudder-server/jobsdb"
ratelimiter "github.com/rudderlabs/rudder-server/rate-limiter"
"github.com/rudderlabs/rudder-server/services/db"
sourcedebugger "github.com/rudderlabs/rudder-server/services/debugger/source"
"github.com/rudderlabs/rudder-server/services/fileuploader"
"github.com/rudderlabs/rudder-server/services/stats"
"github.com/rudderlabs/rudder-server/utils/logger"
"github.com/rudderlabs/rudder-server/utils/misc"
"github.com/rudderlabs/rudder-server/utils/types/deployment"
Expand Down Expand Up @@ -108,9 +109,10 @@ func (a *gatewayApp) StartRudderCore(ctx context.Context, options *app.Options)
})

var gw gateway.HandleT
var rateLimiter ratelimiter.HandleT

rateLimiter.SetUp()
rateLimiter, err := gwThrottler.New(stats.Default)
if err != nil {
return fmt.Errorf("failed to create rate limiter: %w", err)
}
gw.SetReadonlyDB(readonlyGatewayDB)
rsourcesService, err := NewRsourcesService(deploymentType)
if err != nil {
Expand All @@ -119,7 +121,7 @@ func (a *gatewayApp) StartRudderCore(ctx context.Context, options *app.Options)
err = gw.Setup(
ctx,
a.app, backendconfig.DefaultBackendConfig, gatewayDB,
&rateLimiter, a.versionHandler, rsourcesService, sourceHandle,
rateLimiter, a.versionHandler, rsourcesService, sourceHandle,
)
if err != nil {
return fmt.Errorf("failed to setup gateway: %w", err)
Expand Down
29 changes: 18 additions & 11 deletions gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ import (
event_schema "github.com/rudderlabs/rudder-server/event-schema"
gwstats "github.com/rudderlabs/rudder-server/gateway/internal/stats"
"github.com/rudderlabs/rudder-server/gateway/response"
"github.com/rudderlabs/rudder-server/gateway/throttler"
"github.com/rudderlabs/rudder-server/gateway/webhook"
"github.com/rudderlabs/rudder-server/jobsdb"
"github.com/rudderlabs/rudder-server/middleware"
ratelimiter "github.com/rudderlabs/rudder-server/rate-limiter"
"github.com/rudderlabs/rudder-server/rruntime"
sourcedebugger "github.com/rudderlabs/rudder-server/services/debugger/source"
"github.com/rudderlabs/rudder-server/services/diagnostics"
Expand Down Expand Up @@ -161,7 +161,7 @@ type HandleT struct {
ackCount uint64
recvCount uint64
backendConfig backendconfig.BackendConfig
rateLimiter ratelimiter.RateLimiter
rateLimiter throttler.Throttler

stats stats.Stats
batchSizeStat stats.Measurement
Expand Down Expand Up @@ -538,9 +538,13 @@ func (gateway *HandleT) getJobDataFromRequest(req *webRequestT) (jobData *jobFro

if enableRateLimit {
// In case of "batch" requests, if rate-limiter returns true for LimitReached, just drop the event batch and continue.
if gateway.rateLimiter.LimitReached(workspaceId) {
err = errRequestDropped
return
ok, errCheck := gateway.rateLimiter.CheckLimitReached(context.TODO(), workspaceId)
if errCheck != nil {
gateway.stats.NewTaggedStat("gateway.rate_limiter_error", stats.CountType, stats.Tags{"workspaceId": workspaceId}).Increment()
gateway.logger.Errorf("Rate limiter error: %v Allowing the request", errCheck)
}
if ok {
return jobData, errRequestDropped
}
}

Expand Down Expand Up @@ -604,7 +608,7 @@ func (gateway *HandleT) getJobDataFromRequest(req *webRequestT) (jobData *jobFro
"library",
"version",
).(string)
if firstSDKVersion != "" && !semverRegexp.Match([]byte(firstSDKVersion)) {
if firstSDKVersion != "" && !semverRegexp.Match([]byte(firstSDKVersion)) { // skipcq: CRT-A0007
firstSDKVersion = "invalid"
}
if firstSDKName != "" || firstSDKVersion != "" {
Expand Down Expand Up @@ -898,6 +902,7 @@ func warehouseHandler(w http.ResponseWriter, r *http.Request) {
origin, err := url.Parse(misc.GetWarehouseURL())
if err != nil {
http.Error(w, err.Error(), 404)
return
}
// gateway.logger.LogRequest(r)
director := func(req *http.Request) {
Expand Down Expand Up @@ -991,6 +996,7 @@ func (gateway *HandleT) webRequestHandler(rh RequestHandler, w http.ResponseWrit
if errorMessage != "" {
gateway.logger.Infof("IP: %s -- %s -- Response: %d, %s", misc.GetIPFromReq(r), r.URL.Path, response.GetErrorStatusCode(errorMessage), errorMessage)
http.Error(w, response.GetStatus(errorMessage), response.GetErrorStatusCode(errorMessage))
return
}
}()
payload, writeKey, err := gateway.getPayloadAndWriteKey(w, r, reqType)
Expand Down Expand Up @@ -1326,8 +1332,9 @@ func (gateway *HandleT) StartAdminHandler(ctx context.Context) error {
middleware.LimitConcurrentRequests(maxConcurrentRequests),
)
srv := &http.Server{
Addr: ":" + strconv.Itoa(adminWebPort),
Handler: bugsnag.Handler(srvMux),
Addr: ":" + strconv.Itoa(adminWebPort),
Handler: bugsnag.Handler(srvMux),
ReadHeaderTimeout: ReadHeaderTimeout,
}

return rs_httputil.ListenAndServe(ctx, srv)
Expand All @@ -1343,8 +1350,8 @@ func (gateway *HandleT) backendConfigSubscriber() {
newEnabledWriteKeyWorkspaceMap = map[string]string{}
newSourceIDToNameMap = map[string]string{}
)
config := data.Data.(map[string]backendconfig.ConfigT)
for workspaceID, wsConfig := range config {
configData := data.Data.(map[string]backendconfig.ConfigT)
for workspaceID, wsConfig := range configData {
for _, source := range wsConfig.Sources {
newSourceIDToNameMap[source.ID] = source.Name
newWriteKeysSourceMap[source.WriteKey] = source
Expand Down Expand Up @@ -1429,7 +1436,7 @@ This function will block until backend config is initially received.
func (gateway *HandleT) Setup(
ctx context.Context,
application app.App, backendConfig backendconfig.BackendConfig, jobsDB jobsdb.JobsDB,
rateLimiter ratelimiter.RateLimiter, versionHandler func(w http.ResponseWriter, r *http.Request),
rateLimiter throttler.Throttler, versionHandler func(w http.ResponseWriter, r *http.Request),
rsourcesService rsources.JobService, sourcehandle sourcedebugger.SourceDebugger,
) error {
gateway.logger = pkgLogger
Expand Down
18 changes: 7 additions & 11 deletions gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import (
"github.com/rudderlabs/rudder-server/jobsdb"
mocksApp "github.com/rudderlabs/rudder-server/mocks/app"
mocksBackendConfig "github.com/rudderlabs/rudder-server/mocks/config/backend-config"
mockThrottler "github.com/rudderlabs/rudder-server/mocks/gateway"
mocksJobsDB "github.com/rudderlabs/rudder-server/mocks/jobsdb"
mocksRateLimiter "github.com/rudderlabs/rudder-server/mocks/rate-limiter"
mocksTypes "github.com/rudderlabs/rudder-server/mocks/utils/types"
sourcedebugger "github.com/rudderlabs/rudder-server/services/debugger/source"
"github.com/rudderlabs/rudder-server/services/rsources"
Expand Down Expand Up @@ -106,8 +106,8 @@ type testContext struct {
mockCtrl *gomock.Controller
mockJobsDB *mocksJobsDB.MockJobsDB
mockBackendConfig *mocksBackendConfig.MockBackendConfig
mockRateLimiter *mockThrottler.MockThrottler
mockApp *mocksApp.MockApp
mockRateLimiter *mocksRateLimiter.MockRateLimiter

mockVersionHandler func(w http.ResponseWriter, r *http.Request)

Expand Down Expand Up @@ -138,7 +138,7 @@ func (c *testContext) Setup() {
c.mockJobsDB = mocksJobsDB.NewMockJobsDB(c.mockCtrl)
c.mockBackendConfig = mocksBackendConfig.NewMockBackendConfig(c.mockCtrl)
c.mockApp = mocksApp.NewMockApp(c.mockCtrl)
c.mockRateLimiter = mocksRateLimiter.NewMockRateLimiter(c.mockCtrl)
c.mockRateLimiter = mockThrottler.NewMockThrottler(c.mockCtrl)

c.mockBackendConfig.EXPECT().Subscribe(gomock.Any(), backendconfig.TopicProcessConfig).
DoAndReturn(func(ctx context.Context, topic backendconfig.Topic) pubsub.DataChannel {
Expand Down Expand Up @@ -505,15 +505,12 @@ var _ = Describe("Gateway", func() {
})

It("should store messages successfully if rate limit is not reached for workspace", func() {
mockCall := c.mockRateLimiter.EXPECT().LimitReached(WorkspaceID).Return(false).Times(1)
tFunc := c.asyncHelper.ExpectAndNotifyCallbackWithName("")
mockCall.Do(func(interface{}) { tFunc() })

c.mockRateLimiter.EXPECT().CheckLimitReached(gomock.Any(), gomock.Any()).Return(false, nil).Times(1)
c.mockJobsDB.EXPECT().WithStoreSafeTx(gomock.Any(), gomock.Any()).Times(1).Do(func(ctx context.Context, f func(tx jobsdb.StoreSafeTx) error) {
_ = f(jobsdb.EmptyStoreSafeTx())
}).Return(nil)
mockCall = c.mockJobsDB.EXPECT().StoreWithRetryEachInTx(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(jobsToEmptyErrors).Times(1)
tFunc = c.asyncHelper.ExpectAndNotifyCallbackWithName("")
mockCall := c.mockJobsDB.EXPECT().StoreWithRetryEachInTx(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(jobsToEmptyErrors).Times(1)
tFunc := c.asyncHelper.ExpectAndNotifyCallbackWithName("")
mockCall.Do(func(context.Context, interface{}, interface{}) { tFunc() })

expectHandlerResponse(
Expand Down Expand Up @@ -547,8 +544,7 @@ var _ = Describe("Gateway", func() {
})

It("should reject messages if rate limit is reached for workspace", func() {
c.mockRateLimiter.EXPECT().LimitReached(WorkspaceID).Return(true).Times(1)

c.mockRateLimiter.EXPECT().CheckLimitReached(gomock.Any(), gomock.Any()).Return(true, nil).Times(1)
expectHandlerResponse(
gateway.webAliasHandler,
authorizedRequest(WriteKeyEnabled, bytes.NewBufferString("{}")),
Expand Down

0 comments on commit 61d9275

Please sign in to comment.