From 6ad51e8e7411584d61a546b8645e4bc1b88c1fbb Mon Sep 17 00:00:00 2001 From: Saurav Malani Date: Thu, 8 Jun 2023 16:36:30 +0530 Subject: [PATCH] fix: gw transient errors crash (#3397) --- app/apphandlers/embeddedAppHandler.go | 5 + app/apphandlers/gatewayAppHandler.go | 2 + app/apphandlers/processorAppHandler.go | 5 + jobsdb/backup.go | 102 ++++++---- jobsdb/jobsdb.go | 253 +++++++++++++++---------- jobsdb/jobsdb_test.go | 4 +- jobsdb/jobsdb_utils.go | 17 +- jobsdb/migration.go | 155 ++++++++------- jobsdb/unionQuery.go | 2 +- mocks/jobsdb/mock_jobsdb.go | 19 +- mocks/jobsdb/mock_unionQuery.go | 5 +- router/batchrouter/batchrouter_test.go | 2 +- router/batchrouter/handle.go | 5 +- 13 files changed, 356 insertions(+), 220 deletions(-) diff --git a/app/apphandlers/embeddedAppHandler.go b/app/apphandlers/embeddedAppHandler.go index 6ceb9b64e5..a8e510796f 100644 --- a/app/apphandlers/embeddedAppHandler.go +++ b/app/apphandlers/embeddedAppHandler.go @@ -140,6 +140,7 @@ func (a *embeddedApp) StartRudderCore(ctx context.Context, options *app.Options) jobsdb.WithPreBackupHandlers(prebackupHandlers), jobsdb.WithDSLimit(&a.config.gatewayDSLimit), jobsdb.WithFileUploaderProvider(fileUploaderProvider), + jobsdb.WithSkipMaintenanceErr(config.GetBool("Gateway.jobsDB.skipMaintenanceError", true)), ) defer gwDBForProcessor.Close() routerDB := jobsdb.NewForReadWrite( @@ -148,6 +149,7 @@ func (a *embeddedApp) StartRudderCore(ctx context.Context, options *app.Options) jobsdb.WithPreBackupHandlers(prebackupHandlers), jobsdb.WithDSLimit(&a.config.routerDSLimit), jobsdb.WithFileUploaderProvider(fileUploaderProvider), + jobsdb.WithSkipMaintenanceErr(config.GetBool("Router.jobsDB.skipMaintenanceError", false)), ) defer routerDB.Close() batchRouterDB := jobsdb.NewForReadWrite( @@ -156,6 +158,7 @@ func (a *embeddedApp) StartRudderCore(ctx context.Context, options *app.Options) jobsdb.WithPreBackupHandlers(prebackupHandlers), jobsdb.WithDSLimit(&a.config.batchRouterDSLimit), jobsdb.WithFileUploaderProvider(fileUploaderProvider), + jobsdb.WithSkipMaintenanceErr(config.GetBool("BatchRouter.jobsDB.skipMaintenanceError", false)), ) defer batchRouterDB.Close() errDB := jobsdb.NewForReadWrite( @@ -164,11 +167,13 @@ func (a *embeddedApp) StartRudderCore(ctx context.Context, options *app.Options) jobsdb.WithPreBackupHandlers(prebackupHandlers), jobsdb.WithDSLimit(&a.config.processorDSLimit), jobsdb.WithFileUploaderProvider(fileUploaderProvider), + jobsdb.WithSkipMaintenanceErr(config.GetBool("Processor.jobsDB.skipMaintenanceError", false)), ) schemaDB := jobsdb.NewForReadWrite( "esch", jobsdb.WithClearDB(options.ClearDB), jobsdb.WithDSLimit(&a.config.processorDSLimit), + jobsdb.WithSkipMaintenanceErr(config.GetBool("Processor.jobsDB.skipMaintenanceError", false)), ) var tenantRouterDB jobsdb.MultiTenantJobsDB diff --git a/app/apphandlers/gatewayAppHandler.go b/app/apphandlers/gatewayAppHandler.go index 09b9d4d2ba..3e6dcead78 100644 --- a/app/apphandlers/gatewayAppHandler.go +++ b/app/apphandlers/gatewayAppHandler.go @@ -79,9 +79,11 @@ func (a *gatewayApp) StartRudderCore(ctx context.Context, options *app.Options) "gw", jobsdb.WithClearDB(options.ClearDB), jobsdb.WithDSLimit(&a.config.gatewayDSLimit), + jobsdb.WithSkipMaintenanceErr(config.GetBool("Gateway.jobsDB.skipMaintenanceError", true)), jobsdb.WithFileUploaderProvider(fileUploaderProvider), ) defer gatewayDB.Close() + if err := gatewayDB.Start(); err != nil { return fmt.Errorf("could not start gatewayDB: %w", err) } diff --git a/app/apphandlers/processorAppHandler.go b/app/apphandlers/processorAppHandler.go index 9d3bdb6818..db7220f1c0 100644 --- a/app/apphandlers/processorAppHandler.go +++ b/app/apphandlers/processorAppHandler.go @@ -148,6 +148,7 @@ func (a *processorApp) StartRudderCore(ctx context.Context, options *app.Options jobsdb.WithPreBackupHandlers(prebackupHandlers), jobsdb.WithDSLimit(&a.config.gatewayDSLimit), jobsdb.WithFileUploaderProvider(fileUploaderProvider), + jobsdb.WithSkipMaintenanceErr(config.GetBool("Gateway.jobsDB.skipMaintenanceError", true)), ) defer gwDBForProcessor.Close() routerDB := jobsdb.NewForReadWrite( @@ -156,6 +157,7 @@ func (a *processorApp) StartRudderCore(ctx context.Context, options *app.Options jobsdb.WithPreBackupHandlers(prebackupHandlers), jobsdb.WithDSLimit(&a.config.routerDSLimit), jobsdb.WithFileUploaderProvider(fileUploaderProvider), + jobsdb.WithSkipMaintenanceErr(config.GetBool("Router.jobsDB.skipMaintenanceError", false)), ) defer routerDB.Close() batchRouterDB := jobsdb.NewForReadWrite( @@ -164,6 +166,7 @@ func (a *processorApp) StartRudderCore(ctx context.Context, options *app.Options jobsdb.WithPreBackupHandlers(prebackupHandlers), jobsdb.WithDSLimit(&a.config.batchRouterDSLimit), jobsdb.WithFileUploaderProvider(fileUploaderProvider), + jobsdb.WithSkipMaintenanceErr(config.GetBool("BatchRouter.jobsDB.skipMaintenanceError", false)), ) defer batchRouterDB.Close() errDB := jobsdb.NewForReadWrite( @@ -172,11 +175,13 @@ func (a *processorApp) StartRudderCore(ctx context.Context, options *app.Options jobsdb.WithPreBackupHandlers(prebackupHandlers), jobsdb.WithDSLimit(&a.config.processorDSLimit), jobsdb.WithFileUploaderProvider(fileUploaderProvider), + jobsdb.WithSkipMaintenanceErr(config.GetBool("Processor.jobsDB.skipMaintenanceError", false)), ) schemaDB := jobsdb.NewForReadWrite( "esch", jobsdb.WithClearDB(options.ClearDB), jobsdb.WithDSLimit(&a.config.processorDSLimit), + jobsdb.WithFileUploaderProvider(fileUploaderProvider), ) var tenantRouterDB jobsdb.MultiTenantJobsDB diff --git a/jobsdb/backup.go b/jobsdb/backup.go index 3d2e561099..dc86c2f15f 100644 --- a/jobsdb/backup.go +++ b/jobsdb/backup.go @@ -51,37 +51,65 @@ func (jd *HandleT) backupDSLoop(ctx context.Context) { case <-ctx.Done(): return } - jd.logger.Debugf("backupDSLoop backup enabled %s", jd.tablePrefix) - backupDSRange := jd.getBackupDSRange() - // check if non-empty dataset is present to back up - // else continue - sleepMultiplier = 1 - if (dataSetRangeT{} == *backupDSRange) { - // sleep for more duration if no dataset is found - sleepMultiplier = 6 - continue - } + loop := func() error { + jd.logger.Debugf("backupDSLoop backup enabled %s", jd.tablePrefix) + backupDSRange, err := jd.getBackupDSRange() + if err != nil { + return fmt.Errorf("[JobsDB] :: Failed to get backup dataset range. Err: %w", err) + } + // check if non-empty dataset is present to back up + // else continue + sleepMultiplier = 1 + if (dataSetRangeT{} == *backupDSRange) { + // sleep for more duration if no dataset is found + sleepMultiplier = 6 + return nil + } - backupDS := backupDSRange.ds + backupDS := backupDSRange.ds - opPayload, err := json.Marshal(&backupDS) - jd.assertError(err) + opPayload, err := json.Marshal(&backupDS) + jd.assertError(err) - opID := jd.JournalMarkStart(backupDSOperation, opPayload) - err = jd.backupDS(ctx, backupDSRange) - if err != nil { - jd.logger.Errorf("[JobsDB] :: Failed to backup jobs table %v. Err: %v", backupDSRange.ds.JobStatusTable, err) + opID, err := jd.JournalMarkStart(backupDSOperation, opPayload) + if err != nil { + return fmt.Errorf("mark start of backup operation: %w", err) + } + err = jd.backupDS(ctx, backupDSRange) + if err != nil { + return fmt.Errorf("backup dataset: %w", err) + } + err = jd.JournalMarkDone(opID) + if err != nil { + return fmt.Errorf("mark end of backup operation: %w", err) + } + + // drop dataset after successfully uploading both jobs and jobs_status to s3 + opID, err = jd.JournalMarkStart(backupDropDSOperation, opPayload) + if err != nil { + return fmt.Errorf("mark start of drop backup operation: %w", err) + } + // Currently, we retry uploading a table for some time & if it fails. We only drop that table & not all `pre_drop` tables. + // So, in situation when new table creation rate is more than drop. We will still have pipe up issue. + // An easy way to fix this is, if at any point of time exponential retry fails then instead of just dropping that particular + // table drop all subsequent `pre_drop` table. As, most likely the upload of rest of the table will also fail with the same error. + err = jd.dropDS(backupDS) + if err != nil { + return fmt.Errorf(" drop dataset: %w", err) + } + err = jd.JournalMarkDone(opID) + if err != nil { + return fmt.Errorf("mark end of drop backup operation: %w", err) + } + return nil + } + if err := loop(); err != nil && ctx.Err() == nil { + if !jd.skipMaintenanceError { + panic(err) + } + jd.logger.Errorf("[JobsDB] :: Failed to backup dataset. Err: %s", err.Error()) } - jd.JournalMarkDone(opID) - // drop dataset after successfully uploading both jobs and jobs_status to s3 - opID = jd.JournalMarkStart(backupDropDSOperation, opPayload) - // Currently, we retry uploading a table for some time & if it fails. We only drop that table & not all `pre_drop` tables. - // So, in situation when new table creation rate is more than drop. We will still have pipe up issue. - // An easy way to fix this is, if at any point of time exponential retry fails then instead of just dropping that particular - // table drop all subsequent `pre_drop` table. As, most likely the upload of rest of the table will also fail with the same error. - jd.mustDropDS(backupDS) - jd.JournalMarkDone(opID) } } @@ -619,13 +647,15 @@ func (jd *HandleT) backupUploadWithExponentialBackoff(ctx context.Context, file return output, err } -func (jd *HandleT) getBackupDSRange() *dataSetRangeT { +func (jd *HandleT) getBackupDSRange() (*dataSetRangeT, error) { var backupDS dataSetT var backupDSRange dataSetRangeT // Read the table names from PG - tableNames := mustGetAllTableNames(jd, jd.dbHandle) - + tableNames, err := getAllTableNames(jd.dbHandle) + if err != nil { + return nil, fmt.Errorf("getAllTableNames: %w", err) + } // We check for job_status because that is renamed after job var dnumList []string for _, t := range tableNames { @@ -636,7 +666,7 @@ func (jd *HandleT) getBackupDSRange() *dataSetRangeT { } } if len(dnumList) == 0 { - return &backupDSRange + return &backupDSRange, nil } jd.statPreDropTableCount.Gauge(len(dnumList)) @@ -651,14 +681,18 @@ func (jd *HandleT) getBackupDSRange() *dataSetRangeT { var minID, maxID sql.NullInt64 jobIDSQLStatement := fmt.Sprintf(`SELECT MIN(job_id), MAX(job_id) from %q`, backupDS.JobTable) row := jd.dbHandle.QueryRow(jobIDSQLStatement) - err := row.Scan(&minID, &maxID) - jd.assertError(err) + err = row.Scan(&minID, &maxID) + if err != nil { + return nil, fmt.Errorf("getting min and max job_id: %w", err) + } var minCreatedAt, maxCreatedAt time.Time jobTimeSQLStatement := fmt.Sprintf(`SELECT MIN(created_at), MAX(created_at) from %q`, backupDS.JobTable) row = jd.dbHandle.QueryRow(jobTimeSQLStatement) err = row.Scan(&minCreatedAt, &maxCreatedAt) - jd.assertError(err) + if err != nil { + return nil, fmt.Errorf("getting min and max created_at: %w", err) + } backupDSRange = dataSetRangeT{ minJobID: minID.Int64, @@ -667,5 +701,5 @@ func (jd *HandleT) getBackupDSRange() *dataSetRangeT { endTime: maxCreatedAt.UnixNano() / int64(time.Millisecond), ds: backupDS, } - return &backupDSRange + return &backupDSRange, nil } diff --git a/jobsdb/jobsdb.go b/jobsdb/jobsdb.go index fb674582a8..17c7e8e854 100644 --- a/jobsdb/jobsdb.go +++ b/jobsdb/jobsdb.go @@ -296,7 +296,8 @@ type JobsDB interface { GetJournalEntries(opType string) (entries []JournalEntryT) JournalDeleteEntry(opID int64) - JournalMarkStart(opType string, opPayload json.RawMessage) int64 + JournalMarkStart(opType string, opPayload json.RawMessage) (int64, error) + JournalMarkDone(opID int64) error } /* @@ -429,17 +430,18 @@ type HandleT struct { statNewDSPeriod stats.Measurement newDSCreationTime time.Time invalidCacheKeyStat stats.Measurement - isStatNewDSPeriodInitialized bool statDropDSPeriod stats.Measurement dsDropTime time.Time unionQueryTime stats.Measurement - isStatDropDSPeriodInitialized bool logger logger.Logger writeCapacity chan struct{} readCapacity chan struct{} enableWriterQueue bool enableReaderQueue bool clearAll bool + skipMaintenanceError bool + isStatNewDSPeriodInitialized bool + isStatDropDSPeriodInitialized bool dsLimit *int maxReaders int maxWriters int @@ -460,7 +462,6 @@ type HandleT struct { TriggerAddNewDS func() <-chan time.Time TriggerMigrateDS func() <-chan time.Time migrateDSTimeout time.Duration - TriggerRefreshDS func() <-chan time.Time refreshDSTimeout time.Duration @@ -682,6 +683,12 @@ func WithDSLimit(limit *int) OptsFunc { } } +func WithSkipMaintenanceErr(ignore bool) OptsFunc { + return func(jd *HandleT) { + jd.skipMaintenanceError = ignore + } +} + func WithFileUploaderProvider(fileUploaderProvider fileuploader.Provider) OptsFunc { return func(jd *HandleT) { jd.fileUploaderProvider = fileUploaderProvider @@ -813,7 +820,8 @@ func (jd *HandleT) init() { // doesn't return the full list of datasets, only the rightmost two. // But we need to run the schema migration against all datasets, no matter // whether jobsdb is a writer or not. - datasets := getDSList(jd, jd.dbHandle, jd.tablePrefix) + datasets, err := getDSList(jd, jd.dbHandle, jd.tablePrefix) + jd.assertError(err) datasetIndices := make([]string, 0) for _, dataset := range datasets { @@ -840,7 +848,8 @@ func (jd *HandleT) init() { jd.runAlwaysChangesets(templateData) // finally refresh the dataset list to make sure [datasetList] field is populated - jd.refreshDSList(l) + _, err := jd.refreshDSList(l) + jd.assertError(err) }) return nil }) @@ -942,11 +951,9 @@ func (jd *HandleT) readerSetup(ctx context.Context, l lock.LockToken) { // This is a thread-safe operation. // Even if two different services (gateway and processor) perform this operation, there should not be any problem. jd.recoverFromJournal(ReadWrite) - - jd.refreshDSRangeList(l) + jd.assertError(jd.refreshDSRangeList(l)) g := jd.backgroundGroup - g.Go(misc.WithBugsnag(func() error { jd.refreshDSListLoop(ctx) return nil @@ -966,8 +973,8 @@ func (jd *HandleT) writerSetup(ctx context.Context, l lock.LockToken) { // This is a thread-safe operation. // Even if two different services (gateway and processor) perform this operation, there should not be any problem. jd.recoverFromJournal(ReadWrite) + jd.assertError(jd.refreshDSRangeList(l)) - jd.refreshDSRangeList(l) // If no DS present, add one if len(jd.getDSList()) == 0 { jd.addNewDS(l, newDataSet(jd.tablePrefix, jd.computeNewIdxForAppend(l))) @@ -1031,11 +1038,15 @@ func (jd *HandleT) getDSList() []dataSetT { } // refreshDSList refreshes the ds list from the database -func (jd *HandleT) refreshDSList(l lock.LockToken) []dataSetT { - jd.assert(l != nil, "cannot refresh DS list without a valid lock token") +func (jd *HandleT) refreshDSList(l lock.LockToken) ([]dataSetT, error) { + if l == nil { + return nil, fmt.Errorf("cannot refresh DS list without a valid lock token") + } + var err error // Reset the global list - jd.datasetList = getDSList(jd, jd.dbHandle, jd.tablePrefix) - + if jd.datasetList, err = getDSList(jd, jd.dbHandle, jd.tablePrefix); err != nil { + return nil, fmt.Errorf("getDSList %w", err) + } // report table count metrics before shrinking the datasetList jd.statTableCount.Gauge(len(jd.datasetList)) jd.statDSCount.Gauge(len(jd.datasetList)) @@ -1049,7 +1060,7 @@ func (jd *HandleT) refreshDSList(l lock.LockToken) []dataSetT { } } - return jd.datasetList + return jd.datasetList, nil } func (jd *HandleT) getDSRangeList() []dataSetRangeT { @@ -1057,25 +1068,34 @@ func (jd *HandleT) getDSRangeList() []dataSetRangeT { } // refreshDSRangeList first refreshes the DS list and then calculate the DS range list -func (jd *HandleT) refreshDSRangeList(l lock.LockToken) { - var minID, maxID sql.NullInt64 +func (jd *HandleT) refreshDSRangeList(l lock.LockToken) error { var prevMax int64 // At this point we must have write-locked dsListLock - dsList := jd.refreshDSList(l) - + dsList, err := jd.refreshDSList(l) + if err != nil { + return fmt.Errorf("refreshDSList %w", err) + } jd.datasetRangeList = nil for idx, ds := range dsList { jd.assert(ds.Index != "", "ds.Index is empty") - sqlStatement := fmt.Sprintf(`SELECT MIN(job_id), MAX(job_id) FROM %q`, ds.JobTable) - // Note: Using Query instead of QueryRow, because the sqlmock library doesn't have support for QueryRow - row := jd.dbHandle.QueryRow(sqlStatement) - err := row.Scan(&minID, &maxID) - jd.assertError(err) + getIndex := func() (sql.NullInt64, sql.NullInt64, error) { + var minID, maxID sql.NullInt64 + sqlStatement := fmt.Sprintf(`SELECT MIN(job_id), MAX(job_id) FROM %q`, ds.JobTable) + row := jd.dbHandle.QueryRow(sqlStatement) + if err := row.Scan(&minID, &maxID); err != nil { + return sql.NullInt64{}, sql.NullInt64{}, fmt.Errorf("scanning min & max jobID %w", err) + } + jd.logger.Debug(sqlStatement, minID, maxID) + return minID, maxID, nil + } + minID, maxID, err := getIndex() + if err != nil { + return err + } - jd.logger.Debug(sqlStatement, minID, maxID) // We store ranges EXCEPT for // 1. the last element (which is being actively written to) // 2. Migration target ds @@ -1101,6 +1121,7 @@ func (jd *HandleT) refreshDSRangeList(l lock.LockToken) { prevMax = maxID.Int64 } } + return nil } func (jd *HandleT) getTableRowCount(jobTable string) int { @@ -1207,10 +1228,12 @@ func newDataSet(tablePrefix, dsIdx string) dataSetT { func (jd *HandleT) addNewDS(l lock.LockToken, ds dataSetT) { err := jd.WithTx(func(tx *Tx) error { - return jd.addNewDSInTx(tx, l, jd.refreshDSList(l), ds) + dsList, err := jd.refreshDSList(l) + jd.assertError(err) + return jd.addNewDSInTx(tx, l, dsList, ds) }) jd.assertError(err) - jd.refreshDSRangeList(l) + jd.assertError(jd.refreshDSRangeList(l)) } // NOTE: If addNewDSInTx is directly called, make sure to explicitly call refreshDSRangeList(l) to update the DS list in cache, once transaction has completed. @@ -1250,14 +1273,9 @@ func (jd *HandleT) addDSInTx(tx *Tx, ds dataSetT) error { return jd.createDSInTx(tx, ds) } -// mustDropDS drops a dataset and panics if it fails to do so -func (jd *HandleT) mustDropDS(ds dataSetT) { - err := jd.dropDS(ds) - jd.assertError(err) -} - func (jd *HandleT) computeNewIdxForAppend(l lock.LockToken) string { - dList := jd.refreshDSList(l) + dList, err := jd.refreshDSList(l) + jd.assertError(err) return jd.doComputeNewIdxForAppend(dList) } @@ -1607,17 +1625,21 @@ func (jd *HandleT) dropAllBackupDS() error { func (jd *HandleT) dropAllDS(l lock.LockToken) error { var err error - dList := jd.refreshDSList(l) + dList, err := jd.refreshDSList(l) + if err != nil { + return fmt.Errorf("refreshDSList: %w", err) + } for _, ds := range dList { if err = jd.dropDS(ds); err != nil { - return err + return fmt.Errorf("dropDS: %w", err) } } // Update the lists - jd.refreshDSRangeList(l) - - return err + if err = jd.refreshDSRangeList(l); err != nil { + return fmt.Errorf("refreshDSRangeList: %w", err) + } + return nil } func (jd *HandleT) internalStoreJobsInTx(ctx context.Context, tx *Tx, ds dataSetT, jobList []*JobT) error { @@ -1653,7 +1675,10 @@ func (jd *HandleT) inStoreSafeCtx(ctx context.Context, f func() error) error { if err != nil && errors.Is(err, errStaleDsList) { jd.logger.Errorf("[JobsDB] :: Store failed: %v. Retrying after refreshing DS cache", errStaleDsList) if err := jd.dsListLock.WithLockInCtx(ctx, func(l lock.LockToken) error { - _ = jd.refreshDSList(l) + _, err = jd.refreshDSList(l) + if err != nil { + return fmt.Errorf("refreshing ds list: %w", err) + } return nil }); err != nil { return err @@ -2387,57 +2412,75 @@ func (jd *HandleT) addNewDSLoop(ctx context.Context) { return case <-jd.TriggerAddNewDS(): } - - // Adding a new DS only creates a new DS & updates the cache. It doesn't move any data so we only take the list lock. var dsListLock lock.LockToken var releaseDsListLock chan<- lock.LockToken - // start a transaction - err := jd.WithTx(func(tx *Tx) error { - return jd.withDistributedSharedLock(context.TODO(), tx, "schema_migrate", func() error { // cannot run while schema migration is running - return jd.withDistributedLock(context.TODO(), tx, "add_ds", func() error { // only one add_ds can run at a time - var err error - // We acquire the list lock only after we have acquired the advisory lock. - // We will release the list lock after the transaction ends, that's why we need to use an async lock - dsListLock, releaseDsListLock, err = jd.dsListLock.AsyncLockWithCtx(ctx) - if err != nil { - return err - } - // refresh ds list - var dsList []dataSetT - var nextDSIdx string - // make sure we are operating on the latest version of the list - dsList = getDSList(jd, tx, jd.tablePrefix) - latestDS := dsList[len(dsList)-1] - full, err := jd.checkIfFullDSInTx(tx, latestDS) - if err != nil { - return fmt.Errorf("error while checking if DS is full: %w", err) - } - // checkIfFullDS is true for last DS in the list - if full { - if _, err = tx.Exec(fmt.Sprintf(`LOCK TABLE %q IN EXCLUSIVE MODE;`, latestDS.JobTable)); err != nil { - return fmt.Errorf("error locking table %s: %w", latestDS.JobTable, err) + addNewDS := func() error { + defer func() { + if releaseDsListLock != nil && dsListLock != nil { + releaseDsListLock <- dsListLock + } + }() + // Adding a new DS only creates a new DS & updates the cache. It doesn't move any data so we only take the list lock. + // start a transaction + err := jd.WithTx(func(tx *Tx) error { + return jd.withDistributedSharedLock(context.TODO(), tx, "schema_migrate", func() error { // cannot run while schema migration is running + return jd.withDistributedLock(context.TODO(), tx, "add_ds", func() error { // only one add_ds can run at a time + var err error + // We acquire the list lock only after we have acquired the advisory lock. + // We will release the list lock after the transaction ends, that's why we need to use an async lock + dsListLock, releaseDsListLock, err = jd.dsListLock.AsyncLockWithCtx(ctx) + if err != nil { + return err } - - nextDSIdx = jd.doComputeNewIdxForAppend(dsList) - jd.logger.Infof("[[ %s : addNewDSLoop ]]: NewDS", jd.tablePrefix) - if err = jd.addNewDSInTx(tx, dsListLock, dsList, newDataSet(jd.tablePrefix, nextDSIdx)); err != nil { - return fmt.Errorf("error adding new DS: %w", err) + // refresh ds list + var dsList []dataSetT + var nextDSIdx string + // make sure we are operating on the latest version of the list + dsList, err = getDSList(jd, tx, jd.tablePrefix) + if err != nil { + return fmt.Errorf("getDSList: %w", err) } - - // previous DS should become read only - if err = setReadonlyDsInTx(tx, latestDS); err != nil { - return fmt.Errorf("error making dataset read only: %w", err) + latestDS := dsList[len(dsList)-1] + full, err := jd.checkIfFullDSInTx(tx, latestDS) + if err != nil { + return fmt.Errorf("checkIfFullDSInTx: %w", err) } - } - return nil + // checkIfFullDS is true for last DS in the list + if full { + if _, err = tx.Exec(fmt.Sprintf(`LOCK TABLE %q IN EXCLUSIVE MODE;`, latestDS.JobTable)); err != nil { + return fmt.Errorf("error locking table %s: %w", latestDS.JobTable, err) + } + + nextDSIdx = jd.doComputeNewIdxForAppend(dsList) + jd.logger.Infof("[[ %s : addNewDSLoop ]]: NewDS", jd.tablePrefix) + if err = jd.addNewDSInTx(tx, dsListLock, dsList, newDataSet(jd.tablePrefix, nextDSIdx)); err != nil { + return fmt.Errorf("error adding new DS: %w", err) + } + + // previous DS should become read only + if err = setReadonlyDsInTx(tx, latestDS); err != nil { + return fmt.Errorf("error making dataset read only: %w", err) + } + } + return nil + }) }) }) - }) - jd.assertError(err) - - // to get the updated DS list in the cache after createDS transaction has been committed. - jd.refreshDSRangeList(dsListLock) - releaseDsListLock <- dsListLock + if err != nil { + return fmt.Errorf("addNewDSLoop: %w", err) + } + // to get the updated DS list in the cache after createDS transaction has been committed. + if err = jd.refreshDSRangeList(dsListLock); err != nil { + return fmt.Errorf("refreshDSRangeList: %w", err) + } + return nil + } + if err := addNewDS(); err != nil { + if !jd.skipMaintenanceError && ctx.Err() == nil { + panic(err) + } + jd.logger.Errorw("addNewDSLoop", "error", err) + } } } @@ -2466,18 +2509,29 @@ func (jd *HandleT) refreshDSListLoop(ctx context.Context) { case <-ctx.Done(): return } - start := time.Now() - jd.logger.Debugw("Start", "operation", "refreshDSListLoop") - timeoutCtx, cancel := context.WithTimeout(ctx, jd.refreshDSTimeout) - err := jd.dsListLock.WithLockInCtx(timeoutCtx, func(l lock.LockToken) error { - jd.refreshDSRangeList(l) + refresh := func() error { + jd.logger.Debugw("Start", "operation", "refreshDSListLoop") + + timeoutCtx, cancel := context.WithTimeout(ctx, jd.refreshDSTimeout) + defer cancel() + + start := time.Now() + err := jd.dsListLock.WithLockInCtx(timeoutCtx, func(l lock.LockToken) error { + return jd.refreshDSRangeList(l) + }) + stats.Default.NewTaggedStat("refresh_ds_loop", stats.TimerType, stats.Tags{"customVal": jd.tablePrefix, "error": strconv.FormatBool(err != nil)}).Since(start) + if err != nil { + return fmt.Errorf("refreshDSRangeList: %w", err) + } + return nil - }) - cancel() - if err != nil { - jd.logger.Errorf("Failed to refresh ds list: %v", err) } - stats.Default.NewTaggedStat("refresh_ds_loop", stats.TimerType, stats.Tags{"customVal": jd.tablePrefix, "error": strconv.FormatBool(err != nil)}).Since(start) + if err := refresh(); err != nil { + if !jd.skipMaintenanceError && ctx.Err() == nil { + panic(err) + } + jd.logger.Errorw("refreshDSListLoop", "error", err) + } } } @@ -2512,15 +2566,13 @@ func (jd *HandleT) dropJournal() { jd.assertError(err) } -func (jd *HandleT) JournalMarkStart(opType string, opPayload json.RawMessage) int64 { +func (jd *HandleT) JournalMarkStart(opType string, opPayload json.RawMessage) (int64, error) { var opID int64 - err := jd.WithTx(func(tx *Tx) error { + return opID, jd.WithTx(func(tx *Tx) error { var err error opID, err = jd.JournalMarkStartInTx(tx, opType, opPayload) return err }) - jd.assertError(err) - return opID } func (jd *HandleT) JournalMarkStartInTx(tx *Tx, opType string, opPayload json.RawMessage) (int64, error) { @@ -2540,11 +2592,10 @@ func (jd *HandleT) JournalMarkStartInTx(tx *Tx, opType string, opPayload json.Ra } // JournalMarkDone marks the end of a journal action -func (jd *HandleT) JournalMarkDone(opID int64) { - err := jd.WithTx(func(tx *Tx) error { +func (jd *HandleT) JournalMarkDone(opID int64) error { + return jd.WithTx(func(tx *Tx) error { return jd.journalMarkDoneInTx(tx, opID) }) - jd.assertError(err) } // JournalMarkDoneInTx marks the end of a journal action in a transaction diff --git a/jobsdb/jobsdb_test.go b/jobsdb/jobsdb_test.go index 72fa609441..dc737f1ce4 100644 --- a/jobsdb/jobsdb_test.go +++ b/jobsdb/jobsdb_test.go @@ -400,7 +400,9 @@ func TestRefreshDSList(t *testing.T) { })) require.Equal(t, 1, len(jobsDB.getDSList()), "addDS should not refresh the ds list") jobsDB.dsListLock.WithLock(func(l lock.LockToken) { - require.Equal(t, 2, len(jobsDB.refreshDSList(l)), "after refreshing the ds list jobsDB should have a ds list size of 2") + dsList, err := jobsDB.refreshDSList(l) + require.NoError(t, err) + require.Equal(t, 2, len(dsList), "after refreshing the ds list jobsDB should have a ds list size of 2") }) } diff --git a/jobsdb/jobsdb_utils.go b/jobsdb/jobsdb_utils.go index 6240c36468..55bd945979 100644 --- a/jobsdb/jobsdb_utils.go +++ b/jobsdb/jobsdb_utils.go @@ -19,12 +19,14 @@ type sqlDbOrTx interface { Function to return an ordered list of datasets and datasetRanges Most callers use the in-memory list of dataset and datasetRanges */ -func getDSList(jd assertInterface, dbHandle sqlDbOrTx, tablePrefix string) []dataSetT { +func getDSList(jd assertInterface, dbHandle sqlDbOrTx, tablePrefix string) ([]dataSetT, error) { var datasetList []dataSetT // Read the table names from PG - tableNames := mustGetAllTableNames(jd, dbHandle) - + tableNames, err := getAllTableNames(dbHandle) + if err != nil { + return nil, fmt.Errorf("getAllTableNames: %w", err) + } // Tables are of form jobs_ and job_status_. Iterate // through them and sort them to produce and // ordered list of datasets @@ -63,7 +65,7 @@ func getDSList(jd assertInterface, dbHandle sqlDbOrTx, tablePrefix string) []dat }) } - return datasetList + return datasetList, nil } /* @@ -77,13 +79,6 @@ func sortDnumList(dnumList []string) { }) } -// mustGetAllTableNames gets all table names from Postgres and panics in case of an error -func mustGetAllTableNames(jd assertInterface, dbHandle sqlDbOrTx) []string { - tableNames, err := getAllTableNames(dbHandle) - jd.assertError(err) - return tableNames -} - // getAllTableNames gets all table names from Postgres func getAllTableNames(dbHandle sqlDbOrTx) ([]string, error) { var tableNames []string diff --git a/jobsdb/migration.go b/jobsdb/migration.go index 3c75bb5689..371838b22a 100644 --- a/jobsdb/migration.go +++ b/jobsdb/migration.go @@ -35,16 +35,24 @@ func (jd *HandleT) migrateDSLoop(ctx context.Context) { case <-ctx.Done(): return } - start := time.Now() - jd.logger.Debugw("Start", "operation", "migrateDSLoop") - timeoutCtx, cancel := context.WithTimeout(ctx, jd.migrateDSTimeout) - err := jd.doMigrateDS(timeoutCtx) - cancel() - if err != nil { - jd.logger.Errorf("Failed to migrate ds: %v", err) + migrate := func() error { + start := time.Now() + jd.logger.Debugw("Start", "operation", "migrateDSLoop") + timeoutCtx, cancel := context.WithTimeout(ctx, jd.migrateDSTimeout) + defer cancel() + err := jd.doMigrateDS(timeoutCtx) + stats.Default.NewTaggedStat("migration_loop", stats.TimerType, stats.Tags{"customVal": jd.tablePrefix, "error": strconv.FormatBool(err != nil)}).Since(start) + if err != nil { + return fmt.Errorf("failed to migrate ds: %w", err) + } + return nil + } + if err := migrate(); err != nil && ctx.Err() == nil { + if !jd.skipMaintenanceError { + panic(err) + } + jd.logger.Errorw("Failed to migrate ds", "error", err) } - stats.Default.NewTaggedStat("migration_loop", stats.TimerType, stats.Tags{"customVal": jd.tablePrefix, "error": strconv.FormatBool(err != nil)}).Since(start) - } } @@ -59,14 +67,16 @@ func (jd *HandleT) doMigrateDS(ctx context.Context) error { return err } - migrateFrom, pendingJobsCount, insertBeforeDS := jd.getMigrationList(dsList) - + migrateFrom, pendingJobsCount, insertBeforeDS, err := jd.getMigrationList(dsList) + if err != nil { + return fmt.Errorf("could not get migration list: %w", err) + } if len(migrateFrom) == 0 { return nil } var l lock.LockToken var lockChan chan<- lock.LockToken - err := jd.WithTx(func(tx *Tx) error { + err = jd.WithTx(func(tx *Tx) error { return jd.withDistributedSharedLock(ctx, tx, "schema_migrate", func() error { // cannot run while schema migration is running // Take the lock and run actual migration if !jd.dsMigrationLock.TryLockWithCtx(ctx) { @@ -75,17 +85,24 @@ func (jd *HandleT) doMigrateDS(ctx context.Context) error { defer jd.dsMigrationLock.Unlock() // repeat the check after the dsMigrationLock is acquired to get correct pending jobs count. // the pending jobs count cannot change after the dsMigrationLock is acquired - if migrateFrom, pendingJobsCount, insertBeforeDS = jd.getMigrationList(dsList); len(migrateFrom) == 0 { + migrateFrom, pendingJobsCount, insertBeforeDS, err = jd.getMigrationList(dsList) + if err != nil { + return fmt.Errorf("could not get migration list: %w", err) + } + if len(migrateFrom) == 0 { return nil } if pendingJobsCount > 0 { // migrate incomplete jobs var destination dataSetT - err := jd.dsListLock.WithLockInCtx(ctx, func(l lock.LockToken) error { - destination = newDataSet(jd.tablePrefix, jd.computeNewIdxForIntraNodeMigration(l, insertBeforeDS)) + if err := jd.dsListLock.WithLockInCtx(ctx, func(l lock.LockToken) error { + dsIdx, err := jd.computeNewIdxForIntraNodeMigration(l, insertBeforeDS) + if err != nil { + return fmt.Errorf("computing new index for intra-node migration: %w", err) + } + destination = newDataSet(jd.tablePrefix, dsIdx) return nil - }) - if err != nil { + }); err != nil { return err } @@ -95,16 +112,17 @@ func (jd *HandleT) doMigrateDS(ctx context.Context) error { opPayload, err := json.Marshal(&journalOpPayloadT{From: migrateFrom, To: destination}) if err != nil { - return err + return fmt.Errorf("failed to marshal journal payload: %w", err) } + opID, err := jd.JournalMarkStartInTx(tx, migrateCopyOperation, opPayload) if err != nil { - return err + return fmt.Errorf("failed to mark journal start: %w", err) } err = jd.addDSInTx(tx, destination) if err != nil { - return err + return fmt.Errorf("failed to add dataset: %w", err) } totalJobsMigrated := 0 @@ -113,42 +131,46 @@ func (jd *HandleT) doMigrateDS(ctx context.Context) error { jd.logger.Infof("[[ migrateDSLoop ]]: Migrate: %v to: %v", source, destination) noJobsMigrated, err = jd.migrateJobsInTx(ctx, tx, source, destination) if err != nil { - return err + return fmt.Errorf("failed to migrate jobs: %w", err) } totalJobsMigrated += noJobsMigrated } - err = jd.journalMarkDoneInTx(tx, opID) - if err != nil { - return err + if err = jd.journalMarkDoneInTx(tx, opID); err != nil { + return fmt.Errorf("failed to mark journal done: %w", err) } jd.logger.Infof("[[ migrateDSLoop ]]: Total migrated %d jobs", totalJobsMigrated) } opPayload, err := json.Marshal(&journalOpPayloadT{From: migrateFrom}) if err != nil { - return err + return fmt.Errorf("failed to marshal journal payload: %w", err) } opID, err := jd.JournalMarkStartInTx(tx, postMigrateDSOperation, opPayload) if err != nil { - return err + return fmt.Errorf("failed to mark journal start: %w", err) } // acquire an async lock, as this needs to be released after the transaction commits l, lockChan, err = jd.dsListLock.AsyncLockWithCtx(ctx) if err != nil { - return err + return fmt.Errorf("failed to acquire lock: %w", err) } - err = jd.postMigrateHandleDS(tx, migrateFrom) - if err != nil { - return err + if err = jd.postMigrateHandleDS(tx, migrateFrom); err != nil { + return fmt.Errorf("failed to post migrate handle ds: %w", err) } - return jd.journalMarkDoneInTx(tx, opID) + if err = jd.journalMarkDoneInTx(tx, opID); err != nil { + return fmt.Errorf("failed to mark journal done: %w", err) + } + return nil }) }) if l != nil { + defer func() { lockChan <- l }() if err == nil { - jd.refreshDSRangeList(l) + if err = jd.refreshDSRangeList(l); err != nil { + return fmt.Errorf("failed to refresh ds range list: %w", err) + } } - lockChan <- l + } return err } @@ -227,8 +249,7 @@ func (jd *HandleT) getCleanUpCandidates(ctx context.Context, dsList []dataSetT) estimate float64 tableName string ) - err = rows.Scan(&estimate, &tableName) - if err != nil { + if err = rows.Scan(&estimate, &tableName); err != nil { return nil, err } estimates[tableName] = estimate @@ -336,7 +357,7 @@ func (*HandleT) cleanStatusTable(ctx context.Context, tx *Tx, table string, canB // getMigrationList returns the list of datasets to migrate from, // the number of unfinished jobs contained in these datasets // and the dataset before which the new (migrated) dataset that will hold these jobs needs to be created -func (jd *HandleT) getMigrationList(dsList []dataSetT) (migrateFrom []dataSetT, pendingJobsCount int, insertBeforeDS dataSetT) { +func (jd *HandleT) getMigrationList(dsList []dataSetT) (migrateFrom []dataSetT, pendingJobsCount int, insertBeforeDS dataSetT, err error) { var ( liveDSCount, migrateDSProbeCount int // we don't want `maxDSSize` value to change, during dsList loop @@ -360,7 +381,11 @@ func (jd *HandleT) getMigrationList(dsList []dataSetT) (migrateFrom []dataSetT, break } - migrate, isSmall, recordsLeft := jd.checkIfMigrateDS(ds) + migrate, isSmall, recordsLeft, migrateErr := jd.checkIfMigrateDS(ds) + if migrateErr != nil { + err = migrateErr + return + } jd.logger.Debugf( "[[ migrateDSLoop ]]: Migrate check %v, is small: %v, records left: %d, ds: %v", migrate, isSmall, recordsLeft, ds, @@ -433,12 +458,14 @@ func (jd *HandleT) migrateJobsInTx(ctx context.Context, tx *Tx, srcDS, destDS da return int(numJobsMigrated), nil } -func (jd *HandleT) computeNewIdxForIntraNodeMigration(l lock.LockToken, insertBeforeDS dataSetT) string { // Within the node +func (jd *HandleT) computeNewIdxForIntraNodeMigration(l lock.LockToken, insertBeforeDS dataSetT) (string, error) { // Within the node jd.logger.Debugf("computeNewIdxForIntraNodeMigration, insertBeforeDS : %v", insertBeforeDS) - dList := jd.refreshDSList(l) + dList, err := jd.refreshDSList(l) + if err != nil { + return "", fmt.Errorf("refreshDSList: %w", err) + } jd.logger.Debugf("dlist in which we are trying to find %v is %v", insertBeforeDS, dList) newDSIdx := "" - var err error jd.assert(len(dList) > 0, fmt.Sprintf("len(dList): %d <= 0", len(dList))) for idx, ds := range dList { if ds.Index == insertBeforeDS.Index { @@ -447,7 +474,7 @@ func (jd *HandleT) computeNewIdxForIntraNodeMigration(l lock.LockToken, insertBe jd.assertError(err) } } - return newDSIdx + return newDSIdx, nil } func (jd *HandleT) postMigrateHandleDS(tx *Tx, migrateFrom []dataSetT) error { @@ -491,7 +518,7 @@ func computeInsertIdx(beforeIndex, afterIndex string) (string, error) { // We migrate the DB ONCE most of the jobs have been processed (succeeded/aborted) // Or when the job_status table gets too big because of lots of retries/failures func (jd *HandleT) checkIfMigrateDS(ds dataSetT) ( - migrate, small bool, recordsLeft int, + migrate, small bool, recordsLeft int, err error, ) { defer jd.getTimerStat( "migration_ds_check", @@ -500,26 +527,26 @@ func (jd *HandleT) checkIfMigrateDS(ds dataSetT) ( var delCount, totalCount, statusCount int sqlStatement := fmt.Sprintf(`SELECT COUNT(*) from %q`, ds.JobTable) - row := jd.dbHandle.QueryRow(sqlStatement) - err := row.Scan(&totalCount) - jd.assertError(err) + if err = jd.dbHandle.QueryRow(sqlStatement).Scan(&totalCount); err != nil { + return false, false, 0, fmt.Errorf("error getting count of jobs in %s: %w", ds.JobTable, err) + } // Jobs which have either succeeded or expired sqlStatement = fmt.Sprintf(`SELECT COUNT(DISTINCT(job_id)) from %q WHERE job_state IN ('%s')`, ds.JobStatusTable, strings.Join(validTerminalStates, "', '")) - row = jd.dbHandle.QueryRow(sqlStatement) - err = row.Scan(&delCount) - jd.assertError(err) + if err = jd.dbHandle.QueryRow(sqlStatement).Scan(&delCount); err != nil { + return false, false, 0, fmt.Errorf("error getting count of jobs in %s: %w", ds.JobStatusTable, err) + } if jobStatusCountMigrationCheck { // Total number of job status. If this table grows too big (e.g. a lot of retries) // we migrate to a new table and get rid of old job status sqlStatement = fmt.Sprintf(`SELECT COUNT(*) from %q`, ds.JobStatusTable) - row = jd.dbHandle.QueryRow(sqlStatement) - err = row.Scan(&statusCount) - jd.assertError(err) + if err = jd.dbHandle.QueryRow(sqlStatement).Scan(&statusCount); err != nil { + return false, false, 0, fmt.Errorf("error getting count of jobs in %s: %w", ds.JobStatusTable, err) + } } recordsLeft = totalCount - delCount @@ -527,12 +554,12 @@ func (jd *HandleT) checkIfMigrateDS(ds dataSetT) ( if jd.MinDSRetentionPeriod > 0 { var maxCreatedAt time.Time sqlStatement = fmt.Sprintf(`SELECT MAX(created_at) from %q`, ds.JobTable) - row = jd.dbHandle.QueryRow(sqlStatement) - err = row.Scan(&maxCreatedAt) - jd.assertError(err) + if err = jd.dbHandle.QueryRow(sqlStatement).Scan(&maxCreatedAt); err != nil { + return false, false, 0, fmt.Errorf("error getting max created_at from %s: %w", ds.JobTable, err) + } if time.Since(maxCreatedAt) < jd.MinDSRetentionPeriod { - return false, false, recordsLeft + return false, false, recordsLeft, nil } } @@ -543,15 +570,11 @@ func (jd *HandleT) checkIfMigrateDS(ds dataSetT) ( FROM %q WHERE job_state = ANY($1) and exec_time < $2)`, ds.JobStatusTable) - stmt, err := jd.dbHandle.Prepare(sqlStatement) - jd.assertError(err) - defer func() { _ = stmt.Close() }() - - row = stmt.QueryRow(pq.Array(validTerminalStates), time.Now().Add(-1*jd.MaxDSRetentionPeriod)) - err = row.Scan(&terminalJobsExist) - jd.assertError(err) + if err = jd.dbHandle.QueryRow(sqlStatement, pq.Array(validTerminalStates), time.Now().Add(-1*jd.MaxDSRetentionPeriod)).Scan(&terminalJobsExist); err != nil { + return false, false, 0, fmt.Errorf("checking terminalJobsExist %s: %w", ds.JobStatusTable, err) + } if terminalJobsExist { - return true, false, recordsLeft + return true, false, recordsLeft, nil } } @@ -561,12 +584,12 @@ func (jd *HandleT) checkIfMigrateDS(ds dataSetT) ( } if float64(delCount)/float64(totalCount) > jobDoneMigrateThres || (float64(statusCount)/float64(totalCount) > jobStatusMigrateThres) { - return true, isSmall(), recordsLeft + return true, isSmall(), recordsLeft, nil } if isSmall() { - return true, true, recordsLeft + return true, true, recordsLeft, nil } - return false, false, recordsLeft + return false, false, recordsLeft, nil } diff --git a/jobsdb/unionQuery.go b/jobsdb/unionQuery.go index d279cfde76..11b20c8e99 100644 --- a/jobsdb/unionQuery.go +++ b/jobsdb/unionQuery.go @@ -50,7 +50,7 @@ type MultiTenantJobsDB interface { FailExecuting() GetJournalEntries(opType string) (entries []JournalEntryT) - JournalMarkStart(opType string, opPayload json.RawMessage) int64 + JournalMarkStart(opType string, opPayload json.RawMessage) (int64, error) JournalDeleteEntry(opID int64) GetPileUpCounts(context.Context) (map[string]map[string]int, error) GetActiveWorkspaces(ctx context.Context, customVal string) (workspaces []string, err error) diff --git a/mocks/jobsdb/mock_jobsdb.go b/mocks/jobsdb/mock_jobsdb.go index 6eefa45303..87c9d3c04f 100644 --- a/mocks/jobsdb/mock_jobsdb.go +++ b/mocks/jobsdb/mock_jobsdb.go @@ -236,12 +236,27 @@ func (mr *MockJobsDBMockRecorder) JournalDeleteEntry(arg0 interface{}) *gomock.C return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "JournalDeleteEntry", reflect.TypeOf((*MockJobsDB)(nil).JournalDeleteEntry), arg0) } +// JournalMarkDone mocks base method. +func (m *MockJobsDB) JournalMarkDone(arg0 int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "JournalMarkDone", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// JournalMarkDone indicates an expected call of JournalMarkDone. +func (mr *MockJobsDBMockRecorder) JournalMarkDone(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "JournalMarkDone", reflect.TypeOf((*MockJobsDB)(nil).JournalMarkDone), arg0) +} + // JournalMarkStart mocks base method. -func (m *MockJobsDB) JournalMarkStart(arg0 string, arg1 json.RawMessage) int64 { +func (m *MockJobsDB) JournalMarkStart(arg0 string, arg1 json.RawMessage) (int64, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "JournalMarkStart", arg0, arg1) ret0, _ := ret[0].(int64) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // JournalMarkStart indicates an expected call of JournalMarkStart. diff --git a/mocks/jobsdb/mock_unionQuery.go b/mocks/jobsdb/mock_unionQuery.go index 2e0576ba9f..6da1fa571a 100644 --- a/mocks/jobsdb/mock_unionQuery.go +++ b/mocks/jobsdb/mock_unionQuery.go @@ -147,11 +147,12 @@ func (mr *MockMultiTenantJobsDBMockRecorder) JournalDeleteEntry(arg0 interface{} } // JournalMarkStart mocks base method. -func (m *MockMultiTenantJobsDB) JournalMarkStart(arg0 string, arg1 json.RawMessage) int64 { +func (m *MockMultiTenantJobsDB) JournalMarkStart(arg0 string, arg1 json.RawMessage) (int64, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "JournalMarkStart", arg0, arg1) ret0, _ := ret[0].(int64) - return ret0 + ret1, _ := ret[1].(error) + return ret0, ret1 } // JournalMarkStart indicates an expected call of JournalMarkStart. diff --git a/router/batchrouter/batchrouter_test.go b/router/batchrouter/batchrouter_test.go index b7219a7f1b..e8cc5931df 100644 --- a/router/batchrouter/batchrouter_test.go +++ b/router/batchrouter/batchrouter_test.go @@ -269,7 +269,7 @@ var _ = Describe("BatchRouter", func() { assertJobStatus(unprocessedJobsList[0], statuses[1], jobsdb.Executing.State, `{}`, 1) }).Return(nil) - c.mockBatchRouterJobsDB.EXPECT().JournalMarkStart(gomock.Any(), gomock.Any()).Times(1).Return(int64(1)) + c.mockBatchRouterJobsDB.EXPECT().JournalMarkStart(gomock.Any(), gomock.Any()).Times(1).Return(int64(1), nil) c.mockBatchRouterJobsDB.EXPECT().WithUpdateSafeTx(gomock.Any(), gomock.Any()).Times(1).Do(func(ctx context.Context, f func(tx jobsdb.UpdateSafeTx) error) { _ = f(jobsdb.EmptyUpdateSafeTx()) diff --git a/router/batchrouter/handle.go b/router/batchrouter/handle.go index 98051f9f14..5f54098c86 100644 --- a/router/batchrouter/handle.go +++ b/router/batchrouter/handle.go @@ -422,7 +422,10 @@ func (brt *Handle) upload(provider string, batchJobs *BatchedJobs, isWarehouse b DestinationID: batchJobs.Connection.Destination.ID, DestinationType: batchJobs.Connection.Destination.DestinationDefinition.Name, }) - opID = brt.jobsDB.JournalMarkStart(jobsdb.RawDataDestUploadOperation, opPayload) + opID, err = brt.jobsDB.JournalMarkStart(jobsdb.RawDataDestUploadOperation, opPayload) + if err != nil { + panic(fmt.Errorf("BRT: Error marking start of upload operation in journal: %v", err)) + } } startTime := time.Now()