From 5dacdcfbad533a310f33bf47fb8abd494f1b16fe Mon Sep 17 00:00:00 2001 From: Akash Chetty Date: Fri, 25 Aug 2023 12:07:06 +0530 Subject: [PATCH] chore: cleanup http handlers (#3767) --- warehouse/archive/archiver_test.go | 11 +- warehouse/backend_config_test.go | 4 +- warehouse/http.go | 430 ++++++++++ warehouse/http_test.go | 1064 +++++++++++++++++++++++++ warehouse/internal/api/http.go | 34 +- warehouse/internal/api/http_test.go | 14 +- warehouse/internal/errors/errors.go | 14 + warehouse/jobs/handlers.go | 143 ---- warehouse/jobs/http.go | 155 ++++ warehouse/jobs/http_test.go | 354 ++++++++ warehouse/jobs/runner.go | 4 +- warehouse/jobs/types.go | 5 - warehouse/jobs/utils.go | 7 - warehouse/jobs/utils_test.go | 36 - warehouse/logfield/logfield.go | 1 + warehouse/mode.go | 34 + warehouse/mode_test.go | 157 ++++ warehouse/multitenant/manager.go | 30 +- warehouse/multitenant/manager_test.go | 31 +- warehouse/router_test.go | 26 +- warehouse/slave_worker_test.go | 4 +- warehouse/utils/utils.go | 25 - warehouse/warehouse.go | 412 +--------- warehouse/warehousegrpc_test.go | 5 +- 24 files changed, 2315 insertions(+), 685 deletions(-) create mode 100644 warehouse/http.go create mode 100644 warehouse/http_test.go create mode 100644 warehouse/internal/errors/errors.go delete mode 100644 warehouse/jobs/handlers.go create mode 100644 warehouse/jobs/http.go create mode 100644 warehouse/jobs/http_test.go delete mode 100644 warehouse/jobs/utils_test.go create mode 100644 warehouse/mode.go create mode 100644 warehouse/mode_test.go diff --git a/warehouse/archive/archiver_test.go b/warehouse/archive/archiver_test.go index 402ca05660..a2ae148743 100644 --- a/warehouse/archive/archiver_test.go +++ b/warehouse/archive/archiver_test.go @@ -12,6 +12,8 @@ import ( "testing" "time" + backendConfig "github.com/rudderlabs/rudder-server/backend-config" + "github.com/rudderlabs/rudder-go-kit/config" "github.com/golang/mock/gomock" @@ -147,15 +149,18 @@ func TestArchiver(t *testing.T) { `, tc.workspaceID, now) require.NoError(t, err) + c := config.New() + c.Set("Warehouse.degradedWorkspaceIDs", tc.degradedWorkspaceIDs) + + tenantManager := multitenant.New(c, backendConfig.DefaultBackendConfig) + archiver := archive.New( config.Default, logger.NOP, mockStats, pgResource.DB, filemanager.New, - &multitenant.Manager{ - DegradedWorkspaceIDs: tc.degradedWorkspaceIDs, - }, + tenantManager, ) ctx, cancel := context.WithCancel(context.Background()) diff --git a/warehouse/backend_config_test.go b/warehouse/backend_config_test.go index 52a8d7f1a2..73b631a81a 100644 --- a/warehouse/backend_config_test.go +++ b/warehouse/backend_config_test.go @@ -110,9 +110,7 @@ func TestBackendConfigManager(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tenantManager = &multitenant.Manager{ - BackendConfig: mockBackendConfig, - } + tenantManager = multitenant.New(config.Default, mockBackendConfig) t.Run("Subscriptions", func(t *testing.T) { bcm := newBackendConfigManager(c, db, tenantManager, logger.NOP) diff --git a/warehouse/http.go b/warehouse/http.go new file mode 100644 index 0000000000..5b6754d188 --- /dev/null +++ b/warehouse/http.go @@ -0,0 +1,430 @@ +package warehouse + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "strconv" + "strings" + "time" + + "github.com/bugsnag/bugsnag-go/v2" + "github.com/go-chi/chi/v5" + + "github.com/rudderlabs/rudder-server/warehouse/internal/api" + ierrors "github.com/rudderlabs/rudder-server/warehouse/internal/errors" + lf "github.com/rudderlabs/rudder-server/warehouse/logfield" + + "github.com/rudderlabs/rudder-go-kit/config" + kithttputil "github.com/rudderlabs/rudder-go-kit/httputil" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + backendconfig "github.com/rudderlabs/rudder-server/backend-config" + "github.com/rudderlabs/rudder-server/services/pgnotifier" + sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + "github.com/rudderlabs/rudder-server/warehouse/internal/repo" + "github.com/rudderlabs/rudder-server/warehouse/jobs" + "github.com/rudderlabs/rudder-server/warehouse/multitenant" + warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" +) + +type pendingEventsRequest struct { + SourceID string `json:"source_id"` + TaskRunID string `json:"task_run_id"` +} + +type pendingEventsResponse struct { + PendingEvents bool `json:"pending_events"` + PendingStagingFilesCount int64 `json:"pending_staging_files"` + PendingUploadCount int64 `json:"pending_uploads"` + AbortedEvents bool `json:"aborted_events"` +} + +type fetchTablesRequest struct { + Connections []warehouseutils.SourceIDDestinationID `json:"connections"` +} + +type fetchTablesResponse struct { + ConnectionsTables []warehouseutils.FetchTableInfo `json:"connections_tables"` +} + +type triggerUploadRequest struct { + SourceID string `json:"source_id"` + DestinationID string `json:"destination_id"` +} + +type Api struct { + mode string + logger logger.Logger + statsFactory stats.Stats + db *sqlmw.DB + notifier *pgnotifier.PGNotifier + bcConfig backendconfig.BackendConfig + tenantManager *multitenant.Manager + bcManager *backendConfigManager + asyncManager *jobs.AsyncJobWh + stagingRepo *repo.StagingFiles + uploadRepo *repo.Uploads + schemaRepo *repo.WHSchema + + config struct { + healthTimeout time.Duration + readerHeaderTimeout time.Duration + runningMode string + webPort int + mode string + } +} + +func NewApi( + mode string, + conf *config.Config, + log logger.Logger, + statsFactory stats.Stats, + bcConfig backendconfig.BackendConfig, + db *sqlmw.DB, + notifier *pgnotifier.PGNotifier, + tenantManager *multitenant.Manager, + bcManager *backendConfigManager, + asyncManager *jobs.AsyncJobWh, +) *Api { + a := &Api{ + mode: mode, + logger: log.Child("api"), + db: db, + notifier: notifier, + bcConfig: bcConfig, + statsFactory: statsFactory, + tenantManager: tenantManager, + bcManager: bcManager, + asyncManager: asyncManager, + stagingRepo: repo.NewStagingFiles(db), + uploadRepo: repo.NewUploads(db), + schemaRepo: repo.NewWHSchemas(db), + } + a.config.healthTimeout = conf.GetDuration("Warehouse.healthTimeout", 10, time.Second) + a.config.readerHeaderTimeout = conf.GetDuration("Warehouse.readerHeaderTimeout", 3, time.Second) + a.config.runningMode = conf.GetString("Warehouse.runningMode", "") + a.config.webPort = conf.GetInt("Warehouse.webPort", 8082) + + return a +} + +func (a *Api) Start(ctx context.Context) error { + srvMux := chi.NewRouter() + + if isStandAlone(a.mode) { + srvMux.Get("/health", a.healthHandler) + } + if a.config.runningMode != DegradedMode { + if isMaster(a.mode) { + a.addMasterEndpoints(ctx, srvMux) + + a.logger.Infow("Starting warehouse master service on" + strconv.Itoa(a.config.webPort)) + } else { + a.logger.Infow("Starting warehouse slave service on" + strconv.Itoa(a.config.webPort)) + } + } + + srv := &http.Server{ + Addr: net.JoinHostPort("", strconv.Itoa(a.config.webPort)), + Handler: bugsnag.Handler(srvMux), + ReadHeaderTimeout: a.config.readerHeaderTimeout, + } + return kithttputil.ListenAndServe(ctx, srv) +} + +func (a *Api) addMasterEndpoints(ctx context.Context, r chi.Router) { + a.logger.Infow("waiting for BackendConfig before starting on " + strconv.Itoa(a.config.webPort)) + + a.bcConfig.WaitForConfig(ctx) + + r.Handle("/v1/process", (&api.WarehouseAPI{ + Logger: a.logger, + Stats: a.statsFactory, + Repo: a.stagingRepo, + Multitenant: a.tenantManager, + }).Handler()) + + r.Route("/v1", func(r chi.Router) { + r.Route("/warehouse", func(r chi.Router) { + r.Post("/pending-events", a.logMiddleware(a.pendingEventsHandler)) + r.Post("/trigger-upload", a.logMiddleware(a.triggerUploadHandler)) + + r.Post("/jobs", a.logMiddleware(a.asyncManager.InsertJobHandler)) // TODO: add degraded mode + r.Get("/jobs/status", a.logMiddleware(a.asyncManager.StatusJobHandler)) // TODO: add degraded mode + + r.Get("/fetch-tables", a.logMiddleware(a.fetchTablesHandler)) // TODO: Remove this endpoint once sources change is released + }) + }) + r.Route("/internal", func(r chi.Router) { + r.Route("/v1", func(r chi.Router) { + r.Route("/warehouse", func(r chi.Router) { + r.Get("/fetch-tables", a.logMiddleware(a.fetchTablesHandler)) + }) + }) + }) +} + +func (a *Api) healthHandler(w http.ResponseWriter, r *http.Request) { + var dbService, notifierService string + + ctx, cancel := context.WithTimeout(r.Context(), a.config.healthTimeout) + defer cancel() + + if a.config.runningMode != DegradedMode { + if !checkHealth(ctx, a.notifier.GetDBHandle()) { + http.Error(w, "Cannot connect to notifierService", http.StatusInternalServerError) + return + } + notifierService = "UP" + } + + if isMaster(a.mode) { + if !checkHealth(ctx, a.db.DB) { + http.Error(w, "Cannot connect to dbService", http.StatusInternalServerError) + return + } + dbService = "UP" + } + + healthVal := fmt.Sprintf(` +{ + "server": "UP", + "db": %q, + "notifier": %q, + "acceptingEvents": "TRUE", + "warehouseMode": %q +} + `, + dbService, + notifierService, + strings.ToUpper(a.mode), + ) + + _, _ = w.Write([]byte(healthVal)) +} + +func checkHealth(ctx context.Context, db *sql.DB) bool { + if db == nil { + return false + } + + healthCheckMsg := "Rudder Warehouse DB Health Check" + msg := "" + + err := db.QueryRowContext(ctx, `SELECT '`+healthCheckMsg+`'::text as message;`).Scan(&msg) + if err != nil { + return false + } + + return healthCheckMsg == msg +} + +// pendingEventsHandler check whether there are any pending staging files or uploads for the given source id +func (a *Api) pendingEventsHandler(w http.ResponseWriter, r *http.Request) { + defer func() { _ = r.Body.Close() }() + + var payload pendingEventsRequest + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + a.logger.Warnw("invalid JSON in request body for pending events", lf.Error, err.Error()) + http.Error(w, ierrors.ErrInvalidJSONRequestBody.Error(), http.StatusBadRequest) + return + } + + sourceID, taskRunID := payload.SourceID, payload.TaskRunID + if sourceID == "" || taskRunID == "" { + a.logger.Warnw("empty source or task run id for pending events", + lf.SourceID, payload.SourceID, + lf.TaskRunID, payload.TaskRunID, + ) + http.Error(w, "empty source or task run id", http.StatusBadRequest) + return + } + + workspaceID, err := a.tenantManager.SourceToWorkspace(r.Context(), sourceID) + if err != nil { + a.logger.Warnw("workspace from source not found for pending events", lf.SourceID, payload.SourceID) + http.Error(w, ierrors.ErrWorkspaceFromSourceNotFound.Error(), http.StatusBadRequest) + return + } + + if a.tenantManager.DegradedWorkspace(workspaceID) { + a.logger.Infow("workspace is degraded for pending events", lf.WorkspaceID, workspaceID) + http.Error(w, ierrors.ErrWorkspaceDegraded.Error(), http.StatusServiceUnavailable) + return + } + + pendingStagingFileCount, err := a.stagingRepo.CountPendingForSource(r.Context(), sourceID) + if err != nil { + if errors.Is(r.Context().Err(), context.Canceled) { + http.Error(w, ierrors.ErrRequestCancelled.Error(), http.StatusBadRequest) + return + } + a.logger.Errorw("counting pending staging files", lf.Error, err.Error()) + http.Error(w, "can't get pending staging files count", http.StatusInternalServerError) + return + } + + filters := []repo.FilterBy{ + {Key: "source_id", Value: sourceID}, + {Key: "metadata->>'source_task_run_id'", Value: taskRunID}, + {Key: "status", NotEquals: true, Value: model.ExportedData}, + {Key: "status", NotEquals: true, Value: model.Aborted}, + } + pendingUploadCount, err := a.uploadRepo.Count(r.Context(), filters...) + if err != nil { + if errors.Is(r.Context().Err(), context.Canceled) { + http.Error(w, ierrors.ErrRequestCancelled.Error(), http.StatusBadRequest) + return + } + a.logger.Errorw("counting pending uploads", lf.Error, err.Error()) + http.Error(w, "can't get pending uploads count", http.StatusInternalServerError) + return + } + + filters = []repo.FilterBy{ + {Key: "source_id", Value: sourceID}, + {Key: "metadata->>'source_task_run_id'", Value: payload.TaskRunID}, + {Key: "status", Value: "aborted"}, + } + abortedUploadCount, err := a.uploadRepo.Count(r.Context(), filters...) + if err != nil { + if errors.Is(r.Context().Err(), context.Canceled) { + http.Error(w, ierrors.ErrRequestCancelled.Error(), http.StatusBadRequest) + return + } + a.logger.Errorw("counting aborted uploads", lf.Error, err.Error()) + http.Error(w, "can't get aborted uploads count", http.StatusInternalServerError) + return + } + + pendingEventsAvailable := (pendingStagingFileCount + pendingUploadCount) > 0 + triggerPendingUpload, _ := strconv.ParseBool(r.URL.Query().Get(triggerUploadQPName)) + + if pendingEventsAvailable && triggerPendingUpload { + a.logger.Infow("triggering upload for all destinations connected to source", + lf.WorkspaceID, workspaceID, + lf.SourceID, payload.SourceID, + ) + + wh := a.bcManager.WarehousesBySourceID(sourceID) + if len(wh) == 0 { + a.logger.Warnw("no warehouse found for pending events", + lf.WorkspaceID, workspaceID, + lf.SourceID, payload.SourceID, + ) + http.Error(w, ierrors.ErrNoWarehouseFound.Error(), http.StatusBadRequest) + return + } + + for _, warehouse := range wh { + triggerUpload(warehouse) + } + } + + resBody, err := json.Marshal(pendingEventsResponse{ + PendingEvents: pendingEventsAvailable, + PendingStagingFilesCount: pendingStagingFileCount, + PendingUploadCount: pendingUploadCount, + AbortedEvents: abortedUploadCount > 0, + }) + if err != nil { + a.logger.Errorw("marshalling response for pending events", lf.Error, err.Error()) + http.Error(w, ierrors.ErrMarshallResponse.Error(), http.StatusInternalServerError) + return + } + + _, _ = w.Write(resBody) +} + +func (a *Api) triggerUploadHandler(w http.ResponseWriter, r *http.Request) { + defer func() { _ = r.Body.Close() }() + + var payload triggerUploadRequest + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + a.logger.Warnw("invalid JSON in request body for triggering upload", lf.Error, err.Error()) + http.Error(w, ierrors.ErrInvalidJSONRequestBody.Error(), http.StatusBadRequest) + return + } + + workspaceID, err := a.tenantManager.SourceToWorkspace(r.Context(), payload.SourceID) + if err != nil { + a.logger.Warnw("workspace from source not found for triggering upload", lf.SourceID, payload.SourceID) + http.Error(w, ierrors.ErrWorkspaceFromSourceNotFound.Error(), http.StatusBadRequest) + return + } + + if a.tenantManager.DegradedWorkspace(workspaceID) { + a.logger.Infow("workspace is degraded for triggering upload", lf.WorkspaceID, workspaceID) + http.Error(w, ierrors.ErrWorkspaceDegraded.Error(), http.StatusServiceUnavailable) + return + } + + var wh []model.Warehouse + if payload.SourceID != "" && payload.DestinationID == "" { + wh = a.bcManager.WarehousesBySourceID(payload.SourceID) + } else if payload.DestinationID != "" { + wh = a.bcManager.WarehousesByDestID(payload.DestinationID) + } + if len(wh) == 0 { + a.logger.Warnw("no warehouse found for triggering upload", + lf.WorkspaceID, workspaceID, + lf.SourceID, payload.SourceID, + lf.DestinationID, payload.DestinationID, + ) + http.Error(w, ierrors.ErrNoWarehouseFound.Error(), http.StatusBadRequest) + return + } + + for _, warehouse := range wh { + triggerUpload(warehouse) + } + + w.WriteHeader(http.StatusOK) +} + +func (a *Api) fetchTablesHandler(w http.ResponseWriter, r *http.Request) { + defer func() { _ = r.Body.Close() }() + + var payload fetchTablesRequest + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + a.logger.Warnw("invalid JSON in request body for fetching tables", lf.Error, err.Error()) + http.Error(w, ierrors.ErrInvalidJSONRequestBody.Error(), http.StatusBadRequest) + return + } + + tables, err := a.schemaRepo.GetTablesForConnection(r.Context(), payload.Connections) + if err != nil { + if errors.Is(r.Context().Err(), context.Canceled) { + http.Error(w, ierrors.ErrRequestCancelled.Error(), http.StatusBadRequest) + return + } + a.logger.Errorw("fetching tables", lf.Error, err.Error()) + http.Error(w, "can't fetch tables", http.StatusInternalServerError) + return + } + + resBody, err := json.Marshal(fetchTablesResponse{ + ConnectionsTables: tables, + }) + if err != nil { + a.logger.Errorw("marshalling response for fetching tables", lf.Error, err.Error()) + http.Error(w, ierrors.ErrMarshallResponse.Error(), http.StatusInternalServerError) + return + } + + _, _ = w.Write(resBody) +} + +func (a *Api) logMiddleware(delegate http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + a.logger.LogRequest(r) + delegate.ServeHTTP(w, r) + } +} diff --git a/warehouse/http_test.go b/warehouse/http_test.go new file mode 100644 index 0000000000..045ad1cd2d --- /dev/null +++ b/warehouse/http_test.go @@ -0,0 +1,1064 @@ +package warehouse + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/rudderlabs/rudder-server/utils/httputil" + + "golang.org/x/sync/errgroup" + + kithelper "github.com/rudderlabs/rudder-go-kit/testhelper" + sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + "github.com/rudderlabs/rudder-server/warehouse/jobs" + "github.com/rudderlabs/rudder-server/warehouse/multitenant" + + "github.com/golang/mock/gomock" + "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + "github.com/rudderlabs/rudder-server/warehouse/internal/repo" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" + backendconfig "github.com/rudderlabs/rudder-server/backend-config" + mocksBackendConfig "github.com/rudderlabs/rudder-server/mocks/backend-config" + "github.com/rudderlabs/rudder-server/services/pgnotifier" + migrator "github.com/rudderlabs/rudder-server/services/sql-migrator" + "github.com/rudderlabs/rudder-server/utils/pubsub" + warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" +) + +func TestHTTPApi(t *testing.T) { + pgnotifier.Init() + Init4() + + const ( + workspaceID = "test_workspace_id" + sourceID = "test_source_id" + destinationID = "test_destination_id" + degradedWorkspaceID = "degraded_test_workspace_id" + degradedSourceID = "degraded_test_source_id" + degradedDestinationID = "degraded_test_destination_id" + unusedWorkspaceID = "unused_test_workspace_id" + unusedSourceID = "unused_test_source_id" + unusedDestinationID = "unused_test_destination_id" + unsupportedWorkspaceID = "unsupported_test_workspace_id" + unsupportedSourceID = "unsupported_test_source_id" + unsupportedDestinationID = "unsupported_test_destination_id" + workspaceIdentifier = "test_workspace-identifier" + namespace = "test_namespace" + destinationType = "test_destination_type" + sourceTaskRunID = "test_source_task_run_id" + sourceJobID = "test_source_job_id" + sourceJobRunID = "test_source_job_run_id" + ) + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + pgResource, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + + t.Log("db:", pgResource.DBDsn) + + err = (&migrator.Migrator{ + Handle: pgResource.DB, + MigrationsTable: "wh_schema_migrations", + }).Migrate("warehouse") + require.NoError(t, err) + + ctrl := gomock.NewController(t) + mockBackendConfig := mocksBackendConfig.NewMockBackendConfig(ctrl) + mockBackendConfig.EXPECT().WaitForConfig(gomock.Any()).DoAndReturn(func(ctx context.Context) error { + return nil + }).AnyTimes() + mockBackendConfig.EXPECT().Subscribe(gomock.Any(), backendconfig.TopicBackendConfig).DoAndReturn(func(ctx context.Context, topic backendconfig.Topic) pubsub.DataChannel { + ch := make(chan pubsub.DataEvent, 1) + ch <- pubsub.DataEvent{ + Data: map[string]backendconfig.ConfigT{ + workspaceID: { + WorkspaceID: workspaceID, + Sources: []backendconfig.SourceT{ + { + ID: sourceID, + Enabled: true, + Destinations: []backendconfig.DestinationT{ + { + ID: destinationID, + Enabled: true, + DestinationDefinition: backendconfig.DestinationDefinitionT{ + Name: warehouseutils.POSTGRES, + }, + }, + }, + }, + }, + }, + degradedWorkspaceID: { + WorkspaceID: degradedWorkspaceID, + Sources: []backendconfig.SourceT{ + { + ID: degradedSourceID, + Enabled: true, + Destinations: []backendconfig.DestinationT{ + { + ID: degradedDestinationID, + Enabled: true, + DestinationDefinition: backendconfig.DestinationDefinitionT{ + Name: warehouseutils.POSTGRES, + }, + }, + }, + }, + }, + }, + unsupportedWorkspaceID: { + WorkspaceID: unsupportedWorkspaceID, + Sources: []backendconfig.SourceT{ + { + ID: unsupportedSourceID, + Enabled: true, + Destinations: []backendconfig.DestinationT{ + { + ID: unsupportedDestinationID, + Enabled: true, + DestinationDefinition: backendconfig.DestinationDefinitionT{ + Name: "unknown_destination_type", + }, + }, + }, + }, + }, + }, + unusedWorkspaceID: { + WorkspaceID: unusedWorkspaceID, + Sources: []backendconfig.SourceT{ + { + ID: unusedSourceID, + Enabled: true, + Destinations: []backendconfig.DestinationT{ + { + ID: unusedDestinationID, + Enabled: true, + DestinationDefinition: backendconfig.DestinationDefinitionT{ + Name: warehouseutils.POSTGRES, + }, + }, + }, + }, + }, + }, + }, + Topic: string(backendconfig.TopicBackendConfig), + } + close(ch) + return ch + }).AnyTimes() + + c := config.New() + c.Set("Warehouse.degradedWorkspaceIDs", []string{degradedWorkspaceID}) + + db := sqlmiddleware.New(pgResource.DB) + + notifier, err := pgnotifier.New(workspaceIdentifier, pgResource.DBDsn) + require.NoError(t, err) + + tenantManager := multitenant.New(c, mockBackendConfig) + + bcManager := newBackendConfigManager(config.Default, db, tenantManager, logger.NOP) + + ctx, stopTest := context.WithCancel(context.Background()) + + jobsManager := jobs.InitWarehouseJobsAPI( + ctx, + db.DB, + ¬ifier, + ) + jobs.WithConfig(jobsManager, config.Default) + + g, gCtx := errgroup.WithContext(ctx) + g.Go(func() error { + tenantManager.Run(gCtx) + return nil + }) + g.Go(func() error { + bcManager.Start(gCtx) + return nil + }) + g.Go(func() error { + return jobsManager.InitAsyncJobRunner() + }) + + setupCh := make(chan struct{}) + go func() { + require.NoError(t, g.Wait()) + close(setupCh) + }() + + now := time.Now().Truncate(time.Second).UTC() + stagingRepo := repo.NewStagingFiles(db, repo.WithNow(func() time.Time { + return now + })) + uploadsRepo := repo.NewUploads(db, repo.WithNow(func() time.Time { + return now + })) + tableUploadsRepo := repo.NewTableUploads(db, repo.WithNow(func() time.Time { + return now + })) + + stagingFile := model.StagingFile{ + WorkspaceID: workspaceID, + Location: "s3://bucket/path/to/file", + SourceID: sourceID, + DestinationID: destinationID, + Status: warehouseutils.StagingFileWaitingState, + Error: fmt.Errorf("dummy error"), + FirstEventAt: now.Add(time.Second), + UseRudderStorage: true, + DestinationRevisionID: "destination_revision_id", + TotalEvents: 100, + SourceTaskRunID: sourceTaskRunID, + SourceJobID: sourceJobID, + SourceJobRunID: sourceJobRunID, + TimeWindow: time.Date(1993, 8, 1, 3, 0, 0, 0, time.UTC), + }.WithSchema([]byte(`{"type": "object"}`)) + + failedStagingID, err := stagingRepo.Insert(ctx, &stagingFile) + require.NoError(t, err) + pendingStagingID, err := stagingRepo.Insert(ctx, &stagingFile) + require.NoError(t, err) + + _, err = uploadsRepo.CreateWithStagingFiles(ctx, model.Upload{ + WorkspaceID: workspaceID, + Namespace: namespace, + SourceID: sourceID, + DestinationID: destinationID, + DestinationType: destinationType, + Status: model.Aborted, + SourceJobRunID: sourceJobRunID, + SourceTaskRunID: sourceTaskRunID, + }, []*model.StagingFile{{ + ID: failedStagingID, + SourceID: sourceID, + DestinationID: destinationID, + SourceJobRunID: sourceJobRunID, + SourceTaskRunID: sourceTaskRunID, + }}) + require.NoError(t, err) + uploadID, err := uploadsRepo.CreateWithStagingFiles(ctx, model.Upload{ + WorkspaceID: workspaceID, + Namespace: namespace, + SourceID: sourceID, + DestinationID: destinationID, + DestinationType: destinationType, + Status: model.Waiting, + SourceJobRunID: sourceJobRunID, + SourceTaskRunID: sourceTaskRunID, + }, []*model.StagingFile{{ + ID: pendingStagingID, + SourceID: sourceID, + DestinationID: destinationID, + SourceJobRunID: sourceJobRunID, + SourceTaskRunID: sourceTaskRunID, + }}) + require.NoError(t, err) + + err = tableUploadsRepo.Insert(ctx, uploadID, []string{ + "test_table_1", + "test_table_2", + "test_table_3", + "test_table_4", + "test_table_5", + + "rudder_discards", + "rudder_identity_mappings", + "rudder_identity_merge_rules", + }) + require.NoError(t, err) + + for pendingStagingFiles := 0; pendingStagingFiles < 5; pendingStagingFiles++ { + _, err = stagingRepo.Insert(ctx, &stagingFile) + require.NoError(t, err) + } + + schemaRepo := repo.NewWHSchemas(db) + _, err = schemaRepo.Insert(ctx, &model.WHSchema{ + UploadID: 1, + SourceID: sourceID, + Namespace: namespace, + DestinationID: destinationID, + DestinationType: destinationType, + Schema: model.Schema{ + "test_table": { + "test_column": "test_data_type", + }, + }, + CreatedAt: now, + UpdatedAt: now, + }) + require.NoError(t, err) + + t.Run("health handler", func(t *testing.T) { + testCases := []struct { + name string + mode string + runningMode string + response map[string]string + }{ + { + name: "embedded", + mode: config.EmbeddedMode, + response: map[string]string{ + "acceptingEvents": "TRUE", + "db": "UP", + "notifier": "UP", + "server": "UP", + "warehouseMode": "EMBEDDED", + }, + }, + { + name: "master", + mode: config.MasterMode, + response: map[string]string{ + "acceptingEvents": "TRUE", + "db": "UP", + "notifier": "UP", + "server": "UP", + "warehouseMode": "MASTER", + }, + }, + { + name: "degraded master", + mode: config.MasterMode, + runningMode: "degraded", + response: map[string]string{ + "acceptingEvents": "TRUE", + "db": "UP", + "notifier": "", + "server": "UP", + "warehouseMode": "MASTER", + }, + }, + { + name: "master and slave", + mode: config.MasterSlaveMode, + response: map[string]string{ + "acceptingEvents": "TRUE", + "db": "UP", + "notifier": "UP", + "server": "UP", + "warehouseMode": "MASTER_AND_SLAVE", + }, + }, + { + name: "embedded master", + mode: config.EmbeddedMasterMode, + response: map[string]string{ + "acceptingEvents": "TRUE", + "db": "UP", + "notifier": "UP", + "server": "UP", + "warehouseMode": "EMBEDDED_MASTER", + }, + }, + { + name: "slave", + mode: config.SlaveMode, + response: map[string]string{ + "acceptingEvents": "TRUE", + "db": "", + "notifier": "UP", + "server": "UP", + "warehouseMode": "SLAVE", + }, + }, + { + name: "off", + mode: config.OffMode, + response: map[string]string{ + "acceptingEvents": "TRUE", + "db": "", + "notifier": "UP", + "server": "UP", + "warehouseMode": "OFF", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/health", nil) + resp := httptest.NewRecorder() + + c := config.New() + c.Set("Warehouse.runningMode", tc.runningMode) + + a := NewApi(tc.mode, c, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.healthHandler(resp, req) + + var healthBody map[string]string + err = json.NewDecoder(resp.Body).Decode(&healthBody) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.Code) + + require.EqualValues(t, healthBody, tc.response) + }) + } + }) + + t.Run("pending events handler", func(t *testing.T) { + t.Run("invalid payload", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/pending-events", bytes.NewReader([]byte(`"Invalid payload"`))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.pendingEventsHandler(resp, req) + require.Equal(t, http.StatusBadRequest, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "invalid JSON in request body\n", string(b)) + }) + + t.Run("empty source id or task run id", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/pending-events", bytes.NewReader([]byte(` + { + "source_id": "", + "task_run_id": "" + } + `))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.pendingEventsHandler(resp, req) + require.Equal(t, http.StatusBadRequest, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "empty source or task run id\n", string(b)) + }) + + t.Run("workspace not found", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/pending-events", bytes.NewReader([]byte(` + { + "source_id": "unknown_source_id", + "task_run_id": "unknown_task_run_id" + } + `))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.pendingEventsHandler(resp, req) + require.Equal(t, http.StatusBadRequest, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "workspace from source not found\n", string(b)) + }) + + t.Run("degraded workspace", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/pending-events", bytes.NewReader([]byte(` + { + "source_id": "degraded_test_source_id", + "task_run_id": "degraded_task_run_id" + } + `))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.pendingEventsHandler(resp, req) + require.Equal(t, http.StatusServiceUnavailable, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "workspace is degraded\n", string(b)) + }) + + t.Run("pending events available with without trigger uploads", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/pending-events?triggerUpload=false", bytes.NewReader([]byte(` + { + "source_id": "test_source_id", + "task_run_id": "test_source_task_run_id" + } + `))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.pendingEventsHandler(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + + var pendingEventsResponse pendingEventsResponse + err := json.NewDecoder(resp.Body).Decode(&pendingEventsResponse) + require.NoError(t, err) + + require.EqualValues(t, pendingEventsResponse.PendingEvents, true) + require.EqualValues(t, pendingEventsResponse.PendingUploadCount, 1) + require.EqualValues(t, pendingEventsResponse.PendingStagingFilesCount, 5) + require.EqualValues(t, pendingEventsResponse.AbortedEvents, true) + require.False(t, isUploadTriggered(model.Warehouse{ + Identifier: "POSTGRES:test_source_id:test_destination_id", + })) + }) + + t.Run("pending events available with trigger uploads", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/pending-events?triggerUpload=true", bytes.NewReader([]byte(` + { + "source_id": "test_source_id", + "task_run_id": "test_source_task_run_id" + } + `))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.pendingEventsHandler(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + + var pendingEventsResponse pendingEventsResponse + err := json.NewDecoder(resp.Body).Decode(&pendingEventsResponse) + require.NoError(t, err) + + defer func() { + clearTriggeredUpload(model.Warehouse{ + Identifier: "POSTGRES:test_source_id:test_destination_id", + }) + }() + + require.EqualValues(t, pendingEventsResponse.PendingEvents, true) + require.EqualValues(t, pendingEventsResponse.PendingUploadCount, 1) + require.EqualValues(t, pendingEventsResponse.PendingStagingFilesCount, 5) + require.EqualValues(t, pendingEventsResponse.AbortedEvents, true) + require.True(t, isUploadTriggered(model.Warehouse{ + Identifier: "POSTGRES:test_source_id:test_destination_id", + })) + }) + + t.Run("no pending events available", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/pending-events?triggerUpload=true", bytes.NewReader([]byte(` + { + "source_id": "unused_test_source_id", + "task_run_id": "unused_test_source_task_run_id" + } + `))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.pendingEventsHandler(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + + var pendingEventsResponse pendingEventsResponse + err := json.NewDecoder(resp.Body).Decode(&pendingEventsResponse) + require.NoError(t, err) + require.EqualValues(t, pendingEventsResponse.PendingEvents, false) + require.EqualValues(t, pendingEventsResponse.PendingUploadCount, 0) + require.EqualValues(t, pendingEventsResponse.PendingStagingFilesCount, 0) + require.EqualValues(t, pendingEventsResponse.AbortedEvents, false) + require.False(t, isUploadTriggered(model.Warehouse{ + Identifier: "POSTGRES:unused_test_source_id:unused_test_destination_id", + })) + }) + }) + + t.Run("fetch tables handler", func(t *testing.T) { + t.Run("invalid payload", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/internal/v1/warehouse/fetch-tables", bytes.NewReader([]byte(`"Invalid payload"`))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.fetchTablesHandler(resp, req) + require.Equal(t, http.StatusBadRequest, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "invalid JSON in request body\n", string(b)) + }) + + t.Run("empty connections", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/internal/v1/warehouse/fetch-tables", bytes.NewReader([]byte(` + { + "connections": [] + } + `))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.fetchTablesHandler(resp, req) + require.Equal(t, http.StatusInternalServerError, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "can't fetch tables\n", string(b)) + }) + + t.Run("succeed", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/internal/v1/warehouse/fetch-tables", bytes.NewReader([]byte(` + { + "connections": [ + { + "source_id": "test_source_id", + "destination_id": "test_destination_id" + } + ] + } + `))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.fetchTablesHandler(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + + var ftr fetchTablesResponse + err = json.NewDecoder(resp.Body).Decode(&ftr) + require.NoError(t, err) + require.EqualValues(t, ftr.ConnectionsTables, []warehouseutils.FetchTableInfo{ + { + SourceID: sourceID, + DestinationID: destinationID, + Namespace: namespace, + Tables: []string{"test_table"}, + }, + }) + }) + }) + + t.Run("trigger uploads handler", func(t *testing.T) { + t.Run("invalid payload", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/trigger-upload", bytes.NewReader([]byte(`"Invalid payload"`))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.triggerUploadHandler(resp, req) + require.Equal(t, http.StatusBadRequest, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "invalid JSON in request body\n", string(b)) + }) + + t.Run("workspace not found", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/trigger-upload", bytes.NewReader([]byte(` + { + "source_id": "unknown_source_id", + "destination_id": "unknown_destination_id" + } + `))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.triggerUploadHandler(resp, req) + require.Equal(t, http.StatusBadRequest, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "workspace from source not found\n", string(b)) + }) + + t.Run("degraded workspaces", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/trigger-upload", bytes.NewReader([]byte(` + { + "source_id": "degraded_test_source_id", + "destination_id": "degraded_test_destination_id" + } + `))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.triggerUploadHandler(resp, req) + + require.Equal(t, http.StatusServiceUnavailable, resp.Code) + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "workspace is degraded\n", string(b)) + }) + + t.Run("no warehouses", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/trigger-upload", bytes.NewReader([]byte(` + { + "source_id": "unsupported_test_source_id", + "destination_id": "unsupported_test_destination_id" + } + `))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.triggerUploadHandler(resp, req) + + require.Equal(t, http.StatusBadRequest, resp.Code) + require.False(t, isUploadTriggered(model.Warehouse{ + Identifier: "POSTGRES:unsupported_test_source_id:unsupported_test_destination_id", + })) + }) + + t.Run("without destination id", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/trigger-upload", bytes.NewReader([]byte(` + { + "source_id": "test_source_id", + "destination_id": "" + } + `))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.triggerUploadHandler(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + + defer func() { + clearTriggeredUpload(model.Warehouse{ + Identifier: "POSTGRES:test_source_id:test_destination_id", + }) + }() + require.True(t, isUploadTriggered(model.Warehouse{ + Identifier: "POSTGRES:test_source_id:test_destination_id", + })) + }) + + t.Run("with destination id", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/trigger-upload", bytes.NewReader([]byte(` + { + "source_id": "test_source_id", + "destination_id": "test_destination_id" + } + `))) + resp := httptest.NewRecorder() + + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a.triggerUploadHandler(resp, req) + + require.Equal(t, http.StatusOK, resp.Code) + + defer func() { + clearTriggeredUpload(model.Warehouse{ + Identifier: "POSTGRES:test_source_id:test_destination_id", + }) + }() + require.True(t, isUploadTriggered(model.Warehouse{ + Identifier: "POSTGRES:test_source_id:test_destination_id", + })) + }) + }) + + t.Run("endpoints", func(t *testing.T) { + t.Run("normal mode", func(t *testing.T) { + wenPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + c := config.New() + c.Set("Warehouse.webPort", wenPort) + + srvCtx, stopServer := context.WithCancel(ctx) + + a := NewApi(config.MasterMode, c, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + + serverSetupCh := make(chan struct{}) + go func() { + require.NoError(t, a.Start(srvCtx)) + + close(serverSetupCh) + }() + + serverURL := fmt.Sprintf("http://localhost:%d", wenPort) + + t.Run("health", func(t *testing.T) { + require.Eventually(t, func() bool { + resp, err := http.Get(fmt.Sprintf("%s/health", serverURL)) + if err != nil { + return false + } + defer func() { + httputil.CloseResponse(resp) + }() + + return resp.StatusCode == http.StatusOK + }, + time.Second*10, + time.Second, + ) + }) + + t.Run("process", func(t *testing.T) { + pendingEventsURL := fmt.Sprintf("%s/v1/process", serverURL) + req, err := http.NewRequest(http.MethodPost, pendingEventsURL, bytes.NewReader([]byte(` + { + "WorkspaceID": "test_workspace_id", + "Schema": { + "test_table": { + "test_column": "test_data_type" + } + }, + "BatchDestination": { + "Source": { + "ID": "test_source_id" + }, + "Destination": { + "ID": "test_destination_id" + } + }, + "Location": "rudder-warehouse-staging-logs/279L3gEKqwruBoKGsXZtSVX7vIy/2022-11-08/1667913810.279L3gEKqwruBoKGsXZtSVX7vIy.7a6e7785-7a75-4345-8d3c-d7a1ce49a43f.json.gz", + "FirstEventAt": "2022-11-08T13:23:07Z", + "LastEventAt": "2022-11-08T13:23:07Z", + "TotalEvents": 2, + "TotalBytes": 2000, + "UseRudderStorage": false, + "DestinationRevisionID": "2H1cLBvL3v0prRBNzpe8D34XTzU", + "SourceTaskRunID": "test_source_task_run_id", + "SourceJobID": "test_source_job_id", + "SourceJobRunID": "test_source_job_run_id", + "TimeWindow": "0001-01-01T00:40:00Z" + } + `))) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + resp, err := (&http.Client{}).Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + defer func() { + httputil.CloseResponse(resp) + }() + }) + + t.Run("pending events", func(t *testing.T) { + pendingEventsURL := fmt.Sprintf("%s/v1/warehouse/pending-events?triggerUpload=true", serverURL) + req, err := http.NewRequest(http.MethodPost, pendingEventsURL, bytes.NewReader([]byte(` + { + "source_id": "test_source_id", + "task_run_id": "test_source_task_run_id" + } + `))) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + resp, err := (&http.Client{}).Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + defer func() { + httputil.CloseResponse(resp) + }() + }) + + t.Run("trigger upload", func(t *testing.T) { + triggerUploadURL := fmt.Sprintf("%s/v1/warehouse/trigger-upload", serverURL) + req, err := http.NewRequest(http.MethodPost, triggerUploadURL, bytes.NewReader([]byte(` + { + "source_id": "test_source_id", + "destination_id": "test_destination_id" + } + `))) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + resp, err := (&http.Client{}).Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + defer func() { + httputil.CloseResponse(resp) + }() + }) + + t.Run("fetch tables", func(t *testing.T) { + for _, u := range []string{ + fmt.Sprintf("%s/v1/warehouse/fetch-tables", serverURL), + fmt.Sprintf("%s/internal/v1/warehouse/fetch-tables", serverURL), + } { + req, err := http.NewRequest(http.MethodGet, u, bytes.NewReader([]byte(` + { + "connections": [ + { + "source_id": "test_source_id", + "destination_id": "test_destination_id" + } + ] + } + `))) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + resp, err := (&http.Client{}).Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + defer func() { + httputil.CloseResponse(resp) + }() + } + }) + + t.Run("jobs", func(t *testing.T) { + jobsURL := fmt.Sprintf("%s/v1/warehouse/jobs", serverURL) + req, err := http.NewRequest(http.MethodPost, jobsURL, bytes.NewReader([]byte(` + { + "source_id": "test_source_id", + "destination_id": "test_destination_id", + "job_run_id": "test_source_job_run_id", + "task_run_id": "test_source_task_run_id" + } + `))) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + resp, err := (&http.Client{}).Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + defer func() { + httputil.CloseResponse(resp) + }() + }) + + t.Run("jobs status", func(t *testing.T) { + qp := url.Values{} + qp.Add("task_run_id", sourceTaskRunID) + qp.Add("job_run_id", sourceJobRunID) + qp.Add("source_id", sourceID) + qp.Add("destination_id", destinationID) + qp.Add("workspace_id", workspaceID) + + jobsStatusURL := fmt.Sprintf("%s/v1/warehouse/jobs/status?"+qp.Encode(), serverURL) + req, err := http.NewRequest(http.MethodGet, jobsStatusURL, nil) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + resp, err := (&http.Client{}).Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + defer func() { + httputil.CloseResponse(resp) + }() + }) + + stopServer() + + <-serverSetupCh + }) + + t.Run("degraded mode", func(t *testing.T) { + wenPort, err := kithelper.GetFreePort() + require.NoError(t, err) + + c := config.New() + c.Set("Warehouse.webPort", wenPort) + c.Set("Warehouse.runningMode", DegradedMode) + + srvCtx, stopServer := context.WithCancel(ctx) + + a := NewApi(config.MasterMode, c, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + + serverSetupCh := make(chan struct{}) + go func() { + require.NoError(t, a.Start(srvCtx)) + + close(serverSetupCh) + }() + + serverURL := fmt.Sprintf("http://localhost:%d", wenPort) + + t.Run("health endpoint should work", func(t *testing.T) { + require.Eventually(t, func() bool { + resp, err := http.Get(fmt.Sprintf("%s/health", serverURL)) + if err != nil { + return false + } + defer func() { + httputil.CloseResponse(resp) + }() + + return resp.StatusCode == http.StatusOK + }, + time.Second*10, + time.Second, + ) + }) + + t.Run("other endpoints should fail", func(t *testing.T) { + testCases := []struct { + name string + url string + method string + body io.Reader + }{ + { + name: "process", + url: fmt.Sprintf("%s/v1/process", serverURL), + method: http.MethodPost, + body: bytes.NewReader([]byte(`{}`)), + }, + { + name: "pending events", + url: fmt.Sprintf("%s/v1/warehouse/pending-events", serverURL), + method: http.MethodPost, + body: bytes.NewReader([]byte(`{}`)), + }, + { + name: "trigger upload", + url: fmt.Sprintf("%s/v1/warehouse/trigger-upload", serverURL), + method: http.MethodPost, + body: bytes.NewReader([]byte(`{}`)), + }, + { + name: "jobs", + url: fmt.Sprintf("%s/v1/warehouse/jobs", serverURL), + method: http.MethodPost, + body: bytes.NewReader([]byte(`{}`)), + }, + { + name: "jobs status", + url: fmt.Sprintf("%s/v1/warehouse/jobs/status", serverURL), + method: http.MethodGet, + body: nil, + }, + { + name: "fetch tables", + url: fmt.Sprintf("%s/v1/warehouse/fetch-tables", serverURL), + method: http.MethodGet, + body: nil, + }, + { + name: "internal fetch tables", + url: fmt.Sprintf("%s/internal/v1/warehouse/fetch-tables", serverURL), + method: http.MethodGet, + body: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest(tc.method, tc.url, tc.body) + require.NoError(t, err) + + resp, err := (&http.Client{}).Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusNotFound, resp.StatusCode) + + defer func() { + httputil.CloseResponse(resp) + }() + }) + } + }) + + stopServer() + + <-serverSetupCh + }) + }) + + stopTest() + + <-setupCh +} diff --git a/warehouse/internal/api/http.go b/warehouse/internal/api/http.go index e0b6580044..936fe4824f 100644 --- a/warehouse/internal/api/http.go +++ b/warehouse/internal/api/http.go @@ -2,10 +2,14 @@ package api import ( "context" + "errors" "fmt" "net/http" "time" + ierrors "github.com/rudderlabs/rudder-server/warehouse/internal/errors" + lf "github.com/rudderlabs/rudder-server/warehouse/logfield" + "github.com/go-chi/chi/v5" jsoniter "github.com/json-iterator/go" @@ -55,22 +59,22 @@ type stagingFileSchema struct { func mapStagingFile(payload *stagingFileSchema) (model.StagingFileWithSchema, error) { if payload.WorkspaceID == "" { - return model.StagingFileWithSchema{}, fmt.Errorf("workspaceId is required") + return model.StagingFileWithSchema{}, errors.New("workspaceId is required") } if payload.Location == "" { - return model.StagingFileWithSchema{}, fmt.Errorf("location is required") + return model.StagingFileWithSchema{}, errors.New("location is required") } if payload.BatchDestination.Source.ID == "" { - return model.StagingFileWithSchema{}, fmt.Errorf("batchDestination.source.id is required") + return model.StagingFileWithSchema{}, errors.New("batchDestination.source.id is required") } if payload.BatchDestination.Destination.ID == "" { - return model.StagingFileWithSchema{}, fmt.Errorf("batchDestination.destination.id is required") + return model.StagingFileWithSchema{}, errors.New("batchDestination.destination.id is required") } if len(payload.Schema) == 0 { - return model.StagingFileWithSchema{}, fmt.Errorf("schema is required") + return model.StagingFileWithSchema{}, errors.New("schema is required") } var schema []byte @@ -108,31 +112,35 @@ func (api *WarehouseAPI) Handler() http.Handler { } func (api *WarehouseAPI) processHandler(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - defer r.Body.Close() + defer func() { _ = r.Body.Close() }() var payload stagingFileSchema err := json.NewDecoder(r.Body).Decode(&payload) if err != nil { - api.Logger.Errorf("Error parsing body: %v", err) - http.Error(w, "can't unmarshal body", http.StatusBadRequest) + api.Logger.Warnw("invalid JSON in request body for processing staging file", lf.Error, err.Error()) + http.Error(w, ierrors.ErrInvalidJSONRequestBody.Error(), http.StatusBadRequest) return } stagingFile, err := mapStagingFile(&payload) if err != nil { - api.Logger.Warnf("invalid payload: %v", err) + api.Logger.Warnw("invalid payload for processing staging file", lf.Error, err.Error()) http.Error(w, fmt.Sprintf("invalid payload: %s", err.Error()), http.StatusBadRequest) return } if api.Multitenant.DegradedWorkspace(stagingFile.WorkspaceID) { - http.Error(w, "Workspace is degraded", http.StatusServiceUnavailable) + api.Logger.Infow("workspace is degraded for processing staging file", lf.WorkspaceID, stagingFile.WorkspaceID) + http.Error(w, ierrors.ErrWorkspaceDegraded.Error(), http.StatusServiceUnavailable) return } - if _, err := api.Repo.Insert(ctx, &stagingFile); err != nil { - api.Logger.Errorf("Error inserting staging file: %v", err) + if _, err := api.Repo.Insert(r.Context(), &stagingFile); err != nil { + if errors.Is(r.Context().Err(), context.Canceled) { + http.Error(w, ierrors.ErrRequestCancelled.Error(), http.StatusBadRequest) + return + } + api.Logger.Errorw("inserting staging file", lf.Error, err.Error()) http.Error(w, "can't insert staging file", http.StatusInternalServerError) return } diff --git a/warehouse/internal/api/http_test.go b/warehouse/internal/api/http_test.go index 1785237e74..352347baa7 100644 --- a/warehouse/internal/api/http_test.go +++ b/warehouse/internal/api/http_test.go @@ -13,6 +13,9 @@ import ( "testing" "time" + "github.com/rudderlabs/rudder-go-kit/config" + backendConfig "github.com/rudderlabs/rudder-server/backend-config" + "github.com/stretchr/testify/require" "github.com/rudderlabs/rudder-go-kit/logger" @@ -115,13 +118,13 @@ func TestAPI_Process(t *testing.T) { degradedWorkspaceIDs: []string{"279L3V7FSpx43LaNJ0nIs9KRaNC"}, respCode: http.StatusServiceUnavailable, - respBody: "Workspace is degraded\n", + respBody: "workspace is degraded\n", }, { name: "invalid request body missing", respCode: http.StatusBadRequest, - respBody: "can't unmarshal body\n", + respBody: "invalid JSON in request body\n", }, { name: "invalid request workspace id missing", @@ -166,9 +169,10 @@ func TestAPI_Process(t *testing.T) { err: tc.storeErr, } - m := &multitenant.Manager{ - DegradedWorkspaceIDs: tc.degradedWorkspaceIDs, - } + c := config.New() + c.Set("Warehouse.degradedWorkspaceIDs", tc.degradedWorkspaceIDs) + + m := multitenant.New(c, backendConfig.DefaultBackendConfig) wAPI := api.WarehouseAPI{ Repo: r, diff --git a/warehouse/internal/errors/errors.go b/warehouse/internal/errors/errors.go new file mode 100644 index 0000000000..8edf2d57e8 --- /dev/null +++ b/warehouse/internal/errors/errors.go @@ -0,0 +1,14 @@ +package errors + +import "errors" + +var ( + ErrInvalidJSONRequestBody = errors.New("invalid JSON in request body") + ErrRequestCancelled = errors.New("request cancelled") + ErrWorkspaceDegraded = errors.New("workspace is degraded") + ErrNoWarehouseFound = errors.New("no warehouse found") + ErrWorkspaceFromSourceNotFound = errors.New("workspace from source not found") + ErrMarshallResponse = errors.New("can't marshall response") + ErrInvalidRequest = errors.New("invalid request") + ErrJobsApiNotInitialized = errors.New("warehouse jobs api not initialized") +) diff --git a/warehouse/jobs/handlers.go b/warehouse/jobs/handlers.go deleted file mode 100644 index 1cd5d72eba..0000000000 --- a/warehouse/jobs/handlers.go +++ /dev/null @@ -1,143 +0,0 @@ -/* - Warehouse jobs package provides the capability to run arbitrary jobs on the warehouses using the query parameters provided. - Some jobs that can be run are - 1) delete by task run id, - 2) delete by job run id, - 3) delete by update_at - 4) any other update / clean up operations - - The following handlers file is the entry point for the handlers. -*/ - -package jobs - -import ( - "encoding/json" - "io" - "net/http" - "strings" -) - -// AddWarehouseJobHandler The following handler gets called for adding async -func (a *AsyncJobWh) AddWarehouseJobHandler(w http.ResponseWriter, r *http.Request) { - a.logger.Info("[WH-Jobs] Got Async Job Add Request") - if !a.enabled { - a.logger.Errorf("[WH-Jobs]: Error Warehouse Jobs API not initialized") - http.Error(w, "warehouse jobs api not initialized", http.StatusBadRequest) - return - } - body, err := io.ReadAll(r.Body) - if err != nil { - a.logger.LogRequest(r) - a.logger.Errorf("[WH-Jobs]: Error reading body: %v", err) - http.Error(w, "can't read body", http.StatusBadRequest) - return - } - _ = r.Body.Close() - var startJobPayload StartJobReqPayload - err = json.Unmarshal(body, &startJobPayload) - if err != nil { - a.logger.LogRequest(r) - a.logger.Errorf("[WH-Jobs]: Error unmarshalling body: %v", err) - http.Error(w, "can't unmarshall body", http.StatusBadRequest) - return - } - if !validatePayload(startJobPayload) { - a.logger.LogRequest(r) - a.logger.Errorf("[WH-Jobs]: Invalid Payload") - http.Error(w, "invalid Payload", http.StatusBadRequest) - return - } - tableNames, err := a.getTableNamesBy(startJobPayload.SourceID, startJobPayload.DestinationID, startJobPayload.JobRunID, startJobPayload.TaskRunID) - if err != nil { - a.logger.LogRequest(r) - a.logger.Errorf("[WH-Jobs]: Error extracting tableNames for the job run id: %v", err) - http.Error(w, "Error extracting tableNames", http.StatusBadRequest) - return - } - - var jobIds []int64 - // Add to wh_async_job queue each of the tables - for _, table := range tableNames { - - switch strings.ToLower(table) { - case "rudder_discards", "rudder_identity_mappings", "rudder_identity_merge_rules": - continue - } - - jobsMetaData := WhJobsMetaData{ - JobRunID: startJobPayload.JobRunID, - TaskRunID: startJobPayload.TaskRunID, - StartTime: startJobPayload.StartTime, - JobType: AsyncJobType, - } - metadataJson, err := json.Marshal(jobsMetaData) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - payload := AsyncJobPayload{ - SourceID: startJobPayload.SourceID, - DestinationID: startJobPayload.DestinationID, - TableName: table, - AsyncJobType: startJobPayload.AsyncJobType, - MetaData: metadataJson, - WorkspaceID: startJobPayload.WorkspaceID, - } - id, err := a.addJobsToDB(&payload) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - jobIds = append(jobIds, id) - } - whAddJobResponse := WhAddJobResponse{ - JobIds: jobIds, - Err: nil, - } - response, err := json.Marshal(whAddJobResponse) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - _, _ = w.Write(response) -} - -func (a *AsyncJobWh) StatusWarehouseJobHandler(w http.ResponseWriter, r *http.Request) { - a.logger.Info("[WH-Status Handler] Got Async Job Status Request") - if !a.enabled { - a.logger.Errorf("[WH]: Error Warehouse Jobs API not initialized") - http.Error(w, "warehouse jobs api not initialized", http.StatusBadRequest) - return - } - jobRunId := r.URL.Query().Get("job_run_id") - taskRunId := r.URL.Query().Get("task_run_id") - - sourceId := r.URL.Query().Get("source_id") - destinationId := r.URL.Query().Get("destination_id") - workspaceId := r.URL.Query().Get("workspace_id") - payload := StartJobReqPayload{ - TaskRunID: taskRunId, - JobRunID: jobRunId, - SourceID: sourceId, - DestinationID: destinationId, - WorkspaceID: workspaceId, - } - if !validatePayload(payload) { - a.logger.LogRequest(r) - a.logger.Errorf("[WH]: Error Invalid Payload") - http.Error(w, "invalid request", http.StatusBadRequest) - return - } - a.logger.Infof("Got Payload job_run_id %s, task_run_id %s \n", payload.JobRunID, payload.TaskRunID) - - response := a.getStatusAsyncJob(&payload) - - writeResponse, err := json.Marshal(response) - if err != nil { - a.logger.LogRequest(r) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - _, _ = w.Write(writeResponse) -} diff --git a/warehouse/jobs/http.go b/warehouse/jobs/http.go new file mode 100644 index 0000000000..85007e68dc --- /dev/null +++ b/warehouse/jobs/http.go @@ -0,0 +1,155 @@ +package jobs + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + + ierrors "github.com/rudderlabs/rudder-server/warehouse/internal/errors" + lf "github.com/rudderlabs/rudder-server/warehouse/logfield" + + "github.com/samber/lo" +) + +type insertJobResponse struct { + JobIds []int64 `json:"jobids"` + Err error `json:"error"` +} + +// InsertJobHandler adds a job to the warehouse_jobs table +func (a *AsyncJobWh) InsertJobHandler(w http.ResponseWriter, r *http.Request) { + defer func() { _ = r.Body.Close() }() + + if !a.enabled { + a.logger.Errorw("jobs api not initialized for inserting async job") + http.Error(w, ierrors.ErrJobsApiNotInitialized.Error(), http.StatusInternalServerError) + return + } + + var payload StartJobReqPayload + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + a.logger.Warnw("invalid JSON in request body for inserting async jobs", lf.Error, err.Error()) + http.Error(w, ierrors.ErrInvalidJSONRequestBody.Error(), http.StatusBadRequest) + return + } + + if err := validatePayload(&payload); err != nil { + a.logger.Warnw("invalid payload for inserting async job", lf.Error, err.Error()) + http.Error(w, fmt.Sprintf("invalid payload: %s", err.Error()), http.StatusBadRequest) + return + } + + // TODO: Move to repository + tableNames, err := a.tableNamesBy(payload.SourceID, payload.DestinationID, payload.JobRunID, payload.TaskRunID) + if err != nil { + a.logger.Errorw("extracting tableNames for inserting async job", lf.Error, err.Error()) + http.Error(w, "can't extract tableNames", http.StatusInternalServerError) + return + } + + tableNames = lo.Filter(tableNames, func(tableName string, i int) bool { + switch strings.ToLower(tableName) { + case "rudder_discards", "rudder_identity_mappings", "rudder_identity_merge_rules": + return false + default: + return true + } + }) + + jobIds := make([]int64, 0, len(tableNames)) + for _, table := range tableNames { + metadataJson, err := json.Marshal(WhJobsMetaData{ + JobRunID: payload.JobRunID, + TaskRunID: payload.TaskRunID, + StartTime: payload.StartTime, + JobType: AsyncJobType, + }) + if err != nil { + a.logger.Errorw("marshalling metadata for inserting async job", lf.Error, err.Error()) + http.Error(w, "can't marshall metadata", http.StatusInternalServerError) + return + } + + // TODO: Move to repository + id, err := a.addJobsToDB(&AsyncJobPayload{ + SourceID: payload.SourceID, + DestinationID: payload.DestinationID, + TableName: table, + AsyncJobType: payload.AsyncJobType, + MetaData: metadataJson, + WorkspaceID: payload.WorkspaceID, + }) + if err != nil { + a.logger.Errorw("inserting async job", lf.Error, err.Error()) + http.Error(w, "can't insert async job", http.StatusInternalServerError) + return + } + + jobIds = append(jobIds, id) + } + + resBody, err := json.Marshal(insertJobResponse{ + JobIds: jobIds, + Err: nil, + }) + if err != nil { + a.logger.Errorw("marshalling response for inserting async job", lf.Error, err.Error()) + http.Error(w, ierrors.ErrMarshallResponse.Error(), http.StatusInternalServerError) + return + } + + _, _ = w.Write(resBody) +} + +// StatusJobHandler The following handler gets called for getting the status of the async job +func (a *AsyncJobWh) StatusJobHandler(w http.ResponseWriter, r *http.Request) { + defer func() { _ = r.Body.Close() }() + + if !a.enabled { + a.logger.Errorw("jobs api not initialized for async job status") + http.Error(w, ierrors.ErrJobsApiNotInitialized.Error(), http.StatusInternalServerError) + return + } + + queryParams := r.URL.Query() + payload := StartJobReqPayload{ + TaskRunID: queryParams.Get("task_run_id"), + JobRunID: queryParams.Get("job_run_id"), + SourceID: queryParams.Get("source_id"), + DestinationID: queryParams.Get("destination_id"), + WorkspaceID: queryParams.Get("workspace_id"), + } + if err := validatePayload(&payload); err != nil { + a.logger.Warnw("invalid payload for async job status", lf.Error, err.Error()) + http.Error(w, fmt.Sprintf("invalid request: %s", err.Error()), http.StatusBadRequest) + return + } + + // TODO: Move to repository + jobStatus := a.jobStatus(&payload) + resBody, err := json.Marshal(jobStatus) + if err != nil { + a.logger.Errorw("marshalling response for async job status", lf.Error, err.Error()) + http.Error(w, ierrors.ErrMarshallResponse.Error(), http.StatusInternalServerError) + return + } + + _, _ = w.Write(resBody) +} + +func validatePayload(payload *StartJobReqPayload) error { + switch true { + case payload.SourceID == "": + return errors.New("source_id is required") + case payload.DestinationID == "": + return errors.New("destination_id is required") + case payload.JobRunID == "": + return errors.New("job_run_id is required") + case payload.TaskRunID == "": + return errors.New("task_run_id is required") + default: + return nil + } +} diff --git a/warehouse/jobs/http_test.go b/warehouse/jobs/http_test.go new file mode 100644 index 0000000000..966858cf2e --- /dev/null +++ b/warehouse/jobs/http_test.go @@ -0,0 +1,354 @@ +package jobs + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/ory/dockertest/v3" + + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" + "github.com/rudderlabs/rudder-server/services/pgnotifier" + migrator "github.com/rudderlabs/rudder-server/services/sql-migrator" + sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" + "github.com/rudderlabs/rudder-server/warehouse/internal/repo" + whutils "github.com/rudderlabs/rudder-server/warehouse/utils" + + "github.com/stretchr/testify/require" +) + +func TestAsyncJobHandlers(t *testing.T) { + pgnotifier.Init() + + const ( + workspaceID = "test_workspace_id" + sourceID = "test_source_id" + destinationID = "test_destination_id" + workspaceIdentifier = "test_workspace-identifier" + namespace = "test_namespace" + destinationType = "test_destination_type" + sourceTaskRunID = "test_source_task_run_id" + sourceJobID = "test_source_job_id" + sourceJobRunID = "test_source_job_run_id" + ) + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + pgResource, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + + t.Log("db:", pgResource.DBDsn) + + err = (&migrator.Migrator{ + Handle: pgResource.DB, + MigrationsTable: "wh_schema_migrations", + }).Migrate("warehouse") + require.NoError(t, err) + + db := sqlmiddleware.New(pgResource.DB) + + notifier, err := pgnotifier.New(workspaceIdentifier, pgResource.DBDsn) + require.NoError(t, err) + + ctx := context.Background() + + now := time.Now().Truncate(time.Second).UTC() + + uploadsRepo := repo.NewUploads(db, repo.WithNow(func() time.Time { + return now + })) + tableUploadsRepo := repo.NewTableUploads(db, repo.WithNow(func() time.Time { + return now + })) + stagingRepo := repo.NewStagingFiles(db, repo.WithNow(func() time.Time { + return now + })) + + stagingFile := model.StagingFile{ + WorkspaceID: workspaceID, + Location: "s3://bucket/path/to/file", + SourceID: sourceID, + DestinationID: destinationID, + Status: whutils.StagingFileWaitingState, + Error: fmt.Errorf("dummy error"), + FirstEventAt: now.Add(time.Second), + UseRudderStorage: true, + DestinationRevisionID: "destination_revision_id", + TotalEvents: 100, + SourceTaskRunID: sourceTaskRunID, + SourceJobID: sourceJobID, + SourceJobRunID: sourceJobRunID, + TimeWindow: time.Date(1993, 8, 1, 3, 0, 0, 0, time.UTC), + }.WithSchema([]byte(`{"type": "object"}`)) + + stagingID, err := stagingRepo.Insert(ctx, &stagingFile) + require.NoError(t, err) + + uploadID, err := uploadsRepo.CreateWithStagingFiles(ctx, model.Upload{ + WorkspaceID: workspaceID, + Namespace: namespace, + SourceID: sourceID, + DestinationID: destinationID, + DestinationType: destinationType, + Status: model.Aborted, + SourceJobRunID: sourceJobRunID, + SourceTaskRunID: sourceTaskRunID, + }, []*model.StagingFile{{ + ID: stagingID, + SourceID: sourceID, + DestinationID: destinationID, + SourceJobRunID: sourceJobRunID, + SourceTaskRunID: sourceTaskRunID, + }}) + require.NoError(t, err) + + err = tableUploadsRepo.Insert(ctx, uploadID, []string{ + "test_table_1", + "test_table_2", + "test_table_3", + "test_table_4", + "test_table_5", + + "rudder_discards", + "rudder_identity_mappings", + "rudder_identity_merge_rules", + }) + require.NoError(t, err) + + t.Run("validate payload", func(t *testing.T) { + testCases := []struct { + name string + payload StartJobReqPayload + expectedError error + }{ + { + name: "invalid source", + payload: StartJobReqPayload{ + JobRunID: "job_run_id", + TaskRunID: "task_run_id", + SourceID: "", + DestinationID: "destination_id", + WorkspaceID: "workspace_id", + }, + expectedError: errors.New("source_id is required"), + }, + { + name: "invalid destination", + payload: StartJobReqPayload{ + JobRunID: "job_run_id", + TaskRunID: "task_run_id", + SourceID: "source_id", + DestinationID: "", + WorkspaceID: "workspace_id", + }, + expectedError: errors.New("destination_id is required"), + }, + { + name: "invalid task run", + payload: StartJobReqPayload{ + JobRunID: "job_run_id", + TaskRunID: "", + SourceID: "source_id", + DestinationID: "destination_id", + WorkspaceID: "workspace_id", + }, + expectedError: errors.New("task_run_id is required"), + }, + { + name: "invalid job run", + payload: StartJobReqPayload{ + JobRunID: "", + TaskRunID: "task_run_id", + SourceID: "source_id", + DestinationID: "destination_id", + WorkspaceID: "workspace_id", + }, + expectedError: errors.New("job_run_id is required"), + }, + { + name: "valid payload", + payload: StartJobReqPayload{ + JobRunID: "job_run_id", + TaskRunID: "task_run_id", + SourceID: "source_id", + DestinationID: "destination_id", + WorkspaceID: "workspace_id", + }, + }, + } + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.expectedError, validatePayload(&tc.payload)) + }) + } + }) + + t.Run("InsertJobHandler", func(t *testing.T) { + t.Run("Not enabled", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/jobs", nil) + resp := httptest.NewRecorder() + + jobsManager := AsyncJobWh{ + dbHandle: db.DB, + enabled: false, + logger: logger.NOP, + context: ctx, + pgnotifier: ¬ifier, + } + jobsManager.InsertJobHandler(resp, req) + require.Equal(t, http.StatusInternalServerError, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "warehouse jobs api not initialized\n", string(b)) + }) + t.Run("invalid payload", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/jobs", bytes.NewReader([]byte(`"Invalid payload"`))) + resp := httptest.NewRecorder() + + jobsManager := AsyncJobWh{ + dbHandle: db.DB, + enabled: true, + logger: logger.NOP, + context: ctx, + pgnotifier: ¬ifier, + } + jobsManager.InsertJobHandler(resp, req) + require.Equal(t, http.StatusBadRequest, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "invalid JSON in request body\n", string(b)) + }) + t.Run("invalid request", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/jobs", bytes.NewReader([]byte(`{}`))) + resp := httptest.NewRecorder() + + jobsManager := AsyncJobWh{ + dbHandle: db.DB, + enabled: true, + logger: logger.NOP, + context: ctx, + pgnotifier: ¬ifier, + } + jobsManager.InsertJobHandler(resp, req) + require.Equal(t, http.StatusBadRequest, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "invalid payload: source_id is required\n", string(b)) + }) + t.Run("success", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/jobs", bytes.NewReader([]byte(` + { + "source_id": "test_source_id", + "destination_id": "test_destination_id", + "job_run_id": "test_source_job_run_id", + "task_run_id": "test_source_task_run_id" + } + `))) + resp := httptest.NewRecorder() + + jobsManager := AsyncJobWh{ + dbHandle: db.DB, + enabled: true, + logger: logger.NOP, + context: ctx, + pgnotifier: ¬ifier, + } + jobsManager.InsertJobHandler(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + + var insertResponse insertJobResponse + err = json.NewDecoder(resp.Body).Decode(&insertResponse) + require.NoError(t, err) + require.Nil(t, insertResponse.Err) + require.Len(t, insertResponse.JobIds, 5) + }) + }) + + t.Run("StatusJobHandler", func(t *testing.T) { + t.Run("Not enabled", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/warehouse/jobs/status", nil) + resp := httptest.NewRecorder() + + jobsManager := AsyncJobWh{ + dbHandle: db.DB, + enabled: false, + logger: logger.NOP, + context: ctx, + pgnotifier: ¬ifier, + } + jobsManager.StatusJobHandler(resp, req) + require.Equal(t, http.StatusInternalServerError, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "warehouse jobs api not initialized\n", string(b)) + }) + t.Run("invalid payload", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v1/warehouse/jobs/status", nil) + resp := httptest.NewRecorder() + + jobsManager := AsyncJobWh{ + dbHandle: db.DB, + enabled: true, + logger: logger.NOP, + context: ctx, + pgnotifier: ¬ifier, + } + jobsManager.StatusJobHandler(resp, req) + require.Equal(t, http.StatusBadRequest, resp.Code) + + b, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "invalid request: source_id is required\n", string(b)) + }) + t.Run("success", func(t *testing.T) { + _, err := db.ExecContext(ctx, ` + INSERT INTO `+whutils.WarehouseAsyncJobTable+` (source_id, destination_id, status, created_at, updated_at, tablename, error, async_job_type, metadata, workspace_id) + VALUES ('test_source_id', 'test_destination_id', 'aborted', NOW(), NOW(), 'test_table_name', 'test_error', 'deletebyjobrunid', '{"job_run_id": "test_source_job_run_id", "task_run_id": "test_source_task_run_id"}', 'test_workspace_id') + `) + require.NoError(t, err) + + qp := url.Values{} + qp.Add("task_run_id", sourceTaskRunID) + qp.Add("job_run_id", sourceJobRunID) + qp.Add("source_id", sourceID) + qp.Add("destination_id", destinationID) + qp.Add("workspace_id", workspaceID) + + req := httptest.NewRequest(http.MethodGet, "/v1/warehouse/jobs/status?"+qp.Encode(), nil) + resp := httptest.NewRecorder() + + jobsManager := AsyncJobWh{ + dbHandle: db.DB, + enabled: true, + logger: logger.NOP, + context: ctx, + pgnotifier: ¬ifier, + } + jobsManager.StatusJobHandler(resp, req) + require.Equal(t, http.StatusOK, resp.Code) + + var statusResponse WhStatusResponse + err = json.NewDecoder(resp.Body).Decode(&statusResponse) + require.NoError(t, err) + require.Equal(t, statusResponse.Status, "aborted") + require.Equal(t, statusResponse.Err, "test_error") + }) + }) +} diff --git a/warehouse/jobs/runner.go b/warehouse/jobs/runner.go index 22cb7046b5..df5b001696 100644 --- a/warehouse/jobs/runner.go +++ b/warehouse/jobs/runner.go @@ -43,7 +43,7 @@ func WithConfig(a *AsyncJobWh, config *config.Config) { a.asyncJobTimeOut = config.GetDuration("Warehouse.jobs.asyncJobTimeOut", 300, time.Second) } -func (a *AsyncJobWh) getTableNamesBy(sourceID, destinationID, jobRunID, taskRunID string) ([]string, error) { +func (a *AsyncJobWh) tableNamesBy(sourceID, destinationID, jobRunID, taskRunID string) ([]string, error) { a.logger.Infof("[WH-Jobs]: Extracting table names for the job run id %s", jobRunID) var tableNames []string var err error @@ -348,7 +348,7 @@ func (a *AsyncJobWh) updateAsyncJobAttempt(ctx context.Context, Id string) error // returns status and errMessage // Only succeeded, executing & waiting states should have empty errMessage // Rest of the states failed, aborted should send an error message conveying a message -func (a *AsyncJobWh) getStatusAsyncJob(payload *StartJobReqPayload) WhStatusResponse { +func (a *AsyncJobWh) jobStatus(payload *StartJobReqPayload) WhStatusResponse { var statusResponse WhStatusResponse a.logger.Info("[WH-Jobs]: Getting status for wh async jobs %v", payload) // Need to check for count first and see if there are any rows matching the job_run_id and task_run_id. If none, then raise an error instead of showing complete diff --git a/warehouse/jobs/types.go b/warehouse/jobs/types.go index 8f02a27100..10313fab5e 100644 --- a/warehouse/jobs/types.go +++ b/warehouse/jobs/types.go @@ -68,11 +68,6 @@ type PGNotifierOutput struct { Id string `json:"id"` } -type WhAddJobResponse struct { - JobIds []int64 `json:"jobids"` - Err error `json:"error"` -} - type WhStatusResponse struct { Status string Err string diff --git a/warehouse/jobs/utils.go b/warehouse/jobs/utils.go index a822563659..36635db882 100644 --- a/warehouse/jobs/utils.go +++ b/warehouse/jobs/utils.go @@ -31,13 +31,6 @@ func getMessagePayloadsFromAsyncJobPayloads(asyncJobPayloads []AsyncJobPayload) return messages, nil } -func validatePayload(payload StartJobReqPayload) bool { - if payload.SourceID == "" || payload.JobRunID == "" || payload.TaskRunID == "" || payload.DestinationID == "" { - return false - } - return true -} - func getAsyncStatusMapFromAsyncPayloads(payloads []AsyncJobPayload) map[string]AsyncJobStatus { asyncJobStatusMap := make(map[string]AsyncJobStatus) for _, payload := range payloads { diff --git a/warehouse/jobs/utils_test.go b/warehouse/jobs/utils_test.go deleted file mode 100644 index b3f3fa582a..0000000000 --- a/warehouse/jobs/utils_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package jobs - -import ( - "testing" -) - -func TestValidatePayload(t *testing.T) { - payloadTests := []struct { - payload StartJobReqPayload - expected bool - }{ - { - StartJobReqPayload{ - JobRunID: "", - TaskRunID: "", - }, - false, - }, - { - StartJobReqPayload{ - JobRunID: "abc", - TaskRunID: "bbc", - SourceID: "cbc", - DestinationID: "dbc", - WorkspaceID: "ebc", - }, - true, - }, - } - for _, tt := range payloadTests { - output := validatePayload(tt.payload) - if output != tt.expected { - t.Errorf("error in function validatepayload, expected %t and got %t", tt.expected, output) - } - } -} diff --git a/warehouse/logfield/logfield.go b/warehouse/logfield/logfield.go index b7d7cb037d..4bea8771da 100644 --- a/warehouse/logfield/logfield.go +++ b/warehouse/logfield/logfield.go @@ -4,6 +4,7 @@ const ( UploadJobID = "uploadJobID" UploadStatus = "uploadStatus" UseRudderStorage = "useRudderStorage" + TaskRunID = "taskRunID" SourceID = "sourceID" SourceType = "sourceType" DestinationID = "destinationID" diff --git a/warehouse/mode.go b/warehouse/mode.go new file mode 100644 index 0000000000..4ba17e2434 --- /dev/null +++ b/warehouse/mode.go @@ -0,0 +1,34 @@ +package warehouse + +import "github.com/rudderlabs/rudder-go-kit/config" + +func isStandAlone(mode string) bool { + switch mode { + case config.EmbeddedMode, config.EmbeddedMasterMode: + return false + default: + return true + } +} + +func isMaster(mode string) bool { + switch mode { + case config.MasterMode, config.MasterSlaveMode, config.EmbeddedMode, config.EmbeddedMasterMode: + return true + default: + return false + } +} + +func isSlave(mode string) bool { + switch mode { + case config.SlaveMode, config.MasterSlaveMode, config.EmbeddedMode: + return true + default: + return false + } +} + +func isStandAloneSlave(mode string) bool { + return mode == config.SlaveMode +} diff --git a/warehouse/mode_test.go b/warehouse/mode_test.go new file mode 100644 index 0000000000..f4e93dfc04 --- /dev/null +++ b/warehouse/mode_test.go @@ -0,0 +1,157 @@ +package warehouse + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-go-kit/config" +) + +func TestIsStandAlone(t *testing.T) { + testCases := []struct { + name string + isStandAlone bool + }{ + { + name: config.EmbeddedMode, + isStandAlone: false, + }, + { + name: config.EmbeddedMasterMode, + isStandAlone: false, + }, + { + name: config.MasterMode, + isStandAlone: true, + }, + { + name: config.MasterSlaveMode, + isStandAlone: true, + }, + { + name: config.SlaveMode, + isStandAlone: true, + }, + { + name: config.OffMode, + isStandAlone: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, isStandAlone(tc.name), tc.isStandAlone) + }) + } +} + +func TestIsMaster(t *testing.T) { + testCases := []struct { + name string + isMaster bool + }{ + { + name: config.EmbeddedMode, + isMaster: true, + }, + { + name: config.EmbeddedMasterMode, + isMaster: true, + }, + { + name: config.MasterMode, + isMaster: true, + }, + { + name: config.MasterSlaveMode, + isMaster: true, + }, + { + name: config.SlaveMode, + isMaster: false, + }, + { + name: config.OffMode, + isMaster: false, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, isMaster(tc.name), tc.isMaster) + }) + } +} + +func TestIsSlave(t *testing.T) { + testCases := []struct { + name string + isSlave bool + }{ + { + name: config.EmbeddedMode, + isSlave: true, + }, + { + name: config.EmbeddedMasterMode, + isSlave: false, + }, + { + name: config.MasterMode, + isSlave: false, + }, + { + name: config.MasterSlaveMode, + isSlave: true, + }, + { + name: config.SlaveMode, + isSlave: true, + }, + { + name: config.OffMode, + isSlave: false, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, isSlave(tc.name), tc.isSlave) + }) + } +} + +func TestIsStandAloneSlave(t *testing.T) { + testCases := []struct { + name string + isStandAloneSlave bool + }{ + { + name: config.EmbeddedMode, + isStandAloneSlave: false, + }, + { + name: config.EmbeddedMasterMode, + isStandAloneSlave: false, + }, + { + name: config.MasterMode, + isStandAloneSlave: false, + }, + { + name: config.MasterSlaveMode, + isStandAloneSlave: false, + }, + { + name: config.SlaveMode, + isStandAloneSlave: true, + }, + { + name: config.OffMode, + isStandAloneSlave: false, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, isStandAloneSlave(tc.name), tc.isStandAloneSlave) + }) + } +} diff --git a/warehouse/multitenant/manager.go b/warehouse/multitenant/manager.go index 0ac6d81bcd..32472215b0 100644 --- a/warehouse/multitenant/manager.go +++ b/warehouse/multitenant/manager.go @@ -9,15 +9,9 @@ import ( backendconfig "github.com/rudderlabs/rudder-server/backend-config" ) -var degradedWorkspaceIDs []string - -func init() { - config.RegisterStringSliceConfigVariable(nil, °radedWorkspaceIDs, false, "Warehouse.degradedWorkspaceIDs") -} - type Manager struct { - BackendConfig backendconfig.BackendConfig - DegradedWorkspaceIDs []string + backendConfig backendconfig.BackendConfig + degradedWorkspaceIDs []string sourceIDToWorkspaceID map[string]string excludeWorkspaceIDMap map[string]struct{} @@ -28,16 +22,20 @@ type Manager struct { initOnce sync.Once } +func New(conf *config.Config, bcConfig backendconfig.BackendConfig) *Manager { + m := &Manager{} + m.backendConfig = bcConfig + m.degradedWorkspaceIDs = conf.GetStringSlice("Warehouse.degradedWorkspaceIDs", nil) + + return m +} + func (m *Manager) init() { m.initOnce.Do(func() { - if m.DegradedWorkspaceIDs == nil { - m.DegradedWorkspaceIDs = degradedWorkspaceIDs - } - m.sourceIDToWorkspaceID = make(map[string]string) m.excludeWorkspaceIDMap = make(map[string]struct{}) - for _, workspaceID := range m.DegradedWorkspaceIDs { + for _, workspaceID := range m.degradedWorkspaceIDs { m.excludeWorkspaceIDMap[workspaceID] = struct{}{} } m.ready = make(chan struct{}) @@ -48,7 +46,7 @@ func (m *Manager) init() { func (m *Manager) Run(ctx context.Context) { m.init() - chIn := m.BackendConfig.Subscribe(ctx, backendconfig.TopicBackendConfig) + chIn := m.backendConfig.Subscribe(ctx, backendconfig.TopicBackendConfig) for data := range chIn { m.sourceMu.Lock() config := data.Data.(map[string]backendconfig.ConfigT) @@ -76,7 +74,7 @@ func (m *Manager) DegradedWorkspace(workspaceID string) bool { func (m *Manager) DegradedWorkspaces() []string { m.init() - return m.DegradedWorkspaceIDs + return m.degradedWorkspaceIDs } // SourceToWorkspace returns the workspaceID for a given sourceID, even if workspaceID is degraded. @@ -109,7 +107,7 @@ func (m *Manager) SourceToWorkspace(ctx context.Context, sourceID string) (strin func (m *Manager) WatchConfig(ctx context.Context) <-chan map[string]backendconfig.ConfigT { m.init() - chIn := m.BackendConfig.Subscribe(ctx, backendconfig.TopicBackendConfig) + chIn := m.backendConfig.Subscribe(ctx, backendconfig.TopicBackendConfig) chOut := make(chan map[string]backendconfig.ConfigT) diff --git a/warehouse/multitenant/manager_test.go b/warehouse/multitenant/manager_test.go index 96e4a7722c..6dfc6697a3 100644 --- a/warehouse/multitenant/manager_test.go +++ b/warehouse/multitenant/manager_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" @@ -92,12 +94,13 @@ func TestDegradeWorkspace(t *testing.T) { WorkspaceID: workspace, } } - m := multitenant.Manager{ - BackendConfig: &mockBackendConfig{ - config: backendConfig, - }, - DegradedWorkspaceIDs: tc.degradedWorkspaces, - } + + c := config.New() + c.Set("Warehouse.degradedWorkspaceIDs", tc.degradedWorkspaces) + + m := multitenant.New(c, &mockBackendConfig{ + config: backendConfig, + }) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -147,11 +150,9 @@ func TestSourceToWorkspace(t *testing.T) { backendConfig[workspace] = entry } - m := multitenant.Manager{ - BackendConfig: &mockBackendConfig{ - config: backendConfig, - }, - } + m := multitenant.New(config.Default, &mockBackendConfig{ + config: backendConfig, + }) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -176,11 +177,9 @@ func TestSourceToWorkspace(t *testing.T) { require.NoError(t, g.Wait()) t.Run("context canceled", func(t *testing.T) { - m := multitenant.Manager{ - BackendConfig: &mockBackendConfig{ - config: backendConfig, - }, - } + m := multitenant.New(config.Default, &mockBackendConfig{ + config: backendConfig, + }) ctx, cancel := context.WithCancel(context.Background()) cancel() diff --git a/warehouse/router_test.go b/warehouse/router_test.go index 9772b2b3c2..41b876a94e 100644 --- a/warehouse/router_test.go +++ b/warehouse/router_test.go @@ -97,9 +97,9 @@ func TestRouter(t *testing.T) { notifier, err := pgnotifier.New(workspaceIdentifier, pgResource.DBDsn) require.NoError(t, err) - tenantManager := &multitenant.Manager{ - BackendConfig: mocksBackendConfig.NewMockBackendConfig(gomock.NewController(t)), - } + ctrl := gomock.NewController(t) + + tenantManager := multitenant.New(config.Default, mocksBackendConfig.NewMockBackendConfig(ctrl)) s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) @@ -449,9 +449,7 @@ func TestRouter(t *testing.T) { r.config.warehouseSyncFreqIgnore = true r.destType = destinationType r.logger = logger.NOP - r.tenantManager = &multitenant.Manager{ - BackendConfig: mocksBackendConfig.NewMockBackendConfig(ctrl), - } + r.tenantManager = multitenant.New(config.Default, mocksBackendConfig.NewMockBackendConfig(ctrl)) r.warehouses = []model.Warehouse{warehouse} r.uploadJobFactory = UploadJobFactory{ app: mockApp, @@ -586,9 +584,7 @@ func TestRouter(t *testing.T) { r.config.uploadAllocatorSleep = time.Millisecond * 100 r.destType = warehouseutils.RS r.logger = logger.NOP - r.tenantManager = &multitenant.Manager{ - BackendConfig: mocksBackendConfig.NewMockBackendConfig(ctrl), - } + r.tenantManager = multitenant.New(config.Default, mocksBackendConfig.NewMockBackendConfig(ctrl)) r.bcManager = newBackendConfigManager(r.conf, r.dbHandle, r.tenantManager, r.logger) r.warehouses = []model.Warehouse{warehouse} r.uploadJobFactory = UploadJobFactory{ @@ -736,9 +732,7 @@ func TestRouter(t *testing.T) { r.config.uploadAllocatorSleep = time.Millisecond * 100 r.destType = warehouseutils.RS r.logger = logger.NOP - r.tenantManager = &multitenant.Manager{ - BackendConfig: mocksBackendConfig.NewMockBackendConfig(ctrl), - } + r.tenantManager = multitenant.New(config.Default, mocksBackendConfig.NewMockBackendConfig(ctrl)) r.bcManager = newBackendConfigManager(r.conf, r.dbHandle, r.tenantManager, r.logger) r.warehouses = []model.Warehouse{warehouse} r.uploadJobFactory = UploadJobFactory{ @@ -920,8 +914,10 @@ func TestRouter(t *testing.T) { _, err = pgResource.DB.Exec(string(sqlStatement)) require.NoError(t, err) + ctrl := gomock.NewController(t) + ctx := context.Background() - tenantManager = &multitenant.Manager{} + tenantManager = multitenant.New(config.Default, mocksBackendConfig.NewMockBackendConfig(ctrl)) jobStats, err := repo.NewUploads(sqlmiddleware.New(pgResource.DB), repo.WithNow(func() time.Time { // nowSQL := "'2022-12-06 22:00:00'" @@ -1046,9 +1042,7 @@ func TestRouter(t *testing.T) { r.logger = logger.NOP r.destType = warehouseutils.RS r.config.maxConcurrentUploadJobs = 1 - r.tenantManager = &multitenant.Manager{ - BackendConfig: mockBackendConfig, - } + r.tenantManager = multitenant.New(config.Default, mockBackendConfig) r.bcManager = newBackendConfigManager(r.conf, r.dbHandle, r.tenantManager, r.logger) go func() { diff --git a/warehouse/slave_worker_test.go b/warehouse/slave_worker_test.go index 221fb399d8..11cf02dfe8 100644 --- a/warehouse/slave_worker_test.go +++ b/warehouse/slave_worker_test.go @@ -521,9 +521,7 @@ func TestSlaveWorker(t *testing.T) { return ch }).AnyTimes() - tenantManager := &multitenant.Manager{ - BackendConfig: mockBackendConfig, - } + tenantManager := multitenant.New(config.Default, mockBackendConfig) bcm := newBackendConfigManager(config.Default, nil, tenantManager, logger.NOP) ef := encoding.NewFactory(config.Default) diff --git a/warehouse/utils/utils.go b/warehouse/utils/utils.go index b2ef925ae1..b7497677da 100644 --- a/warehouse/utils/utils.go +++ b/warehouse/utils/utils.go @@ -275,32 +275,11 @@ type QueryResult struct { Values [][]string } -type PendingEventsRequest struct { - SourceID string `json:"source_id"` - TaskRunID string `json:"task_run_id"` -} - -type PendingEventsResponse struct { - PendingEvents bool `json:"pending_events"` - PendingStagingFilesCount int64 `json:"pending_staging_files"` - PendingUploadCount int64 `json:"pending_uploads"` - AbortedEvents bool `json:"aborted_events"` -} - -type TriggerUploadRequest struct { - SourceID string `json:"source_id"` - DestinationID string `json:"destination_id"` -} - type SourceIDDestinationID struct { SourceID string `json:"source_id"` DestinationID string `json:"destination_id"` } -type FetchTablesRequest struct { - Connections []SourceIDDestinationID `json:"connections"` -} - type FetchTableInfo struct { SourceID string `json:"source_id"` DestinationID string `json:"destination_id"` @@ -308,10 +287,6 @@ type FetchTableInfo struct { Tables []string `json:"tables"` } -type FetchTablesResponse struct { - ConnectionsTables []FetchTableInfo `json:"connections_tables"` -} - func TimingFromJSONString(str sql.NullString) (status string, recordedTime time.Time) { timingsMap := gjson.Parse(str.String).Map() for s, t := range timingsMap { diff --git a/warehouse/warehouse.go b/warehouse/warehouse.go index 13ab9e8c96..dd4ae302ab 100644 --- a/warehouse/warehouse.go +++ b/warehouse/warehouse.go @@ -3,31 +3,23 @@ package warehouse import ( "context" "database/sql" - "encoding/json" "errors" "expvar" "fmt" - "io" - "net/http" "os" - "runtime" "strconv" - "strings" "sync" "time" "github.com/rudderlabs/rudder-server/warehouse/encoding" - "github.com/bugsnag/bugsnag-go/v2" "github.com/cenkalti/backoff/v4" - "github.com/go-chi/chi/v5" "github.com/samber/lo" "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/filemanager" - kithttputil "github.com/rudderlabs/rudder-go-kit/httputil" "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/stats" "github.com/rudderlabs/rudder-server/app" @@ -43,9 +35,7 @@ import ( "github.com/rudderlabs/rudder-server/utils/types" "github.com/rudderlabs/rudder-server/warehouse/archive" "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" - "github.com/rudderlabs/rudder-server/warehouse/internal/api" "github.com/rudderlabs/rudder-server/warehouse/internal/model" - "github.com/rudderlabs/rudder-server/warehouse/internal/repo" "github.com/rudderlabs/rudder-server/warehouse/jobs" "github.com/rudderlabs/rudder-server/warehouse/multitenant" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" @@ -53,7 +43,6 @@ import ( var ( application app.App - webPort int dbHandle *sql.DB wrappedDBHandle *sqlquerywrapper.DB dbHandleTimeout time.Duration @@ -64,12 +53,10 @@ var ( lastProcessedMarkerMap map[string]int64 lastProcessedMarkerExp = expvar.NewMap("lastProcessedMarkerMap") lastProcessedMarkerMapLock sync.RWMutex - warehouseMode string bcManager *backendConfigManager triggerUploadsMap map[string]bool // `whType:sourceID:destinationID` -> boolean value representing if an upload was triggered or not triggerUploadsMapLock sync.RWMutex pkgLogger logger.Logger - runningMode string ShouldForceSetLowerVersion bool asyncWh *jobs.AsyncJobWh ) @@ -81,12 +68,6 @@ var ( var defaultUploadPriority = 100 -// warehouses worker modes -const ( - EmbeddedMode = "embedded" - EmbeddedMasterMode = "embedded_master" -) - const ( DegradedMode = "degraded" triggerUploadQPName = "triggerUpload" @@ -104,10 +85,8 @@ func Init4() { func loadConfig() { // Port where WH is running - config.RegisterIntConfigVariable(8082, &webPort, false, 1, "Warehouse.webPort") config.RegisterInt64ConfigVariable(1800, &uploadFreqInS, true, 1, "Warehouse.uploadFreqInS") lastProcessedMarkerMap = map[string]int64{} - config.RegisterStringConfigVariable("embedded", &warehouseMode, false, "Warehouse.mode") host = config.GetString("WAREHOUSE_JOBS_DB_HOST", "localhost") user = config.GetString("WAREHOUSE_JOBS_DB_USER", "ubuntu") dbname = config.GetString("WAREHOUSE_JOBS_DB_DB_NAME", "ubuntu") @@ -115,7 +94,6 @@ func loadConfig() { password = config.GetString("WAREHOUSE_JOBS_DB_PASSWORD", "ubuntu") // Reading secrets from sslMode = config.GetString("WAREHOUSE_JOBS_DB_SSL_MODE", "disable") triggerUploadsMap = map[string]bool{} - runningMode = config.GetString("Warehouse.runningMode", "") config.RegisterBoolConfigVariable(true, &ShouldForceSetLowerVersion, false, "SQLMigrator.forceSetLowerVersion") config.RegisterDurationConfigVariable(5, &dbHandleTimeout, true, time.Minute, []string{"Warehouse.dbHandleTimeout", "Warehouse.dbHanndleTimeoutInMin"}...) @@ -247,151 +225,6 @@ func setupTables(dbHandle *sql.DB) error { return nil } -func pendingEventsHandler(w http.ResponseWriter, r *http.Request) { - // TODO : respond with errors in a common way - pkgLogger.LogRequest(r) - - ctx := r.Context() - // read body - body, err := io.ReadAll(r.Body) - if err != nil { - pkgLogger.Errorf("[WH]: Error reading body: %v", err) - http.Error(w, "can't read body", http.StatusBadRequest) - return - } - defer func() { _ = r.Body.Close() }() - - // unmarshall body - var pendingEventsReq warehouseutils.PendingEventsRequest - err = json.Unmarshal(body, &pendingEventsReq) - if err != nil { - pkgLogger.Errorf("[WH]: Error unmarshalling body: %v", err) - http.Error(w, "can't unmarshall body", http.StatusBadRequest) - return - } - - sourceID, taskRunID := pendingEventsReq.SourceID, pendingEventsReq.TaskRunID - // return error if source id is empty - if sourceID == "" || taskRunID == "" { - pkgLogger.Errorf("empty source_id or task_run_id in the pending events request") - http.Error(w, "empty source_id or task_run_id", http.StatusBadRequest) - return - } - - workspaceID, err := tenantManager.SourceToWorkspace(ctx, sourceID) - if err != nil { - pkgLogger.Errorf("[WH]: Error checking if source is degraded: %v", err) - http.Error(w, "workspaceID from sourceID not found", http.StatusBadRequest) - return - } - - if tenantManager.DegradedWorkspace(workspaceID) { - pkgLogger.Infof("[WH]: Workspace (id: %q) is degraded: %v", workspaceID, err) - http.Error(w, "workspace is in degraded mode", http.StatusServiceUnavailable) - return - } - - pendingEvents := false - var ( - pendingStagingFileCount int64 - pendingUploadCount int64 - ) - - // check whether there are any pending staging files or uploads for the given source id - // get pending staging files - pendingStagingFileCount, err = repo.NewStagingFiles(wrappedDBHandle).CountPendingForSource(ctx, sourceID) - if err != nil { - err := fmt.Errorf("error getting pending staging file count : %v", err) - pkgLogger.Errorf("[WH]: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - filters := []repo.FilterBy{ - {Key: "source_id", Value: sourceID}, - {Key: "metadata->>'source_task_run_id'", Value: taskRunID}, - {Key: "status", NotEquals: true, Value: model.ExportedData}, - {Key: "status", NotEquals: true, Value: model.Aborted}, - } - - pendingUploadCount, err = getFilteredCount(ctx, filters...) - - if err != nil { - pkgLogger.Errorf("getting pending uploads count", "error", err) - http.Error(w, fmt.Sprintf( - "getting pending uploads count: %s", err.Error()), - http.StatusInternalServerError) - return - } - - filters = []repo.FilterBy{ - {Key: "source_id", Value: sourceID}, - {Key: "metadata->>'source_task_run_id'", Value: pendingEventsReq.TaskRunID}, - {Key: "status", Value: "aborted"}, - } - - abortedUploadCount, err := getFilteredCount(ctx, filters...) - if err != nil { - pkgLogger.Errorf("getting aborted uploads count", "error", err.Error()) - http.Error(w, fmt.Sprintf("getting aborted uploads count: %s", err), http.StatusInternalServerError) - return - } - - // if there are any pending staging files or uploads, set pending events as true - if (pendingStagingFileCount + pendingUploadCount) > int64(0) { - pendingEvents = true - } - - // read `triggerUpload` queryParam - var triggerPendingUpload bool - triggerUploadQP := r.URL.Query().Get(triggerUploadQPName) - if triggerUploadQP != "" { - triggerPendingUpload, _ = strconv.ParseBool(triggerUploadQP) - } - - // trigger upload if there are pending events and triggerPendingUpload is true - if pendingEvents && triggerPendingUpload { - pkgLogger.Infof("[WH]: Triggering upload for all wh destinations connected to source '%s'", sourceID) - - wh := bcManager.WarehousesBySourceID(sourceID) - - // return error if no such destinations found - if len(wh) == 0 { - err := fmt.Errorf("no warehouse destinations found for source id '%s'", sourceID) - pkgLogger.Errorf("[WH]: %v", err) - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - for _, warehouse := range wh { - triggerUpload(warehouse) - } - } - - // create and write response - res := warehouseutils.PendingEventsResponse{ - PendingEvents: pendingEvents, - PendingStagingFilesCount: pendingStagingFileCount, - PendingUploadCount: pendingUploadCount, - AbortedEvents: abortedUploadCount > 0, - } - - resBody, err := json.Marshal(res) - if err != nil { - err := fmt.Errorf("failed to marshall pending events response : %v", err) - pkgLogger.Errorf("[WH]: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - _, _ = w.Write(resBody) -} - -func getFilteredCount(ctx context.Context, filters ...repo.FilterBy) (int64, error) { - pkgLogger.Debugf("fetching filtered count") - return repo.NewUploads(wrappedDBHandle).Count(ctx, filters...) -} - func getPendingUploadCount(filters ...warehouseutils.FilterBy) (uploadCount int64, err error) { pkgLogger.Debugf("Fetching pending upload count with filters: %v", filters) @@ -423,51 +256,6 @@ func getPendingUploadCount(filters ...warehouseutils.FilterBy) (uploadCount int6 return uploadCount, nil } -func triggerUploadHandler(w http.ResponseWriter, r *http.Request) { - // TODO : respond with errors in a common way - pkgLogger.LogRequest(r) - - ctx := r.Context() - - // read body - body, err := io.ReadAll(r.Body) - if err != nil { - pkgLogger.Errorf("[WH]: Error reading body: %v", err) - http.Error(w, "can't read body", http.StatusBadRequest) - return - } - defer func() { _ = r.Body.Close() }() - - // unmarshall body - var triggerUploadReq warehouseutils.TriggerUploadRequest - err = json.Unmarshal(body, &triggerUploadReq) - if err != nil { - pkgLogger.Errorf("[WH]: Error unmarshalling body: %v", err) - http.Error(w, "can't unmarshall body", http.StatusBadRequest) - return - } - - workspaceID, err := tenantManager.SourceToWorkspace(ctx, triggerUploadReq.SourceID) - if err != nil { - pkgLogger.Errorf("[WH]: Error checking if source is degraded: %v", err) - http.Error(w, "workspaceID from sourceID not found", http.StatusBadRequest) - return - } - - if tenantManager.DegradedWorkspace(workspaceID) { - pkgLogger.Infof("[WH]: Workspace (id: %q) is degraded: %v", workspaceID, err) - http.Error(w, "workspace is in degraded mode", http.StatusServiceUnavailable) - return - } - - err = TriggerUploadHandler(triggerUploadReq.SourceID, triggerUploadReq.DestinationID) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - w.WriteHeader(http.StatusOK) -} - func TriggerUploadHandler(sourceID, destID string) error { // return error if source id and dest id is empty if sourceID == "" && destID == "" { @@ -498,45 +286,6 @@ func TriggerUploadHandler(sourceID, destID string) error { return nil } -func fetchTablesHandler(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - defer func() { _ = r.Body.Close() }() - - body, err := io.ReadAll(r.Body) - if err != nil { - pkgLogger.Errorf("[WH]: Error reading body: %v", err) - http.Error(w, "can't read body", http.StatusBadRequest) - return - } - - var connectionsTableRequest warehouseutils.FetchTablesRequest - err = json.Unmarshal(body, &connectionsTableRequest) - if err != nil { - pkgLogger.Errorf("[WH]: Error unmarshalling body: %v", err) - http.Error(w, "can't unmarshall body", http.StatusBadRequest) - return - } - - schemaRepo := repo.NewWHSchemas(wrappedDBHandle) - tables, err := schemaRepo.GetTablesForConnection(ctx, connectionsTableRequest.Connections) - if err != nil { - pkgLogger.Errorf("[WH]: Error fetching tables: %v", err) - http.Error(w, "can't fetch tables from schemas repo", http.StatusInternalServerError) - return - } - resBody, err := json.Marshal(warehouseutils.FetchTablesResponse{ - ConnectionsTables: tables, - }) - if err != nil { - err := fmt.Errorf("failed to marshall tables to response : %v", err) - pkgLogger.Errorf("[WH]: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - _, _ = w.Write(resBody) -} - func isUploadTriggered(wh model.Warehouse) bool { triggerUploadsMapLock.RLock() defer triggerUploadsMapLock.RUnlock() @@ -556,63 +305,6 @@ func clearTriggeredUpload(wh model.Warehouse) { delete(triggerUploadsMap, wh.Identifier) } -func healthHandler(w http.ResponseWriter, _ *http.Request) { - var dbService, pgNotifierService string - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - if runningMode != DegradedMode { - if !CheckPGHealth(ctx, notifier.GetDBHandle()) { - http.Error(w, "Cannot connect to pgNotifierService", http.StatusInternalServerError) - return - } - pgNotifierService = "UP" - } - - if isMaster() { - if !CheckPGHealth(ctx, dbHandle) { - http.Error(w, "Cannot connect to dbService", http.StatusInternalServerError) - return - } - dbService = "UP" - } - - healthVal := fmt.Sprintf(` - { - "server": "UP", - "db": %q, - "pgNotifier": %q, - "acceptingEvents": "TRUE", - "warehouseMode": %q, - "goroutines": "%d" - } - `, - dbService, - pgNotifierService, - strings.ToUpper(warehouseMode), - runtime.NumGoroutine(), - ) - - _, _ = w.Write([]byte(healthVal)) -} - -func CheckPGHealth(ctx context.Context, db *sql.DB) bool { - if db == nil { - return false - } - - healthCheckMsg := "Rudder Warehouse DB Health Check" - msg := "" - - err := db.QueryRowContext(ctx, `SELECT '`+healthCheckMsg+`'::text as message;`).Scan(&msg) - if err != nil { - return false - } - - return healthCheckMsg == msg -} - func getConnectionString() string { if !CheckForWarehouseEnvVars() { return misc.GetConnectionString(config.Default) @@ -622,54 +314,6 @@ func getConnectionString() string { host, port, user, password, dbname, sslMode, appName) } -func startWebHandler(ctx context.Context) error { - srvMux := chi.NewRouter() - - // do not register same endpoint when running embedded in rudder backend - if isStandAlone() { - srvMux.Get("/health", healthHandler) - } - if runningMode != DegradedMode { - if isMaster() { - pkgLogger.Infof("WH: Warehouse master service waiting for BackendConfig before starting on %d", webPort) - backendconfig.DefaultBackendConfig.WaitForConfig(ctx) - - srvMux.Handle("/v1/process", (&api.WarehouseAPI{ - Logger: pkgLogger, - Stats: stats.Default, - Repo: repo.NewStagingFiles(wrappedDBHandle), - Multitenant: tenantManager, - }).Handler()) - - // triggers upload only when there are pending events and triggerUpload is sent for a sourceId - srvMux.Post("/v1/warehouse/pending-events", pendingEventsHandler) - // triggers uploads for a source - srvMux.Post("/v1/warehouse/trigger-upload", triggerUploadHandler) - - // Warehouse Async Job end-points - srvMux.Post("/v1/warehouse/jobs", asyncWh.AddWarehouseJobHandler) // FIXME: add degraded mode - srvMux.Get("/v1/warehouse/jobs/status", asyncWh.StatusWarehouseJobHandler) // FIXME: add degraded mode - - // fetch schema info - // TODO: Remove this endpoint once sources change is released - srvMux.Get("/v1/warehouse/fetch-tables", fetchTablesHandler) - srvMux.Get("/internal/v1/warehouse/fetch-tables", fetchTablesHandler) - - pkgLogger.Infof("WH: Starting warehouse master service in %d", webPort) - } else { - pkgLogger.Infof("WH: Starting warehouse slave service in %d", webPort) - } - } - - srv := &http.Server{ - Addr: fmt.Sprintf(":%d", webPort), - Handler: bugsnag.Handler(srvMux), - ReadHeaderTimeout: 3 * time.Second, - } - - return kithttputil.ListenAndServe(ctx, srv) -} - // CheckForWarehouseEnvVars Checks if all the required Env Variables for Warehouse are present func CheckForWarehouseEnvVars() bool { return config.IsSet("WAREHOUSE_JOBS_DB_HOST") && @@ -678,26 +322,6 @@ func CheckForWarehouseEnvVars() bool { config.IsSet("WAREHOUSE_JOBS_DB_PASSWORD") } -// This checks if gateway is running or not -func isStandAlone() bool { - return warehouseMode != EmbeddedMode && warehouseMode != EmbeddedMasterMode -} - -func isMaster() bool { - return warehouseMode == config.MasterMode || - warehouseMode == config.MasterSlaveMode || - warehouseMode == config.EmbeddedMode || - warehouseMode == config.EmbeddedMasterMode -} - -func isSlave() bool { - return warehouseMode == config.SlaveMode || warehouseMode == config.MasterSlaveMode || warehouseMode == config.EmbeddedMode -} - -func isStandAloneSlave() bool { - return warehouseMode == config.SlaveMode -} - func setupDB(ctx context.Context, connInfo string) error { var err error dbHandle, err = sql.Open("postgres", connInfo) @@ -746,11 +370,13 @@ func Setup(ctx context.Context) error { func Start(ctx context.Context, app app.App) error { application = app - if dbHandle == nil && !isStandAloneSlave() { + mode := config.GetString("Warehouse.mode", config.EmbeddedMode) + + if dbHandle == nil && !isStandAloneSlave(mode) { return errors.New("warehouse service cannot start, database connection is not setup") } // do not start warehouse service if rudder core is not in normal mode and warehouse is running in same process as rudder core - if !isStandAlone() && !db.IsNormalMode() { + if !isStandAlone(mode) && !db.IsNormalMode() { pkgLogger.Infof("Skipping start of warehouse service...") return nil } @@ -767,9 +393,8 @@ func Start(ctx context.Context, app app.App) error { g, gCtx := errgroup.WithContext(ctx) - tenantManager = &multitenant.Manager{ - BackendConfig: backendconfig.DefaultBackendConfig, - } + tenantManager = multitenant.New(config.Default, backendconfig.DefaultBackendConfig) + g.Go(func() error { tenantManager.Run(gCtx) return nil @@ -789,14 +414,20 @@ func Start(ctx context.Context, app app.App) error { runningMode := config.GetString("Warehouse.runningMode", "") if runningMode == DegradedMode { pkgLogger.Infof("WH: Running warehouse service in degraded mode...") - if isMaster() { + if isMaster(mode) { err := InitWarehouseAPI(dbHandle, bcManager, pkgLogger.Child("upload_api")) if err != nil { pkgLogger.Errorf("WH: Failed to start warehouse api: %v", err) return err } } - return startWebHandler(ctx) + + api := NewApi( + mode, config.Default, pkgLogger, stats.Default, + backendconfig.DefaultBackendConfig, wrappedDBHandle, nil, tenantManager, + bcManager, nil, + ) + return api.Start(ctx) } var err error workspaceIdentifier := fmt.Sprintf(`%s::%s`, config.GetKubeNamespace(), misc.GetMD5Hash(config.GetWorkspaceToken())) @@ -809,7 +440,7 @@ func Start(ctx context.Context, app app.App) error { // A different DB for warehouse is used when: // 1. MultiTenant (uses RDS) // 2. rudderstack-postgresql-warehouse pod in Hosted and Enterprise - if (isStandAlone() && isMaster()) || (misc.GetConnectionString(config.Default) != psqlInfo) { + if (isStandAlone(mode) && isMaster(mode)) || (misc.GetConnectionString(config.Default) != psqlInfo) { reporting := application.Features().Reporting.Setup(backendconfig.DefaultBackendConfig) g.Go(misc.WithBugsnagForWarehouse(func() error { @@ -818,7 +449,7 @@ func Start(ctx context.Context, app app.App) error { })) } - if isStandAlone() && isMaster() { + if isStandAlone(mode) && isMaster(mode) { // Report warehouse features g.Go(func() error { backendconfig.DefaultBackendConfig.WaitForConfig(gCtx) @@ -838,7 +469,7 @@ func Start(ctx context.Context, app app.App) error { }) } - if isSlave() { + if isSlave(mode) { pkgLogger.Infof("WH: Starting warehouse slave...") g.Go(misc.WithBugsnagForWarehouse(func() error { cm := newConstraintsManager(config.Default) @@ -849,7 +480,7 @@ func Start(ctx context.Context, app app.App) error { })) } - if isMaster() { + if isMaster(mode) { pkgLogger.Infof("[WH]: Starting warehouse master...") backendconfig.DefaultBackendConfig.WaitForConfig(ctx) @@ -896,7 +527,12 @@ func Start(ctx context.Context, app app.App) error { } g.Go(func() error { - return startWebHandler(gCtx) + api := NewApi( + mode, config.Default, pkgLogger, stats.Default, + backendconfig.DefaultBackendConfig, wrappedDBHandle, ¬ifier, tenantManager, + bcManager, asyncWh, + ) + return api.Start(gCtx) }) return g.Wait() diff --git a/warehouse/warehousegrpc_test.go b/warehouse/warehousegrpc_test.go index b06ab7e4b7..446ae8cb7e 100644 --- a/warehouse/warehousegrpc_test.go +++ b/warehouse/warehousegrpc_test.go @@ -20,7 +20,6 @@ import ( "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" - backendConfig "github.com/rudderlabs/rudder-server/backend-config" backendconfig "github.com/rudderlabs/rudder-server/backend-config" proto "github.com/rudderlabs/rudder-server/proto/warehouse" "github.com/rudderlabs/rudder-server/testhelper/destination" @@ -641,9 +640,7 @@ func setupWarehouseGRPCTest( pkgLogger = logger.NOP - tenantManager = &multitenant.Manager{ - BackendConfig: backendConfig.DefaultBackendConfig, - } + tenantManager = multitenant.New(config.Default, backendconfig.DefaultBackendConfig) bcManager = newBackendConfigManager( config.Default, wrappedDBHandle, tenantManager, nil,