diff --git a/runner/runner.go b/runner/runner.go index d8f310dea0..216d615005 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -35,7 +35,6 @@ import ( "github.com/rudderlabs/rudder-server/services/db" "github.com/rudderlabs/rudder-server/services/diagnostics" "github.com/rudderlabs/rudder-server/services/oauth" - "github.com/rudderlabs/rudder-server/services/pgnotifier" "github.com/rudderlabs/rudder-server/services/streammanager/kafka" "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/utils/types/deployment" @@ -323,7 +322,6 @@ func runAllInit() { diagnostics.Init() backendconfig.Init() warehouseutils.Init() - pgnotifier.Init() warehouse.Init4() validations.Init() eventschema.Init() diff --git a/services/notifier/notifier.go b/services/notifier/notifier.go new file mode 100644 index 0000000000..6f7dee9469 --- /dev/null +++ b/services/notifier/notifier.go @@ -0,0 +1,598 @@ +package notifier + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "math/rand" + "time" + + "github.com/lib/pq" + + "golang.org/x/sync/errgroup" + + "github.com/cenkalti/backoff" + "github.com/google/uuid" + + "github.com/allisson/go-pglock/v2" + "github.com/spaolacci/murmur3" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + migrator "github.com/rudderlabs/rudder-server/services/sql-migrator" + "github.com/rudderlabs/rudder-server/utils/misc" + sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" +) + +const ( + queueName = "pg_notifier_queue" + module = "pgnotifier" +) + +type JobType string + +const ( + JobTypeUpload JobType = "upload" + JobTypeAsync JobType = "async_job" +) + +type JobStatus string + +const ( + Waiting JobStatus = "waiting" + Executing JobStatus = "executing" + Succeeded JobStatus = "succeeded" + Failed JobStatus = "failed" + Aborted JobStatus = "aborted" +) + +type Job struct { + ID int64 + BatchID string + WorkerID string + WorkspaceIdentifier string + + Attempt int + Status JobStatus + Type JobType + Priority int + Error error + + Payload json.RawMessage + + CreatedAt time.Time + UpdatedAt time.Time + LastExecTime time.Time +} + +type PublishRequest struct { + Payloads []json.RawMessage + UploadSchema json.RawMessage // ATM Hack to support merging schema with the payload at the postgres level + JobType JobType + Priority int +} + +type PublishResponse struct { + Jobs []Job + Err error +} + +type ClaimJob struct { + Job *Job +} + +type ClaimJobResponse struct { + Payload json.RawMessage + Err error +} + +type notifierRepo interface { + resetForWorkspace(context.Context, string) error + insert(context.Context, *PublishRequest, string, string) error + pendingByBatchID(context.Context, string) (int64, error) + deleteByBatchID(context.Context, string) error + orphanJobIDs(context.Context, int) ([]int64, error) + getByBatchID(context.Context, string) ([]Job, error) + claim(context.Context, string) (*Job, error) + onClaimFailed(context.Context, *Job, error, int) error + onClaimSuccess(context.Context, *Job, json.RawMessage) error +} + +type Notifier struct { + conf *config.Config + logger logger.Logger + statsFactory stats.Stats + db *sqlmw.DB + repo notifierRepo + workspaceIdentifier string + batchIDGenerator func() uuid.UUID + randGenerator *rand.Rand + now func() time.Time + background struct { + group *errgroup.Group + groupCtx context.Context + groupCancel context.CancelFunc + groupWait func() error + } + + config struct { + host string + port int + user string + password string + database string + sslMode string + maxAttempt int + maxOpenConnections int + shouldForceSetLowerVersion bool + trackBatchInterval time.Duration + maxPollSleep misc.ValueLoader[time.Duration] + jobOrphanTimeout misc.ValueLoader[time.Duration] + queryTimeout time.Duration + } + stats struct { + insertRecords stats.Counter + publish stats.Counter + publishTime stats.Timer + claimSucceeded stats.Counter + claimSucceededTime stats.Timer + claimFailed stats.Counter + claimFailedTime stats.Timer + claimUpdateFailed stats.Counter + abortedRecords stats.Counter + } +} + +func New( + conf *config.Config, + log logger.Logger, + statsFactory stats.Stats, + workspaceIdentifier string, +) *Notifier { + n := &Notifier{ + conf: conf, + logger: log.Child("notifier"), + statsFactory: statsFactory, + workspaceIdentifier: workspaceIdentifier, + batchIDGenerator: misc.FastUUID, + randGenerator: rand.New(rand.NewSource(time.Now().UnixNano())), + now: time.Now, + } + + n.logger.Infof("Initializing Notifier...") + + n.config.host = n.conf.GetString("PGNOTIFIER_DB_HOST", "localhost") + n.config.user = n.conf.GetString("PGNOTIFIER_DB_USER", "ubuntu") + n.config.database = n.conf.GetString("PGNOTIFIER_DB_NAME", "ubuntu") + n.config.port = n.conf.GetInt("PGNOTIFIER_DB_PORT", 5432) + n.config.password = n.conf.GetString("PGNOTIFIER_DB_PASSWORD", "ubuntu") + n.config.sslMode = n.conf.GetString("PGNOTIFIER_DB_SSL_MODE", "disable") + n.config.maxAttempt = n.conf.GetInt("PgNotifier.maxAttempt", 3) + n.config.maxOpenConnections = n.conf.GetInt("PgNotifier.maxOpenConnections", 20) + n.config.shouldForceSetLowerVersion = n.conf.GetBool("SQLMigrator.forceSetLowerVersion", true) + n.config.trackBatchInterval = n.conf.GetDuration("PgNotifier.trackBatchIntervalInS", 2, time.Second) + n.config.queryTimeout = n.conf.GetDuration("Warehouse.pgNotifierQueryTimeout", 5, time.Minute) + n.config.maxPollSleep = n.conf.GetReloadableDurationVar(5000, time.Millisecond, "PgNotifier.maxPollSleep") + n.config.jobOrphanTimeout = n.conf.GetReloadableDurationVar(120, time.Second, "PgNotifier.jobOrphanTimeout") + + n.stats.insertRecords = n.statsFactory.NewTaggedStat("pg_notifier.insert_records", stats.CountType, stats.Tags{ + "module": "pg_notifier", + "queueName": queueName, + }) + n.stats.publish = n.statsFactory.NewTaggedStat("pgnotifier.publish", stats.CountType, stats.Tags{ + "module": module, + }) + n.stats.claimSucceeded = n.statsFactory.NewTaggedStat("pgnotifier.claim", stats.CountType, stats.Tags{ + "module": module, + "status": string(Succeeded), + }) + n.stats.claimFailed = n.statsFactory.NewTaggedStat("pgnotifier.claim", stats.CountType, stats.Tags{ + "module": module, + "status": string(Failed), + }) + n.stats.claimUpdateFailed = n.statsFactory.NewStat("pgnotifier.claimUpdateFailed", stats.CountType) + n.stats.publishTime = n.statsFactory.NewTaggedStat("pgnotifier.publishTime", stats.TimerType, stats.Tags{ + "module": module, + }) + n.stats.claimSucceededTime = n.statsFactory.NewTaggedStat("pgnotifier.claimTime", stats.TimerType, stats.Tags{ + "module": module, + "status": string(Succeeded), + }) + n.stats.claimFailedTime = n.statsFactory.NewTaggedStat("pgnotifier.claimTime", stats.TimerType, stats.Tags{ + "module": module, + "status": string(Failed), + }) + n.stats.abortedRecords = n.statsFactory.NewTaggedStat("pg_notifier.aborted_records", stats.CountType, stats.Tags{ + "workspace": n.workspaceIdentifier, + "module": "pg_notifier", + "queueName": queueName, + }) + return n +} + +func (n *Notifier) Setup( + ctx context.Context, + fallbackDSN string, +) error { + dsn := fallbackDSN + if n.checkForNotifierEnvVars() { + dsn = n.connectionString() + } + + if err := n.setupDatabase(ctx, dsn); err != nil { + return fmt.Errorf("could not setup db: %w", err) + } + n.repo = newRepo(n.db) + + groupCtx, groupCancel := context.WithCancel(ctx) + n.background.group, n.background.groupCtx = errgroup.WithContext(groupCtx) + n.background.groupCancel = groupCancel + n.background.groupWait = n.background.group.Wait + + return nil +} + +func (n *Notifier) checkForNotifierEnvVars() bool { + return n.conf.IsSet("PGNOTIFIER_DB_HOST") && + n.conf.IsSet("PGNOTIFIER_DB_USER") && + n.conf.IsSet("PGNOTIFIER_DB_NAME") && + n.conf.IsSet("PGNOTIFIER_DB_PASSWORD") +} + +func (n *Notifier) connectionString() string { + return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", + n.config.host, + n.config.port, + n.config.user, + n.config.password, + n.config.database, + n.config.sslMode, + ) +} + +func (n *Notifier) setupDatabase( + ctx context.Context, + dsn string, +) error { + database, err := sql.Open("postgres", dsn) + if err != nil { + return fmt.Errorf("could not open: %w", err) + } + database.SetMaxOpenConns(n.config.maxOpenConnections) + + if err := database.PingContext(ctx); err != nil { + return fmt.Errorf("could not ping: %w", err) + } + + n.db = sqlmw.New( + database, + sqlmw.WithLogger(n.logger.Child("notifier-db")), + sqlmw.WithQueryTimeout(n.config.queryTimeout), + sqlmw.WithStats(n.statsFactory), + ) + + if err := n.setupTables(); err != nil { + return fmt.Errorf("could not setup tables: %w", err) + } + return nil +} + +func (n *Notifier) setupTables() error { + m := &migrator.Migrator{ + Handle: n.db.DB, + MigrationsTable: "pg_notifier_queue_migrations", + ShouldForceSetLowerVersion: n.config.shouldForceSetLowerVersion, + } + + operation := func() error { + return m.Migrate("pg_notifier_queue") + } + + backoffWithMaxRetry := backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 3) + + err := backoff.RetryNotify(operation, backoffWithMaxRetry, func(err error, t time.Duration) { + n.logger.Warnf("retrying warehouse database migration in %s: %v", t, err) + }) + if err != nil { + return fmt.Errorf("could not migrate pg_notifier_queue: %w", err) + } + return nil +} + +// ClearJobs deletes all jobs for the current workspace. +func (n *Notifier) ClearJobs(ctx context.Context) error { + if n.workspaceIdentifier == "" { + return nil + } + + n.logger.Infof("Deleting all jobs for workspace: %s", n.workspaceIdentifier) + + err := n.repo.resetForWorkspace(ctx, n.workspaceIdentifier) + if err != nil { + return fmt.Errorf("could not reset notifier for workspace: %s: %w", n.workspaceIdentifier, err) + } + return nil +} + +func (n *Notifier) CheckHealth(ctx context.Context) bool { + healthCheckMsg := "Rudder Warehouse DB Health Check" + msg := "" + + err := n.db.QueryRowContext(ctx, `SELECT '`+healthCheckMsg+`'::text as message;`).Scan(&msg) + if err != nil { + return false + } + + return healthCheckMsg == msg +} + +// Publish inserts the payloads into the database and returns a channel of type PublishResponse +func (n *Notifier) Publish( + ctx context.Context, + publishRequest *PublishRequest, +) (<-chan *PublishResponse, error) { + publishStartTime := n.now() + + batchID := n.batchIDGenerator().String() + + if err := n.repo.insert(ctx, publishRequest, n.workspaceIdentifier, batchID); err != nil { + return nil, fmt.Errorf("inserting jobs: %w", err) + } + + n.logger.Infof("Inserted %d records into %s for batch: %s", len(publishRequest.Payloads), queueName, batchID) + + n.stats.insertRecords.Count(len(publishRequest.Payloads)) + + defer func() { + n.stats.publishTime.Since(publishStartTime) + n.stats.publish.Increment() + }() + + return n.trackBatch(ctx, batchID), nil +} + +// trackBatch tracks the batch and returns a channel of type PublishResponse +func (n *Notifier) trackBatch( + ctx context.Context, + batchID string, +) <-chan *PublishResponse { + publishResCh := make(chan *PublishResponse, 1) + + n.background.group.Go(func() error { + defer close(publishResCh) + + onUpdate := func(response *PublishResponse) { + select { + case <-ctx.Done(): + return + case <-n.background.groupCtx.Done(): + return + case publishResCh <- response: + } + } + + for { + select { + case <-ctx.Done(): + return nil + case <-n.background.groupCtx.Done(): + return nil + case <-time.After(n.config.trackBatchInterval): + } + + count, err := n.repo.pendingByBatchID(ctx, batchID) + if err != nil { + onUpdate(&PublishResponse{ + Err: fmt.Errorf("could not get pending count for batch: %s: %w", batchID, err), + }) + return nil + } else if count != 0 { + continue + } + + jobs, err := n.repo.getByBatchID(ctx, batchID) + if err != nil { + onUpdate(&PublishResponse{ + Err: fmt.Errorf("could not get jobs for batch: %s: %w", batchID, err), + }) + return nil + } + + err = n.repo.deleteByBatchID(ctx, batchID) + if err != nil { + onUpdate(&PublishResponse{ + Err: fmt.Errorf("could not delete jobs for batch: %s: %w", batchID, err), + }) + return nil + } + + n.logger.Infof("Completed processing all files in batch: %s", batchID) + + onUpdate(&PublishResponse{ + Jobs: jobs, + }) + return nil + } + }) + return publishResCh +} + +// Subscribe returns a channel of type Job +func (n *Notifier) Subscribe( + ctx context.Context, + workerId string, + bufferSize int, +) <-chan *ClaimJob { + jobsCh := make(chan *ClaimJob, bufferSize) + + nextPollInterval := func(pollSleep time.Duration) time.Duration { + pollSleep = 2*pollSleep + time.Duration(n.randGenerator.Intn(100))*time.Millisecond + + if pollSleep < n.config.maxPollSleep.Load() { + return pollSleep + } + + return n.config.maxPollSleep.Load() + } + + n.background.group.Go(func() error { + defer close(jobsCh) + + pollSleep := time.Duration(0) + + for { + job, err := n.claim(ctx, workerId) + if err != nil { + var pqErr *pq.Error + + switch { + case errors.Is(err, sql.ErrNoRows), + errors.Is(err, context.Canceled), + errors.Is(err, context.DeadlineExceeded), + errors.As(err, &pqErr) && pqErr.Code == "57014": + default: + n.logger.Warnf("claiming job: %v", err) + } + + pollSleep = nextPollInterval(pollSleep) + } else { + jobsCh <- &ClaimJob{ + Job: job, + } + + pollSleep = time.Duration(0) + } + + select { + case <-ctx.Done(): + return nil + case <-n.background.groupCtx.Done(): + return nil + case <-time.After(pollSleep): + } + } + }) + return jobsCh +} + +// Claim claims a job from the notifier queue +func (n *Notifier) claim( + ctx context.Context, + workerID string, +) (*Job, error) { + claimStartTime := n.now() + + claimedJob, err := n.repo.claim(ctx, workerID) + if errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("no jobs found: %w", err) + } + if err != nil { + n.stats.claimFailedTime.Since(claimStartTime) + n.stats.claimFailed.Increment() + + return nil, fmt.Errorf("claiming job: %w", err) + } + + n.stats.claimSucceededTime.Since(claimStartTime) + n.stats.claimSucceeded.Increment() + + return claimedJob, nil +} + +// UpdateClaim updates the notifier with the claimResponse +// In case if we are not able to update the claim, we are just logging it, +// maintenance workers can again mark the status as waiting, and it will be again claimed by somebody else. +// Although, there is a case that it is being picked up, but never getting updated. We can monitor it using claim lag. +// claim lag also helps us to make sure that even the maintenance workers are able to monitor the jobs correctly. +func (n *Notifier) UpdateClaim( + ctx context.Context, + claimedJob *ClaimJob, + response *ClaimJobResponse, +) { + if response.Err != nil { + if err := n.repo.onClaimFailed(ctx, claimedJob.Job, response.Err, n.config.maxAttempt); err != nil { + n.stats.claimUpdateFailed.Increment() + n.logger.Errorf("update claimed: on claimed failed: %v", err) + } + + if claimedJob.Job.Attempt > n.config.maxAttempt { + n.stats.abortedRecords.Increment() + } + return + } + + if err := n.repo.onClaimSuccess(ctx, claimedJob.Job, response.Payload); err != nil { + n.stats.claimUpdateFailed.Increment() + n.logger.Errorf("update claimed: on claimed success: %v", err) + } +} + +// RunMaintenance re-triggers zombie jobs which were left behind by dead workers in executing state +// Since it's a blocking call, it should be run in a separate goroutine +func (n *Notifier) RunMaintenance(ctx context.Context) error { + maintenanceWorkerLockID := murmur3.Sum64([]byte(queueName)) + maintenanceWorkerLock, err := pglock.NewLock(ctx, int64(maintenanceWorkerLockID), n.db.DB) + if err != nil { + return fmt.Errorf("creating maintenance worker lock: %w", err) + } + + var locked bool + defer func() { + if locked { + if err := maintenanceWorkerLock.Unlock(ctx); err != nil { + n.logger.Warnf("unlocking maintenance worker lock: %v", err) + } + } + }() + + for { + if locked, err = maintenanceWorkerLock.Lock(ctx); err != nil { + n.logger.Warnf("acquiring maintenance worker lock: %v", err) + } else if locked { + break + } + + select { + case <-ctx.Done(): + return nil + case <-time.After(n.config.jobOrphanTimeout.Load() / 5): + } + } + + for { + orphanJobIDs, err := n.repo.orphanJobIDs(ctx, int(n.config.jobOrphanTimeout.Load()/time.Second)) + if err != nil { + var pqErr *pq.Error + + switch { + case errors.Is(err, context.Canceled), + errors.Is(err, context.DeadlineExceeded), + errors.As(err, &pqErr) && pqErr.Code == "57014": + return nil + default: + return fmt.Errorf("fetching orphan job ids: %w", err) + } + } + + if len(orphanJobIDs) > 0 { + n.logger.Infof("Re-triggered job ids: %v", orphanJobIDs) + } + + select { + case <-ctx.Done(): + return nil + case <-time.After(n.config.jobOrphanTimeout.Load() / 5): + } + } +} + +// Shutdown waits for all the background jobs to be drained off. +func (n *Notifier) Shutdown() error { + n.logger.Infof("Shutting down notifier") + + n.background.groupCancel() + return n.background.group.Wait() +} diff --git a/services/notifier/notifier_test.go b/services/notifier/notifier_test.go new file mode 100644 index 0000000000..78adb3ea7f --- /dev/null +++ b/services/notifier/notifier_test.go @@ -0,0 +1,631 @@ +package notifier_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "testing" + "time" + + "github.com/samber/lo" + + "go.uber.org/atomic" + "golang.org/x/sync/errgroup" + + "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" + "github.com/rudderlabs/rudder-server/services/notifier" + + "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" + "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource/postgres" +) + +func setup(t testing.TB) *resource.PostgresResource { + t.Helper() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + pgResource, err := resource.SetupPostgres(pool, t, postgres.WithOptions("max_connections=1000")) + require.NoError(t, err) + + t.Log("db:", pgResource.DBDsn) + + return pgResource +} + +func TestNotifier(t *testing.T) { + t.Parallel() + + const ( + workspaceIdentifier = "test_workspace_identifier" + workerID = "test_worker" + ) + + t.Run("clear jobs", func(t *testing.T) { + t.Parallel() + + pgResource := setup(t) + ctx := context.Background() + + publishRequest := ¬ifier.PublishRequest{ + Payloads: []json.RawMessage{ + json.RawMessage(`{"id":"1"}`), + json.RawMessage(`{"id":"2"}`), + json.RawMessage(`{"id":"3"}`), + json.RawMessage(`{"id":"4"}`), + json.RawMessage(`{"id":"5"}`), + }, + JobType: notifier.JobTypeUpload, + UploadSchema: json.RawMessage(`{"UploadSchema": "1"}`), + Priority: 50, + } + + count := func(t testing.TB) int { + t.Helper() + + var count int + err := pgResource.DB.QueryRowContext(ctx, ` + SELECT + COUNT(*) + FROM + pg_notifier_queue; + `).Scan(&count) + require.NoError(t, err) + return count + } + + notifierWithIdentifier := notifier.New(config.Default, logger.NOP, stats.Default, workspaceIdentifier) + err := notifierWithIdentifier.Setup(ctx, pgResource.DBDsn) + require.NoError(t, err) + + notifierWithoutIdentifier := notifier.New(config.Default, logger.NOP, stats.Default, "") + err = notifierWithoutIdentifier.Setup(ctx, pgResource.DBDsn) + require.NoError(t, err) + + t.Run("without workspace identifier", func(t *testing.T) { + _, err = notifierWithIdentifier.Publish(ctx, publishRequest) + require.NoError(t, err) + require.Equal(t, len(publishRequest.Payloads), count(t)) + + err = notifierWithoutIdentifier.ClearJobs(ctx) + require.NoError(t, err) + require.Equal(t, len(publishRequest.Payloads), count(t)) + + err = notifierWithIdentifier.ClearJobs(ctx) + require.NoError(t, err) + require.Zero(t, count(t)) + }) + + t.Run("with workspace identifier", func(t *testing.T) { + _, err = notifierWithoutIdentifier.Publish(ctx, publishRequest) + require.NoError(t, err) + require.Equal(t, len(publishRequest.Payloads), count(t)) + + err = notifierWithIdentifier.ClearJobs(ctx) + require.NoError(t, err) + require.Equal(t, len(publishRequest.Payloads), count(t)) + + err = notifierWithoutIdentifier.ClearJobs(ctx) + require.NoError(t, err) + require.Equal(t, len(publishRequest.Payloads), count(t)) + }) + + t.Run("context cancelled", func(t *testing.T) { + cancelledCtx, cancel := context.WithCancel(ctx) + cancel() + + err = notifierWithIdentifier.ClearJobs(cancelledCtx) + require.ErrorIs(t, err, context.Canceled) + }) + }) + t.Run("health check", func(t *testing.T) { + t.Parallel() + + pgResource := setup(t) + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + n := notifier.New(config.Default, logger.NOP, stats.Default, workspaceIdentifier) + err := n.Setup(ctx, pgResource.DBDsn) + require.NoError(t, err) + require.True(t, n.CheckHealth(ctx)) + }) + t.Run("context cancelled", func(t *testing.T) { + cancelledCtx, cancel := context.WithCancel(ctx) + cancel() + + n := notifier.New(config.Default, logger.NOP, stats.Default, workspaceIdentifier) + err := n.Setup(ctx, pgResource.DBDsn) + require.NoError(t, err) + require.False(t, n.CheckHealth(cancelledCtx)) + }) + }) + t.Run("basic workflow", func(t *testing.T) { + t.Parallel() + + pgResource := setup(t) + ctx := context.Background() + + publishRequest := ¬ifier.PublishRequest{ + Payloads: []json.RawMessage{ + json.RawMessage(`{"id":"1"}`), + json.RawMessage(`{"id":"2"}`), + json.RawMessage(`{"id":"3"}`), + json.RawMessage(`{"id":"4"}`), + json.RawMessage(`{"id":"5"}`), + }, + JobType: notifier.JobTypeUpload, + UploadSchema: json.RawMessage(`{"UploadSchema": "1"}`), + Priority: 50, + } + + statsStore := memstats.New() + + c := config.New() + c.Set("PgNotifier.maxAttempt", 1) + c.Set("PgNotifier.maxOpenConnections", 500) + c.Set("PgNotifier.trackBatchIntervalInS", "1s") + c.Set("PgNotifier.maxPollSleep", "1s") + c.Set("PgNotifier.jobOrphanTimeout", "120s") + + groupCtx, groupCancel := context.WithCancel(ctx) + g, gCtx := errgroup.WithContext(groupCtx) + + n := notifier.New(c, logger.NOP, statsStore, workspaceIdentifier) + err := n.Setup(groupCtx, pgResource.DBDsn) + require.NoError(t, err) + + const ( + totalJobs = 12 + subscribers = 8 + subscriberWorkers = 4 + ) + + collectResponses := make(chan *notifier.PublishResponse) + + for i := 0; i < totalJobs; i++ { + g.Go(func() error { + publishCh, err := n.Publish(gCtx, publishRequest) + require.NoError(t, err) + + response, ok := <-publishCh + require.True(t, ok) + require.NotNil(t, response) + + collectResponses <- response + + response, ok = <-publishCh + require.False(t, ok) + require.Nil(t, response) + + return nil + }) + } + for i := 0; i < subscribers; i++ { + g.Go(func() error { + subscriberCh := n.Subscribe(gCtx, workerID, subscriberWorkers) + + slaveGroup, slaveCtx := errgroup.WithContext(gCtx) + + for j := 0; j < subscriberWorkers; j++ { + slaveGroup.Go(func() error { + for job := range subscriberCh { + switch job.Job.ID % 4 { + case 0, 1, 2: + n.UpdateClaim(slaveCtx, job, ¬ifier.ClaimJobResponse{ + Payload: json.RawMessage(`{"test": "payload"}`), + }) + case 3: + n.UpdateClaim(slaveCtx, job, ¬ifier.ClaimJobResponse{ + Err: errors.New("test error"), + }) + } + } + return nil + }) + } + return slaveGroup.Wait() + }) + } + g.Go(func() error { + return n.RunMaintenance(gCtx) + }) + g.Go(func() error { + <-groupCtx.Done() + return n.Shutdown() + }) + g.Go(func() error { + responses := make([]notifier.Job, 0, totalJobs*len(publishRequest.Payloads)) + + for i := 0; i < totalJobs; i++ { + job := <-collectResponses + require.NoError(t, job.Err) + require.Len(t, job.Jobs, len(publishRequest.Payloads)) + + successfulJobs := lo.Filter(job.Jobs, func(item notifier.Job, index int) bool { + return item.Error == nil + }) + for _, j := range successfulJobs { + require.EqualValues(t, j.Payload, json.RawMessage(`{"test": "payload"}`)) + } + responses = append(responses, job.Jobs...) + } + + successCount := (3 * totalJobs * len(publishRequest.Payloads)) / 4 + failureCount := totalJobs * len(publishRequest.Payloads) / 4 + + require.Len(t, lo.Filter(responses, func(item notifier.Job, index int) bool { + return item.Error == nil + }), + successCount, + ) + require.Len(t, lo.Filter(responses, func(item notifier.Job, index int) bool { + return item.Error != nil + }), + failureCount, + ) + + close(collectResponses) + groupCancel() + + return nil + }) + require.NoError(t, g.Wait()) + require.NoError(t, n.ClearJobs(ctx)) + require.EqualValues(t, statsStore.Get("pgnotifier.publish", stats.Tags{ + "module": "pgnotifier", + }).LastValue(), totalJobs) + require.EqualValues(t, statsStore.Get("pg_notifier.insert_records", stats.Tags{ + "module": "pg_notifier", + "queueName": "pg_notifier_queue", + }).LastValue(), totalJobs*len(publishRequest.Payloads)) + }) + t.Run("many publish jobs", func(t *testing.T) { + t.Parallel() + + pgResource := setup(t) + ctx := context.Background() + + const ( + batchSize = 1 + jobs = 25 + subscribers = 100 + subscriberWorkers = 4 + ) + + var payloads []json.RawMessage + for i := 0; i < batchSize; i++ { + payloads = append(payloads, json.RawMessage(fmt.Sprintf(`{"id": "%d"}`, i))) + } + + publishRequest := ¬ifier.PublishRequest{ + Payloads: payloads, + JobType: notifier.JobTypeUpload, + UploadSchema: json.RawMessage(`{"UploadSchema": "1"}`), + Priority: 50, + } + + c := config.New() + c.Set("PgNotifier.maxAttempt", 1) + c.Set("PgNotifier.maxOpenConnections", 900) + c.Set("PgNotifier.trackBatchIntervalInS", "1s") + c.Set("PgNotifier.maxPollSleep", "100ms") + c.Set("PgNotifier.jobOrphanTimeout", "120s") + + groupCtx, groupCancel := context.WithCancel(ctx) + g, gCtx := errgroup.WithContext(groupCtx) + + n := notifier.New(c, logger.NOP, stats.Default, workspaceIdentifier) + err := n.Setup(groupCtx, pgResource.DBDsn) + require.NoError(t, err) + + publishResponses := make(chan *notifier.PublishResponse) + + for i := 0; i < jobs; i++ { + g.Go(func() error { + publishCh, err := n.Publish(gCtx, publishRequest) + require.NoError(t, err) + + publishResponses <- <-publishCh + + return nil + }) + } + + for i := 0; i < subscribers; i++ { + g.Go(func() error { + subscriberCh := n.Subscribe(gCtx, workerID, subscriberWorkers) + + slaveGroup, slaveCtx := errgroup.WithContext(gCtx) + + for j := 0; j < subscriberWorkers; j++ { + slaveGroup.Go(func() error { + for job := range subscriberCh { + n.UpdateClaim(slaveCtx, job, ¬ifier.ClaimJobResponse{ + Payload: json.RawMessage(`{"test": "payload"}`), + }) + } + return nil + }) + } + return slaveGroup.Wait() + }) + } + g.Go(func() error { + for i := 0; i < jobs; i++ { + response := <-publishResponses + require.NoError(t, response.Err) + require.Len(t, response.Jobs, len(publishRequest.Payloads)) + } + close(publishResponses) + groupCancel() + return nil + }) + g.Go(func() error { + return n.RunMaintenance(gCtx) + }) + g.Go(func() error { + <-groupCtx.Done() + return n.Shutdown() + }) + require.NoError(t, g.Wait()) + }) + t.Run("bigger batches and many subscribers", func(t *testing.T) { + t.Parallel() + + pgResource := setup(t) + ctx := context.Background() + + const ( + batchSize = 500 + jobs = 1 + subscribers = 100 + subscriberWorkers = 4 + ) + + var payloads []json.RawMessage + for i := 0; i < batchSize; i++ { + payloads = append(payloads, json.RawMessage(fmt.Sprintf(`{"id": "%d"}`, i))) + } + + publishRequest := ¬ifier.PublishRequest{ + Payloads: payloads, + JobType: notifier.JobTypeUpload, + UploadSchema: json.RawMessage(`{"UploadSchema": "1"}`), + Priority: 50, + } + + c := config.New() + c.Set("PgNotifier.maxAttempt", 1) + c.Set("PgNotifier.maxOpenConnections", 900) + c.Set("PgNotifier.trackBatchIntervalInS", "1s") + c.Set("PgNotifier.maxPollSleep", "100ms") + c.Set("PgNotifier.jobOrphanTimeout", "120s") + + groupCtx, groupCancel := context.WithCancel(ctx) + g, gCtx := errgroup.WithContext(groupCtx) + + n := notifier.New(c, logger.NOP, stats.Default, workspaceIdentifier) + err := n.Setup(groupCtx, pgResource.DBDsn) + require.NoError(t, err) + + publishResponses := make(chan *notifier.PublishResponse) + + for i := 0; i < jobs; i++ { + g.Go(func() error { + publishCh, err := n.Publish(gCtx, publishRequest) + require.NoError(t, err) + + publishResponses <- <-publishCh + + return nil + }) + } + + for i := 0; i < subscribers; i++ { + g.Go(func() error { + subscriberCh := n.Subscribe(gCtx, workerID, subscriberWorkers) + + slaveGroup, slaveCtx := errgroup.WithContext(gCtx) + + for j := 0; j < subscriberWorkers; j++ { + slaveGroup.Go(func() error { + for job := range subscriberCh { + n.UpdateClaim(slaveCtx, job, ¬ifier.ClaimJobResponse{ + Payload: json.RawMessage(`{"test": "payload"}`), + }) + } + return nil + }) + } + return slaveGroup.Wait() + }) + } + g.Go(func() error { + for i := 0; i < jobs; i++ { + response := <-publishResponses + require.NoError(t, response.Err) + require.Len(t, response.Jobs, len(publishRequest.Payloads)) + } + close(publishResponses) + groupCancel() + return nil + }) + g.Go(func() error { + return n.RunMaintenance(gCtx) + }) + g.Go(func() error { + <-groupCtx.Done() + return n.Shutdown() + }) + require.NoError(t, g.Wait()) + }) + t.Run("round robin pickup and maintenance workers", func(t *testing.T) { + t.Parallel() + + pgResource := setup(t) + ctx := context.Background() + + const ( + batchSize = 1 + jobs = 1 + subscribers = 1 + subscriberWorkers = 4 + ) + + var payloads []json.RawMessage + for i := 0; i < batchSize; i++ { + payloads = append(payloads, json.RawMessage(fmt.Sprintf(`{"id": "%d"}`, i))) + } + + publishRequest := ¬ifier.PublishRequest{ + Payloads: payloads, + JobType: notifier.JobTypeUpload, + UploadSchema: json.RawMessage(`{"UploadSchema": "1"}`), + Priority: 50, + } + + c := config.New() + c.Set("PgNotifier.maxAttempt", 1) + c.Set("PgNotifier.maxOpenConnections", 900) + c.Set("PgNotifier.trackBatchIntervalInS", "1s") + c.Set("PgNotifier.maxPollSleep", "100ms") + c.Set("PgNotifier.jobOrphanTimeout", "3s") + + groupCtx, groupCancel := context.WithCancel(ctx) + g, gCtx := errgroup.WithContext(groupCtx) + + statsStore := memstats.New() + + n := notifier.New(c, logger.NOP, statsStore, workspaceIdentifier) + err := n.Setup(groupCtx, pgResource.DBDsn) + require.NoError(t, err) + + publishResponses := make(chan *notifier.PublishResponse) + + claimedWorkers := atomic.NewInt64(0) + + for i := 0; i < jobs; i++ { + g.Go(func() error { + publishCh, err := n.Publish(gCtx, publishRequest) + require.NoError(t, err) + + publishResponses <- <-publishCh + + return nil + }) + } + + for i := 0; i < subscribers; i++ { + g.Go(func() error { + subscriberCh := n.Subscribe(gCtx, workerID, subscriberWorkers) + + slaveGroup, slaveCtx := errgroup.WithContext(gCtx) + + blockSub := make(chan struct{}) + defer close(blockSub) + + for i := 0; i < subscriberWorkers; i++ { + slaveGroup.Go(func() error { + for job := range subscriberCh { + claimedWorkers.Add(1) + + if claimedWorkers.Load() < subscriberWorkers { + select { + case <-blockSub: + case <-slaveCtx.Done(): + return nil + } + } + + n.UpdateClaim(slaveCtx, job, ¬ifier.ClaimJobResponse{ + Payload: json.RawMessage(`{"test": "payload"}`), + }) + } + return nil + }) + } + return slaveGroup.Wait() + }) + } + g.Go(func() error { + for i := 0; i < jobs; i++ { + response := <-publishResponses + require.NoError(t, response.Err) + require.Len(t, response.Jobs, len(publishRequest.Payloads)) + } + close(publishResponses) + groupCancel() + return nil + }) + g.Go(func() error { + return n.RunMaintenance(gCtx) + }) + g.Go(func() error { + <-groupCtx.Done() + return n.Shutdown() + }) + require.NoError(t, g.Wait()) + }) + t.Run("env vars", func(t *testing.T) { + t.Parallel() + + pgResource := setup(t) + ctx := context.Background() + + c := config.New() + c.Set("PGNOTIFIER_DB_HOST", pgResource.Host) + c.Set("PGNOTIFIER_DB_USER", pgResource.User) + c.Set("PGNOTIFIER_DB_NAME", pgResource.Database) + c.Set("PGNOTIFIER_DB_PORT", pgResource.Port) + c.Set("PGNOTIFIER_DB_PASSWORD", pgResource.Password) + + n := notifier.New(c, logger.NOP, stats.Default, workspaceIdentifier) + err := n.Setup(ctx, pgResource.DBDsn) + require.NoError(t, err) + }) + t.Run("maintenance workflow", func(t *testing.T) { + t.Parallel() + + pgResource := setup(t) + + ctx := context.Background() + + c := config.New() + c.Set("PgNotifier.jobOrphanTimeout", "1s") + + g, _ := errgroup.WithContext(ctx) + g.Go(func() error { + ctxWithTimeout, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + n := notifier.New(c, logger.NOP, stats.Default, workspaceIdentifier) + err := n.Setup(ctx, pgResource.DBDsn) + require.NoError(t, err) + + err = n.RunMaintenance(ctxWithTimeout) + require.NoError(t, err) + return nil + }) + g.Go(func() error { + ctxWithTimeout, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + n := notifier.New(c, logger.NOP, stats.Default, workspaceIdentifier) + err := n.Setup(ctx, pgResource.DBDsn) + require.NoError(t, err) + + err = n.RunMaintenance(ctxWithTimeout) + require.NoError(t, err) + return nil + }) + require.NoError(t, g.Wait()) + }) +} diff --git a/services/notifier/repo.go b/services/notifier/repo.go new file mode 100644 index 0000000000..97cb5efaa1 --- /dev/null +++ b/services/notifier/repo.go @@ -0,0 +1,471 @@ +package notifier + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/lib/pq" + + "github.com/rudderlabs/rudder-server/utils/timeutil" + sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" +) + +const ( + notifierTableName = "pg_notifier_queue" + notifierTableColumns = ` + id, + batch_id, + worker_id, + workspace, + attempt, + status, + job_type, + priority, + error, + payload, + created_at, + updated_at, + last_exec_time +` +) + +type Opt func(*repo) + +type scanFn func(dest ...any) error + +func WithNow(now func() time.Time) Opt { + return func(r *repo) { + r.now = now + } +} + +type repo struct { + db *sqlmw.DB + now func() time.Time +} + +func newRepo(db *sqlmw.DB, opts ...Opt) *repo { + r := &repo{ + db: db, + now: timeutil.Now, + } + for _, opt := range opts { + opt(r) + } + return r +} + +// ResetForWorkspace deletes all the jobs for a specified workspace. +func (n *repo) resetForWorkspace( + ctx context.Context, + workspaceIdentifier string, +) error { + _, err := n.db.ExecContext(ctx, ` + DELETE FROM `+notifierTableName+` + WHERE workspace = $1; + `, + workspaceIdentifier, + ) + if err != nil { + return fmt.Errorf("reset: delete for workspace %s: %w", workspaceIdentifier, err) + } + return nil +} + +// Insert inserts a jobs into the notifier queue. +func (n *repo) insert( + ctx context.Context, + publishRequest *PublishRequest, + workspaceIdentifier string, + batchID string, +) error { + txn, err := n.db.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + return fmt.Errorf("inserting: begin transaction: %w", err) + } + defer func() { + if err != nil { + _ = txn.Rollback() + return + } + }() + + now := n.now() + + stmt, err := txn.PrepareContext( + ctx, + pq.CopyIn( + notifierTableName, + "batch_id", + "status", + "payload", + "workspace", + "priority", + "job_type", + "created_at", + "updated_at", + ), + ) + if err != nil { + return fmt.Errorf(`inserting: CopyIn: %w`, err) + } + defer func() { _ = stmt.Close() }() + + for _, payload := range publishRequest.Payloads { + _, err = stmt.ExecContext( + ctx, + batchID, + Waiting, + string(payload), + workspaceIdentifier, + publishRequest.Priority, + publishRequest.JobType, + now.UTC(), + now.UTC(), + ) + if err != nil { + return fmt.Errorf(`inserting: CopyIn exec: %w`, err) + } + } + if _, err = stmt.ExecContext(ctx); err != nil { + return fmt.Errorf(`inserting: CopyIn final exec: %w`, err) + } + + if publishRequest.UploadSchema != nil { + _, err = txn.ExecContext(ctx, ` + UPDATE + `+notifierTableName+` + SET + payload = payload || $1 + WHERE + batch_id = $2; + `, + publishRequest.UploadSchema, + batchID, + ) + if err != nil { + return fmt.Errorf(`updating: metadata: %w`, err) + } + } + + if err = txn.Commit(); err != nil { + return fmt.Errorf(`inserting: commit: %w`, err) + } + return nil +} + +// PendingByBatchID returns the number of pending jobs for a batchID. +func (n *repo) pendingByBatchID( + ctx context.Context, + batchID string, +) (int64, error) { + var count int64 + + err := n.db.QueryRowContext(ctx, ` + SELECT + COUNT(*) + FROM + `+notifierTableName+` + WHERE + batch_id = $1 AND + status != $2 AND + status != $3 +`, + batchID, + Succeeded, + Aborted, + ).Scan(&count) + if err != nil { + return 0, fmt.Errorf("pending by batchID: %w", err) + } + + return count, err +} + +// GetByBatchID returns all the jobs for a batchID. +// TODO: ATM Hack to remove `UploadSchema` from the payload to have the similar implementation as the old notifier. +func (n *repo) getByBatchID( + ctx context.Context, + batchID string, +) ([]Job, error) { + query := ` + SELECT + id, + batch_id, + worker_id, + workspace, + attempt, + status, + job_type, + priority, + error, + payload - 'UploadSchema', + created_at, + updated_at, + last_exec_time + FROM + ` + notifierTableName + ` + WHERE + batch_id = $1 + ORDER BY + id;` + + rows, err := n.db.QueryContext(ctx, query, batchID) + if err != nil { + return nil, fmt.Errorf("getting by batchID: %w", err) + } + defer func() { _ = rows.Close() }() + + var jobs []Job + for rows.Next() { + var job Job + err := scanJob(rows.Scan, &job) + if err != nil { + return nil, fmt.Errorf("getting by batchID: scan: %w", err) + } + + jobs = append(jobs, job) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("getting by batchID: rows err: %w", err) + } + if len(jobs) == 0 { + return nil, fmt.Errorf("getting by batchID: no jobs found") + } + + return jobs, err +} + +func scanJob(scan scanFn, job *Job) error { + var ( + jobTypeRaw sql.NullString + errorRaw sql.NullString + workerIDRaw sql.NullString + lasExecTime sql.NullTime + ) + + err := scan( + &job.ID, + &job.BatchID, + &workerIDRaw, + &job.WorkspaceIdentifier, + &job.Attempt, + &job.Status, + &jobTypeRaw, + &job.Priority, + &errorRaw, + &job.Payload, + &job.CreatedAt, + &job.UpdatedAt, + &lasExecTime, + ) + if err != nil { + return fmt.Errorf("scanning: %w", err) + } + + if workerIDRaw.Valid { + job.WorkerID = workerIDRaw.String + } + if jobTypeRaw.Valid { + switch jobTypeRaw.String { + case string(JobTypeUpload), string(JobTypeAsync): + job.Type = JobType(jobTypeRaw.String) + default: + return fmt.Errorf("scanning: unknown job type: %s", jobTypeRaw.String) + } + } else { + job.Type = JobTypeUpload + } + if errorRaw.Valid { + job.Error = errors.New(errorRaw.String) + } + if lasExecTime.Valid { + job.LastExecTime = lasExecTime.Time + } + + return nil +} + +// DeleteByBatchID deletes all the jobs for a batchID. +func (n *repo) deleteByBatchID( + ctx context.Context, + batchID string, +) error { + _, err := n.db.ExecContext(ctx, ` + DELETE FROM `+notifierTableName+` WHERE batch_id = $1; + `, + batchID, + ) + if err != nil { + return fmt.Errorf("deleting by batchID: %w", err) + } + return nil +} + +func (n *repo) claim( + ctx context.Context, + workerID string, +) (*Job, error) { + row := n.db.QueryRowContext(ctx, ` + UPDATE + `+notifierTableName+` + SET + status = $1, + updated_at = $2, + last_exec_time = $2, + worker_id = $3 + WHERE + id = ( + SELECT + id + FROM + `+notifierTableName+` + WHERE + (status = $4 OR status = $5) + ORDER BY + priority ASC, + id ASC + FOR + UPDATE + SKIP LOCKED + LIMIT + 1 + ) RETURNING `+notifierTableColumns+`; +`, + Executing, + n.now(), + workerID, + Waiting, + Failed, + ) + + var job Job + err := scanJob(row.Scan, &job) + if err != nil { + return nil, fmt.Errorf("claim for workerID %s: scan: %w", workerID, err) + } + return &job, nil +} + +// OnClaimFailed updates the status of a job to failed. +func (n *repo) onClaimFailed( + ctx context.Context, + job *Job, + claimError error, + maxAttempt int, +) error { + query := fmt.Sprint(` + UPDATE + ` + notifierTableName + ` + SET + status =( + CASE WHEN attempt > $1 THEN CAST ( + '` + Aborted + `' AS pg_notifier_status_type + ) ELSE CAST( + '` + Failed + `' AS pg_notifier_status_type + ) END + ), + attempt = attempt + 1, + updated_at = $2, + error = $3 + WHERE + id = $4; + `, + ) + + _, err := n.db.ExecContext(ctx, + query, + maxAttempt, + n.now(), + claimError.Error(), + job.ID, + ) + if err != nil { + return fmt.Errorf("on claim failed: %w", err) + } + + return nil +} + +// OnClaimSuccess updates the status of a job to succeed. +func (n *repo) onClaimSuccess( + ctx context.Context, + job *Job, + payload json.RawMessage, +) error { + _, err := n.db.ExecContext(ctx, ` + UPDATE + `+notifierTableName+` + SET + status = $1, + updated_at = $2, + payload = $3 + WHERE + id = $4; + `, + Succeeded, + n.now(), + string(payload), + job.ID, + ) + if err != nil { + return fmt.Errorf("on claim success: %w", err) + } + + return nil +} + +// OrphanJobIDs returns the IDs of the jobs that are in executing state for more than the given interval. +func (n *repo) orphanJobIDs( + ctx context.Context, + intervalInSeconds int, +) ([]int64, error) { + rows, err := n.db.QueryContext(ctx, ` + UPDATE + `+notifierTableName+` + SET + status = $1, + updated_at = $2 + WHERE + id IN ( + SELECT + id + FROM + `+notifierTableName+` + WHERE + status = $3 + AND last_exec_time <= NOW() - $4 * INTERVAL '1 SECOND' + FOR + UPDATE + SKIP LOCKED + ) RETURNING id; +`, + Waiting, + n.now(), + Executing, + intervalInSeconds, + ) + if err != nil { + return nil, fmt.Errorf("orphan jobs ids: %w", err) + } + defer func() { _ = rows.Close() }() + + var ids []int64 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, fmt.Errorf("orphan jobs ids: scanning: %w", err) + } + + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("orphan jobs ids: iterating over rows: %w", err) + } + + return ids, nil +} diff --git a/services/notifier/repo_test.go b/services/notifier/repo_test.go new file mode 100644 index 0000000000..f50c54eb49 --- /dev/null +++ b/services/notifier/repo_test.go @@ -0,0 +1,499 @@ +package notifier + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strconv" + "testing" + "time" + + "github.com/google/uuid" + "github.com/samber/lo" + + "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" + migrator "github.com/rudderlabs/rudder-server/services/sql-migrator" + sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" +) + +func TestRepo(t *testing.T) { + const ( + workspaceIdentifier = "test_workspace_identifier" + workerID = "test_worker" + ) + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + pgResource, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + + err = (&migrator.Migrator{ + Handle: pgResource.DB, + MigrationsTable: "pg_notifier_queue_migrations", + }).Migrate("pg_notifier_queue") + require.NoError(t, err) + + t.Log("db:", pgResource.DBDsn) + + ctx := context.Background() + now := time.Now().Truncate(time.Second).UTC() + + db := sqlmw.New(pgResource.DB) + + r := newRepo(db, WithNow(func() time.Time { + return now + })) + + publishRequest := PublishRequest{ + Payloads: []json.RawMessage{ + json.RawMessage(`{"id":"1"}`), + json.RawMessage(`{"id":"2"}`), + json.RawMessage(`{"id":"3"}`), + json.RawMessage(`{"id":"4"}`), + json.RawMessage(`{"id":"5"}`), + }, + JobType: JobTypeUpload, + UploadSchema: json.RawMessage(`{"UploadSchema":"1"}`), + Priority: 50, + } + + cancelledCtx, cancel := context.WithCancel(ctx) + cancel() + + t.Run("Insert and get", func(t *testing.T) { + t.Run("create jobs", func(t *testing.T) { + batchID := uuid.New().String() + + err := r.insert(ctx, &publishRequest, workspaceIdentifier, batchID) + require.NoError(t, err) + + jobs, err := r.getByBatchID(ctx, batchID) + require.NoError(t, err) + require.Len(t, jobs, len(publishRequest.Payloads)) + + for i, job := range jobs { + require.EqualValues(t, job.Payload, json.RawMessage(fmt.Sprintf(`{"id": "%d"}`, i+1))) + require.EqualValues(t, job.WorkspaceIdentifier, workspaceIdentifier) + require.EqualValues(t, job.BatchID, batchID) + require.EqualValues(t, job.Type, publishRequest.JobType) + require.EqualValues(t, job.Priority, publishRequest.Priority) + require.EqualValues(t, job.Status, Waiting) + require.EqualValues(t, job.WorkerID, "") + require.EqualValues(t, job.Attempt, 0) + require.EqualValues(t, job.CreatedAt.UTC(), now.UTC()) + require.EqualValues(t, job.UpdatedAt.UTC(), now.UTC()) + require.Nil(t, job.Error) + } + }) + + t.Run("missing batch id", func(t *testing.T) { + jobs, err := r.getByBatchID(ctx, "missing_batch_id") + require.EqualError(t, err, "getting by batchID: no jobs found") + require.Nil(t, jobs) + }) + + t.Run("context cancelled", func(t *testing.T) { + batchID := uuid.New().String() + + err := r.insert(cancelledCtx, &publishRequest, workspaceIdentifier, batchID) + require.ErrorIs(t, err, context.Canceled) + + jobs, err := r.getByBatchID(cancelledCtx, batchID) + require.ErrorIs(t, err, context.Canceled) + require.Nil(t, jobs) + }) + }) + + t.Run("delete by batch id", func(t *testing.T) { + t.Run("success", func(t *testing.T) { + batchID := uuid.New().String() + + err := r.insert(ctx, &publishRequest, workspaceIdentifier, batchID) + require.NoError(t, err) + + jobs, err := r.getByBatchID(ctx, batchID) + require.NoError(t, err) + require.Len(t, jobs, len(publishRequest.Payloads)) + + err = r.deleteByBatchID(ctx, batchID) + require.NoError(t, err) + + jobs, err = r.getByBatchID(ctx, batchID) + require.EqualError(t, err, "getting by batchID: no jobs found") + require.Nil(t, jobs) + }) + + t.Run("context cancelled", func(t *testing.T) { + batchID := uuid.New().String() + + err := r.insert(cancelledCtx, &publishRequest, workspaceIdentifier, batchID) + require.ErrorIs(t, err, context.Canceled) + + err = r.deleteByBatchID(cancelledCtx, batchID) + require.ErrorIs(t, err, context.Canceled) + }) + }) + + t.Run("reset workspace", func(t *testing.T) { + t.Run("success", func(t *testing.T) { + var workspaceIdentifiers []string + var batchIDs []string + + for i := 0; i < 10; i++ { + batchID := uuid.New().String() + workspaceIdentifier := workspaceIdentifier + "_" + uuid.New().String() + workspaceIdentifiers = append(workspaceIdentifiers, workspaceIdentifier) + + err := r.insert(ctx, &publishRequest, workspaceIdentifier, batchID) + require.NoError(t, err) + } + for _, batchID := range batchIDs { + jobs, err := r.getByBatchID(ctx, batchID) + require.NoError(t, err) + require.Len(t, jobs, len(publishRequest.Payloads)) + } + + for i, workspaceIdentifier := range workspaceIdentifiers { + if i%4 == 0 { + continue + } + + err = r.resetForWorkspace(ctx, workspaceIdentifier) + require.NoError(t, err) + } + + for i, batchID := range batchIDs { + if i%4 == 0 { + jobs, err := r.getByBatchID(ctx, batchID) + require.EqualError(t, err, "getting by batchID: no jobs found") + require.Nil(t, jobs) + continue + } + + jobs, err := r.getByBatchID(ctx, batchID) + require.NoError(t, err) + require.Len(t, jobs, len(publishRequest.Payloads)) + } + }) + + t.Run("without upload schema", func(t *testing.T) { + publishRequest := PublishRequest{ + Payloads: []json.RawMessage{ + json.RawMessage(`{"id":"11"}`), + }, + JobType: JobTypeUpload, + Priority: 75, + } + + batchID := uuid.New().String() + err := r.insert(ctx, &publishRequest, workspaceIdentifier, batchID) + require.NoError(t, err) + + err = r.resetForWorkspace(ctx, workspaceIdentifier) + require.NoError(t, err) + + jobs, err := r.getByBatchID(ctx, batchID) + require.EqualError(t, err, "getting by batchID: no jobs found") + require.Nil(t, jobs) + }) + + t.Run("context cancelled", func(t *testing.T) { + batchID := uuid.New().String() + + err := r.insert(cancelledCtx, &publishRequest, workspaceIdentifier, batchID) + require.ErrorIs(t, err, context.Canceled) + + err = r.resetForWorkspace(cancelledCtx, batchID) + require.ErrorIs(t, err, context.Canceled) + }) + + t.Run("empty", func(t *testing.T) { + err := r.resetForWorkspace(ctx, "missing_batch_id") + require.NoError(t, err) + }) + }) + + t.Run("pending by batch id", func(t *testing.T) { + t.Run("success", func(t *testing.T) { + batchID := uuid.New().String() + + err := r.insert(ctx, &publishRequest, workspaceIdentifier, batchID) + require.NoError(t, err) + + jobs, err := r.getByBatchID(ctx, batchID) + require.NoError(t, err) + require.Len(t, jobs, len(publishRequest.Payloads)) + + err = r.onClaimSuccess(ctx, &jobs[0], json.RawMessage(`{"test": "payload"}`)) + require.NoError(t, err) + err = r.onClaimFailed(ctx, &jobs[1], errors.New("test error"), 100) + require.NoError(t, err) + err = r.onClaimFailed(ctx, &jobs[2], errors.New("test error"), -1) + require.NoError(t, err) + + pendingCount, err := r.pendingByBatchID(ctx, batchID) + require.NoError(t, err) + require.EqualValues(t, pendingCount, 3) + }) + + t.Run("context cancelled", func(t *testing.T) { + batchID := uuid.New().String() + + err := r.insert(cancelledCtx, &publishRequest, workspaceIdentifier, batchID) + require.ErrorIs(t, err, context.Canceled) + + pendingCount, err := r.pendingByBatchID(cancelledCtx, batchID) + require.ErrorIs(t, err, context.Canceled) + require.Zero(t, pendingCount) + }) + + t.Run("no pending", func(t *testing.T) { + batchID := uuid.New().String() + + err := r.insert(ctx, &publishRequest, workspaceIdentifier, batchID) + require.NoError(t, err) + + jobs, err := r.getByBatchID(ctx, batchID) + require.NoError(t, err) + + for _, job := range jobs { + require.NoError(t, r.onClaimSuccess(ctx, &job, json.RawMessage(`{"test": "payload"}`))) + } + + pendingCount, err := r.pendingByBatchID(ctx, batchID) + require.NoError(t, err) + require.Zero(t, pendingCount) + }) + }) + + t.Run("orphan job ids", func(t *testing.T) { + t.Run("success", func(t *testing.T) { + batchID := uuid.New().String() + orphanInterval := 5 + + err := r.insert(ctx, &publishRequest, workspaceIdentifier, batchID) + require.NoError(t, err) + + jobs, err := r.getByBatchID(ctx, batchID) + require.NoError(t, err) + + t.Run("no orphans", func(t *testing.T) { + jobIDs, err := r.orphanJobIDs(ctx, orphanInterval) + require.NoError(t, err) + require.Len(t, jobIDs, 0) + }) + + t.Run("few orphans", func(t *testing.T) { + for _, job := range jobs[:3] { + _, err := db.ExecContext(ctx, ` + UPDATE + pg_notifier_queue + SET + status = 'executing', + last_exec_time = NOW() - $1 * INTERVAL '1 SECOND' + WHERE + id = $2;`, + 2*orphanInterval, + job.ID, + ) + require.NoError(t, err) + } + + jobIDs, err := r.orphanJobIDs(ctx, orphanInterval) + require.NoError(t, err) + require.Len(t, jobIDs, len(jobs[:3])) + for _, job := range jobs[:3] { + require.Contains(t, jobIDs, job.ID) + } + }) + }) + + t.Run("context cancelled", func(t *testing.T) { + batchID := uuid.New().String() + + err := r.insert(cancelledCtx, &publishRequest, workspaceIdentifier, batchID) + require.ErrorIs(t, err, context.Canceled) + + jobIDs, err := r.orphanJobIDs(cancelledCtx, 0) + require.ErrorIs(t, err, context.Canceled) + require.Nil(t, jobIDs) + }) + }) + + t.Run("claim", func(t *testing.T) { + uNow := now.Add(time.Second * 10).Truncate(time.Second).UTC() + ur := newRepo(db, WithNow(func() time.Time { + return uNow + })) + + t.Run("success", func(t *testing.T) { + _, err := db.ExecContext(ctx, "TRUNCATE TABLE pg_notifier_queue;") + require.NoError(t, err) + + batchID := uuid.New().String() + + err = r.insert(ctx, &publishRequest, workspaceIdentifier, batchID) + require.NoError(t, err) + + t.Run("with jobs", func(t *testing.T) { + jobs, err := r.getByBatchID(ctx, batchID) + require.NoError(t, err) + + for i, job := range jobs { + claimedJob, err := ur.claim(ctx, workerID+strconv.Itoa(i)) + require.NoError(t, err) + require.EqualValues(t, claimedJob.ID, job.ID) + require.EqualValues(t, claimedJob.BatchID, job.BatchID) + require.EqualValues(t, claimedJob.WorkerID, workerID+strconv.Itoa(i)) + require.EqualValues(t, claimedJob.WorkspaceIdentifier, job.WorkspaceIdentifier) + require.EqualValues(t, claimedJob.Status, Executing) + require.EqualValues(t, claimedJob.Type, job.Type) + require.EqualValues(t, claimedJob.Priority, job.Priority) + require.EqualValues(t, claimedJob.Attempt, job.Attempt) + require.EqualValues(t, claimedJob.Error, job.Error) + require.EqualValues(t, claimedJob.Payload, json.RawMessage(fmt.Sprintf(`{"id": "%d", "UploadSchema": "1"}`, i+1))) + require.EqualValues(t, claimedJob.CreatedAt.UTC(), job.CreatedAt.UTC()) + require.EqualValues(t, claimedJob.UpdatedAt.UTC(), uNow.UTC()) + require.EqualValues(t, claimedJob.LastExecTime.UTC(), uNow.UTC()) + } + }) + + t.Run("no jobs", func(t *testing.T) { + claimedJob, err := ur.claim(ctx, workerID) + require.ErrorIs(t, err, sql.ErrNoRows) + require.Nil(t, claimedJob) + }) + }) + + t.Run("context cancelled", func(t *testing.T) { + batchID := uuid.New().String() + + err := r.insert(cancelledCtx, &publishRequest, workspaceIdentifier, batchID) + require.ErrorIs(t, err, context.Canceled) + + claimedJob, err := ur.claim(cancelledCtx, workerID) + require.ErrorIs(t, err, context.Canceled) + require.Nil(t, claimedJob) + }) + }) + + t.Run("claim success", func(t *testing.T) { + uNow := now.Add(time.Second * 10).Truncate(time.Second).UTC() + ur := newRepo(db, WithNow(func() time.Time { + return uNow + })) + + t.Run("success", func(t *testing.T) { + batchID := uuid.New().String() + payload := json.RawMessage(`{"test": "payload"}`) + + err := r.insert(ctx, &publishRequest, workspaceIdentifier, batchID) + require.NoError(t, err) + + jobs, err := r.getByBatchID(ctx, batchID) + require.NoError(t, err) + + for _, job := range jobs { + require.NoError(t, ur.onClaimSuccess(ctx, &job, json.RawMessage(`{"test": "payload"}`))) + } + + successClaims, err := ur.getByBatchID(ctx, batchID) + require.NoError(t, err) + for _, job := range successClaims { + require.EqualValues(t, job.UpdatedAt.UTC(), uNow.UTC()) + require.EqualValues(t, job.Status, Succeeded) + require.EqualValues(t, job.Payload, payload) + require.Nil(t, job.Error) + } + }) + + t.Run("context cancelled", func(t *testing.T) { + batchID := uuid.New().String() + + err := r.insert(cancelledCtx, &publishRequest, workspaceIdentifier, batchID) + require.ErrorIs(t, err, context.Canceled) + + err = ur.onClaimSuccess(cancelledCtx, &Job{ID: 1}, nil) + require.ErrorIs(t, err, context.Canceled) + }) + }) + + t.Run("claim failure", func(t *testing.T) { + uNow := now.Add(time.Second * 10).Truncate(time.Second).UTC() + ur := newRepo(db, WithNow(func() time.Time { + return uNow + })) + + t.Run("first failed and then succeeded", func(t *testing.T) { + batchID := uuid.New().String() + + err := r.insert(ctx, &publishRequest, workspaceIdentifier, batchID) + require.NoError(t, err) + + jobs, err := r.getByBatchID(ctx, batchID) + require.NoError(t, err) + + t.Run("marking failed", func(t *testing.T) { + for i, job := range jobs { + for j := 0; j < i+1; j++ { + require.NoError(t, ur.onClaimFailed(ctx, &job, errors.New("test_error"), 2)) + } + } + + failedClaims, err := ur.getByBatchID(ctx, batchID) + require.NoError(t, err) + require.Equal(t, []JobStatus{ + Failed, + Failed, + Failed, + Aborted, + Aborted, + }, + lo.Map(failedClaims, func(item Job, index int) JobStatus { + return item.Status + }), + ) + + for i, job := range failedClaims { + require.EqualValues(t, job.Error, errors.New("test_error")) + require.EqualValues(t, job.Attempt, i+1) + require.EqualValues(t, job.UpdatedAt.UTC(), uNow.UTC()) + } + }) + + t.Run("marking succeeded", func(t *testing.T) { + failedClaims, err := ur.getByBatchID(ctx, batchID) + require.NoError(t, err) + + for _, job := range failedClaims { + require.NoError(t, ur.onClaimSuccess(ctx, &job, json.RawMessage(`{"test": "payload"}`))) + } + + successClaims, err := ur.getByBatchID(ctx, batchID) + require.NoError(t, err) + for i, job := range successClaims { + require.EqualValues(t, job.UpdatedAt.UTC(), uNow.UTC()) + require.EqualValues(t, job.Status, Succeeded) + require.EqualValues(t, job.Attempt, i+1) + require.EqualValues(t, job.Error, errors.New("test_error")) + } + }) + }) + + t.Run("context cancelled", func(t *testing.T) { + batchID := uuid.New().String() + + err := r.insert(cancelledCtx, &publishRequest, workspaceIdentifier, batchID) + require.ErrorIs(t, err, context.Canceled) + + err = ur.onClaimFailed(cancelledCtx, &Job{ID: 1}, errors.New("test_error"), 0) + require.ErrorIs(t, err, context.Canceled) + }) + }) +} diff --git a/services/pgnotifier/pgnotifier.go b/services/pgnotifier/pgnotifier.go deleted file mode 100644 index feb76b1576..0000000000 --- a/services/pgnotifier/pgnotifier.go +++ /dev/null @@ -1,738 +0,0 @@ -package pgnotifier - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - "math/rand" - "time" - - "github.com/allisson/go-pglock/v2" - "github.com/lib/pq" - "github.com/spaolacci/murmur3" - - "github.com/rudderlabs/rudder-go-kit/config" - "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-go-kit/stats" - "github.com/rudderlabs/rudder-server/rruntime" - migrator "github.com/rudderlabs/rudder-server/services/sql-migrator" - "github.com/rudderlabs/rudder-server/utils/misc" - sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" - whUtils "github.com/rudderlabs/rudder-server/warehouse/utils" -) - -var ( - queueName string - maxAttempt int - trackBatchInterval time.Duration - maxPollSleep misc.ValueLoader[time.Duration] - jobOrphanTimeout misc.ValueLoader[time.Duration] - pkgLogger logger.Logger -) - -var ( - pgNotifierDBHost, pgNotifierDBUser, pgNotifierDBPassword, pgNotifierDBName, pgNotifierDBSSLMode string - pgNotifierDBPort int - pgNotifierPublish, pgNotifierPublishTime stats.Measurement - pgNotifierClaimSucceeded, pgNotifierClaimSucceededTime, pgNotifierClaimFailed, pgNotifierClaimFailedTime stats.Measurement - pgNotifierClaimUpdateFailed stats.Measurement -) - -const ( - WaitingState = "waiting" - ExecutingState = "executing" - SucceededState = "succeeded" - FailedState = "failed" - AbortedState = "aborted" -) - -const ( - AsyncJobType = "async_job" -) - -func Init() { - loadPGNotifierConfig() - queueName = "pg_notifier_queue" - pkgLogger = logger.NewLogger().Child("warehouse").Child("pgnotifier") -} - -type PGNotifier struct { - URI string - db *sqlmiddleware.DB - workspaceIdentifier string -} - -type JobPayload json.RawMessage - -type Response struct { - JobID int64 - Status string - Output json.RawMessage - Error string - JobType string -} - -type JobsResponse struct { - Status string - Output json.RawMessage - Error string - JobType string - JobRunID string - TaskRunID string -} -type Claim struct { - ID int64 - BatchID string - Status string - Workspace string - Payload json.RawMessage - Attempt int - JobType string -} - -type ClaimResponse struct { - Payload json.RawMessage - Err error -} - -type MessagePayload struct { - Jobs []JobPayload - JobType string -} - -func loadPGNotifierConfig() { - pgNotifierDBHost = config.GetString("PGNOTIFIER_DB_HOST", "localhost") - pgNotifierDBUser = config.GetString("PGNOTIFIER_DB_USER", "ubuntu") - pgNotifierDBName = config.GetString("PGNOTIFIER_DB_NAME", "ubuntu") - pgNotifierDBPort = config.GetInt("PGNOTIFIER_DB_PORT", 5432) - pgNotifierDBPassword = config.GetString("PGNOTIFIER_DB_PASSWORD", "ubuntu") // Reading secrets from - pgNotifierDBSSLMode = config.GetString("PGNOTIFIER_DB_SSL_MODE", "disable") - maxAttempt = config.GetIntVar(3, 1, "PgNotifier.maxAttempt") - trackBatchInterval = time.Duration(config.GetInt("PgNotifier.trackBatchIntervalInS", 2)) * time.Second - maxPollSleep = config.GetReloadableDurationVar(5000, time.Millisecond, "PgNotifier.maxPollSleep") - jobOrphanTimeout = config.GetReloadableDurationVar(120, time.Second, "PgNotifier.jobOrphanTimeout") -} - -// New Given default connection info return pg notifier object from it -func New(workspaceIdentifier, fallbackConnectionInfo string) (notifier PGNotifier, err error) { - // by default connection info is fallback connection info - connectionInfo := fallbackConnectionInfo - - // if PG Notifier variables are defined then use get values provided in env vars - if CheckForPGNotifierEnvVars() { - connectionInfo = GetPGNotifierConnectionString() - } - pkgLogger.Infof("PgNotifier: Initializing PgNotifier...") - dbHandle, err := sql.Open("postgres", connectionInfo) - if err != nil { - return - } - dbHandle.SetMaxOpenConns(config.GetInt("PgNotifier.maxOpenConnections", 20)) - - // setup metrics - pgNotifierModuleTag := whUtils.Tag{Name: "module", Value: "pgnotifier"} - // publish metrics - pgNotifierPublish = whUtils.NewCounterStat("pgnotifier_publish", pgNotifierModuleTag) - pgNotifierPublishTime = whUtils.NewTimerStat("pgnotifier_publish_time", pgNotifierModuleTag) - // claim metrics - pgNotifierClaimSucceeded = whUtils.NewCounterStat("pgnotifier_claim", pgNotifierModuleTag, whUtils.Tag{Name: "status", Value: "succeeded"}) - pgNotifierClaimFailed = whUtils.NewCounterStat("pgnotifier_claim", pgNotifierModuleTag, whUtils.Tag{Name: "status", Value: "failed"}) - pgNotifierClaimSucceededTime = whUtils.NewTimerStat("pgnotifier_claim_time", pgNotifierModuleTag, whUtils.Tag{Name: "status", Value: "succeeded"}) - pgNotifierClaimFailedTime = whUtils.NewTimerStat("pgnotifier_claim_time", pgNotifierModuleTag, whUtils.Tag{Name: "status", Value: "failed"}) - pgNotifierClaimUpdateFailed = whUtils.NewCounterStat("pgnotifier_claim_update_failed", pgNotifierModuleTag) - - notifier = PGNotifier{ - db: sqlmiddleware.New( - dbHandle, - sqlmiddleware.WithQueryTimeout(config.GetDuration("Warehouse.pgNotifierQueryTimeout", 5, time.Minute)), - ), - URI: connectionInfo, - workspaceIdentifier: workspaceIdentifier, - } - err = notifier.setupQueue() - return -} - -func (notifier *PGNotifier) GetDBHandle() *sql.DB { - return notifier.db.DB -} - -func (notifier *PGNotifier) ClearJobs(ctx context.Context) (err error) { - // clean up all jobs in pgnotifier for same workspace - // additional safety check to not delete all jobs with empty workspaceIdentifier - if notifier.workspaceIdentifier != "" { - stmt := fmt.Sprintf(` - DELETE FROM - %s - WHERE - workspace = '%s'; -`, - queueName, - notifier.workspaceIdentifier, - ) - pkgLogger.Infof("PgNotifier: Deleting all jobs for workspace: %s", notifier.workspaceIdentifier) - _, err = notifier.db.ExecContext(ctx, stmt) - if err != nil { - return - } - } - - return -} - -// CheckForPGNotifierEnvVars Checks if all the required Env Variables for PG Notifier are present -func CheckForPGNotifierEnvVars() bool { - return config.IsSet("PGNOTIFIER_DB_HOST") && - config.IsSet("PGNOTIFIER_DB_USER") && - config.IsSet("PGNOTIFIER_DB_NAME") && - config.IsSet("PGNOTIFIER_DB_PASSWORD") -} - -// GetPGNotifierConnectionString Returns PG Notifier DB Connection Configuration -func GetPGNotifierConnectionString() string { - pkgLogger.Debugf("WH: All Env variables required for separate PG Notifier are set... Check pg notifier says True...") - return fmt.Sprintf("host=%s port=%d user=%s "+ - "password=%s dbname=%s sslmode=%s", - pgNotifierDBHost, pgNotifierDBPort, pgNotifierDBUser, - pgNotifierDBPassword, pgNotifierDBName, pgNotifierDBSSLMode) -} - -// trackUploadBatch tracks the upload batches until they are complete and triggers output through channel of type ResponseT -func (notifier *PGNotifier) trackUploadBatch(ctx context.Context, batchID string, ch *chan []Response) { - rruntime.GoForWarehouse(func() { - for { - time.Sleep(trackBatchInterval) - // keep polling db for batch status - // or subscribe to triggers - stmt := fmt.Sprintf(` - SELECT - count(*) - FROM - %s - WHERE - batch_id = '%s' - AND status != '%s' - AND status != '%s'; -`, - queueName, - batchID, - SucceededState, - AbortedState, - ) - var count int - err := notifier.db.QueryRowContext(ctx, stmt).Scan(&count) - if err != nil { - pkgLogger.Errorf("PgNotifier: Failed to query for tracking jobs by batch_id: %s, connInfo: %s", stmt, notifier.URI) - panic(err) - } - - if count == 0 { - stmt = fmt.Sprintf(` - SELECT - payload -> 'StagingFileID', - payload -> 'Output', - status, - error - FROM - %s - WHERE - batch_id = '%s'; -`, - queueName, - batchID, - ) - rows, err := notifier.db.QueryContext(ctx, stmt) - if err != nil { - panic(err) - } - var responses []Response - for rows.Next() { - var status, jobError, output sql.NullString - var jobID int64 - err = rows.Scan(&jobID, &output, &status, &jobError) - if err != nil { - panic(fmt.Errorf("Failed to scan result from query: %s\nwith Error : %w", stmt, err)) - } - responses = append(responses, Response{ - JobID: jobID, - Output: []byte(output.String), - Status: status.String, - Error: jobError.String, - }) - } - if rows.Err() != nil { - panic(fmt.Errorf("Failed to scan result from query: %s\nwith Error : %w", stmt, rows.Err())) - } - _ = rows.Close() - *ch <- responses - pkgLogger.Infof("PgNotifier: Completed processing all files in batch: %s", batchID) - stmt = fmt.Sprintf(` - DELETE FROM - %s - WHERE - batch_id = '%s'; -`, - queueName, - batchID, - ) - _, err = notifier.db.ExecContext(ctx, stmt) - if err != nil { - pkgLogger.Errorf("PgNotifier: Error deleting from %s for batch_id:%s : %v", queueName, batchID, err) - } - break - } - pkgLogger.Debugf("PgNotifier: Pending %d files to process in batch: %s", count, batchID) - } - }) -} - -// trackAsyncBatch tracks the upload batches until they are complete and triggers output through channel of type ResponseT -func (notifier *PGNotifier) trackAsyncBatch(ctx context.Context, batchID string, ch *chan []Response) { - rruntime.GoForWarehouse(func() { - // retry := 0 - var responses []Response - for { - time.Sleep(trackBatchInterval) - // keep polling db for batch status - // or subscribe to triggers - stmt := fmt.Sprintf(`SELECT count(*) FROM %s WHERE batch_id=$1 AND status!=$2 AND status!=$3`, queueName) - var count int - err := notifier.db.QueryRowContext(ctx, stmt, batchID, SucceededState, AbortedState).Scan(&count) - if err != nil { - *ch <- responses - pkgLogger.Errorf("PgNotifier: Failed to query for tracking jobs by batch_id: %s, connInfo: %s, error : %s", stmt, notifier.URI, err.Error()) - break - } - - if count == 0 { - stmt = fmt.Sprintf(`SELECT payload, status, error FROM %s WHERE batch_id = $1`, queueName) - rows, err := notifier.db.QueryContext(ctx, stmt, batchID) - if err != nil { - *ch <- responses - pkgLogger.Errorf("PgNotifier: Failed to query for getting jobs for payload, status & error: %s, connInfo: %s, error : %s", stmt, notifier.URI, err.Error()) - break - } - for rows.Next() { - var status, jobError sql.NullString - var payload json.RawMessage - err = rows.Scan(&payload, &status, &jobError) - if err != nil { - continue - } - responses = append(responses, Response{ - JobID: 0, // Not required for this as there is no concept of BatchFileId - Output: payload, - Status: status.String, - Error: jobError.String, - }) - } - if err := rows.Err(); err != nil { - *ch <- responses - pkgLogger.Errorf("PgNotifier: Failed to query for getting jobs for payload with rows error, status & error: %s, connInfo: %s, error : %v", stmt, notifier.URI, err) - break - } - _ = rows.Close() - *ch <- responses - pkgLogger.Infof("PgNotifier: Completed processing asyncjobs in batch: %s", batchID) - stmt = fmt.Sprintf(`DELETE FROM %s WHERE batch_id = $1`, queueName) - pkgLogger.Infof("Query for deleting pgnotifier rows is %s for batchId : %s in queueName: %s", stmt, batchID, queueName) - _, err = notifier.db.ExecContext(ctx, stmt, batchID) - if err != nil { - pkgLogger.Errorf("PgNotifier: Error deleting from %s for batch_id:%s : %v", queueName, batchID, err) - } - break - } - pkgLogger.Debugf("PgNotifier: Pending %d files to process in batch: %s", count, batchID) - } - }) -} - -func (notifier *PGNotifier) UpdateClaimedEvent(claim *Claim, response *ClaimResponse) { - var err error - if response.Err != nil { - pkgLogger.Error(response.Err.Error()) - stmt := fmt.Sprintf(` - UPDATE - %[1]s - SET - status =( - CASE WHEN attempt > %[2]d THEN CAST ( - '%[3]s' AS pg_notifier_status_type - ) ELSE CAST( - '%[4]s' AS pg_notifier_status_type - ) END - ), - attempt = attempt + 1, - updated_at = '%[5]s', - error = %[6]s - WHERE - id = %[7]v; -`, - queueName, - maxAttempt, - AbortedState, - FailedState, - GetCurrentSQLTimestamp(), - misc.QuoteLiteral(response.Err.Error()), - claim.ID, - ) - _, err = notifier.db.Exec(stmt) - - // Sending stats when we mark pg_notifier status as aborted. - if claim.Attempt > maxAttempt { - stats.Default.NewTaggedStat("pg_notifier_aborted_records", stats.CountType, map[string]string{ - "queueName": queueName, - "workspace": claim.Workspace, - "module": "pg_notifier", - }).Increment() - } - } else { - stmt := fmt.Sprintf(` - UPDATE - %[1]s - SET - status = '%[2]s', - updated_at = '%[3]s', - payload = $1 - WHERE - id = %[4]v; -`, - queueName, - SucceededState, - GetCurrentSQLTimestamp(), - claim.ID, - ) - _, err = notifier.db.Exec(stmt, response.Payload) - } - - if err != nil { - pgNotifierClaimUpdateFailed.Increment() - pkgLogger.Errorf("PgNotifier: Failed to update claimed event: %v", err) - } -} - -func (notifier *PGNotifier) claim(workerID string) (claim Claim, err error) { - claimStartTime := time.Now() - defer func() { - if err != nil { - pgNotifierClaimFailedTime.Since(claimStartTime) - pgNotifierClaimFailed.Increment() - return - } - pgNotifierClaimSucceededTime.Since(claimStartTime) - pgNotifierClaimSucceeded.Increment() - }() - var claimedID int64 - var attempt int - var batchID, status, workspace string - var jobType sql.NullString - var payload json.RawMessage - stmt := fmt.Sprintf(` - UPDATE - %[1]s - SET - status = '%[2]s', - updated_at = '%[3]s', - last_exec_time = '%[3]s', - worker_id = '%[4]v' - WHERE - id = ( - SELECT - id - FROM - %[1]s - WHERE - status = '%[5]s' - OR status = '%[6]s' - ORDER BY - priority ASC, - id ASC FOR - UPDATE - SKIP LOCKED - LIMIT - 1 - ) RETURNING id, - batch_id, - status, - payload, - workspace, - attempt, - job_type; -`, - queueName, - ExecutingState, - GetCurrentSQLTimestamp(), - workerID, - WaitingState, - FailedState, - ) - - tx, err := notifier.db.Begin() - if err != nil { - return - } - err = tx.QueryRow(stmt).Scan(&claimedID, &batchID, &status, &payload, &workspace, &attempt, &jobType) - defer func() { - if err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - err = fmt.Errorf("%v: %w", err, rollbackErr) - } - } - }() - if err == sql.ErrNoRows { - return - } - if err != nil { - pkgLogger.Errorf("PgNotifier: Claim failed: %v, query: %s, connInfo: %s", err, stmt, notifier.URI) - return - } - - err = tx.Commit() - if err != nil { - pkgLogger.Errorf("PgNotifier: Error committing claim txn: %v", err) - return - } - - // fallback to upload if jobType is not valid - if !jobType.Valid { - jobType = sql.NullString{String: "upload", Valid: true} - } - - claim = Claim{ - ID: claimedID, - BatchID: batchID, - Status: status, - Payload: payload, - Attempt: attempt, - Workspace: workspace, - JobType: jobType.String, - } - return claim, nil -} - -func (notifier *PGNotifier) Publish(ctx context.Context, payload MessagePayload, schema *whUtils.Schema, priority int) (ch chan []Response, err error) { - publishStartTime := time.Now() - jobs := payload.Jobs - defer func() { - if err == nil { - pgNotifierPublishTime.Since(publishStartTime) - pgNotifierPublish.Increment() - } - }() - - ch = make(chan []Response) - - // Using transactions for bulk copying - txn, err := notifier.db.Begin() - if err != nil { - err = fmt.Errorf("PgNotifier: Failed creating transaction for publishing with error: %w", err) - return - } - defer func() { - if err != nil { - if rollbackErr := txn.Rollback(); rollbackErr != nil { - pkgLogger.Errorf("PgNotifier: Failed rollback transaction for publishing with error: %s", rollbackErr.Error()) - } - } - }() - - stmt, err := txn.Prepare(pq.CopyIn(queueName, "batch_id", "status", "payload", "workspace", "priority", "job_type")) - if err != nil { - err = fmt.Errorf("PgNotifier: Failed creating prepared statement for publishing with error: %w", err) - return - } - defer stmt.Close() - - batchID := misc.FastUUID().String() - pkgLogger.Infof("PgNotifier: Inserting %d records into %s as batch: %s", len(jobs), queueName, batchID) - for _, job := range jobs { - _, err = stmt.ExecContext(ctx, batchID, WaitingState, string(job), notifier.workspaceIdentifier, priority, payload.JobType) - if err != nil { - err = fmt.Errorf("PgNotifier: Failed executing prepared statement for publishing with error: %w", err) - return - } - } - _, err = stmt.Exec() - if err != nil { - err = fmt.Errorf("PgNotifier: Failed publishing prepared statement for publishing with error: %w", err) - return - } - - uploadSchemaJSON, err := json.Marshal(struct { - UploadSchema whUtils.Schema - }{ - UploadSchema: *schema, - }) - if err != nil { - err = fmt.Errorf("PgNotifier: Failed unmarshalling uploadschema for publishing with error: %w", err) - return - } - - sqlStatement := ` - UPDATE - pg_notifier_queue - SET - payload = payload || $1 - WHERE - batch_id = $2;` - _, err = txn.ExecContext(ctx, sqlStatement, uploadSchemaJSON, batchID) - if err != nil { - err = fmt.Errorf("PgNotifier: Failed updating uploadschema for publishing with error: %w", err) - return - } - - err = txn.Commit() - if err != nil { - err = fmt.Errorf("PgNotifier: Failed committing transaction for publishing with error: %w", err) - return - } - - pkgLogger.Infof("PgNotifier: Inserted %d records into %s as batch: %s", len(jobs), queueName, batchID) - stats.Default.NewTaggedStat("pg_notifier_insert_records", stats.CountType, map[string]string{ - "queueName": queueName, - "module": "pg_notifier", - }).Count(len(jobs)) - if payload.JobType == AsyncJobType { - notifier.trackAsyncBatch(ctx, batchID, &ch) - return - } - notifier.trackUploadBatch(ctx, batchID, &ch) - return -} - -func (notifier *PGNotifier) Subscribe(ctx context.Context, workerId string, jobsBufferSize int) chan Claim { - jobs := make(chan Claim, jobsBufferSize) - rruntime.GoForWarehouse(func() { - pollSleep := time.Duration(0) - defer close(jobs) - for { - claimedJob, err := notifier.claim(workerId) - if err == nil { - jobs <- claimedJob - pollSleep = time.Duration(0) - } else { - pollSleep = 2*pollSleep + time.Duration(rand.Intn(100))*time.Millisecond - if pollSleep > maxPollSleep.Load() { - pollSleep = maxPollSleep.Load() - } - } - select { - case <-ctx.Done(): - return - case <-time.After(pollSleep): - } - } - }) - return jobs -} - -func (notifier *PGNotifier) setupQueue() (err error) { - pkgLogger.Infof("PgNotifier: Creating Job Queue Tables ") - - m := &migrator.Migrator{ - Handle: notifier.GetDBHandle(), - MigrationsTable: "pg_notifier_queue_migrations", - ShouldForceSetLowerVersion: config.GetBool("SQLMigrator.forceSetLowerVersion", true), - } - err = m.Migrate("pg_notifier_queue") - if err != nil { - panic(fmt.Errorf("could not run pg_notifier_queue migrations: %w", err)) - } - - return -} - -// GetCurrentSQLTimestamp to get sql complaint current datetime string -func GetCurrentSQLTimestamp() string { - const SQLTimeFormat = "2006-01-02 15:04:05" - return time.Now().Format(SQLTimeFormat) -} - -// RunMaintenanceWorker (blocking - to be called from go routine) re-triggers zombie jobs -// which were left behind by dead workers in executing state -func (notifier *PGNotifier) RunMaintenanceWorker(ctx context.Context) error { - maintenanceWorkerLockID := murmur3.Sum64([]byte(queueName)) - maintenanceWorkerLock, err := pglock.NewLock(ctx, int64(maintenanceWorkerLockID), notifier.GetDBHandle()) - if err != nil { - return err - } - - var locked bool - defer func() { - if locked { - if err := maintenanceWorkerLock.Unlock(ctx); err != nil { - pkgLogger.Errorf("Error while unlocking maintenance worker lock: %v", err) - } - } - }() - for { - locked, err = maintenanceWorkerLock.Lock(ctx) - if err != nil { - pkgLogger.Errorf("Received error trying to acquire maintenance worker lock: %v", err) - } - if locked { - break - } - - select { - case <-ctx.Done(): - return nil - case <-time.After(jobOrphanTimeout.Load() / 5): - } - } - for { - stmt := fmt.Sprintf(` - UPDATE - %[1]s - SET - status = '%[3]s', - updated_at = '%[2]s' - WHERE - id IN ( - SELECT - id - FROM - %[1]s - WHERE - status = '%[4]s' - AND last_exec_time <= NOW() - INTERVAL '%[5]v seconds' FOR - UPDATE - SKIP LOCKED - ) RETURNING id; -`, - queueName, - GetCurrentSQLTimestamp(), - WaitingState, - ExecutingState, - int(jobOrphanTimeout.Load()/time.Second), - ) - pkgLogger.Debugf("PgNotifier: re-triggering zombie jobs: %v", stmt) - rows, err := notifier.db.Query(stmt) - if err != nil { - panic(err) - } - var ids []int64 - for rows.Next() { - var id int64 - err := rows.Scan(&id) - if err != nil { - pkgLogger.Errorf("PgNotifier: Error scanning returned id from re-triggered jobs: %v", err) - continue - } - ids = append(ids, id) - } - if err := rows.Err(); err != nil { - panic(err) - } - - _ = rows.Close() - pkgLogger.Debugf("PgNotifier: Re-triggered job ids: %v", ids) - - select { - case <-ctx.Done(): - return nil - case <-time.After(jobOrphanTimeout.Load() / 5): - } - } -} diff --git a/warehouse/http.go b/warehouse/http.go index 5b6754d188..e860587cb1 100644 --- a/warehouse/http.go +++ b/warehouse/http.go @@ -12,6 +12,8 @@ import ( "strings" "time" + "github.com/rudderlabs/rudder-server/services/notifier" + "github.com/bugsnag/bugsnag-go/v2" "github.com/go-chi/chi/v5" @@ -24,7 +26,6 @@ import ( "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/stats" backendconfig "github.com/rudderlabs/rudder-server/backend-config" - "github.com/rudderlabs/rudder-server/services/pgnotifier" sqlmw "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" "github.com/rudderlabs/rudder-server/warehouse/internal/model" "github.com/rudderlabs/rudder-server/warehouse/internal/repo" @@ -63,7 +64,7 @@ type Api struct { logger logger.Logger statsFactory stats.Stats db *sqlmw.DB - notifier *pgnotifier.PGNotifier + notifier *notifier.Notifier bcConfig backendconfig.BackendConfig tenantManager *multitenant.Manager bcManager *backendConfigManager @@ -88,7 +89,7 @@ func NewApi( statsFactory stats.Stats, bcConfig backendconfig.BackendConfig, db *sqlmw.DB, - notifier *pgnotifier.PGNotifier, + notifier *notifier.Notifier, tenantManager *multitenant.Manager, bcManager *backendConfigManager, asyncManager *jobs.AsyncJobWh, @@ -178,7 +179,7 @@ func (a *Api) healthHandler(w http.ResponseWriter, r *http.Request) { defer cancel() if a.config.runningMode != DegradedMode { - if !checkHealth(ctx, a.notifier.GetDBHandle()) { + if !a.notifier.CheckHealth(ctx) { http.Error(w, "Cannot connect to notifierService", http.StatusInternalServerError) return } diff --git a/warehouse/http_test.go b/warehouse/http_test.go index 045ad1cd2d..2e3aaf8489 100644 --- a/warehouse/http_test.go +++ b/warehouse/http_test.go @@ -12,6 +12,8 @@ import ( "testing" "time" + "github.com/rudderlabs/rudder-server/services/notifier" + "github.com/rudderlabs/rudder-server/utils/httputil" "golang.org/x/sync/errgroup" @@ -34,14 +36,12 @@ import ( "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" backendconfig "github.com/rudderlabs/rudder-server/backend-config" mocksBackendConfig "github.com/rudderlabs/rudder-server/mocks/backend-config" - "github.com/rudderlabs/rudder-server/services/pgnotifier" migrator "github.com/rudderlabs/rudder-server/services/sql-migrator" "github.com/rudderlabs/rudder-server/utils/pubsub" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) func TestHTTPApi(t *testing.T) { - pgnotifier.Init() Init4() const ( @@ -172,19 +172,20 @@ func TestHTTPApi(t *testing.T) { db := sqlmiddleware.New(pgResource.DB) - notifier, err := pgnotifier.New(workspaceIdentifier, pgResource.DBDsn) - require.NoError(t, err) - tenantManager := multitenant.New(c, mockBackendConfig) bcManager := newBackendConfigManager(config.Default, db, tenantManager, logger.NOP) ctx, stopTest := context.WithCancel(context.Background()) + n := notifier.New(config.Default, logger.NOP, stats.Default, workspaceIdentifier) + err = n.Setup(ctx, pgResource.DBDsn) + require.NoError(t, err) + jobsManager := jobs.InitWarehouseJobsAPI( ctx, db.DB, - ¬ifier, + n, ) jobs.WithConfig(jobsManager, config.Default) @@ -405,7 +406,7 @@ func TestHTTPApi(t *testing.T) { c := config.New() c.Set("Warehouse.runningMode", tc.runningMode) - a := NewApi(tc.mode, c, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(tc.mode, c, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.healthHandler(resp, req) var healthBody map[string]string @@ -423,7 +424,7 @@ func TestHTTPApi(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/pending-events", bytes.NewReader([]byte(`"Invalid payload"`))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.pendingEventsHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -441,7 +442,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.pendingEventsHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -459,7 +460,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.pendingEventsHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -477,7 +478,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.pendingEventsHandler(resp, req) require.Equal(t, http.StatusServiceUnavailable, resp.Code) @@ -495,7 +496,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.pendingEventsHandler(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -521,7 +522,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.pendingEventsHandler(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -553,7 +554,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.pendingEventsHandler(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -575,7 +576,7 @@ func TestHTTPApi(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/internal/v1/warehouse/fetch-tables", bytes.NewReader([]byte(`"Invalid payload"`))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.fetchTablesHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -592,7 +593,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.fetchTablesHandler(resp, req) require.Equal(t, http.StatusInternalServerError, resp.Code) @@ -614,7 +615,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.fetchTablesHandler(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -637,7 +638,7 @@ func TestHTTPApi(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/v1/warehouse/trigger-upload", bytes.NewReader([]byte(`"Invalid payload"`))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.triggerUploadHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -655,7 +656,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.triggerUploadHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -673,7 +674,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.triggerUploadHandler(resp, req) require.Equal(t, http.StatusServiceUnavailable, resp.Code) @@ -691,7 +692,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.triggerUploadHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -709,7 +710,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.triggerUploadHandler(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -733,7 +734,7 @@ func TestHTTPApi(t *testing.T) { `))) resp := httptest.NewRecorder() - a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, config.Default, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) a.triggerUploadHandler(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -759,7 +760,7 @@ func TestHTTPApi(t *testing.T) { srvCtx, stopServer := context.WithCancel(ctx) - a := NewApi(config.MasterMode, c, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, c, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) serverSetupCh := make(chan struct{}) go func() { @@ -957,7 +958,7 @@ func TestHTTPApi(t *testing.T) { srvCtx, stopServer := context.WithCancel(ctx) - a := NewApi(config.MasterMode, c, logger.NOP, stats.Default, mockBackendConfig, db, ¬ifier, tenantManager, bcManager, jobsManager) + a := NewApi(config.MasterMode, c, logger.NOP, stats.Default, mockBackendConfig, db, n, tenantManager, bcManager, jobsManager) serverSetupCh := make(chan struct{}) go func() { diff --git a/warehouse/internal/loadfiles/loadfiles.go b/warehouse/internal/loadfiles/loadfiles.go index 4a10bd8fee..a0b663b69f 100644 --- a/warehouse/internal/loadfiles/loadfiles.go +++ b/warehouse/internal/loadfiles/loadfiles.go @@ -6,6 +6,10 @@ import ( "strings" "time" + "github.com/rudderlabs/rudder-server/services/notifier" + + stdjson "encoding/json" + "github.com/samber/lo" "golang.org/x/exp/slices" @@ -16,7 +20,6 @@ import ( "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" backendconfig "github.com/rudderlabs/rudder-server/backend-config" - "github.com/rudderlabs/rudder-server/services/pgnotifier" "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/utils/timeutil" schemarepository "github.com/rudderlabs/rudder-server/warehouse/integrations/datalake/schema-repository" @@ -35,7 +38,7 @@ const ( var warehousesToVerifyLoadFilesFolder = []string{warehouseutils.SNOWFLAKE} type Notifier interface { - Publish(ctx context.Context, payload pgnotifier.MessagePayload, schema *warehouseutils.Schema, priority int) (ch chan []pgnotifier.Response, err error) + Publish(ctx context.Context, payload *notifier.PublishRequest) (ch <-chan *notifier.PublishResponse, err error) } type StageFileRepo interface { @@ -67,6 +70,11 @@ type LoadFileGenerator struct { } type WorkerJobResponse struct { + StagingFileID int64 `json:"StagingFileID"` + Output []LoadFileUpload `json:"Output"` +} + +type LoadFileUpload struct { TableName string Location string TotalRows int @@ -81,7 +89,6 @@ type WorkerJobRequest struct { UploadID int64 StagingFileID int64 StagingFileLocation string - UploadSchema model.Schema WorkspaceID string SourceID string SourceName string @@ -97,7 +104,6 @@ type WorkerJobRequest struct { StagingUseRudderStorage bool UniqueLoadGenID string RudderStoragePrefix string - Output []WorkerJobResponse LoadFilePrefix string // prefix for the load file name LoadFileType string } @@ -193,7 +199,7 @@ func (lf *LoadFileGenerator) createFromStaging(ctx context.Context, job *model.U var sampleError error for _, chunk := range lo.Chunk(toProcessStagingFiles, publishBatchSize) { // td : add prefix to payload for s3 dest - var messages []pgnotifier.JobPayload + var messages []stdjson.RawMessage for _, stagingFile := range chunk { payload := WorkerJobRequest{ UploadID: job.Upload.ID, @@ -229,71 +235,89 @@ func (lf *LoadFileGenerator) createFromStaging(ctx context.Context, job *model.U messages = append(messages, payloadJSON) } - schema := &job.Upload.UploadSchema - - lf.Logger.Infof("[WH]: Publishing %d staging files for %s:%s to PgNotifier", len(messages), destType, destID) - messagePayload := pgnotifier.MessagePayload{ - Jobs: messages, - JobType: "upload", + uploadSchemaJSON, err := json.Marshal(struct { + UploadSchema model.Schema + }{ + UploadSchema: job.Upload.UploadSchema, + }) + if err != nil { + return 0, 0, fmt.Errorf("error marshalling upload schema: %w", err) } - ch, err := lf.Notifier.Publish(ctx, messagePayload, (*warehouseutils.Schema)(schema), job.Upload.Priority) + lf.Logger.Infof("[WH]: Publishing %d staging files for %s:%s to notifier", len(messages), destType, destID) + + ch, err := lf.Notifier.Publish(ctx, ¬ifier.PublishRequest{ + Payloads: messages, + JobType: notifier.JobTypeUpload, + UploadSchema: uploadSchemaJSON, + Priority: job.Upload.Priority, + }) if err != nil { - return 0, 0, fmt.Errorf("error publishing to PgNotifier: %w", err) + return 0, 0, fmt.Errorf("error publishing to notifier: %w", err) } // set messages to nil to release mem allocated messages = nil startId := chunk[0].ID endId := chunk[len(chunk)-1].ID g.Go(func() error { - responses := <-ch - lf.Logger.Infow("Received responses for staging files %d:%d for %s:%s from PgNotifier", + responses, ok := <-ch + if !ok { + return fmt.Errorf("receiving notifier channel closed") + } + + lf.Logger.Infow("Received responses for staging files %d:%d for %s:%s from Notifier", "startId", startId, "endID", endId, logfield.DestinationID, destType, logfield.DestinationType, destID, ) + if responses.Err != nil { + return fmt.Errorf("receiving responses from notifier: %w", responses.Err) + } + var loadFiles []model.LoadFile var successfulStagingFileIDs []int64 - for _, resp := range responses { + for _, resp := range responses.Jobs { // Error handling during generating_load_files step: - // 1. any error returned by pgnotifier is set on corresponding staging_file + // 1. any error returned by notifier is set on corresponding staging_file // 2. any error effecting a batch/all the staging files like saving load file records to wh db // is returned as error to caller of the func to set error on all staging files and the whole generating_load_files step - if resp.Status == "aborted" { + var jobResponse WorkerJobResponse + if err := json.Unmarshal(resp.Payload, &jobResponse); err != nil { + return fmt.Errorf("unmarshalling response from notifier: %w", err) + } + + if resp.Status == notifier.Aborted && resp.Error != nil { lf.Logger.Errorf("[WH]: Error in generating load files: %v", resp.Error) - sampleError = fmt.Errorf(resp.Error) - err = lf.StageRepo.SetErrorStatus(ctx, resp.JobID, sampleError) + sampleError = fmt.Errorf(resp.Error.Error()) + err = lf.StageRepo.SetErrorStatus(ctx, jobResponse.StagingFileID, sampleError) if err != nil { return fmt.Errorf("set staging file error status: %w", err) } continue } - var output []WorkerJobResponse - err = json.Unmarshal(resp.Output, &output) - if err != nil { - return fmt.Errorf("unmarshalling response from pgnotifier: %w", err) - } - if len(output) == 0 { - lf.Logger.Errorf("[WH]: No LoadFiles returned by wh worker") + + if len(jobResponse.Output) == 0 { + lf.Logger.Errorf("[WH]: No LoadFiles returned by worker") continue } - for i := range output { + + for _, output := range jobResponse.Output { loadFiles = append(loadFiles, model.LoadFile{ - TableName: output[i].TableName, - Location: output[i].Location, - TotalRows: output[i].TotalRows, - ContentLength: output[i].ContentLength, - StagingFileID: output[i].StagingFileID, - DestinationRevisionID: output[i].DestinationRevisionID, - UseRudderStorage: output[i].UseRudderStorage, + TableName: output.TableName, + Location: output.Location, + TotalRows: output.TotalRows, + ContentLength: output.ContentLength, + StagingFileID: output.StagingFileID, + DestinationRevisionID: output.DestinationRevisionID, + UseRudderStorage: output.UseRudderStorage, SourceID: job.Upload.SourceID, DestinationID: job.Upload.DestinationID, DestinationType: job.Upload.DestinationType, }) } - successfulStagingFileIDs = append(successfulStagingFileIDs, resp.JobID) + successfulStagingFileIDs = append(successfulStagingFileIDs, jobResponse.StagingFileID) } if len(loadFiles) == 0 { diff --git a/warehouse/internal/loadfiles/loadfiles_test.go b/warehouse/internal/loadfiles/loadfiles_test.go index 6606a84653..81d28526b7 100644 --- a/warehouse/internal/loadfiles/loadfiles_test.go +++ b/warehouse/internal/loadfiles/loadfiles_test.go @@ -48,7 +48,7 @@ func getStagingFiles() []*model.StagingFile { func TestCreateLoadFiles(t *testing.T) { t.Parallel() - notifer := &mockNotifier{ + notifier := &mockNotifier{ t: t, tables: []string{"track", "indentify"}, } @@ -58,7 +58,7 @@ func TestCreateLoadFiles(t *testing.T) { lf := loadfiles.LoadFileGenerator{ Logger: logger.NOP, - Notifier: notifer, + Notifier: notifier, StageRepo: stageRepo, LoadRepo: loadRepo, @@ -95,7 +95,7 @@ func TestCreateLoadFiles(t *testing.T) { require.Equal(t, int64(1), startID) require.Equal(t, int64(20), endID) - require.Len(t, loadRepo.store, len(stagingFiles)*len(notifer.tables)) + require.Len(t, loadRepo.store, len(stagingFiles)*len(notifier.tables)) require.Len(t, stageRepo.store, len(stagingFiles)) for _, stagingFile := range stagingFiles { @@ -114,7 +114,7 @@ func TestCreateLoadFiles(t *testing.T) { tableNames = append(tableNames, loadFile.TableName) } - require.ElementsMatch(t, notifer.tables, tableNames) + require.ElementsMatch(t, notifier.tables, tableNames) require.Equal(t, warehouseutils.StagingFileSucceededState, stageRepo.store[stagingFile.ID].Status) } @@ -130,7 +130,7 @@ func TestCreateLoadFiles(t *testing.T) { require.Equal(t, int64(21), startID) require.Equal(t, int64(22), endID) - require.Len(t, loadRepo.store, len(stagingFiles)*len(notifer.tables)) + require.Len(t, loadRepo.store, len(stagingFiles)*len(notifier.tables)) }) t.Run("force recreate", func(t *testing.T) { @@ -143,7 +143,7 @@ func TestCreateLoadFiles(t *testing.T) { require.Equal(t, int64(23), startID) require.Equal(t, int64(42), endID) - require.Len(t, loadRepo.store, len(stagingFiles)*len(notifer.tables)) + require.Len(t, loadRepo.store, len(stagingFiles)*len(notifier.tables)) require.Len(t, stageRepo.store, len(stagingFiles)) }) } diff --git a/warehouse/internal/loadfiles/mock_notifier_test.go b/warehouse/internal/loadfiles/mock_notifier_test.go index 1e99b6106d..3cf43a892a 100644 --- a/warehouse/internal/loadfiles/mock_notifier_test.go +++ b/warehouse/internal/loadfiles/mock_notifier_test.go @@ -3,13 +3,14 @@ package loadfiles_test import ( "context" "encoding/json" + "errors" "testing" + "github.com/rudderlabs/rudder-server/services/notifier" + "github.com/stretchr/testify/require" - "github.com/rudderlabs/rudder-server/services/pgnotifier" "github.com/rudderlabs/rudder-server/warehouse/internal/loadfiles" - warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) type mockNotifier struct { @@ -19,20 +20,20 @@ type mockNotifier struct { tables []string } -func (n *mockNotifier) Publish(_ context.Context, payload pgnotifier.MessagePayload, _ *warehouseutils.Schema, _ int) (chan []pgnotifier.Response, error) { - var responses []pgnotifier.Response - for _, p := range payload.Jobs { +func (n *mockNotifier) Publish(_ context.Context, payload *notifier.PublishRequest) (<-chan *notifier.PublishResponse, error) { + var responses notifier.PublishResponse + for _, p := range payload.Payloads { var req loadfiles.WorkerJobRequest err := json.Unmarshal(p, &req) require.NoError(n.t, err) - var resps []loadfiles.WorkerJobResponse + var loadFileUploads []loadfiles.LoadFileUpload for _, tableName := range n.tables { destinationRevisionID := req.DestinationRevisionID n.requests = append(n.requests, req) - resps = append(resps, loadfiles.WorkerJobResponse{ + loadFileUploads = append(loadFileUploads, loadfiles.LoadFileUpload{ TableName: tableName, Location: req.StagingFileLocation + "/" + req.UniqueLoadGenID + "/" + tableName, TotalRows: 10, @@ -42,28 +43,31 @@ func (n *mockNotifier) Publish(_ context.Context, payload pgnotifier.MessagePayl UseRudderStorage: req.UseRudderStorage, }) } - out, err := json.Marshal(resps) + jobResponse := loadfiles.WorkerJobResponse{ + StagingFileID: req.StagingFileID, + Output: loadFileUploads, + } + out, err := json.Marshal(jobResponse) errString := "" if err != nil { errString = err.Error() } - status := "ok" + status := notifier.Succeeded if req.StagingFileLocation == "" { errString = "staging file location is empty" - status = "aborted" + status = notifier.Aborted } - responses = append(responses, pgnotifier.Response{ - JobID: req.StagingFileID, - Output: out, - Error: errString, - Status: status, + responses.Jobs = append(responses.Jobs, notifier.Job{ + Payload: out, + Error: errors.New(errString), + Status: status, }) } - ch := make(chan []pgnotifier.Response, 1) - ch <- responses + ch := make(chan *notifier.PublishResponse, 1) + ch <- &responses return ch, nil } diff --git a/warehouse/jobs/http.go b/warehouse/jobs/http.go index 85007e68dc..323d3568ca 100644 --- a/warehouse/jobs/http.go +++ b/warehouse/jobs/http.go @@ -7,6 +7,8 @@ import ( "net/http" "strings" + "github.com/rudderlabs/rudder-server/services/notifier" + ierrors "github.com/rudderlabs/rudder-server/warehouse/internal/errors" lf "github.com/rudderlabs/rudder-server/warehouse/logfield" @@ -64,7 +66,7 @@ func (a *AsyncJobWh) InsertJobHandler(w http.ResponseWriter, r *http.Request) { JobRunID: payload.JobRunID, TaskRunID: payload.TaskRunID, StartTime: payload.StartTime, - JobType: AsyncJobType, + JobType: string(notifier.JobTypeAsync), }) if err != nil { a.logger.Errorw("marshalling metadata for inserting async job", lf.Error, err.Error()) diff --git a/warehouse/jobs/http_test.go b/warehouse/jobs/http_test.go index 966858cf2e..4cf83b80a6 100644 --- a/warehouse/jobs/http_test.go +++ b/warehouse/jobs/http_test.go @@ -13,11 +13,14 @@ import ( "testing" "time" + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-server/services/notifier" + "github.com/ory/dockertest/v3" "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" - "github.com/rudderlabs/rudder-server/services/pgnotifier" migrator "github.com/rudderlabs/rudder-server/services/sql-migrator" sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" "github.com/rudderlabs/rudder-server/warehouse/internal/model" @@ -28,8 +31,6 @@ import ( ) func TestAsyncJobHandlers(t *testing.T) { - pgnotifier.Init() - const ( workspaceID = "test_workspace_id" sourceID = "test_source_id" @@ -58,11 +59,12 @@ func TestAsyncJobHandlers(t *testing.T) { db := sqlmiddleware.New(pgResource.DB) - notifier, err := pgnotifier.New(workspaceIdentifier, pgResource.DBDsn) - require.NoError(t, err) - ctx := context.Background() + n := notifier.New(config.Default, logger.NOP, stats.Default, workspaceIdentifier) + err = n.Setup(ctx, pgResource.DBDsn) + require.NoError(t, err) + now := time.Now().Truncate(time.Second).UTC() uploadsRepo := repo.NewUploads(db, repo.WithNow(func() time.Time { @@ -202,11 +204,11 @@ func TestAsyncJobHandlers(t *testing.T) { resp := httptest.NewRecorder() jobsManager := AsyncJobWh{ - dbHandle: db.DB, - enabled: false, - logger: logger.NOP, - context: ctx, - pgnotifier: ¬ifier, + dbHandle: db.DB, + enabled: false, + logger: logger.NOP, + context: ctx, + notifier: n, } jobsManager.InsertJobHandler(resp, req) require.Equal(t, http.StatusInternalServerError, resp.Code) @@ -220,11 +222,11 @@ func TestAsyncJobHandlers(t *testing.T) { resp := httptest.NewRecorder() jobsManager := AsyncJobWh{ - dbHandle: db.DB, - enabled: true, - logger: logger.NOP, - context: ctx, - pgnotifier: ¬ifier, + dbHandle: db.DB, + enabled: true, + logger: logger.NOP, + context: ctx, + notifier: n, } jobsManager.InsertJobHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -238,11 +240,11 @@ func TestAsyncJobHandlers(t *testing.T) { resp := httptest.NewRecorder() jobsManager := AsyncJobWh{ - dbHandle: db.DB, - enabled: true, - logger: logger.NOP, - context: ctx, - pgnotifier: ¬ifier, + dbHandle: db.DB, + enabled: true, + logger: logger.NOP, + context: ctx, + notifier: n, } jobsManager.InsertJobHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -263,11 +265,11 @@ func TestAsyncJobHandlers(t *testing.T) { resp := httptest.NewRecorder() jobsManager := AsyncJobWh{ - dbHandle: db.DB, - enabled: true, - logger: logger.NOP, - context: ctx, - pgnotifier: ¬ifier, + dbHandle: db.DB, + enabled: true, + logger: logger.NOP, + context: ctx, + notifier: n, } jobsManager.InsertJobHandler(resp, req) require.Equal(t, http.StatusOK, resp.Code) @@ -286,11 +288,11 @@ func TestAsyncJobHandlers(t *testing.T) { resp := httptest.NewRecorder() jobsManager := AsyncJobWh{ - dbHandle: db.DB, - enabled: false, - logger: logger.NOP, - context: ctx, - pgnotifier: ¬ifier, + dbHandle: db.DB, + enabled: false, + logger: logger.NOP, + context: ctx, + notifier: n, } jobsManager.StatusJobHandler(resp, req) require.Equal(t, http.StatusInternalServerError, resp.Code) @@ -304,11 +306,11 @@ func TestAsyncJobHandlers(t *testing.T) { resp := httptest.NewRecorder() jobsManager := AsyncJobWh{ - dbHandle: db.DB, - enabled: true, - logger: logger.NOP, - context: ctx, - pgnotifier: ¬ifier, + dbHandle: db.DB, + enabled: true, + logger: logger.NOP, + context: ctx, + notifier: n, } jobsManager.StatusJobHandler(resp, req) require.Equal(t, http.StatusBadRequest, resp.Code) @@ -335,11 +337,11 @@ func TestAsyncJobHandlers(t *testing.T) { resp := httptest.NewRecorder() jobsManager := AsyncJobWh{ - dbHandle: db.DB, - enabled: true, - logger: logger.NOP, - context: ctx, - pgnotifier: ¬ifier, + dbHandle: db.DB, + enabled: true, + logger: logger.NOP, + context: ctx, + notifier: n, } jobsManager.StatusJobHandler(resp, req) require.Equal(t, http.StatusOK, resp.Code) diff --git a/warehouse/jobs/runner.go b/warehouse/jobs/runner.go index df5b001696..c9e8dc7db4 100644 --- a/warehouse/jobs/runner.go +++ b/warehouse/jobs/runner.go @@ -7,13 +7,14 @@ import ( "fmt" "time" + "github.com/rudderlabs/rudder-server/services/notifier" + "github.com/lib/pq" "github.com/samber/lo" "golang.org/x/sync/errgroup" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-server/services/pgnotifier" "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/utils/timeutil" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" @@ -23,14 +24,14 @@ import ( func InitWarehouseJobsAPI( ctx context.Context, dbHandle *sql.DB, - notifier *pgnotifier.PGNotifier, + notifier *notifier.Notifier, ) *AsyncJobWh { return &AsyncJobWh{ - dbHandle: dbHandle, - enabled: false, - pgnotifier: notifier, - context: ctx, - logger: logger.NewLogger().Child("asyncjob"), + dbHandle: dbHandle, + enabled: false, + notifier: notifier, + context: ctx, + logger: logger.NewLogger().Child("asyncjob"), } } @@ -152,8 +153,8 @@ func (a *AsyncJobWh) cleanUpAsyncTable(ctx context.Context) error { startAsyncJobRunner is the main runner that 1) Periodically queries the db for any pending async jobs 2) Groups them together -3) Publishes them to the pgnotifier -4) Spawns a subroutine that periodically checks for responses from pgNotifier/slave worker post trackBatch +3) Publishes them to the notifier +4) Spawns a subroutine that periodically checks for responses from Notifier/slave worker post trackBatch */ func (a *AsyncJobWh) startAsyncJobRunner(ctx context.Context) error { a.logger.Info("[WH-Jobs]: Starting async job runner") @@ -186,13 +187,14 @@ func (a *AsyncJobWh) startAsyncJobRunner(ctx context.Context) error { _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) continue } - messagePayload := pgnotifier.MessagePayload{ - Jobs: notifierClaims, - JobType: AsyncJobType, - } - ch, err := a.pgnotifier.Publish(ctx, messagePayload, &warehouseutils.Schema{}, 100) + + ch, err := a.notifier.Publish(ctx, ¬ifier.PublishRequest{ + Payloads: notifierClaims, + JobType: notifier.JobTypeAsync, + Priority: 100, + }) if err != nil { - a.logger.Errorf("[WH-Jobs]: unable to get publish async jobs to pgnotifier. Task failed with error %s", err.Error()) + a.logger.Errorf("[WH-Jobs]: unable to get publish async jobs to notifier. Task failed with error %s", err.Error()) asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, err) _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) continue @@ -204,34 +206,46 @@ func (a *AsyncJobWh) startAsyncJobRunner(ctx context.Context) error { case <-ctx.Done(): a.logger.Infof("[WH-Jobs]: Context cancelled for async job runner") return nil - case responses := <-ch: - a.logger.Info("[WH-Jobs]: Response received from the pgnotifier track batch") + case responses, ok := <-ch: + if !ok { + a.logger.Error("[WH-Jobs]: Notifier track batch channel closed") + asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, fmt.Errorf("receiving channel closed")) + _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) + continue + } + if responses.Err != nil { + a.logger.Errorf("[WH-Jobs]: Error received from the notifier track batch %s", responses.Err.Error()) + asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, responses.Err) + _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) + continue + } + a.logger.Info("[WH-Jobs]: Response received from the notifier track batch") asyncJobsStatusMap := getAsyncStatusMapFromAsyncPayloads(pendingAsyncJobs) - a.updateStatusJobPayloadsFromPgNotifierResponse(responses, asyncJobsStatusMap) + a.updateStatusJobPayloadsFromNotifierResponse(responses, asyncJobsStatusMap) _ = a.updateAsyncJobs(ctx, asyncJobsStatusMap) case <-time.After(a.asyncJobTimeOut): - a.logger.Errorf("Go Routine timed out waiting for a response from PgNotifier", pendingAsyncJobs[0].Id) + a.logger.Errorf("Go Routine timed out waiting for a response from Notifier", pendingAsyncJobs[0].Id) asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, err) _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) } } } -func (a *AsyncJobWh) updateStatusJobPayloadsFromPgNotifierResponse(r []pgnotifier.Response, m map[string]AsyncJobStatus) { - for _, resp := range r { - var pgNotifierOutput PGNotifierOutput - err := json.Unmarshal(resp.Output, &pgNotifierOutput) +func (a *AsyncJobWh) updateStatusJobPayloadsFromNotifierResponse(r *notifier.PublishResponse, m map[string]AsyncJobStatus) { + for _, resp := range r.Jobs { + var response NotifierResponse + err := json.Unmarshal(resp.Payload, &response) if err != nil { - a.logger.Errorf("error unmarshalling pgnotifier payload to AsyncJobStatusMa for Id: %s", pgNotifierOutput.Id) + a.logger.Errorf("error unmarshalling notifier payload to AsyncJobStatusMa for Id: %s", response.Id) continue } - if output, ok := m[pgNotifierOutput.Id]; ok { - output.Status = resp.Status - if resp.Error != "" { - output.Error = fmt.Errorf(resp.Error) + if output, ok := m[response.Id]; ok { + output.Status = string(resp.Status) + if resp.Error != nil { + output.Error = fmt.Errorf(resp.Error.Error()) } - m[pgNotifierOutput.Id] = output + m[response.Id] = output } } } diff --git a/warehouse/jobs/types.go b/warehouse/jobs/types.go index 10313fab5e..29788fa6f0 100644 --- a/warehouse/jobs/types.go +++ b/warehouse/jobs/types.go @@ -6,8 +6,9 @@ import ( "encoding/json" "time" + "github.com/rudderlabs/rudder-server/services/notifier" + "github.com/rudderlabs/rudder-go-kit/logger" - "github.com/rudderlabs/rudder-server/services/pgnotifier" ) // StartJobReqPayload For processing requests payload in handlers.go @@ -26,7 +27,7 @@ type StartJobReqPayload struct { type AsyncJobWh struct { dbHandle *sql.DB enabled bool - pgnotifier *pgnotifier.PGNotifier + notifier *notifier.Notifier context context.Context logger logger.Logger maxBatchSizeToProcess int @@ -61,10 +62,9 @@ const ( WhJobSucceeded string = "succeeded" WhJobAborted string = "aborted" WhJobFailed string = "failed" - AsyncJobType string = "async_job" ) -type PGNotifierOutput struct { +type NotifierResponse struct { Id string `json:"id"` } diff --git a/warehouse/jobs/utils.go b/warehouse/jobs/utils.go index 36635db882..017797a8ef 100644 --- a/warehouse/jobs/utils.go +++ b/warehouse/jobs/utils.go @@ -2,8 +2,6 @@ package jobs import ( "encoding/json" - - "github.com/rudderlabs/rudder-server/services/pgnotifier" ) func convertToPayloadStatusStructWithSingleStatus(payloads []AsyncJobPayload, status string, err error) map[string]AsyncJobStatus { @@ -18,9 +16,9 @@ func convertToPayloadStatusStructWithSingleStatus(payloads []AsyncJobPayload, st return asyncJobStatusMap } -// convert to pgNotifier Payload and return the array of payloads -func getMessagePayloadsFromAsyncJobPayloads(asyncJobPayloads []AsyncJobPayload) ([]pgnotifier.JobPayload, error) { - var messages []pgnotifier.JobPayload +// convert to notifier Payload and return the array of payloads +func getMessagePayloadsFromAsyncJobPayloads(asyncJobPayloads []AsyncJobPayload) ([]json.RawMessage, error) { + var messages []json.RawMessage for _, job := range asyncJobPayloads { message, err := json.Marshal(job) if err != nil { diff --git a/warehouse/router.go b/warehouse/router.go index 5fb3832f42..b4a1338b27 100644 --- a/warehouse/router.go +++ b/warehouse/router.go @@ -5,11 +5,14 @@ import ( "errors" "fmt" "math/rand" - "strings" "sync" "sync/atomic" "time" + "github.com/lib/pq" + + "github.com/rudderlabs/rudder-server/services/notifier" + "github.com/rudderlabs/rudder-server/warehouse/encoding" "github.com/rudderlabs/rudder-server/app" @@ -28,7 +31,6 @@ import ( "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/stats" "github.com/rudderlabs/rudder-server/rruntime" - "github.com/rudderlabs/rudder-server/services/pgnotifier" "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/utils/timeutil" "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" @@ -76,7 +78,7 @@ type router struct { tenantManager *multitenant.Manager bcManager *backendConfigManager uploadJobFactory UploadJobFactory - notifier *pgnotifier.PGNotifier + notifier *notifier.Notifier config struct { noOfWorkers int @@ -112,7 +114,7 @@ func newRouter( logger logger.Logger, statsFactory stats.Stats, db *sqlquerywrapper.DB, - pgNotifier *pgnotifier.PGNotifier, + notifier *notifier.Notifier, tenantManager *multitenant.Manager, controlPlaneClient *controlplane.Client, bcManager *backendConfigManager, @@ -133,7 +135,7 @@ func newRouter( r.uploadRepo = repo.NewUploads(db) r.whSchemaRepo = repo.NewWHSchemas(db) - r.notifier = pgNotifier + r.notifier = notifier r.tenantManager = tenantManager r.bcManager = bcManager r.destType = destType @@ -152,7 +154,7 @@ func newRouter( logger: r.logger, statsFactory: r.statsFactory, dbHandle: r.dbHandle, - pgNotifier: r.notifier, + notifier: r.notifier, destinationValidator: validations.NewDestinationValidator(), loadFile: &loadfiles.LoadFileGenerator{ Logger: r.logger.Child("loadfile"), @@ -396,12 +398,14 @@ loop: uploadJobsToProcess, err := r.uploadsToProcess(ctx, availableWorkers, inProgressNamespaces) if err != nil { - if errors.Is(err, context.Canceled) || - errors.Is(err, context.DeadlineExceeded) || - strings.Contains(err.Error(), "pq: canceling statement due to user request") { + var pqErr *pq.Error + switch true { + case errors.Is(err, context.Canceled), + errors.Is(err, context.DeadlineExceeded), + errors.As(err, &pqErr) && pqErr.Code == "57014": break loop - } else { + default: r.logger.Errorf(`Error executing uploadsToProcess: %v`, err) panic(err) diff --git a/warehouse/router_test.go b/warehouse/router_test.go index 750092d98a..6f9ac72381 100644 --- a/warehouse/router_test.go +++ b/warehouse/router_test.go @@ -9,6 +9,8 @@ import ( "testing" "time" + "github.com/rudderlabs/rudder-server/services/notifier" + "github.com/samber/lo" "github.com/rudderlabs/rudder-server/warehouse/encoding" @@ -33,7 +35,6 @@ import ( "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" backendconfig "github.com/rudderlabs/rudder-server/backend-config" - "github.com/rudderlabs/rudder-server/services/pgnotifier" migrator "github.com/rudderlabs/rudder-server/services/sql-migrator" sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" "github.com/rudderlabs/rudder-server/warehouse/internal/model" @@ -41,7 +42,6 @@ import ( ) func TestRouter(t *testing.T) { - pgnotifier.Init() Init4() pool, err := dockertest.NewPool("") @@ -96,7 +96,10 @@ func TestRouter(t *testing.T) { db := sqlmiddleware.New(pgResource.DB) - notifier, err := pgnotifier.New(workspaceIdentifier, pgResource.DBDsn) + ctx := context.Background() + + n := notifier.New(config.Default, logger.NOP, stats.Default, workspaceIdentifier) + err = n.Setup(ctx, pgResource.DBDsn) require.NoError(t, err) ctrl := gomock.NewController(t) @@ -113,7 +116,7 @@ func TestRouter(t *testing.T) { ) bcm := newBackendConfigManager(config.Default, db, tenantManager, logger.NOP) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) defer cancel() ef := encoding.NewFactory(config.Default) @@ -126,7 +129,7 @@ func TestRouter(t *testing.T) { logger.NOP, stats.Default, db, - ¬ifier, + n, tenantManager, cp, bcm, diff --git a/warehouse/slave.go b/warehouse/slave.go index 7d70c37b6a..362809e6d1 100644 --- a/warehouse/slave.go +++ b/warehouse/slave.go @@ -3,6 +3,8 @@ package warehouse import ( "context" + "github.com/rudderlabs/rudder-server/services/notifier" + "github.com/rudderlabs/rudder-server/warehouse/encoding" "github.com/rudderlabs/rudder-go-kit/logger" @@ -11,14 +13,13 @@ import ( "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/stats" - "github.com/rudderlabs/rudder-server/services/pgnotifier" "github.com/rudderlabs/rudder-server/utils/misc" ) type slaveNotifier interface { - Subscribe(ctx context.Context, workerId string, jobsBufferSize int) chan pgnotifier.Claim - RunMaintenanceWorker(ctx context.Context) error - UpdateClaimedEvent(claim *pgnotifier.Claim, response *pgnotifier.ClaimResponse) + Subscribe(ctx context.Context, workerId string, jobsBufferSize int) <-chan *notifier.ClaimJob + RunMaintenance(ctx context.Context) error + UpdateClaim(ctx context.Context, job *notifier.ClaimJob, response *notifier.ClaimJobResponse) } type slave struct { @@ -78,7 +79,7 @@ func (s *slave) setupSlave(ctx context.Context) error { } g.Go(misc.WithBugsnagForWarehouse(func() error { - return s.notifier.RunMaintenanceWorker(gCtx) + return s.notifier.RunMaintenance(gCtx) })) return g.Wait() diff --git a/warehouse/slave_test.go b/warehouse/slave_test.go index 28f62f3e55..688bcac98d 100644 --- a/warehouse/slave_test.go +++ b/warehouse/slave_test.go @@ -8,6 +8,8 @@ import ( "os" "testing" + "github.com/rudderlabs/rudder-server/services/notifier" + "github.com/rudderlabs/rudder-server/warehouse/encoding" "golang.org/x/sync/errgroup" @@ -20,27 +22,26 @@ import ( "github.com/rudderlabs/rudder-go-kit/filemanager" "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/stats" - "github.com/rudderlabs/rudder-server/services/pgnotifier" "github.com/rudderlabs/rudder-server/testhelper/destination" "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/warehouse/internal/model" ) type mockSlaveNotifier struct { - subscribeCh chan *pgnotifier.ClaimResponse - publishCh chan pgnotifier.Claim + subscribeCh chan *notifier.ClaimJobResponse + publishCh chan *notifier.ClaimJob maintenanceErr error } -func (m *mockSlaveNotifier) Subscribe(context.Context, string, int) chan pgnotifier.Claim { +func (m *mockSlaveNotifier) Subscribe(context.Context, string, int) <-chan *notifier.ClaimJob { return m.publishCh } -func (m *mockSlaveNotifier) UpdateClaimedEvent(_ *pgnotifier.Claim, response *pgnotifier.ClaimResponse) { +func (m *mockSlaveNotifier) UpdateClaim(_ context.Context, _ *notifier.ClaimJob, response *notifier.ClaimJobResponse) { m.subscribeCh <- response } -func (m *mockSlaveNotifier) RunMaintenanceWorker(context.Context) error { +func (m *mockSlaveNotifier) RunMaintenance(context.Context) error { return m.maintenanceErr } @@ -73,12 +74,12 @@ func TestSlave(t *testing.T) { schemaMap := stagingSchema(t) - publishCh := make(chan pgnotifier.Claim) - subscriberCh := make(chan *pgnotifier.ClaimResponse) + publishCh := make(chan *notifier.ClaimJob) + subscriberCh := make(chan *notifier.ClaimJobResponse) defer close(publishCh) defer close(subscriberCh) - notifier := &mockSlaveNotifier{ + slaveNotifier := &mockSlaveNotifier{ publishCh: publishCh, subscribeCh: subscriberCh, } @@ -90,7 +91,7 @@ func TestSlave(t *testing.T) { config.Default, logger.NOP, stats.Default, - notifier, + slaveNotifier, newBackendConfigManager(config.Default, nil, tenantManager, logger.NOP), newConstraintsManager(config.Default), encoding.NewFactory(config.Default), @@ -128,13 +129,15 @@ func TestSlave(t *testing.T) { payloadJson, err := json.Marshal(p) require.NoError(t, err) - claim := pgnotifier.Claim{ - ID: 1, - BatchID: uuid.New().String(), - Payload: payloadJson, - Status: "waiting", - Workspace: "test_workspace", - JobType: "upload", + claim := ¬ifier.ClaimJob{ + Job: ¬ifier.Job{ + ID: 1, + BatchID: uuid.New().String(), + Payload: payloadJson, + Status: model.Waiting, + WorkspaceIdentifier: "test_workspace", + Type: notifier.JobTypeUpload, + }, } g, _ := errgroup.WithContext(ctx) @@ -153,7 +156,7 @@ func TestSlave(t *testing.T) { var uploadPayload payload err := json.Unmarshal(response.Payload, &uploadPayload) require.NoError(t, err) - require.Equal(t, uploadPayload.BatchID, claim.BatchID) + require.Equal(t, uploadPayload.BatchID, claim.Job.BatchID) require.Equal(t, uploadPayload.UploadID, p.UploadID) require.Equal(t, uploadPayload.StagingFileID, p.StagingFileID) require.Equal(t, uploadPayload.StagingFileLocation, p.StagingFileLocation) diff --git a/warehouse/slave_worker.go b/warehouse/slave_worker.go index e84e8e7eee..9fed7bb37d 100644 --- a/warehouse/slave_worker.go +++ b/warehouse/slave_worker.go @@ -11,11 +11,12 @@ import ( "strconv" "time" + "github.com/rudderlabs/rudder-server/services/notifier" + "github.com/rudderlabs/rudder-go-kit/logger" "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/stats" - "github.com/rudderlabs/rudder-server/services/pgnotifier" "github.com/rudderlabs/rudder-server/warehouse/encoding" integrationsconfig "github.com/rudderlabs/rudder-server/warehouse/integrations/config" "github.com/rudderlabs/rudder-server/warehouse/integrations/manager" @@ -100,33 +101,36 @@ func newSlaveWorker( return s } -func (sw *slaveWorker) start(ctx context.Context, notificationChan <-chan pgnotifier.Claim, slaveID string) { +func (sw *slaveWorker) start(ctx context.Context, notificationChan <-chan *notifier.ClaimJob, slaveID string) { workerIdleTimeStart := time.Now() for { select { case <-ctx.Done(): - sw.log.Infof("[WH]: Slave worker-%d-%s is shutting down", sw.workerIdx, slaveID) + sw.log.Infof("Slave worker-%d-%s is shutting down", sw.workerIdx, slaveID) return - case claimedJob := <-notificationChan: + case claimedJob, ok := <-notificationChan: + if !ok { + return + } sw.stats.workerIdleTime.Since(workerIdleTimeStart) - sw.log.Debugf("[WH]: Successfully claimed job:%d by slave worker-%d-%s & job type %s", - claimedJob.ID, + sw.log.Debugf("Successfully claimed job:%d by slave worker-%d-%s & job type %s", + claimedJob.Job.ID, sw.workerIdx, slaveID, - claimedJob.JobType, + claimedJob.Job.Type, ) - switch claimedJob.JobType { - case jobs.AsyncJobType: + switch claimedJob.Job.Type { + case notifier.JobTypeAsync: sw.processClaimedAsyncJob(ctx, claimedJob) default: sw.processClaimedUploadJob(ctx, claimedJob) } - sw.log.Infof("[WH]: Successfully processed job:%d by slave worker-%d-%s", - claimedJob.ID, + sw.log.Infof("Successfully processed job:%d by slave worker-%d-%s", + claimedJob.Job.ID, sw.workerIdx, slaveID, ) @@ -136,13 +140,13 @@ func (sw *slaveWorker) start(ctx context.Context, notificationChan <-chan pgnoti } } -func (sw *slaveWorker) processClaimedUploadJob(ctx context.Context, claimedJob pgnotifier.Claim) { +func (sw *slaveWorker) processClaimedUploadJob(ctx context.Context, claimedJob *notifier.ClaimJob) { sw.stats.workerClaimProcessingTime.RecordDuration()() - handleErr := func(err error, claim pgnotifier.Claim) { + handleErr := func(err error, claimedJob *notifier.ClaimJob) { sw.stats.workerClaimProcessingFailed.Increment() - sw.notifier.UpdateClaimedEvent(&claim, &pgnotifier.ClaimResponse{ + sw.notifier.UpdateClaim(ctx, claimedJob, ¬ifier.ClaimJobResponse{ Err: err, }) } @@ -153,14 +157,14 @@ func (sw *slaveWorker) processClaimedUploadJob(ctx context.Context, claimedJob p err error ) - if err = json.Unmarshal(claimedJob.Payload, &job); err != nil { + if err = json.Unmarshal(claimedJob.Job.Payload, &job); err != nil { handleErr(err, claimedJob) return } - sw.log.Infof(`Starting processing staging-file:%v from claim:%v`, job.StagingFileID, claimedJob.ID) + sw.log.Infof(`Starting processing staging-file:%v from claim:%v`, job.StagingFileID, claimedJob.Job.ID) - job.BatchID = claimedJob.BatchID + job.BatchID = claimedJob.Job.BatchID job.Output, err = sw.processStagingFile(ctx, job) if err != nil { handleErr(err, claimedJob) @@ -174,7 +178,7 @@ func (sw *slaveWorker) processClaimedUploadJob(ctx context.Context, claimedJob p sw.stats.workerClaimProcessingSucceeded.Increment() - sw.notifier.UpdateClaimedEvent(&claimedJob, &pgnotifier.ClaimResponse{ + sw.notifier.UpdateClaim(ctx, claimedJob, ¬ifier.ClaimJobResponse{ Payload: jobJSON, }) } @@ -192,7 +196,7 @@ func (sw *slaveWorker) processStagingFile(ctx context.Context, job payload) ([]u jr := newJobRun(job, sw.conf, sw.log, sw.statsFactory, sw.encodingFactory) - sw.log.Debugf("[WH]: Starting processing staging file: %v at %s for %s", + sw.log.Debugf("Starting processing staging file: %v at %s for %s", job.StagingFileID, job.StagingFileLocation, jr.identifier, @@ -259,7 +263,7 @@ func (sw *slaveWorker) processStagingFile(ctx context.Context, job payload) ([]u ) if err := json.Unmarshal(lineBytes, &batchRouterEvent); err != nil { - jr.logger.Errorf("[WH]: Failed to unmarshal JSON line to batchrouter event: %+v", batchRouterEvent) + jr.logger.Errorf("Failed to unmarshal JSON line to batchrouter event: %+v", batchRouterEvent) continue } @@ -359,7 +363,7 @@ func (sw *slaveWorker) processStagingFile(ctx context.Context, job payload) ([]u err = jr.handleDiscardTypes(tableName, columnName, columnVal, columnData, violatedConstraints, jr.outputFileWritersMap[discardsTable]) if err != nil { - jr.logger.Errorf("[WH]: Failed to write to discards: %v", err) + jr.logger.Errorf("Failed to write to discards: %v", err) } jr.tableEventCountMap[discardsTable]++ @@ -391,7 +395,7 @@ func (sw *slaveWorker) processStagingFile(ctx context.Context, job payload) ([]u jr.tableEventCountMap[tableName]++ } - jr.logger.Debugf("[WH]: Process %v bytes from downloaded staging file: %s", lineBytesCounter, job.StagingFileLocation) + jr.logger.Debugf("Process %v bytes from downloaded staging file: %s", lineBytesCounter, job.StagingFileLocation) jr.processingStagingFileStat.Since(processingStart) jr.bytesProcessedStagingFileStat.Count(lineBytesCounter) @@ -410,11 +414,11 @@ func (sw *slaveWorker) processStagingFile(ctx context.Context, job payload) ([]u return uploadsResults, err } -func (sw *slaveWorker) processClaimedAsyncJob(ctx context.Context, claimedJob pgnotifier.Claim) { - handleErr := func(err error, claim pgnotifier.Claim) { - sw.log.Errorf("[WH]: Error processing claim: %v", err) +func (sw *slaveWorker) processClaimedAsyncJob(ctx context.Context, claimedJob *notifier.ClaimJob) { + handleErr := func(err error, claimedJob *notifier.ClaimJob) { + sw.log.Errorf("Error processing claim: %v", err) - sw.notifier.UpdateClaimedEvent(&claimedJob, &pgnotifier.ClaimResponse{ + sw.notifier.UpdateClaim(ctx, claimedJob, ¬ifier.ClaimJobResponse{ Err: err, }) } @@ -424,7 +428,7 @@ func (sw *slaveWorker) processClaimedAsyncJob(ctx context.Context, claimedJob pg err error ) - if err := json.Unmarshal(claimedJob.Payload, &job); err != nil { + if err := json.Unmarshal(claimedJob.Job.Payload, &job); err != nil { handleErr(err, claimedJob) return } @@ -441,7 +445,7 @@ func (sw *slaveWorker) processClaimedAsyncJob(ctx context.Context, claimedJob pg return } - sw.notifier.UpdateClaimedEvent(&claimedJob, &pgnotifier.ClaimResponse{ + sw.notifier.UpdateClaim(ctx, claimedJob, ¬ifier.ClaimJobResponse{ Payload: jobResultJSON, }) } @@ -487,7 +491,7 @@ func (sw *slaveWorker) runAsyncJob(ctx context.Context, asyncjob jobs.AsyncJobPa StartTime: metadata.StartTime, }) default: - err = errors.New("invalid AsyncJobType") + err = errors.New("invalid asyncJob type") } if err != nil { return result, err diff --git a/warehouse/slave_worker_job_test.go b/warehouse/slave_worker_job_test.go index fdca5f5e5d..7eaf755046 100644 --- a/warehouse/slave_worker_job_test.go +++ b/warehouse/slave_worker_job_test.go @@ -112,7 +112,7 @@ func (m *mockLoadFileWriter) Write(p []byte) (int, error) { return len(p), nil } -func (m *mockLoadFileWriter) WriteRow(r []interface{}) error { +func (m *mockLoadFileWriter) WriteRow([]interface{}) error { return errors.New("not implemented") } diff --git a/warehouse/slave_worker_test.go b/warehouse/slave_worker_test.go index 11cf02dfe8..57fe10baba 100644 --- a/warehouse/slave_worker_test.go +++ b/warehouse/slave_worker_test.go @@ -9,6 +9,8 @@ import ( "os" "testing" + "github.com/rudderlabs/rudder-server/services/notifier" + "github.com/rudderlabs/rudder-server/warehouse/encoding" "github.com/golang/mock/gomock" @@ -24,7 +26,6 @@ import ( "github.com/rudderlabs/rudder-go-kit/testhelper/docker/resource" backendconfig "github.com/rudderlabs/rudder-server/backend-config" mocksBackendConfig "github.com/rudderlabs/rudder-server/mocks/backend-config" - "github.com/rudderlabs/rudder-server/services/pgnotifier" "github.com/rudderlabs/rudder-server/testhelper/destination" "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/utils/pubsub" @@ -77,10 +78,10 @@ func TestSlaveWorker(t *testing.T) { ef := encoding.NewFactory(config.Default) t.Run("success", func(t *testing.T) { - subscribeCh := make(chan *pgnotifier.ClaimResponse) + subscribeCh := make(chan *notifier.ClaimJobResponse) defer close(subscribeCh) - notifier := &mockSlaveNotifier{ + slaveNotifier := &mockSlaveNotifier{ subscribeCh: subscribeCh, } @@ -88,7 +89,7 @@ func TestSlaveWorker(t *testing.T) { config.Default, logger.NOP, stats.Default, - notifier, + slaveNotifier, newBackendConfigManager(config.Default, nil, tenantManager, logger.NOP), newConstraintsManager(config.Default), ef, @@ -119,13 +120,15 @@ func TestSlaveWorker(t *testing.T) { payloadJson, err := json.Marshal(p) require.NoError(t, err) - claim := pgnotifier.Claim{ - ID: 1, - BatchID: uuid.New().String(), - Payload: payloadJson, - Status: "waiting", - Workspace: "test_workspace", - JobType: "upload", + claim := ¬ifier.ClaimJob{ + Job: ¬ifier.Job{ + ID: 1, + BatchID: uuid.New().String(), + Payload: payloadJson, + Status: model.Waiting, + WorkspaceIdentifier: "test_workspace", + Type: notifier.JobTypeUpload, + }, } claimedJobDone := make(chan struct{}) @@ -141,7 +144,7 @@ func TestSlaveWorker(t *testing.T) { var uploadPayload payload err = json.Unmarshal(response.Payload, &uploadPayload) require.NoError(t, err) - require.Equal(t, uploadPayload.BatchID, claim.BatchID) + require.Equal(t, uploadPayload.BatchID, claim.Job.BatchID) require.Equal(t, uploadPayload.UploadID, p.UploadID) require.Equal(t, uploadPayload.StagingFileID, p.StagingFileID) require.Equal(t, uploadPayload.StagingFileLocation, p.StagingFileLocation) @@ -175,10 +178,10 @@ func TestSlaveWorker(t *testing.T) { }) t.Run("clickhouse bool", func(t *testing.T) { - subscribeCh := make(chan *pgnotifier.ClaimResponse) + subscribeCh := make(chan *notifier.ClaimJobResponse) defer close(subscribeCh) - notifier := &mockSlaveNotifier{ + slaveNotifier := &mockSlaveNotifier{ subscribeCh: subscribeCh, } @@ -186,7 +189,7 @@ func TestSlaveWorker(t *testing.T) { config.Default, logger.NOP, stats.Default, - notifier, + slaveNotifier, newBackendConfigManager(config.Default, nil, tenantManager, logger.NOP), newConstraintsManager(config.Default), ef, @@ -217,13 +220,15 @@ func TestSlaveWorker(t *testing.T) { payloadJson, err := json.Marshal(p) require.NoError(t, err) - claim := pgnotifier.Claim{ - ID: 1, - BatchID: uuid.New().String(), - Payload: payloadJson, - Status: "waiting", - Workspace: "test_workspace", - JobType: "upload", + claim := ¬ifier.ClaimJob{ + Job: ¬ifier.Job{ + ID: 1, + BatchID: uuid.New().String(), + Payload: payloadJson, + Status: model.Waiting, + WorkspaceIdentifier: "test_workspace", + Type: notifier.JobTypeUpload, + }, } claimedJobDone := make(chan struct{}) @@ -297,10 +302,10 @@ func TestSlaveWorker(t *testing.T) { }) t.Run("schema limit exceeded", func(t *testing.T) { - subscribeCh := make(chan *pgnotifier.ClaimResponse) + subscribeCh := make(chan *notifier.ClaimJobResponse) defer close(subscribeCh) - notifier := &mockSlaveNotifier{ + slaveNotifier := &mockSlaveNotifier{ subscribeCh: subscribeCh, } @@ -311,7 +316,7 @@ func TestSlaveWorker(t *testing.T) { c, logger.NOP, stats.Default, - notifier, + slaveNotifier, newBackendConfigManager(config.Default, nil, tenantManager, logger.NOP), newConstraintsManager(config.Default), ef, @@ -342,20 +347,22 @@ func TestSlaveWorker(t *testing.T) { payloadJson, err := json.Marshal(p) require.NoError(t, err) - claim := pgnotifier.Claim{ - ID: 1, - BatchID: uuid.New().String(), - Payload: payloadJson, - Status: "waiting", - Workspace: "test_workspace", - JobType: "upload", + claimJob := ¬ifier.ClaimJob{ + Job: ¬ifier.Job{ + ID: 1, + BatchID: uuid.New().String(), + Payload: payloadJson, + Status: model.Waiting, + WorkspaceIdentifier: "test_workspace", + Type: notifier.JobTypeUpload, + }, } claimedJobDone := make(chan struct{}) go func() { defer close(claimedJobDone) - slaveWorker.processClaimedUploadJob(ctx, claim) + slaveWorker.processClaimedUploadJob(ctx, claimJob) }() response := <-subscribeCh @@ -365,10 +372,10 @@ func TestSlaveWorker(t *testing.T) { }) t.Run("discards", func(t *testing.T) { - subscribeCh := make(chan *pgnotifier.ClaimResponse) + subscribeCh := make(chan *notifier.ClaimJobResponse) defer close(subscribeCh) - notifier := &mockSlaveNotifier{ + slaveNotifier := &mockSlaveNotifier{ subscribeCh: subscribeCh, } @@ -376,7 +383,7 @@ func TestSlaveWorker(t *testing.T) { config.Default, logger.NOP, stats.Default, - notifier, + slaveNotifier, newBackendConfigManager(config.Default, nil, tenantManager, logger.NOP), newConstraintsManager(config.Default), ef, @@ -419,13 +426,15 @@ func TestSlaveWorker(t *testing.T) { payloadJson, err := json.Marshal(p) require.NoError(t, err) - claim := pgnotifier.Claim{ - ID: 1, - BatchID: uuid.New().String(), - Payload: payloadJson, - Status: "waiting", - Workspace: "test_workspace", - JobType: "upload", + claim := ¬ifier.ClaimJob{ + Job: ¬ifier.Job{ + ID: 1, + BatchID: uuid.New().String(), + Payload: payloadJson, + Status: model.Waiting, + WorkspaceIdentifier: "test_workspace", + Type: notifier.JobTypeUpload, + }, } claimedJobDone := make(chan struct{}) @@ -536,10 +545,10 @@ func TestSlaveWorker(t *testing.T) { <-setupCh t.Run("success", func(t *testing.T) { - subscribeCh := make(chan *pgnotifier.ClaimResponse) + subscribeCh := make(chan *notifier.ClaimJobResponse) defer close(subscribeCh) - notifier := &mockSlaveNotifier{ + slaveNotifier := &mockSlaveNotifier{ subscribeCh: subscribeCh, } @@ -550,7 +559,7 @@ func TestSlaveWorker(t *testing.T) { c, logger.NOP, stats.Default, - notifier, + slaveNotifier, bcm, newConstraintsManager(config.Default), ef, @@ -570,13 +579,15 @@ func TestSlaveWorker(t *testing.T) { payloadJson, err := json.Marshal(p) require.NoError(t, err) - claim := pgnotifier.Claim{ - ID: 1, - BatchID: uuid.New().String(), - Payload: payloadJson, - Status: "waiting", - Workspace: "test_workspace", - JobType: "async_job", + claim := ¬ifier.ClaimJob{ + Job: ¬ifier.Job{ + ID: 1, + BatchID: uuid.New().String(), + Payload: payloadJson, + Status: model.Waiting, + WorkspaceIdentifier: "test_workspace", + Type: notifier.JobTypeAsync, + }, } claimedJobDone := make(chan struct{}) @@ -600,10 +611,10 @@ func TestSlaveWorker(t *testing.T) { }) t.Run("invalid configurations", func(t *testing.T) { - subscribeCh := make(chan *pgnotifier.ClaimResponse) + subscribeCh := make(chan *notifier.ClaimJobResponse) defer close(subscribeCh) - notifier := &mockSlaveNotifier{ + slaveNotifier := &mockSlaveNotifier{ subscribeCh: subscribeCh, } @@ -614,7 +625,7 @@ func TestSlaveWorker(t *testing.T) { c, logger.NOP, stats.Default, - notifier, + slaveNotifier, bcm, newConstraintsManager(config.Default), ef, @@ -633,7 +644,7 @@ func TestSlaveWorker(t *testing.T) { sourceID: sourceID, destinationID: destinationID, jobType: "invalid_job_type", - expectedError: errors.New("invalid AsyncJobType"), + expectedError: errors.New("invalid asyncJob type"), }, { name: "invalid parameters", @@ -673,14 +684,16 @@ func TestSlaveWorker(t *testing.T) { payloadJson, err := json.Marshal(p) require.NoError(t, err) - claim := pgnotifier.Claim{ - ID: 1, - BatchID: uuid.New().String(), - Payload: payloadJson, - Status: "waiting", - Workspace: "test_workspace", - Attempt: 0, - JobType: "async_job", + claim := ¬ifier.ClaimJob{ + Job: ¬ifier.Job{ + ID: 1, + BatchID: uuid.New().String(), + Payload: payloadJson, + Status: model.Waiting, + WorkspaceIdentifier: "test_workspace", + Attempt: 0, + Type: notifier.JobTypeAsync, + }, } claimedJobDone := make(chan struct{}) diff --git a/warehouse/upload.go b/warehouse/upload.go index f51057f403..724a33f0dd 100644 --- a/warehouse/upload.go +++ b/warehouse/upload.go @@ -12,6 +12,8 @@ import ( "sync/atomic" "time" + "github.com/rudderlabs/rudder-server/services/notifier" + "github.com/rudderlabs/rudder-server/warehouse/encoding" "github.com/rudderlabs/rudder-go-kit/logger" @@ -26,7 +28,6 @@ import ( "github.com/rudderlabs/rudder-server/jobsdb" "github.com/rudderlabs/rudder-server/rruntime" "github.com/rudderlabs/rudder-server/services/alerta" - "github.com/rudderlabs/rudder-server/services/pgnotifier" "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/utils/timeutil" "github.com/rudderlabs/rudder-server/utils/types" @@ -75,7 +76,7 @@ type UploadJobFactory struct { destinationValidator validations.DestinationValidator loadFile *loadfiles.LoadFileGenerator recovery *service.Recovery - pgNotifier *pgnotifier.PGNotifier + notifier *notifier.Notifier conf *config.Config logger logger.Logger statsFactory stats.Stats @@ -91,7 +92,7 @@ type UploadJob struct { tableUploadsRepo *repo.TableUploads recovery *service.Recovery whManager manager.Manager - pgNotifier *pgnotifier.PGNotifier + notifier *notifier.Notifier schemaHandle *Schema conf *config.Config logger logger.Logger @@ -190,7 +191,7 @@ func (f *UploadJobFactory) NewUploadJob(ctx context.Context, dto *model.UploadJo dbHandle: f.dbHandle, loadfile: f.loadFile, recovery: f.recovery, - pgNotifier: f.pgNotifier, + notifier: f.notifier, whManager: whManager, destinationValidator: f.destinationValidator, conf: f.conf, diff --git a/warehouse/warehouse.go b/warehouse/warehouse.go index 7a9b2c99e1..ff04abccfd 100644 --- a/warehouse/warehouse.go +++ b/warehouse/warehouse.go @@ -11,6 +11,8 @@ import ( "sync" "time" + "github.com/rudderlabs/rudder-server/services/notifier" + "github.com/rudderlabs/rudder-server/warehouse/encoding" "github.com/cenkalti/backoff/v4" @@ -27,7 +29,6 @@ import ( "github.com/rudderlabs/rudder-server/info" "github.com/rudderlabs/rudder-server/services/controlplane" "github.com/rudderlabs/rudder-server/services/db" - "github.com/rudderlabs/rudder-server/services/pgnotifier" migrator "github.com/rudderlabs/rudder-server/services/sql-migrator" "github.com/rudderlabs/rudder-server/services/validators" "github.com/rudderlabs/rudder-server/utils/misc" @@ -46,7 +47,7 @@ var ( dbHandle *sql.DB wrappedDBHandle *sqlquerywrapper.DB dbHandleTimeout time.Duration - notifier pgnotifier.PGNotifier + notifierInstance *notifier.Notifier tenantManager *multitenant.Manager controlPlaneClient *controlplane.Client uploadFreqInS int64 @@ -178,7 +179,7 @@ func onConfigDataEvent(ctx context.Context, configMap map[string]backendconfig.C pkgLogger, stats.Default, wrappedDBHandle, - ¬ifier, + notifierInstance, tenantManager, controlPlaneClient, bcManager, @@ -377,10 +378,11 @@ func Start(ctx context.Context, app app.App) error { return g.Wait() } + workspaceIdentifier := fmt.Sprintf(`%s::%s`, config.GetKubeNamespace(), misc.GetMD5Hash(config.GetWorkspaceToken())) - notifier, err = pgnotifier.New(workspaceIdentifier, psqlInfo) - if err != nil { - return fmt.Errorf("cannot setup pgnotifier: %w", err) + notifierInstance = notifier.New(config.Default, pkgLogger, stats.Default, workspaceIdentifier) + if err := notifierInstance.Setup(ctx, psqlInfo); err != nil { + return fmt.Errorf("cannot setup notifier: %w", err) } // Setting up reporting client only if standalone master or embedded connecting to different DB for warehouse @@ -422,7 +424,7 @@ func Start(ctx context.Context, app app.App) error { cm := newConstraintsManager(config.Default) ef := encoding.NewFactory(config.Default) - slave := newSlave(config.Default, pkgLogger, stats.Default, ¬ifier, bcManager, cm, ef) + slave := newSlave(config.Default, pkgLogger, stats.Default, notifierInstance, bcManager, cm, ef) return slave.setupSlave(gCtx) })) } @@ -441,7 +443,7 @@ func Start(ctx context.Context, app app.App) error { ) g.Go(misc.WithBugsnagForWarehouse(func() error { - return notifier.ClearJobs(gCtx) + return notifierInstance.ClearJobs(gCtx) })) g.Go(misc.WithBugsnagForWarehouse(func() error { @@ -465,7 +467,7 @@ func Start(ctx context.Context, app app.App) error { return nil }) - asyncWh = jobs.InitWarehouseJobsAPI(gCtx, dbHandle, ¬ifier) + asyncWh = jobs.InitWarehouseJobsAPI(gCtx, dbHandle, notifierInstance) jobs.WithConfig(asyncWh, config.Default) g.Go(misc.WithBugsnagForWarehouse(func() error { @@ -476,11 +478,15 @@ func Start(ctx context.Context, app app.App) error { g.Go(func() error { api := NewApi( mode, config.Default, pkgLogger, stats.Default, - backendconfig.DefaultBackendConfig, wrappedDBHandle, ¬ifier, tenantManager, + backendconfig.DefaultBackendConfig, wrappedDBHandle, notifierInstance, tenantManager, bcManager, asyncWh, ) return api.Start(gCtx) }) + g.Go(func() error { + <-gCtx.Done() + return notifierInstance.Shutdown() + }) return g.Wait() }