Skip to content

Commit

Permalink
chore: replay internal endpoint (#3746)
Browse files Browse the repository at this point in the history
  • Loading branch information
cisse21 committed Aug 21, 2023
1 parent 188b95c commit cd7557f
Show file tree
Hide file tree
Showing 16 changed files with 340 additions and 161 deletions.
4 changes: 4 additions & 0 deletions backend-config/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ type SourceT struct {
EventSchemasEnabled bool
}

func (s *SourceT) IsReplaySource() bool {
return s.OriginalID != ""
}

type WorkspaceRegulationT struct {
ID string
RegulationType string
Expand Down
26 changes: 23 additions & 3 deletions gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ const (
WriteKeyInvalid = "invalid-write-key"
WriteKeyEmpty = ""
SourceIDEnabled = "enabled-source"
ReplaySourceID = "replay-source"
ReplayWriteKey = "replay-source"
SourceIDDisabled = "disabled-source"
TestRemoteAddressWithPort = "test.com:80"
TestRemoteAddress = "test.com"
Expand Down Expand Up @@ -112,6 +114,16 @@ var sampleBackendConfig = backendconfig.ConfigT{
},
WorkspaceID: WorkspaceID,
},
{
ID: ReplaySourceID,
WriteKey: ReplayWriteKey,
Enabled: true,
OriginalID: ReplaySourceID,
SourceDefinition: backendconfig.SourceDefinitionT{
Name: SourceIDEnabled,
},
WorkspaceID: WorkspaceID,
},
},
}

Expand Down Expand Up @@ -255,7 +267,7 @@ var _ = Describe("Gateway Enterprise", func() {
It("should not accept events from suppress users", func() {
suppressedUserEventData := fmt.Sprintf(`{"batch":[{"userId":%q}]}`, SuppressedUserID)
// Why GET
expectHandlerResponse((gateway.webBatchHandler()), authorizedRequest(WriteKeyEnabled, bytes.NewBufferString(suppressedUserEventData)), http.StatusOK, "OK", "batch")
expectHandlerResponse(gateway.webBatchHandler(), authorizedRequest(WriteKeyEnabled, bytes.NewBufferString(suppressedUserEventData)), http.StatusOK, "OK", "batch")
Eventually(
func() bool {
stat := statsStore.Get(
Expand Down Expand Up @@ -450,7 +462,11 @@ var _ = Describe("Gateway", func() {
}
Expect(err).To(BeNil())
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(WriteKeyEnabled+":")))
if ep == "/internal/v1/replay" {
req.Header.Set("X-Rudder-Source-Id", ReplaySourceID)
} else {
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(WriteKeyEnabled+":")))
}
resp, err := client.Do(req)
Expect(err).To(BeNil())
Expect(resp.StatusCode).To(SatisfyAny(Equal(http.StatusOK), Equal(http.StatusNoContent)), "endpoint: "+ep)
Expand Down Expand Up @@ -1387,6 +1403,8 @@ func endpointsToVerify() ([]string, []string, []string) {
// TODO: Remove this endpoint once sources change is released
"/v1/warehouse/fetch-tables",
"/internal/v1/warehouse/fetch-tables",
"/internal/v1/job-status/123",
"/internal/v1/job-status/123/failed-records",
}

postEndpoints := []string{
Expand All @@ -1399,10 +1417,12 @@ func endpointsToVerify() ([]string, []string, []string) {
"/v1/merge",
"/v1/group",
"/v1/import",
"/v1/audiencelist",
"/v1/audiencelist", // Get rid of this over time and use the /internal endpoint
"/v1/webhook",
"/beacon/v1/batch",
"/internal/v1/extract",
"/internal/v1/replay",
"/internal/v1/audiencelist",
"/v1/warehouse/pending-events",
"/v1/warehouse/trigger-upload",
"/v1/warehouse/jobs",
Expand Down
18 changes: 5 additions & 13 deletions gateway/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ type Handle struct {
IdleTimeout time.Duration
allowReqsWithoutUserIDAndAnonymousID bool
gwAllowPartialWriteWithErrors bool
allowBatchSplitting bool
}
}

Expand Down Expand Up @@ -434,17 +433,6 @@ func (gw *Handle) getJobDataFromRequest(req *webRequestT) (jobData *jobFromReq,
`{"error": "rudder-server gateway failed to marshal params"}`,
)
}
if !gw.conf.allowBatchSplitting {
// instead of multiple jobs with one event, create one job with all events
out = []jobObject{
{
userID: out[0].userID,
events: lo.Map(out, func(userEvent jobObject, _ int) map[string]interface{} {
return userEvent.events[0]
}),
},
}
}
jobs := make([]*jobsdb.JobT, 0)
for _, userEvent := range out {
var (
Expand All @@ -458,11 +446,15 @@ func (gw *Handle) getJobDataFromRequest(req *webRequestT) (jobData *jobFromReq,
WriteKey string `json:"writeKey"`
ReceivedAt string `json:"receivedAt"`
}
receivedAt, ok := userEvent.events[0]["receivedAt"].(string)
if !ok || !arctx.ReplaySource {
receivedAt = time.Now().Format(misc.RFC3339Milli)
}
singularEventBatch := SingularEventBatch{
Batch: userEvent.events,
RequestIP: ipAddr,
WriteKey: arctx.WriteKey,
ReceivedAt: time.Now().Format(misc.RFC3339Milli),
ReceivedAt: receivedAt,
}
payload, err = json.Marshal(singularEventBatch)
if err != nil {
Expand Down
16 changes: 16 additions & 0 deletions gateway/handle_http_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,21 @@ func (gw *Handle) sourceIDAuth(delegate http.HandlerFunc) http.HandlerFunc {
}
}

// replaySourceIDAuth middleware to authenticate sourceID in the X-Rudder-Source-Id header.
// If the sourceID is valid, i.e. it is a replay source and enabled, the source auth info is added to the request context.
// If the sourceID is invalid, the request is rejected.
func (gw *Handle) replaySourceIDAuth(delegate http.HandlerFunc) http.HandlerFunc {
return gw.sourceIDAuth(func(w http.ResponseWriter, r *http.Request) {
arctx := r.Context().Value(gwtypes.CtxParamAuthRequestContext).(*gwtypes.AuthRequestContext)
s, ok := gw.sourceIDSourceMap[arctx.SourceID]
if !ok || !s.IsReplaySource() {
gw.handleHttpError(w, r, response.InvalidReplaySource)
return
}
delegate.ServeHTTP(w, r)
})
}

// augmentAuthRequestContext adds source job run id and task run id from the request to the authentication context.
func augmentAuthRequestContext(arctx *gwtypes.AuthRequestContext, r *http.Request) {
arctx.SourceJobRunID = r.Header.Get("X-Rudder-Job-Run-Id")
Expand Down Expand Up @@ -165,6 +180,7 @@ func sourceToRequestContext(s backendconfig.SourceT) *gwtypes.AuthRequestContext
SourceName: s.Name,
SourceCategory: s.SourceDefinition.Category,
SourceDefName: s.SourceDefinition.Name,
ReplaySource: s.IsReplaySource(),
}
if arctx.SourceCategory == "" {
arctx.SourceCategory = eventStreamSourceCategory
Expand Down
60 changes: 60 additions & 0 deletions gateway/handle_http_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,4 +277,64 @@ func TestAuth(t *testing.T) {
require.Equal(t, "Source is disabled\n", string(body))
})
})

t.Run("replaySourceIDAuth", func(t *testing.T) {
t.Run("replay source", func(t *testing.T) {
sourceID := "123"
gw := newGateway(nil, map[string]backendconfig.SourceT{
sourceID: {
ID: sourceID,
Enabled: true,
OriginalID: sourceID,
},
})
r := newSourceIDRequest(sourceID)
w := httptest.NewRecorder()
gw.replaySourceIDAuth(delegate).ServeHTTP(w, r)

require.Equal(t, http.StatusOK, w.Code, "authentication should succeed")
body, err := io.ReadAll(w.Body)
require.NoError(t, err, "reading response body should succeed")
require.Equal(t, "OK", string(body))
})

t.Run("invalid source using replay endpoint", func(t *testing.T) {
sourceID := "123"
invalidSource := "345"
gw := newGateway(nil, map[string]backendconfig.SourceT{
sourceID: {
ID: sourceID,
Enabled: true,
OriginalID: "",
},
})
r := newSourceIDRequest(invalidSource)
w := httptest.NewRecorder()
gw.replaySourceIDAuth(delegate).ServeHTTP(w, r)

require.Equal(t, http.StatusUnauthorized, w.Code, "authentication should not succeed")
body, err := io.ReadAll(w.Body)
require.NoError(t, err, "reading response body should succeed")
require.Equal(t, "Invalid source id\n", string(body))
})

t.Run("regular source using replay endpoint", func(t *testing.T) {
sourceID := "123"
gw := newGateway(nil, map[string]backendconfig.SourceT{
sourceID: {
ID: sourceID,
Enabled: true,
OriginalID: "",
},
})
r := newSourceIDRequest(sourceID)
w := httptest.NewRecorder()
gw.replaySourceIDAuth(delegate).ServeHTTP(w, r)

require.Equal(t, http.StatusUnauthorized, w.Code, "authentication should not succeed")
body, err := io.ReadAll(w.Body)
require.NoError(t, err, "reading response body should succeed")
require.Equal(t, "Invalid replay source\n", string(body))
})
})
}
8 changes: 8 additions & 0 deletions gateway/handle_http_replay.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package gateway

import "net/http"

// webImportHandler can handle import requests
func (gw *Handle) webReplayHandler() http.HandlerFunc {
return gw.callType("replay", gw.replaySourceIDAuth(gw.webHandler()))
}
18 changes: 10 additions & 8 deletions gateway/handle_lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import (
"strconv"
"time"

"golang.org/x/sync/errgroup"

"github.com/bugsnag/bugsnag-go/v2"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/rs/cors"
"github.com/samber/lo"
"golang.org/x/sync/errgroup"

"github.com/rudderlabs/rudder-go-kit/chiware"
"github.com/rudderlabs/rudder-go-kit/config"
Expand Down Expand Up @@ -91,7 +92,6 @@ func (gw *Handle) Setup(
// Enables accepting requests without user id and anonymous id. This is added to prevent client 4xx retries.
config.RegisterBoolConfigVariable(false, &gw.conf.allowReqsWithoutUserIDAndAnonymousID, true, "Gateway.allowReqsWithoutUserIDAndAnonymousID")
config.RegisterBoolConfigVariable(true, &gw.conf.gwAllowPartialWriteWithErrors, true, "Gateway.allowPartialWriteWithErrors")
config.RegisterBoolConfigVariable(true, &gw.conf.allowBatchSplitting, true, "Gateway.allowBatchSplitting")
config.RegisterDurationConfigVariable(0, &gw.conf.ReadTimeout, false, time.Second, []string{"ReadTimeout", "ReadTimeOutInSec"}...)
config.RegisterDurationConfigVariable(0, &gw.conf.ReadHeaderTimeout, false, time.Second, []string{"ReadHeaderTimeout", "ReadHeaderTimeoutInSec"}...)
config.RegisterDurationConfigVariable(10, &gw.conf.WriteTimeout, false, time.Second, []string{"WriteTimeout", "WriteTimeOutInSec"}...)
Expand Down Expand Up @@ -359,6 +359,10 @@ func (gw *Handle) StartWebHandler(ctx context.Context) error {
gw.logger.Infof("WebHandler Starting on %d", gw.conf.webPort)
component := "gateway"
srvMux := chi.NewRouter()
// rudder-sources new APIs
rsourcesHandler := rsources_http.NewHandler(
gw.rsourcesService,
gw.logger.Child("rsources"))
srvMux.Use(
chiware.StatMiddleware(ctx, srvMux, stats.Default, component),
middleware.LimitConcurrentRequests(gw.conf.maxConcurrentRequests),
Expand All @@ -367,7 +371,11 @@ func (gw *Handle) StartWebHandler(ctx context.Context) error {
srvMux.Route("/internal", func(r chi.Router) {
r.Post("/v1/extract", gw.webExtractHandler())
r.Get("/v1/warehouse/fetch-tables", gw.whProxy.ServeHTTP)
r.Post("/v1/audiencelist", gw.webAudienceListHandler())
r.Post("/v1/replay", gw.webReplayHandler())
r.Mount("/v1/job-status", withContentType("application/json; charset=utf-8", rsourcesHandler.ServeHTTP))
})
srvMux.Mount("/v1/job-status", withContentType("application/json; charset=utf-8", rsourcesHandler.ServeHTTP))

srvMux.Route("/v1", func(r chi.Router) {
r.Post("/alias", gw.webAliasHandler())
Expand Down Expand Up @@ -419,12 +427,6 @@ func (gw *Handle) StartWebHandler(ctx context.Context) error {
})
}

// rudder-sources new APIs
rsourcesHandler := rsources_http.NewHandler(
gw.rsourcesService,
gw.logger.Child("rsources"))
srvMux.Mount("/v1/job-status", withContentType("application/json; charset=utf-8", rsourcesHandler.ServeHTTP))

c := cors.New(cors.Options{
AllowOriginFunc: func(_ string) bool { return true },
AllowCredentials: true,
Expand Down
16 changes: 8 additions & 8 deletions gateway/internal/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ const (

// AuthRequestContext contains the authenticated source information for a request.
type AuthRequestContext struct {
SourceEnabled bool
SourceID string
WriteKey string
WorkspaceID string
SourceName string
SourceDefName string
SourceCategory string

SourceEnabled bool
SourceID string
WriteKey string
WorkspaceID string
SourceName string
SourceDefName string
SourceCategory string
ReplaySource bool
SourceJobRunID string
SourceTaskRunID string
}
Expand Down
3 changes: 3 additions & 0 deletions gateway/response/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ const (
NoSourceIdInHeader = "Failed to read source id from header"
// InvalidSourceID - Invalid source id
InvalidSourceID = "Invalid source id"
// InvalidReplaySource - Invalid replay source
InvalidReplaySource = "Invalid replay source"

transPixelResponse = "\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x00\x00\x00\x00\x00\x00\x00\x00\x21\xF9\x04" +
"\x01\x00\x00\x00\x00\x2C\x00\x00\x00\x00\x01\x00\x01\x00\x00\x02\x02\x44\x01\x00\x3B"
Expand All @@ -79,6 +81,7 @@ var statusMap = map[string]status{
InvalidJSON: {message: InvalidJSON, code: http.StatusBadRequest},
NoSourceIdInHeader: {message: NoSourceIdInHeader, code: http.StatusUnauthorized},
InvalidSourceID: {message: InvalidSourceID, code: http.StatusUnauthorized},
InvalidReplaySource: {message: InvalidReplaySource, code: http.StatusUnauthorized},

// webhook specific status
InvalidWebhookSource: {message: InvalidWebhookSource, code: http.StatusNotFound},
Expand Down
Loading

0 comments on commit cd7557f

Please sign in to comment.