From 990a40510c1ecf88467ccb858d981a1643d3e590 Mon Sep 17 00:00:00 2001 From: Akash Chetty Date: Thu, 25 May 2023 19:05:22 +0530 Subject: [PATCH] chore: pass context (#3326) --- warehouse/admin.go | 5 +- warehouse/api.go | 65 +++--- warehouse/api_test.go | 8 +- warehouse/archive/archiver.go | 24 +- warehouse/client/client.go | 12 +- warehouse/client/controlplane/client_test.go | 6 +- warehouse/identities.go | 32 +-- warehouse/identity/identity.go | 127 +++++----- .../azure-synapse/azure-synapse.go | 74 +++--- warehouse/integrations/bigquery/bigquery.go | 182 +++++++-------- .../integrations/bigquery/bigquery_test.go | 13 +- .../bigquery/middleware/middleware_test.go | 5 +- .../integrations/clickhouse/clickhouse.go | 68 +++--- .../clickhouse/clickhouse_test.go | 102 ++++---- warehouse/integrations/datalake/datalake.go | 42 ++-- .../datalake/schema-repository/glue.go | 35 +-- .../datalake/schema-repository/glue_test.go | 16 +- .../datalake/schema-repository/local.go | 27 +-- .../datalake/schema-repository/local_test.go | 46 ++-- .../schema-repository/schema_repository.go | 16 +- .../deltalake-native/deltalake.go | 92 ++++---- .../integrations/deltalake/client/client.go | 7 +- warehouse/integrations/deltalake/deltalake.go | 141 ++++++----- warehouse/integrations/manager/manager.go | 32 +-- warehouse/integrations/mssql/mssql.go | 106 ++++----- .../integrations/postgres-legacy/postgres.go | 112 +++++---- warehouse/integrations/postgres/load.go | 2 +- warehouse/integrations/postgres/load_test.go | 28 ++- warehouse/integrations/postgres/postgres.go | 66 +++--- warehouse/integrations/redshift/redshift.go | 88 +++---- .../integrations/redshift/redshift_test.go | 4 +- warehouse/integrations/snowflake/snowflake.go | 122 +++++----- warehouse/integrations/testhelper/verify.go | 3 +- warehouse/internal/repo/staging_test.go | 2 +- .../loadfiles/downloader/downloader.go | 2 +- .../loadfiles/downloader/downloader_test.go | 30 ++- warehouse/internal/service/recovery.go | 4 +- warehouse/internal/service/recovery_test.go | 2 +- warehouse/jobs/jobs.go | 11 +- warehouse/schema.go | 26 +-- warehouse/schema_test.go | 16 +- warehouse/slave.go | 42 ++-- warehouse/stats.go | 13 +- warehouse/stats_test.go | 6 +- warehouse/upload.go | 218 +++++++++--------- warehouse/upload_test.go | 33 ++- warehouse/utils/utils.go | 10 +- warehouse/validations/steps.go | 4 +- warehouse/validations/validate.go | 119 +++++----- warehouse/validations/validate_test.go | 41 ++-- warehouse/validations/validations.go | 7 +- warehouse/validations/validations_test.go | 21 +- warehouse/warehouse.go | 42 ++-- warehouse/warehouse_test.go | 2 +- warehouse/warehousegrpc.go | 18 +- warehouse/warehousegrpc_test.go | 10 +- 56 files changed, 1244 insertions(+), 1143 deletions(-) diff --git a/warehouse/admin.go b/warehouse/admin.go index 13affbea01..3edd5c1072 100644 --- a/warehouse/admin.go +++ b/warehouse/admin.go @@ -1,6 +1,7 @@ package warehouse import ( + "context" "errors" "fmt" "strings" @@ -78,7 +79,7 @@ func (*WarehouseAdmin) Query(s QueryInput, reply *warehouseutils.QueryResult) er if err != nil { return err } - client, err := whManager.Connect(warehouse) + client, err := whManager.Connect(context.TODO(), warehouse) if err != nil { return err } @@ -109,7 +110,7 @@ func (*WarehouseAdmin) ConfigurationTest(s ConfigurationTestInput, reply *Config pkgLogger.Infof(`[WH Admin]: Validating warehouse destination: %s:%s`, warehouse.Type, warehouse.Destination.ID) destinationValidator := validations.NewDestinationValidator() - res := destinationValidator.Validate(&warehouse.Destination) + res := destinationValidator.Validate(context.TODO(), &warehouse.Destination) reply.Valid = res.Success reply.Error = res.Error diff --git a/warehouse/api.go b/warehouse/api.go index a65c5cd453..24ed09e071 100644 --- a/warehouse/api.go +++ b/warehouse/api.go @@ -184,7 +184,7 @@ var statusMap = map[string]string{ "failed": "%failed%", } -func (uploadsReq *UploadsReq) GetWhUploads() (uploadsRes *proto.WHUploadsResponse, err error) { +func (uploadsReq *UploadsReq) GetWhUploads(ctx context.Context) (uploadsRes *proto.WHUploadsResponse, err error) { uploadsRes = &proto.WHUploadsResponse{ Uploads: make([]*proto.WHUploadResponse, 0), } @@ -206,15 +206,15 @@ func (uploadsReq *UploadsReq) GetWhUploads() (uploadsRes *proto.WHUploadsRespons } if UploadAPI.isMultiWorkspace { - uploadsRes, err = uploadsReq.warehouseUploadsForHosted(authorizedSourceIDs, `id, source_id, destination_id, destination_type, namespace, status, error, first_event_at, last_event_at, last_exec_at, updated_at, timings, metadata->>'nextRetryTime', metadata->>'archivedStagingAndLoadFiles'`) + uploadsRes, err = uploadsReq.warehouseUploadsForHosted(ctx, authorizedSourceIDs, `id, source_id, destination_id, destination_type, namespace, status, error, first_event_at, last_event_at, last_exec_at, updated_at, timings, metadata->>'nextRetryTime', metadata->>'archivedStagingAndLoadFiles'`) return } - uploadsRes, err = uploadsReq.warehouseUploads(`id, source_id, destination_id, destination_type, namespace, status, error, first_event_at, last_event_at, last_exec_at, updated_at, timings, metadata->>'nextRetryTime', metadata->>'archivedStagingAndLoadFiles'`) + uploadsRes, err = uploadsReq.warehouseUploads(ctx, `id, source_id, destination_id, destination_type, namespace, status, error, first_event_at, last_event_at, last_exec_at, updated_at, timings, metadata->>'nextRetryTime', metadata->>'archivedStagingAndLoadFiles'`) return } -func (uploadsReq *UploadsReq) TriggerWhUploads() (response *proto.TriggerWhUploadsResponse, err error) { +func (uploadsReq *UploadsReq) TriggerWhUploads(ctx context.Context) (response *proto.TriggerWhUploadsResponse, err error) { err = uploadsReq.validateReq() defer func() { if err != nil { @@ -245,7 +245,7 @@ func (uploadsReq *UploadsReq) TriggerWhUploads() (response *proto.TriggerWhUploa return } if pendingUploadCount == int64(0) { - pendingStagingFileCount, err = repo.NewStagingFiles(dbHandle).CountPendingForDestination(context.TODO(), uploadsReq.DestinationID) + pendingStagingFileCount, err = repo.NewStagingFiles(dbHandle).CountPendingForDestination(ctx, uploadsReq.DestinationID) if err != nil { return } @@ -269,7 +269,7 @@ func (uploadsReq *UploadsReq) TriggerWhUploads() (response *proto.TriggerWhUploa return } -func (uploadReq *UploadReq) GetWHUpload() (*proto.WHUploadResponse, error) { +func (uploadReq *UploadReq) GetWHUpload(ctx context.Context) (*proto.WHUploadResponse, error) { err := uploadReq.validateReq() if err != nil { return &proto.WHUploadResponse{}, status.Errorf(codes.Code(code.Code_INVALID_ARGUMENT), err.Error()) @@ -287,7 +287,7 @@ func (uploadReq *UploadReq) GetWHUpload() (*proto.WHUploadResponse, error) { isUploadArchived sql.NullBool ) - row := uploadReq.API.dbHandle.QueryRow(query) + row := uploadReq.API.dbHandle.QueryRowContext(ctx, query) err = row.Scan( &upload.Id, &upload.SourceId, @@ -358,7 +358,7 @@ func (uploadReq *UploadReq) GetWHUpload() (*proto.WHUploadResponse, error) { Name: "", API: uploadReq.API, } - upload.Tables, err = tableUploadReq.GetWhTableUploads() + upload.Tables, err = tableUploadReq.GetWhTableUploads(ctx) if err != nil { return &proto.WHUploadResponse{}, status.Errorf(codes.Code(code.Code_INTERNAL), err.Error()) } @@ -366,7 +366,7 @@ func (uploadReq *UploadReq) GetWHUpload() (*proto.WHUploadResponse, error) { return &upload, nil } -func (uploadReq *UploadReq) TriggerWHUpload() (response *proto.TriggerWhUploadsResponse, err error) { +func (uploadReq *UploadReq) TriggerWHUpload(ctx context.Context) (response *proto.TriggerWhUploadsResponse, err error) { err = uploadReq.validateReq() defer func() { if err != nil { @@ -380,7 +380,7 @@ func (uploadReq *UploadReq) TriggerWHUpload() (response *proto.TriggerWhUploadsR return } - upload, err := repo.NewUploads(uploadReq.API.dbHandle).Get(context.TODO(), uploadReq.UploadId) + upload, err := repo.NewUploads(uploadReq.API.dbHandle).Get(ctx, uploadReq.UploadId) if err == model.ErrUploadNotFound { return &proto.TriggerWhUploadsResponse{ Message: NoSuchSync, @@ -401,7 +401,8 @@ func (uploadReq *UploadReq) TriggerWHUpload() (response *proto.TriggerWhUploadsR uploadJobT := UploadJob{ upload: upload, dbHandle: uploadReq.API.dbHandle, - Now: timeutil.Now, + now: timeutil.Now, + ctx: ctx, } err = uploadJobT.triggerUploadNow() @@ -415,14 +416,14 @@ func (uploadReq *UploadReq) TriggerWHUpload() (response *proto.TriggerWhUploadsR return } -func (tableUploadReq TableUploadReq) GetWhTableUploads() ([]*proto.WHTable, error) { +func (tableUploadReq TableUploadReq) GetWhTableUploads(ctx context.Context) ([]*proto.WHTable, error) { err := tableUploadReq.validateReq() if err != nil { return []*proto.WHTable{}, err } query := tableUploadReq.generateQuery(`id, wh_upload_id, table_name, total_events, status, error, last_exec_time, updated_at`) tableUploadReq.API.log.Debug(query) - rows, err := tableUploadReq.API.dbHandle.Query(query) + rows, err := tableUploadReq.API.dbHandle.QueryContext(ctx, query) if err != nil { tableUploadReq.API.log.Errorf(err.Error()) return []*proto.WHTable{}, err @@ -544,12 +545,12 @@ func (uploadsReq *UploadsReq) authorizedSources() (sourceIDs []string) { return sourceIDs } -func (uploadsReq *UploadsReq) getUploadsFromDB(isMultiWorkspace bool, query string) ([]*proto.WHUploadResponse, int32, error) { +func (uploadsReq *UploadsReq) getUploadsFromDB(ctx context.Context, isMultiWorkspace bool, query string) ([]*proto.WHUploadResponse, int32, error) { var totalUploadCount int32 var err error uploads := make([]*proto.WHUploadResponse, 0) - rows, err := uploadsReq.API.dbHandle.Query(query) + rows, err := uploadsReq.API.dbHandle.QueryContext(ctx, query) if err != nil { uploadsReq.API.log.Errorf(err.Error()) return nil, 0, err @@ -651,7 +652,7 @@ func (uploadsReq *UploadsReq) getUploadsFromDB(isMultiWorkspace bool, query stri return uploads, totalUploadCount, err } -func (uploadsReq *UploadsReq) getTotalUploadCount(whereClause string) (int32, error) { +func (uploadsReq *UploadsReq) getTotalUploadCount(ctx context.Context, whereClause string) (int32, error) { var totalUploadCount int32 query := fmt.Sprintf(` select @@ -665,12 +666,12 @@ func (uploadsReq *UploadsReq) getTotalUploadCount(whereClause string) (int32, er query += fmt.Sprintf(` %s`, whereClause) } uploadsReq.API.log.Info(query) - err := uploadsReq.API.dbHandle.QueryRow(query).Scan(&totalUploadCount) + err := uploadsReq.API.dbHandle.QueryRowContext(ctx, query).Scan(&totalUploadCount) return totalUploadCount, err } // for hosted workspaces - we get the uploads and the total upload count using the same query -func (uploadsReq *UploadsReq) warehouseUploadsForHosted(authorizedSourceIDs []string, selectFields string) (uploadsRes *proto.WHUploadsResponse, err error) { +func (uploadsReq *UploadsReq) warehouseUploadsForHosted(ctx context.Context, authorizedSourceIDs []string, selectFields string) (uploadsRes *proto.WHUploadsResponse, err error) { var ( uploads []*proto.WHUploadResponse totalUploadCount int32 @@ -724,7 +725,7 @@ func (uploadsReq *UploadsReq) warehouseUploadsForHosted(authorizedSourceIDs []st uploadsReq.API.log.Info(query) // get uploads from db - uploads, totalUploadCount, err = uploadsReq.getUploadsFromDB(true, query) + uploads, totalUploadCount, err = uploadsReq.getUploadsFromDB(ctx, true, query) if err != nil { uploadsReq.API.log.Errorf(err.Error()) return &proto.WHUploadsResponse{}, err @@ -743,7 +744,7 @@ func (uploadsReq *UploadsReq) warehouseUploadsForHosted(authorizedSourceIDs []st } // for non hosted workspaces - we get the uploads and the total upload count using separate queries -func (uploadsReq *UploadsReq) warehouseUploads(selectFields string) (uploadsRes *proto.WHUploadsResponse, err error) { +func (uploadsReq *UploadsReq) warehouseUploads(ctx context.Context, selectFields string) (uploadsRes *proto.WHUploadsResponse, err error) { var ( uploads []*proto.WHUploadResponse totalUploadCount int32 @@ -784,13 +785,13 @@ func (uploadsReq *UploadsReq) warehouseUploads(selectFields string) (uploadsRes // we get uploads for non hosted workspaces in two steps // this is because getting this info via 2 queries is faster than getting it via one query(using the 'count(*) OVER()' clause) // step1 - get all uploads - uploads, _, err = uploadsReq.getUploadsFromDB(false, query) + uploads, _, err = uploadsReq.getUploadsFromDB(ctx, false, query) if err != nil { uploadsReq.API.log.Errorf(err.Error()) return &proto.WHUploadsResponse{}, err } // step2 - get total upload count - totalUploadCount, err = uploadsReq.getTotalUploadCount(whereClause) + totalUploadCount, err = uploadsReq.getTotalUploadCount(ctx, whereClause) if err != nil { uploadsReq.API.log.Errorf(err.Error()) return &proto.WHUploadsResponse{}, err @@ -825,8 +826,13 @@ func checkMapForValidKey(configMap map[string]interface{}, key string) bool { func validateObjectStorage(ctx context.Context, request *ObjectStorageValidationRequest) error { pkgLogger.Infof("Received call to validate object storage for type: %s\n", request.Type) + settings, err := getFileManagerSettings(ctx, request.Type, request.Config) + if err != nil { + return fmt.Errorf("unable to create file manager settings: \n%s", err.Error()) + } + factory := &filemanager.FileManagerFactoryT{} - fileManager, err := factory.New(getFileManagerSettings(request.Type, request.Config)) + fileManager, err := factory.New(settings) if err != nil { return fmt.Errorf("unable to create file manager: \n%s", err.Error()) } @@ -874,20 +880,22 @@ func validateObjectStorage(ctx context.Context, request *ObjectStorageValidation return nil } -func getFileManagerSettings(provider string, inputConfig map[string]interface{}) *filemanager.SettingsT { +func getFileManagerSettings(ctx context.Context, provider string, inputConfig map[string]interface{}) (*filemanager.SettingsT, error) { settings := &filemanager.SettingsT{ Provider: provider, Config: inputConfig, } - overrideWithEnv(settings) - return settings + if err := overrideWithEnv(ctx, settings); err != nil { + return nil, fmt.Errorf("overriding config with env: %w", err) + } + return settings, nil } // overrideWithEnv overrides the config keys in the fileManager settings // with fallback values pulled from env. Only supported for S3 for now. -func overrideWithEnv(settings *filemanager.SettingsT) { - envConfig := filemanager.GetProviderConfigFromEnv(context.TODO(), settings.Provider) +func overrideWithEnv(ctx context.Context, settings *filemanager.SettingsT) error { + envConfig := filemanager.GetProviderConfigFromEnv(ctx, settings.Provider) if settings.Provider == "S3" { ifNotExistThenSet("prefix", envConfig["prefix"], settings.Config) @@ -898,6 +906,7 @@ func overrideWithEnv(settings *filemanager.SettingsT) { ifNotExistThenSet("externalID", envConfig["externalID"], settings.Config) ifNotExistThenSet("regionHint", envConfig["regionHint"], settings.Config) } + return ctx.Err() } func ifNotExistThenSet(keyToReplace string, replaceWith interface{}, configMap map[string]interface{}) { diff --git a/warehouse/api_test.go b/warehouse/api_test.go index f00e8732b7..97b58d2c32 100644 --- a/warehouse/api_test.go +++ b/warehouse/api_test.go @@ -1,6 +1,8 @@ package warehouse import ( + "context" + . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/rudderlabs/rudder-server/services/filemanager" @@ -8,16 +10,18 @@ import ( var _ = Describe("warehouse_api", func() { Context("Testing objectStorageValidation ", func() { + ctx := context.Background() + It("Should fallback to backup credentials when fields missing(as of now backup only supported for s3)", func() { fm := &filemanager.SettingsT{ Provider: "AZURE_BLOB", Config: map[string]interface{}{"containerName": "containerName1", "prefix": "prefix1", "accountKey": "accountKey1"}, } - overrideWithEnv(fm) + overrideWithEnv(ctx, fm) Expect(fm.Config["accountName"]).To(BeNil()) fm.Provider = "S3" fm.Config = map[string]interface{}{"bucketName": "bucket1", "prefix": "prefix1", "accessKeyID": "KeyID1"} - overrideWithEnv(fm) + overrideWithEnv(ctx, fm) Expect(fm.Config["accessKey"]).ToNot(BeNil()) }) It("Should set value for key when key not present", func() { diff --git a/warehouse/archive/archiver.go b/warehouse/archive/archiver.go index d8f72a63b0..94439f5418 100644 --- a/warehouse/archive/archiver.go +++ b/warehouse/archive/archiver.go @@ -69,7 +69,7 @@ type Archiver struct { Multitenant *multitenant.Manager } -func (a *Archiver) backupRecords(args backupRecordsArgs) (backupLocation string, err error) { +func (a *Archiver) backupRecords(ctx context.Context, args backupRecordsArgs) (backupLocation string, err error) { a.Logger.Infof(`Starting backupRecords for uploadId: %s, sourceId: %s, destinationId: %s, tableName: %s,`, args.uploadID, args.sourceID, args.destID, args.tableName) tmpDirPath, err := misc.CreateTMPDIR() if err != nil { @@ -91,7 +91,7 @@ func (a *Archiver) backupRecords(args backupRecordsArgs) (backupLocation string, fManager, err := a.FileManager.New(&filemanager.SettingsT{ Provider: config.GetString("JOBS_BACKUP_STORAGE_PROVIDER", "S3"), - Config: filemanager.GetProviderConfigForBackupsFromEnv(context.TODO()), + Config: filemanager.GetProviderConfigForBackupsFromEnv(ctx), }) if err != nil { err = fmt.Errorf("error in creating a file manager for:%s. Error: %w", config.GetString("JOBS_BACKUP_STORAGE_PROVIDER", "S3"), err) @@ -133,7 +133,7 @@ func (a *Archiver) backupRecords(args backupRecordsArgs) (backupLocation string, return } -func (a *Archiver) deleteFilesInStorage(locations []string) error { +func (a *Archiver) deleteFilesInStorage(ctx context.Context, locations []string) error { fManager, err := a.FileManager.New(&filemanager.SettingsT{ Provider: warehouseutils.S3, Config: misc.GetRudderObjectStorageConfig(""), @@ -143,7 +143,7 @@ func (a *Archiver) deleteFilesInStorage(locations []string) error { return err } - err = fManager.DeleteObjects(context.TODO(), locations) + err = fManager.DeleteObjects(ctx, locations) if err != nil { a.Logger.Errorf("Error in deleting objects in Rudder S3: %v", err) } @@ -239,7 +239,7 @@ func (a *Archiver) Do(ctx context.Context) error { var archivedUploads int for _, u := range uploadsToArchive { - txn, err := a.DB.Begin() + txn, err := a.DB.BeginTx(ctx, &sql.TxOptions{}) if err != nil { a.Logger.Errorf(`Error creating txn in archiveUploadFiles. Error: %v`, err) continue @@ -267,7 +267,7 @@ func (a *Archiver) Do(ctx context.Context) error { u.endStagingFileId, ) - stagingFileRows, err := txn.Query(stmt) + stagingFileRows, err := txn.QueryContext(ctx, stmt) if err != nil { a.Logger.Errorf(`Error running txn in archiveUploadFiles. Query: %s Error: %v`, stmt, err) txn.Rollback() @@ -297,7 +297,7 @@ func (a *Archiver) Do(ctx context.Context) error { if len(stagingFileIDs) > 0 { if !hasUsedRudderStorage { filterSQL := fmt.Sprintf(`id IN (%v)`, misc.IntArrayToString(stagingFileIDs, ",")) - storedStagingFilesLocation, err = a.backupRecords(backupRecordsArgs{ + storedStagingFilesLocation, err = a.backupRecords(ctx, backupRecordsArgs{ tableName: warehouseutils.WarehouseStagingFilesTable, sourceID: u.sourceID, destID: u.destID, @@ -315,7 +315,7 @@ func (a *Archiver) Do(ctx context.Context) error { } if hasUsedRudderStorage { - err = a.deleteFilesInStorage(stagingFileLocations) + err = a.deleteFilesInStorage(ctx, stagingFileLocations) if err != nil { a.Logger.Errorf(`Error deleting staging files from Rudder S3. Error: %v`, stmt, err) txn.Rollback() @@ -333,7 +333,7 @@ func (a *Archiver) Do(ctx context.Context) error { warehouseutils.WarehouseStagingFilesTable, misc.IntArrayToString(stagingFileIDs, ","), ) - _, err = txn.Query(stmt) + _, err = txn.QueryContext(ctx, stmt) if err != nil { a.Logger.Errorf(`Error running txn in archiveUploadFiles. Query: %s Error: %v`, stmt, err) txn.Rollback() @@ -349,7 +349,7 @@ func (a *Archiver) Do(ctx context.Context) error { `, warehouseutils.WarehouseLoadFilesTable, ) - loadLocationRows, err := txn.Query(stmt, pq.Array(stagingFileIDs)) + loadLocationRows, err := txn.QueryContext(ctx, stmt, pq.Array(stagingFileIDs)) if err != nil { a.Logger.Errorf(`Error running txn in archiveUploadFiles. Query: %s Error: %v`, stmt, err) txn.Rollback() @@ -380,7 +380,7 @@ func (a *Archiver) Do(ctx context.Context) error { } paths = append(paths, u.Path[1:]) } - err = a.deleteFilesInStorage(paths) + err = a.deleteFilesInStorage(ctx, paths) if err != nil { a.Logger.Errorf(`Error deleting load files from Rudder S3. Error: %v`, stmt, err) txn.Rollback() @@ -403,7 +403,7 @@ func (a *Archiver) Do(ctx context.Context) error { warehouseutils.WarehouseUploadsTable, u.uploadID, ) - _, err = txn.Exec(stmt, u.uploadMetdata) + _, err = txn.ExecContext(ctx, stmt, u.uploadMetdata) if err != nil { a.Logger.Errorf(`Error running txn in archiveUploadFiles. Query: %s Error: %v`, stmt, err) txn.Rollback() diff --git a/warehouse/client/client.go b/warehouse/client/client.go index 6508cf2c05..2fb9c45afd 100644 --- a/warehouse/client/client.go +++ b/warehouse/client/client.go @@ -72,8 +72,8 @@ func (cl *Client) sqlQuery(statement string) (result warehouseutils.QueryResult, func (cl *Client) bqQuery(statement string) (result warehouseutils.QueryResult, err error) { query := cl.BQ.Query(statement) - context := context.Background() - it, err := query.Read(context) + ctx := context.Background() + it, err := query.Read(ctx) if err != nil { return } @@ -101,7 +101,7 @@ func (cl *Client) bqQuery(statement string) (result warehouseutils.QueryResult, } func (cl *Client) dbQuery(statement string) (result warehouseutils.QueryResult, err error) { - executeResponse, err := cl.DeltalakeClient.Client.ExecuteQuery(cl.DeltalakeClient.Context, &proto.ExecuteQueryRequest{ + executeResponse, err := cl.DeltalakeClient.Client.ExecuteQuery(context.TODO(), &proto.ExecuteQueryRequest{ Config: cl.DeltalakeClient.CredConfig, SqlStatement: statement, Identifier: cl.DeltalakeClient.CredIdentifier, @@ -130,10 +130,10 @@ func (cl *Client) Query(statement string) (result warehouseutils.QueryResult, er func (cl *Client) Close() { switch cl.Type { case BQClient: - cl.BQ.Close() + _ = cl.BQ.Close() case DeltalakeClient: - cl.DeltalakeClient.Close() + cl.DeltalakeClient.Close(context.TODO()) default: - cl.SQL.Close() + _ = cl.SQL.Close() } } diff --git a/warehouse/client/controlplane/client_test.go b/warehouse/client/controlplane/client_test.go index 9e754fc25d..f5f16dee43 100644 --- a/warehouse/client/controlplane/client_test.go +++ b/warehouse/client/controlplane/client_test.go @@ -53,9 +53,11 @@ func TestFetchSSHKeys(t *testing.T) { } for _, tc := range testcases { - tc := tc + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + svc := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _, ok := r.BasicAuth() require.True(t, ok) @@ -73,7 +75,7 @@ func TestFetchSSHKeys(t *testing.T) { Password: "password", }) - keys, err := client.GetDestinationSSHKeys(context.TODO(), tc.destinationID) + keys, err := client.GetDestinationSSHKeys(ctx, tc.destinationID) require.Equal(t, tc.expectedError, err) require.Equal(t, tc.expectedKeyPair, keys) }) diff --git a/warehouse/identities.go b/warehouse/identities.go index 327e495ed6..8299d03eaa 100644 --- a/warehouse/identities.go +++ b/warehouse/identities.go @@ -160,20 +160,20 @@ func (wh *HandleT) hasLocalIdentityData(warehouse model.Warehouse) (exists bool) return } -func (wh *HandleT) hasWarehouseData(warehouse model.Warehouse) (bool, error) { +func (wh *HandleT) hasWarehouseData(ctx context.Context, warehouse model.Warehouse) (bool, error) { whManager, err := manager.New(wh.destType) if err != nil { panic(err) } - empty, err := whManager.IsEmpty(warehouse) + empty, err := whManager.IsEmpty(ctx, warehouse) if err != nil { return false, err } return !empty, nil } -func (wh *HandleT) setupIdentityTables(warehouse model.Warehouse) { +func (wh *HandleT) setupIdentityTables(ctx context.Context, warehouse model.Warehouse) { var name sql.NullString sqlStatement := fmt.Sprintf(`SELECT to_regclass('%s')`, warehouseutils.IdentityMappingsTableName(warehouse)) err := wh.dbHandle.QueryRow(sqlStatement).Scan(&name) @@ -196,7 +196,7 @@ func (wh *HandleT) setupIdentityTables(warehouse model.Warehouse) { `, warehouseutils.IdentityMergeRulesTableName(warehouse), ) - _, err = wh.dbHandle.Exec(sqlStatement) + _, err = wh.dbHandle.ExecContext(ctx, sqlStatement) if err != nil { panic(fmt.Errorf("Query: %s\nfailed with Error : %w", sqlStatement, err)) } @@ -210,7 +210,7 @@ func (wh *HandleT) setupIdentityTables(warehouse model.Warehouse) { warehouseutils.IdentityMergeRulesTableName(warehouse), ) - _, err = wh.dbHandle.Exec(sqlStatement) + _, err = wh.dbHandle.ExecContext(ctx, sqlStatement) if err != nil { panic(fmt.Errorf("Query: %s\nfailed with Error : %w", sqlStatement, err)) } @@ -227,7 +227,7 @@ func (wh *HandleT) setupIdentityTables(warehouse model.Warehouse) { warehouseutils.IdentityMappingsTableName(warehouse), ) - _, err = wh.dbHandle.Exec(sqlStatement) + _, err = wh.dbHandle.ExecContext(ctx, sqlStatement) if err != nil { panic(fmt.Errorf("Query: %s\nfailed with Error : %w", sqlStatement, err)) } @@ -244,7 +244,7 @@ func (wh *HandleT) setupIdentityTables(warehouse model.Warehouse) { warehouseutils.IdentityMappingsUniqueMappingConstraintName(warehouse), ) - _, err = wh.dbHandle.Exec(sqlStatement) + _, err = wh.dbHandle.ExecContext(ctx, sqlStatement) if err != nil { panic(fmt.Errorf("Query: %s\nfailed with Error : %w", sqlStatement, err)) } @@ -255,7 +255,7 @@ func (wh *HandleT) setupIdentityTables(warehouse model.Warehouse) { warehouseutils.IdentityMappingsTableName(warehouse), ) - _, err = wh.dbHandle.Exec(sqlStatement) + _, err = wh.dbHandle.ExecContext(ctx, sqlStatement) if err != nil { panic(fmt.Errorf("Query: %s\nfailed with Error : %w", sqlStatement, err)) } @@ -268,7 +268,7 @@ func (wh *HandleT) setupIdentityTables(warehouse model.Warehouse) { warehouseutils.IdentityMappingsTableName(warehouse), ) - _, err = wh.dbHandle.Exec(sqlStatement) + _, err = wh.dbHandle.ExecContext(ctx, sqlStatement) if err != nil { panic(fmt.Errorf("Query: %s\nfailed with Error : %w", sqlStatement, err)) } @@ -359,7 +359,7 @@ func (*HandleT) setFailedStat(warehouse model.Warehouse, err error) { } } -func (wh *HandleT) populateHistoricIdentities(warehouse model.Warehouse) { +func (wh *HandleT) populateHistoricIdentities(ctx context.Context, warehouse model.Warehouse) { if isDestHistoricIdentitiesPopulated(warehouse) || isDestHistoricIdentitiesPopulateInProgress(warehouse) { return } @@ -389,7 +389,7 @@ func (wh *HandleT) populateHistoricIdentities(warehouse model.Warehouse) { return } var hasData bool - hasData, err = wh.hasWarehouseData(warehouse) + hasData, err = wh.hasWarehouseData(ctx, warehouse) if err != nil { pkgLogger.Errorf(`[WH]: Error checking for data in %s:%s:%s, err: %s`, wh.destType, warehouse.Destination.ID, warehouse.Destination.Name, err.Error()) return @@ -408,12 +408,12 @@ func (wh *HandleT) populateHistoricIdentities(warehouse model.Warehouse) { panic(err) } - job := wh.uploadJobFactory.NewUploadJob(&model.UploadJob{ + job := wh.uploadJobFactory.NewUploadJob(ctx, &model.UploadJob{ Upload: upload, Warehouse: warehouse, }, whManager) - tableUploadsCreated, tableUploadsErr := job.tableUploadsRepo.ExistsForUploadID(context.TODO(), job.upload.ID) + tableUploadsCreated, tableUploadsErr := job.tableUploadsRepo.ExistsForUploadID(ctx, job.upload.ID) if tableUploadsErr != nil { pkgLogger.Warnw("table uploads exists", logfield.UploadJobID, job.upload.ID, @@ -433,14 +433,14 @@ func (wh *HandleT) populateHistoricIdentities(warehouse model.Warehouse) { } } - err = whManager.Setup(job.warehouse, job) + err = whManager.Setup(ctx, job.warehouse, job) if err != nil { job.setUploadError(err, model.Aborted) return } - defer whManager.Cleanup() + defer whManager.Cleanup(ctx) - err = job.schemaHandle.fetchSchemaFromWarehouse(whManager) + err = job.schemaHandle.fetchSchemaFromWarehouse(ctx, whManager) if err != nil { pkgLogger.Errorf(`[WH]: Failed fetching schema from warehouse: %v`, err) job.setUploadError(err, model.Aborted) diff --git a/warehouse/identity/identity.go b/warehouse/identity/identity.go index ba3a50ded9..678386e4d1 100644 --- a/warehouse/identity/identity.go +++ b/warehouse/identity/identity.go @@ -32,39 +32,52 @@ func init() { } type WarehouseManager interface { - DownloadIdentityRules(*misc.GZipWriter) error + DownloadIdentityRules(context.Context, *misc.GZipWriter) error } -type HandleT struct { - Warehouse model.Warehouse - DB *sql.DB - Uploader warehouseutils.Uploader - UploadID int64 - WarehouseManager WarehouseManager - LoadFileDownloader downloader.Downloader +type Identity struct { + ctx context.Context + warehouse model.Warehouse + db *sql.DB + uploader warehouseutils.Uploader + uploadID int64 + warehouseManager WarehouseManager + downloader downloader.Downloader } -func (idr *HandleT) mergeRulesTable() string { - return warehouseutils.IdentityMergeRulesTableName(idr.Warehouse) +func New(ctx context.Context, warehouse model.Warehouse, db *sql.DB, uploader warehouseutils.Uploader, uploadID int64, warehouseManager WarehouseManager, loadFileDownloader downloader.Downloader) *Identity { + return &Identity{ + ctx: ctx, + warehouse: warehouse, + db: db, + uploader: uploader, + uploadID: uploadID, + warehouseManager: warehouseManager, + downloader: loadFileDownloader, + } +} + +func (idr *Identity) mergeRulesTable() string { + return warehouseutils.IdentityMergeRulesTableName(idr.warehouse) } -func (idr *HandleT) mappingsTable() string { - return warehouseutils.IdentityMappingsTableName(idr.Warehouse) +func (idr *Identity) mappingsTable() string { + return warehouseutils.IdentityMappingsTableName(idr.warehouse) } -func (idr *HandleT) whMergeRulesTable() string { - return warehouseutils.ToProviderCase(idr.Warehouse.Destination.DestinationDefinition.Name, warehouseutils.IdentityMergeRulesTable) +func (idr *Identity) whMergeRulesTable() string { + return warehouseutils.ToProviderCase(idr.warehouse.Destination.DestinationDefinition.Name, warehouseutils.IdentityMergeRulesTable) } -func (idr *HandleT) whMappingsTable() string { - return warehouseutils.ToProviderCase(idr.Warehouse.Destination.DestinationDefinition.Name, warehouseutils.IdentityMappingsTable) +func (idr *Identity) whMappingsTable() string { + return warehouseutils.ToProviderCase(idr.warehouse.Destination.DestinationDefinition.Name, warehouseutils.IdentityMappingsTable) } -func (idr *HandleT) applyRule(txn *sql.Tx, ruleID int64, gzWriter *misc.GZipWriter) (totalRowsModified int, err error) { +func (idr *Identity) applyRule(txn *sql.Tx, ruleID int64, gzWriter *misc.GZipWriter) (totalRowsModified int, err error) { sqlStatement := fmt.Sprintf(`SELECT merge_property_1_type, merge_property_1_value, merge_property_2_type, merge_property_2_value FROM %s WHERE id=%v`, idr.mergeRulesTable(), ruleID) var prop1Val, prop2Val, prop1Type, prop2Type sql.NullString - err = txn.QueryRow(sqlStatement).Scan(&prop1Type, &prop1Val, &prop2Type, &prop2Val) + err = txn.QueryRowContext(idr.ctx, sqlStatement).Scan(&prop1Type, &prop1Val, &prop2Type, &prop2Val) if err != nil { return } @@ -76,7 +89,7 @@ func (idr *HandleT) applyRule(txn *sql.Tx, ruleID int64, gzWriter *misc.GZipWrit } sqlStatement = fmt.Sprintf(`SELECT ARRAY_AGG(DISTINCT(rudder_id)) FROM %s WHERE (merge_property_type='%s' AND merge_property_value=%s) %s`, idr.mappingsTable(), prop1Type.String, misc.QuoteLiteral(prop1Val.String), additionalClause) pkgLogger.Debugf(`IDR: Fetching all rudder_id's corresponding to the merge_rule: %v`, sqlStatement) - err = txn.QueryRow(sqlStatement).Scan(pq.Array(&rudderIDs)) + err = txn.QueryRowContext(idr.ctx, sqlStatement).Scan(pq.Array(&rudderIDs)) if err != nil { pkgLogger.Errorf("IDR: Error fetching all rudder_id's corresponding to the merge_rule: %v\nwith Error: %v", sqlStatement, err) return @@ -107,9 +120,9 @@ func (idr *HandleT) applyRule(txn *sql.Tx, ruleID int64, gzWriter *misc.GZipWrit row2Values = fmt.Sprintf(`, (%s)`, misc.SingleQuoteLiteralJoin(row2)) } - sqlStatement = fmt.Sprintf(`INSERT INTO %s (merge_property_type, merge_property_value, rudder_id, updated_at) VALUES (%s) %s ON CONFLICT ON CONSTRAINT %s DO NOTHING`, idr.mappingsTable(), row1Values, row2Values, warehouseutils.IdentityMappingsUniqueMappingConstraintName(idr.Warehouse)) + sqlStatement = fmt.Sprintf(`INSERT INTO %s (merge_property_type, merge_property_value, rudder_id, updated_at) VALUES (%s) %s ON CONFLICT ON CONSTRAINT %s DO NOTHING`, idr.mappingsTable(), row1Values, row2Values, warehouseutils.IdentityMappingsUniqueMappingConstraintName(idr.warehouse)) pkgLogger.Debugf(`IDR: Inserting properties from merge_rule into mappings table: %v`, sqlStatement) - _, err = txn.Exec(sqlStatement) + _, err = txn.ExecContext(idr.ctx, sqlStatement) if err != nil { pkgLogger.Errorf(`IDR: Error inserting properties from merge_rule into mappings table: %v`, err) return @@ -132,7 +145,7 @@ func (idr *HandleT) applyRule(txn *sql.Tx, ruleID int64, gzWriter *misc.GZipWrit sqlStatement := fmt.Sprintf(`SELECT merge_property_type, merge_property_value FROM %s WHERE rudder_id IN (%v)`, idr.mappingsTable(), quotedRudderIDs) pkgLogger.Debugf(`IDR: Get all merge properties from mapping table with rudder_id's %v: %v`, quotedRudderIDs, sqlStatement) var tableRows *sql.Rows - tableRows, err = txn.Query(sqlStatement) + tableRows, err = txn.QueryContext(idr.ctx, sqlStatement) if err != nil { return } @@ -149,23 +162,23 @@ func (idr *HandleT) applyRule(txn *sql.Tx, ruleID int64, gzWriter *misc.GZipWrit sqlStatement = fmt.Sprintf(`UPDATE %s SET rudder_id='%s', updated_at='%s' WHERE rudder_id IN (%v)`, idr.mappingsTable(), newID, currentTimeString, misc.SingleQuoteLiteralJoin(rudderIDs[1:])) var res sql.Result - res, err = txn.Exec(sqlStatement) + res, err = txn.ExecContext(idr.ctx, sqlStatement) if err != nil { return } affectedRowCount, _ := res.RowsAffected() pkgLogger.Debugf(`IDR: Updated rudder_id for all properties in mapping table. Updated %v rows: %v `, affectedRowCount, sqlStatement) - sqlStatement = fmt.Sprintf(`INSERT INTO %s (merge_property_type, merge_property_value, rudder_id, updated_at) VALUES (%s) %s ON CONFLICT ON CONSTRAINT %s DO NOTHING`, idr.mappingsTable(), row1Values, row2Values, warehouseutils.IdentityMappingsUniqueMappingConstraintName(idr.Warehouse)) + sqlStatement = fmt.Sprintf(`INSERT INTO %s (merge_property_type, merge_property_value, rudder_id, updated_at) VALUES (%s) %s ON CONFLICT ON CONSTRAINT %s DO NOTHING`, idr.mappingsTable(), row1Values, row2Values, warehouseutils.IdentityMappingsUniqueMappingConstraintName(idr.warehouse)) pkgLogger.Debugf(`IDR: Insert new mappings into %s: %v`, idr.mappingsTable(), sqlStatement) - _, err = txn.Exec(sqlStatement) + _, err = txn.ExecContext(idr.ctx, sqlStatement) if err != nil { return } } columnNames := []string{"merge_property_type", "merge_property_value", "rudder_id", "updated_at"} for _, row := range rows { - eventLoader := encoding.GetNewEventLoader(idr.Warehouse.Type, idr.Uploader.GetLoadFileType(), gzWriter) + eventLoader := encoding.GetNewEventLoader(idr.warehouse.Type, idr.uploader.GetLoadFileType(), gzWriter) // TODO : support add row for parquet loader eventLoader.AddRow(columnNames, row) data, _ := eventLoader.WriteToString() @@ -175,7 +188,7 @@ func (idr *HandleT) applyRule(txn *sql.Tx, ruleID int64, gzWriter *misc.GZipWrit return len(rows), err } -func (idr *HandleT) addRules(txn *sql.Tx, loadFileNames []string, gzWriter *misc.GZipWriter) (ids []int64, err error) { +func (idr *Identity) addRules(txn *sql.Tx, loadFileNames []string, gzWriter *misc.GZipWriter) (ids []int64, err error) { // add rules from load files into temp table // use original table to delete redundant ones from temp table // insert from temp table into original table @@ -186,7 +199,7 @@ func (idr *HandleT) addRules(txn *sql.Tx, loadFileNames []string, gzWriter *misc WITH NO DATA;`, mergeRulesStagingTable, idr.mergeRulesTable()) pkgLogger.Infof(`IDR: Creating temp table %s in postgres for loading %s: %v`, mergeRulesStagingTable, idr.mergeRulesTable(), sqlStatement) - _, err = txn.Exec(sqlStatement) + _, err = txn.ExecContext(idr.ctx, sqlStatement) if err != nil { pkgLogger.Errorf(`IDR: Error creating temp table %s in postgres: %v`, mergeRulesStagingTable, err) return @@ -218,7 +231,7 @@ func (idr *HandleT) addRules(txn *sql.Tx, loadFileNames []string, gzWriter *misc } defer gzipReader.Close() - eventReader := encoding.NewEventReader(gzipReader, idr.Warehouse.Type) + eventReader := encoding.NewEventReader(gzipReader, idr.warehouse.Type) columnNames := []string{"merge_property_1_type", "merge_property_1_value", "merge_property_2_type", "merge_property_2_value"} for { var record []string @@ -239,7 +252,7 @@ func (idr *HandleT) addRules(txn *sql.Tx, loadFileNames []string, gzWriter *misc // add rowID which allows us to insert in same order from staging to original merge _rules table rowID++ recordInterface[4] = rowID - _, err = stmt.Exec(recordInterface[:]...) + _, err = stmt.ExecContext(idr.ctx, recordInterface[:]...) if err != nil { pkgLogger.Errorf("IDR: Error while adding rowID to merge_rules table: %v", err) return @@ -247,9 +260,9 @@ func (idr *HandleT) addRules(txn *sql.Tx, loadFileNames []string, gzWriter *misc } } - _, err = stmt.Exec() + _, err = stmt.ExecContext(idr.ctx) if err != nil { - pkgLogger.Errorf(`IDR: Error bulk copy using CopyIn: %v for uploadID: %v`, err, idr.UploadID) + pkgLogger.Errorf(`IDR: Error bulk copy using CopyIn: %v for uploadID: %v`, err, idr.uploadID) return } @@ -265,7 +278,7 @@ func (idr *HandleT) addRules(txn *sql.Tx, loadFileNames []string, gzWriter *misc (original.merge_property_2_value = staging.merge_property_2_value)`, mergeRulesStagingTable, idr.mergeRulesTable()) pkgLogger.Infof(`IDR: Deleting from staging table %s using %s: %v`, mergeRulesStagingTable, idr.mergeRulesTable(), sqlStatement) - _, err = txn.Exec(sqlStatement) + _, err = txn.ExecContext(idr.ctx, sqlStatement) if err != nil { pkgLogger.Errorf(`IDR: Error deleting from staging table %s using %s: %v`, mergeRulesStagingTable, idr.mergeRulesTable(), err) return @@ -290,7 +303,7 @@ func (idr *HandleT) addRules(txn *sql.Tx, loadFileNames []string, gzWriter *misc ) t ORDER BY id ASC RETURNING id`, idr.mergeRulesTable(), mergeRulesStagingTable) pkgLogger.Infof(`IDR: Inserting into %s from %s: %v`, idr.mergeRulesTable(), mergeRulesStagingTable, sqlStatement) - rows, err := txn.Query(sqlStatement) + rows, err := txn.QueryContext(idr.ctx, sqlStatement) if err != nil { pkgLogger.Errorf(`IDR: Error inserting into %s from %s: %v`, idr.mergeRulesTable(), mergeRulesStagingTable, err) return @@ -304,15 +317,15 @@ func (idr *HandleT) addRules(txn *sql.Tx, loadFileNames []string, gzWriter *misc } ids = append(ids, id) } - pkgLogger.Debugf(`IDR: Number of merge rules inserted for uploadID %v : %v`, idr.UploadID, len(ids)) + pkgLogger.Debugf(`IDR: Number of merge rules inserted for uploadID %v : %v`, idr.uploadID, len(ids)) return ids, nil } -func (idr *HandleT) writeTableToFile(tableName string, txn *sql.Tx, gzWriter *misc.GZipWriter) (err error) { +func (idr *Identity) writeTableToFile(tableName string, txn *sql.Tx, gzWriter *misc.GZipWriter) (err error) { batchSize := int64(500) sqlStatement := fmt.Sprintf(`SELECT COUNT(*) FROM %s`, tableName) var totalRows int64 - err = txn.QueryRow(sqlStatement).Scan(&totalRows) + err = txn.QueryRowContext(idr.ctx, sqlStatement).Scan(&totalRows) if err != nil { return } @@ -322,14 +335,14 @@ func (idr *HandleT) writeTableToFile(tableName string, txn *sql.Tx, gzWriter *mi sqlStatement = fmt.Sprintf(`SELECT merge_property_1_type, merge_property_1_value, merge_property_2_type, merge_property_2_value FROM %s LIMIT %d OFFSET %d`, tableName, batchSize, offset) var rows *sql.Rows - rows, err = txn.Query(sqlStatement) + rows, err = txn.QueryContext(idr.ctx, sqlStatement) if err != nil { return } columnNames := []string{"merge_property_1_type", "merge_property_1_value", "merge_property_2_type", "merge_property_2_value"} for rows.Next() { var rowData []string - eventLoader := encoding.GetNewEventLoader(idr.Warehouse.Type, idr.Uploader.GetLoadFileType(), gzWriter) + eventLoader := encoding.GetNewEventLoader(idr.warehouse.Type, idr.uploader.GetLoadFileType(), gzWriter) var prop1Val, prop2Val, prop1Type, prop2Type sql.NullString err = rows.Scan( &prop1Type, @@ -357,45 +370,45 @@ func (idr *HandleT) writeTableToFile(tableName string, txn *sql.Tx, gzWriter *mi return } -func (idr *HandleT) uploadFile(filePath string, txn *sql.Tx, tableName string, totalRecords int) (err error) { +func (idr *Identity) uploadFile(filePath string, txn *sql.Tx, tableName string, totalRecords int) (err error) { outputFile, err := os.Open(filePath) if err != nil { panic(err) } - storageProvider := warehouseutils.ObjectStorageType(idr.Warehouse.Destination.DestinationDefinition.Name, idr.Warehouse.Destination.Config, idr.Uploader.UseRudderStorage()) + storageProvider := warehouseutils.ObjectStorageType(idr.warehouse.Destination.DestinationDefinition.Name, idr.warehouse.Destination.Config, idr.uploader.UseRudderStorage()) uploader, err := filemanager.DefaultFileManagerFactory.New(&filemanager.SettingsT{ Provider: storageProvider, Config: misc.GetObjectStorageConfig(misc.ObjectStorageOptsT{ Provider: storageProvider, - Config: idr.Warehouse.Destination.Config, - UseRudderStorage: idr.Uploader.UseRudderStorage(), + Config: idr.warehouse.Destination.Config, + UseRudderStorage: idr.uploader.UseRudderStorage(), }), }) if err != nil { - pkgLogger.Errorf("IDR: Error in creating a file manager for :%s: , %v", idr.Warehouse.Destination.DestinationDefinition.Name, err) + pkgLogger.Errorf("IDR: Error in creating a file manager for :%s: , %v", idr.warehouse.Destination.DestinationDefinition.Name, err) return err } - output, err := uploader.Upload(context.TODO(), outputFile, config.GetString("WAREHOUSE_BUCKET_LOAD_OBJECTS_FOLDER_NAME", "rudder-warehouse-load-objects"), tableName, idr.Warehouse.Source.ID, tableName) + output, err := uploader.Upload(idr.ctx, outputFile, config.GetString("WAREHOUSE_BUCKET_LOAD_OBJECTS_FOLDER_NAME", "rudder-warehouse-load-objects"), tableName, idr.warehouse.Source.ID, tableName) if err != nil { return } - sqlStatement := fmt.Sprintf(`UPDATE %s SET location='%s', total_events=%d WHERE wh_upload_id=%d AND table_name='%s'`, warehouseutils.WarehouseTableUploadsTable, output.Location, totalRecords, idr.UploadID, warehouseutils.ToProviderCase(idr.Warehouse.Destination.DestinationDefinition.Name, tableName)) + sqlStatement := fmt.Sprintf(`UPDATE %s SET location='%s', total_events=%d WHERE wh_upload_id=%d AND table_name='%s'`, warehouseutils.WarehouseTableUploadsTable, output.Location, totalRecords, idr.uploadID, warehouseutils.ToProviderCase(idr.warehouse.Destination.DestinationDefinition.Name, tableName)) pkgLogger.Infof(`IDR: Updating load file location for table: %s: %s `, tableName, sqlStatement) - _, err = txn.Exec(sqlStatement) + _, err = txn.ExecContext(idr.ctx, sqlStatement) if err != nil { pkgLogger.Errorf(`IDR: Error updating load file location for table: %s: %v`, tableName, err) } return err } -func (idr *HandleT) createTempGzFile(dirName string) (gzWriter misc.GZipWriter, path string) { +func (idr *Identity) createTempGzFile(dirName string) (gzWriter misc.GZipWriter, path string) { tmpDirPath, err := misc.CreateTMPDIR() if err != nil { panic(err) } - fileExtension := warehouseutils.GetTempFileExtension(idr.Warehouse.Type) - path = tmpDirPath + dirName + fmt.Sprintf(`%s_%s/%v/`, idr.Warehouse.Destination.DestinationDefinition.Name, idr.Warehouse.Destination.ID, idr.UploadID) + misc.FastUUID().String() + "." + fileExtension + fileExtension := warehouseutils.GetTempFileExtension(idr.warehouse.Type) + path = tmpDirPath + dirName + fmt.Sprintf(`%s_%s/%v/`, idr.warehouse.Destination.DestinationDefinition.Name, idr.warehouse.Destination.ID, idr.uploadID) + misc.FastUUID().String() + "." + fileExtension err = os.MkdirAll(filepath.Dir(path), os.ModePerm) if err != nil { panic(err) @@ -407,8 +420,8 @@ func (idr *HandleT) createTempGzFile(dirName string) (gzWriter misc.GZipWriter, return } -func (idr *HandleT) processMergeRules(fileNames []string) (err error) { - txn, err := idr.DB.Begin() +func (idr *Identity) processMergeRules(fileNames []string) (err error) { + txn, err := idr.db.BeginTx(idr.ctx, &sql.TxOptions{}) if err != nil { panic(err) } @@ -439,7 +452,7 @@ func (idr *HandleT) processMergeRules(fileNames []string) (err error) { } totalMappingRecords += count if idx%1000 == 0 { - pkgLogger.Infof(`IDR: Applied %d rules out of %d. Total Mapping records added: %d. Namespace: %s, Destination: %s:%s`, idx+1, len(ruleIDs), totalMappingRecords, idr.Warehouse.Namespace, idr.Warehouse.Type, idr.Warehouse.Destination.ID) + pkgLogger.Infof(`IDR: Applied %d rules out of %d. Total Mapping records added: %d. Namespace: %s, Destination: %s:%s`, idx+1, len(ruleIDs), totalMappingRecords, idr.warehouse.Namespace, idr.warehouse.Type, idr.warehouse.Destination.ID) } } mappingsFileGzWriter.CloseGZ() @@ -472,10 +485,10 @@ func (idr *HandleT) processMergeRules(fileNames []string) (err error) { // 2. Append to local identity merge rules table // 3. Apply each merge rule and update local identity mapping table // 4. Upload the diff of each table to load files for both tables -func (idr *HandleT) Resolve() (err error) { +func (idr *Identity) Resolve() (err error) { var loadFileNames []string defer misc.RemoveFilePaths(loadFileNames...) - loadFileNames, err = idr.LoadFileDownloader.Download(context.TODO(), idr.whMergeRulesTable()) + loadFileNames, err = idr.downloader.Download(idr.ctx, idr.whMergeRulesTable()) if err != nil { pkgLogger.Errorf(`IDR: Failed to download load files for %s with error: %v`, idr.mergeRulesTable(), err) return @@ -484,11 +497,11 @@ func (idr *HandleT) Resolve() (err error) { return idr.processMergeRules(loadFileNames) } -func (idr *HandleT) ResolveHistoricIdentities() (err error) { +func (idr *Identity) ResolveHistoricIdentities() (err error) { var loadFileNames []string defer misc.RemoveFilePaths(loadFileNames...) gzWriter, path := idr.createTempGzFile(fmt.Sprintf(`/%s/`, misc.RudderIdentityMergeRulesTmp)) - err = idr.WarehouseManager.DownloadIdentityRules(&gzWriter) + err = idr.warehouseManager.DownloadIdentityRules(idr.ctx, &gzWriter) gzWriter.CloseGZ() if err != nil { pkgLogger.Errorf(`IDR: Failed to download identity information from warehouse with error: %v`, err) diff --git a/warehouse/integrations/azure-synapse/azure-synapse.go b/warehouse/integrations/azure-synapse/azure-synapse.go index 4ee708530b..bddb4a22df 100644 --- a/warehouse/integrations/azure-synapse/azure-synapse.go +++ b/warehouse/integrations/azure-synapse/azure-synapse.go @@ -178,7 +178,7 @@ func columnsWithDataTypes(columns model.TableSchema, prefix string) string { return strings.Join(arr, ",") } -func (*AzureSynapse) IsEmpty(_ model.Warehouse) (empty bool, err error) { +func (*AzureSynapse) IsEmpty(context.Context, model.Warehouse) (empty bool, err error) { return } @@ -223,7 +223,7 @@ func (as *AzureSynapse) loadTable(ctx context.Context, tableName string, tableSc } if !skipTempTableDelete { - defer as.dropStagingTable(stagingTableName) + defer as.dropStagingTable(ctx, stagingTableName) } stmt, err := txn.PrepareContext(ctx, mssql.CopyIn(as.Namespace+"."+stagingTableName, mssql.BulkOptions{CheckConstraints: false}, append(sortedColumnKeys, extraColumns...)...)) @@ -446,9 +446,9 @@ func (as *AzureSynapse) loadUserTables(ctx context.Context) (errorMap map[string unionStagingTableName := warehouseutils.StagingTableName(provider, "users_identifies_union", tableNameLimit) stagingTableName := warehouseutils.StagingTableName(provider, warehouseutils.UsersTable, tableNameLimit) - defer as.dropStagingTable(stagingTableName) - defer as.dropStagingTable(unionStagingTableName) - defer as.dropStagingTable(identifyStagingTable) + defer as.dropStagingTable(ctx, stagingTableName) + defer as.dropStagingTable(ctx, unionStagingTableName) + defer as.dropStagingTable(ctx, identifyStagingTable) userColMap := as.Uploader.GetTableSchemaInWarehouse(warehouseutils.UsersTable) var userColNames, firstValProps []string @@ -547,53 +547,53 @@ func (as *AzureSynapse) loadUserTables(ctx context.Context) (errorMap map[string return } -func (*AzureSynapse) DeleteBy([]string, warehouseutils.DeleteByParams) error { +func (*AzureSynapse) DeleteBy(context.Context, []string, warehouseutils.DeleteByParams) error { return fmt.Errorf(warehouseutils.NotImplementedErrorCode) } -func (as *AzureSynapse) CreateSchema() (err error) { +func (as *AzureSynapse) CreateSchema(ctx context.Context) (err error) { sqlStatement := fmt.Sprintf(`IF NOT EXISTS ( SELECT * FROM sys.schemas WHERE name = N'%s' ) EXEC('CREATE SCHEMA [%s]'); `, as.Namespace, as.Namespace) as.Logger.Infof("SYNAPSE: Creating schema name in synapse for AZ:%s : %v", as.Warehouse.Destination.ID, sqlStatement) - _, err = as.DB.Exec(sqlStatement) + _, err = as.DB.ExecContext(ctx, sqlStatement) if err == io.EOF { return nil } return } -func (as *AzureSynapse) dropStagingTable(stagingTableName string) { +func (as *AzureSynapse) dropStagingTable(ctx context.Context, stagingTableName string) { as.Logger.Infof("AZ: dropping table %+v\n", stagingTableName) - _, err := as.DB.Exec(fmt.Sprintf(`IF OBJECT_ID ('%[1]s','U') IS NOT NULL DROP TABLE %[1]s;`, as.Namespace+"."+stagingTableName)) + _, err := as.DB.ExecContext(ctx, fmt.Sprintf(`IF OBJECT_ID ('%[1]s','U') IS NOT NULL DROP TABLE %[1]s;`, as.Namespace+"."+stagingTableName)) if err != nil { as.Logger.Errorf("AZ: Error dropping staging table %s in synapse: %v", as.Namespace+"."+stagingTableName, err) } } -func (as *AzureSynapse) createTable(name string, columns model.TableSchema) (err error) { +func (as *AzureSynapse) createTable(ctx context.Context, name string, columns model.TableSchema) (err error) { sqlStatement := fmt.Sprintf(`IF NOT EXISTS (SELECT 1 FROM sys.objects WHERE object_id = OBJECT_ID(N'%[1]s') AND type = N'U') CREATE TABLE %[1]s ( %v )`, name, columnsWithDataTypes(columns, "")) as.Logger.Infof("AZ: Creating table in synapse for AZ:%s : %v", as.Warehouse.Destination.ID, sqlStatement) - _, err = as.DB.Exec(sqlStatement) + _, err = as.DB.ExecContext(ctx, sqlStatement) return } -func (as *AzureSynapse) CreateTable(tableName string, columnMap model.TableSchema) (err error) { +func (as *AzureSynapse) CreateTable(ctx context.Context, tableName string, columnMap model.TableSchema) (err error) { // Search paths doesn't exist unlike Postgres, default is dbo. Hence, use namespace wherever possible - err = as.createTable(as.Namespace+"."+tableName, columnMap) + err = as.createTable(ctx, as.Namespace+"."+tableName, columnMap) return err } -func (as *AzureSynapse) DropTable(tableName string) (err error) { +func (as *AzureSynapse) DropTable(ctx context.Context, tableName string) (err error) { sqlStatement := `DROP TABLE "%[1]s"."%[2]s"` as.Logger.Infof("AZ: Dropping table in synapse for AZ:%s : %v", as.Warehouse.Destination.ID, sqlStatement) - _, err = as.DB.Exec(fmt.Sprintf(sqlStatement, as.Namespace, tableName)) + _, err = as.DB.ExecContext(ctx, fmt.Sprintf(sqlStatement, as.Namespace, tableName)) return } -func (as *AzureSynapse) AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { +func (as *AzureSynapse) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { var ( query string queryBuilder strings.Builder @@ -633,11 +633,11 @@ func (as *AzureSynapse) AddColumns(tableName string, columnsInfo []warehouseutil query += ";" as.Logger.Infof("AZ: Adding columns for destinationID: %s, tableName: %s with query: %v", as.Warehouse.Destination.ID, tableName, query) - _, err = as.DB.Exec(query) + _, err = as.DB.ExecContext(ctx, query) return } -func (*AzureSynapse) AlterColumn(_, _, _ string) (model.AlterTableResponse, error) { +func (*AzureSynapse) AlterColumn(context.Context, string, string, string) (model.AlterTableResponse, error) { return model.AlterTableResponse{}, nil } @@ -653,7 +653,7 @@ func (as *AzureSynapse) TestConnection(ctx context.Context, _ model.Warehouse) e return nil } -func (as *AzureSynapse) Setup(warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { +func (as *AzureSynapse) Setup(_ context.Context, warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { as.Warehouse = warehouse as.Namespace = warehouse.Namespace as.Uploader = uploader @@ -664,11 +664,11 @@ func (as *AzureSynapse) Setup(warehouse model.Warehouse, uploader warehouseutils return err } -func (as *AzureSynapse) CrashRecover() { - as.dropDanglingStagingTables() +func (as *AzureSynapse) CrashRecover(ctx context.Context) { + as.dropDanglingStagingTables(ctx) } -func (as *AzureSynapse) dropDanglingStagingTables() bool { +func (as *AzureSynapse) dropDanglingStagingTables(ctx context.Context) bool { sqlStatement := fmt.Sprintf(` select table_name @@ -681,12 +681,12 @@ func (as *AzureSynapse) dropDanglingStagingTables() bool { as.Namespace, fmt.Sprintf(`%s%%`, warehouseutils.StagingTablePrefix(provider)), ) - rows, err := as.DB.Query(sqlStatement) + rows, err := as.DB.QueryContext(ctx, sqlStatement) if err != nil { as.Logger.Errorf("WH: SYNAPSE: Error dropping dangling staging tables in synapse: %v\nQuery: %s\n", err, sqlStatement) return false } - defer rows.Close() + defer func() { _ = rows.Close() }() var stagingTableNames []string for rows.Next() { @@ -700,7 +700,7 @@ func (as *AzureSynapse) dropDanglingStagingTables() bool { as.Logger.Infof("WH: SYNAPSE: Dropping dangling staging tables: %+v %+v\n", len(stagingTableNames), stagingTableNames) delSuccess := true for _, stagingTableName := range stagingTableNames { - _, err := as.DB.Exec(fmt.Sprintf(`DROP TABLE "%[1]s"."%[2]s"`, as.Namespace, stagingTableName)) + _, err := as.DB.ExecContext(ctx, fmt.Sprintf(`DROP TABLE "%[1]s"."%[2]s"`, as.Namespace, stagingTableName)) if err != nil { as.Logger.Errorf("WH: SYNAPSE: Error dropping dangling staging table: %s in redshift: %v\n", stagingTableName, err) delSuccess = false @@ -710,7 +710,7 @@ func (as *AzureSynapse) dropDanglingStagingTables() bool { } // FetchSchema returns the schema of the warehouse -func (as *AzureSynapse) FetchSchema() (model.Schema, model.Schema, error) { +func (as *AzureSynapse) FetchSchema(ctx context.Context) (model.Schema, model.Schema, error) { schema := make(model.Schema) unrecognizedSchema := make(model.Schema) @@ -725,7 +725,9 @@ func (as *AzureSynapse) FetchSchema() (model.Schema, model.Schema, error) { table_schema = @schema and table_name not like @prefix ` - rows, err := as.DB.Query(sqlStatement, + rows, err := as.DB.QueryContext( + ctx, + sqlStatement, sql.Named("schema", as.Namespace), sql.Named("prefix", fmt.Sprintf("%s%%", warehouseutils.StagingTablePrefix(provider))), ) @@ -774,23 +776,23 @@ func (as *AzureSynapse) LoadTable(ctx context.Context, tableName string) error { return err } -func (as *AzureSynapse) Cleanup() { +func (as *AzureSynapse) Cleanup(ctx context.Context) { if as.DB != nil { // extra check aside dropStagingTable(table) - as.dropDanglingStagingTables() + as.dropDanglingStagingTables(ctx) as.DB.Close() } } -func (*AzureSynapse) LoadIdentityMergeRulesTable() (err error) { +func (*AzureSynapse) LoadIdentityMergeRulesTable(context.Context) (err error) { return } -func (*AzureSynapse) LoadIdentityMappingsTable() (err error) { +func (*AzureSynapse) LoadIdentityMappingsTable(context.Context) (err error) { return } -func (*AzureSynapse) DownloadIdentityRules(*misc.GZipWriter) (err error) { +func (*AzureSynapse) DownloadIdentityRules(context.Context, *misc.GZipWriter) (err error) { return } @@ -810,7 +812,7 @@ func (as *AzureSynapse) GetTotalCountInTable(ctx context.Context, tableName stri return total, err } -func (as *AzureSynapse) Connect(warehouse model.Warehouse) (client.Client, error) { +func (as *AzureSynapse) Connect(_ context.Context, warehouse model.Warehouse) (client.Client, error) { as.Warehouse = warehouse as.Namespace = warehouse.Namespace dbHandle, err := connect(as.getConnectionCredentials()) @@ -821,14 +823,14 @@ func (as *AzureSynapse) Connect(warehouse model.Warehouse) (client.Client, error return client.Client{Type: client.SQLClient, SQL: dbHandle}, err } -func (as *AzureSynapse) LoadTestTable(_, tableName string, payloadMap map[string]interface{}, _ string) (err error) { +func (as *AzureSynapse) LoadTestTable(ctx context.Context, _, tableName string, payloadMap map[string]interface{}, _ string) (err error) { sqlStatement := fmt.Sprintf(`INSERT INTO %q.%q (%v) VALUES (%s)`, as.Namespace, tableName, fmt.Sprintf(`%q, %q`, "id", "val"), fmt.Sprintf(`'%d', '%s'`, payloadMap["id"], payloadMap["val"]), ) - _, err = as.DB.Exec(sqlStatement) + _, err = as.DB.ExecContext(ctx, sqlStatement) return } diff --git a/warehouse/integrations/bigquery/bigquery.go b/warehouse/integrations/bigquery/bigquery.go index 31b6a08032..8267a9e0fb 100644 --- a/warehouse/integrations/bigquery/bigquery.go +++ b/warehouse/integrations/bigquery/bigquery.go @@ -26,13 +26,12 @@ import ( ) type BigQuery struct { - backgroundContext context.Context - db *bigquery.Client - namespace string - warehouse model.Warehouse - projectID string - uploader warehouseutils.Uploader - Logger logger.Logger + db *bigquery.Client + namespace string + warehouse model.Warehouse + projectID string + uploader warehouseutils.Uploader + Logger logger.Logger setUsersLoadPartitionFirstEventFilter bool customPartitionsEnabled bool @@ -170,13 +169,13 @@ func getTableSchema(columns model.TableSchema) []*bigquery.FieldSchema { return schema } -func (bq *BigQuery) DeleteTable(tableName string) (err error) { +func (bq *BigQuery) DeleteTable(ctx context.Context, tableName string) (err error) { tableRef := bq.db.Dataset(bq.namespace).Table(tableName) - err = tableRef.Delete(bq.backgroundContext) + err = tableRef.Delete(ctx) return } -func (bq *BigQuery) CreateTable(tableName string, columnMap model.TableSchema) error { +func (bq *BigQuery) CreateTable(ctx context.Context, tableName string, columnMap model.TableSchema) error { bq.Logger.Infof("BQ: Creating table: %s in bigquery dataset: %s in project: %s", tableName, bq.namespace, bq.projectID) sampleSchema := getTableSchema(columnMap) metaData := &bigquery.TableMetadata{ @@ -184,31 +183,31 @@ func (bq *BigQuery) CreateTable(tableName string, columnMap model.TableSchema) e TimePartitioning: &bigquery.TimePartitioning{}, } tableRef := bq.db.Dataset(bq.namespace).Table(tableName) - err := tableRef.Create(bq.backgroundContext, metaData) + err := tableRef.Create(ctx, metaData) if !checkAndIgnoreAlreadyExistError(err) { return fmt.Errorf("create table: %w", err) } if !bq.dedupEnabled() { - if err = bq.createTableView(tableName, columnMap); err != nil { + if err = bq.createTableView(ctx, tableName, columnMap); err != nil { return fmt.Errorf("create view: %w", err) } } return nil } -func (bq *BigQuery) DropTable(tableName string) (err error) { - err = bq.DeleteTable(tableName) +func (bq *BigQuery) DropTable(ctx context.Context, tableName string) (err error) { + err = bq.DeleteTable(ctx, tableName) if err != nil { return } if !bq.dedupEnabled() { - err = bq.DeleteTable(tableName + "_view") + err = bq.DeleteTable(ctx, tableName+"_view") } return } -func (bq *BigQuery) createTableView(tableName string, columnMap model.TableSchema) (err error) { +func (bq *BigQuery) createTableView(ctx context.Context, tableName string, columnMap model.TableSchema) (err error) { partitionKey := "id" if column, ok := partitionKeyMap[tableName]; ok { partitionKey = column @@ -229,13 +228,13 @@ func (bq *BigQuery) createTableView(tableName string, columnMap model.TableSchem ViewQuery: viewQuery, } tableRef := bq.db.Dataset(bq.namespace).Table(tableName + "_view") - err = tableRef.Create(bq.backgroundContext, metaData) + err = tableRef.Create(ctx, metaData) return } -func (bq *BigQuery) schemaExists(_, _ string) (exists bool, err error) { +func (bq *BigQuery) schemaExists(ctx context.Context, _, _ string) (exists bool, err error) { ds := bq.db.Dataset(bq.namespace) - _, err = ds.Metadata(bq.backgroundContext) + _, err = ds.Metadata(ctx) if err != nil { if e, ok := err.(*googleapi.Error); ok && e.Code == 404 { bq.Logger.Debugf("BQ: Dataset %s not found", bq.namespace) @@ -246,7 +245,7 @@ func (bq *BigQuery) schemaExists(_, _ string) (exists bool, err error) { return true, nil } -func (bq *BigQuery) CreateSchema() (err error) { +func (bq *BigQuery) CreateSchema(ctx context.Context) (err error) { bq.Logger.Infof("BQ: Creating bigquery dataset: %s in project: %s", bq.namespace, bq.projectID) location := strings.TrimSpace(warehouseutils.GetConfigValue(GCPLocation, bq.warehouse)) if location == "" { @@ -254,7 +253,7 @@ func (bq *BigQuery) CreateSchema() (err error) { } var schemaExists bool - schemaExists, err = bq.schemaExists(bq.namespace, location) + schemaExists, err = bq.schemaExists(ctx, bq.namespace, location) if err != nil { bq.Logger.Errorf("BQ: Error checking if schema: %s exists: %v", bq.namespace, err) return err @@ -269,7 +268,7 @@ func (bq *BigQuery) CreateSchema() (err error) { Location: location, } bq.Logger.Infof("BQ: Creating schema: %s ...", bq.namespace) - err = ds.Create(bq.backgroundContext, meta) + err = ds.Create(ctx, meta) if err != nil { if e, ok := err.(*googleapi.Error); ok && e.Code == 409 { bq.Logger.Infof("BQ: Create schema %s failed as schema already exists", bq.namespace) @@ -293,15 +292,15 @@ func checkAndIgnoreAlreadyExistError(err error) bool { return true } -func (bq *BigQuery) dropStagingTable(stagingTableName string) { +func (bq *BigQuery) dropStagingTable(ctx context.Context, stagingTableName string) { bq.Logger.Infof("BQ: Deleting table: %s in bigquery dataset: %s in project: %s", stagingTableName, bq.namespace, bq.projectID) - err := bq.DeleteTable(stagingTableName) + err := bq.DeleteTable(ctx, stagingTableName) if err != nil { bq.Logger.Errorf("BQ: Error dropping staging table %s in bigquery dataset %s in project %s : %v", stagingTableName, bq.namespace, bq.projectID, err) } } -func (bq *BigQuery) DeleteBy(tableNames []string, params warehouseutils.DeleteByParams) error { +func (bq *BigQuery) DeleteBy(ctx context.Context, tableNames []string, params warehouseutils.DeleteByParams) error { for _, tb := range tableNames { bq.Logger.Infof("BQ: Cleaning up the following tables in bigquery for BQ:%s", tb) tableName := fmt.Sprintf("`%s`.`%s`", bq.namespace, tb) @@ -328,12 +327,12 @@ func (bq *BigQuery) DeleteBy(tableNames []string, params warehouseutils.DeleteBy {Name: "starttime", Value: params.StartTime}, } if bq.enableDeleteByJobs { - job, err := bq.getMiddleware().Run(bq.backgroundContext, query) + job, err := bq.getMiddleware().Run(ctx, query) if err != nil { bq.Logger.Errorf("BQ: Error initiating load job: %v\n", err) return err } - status, err := job.Wait(bq.backgroundContext) + status, err := job.Wait(ctx) if err != nil { bq.Logger.Errorf("BQ: Error running job: %v\n", err) return err @@ -350,17 +349,17 @@ func partitionedTable(tableName, partitionDate string) string { return fmt.Sprintf(`%s$%v`, tableName, strings.ReplaceAll(partitionDate, "-", "")) } -func (bq *BigQuery) loadTable(tableName string, _, getLoadFileLocFromTableUploads, skipTempTableDelete bool) (stagingLoadTable StagingLoadTable, err error) { +func (bq *BigQuery) loadTable(ctx context.Context, tableName string, _, getLoadFileLocFromTableUploads, skipTempTableDelete bool) (stagingLoadTable StagingLoadTable, err error) { bq.Logger.Infof("BQ: Starting load for table:%s\n", tableName) var loadFiles []warehouseutils.LoadFile if getLoadFileLocFromTableUploads { - loadFile, err := bq.uploader.GetSingleLoadFile(tableName) + loadFile, err := bq.uploader.GetSingleLoadFile(ctx, tableName) if err != nil { return stagingLoadTable, err } loadFiles = append(loadFiles, loadFile) } else { - loadFiles = bq.uploader.GetLoadFilesMetadata(warehouseutils.GetLoadFilesOptions{Table: tableName}) + loadFiles = bq.uploader.GetLoadFilesMetadata(ctx, warehouseutils.GetLoadFilesOptions{Table: tableName}) } gcsLocations := warehouseutils.GetGCSLocations(loadFiles, warehouseutils.GCSLocationOptions{}) bq.Logger.Infof("BQ: Loading data into table: %s in bigquery dataset: %s in project: %s", tableName, bq.namespace, bq.projectID) @@ -381,12 +380,12 @@ func (bq *BigQuery) loadTable(tableName string, _, getLoadFileLocFromTableUpload loader := bq.db.Dataset(bq.namespace).Table(outputTable).LoaderFrom(gcsRef) - job, err := loader.Run(bq.backgroundContext) + job, err := loader.Run(ctx) if err != nil { bq.Logger.Errorf("BQ: Error initiating append load job: %v\n", err) return } - status, err := job.Wait(bq.backgroundContext) + status, err := job.Wait(ctx) if err != nil { bq.Logger.Errorf("BQ: Error running append load job: %v\n", err) return @@ -409,19 +408,19 @@ func (bq *BigQuery) loadTable(tableName string, _, getLoadFileLocFromTableUpload TimePartitioning: &bigquery.TimePartitioning{}, } tableRef := bq.db.Dataset(bq.namespace).Table(stagingTableName) - err = tableRef.Create(bq.backgroundContext, metaData) + err = tableRef.Create(ctx, metaData) if err != nil { bq.Logger.Infof("BQ: Error creating temporary staging table %s", stagingTableName) return } loader := bq.db.Dataset(bq.namespace).Table(stagingTableName).LoaderFrom(gcsRef) - job, err := loader.Run(bq.backgroundContext) + job, err := loader.Run(ctx) if err != nil { bq.Logger.Errorf("BQ: Error initiating staging table load job: %v\n", err) return } - status, err := job.Wait(bq.backgroundContext) + status, err := job.Wait(ctx) if err != nil { bq.Logger.Errorf("BQ: Error running staging table load job: %v\n", err) return @@ -432,7 +431,7 @@ func (bq *BigQuery) loadTable(tableName string, _, getLoadFileLocFromTableUpload } if !skipTempTableDelete { - defer bq.dropStagingTable(stagingTableName) + defer bq.dropStagingTable(ctx, stagingTableName) } primaryKey := "id" @@ -495,12 +494,12 @@ func (bq *BigQuery) loadTable(tableName string, _, getLoadFileLocFromTableUpload bq.Logger.Infof("BQ: Dedup records for table:%s using staging table: %s\n", tableName, sqlStatement) q := bq.db.Query(sqlStatement) - job, err = bq.getMiddleware().Run(bq.backgroundContext, q) + job, err = bq.getMiddleware().Run(ctx, q) if err != nil { bq.Logger.Errorf("BQ: Error initiating merge load job: %v\n", err) return } - status, err = job.Wait(bq.backgroundContext) + status, err = job.Wait(ctx) if err != nil { bq.Logger.Errorf("BQ: Error running merge load job: %v\n", err) return @@ -521,10 +520,10 @@ func (bq *BigQuery) loadTable(tableName string, _, getLoadFileLocFromTableUpload return } -func (bq *BigQuery) LoadUserTables(context.Context) (errorMap map[string]error) { +func (bq *BigQuery) LoadUserTables(ctx context.Context) (errorMap map[string]error) { errorMap = map[string]error{warehouseutils.IdentifiesTable: nil} bq.Logger.Infof("BQ: Starting load for identifies and users tables\n") - identifyLoadTable, err := bq.loadTable(warehouseutils.IdentifiesTable, true, false, true) + identifyLoadTable, err := bq.loadTable(ctx, warehouseutils.IdentifiesTable, true, false, true) if err != nil { errorMap[warehouseutils.IdentifiesTable] = err return @@ -572,10 +571,10 @@ func (bq *BigQuery) LoadUserTables(context.Context) (errorMap map[string]error) bqTable := func(name string) string { return fmt.Sprintf("`%s`.`%s`", bq.namespace, name) } bqUsersView := bqTable(warehouseutils.UsersView) - viewExists, _ := bq.tableExists(warehouseutils.UsersView) + viewExists, _ := bq.tableExists(ctx, warehouseutils.UsersView) if !viewExists { bq.Logger.Infof("BQ: Creating view: %s in bigquery dataset: %s in project: %s", warehouseutils.UsersView, bq.namespace, bq.projectID) - _ = bq.createTableView(warehouseutils.UsersTable, userColMap) + _ = bq.createTableView(ctx, warehouseutils.UsersTable, userColMap) } bqIdentifiesTable := bqTable(warehouseutils.IdentifiesTable) @@ -609,13 +608,13 @@ func (bq *BigQuery) LoadUserTables(context.Context) (errorMap map[string]error) query.QueryConfig.Dst = bq.db.Dataset(bq.namespace).Table(partitionedUsersTable) query.WriteDisposition = bigquery.WriteAppend - job, err := bq.getMiddleware().Run(bq.backgroundContext, query) + job, err := bq.getMiddleware().Run(ctx, query) if err != nil { bq.Logger.Errorf("BQ: Error initiating load job: %v\n", err) errorMap[warehouseutils.UsersTable] = err return } - status, err := job.Wait(bq.backgroundContext) + status, err := job.Wait(ctx) if err != nil { bq.Logger.Errorf("BQ: Error running load job: %v\n", err) errorMap[warehouseutils.UsersTable] = fmt.Errorf(`append: %v`, err.Error()) @@ -634,14 +633,14 @@ func (bq *BigQuery) LoadUserTables(context.Context) (errorMap map[string]error) query := bq.db.Query(sqlStatement) query.QueryConfig.Dst = bq.db.Dataset(bq.namespace).Table(stagingTableName) query.WriteDisposition = bigquery.WriteAppend - job, err := bq.getMiddleware().Run(bq.backgroundContext, query) + job, err := bq.getMiddleware().Run(ctx, query) if err != nil { bq.Logger.Errorf("BQ: Error initiating staging table for users : %v\n", err) errorMap[warehouseutils.UsersTable] = err return } - status, err := job.Wait(bq.backgroundContext) + status, err := job.Wait(ctx) if err != nil { bq.Logger.Errorf("BQ: Error initiating staging table for users %v\n", err) errorMap[warehouseutils.UsersTable] = fmt.Errorf(`merge: %v`, err.Error()) @@ -652,8 +651,8 @@ func (bq *BigQuery) LoadUserTables(context.Context) (errorMap map[string]error) errorMap[warehouseutils.UsersTable] = status.Err() return } - defer bq.dropStagingTable(identifyLoadTable.stagingTableName) - defer bq.dropStagingTable(stagingTableName) + defer bq.dropStagingTable(ctx, identifyLoadTable.stagingTableName) + defer bq.dropStagingTable(ctx, stagingTableName) primaryKey := "ID" columnNames := append([]string{"ID"}, userColNames...) @@ -682,13 +681,13 @@ func (bq *BigQuery) LoadUserTables(context.Context) (errorMap map[string]error) bq.Logger.Infof(`BQ: Loading data into users table: %v`, sqlStatement) // partitionedUsersTable := partitionedTable(warehouseutils.UsersTable, partitionDate) q := bq.db.Query(sqlStatement) - job, err = bq.getMiddleware().Run(bq.backgroundContext, q) + job, err = bq.getMiddleware().Run(ctx, q) if err != nil { bq.Logger.Errorf("BQ: Error initiating merge load job: %v\n", err) errorMap[warehouseutils.UsersTable] = err return } - status, err = job.Wait(bq.backgroundContext) + status, err = job.Wait(ctx) if err != nil { bq.Logger.Errorf("BQ: Error running merge load job: %v\n", err) errorMap[warehouseutils.UsersTable] = fmt.Errorf(`merge: %v`, err.Error()) @@ -728,10 +727,9 @@ func Connect(context context.Context, cred *BQCredentials) (*bigquery.Client, er return client, err } -func (bq *BigQuery) connect(cred BQCredentials) (*bigquery.Client, error) { +func (bq *BigQuery) connect(ctx context.Context, cred BQCredentials) (*bigquery.Client, error) { bq.Logger.Infof("BQ: Connecting to BigQuery in project: %s", cred.ProjectID) - bq.backgroundContext = context.Background() - client, err := Connect(bq.backgroundContext, &cred) + client, err := Connect(ctx, &cred) return client, err } @@ -739,14 +737,14 @@ func (bq *BigQuery) dedupEnabled() bool { return bq.isDedupEnabled || bq.isUsersTableDedupEnabled } -func (bq *BigQuery) CrashRecover() { +func (bq *BigQuery) CrashRecover(ctx context.Context) { if !bq.dedupEnabled() { return } - bq.dropDanglingStagingTables() + bq.dropDanglingStagingTables(ctx) } -func (bq *BigQuery) dropDanglingStagingTables() bool { +func (bq *BigQuery) dropDanglingStagingTables(ctx context.Context) bool { sqlStatement := fmt.Sprintf(` SELECT table_name @@ -760,7 +758,7 @@ func (bq *BigQuery) dropDanglingStagingTables() bool { fmt.Sprintf(`%s%%`, warehouseutils.StagingTablePrefix(provider)), ) query := bq.db.Query(sqlStatement) - it, err := bq.getMiddleware().Read(bq.backgroundContext, query) + it, err := bq.getMiddleware().Read(ctx, query) if err != nil { bq.Logger.Errorf("WH: BQ: Error dropping dangling staging tables in BQ: %v\nQuery: %s\n", err, sqlStatement) return false @@ -784,7 +782,7 @@ func (bq *BigQuery) dropDanglingStagingTables() bool { bq.Logger.Infof("WH: PG: Dropping dangling staging tables: %+v %+v\n", len(stagingTableNames), stagingTableNames) delSuccess := true for _, stagingTableName := range stagingTableNames { - err := bq.DeleteTable(stagingTableName) + err := bq.DeleteTable(ctx, stagingTableName) if err != nil { bq.Logger.Errorf("WH: BQ: Error dropping dangling staging table: %s in BQ: %v", stagingTableName, err) delSuccess = false @@ -793,13 +791,13 @@ func (bq *BigQuery) dropDanglingStagingTables() bool { return delSuccess } -func (bq *BigQuery) IsEmpty(warehouse model.Warehouse) (empty bool, err error) { +func (bq *BigQuery) IsEmpty(ctx context.Context, warehouse model.Warehouse) (empty bool, err error) { empty = true bq.warehouse = warehouse bq.namespace = warehouse.Namespace bq.projectID = strings.TrimSpace(warehouseutils.GetConfigValue(GCPProjectID, bq.warehouse)) bq.Logger.Infof("BQ: Connecting to BigQuery in project: %s", bq.projectID) - bq.db, err = bq.connect(BQCredentials{ + bq.db, err = bq.connect(ctx, BQCredentials{ ProjectID: bq.projectID, Credentials: warehouseutils.GetConfigValue(GCPCredentials, bq.warehouse), }) @@ -811,14 +809,14 @@ func (bq *BigQuery) IsEmpty(warehouse model.Warehouse) (empty bool, err error) { tables := []string{"tracks", "pages", "screens", "identifies", "aliases"} for _, tableName := range tables { var exists bool - exists, err = bq.tableExists(tableName) + exists, err = bq.tableExists(ctx, tableName) if err != nil { return } if !exists { continue } - count, err := bq.GetTotalCountInTable(bq.backgroundContext, tableName) + count, err := bq.GetTotalCountInTable(ctx, tableName) if err != nil { return empty, err } @@ -830,14 +828,13 @@ func (bq *BigQuery) IsEmpty(warehouse model.Warehouse) (empty bool, err error) { return } -func (bq *BigQuery) Setup(warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { +func (bq *BigQuery) Setup(ctx context.Context, warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { bq.warehouse = warehouse bq.namespace = warehouse.Namespace bq.uploader = uploader bq.projectID = strings.TrimSpace(warehouseutils.GetConfigValue(GCPProjectID, bq.warehouse)) - bq.backgroundContext = context.Background() - bq.db, err = bq.connect(BQCredentials{ + bq.db, err = bq.connect(ctx, BQCredentials{ ProjectID: bq.projectID, Credentials: warehouseutils.GetConfigValue(GCPCredentials, bq.warehouse), }) @@ -848,7 +845,7 @@ func (*BigQuery) TestConnection(context.Context, model.Warehouse) (err error) { return nil } -func (bq *BigQuery) LoadTable(_ context.Context, tableName string) error { +func (bq *BigQuery) LoadTable(ctx context.Context, tableName string) error { var getLoadFileLocFromTableUploads bool switch tableName { case warehouseutils.IdentityMappingsTable, warehouseutils.IdentityMergeRulesTable: @@ -856,14 +853,14 @@ func (bq *BigQuery) LoadTable(_ context.Context, tableName string) error { default: getLoadFileLocFromTableUploads = false } - _, err := bq.loadTable(tableName, false, getLoadFileLocFromTableUploads, false) + _, err := bq.loadTable(ctx, tableName, false, getLoadFileLocFromTableUploads, false) return err } -func (bq *BigQuery) AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { +func (bq *BigQuery) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { bq.Logger.Infof("BQ: Adding columns for destinationID: %s, tableName: %s, dataset: %s, project: %s", bq.warehouse.Destination.ID, tableName, bq.namespace, bq.projectID) tableRef := bq.db.Dataset(bq.namespace).Table(tableName) - meta, err := tableRef.Metadata(bq.backgroundContext) + meta, err := tableRef.Metadata(ctx) if err != nil { return } @@ -878,7 +875,7 @@ func (bq *BigQuery) AddColumns(tableName string, columnsInfo []warehouseutils.Co tableMetadataToUpdate := bigquery.TableMetadataToUpdate{ Schema: newSchema, } - _, err = tableRef.Update(bq.backgroundContext, tableMetadataToUpdate, meta.ETag) + _, err = tableRef.Update(ctx, tableMetadataToUpdate, meta.ETag) // Handle error in case of single column if len(columnsInfo) == 1 { @@ -892,12 +889,12 @@ func (bq *BigQuery) AddColumns(tableName string, columnsInfo []warehouseutils.Co return } -func (*BigQuery) AlterColumn(_, _, _ string) (model.AlterTableResponse, error) { +func (*BigQuery) AlterColumn(context.Context, string, string, string) (model.AlterTableResponse, error) { return model.AlterTableResponse{}, nil } // FetchSchema queries bigquery and returns the schema associated with provided namespace -func (bq *BigQuery) FetchSchema() (model.Schema, model.Schema, error) { +func (bq *BigQuery) FetchSchema(ctx context.Context) (model.Schema, model.Schema, error) { schema := make(model.Schema) unrecognizedSchema := make(model.Schema) @@ -920,7 +917,7 @@ func (bq *BigQuery) FetchSchema() (model.Schema, model.Schema, error) { ) query := bq.db.Query(sqlStatement) - it, err := bq.getMiddleware().Read(bq.backgroundContext, query) + it, err := bq.getMiddleware().Read(ctx, query) if err != nil { if e, ok := err.(*googleapi.Error); ok && e.Code == 404 { // if dataset resource is not found, return empty schema @@ -968,24 +965,24 @@ func (bq *BigQuery) FetchSchema() (model.Schema, model.Schema, error) { return schema, unrecognizedSchema, nil } -func (bq *BigQuery) Cleanup() { +func (bq *BigQuery) Cleanup(context.Context) { if bq.db != nil { _ = bq.db.Close() } } -func (bq *BigQuery) LoadIdentityMergeRulesTable() (err error) { +func (bq *BigQuery) LoadIdentityMergeRulesTable(ctx context.Context) (err error) { identityMergeRulesTable := warehouseutils.IdentityMergeRulesWarehouseTableName(warehouseutils.BQ) - return bq.LoadTable(context.TODO(), identityMergeRulesTable) + return bq.LoadTable(ctx, identityMergeRulesTable) } -func (bq *BigQuery) LoadIdentityMappingsTable() (err error) { +func (bq *BigQuery) LoadIdentityMappingsTable(ctx context.Context) (err error) { identityMappingsTable := warehouseutils.IdentityMappingsWarehouseTableName(warehouseutils.BQ) - return bq.LoadTable(context.TODO(), identityMappingsTable) + return bq.LoadTable(ctx, identityMappingsTable) } -func (bq *BigQuery) tableExists(tableName string) (exists bool, err error) { - _, err = bq.db.Dataset(bq.namespace).Table(tableName).Metadata(context.Background()) +func (bq *BigQuery) tableExists(ctx context.Context, tableName string) (exists bool, err error) { + _, err = bq.db.Dataset(bq.namespace).Table(tableName).Metadata(ctx) if err == nil { return true, nil } @@ -997,8 +994,8 @@ func (bq *BigQuery) tableExists(tableName string) (exists bool, err error) { return false, err } -func (bq *BigQuery) columnExists(columnName, tableName string) (exists bool, err error) { - tableMetadata, err := bq.db.Dataset(bq.namespace).Table(tableName).Metadata(context.Background()) +func (bq *BigQuery) columnExists(ctx context.Context, columnName, tableName string) (exists bool, err error) { + tableMetadata, err := bq.db.Dataset(bq.namespace).Table(tableName).Metadata(ctx) if err != nil { return false, err } @@ -1020,25 +1017,25 @@ type identityRules struct { MergeProperty2Value string `json:"merge_property_2_value"` } -func (bq *BigQuery) DownloadIdentityRules(gzWriter *misc.GZipWriter) (err error) { +func (bq *BigQuery) DownloadIdentityRules(ctx context.Context, gzWriter *misc.GZipWriter) (err error) { getFromTable := func(tableName string) (err error) { var exists bool - exists, err = bq.tableExists(tableName) + exists, err = bq.tableExists(ctx, tableName) if err != nil || !exists { return } - tableMetadata, err := bq.db.Dataset(bq.namespace).Table(tableName).Metadata(context.Background()) + tableMetadata, err := bq.db.Dataset(bq.namespace).Table(tableName).Metadata(ctx) if err != nil { return err } totalRows := int64(tableMetadata.NumRows) // check if table in warehouse has anonymous_id and user_id and construct accordingly - hasAnonymousID, err := bq.columnExists("anonymous_id", tableName) + hasAnonymousID, err := bq.columnExists(ctx, "anonymous_id", tableName) if err != nil { return } - hasUserID, err := bq.columnExists("user_id", tableName) + hasUserID, err := bq.columnExists(ctx, "user_id", tableName) if err != nil { return } @@ -1060,7 +1057,6 @@ func (bq *BigQuery) DownloadIdentityRules(gzWriter *misc.GZipWriter) (err error) for { sqlStatement := fmt.Sprintf(`SELECT DISTINCT %[1]s FROM %[2]s.%[3]s LIMIT %[4]d OFFSET %[5]d`, toSelectFields, bq.namespace, tableName, batchSize, offset) bq.Logger.Infof("BQ: Downloading distinct combinations of anonymous_id, user_id: %s, totalRows: %d", sqlStatement, totalRows) - ctx := context.Background() query := bq.db.Query(sqlStatement) job, err := bq.getMiddleware().Run(ctx, query) if err != nil { @@ -1165,11 +1161,11 @@ func (bq *BigQuery) GetTotalCountInTable(ctx context.Context, tableName string) return total, nil } -func (bq *BigQuery) Connect(warehouse model.Warehouse) (client.Client, error) { +func (bq *BigQuery) Connect(ctx context.Context, warehouse model.Warehouse) (client.Client, error) { bq.warehouse = warehouse bq.namespace = warehouse.Namespace bq.projectID = strings.TrimSpace(warehouseutils.GetConfigValue(GCPProjectID, bq.warehouse)) - dbClient, err := bq.connect(BQCredentials{ + dbClient, err := bq.connect(ctx, BQCredentials{ ProjectID: bq.projectID, Credentials: warehouseutils.GetConfigValue(GCPCredentials, bq.warehouse), }) @@ -1180,7 +1176,7 @@ func (bq *BigQuery) Connect(warehouse model.Warehouse) (client.Client, error) { return client.Client{Type: client.BQClient, BQ: dbClient}, err } -func (bq *BigQuery) LoadTestTable(location, tableName string, _ map[string]interface{}, _ string) (err error) { +func (bq *BigQuery) LoadTestTable(ctx context.Context, location, tableName string, _ map[string]interface{}, _ string) (err error) { gcsLocations := warehouseutils.GetGCSLocation(location, warehouseutils.GCSLocationOptions{}) gcsRef := bigquery.NewGCSReference([]string{gcsLocations}...) gcsRef.SourceFormat = bigquery.JSON @@ -1190,11 +1186,11 @@ func (bq *BigQuery) LoadTestTable(location, tableName string, _ map[string]inter outputTable := partitionedTable(tableName, time.Now().Format("2006-01-02")) loader := bq.db.Dataset(bq.namespace).Table(outputTable).LoaderFrom(gcsRef) - job, err := loader.Run(bq.backgroundContext) + job, err := loader.Run(ctx) if err != nil { return } - status, err := job.Wait(bq.backgroundContext) + status, err := job.Wait(ctx) if err != nil { return } diff --git a/warehouse/integrations/bigquery/bigquery_test.go b/warehouse/integrations/bigquery/bigquery_test.go index 4f8bb51334..85e934fa57 100644 --- a/warehouse/integrations/bigquery/bigquery_test.go +++ b/warehouse/integrations/bigquery/bigquery_test.go @@ -130,8 +130,9 @@ func TestIntegration(t *testing.T) { health.WaitUntilReady(ctx, t, serviceHealthEndpoint, time.Minute, time.Second, "serviceHealthEndpoint") t.Run("Event flow", func(t *testing.T) { + ctx := context.Background() db, err := bigquery.NewClient( - context.TODO(), + ctx, bqTestCredentials.ProjectID, option.WithCredentialsJSON([]byte(bqTestCredentials.Credentials)), ) require.NoError(t, err) @@ -141,7 +142,7 @@ func TestIntegration(t *testing.T) { t.Cleanup(func() { for _, dataset := range []string{namespace, sourcesNamespace} { require.NoError(t, testhelper.WithConstantRetries(func() error { - return db.Dataset(dataset).DeleteWithContents(context.TODO()) + return db.Dataset(dataset).DeleteWithContents(ctx) })) } }) @@ -181,7 +182,7 @@ func TestIntegration(t *testing.T) { prerequisite: func(t testing.TB) { t.Helper() - _ = db.Dataset(namespace).DeleteWithContents(context.TODO()) + _ = db.Dataset(namespace).DeleteWithContents(ctx) }, stagingFilePrefix: "testdata/upload-job-merge-mode", }, @@ -206,7 +207,7 @@ func TestIntegration(t *testing.T) { prerequisite: func(t testing.TB) { t.Helper() - _ = db.Dataset(namespace).DeleteWithContents(context.TODO()) + _ = db.Dataset(namespace).DeleteWithContents(ctx) }, stagingFilePrefix: "testdata/sources-job", }, @@ -227,7 +228,7 @@ func TestIntegration(t *testing.T) { prerequisite: func(t testing.TB) { t.Helper() - _ = db.Dataset(namespace).DeleteWithContents(context.TODO()) + _ = db.Dataset(namespace).DeleteWithContents(ctx) }, stagingFilePrefix: "testdata/upload-job-append-mode", }, @@ -249,7 +250,7 @@ func TestIntegration(t *testing.T) { prerequisite: func(t testing.TB) { t.Helper() - _ = db.Dataset(namespace).DeleteWithContents(context.TODO()) + _ = db.Dataset(namespace).DeleteWithContents(ctx) err = db.Dataset(namespace).Create(context.Background(), &bigquery.DatasetMetadata{ Location: "US", diff --git a/warehouse/integrations/bigquery/middleware/middleware_test.go b/warehouse/integrations/bigquery/middleware/middleware_test.go index 4fcd08e137..dc2d942736 100644 --- a/warehouse/integrations/bigquery/middleware/middleware_test.go +++ b/warehouse/integrations/bigquery/middleware/middleware_test.go @@ -23,7 +23,9 @@ func TestQueryWrapper(t *testing.T) { bqTestCredentials, err := bqHelper.GetBQTestCredentials() require.NoError(t, err) - db, err := bigquery.Connect(context.TODO(), &bigquery.BQCredentials{ + ctx := context.Background() + + db, err := bigquery.Connect(ctx, &bigquery.BQCredentials{ ProjectID: bqTestCredentials.ProjectID, Credentials: bqTestCredentials.Credentials, }) @@ -47,7 +49,6 @@ func TestQueryWrapper(t *testing.T) { } var ( - ctx = context.Background() queryThreshold = 300 * time.Second keysAndValues = []any{"key1", "value2", "key2", "value2"} ) diff --git a/warehouse/integrations/clickhouse/clickhouse.go b/warehouse/integrations/clickhouse/clickhouse.go index 0a62266956..8a73b48179 100644 --- a/warehouse/integrations/clickhouse/clickhouse.go +++ b/warehouse/integrations/clickhouse/clickhouse.go @@ -358,7 +358,7 @@ func (ch *Clickhouse) getClickHouseColumnTypeForSpecificTable(tableName, columnN return getClickhouseColumnTypeForSpecificColumn(columnName, columnType, true) } -func (*Clickhouse) DeleteBy([]string, warehouseutils.DeleteByParams) error { +func (*Clickhouse) DeleteBy(context.Context, []string, warehouseutils.DeleteByParams) error { return fmt.Errorf(warehouseutils.NotImplementedErrorCode) } @@ -538,7 +538,7 @@ func (ch *Clickhouse) loadByCopyCommand(ctx context.Context, tableName string, t return fmt.Sprintf(`%s %s`, name, rudderDataTypesMapToClickHouse[tableSchemaInUpload[name]]) }, ",") - csvObjectLocation, err := ch.Uploader.GetSampleLoadFileLocation(tableName) + csvObjectLocation, err := ch.Uploader.GetSampleLoadFileLocation(ctx, tableName) if err != nil { return fmt.Errorf("sample load file location with error: %w", err) } @@ -739,10 +739,10 @@ func (ch *Clickhouse) loadTablesFromFilesNamesWithRetry(ctx context.Context, tab return } -func (ch *Clickhouse) schemaExists(schemaName string) (exists bool, err error) { +func (ch *Clickhouse) schemaExists(ctx context.Context, schemaName string) (exists bool, err error) { var count int64 sqlStatement := "SELECT count(*) FROM system.databases WHERE name = ?" - err = ch.DB.QueryRow(sqlStatement, schemaName).Scan(&count) + err = ch.DB.QueryRowContext(ctx, sqlStatement, schemaName).Scan(&count) // ignore err if no results for query if err == sql.ErrNoRows { err = nil @@ -752,9 +752,9 @@ func (ch *Clickhouse) schemaExists(schemaName string) (exists bool, err error) { } // createSchema creates a database in clickhouse -func (ch *Clickhouse) createSchema() (err error) { +func (ch *Clickhouse) createSchema(ctx context.Context) (err error) { var schemaExists bool - schemaExists, err = ch.schemaExists(ch.Namespace) + schemaExists, err = ch.schemaExists(ctx, ch.Namespace) if err != nil { ch.Logger.Errorf("CH: Error checking if database: %s exists: %v", ch.Namespace, err) return err @@ -767,7 +767,7 @@ func (ch *Clickhouse) createSchema() (err error) { if err != nil { return err } - defer dbHandle.Close() + defer func() { _ = dbHandle.Close() }() cluster := warehouseutils.GetConfigValue(Cluster, ch.Warehouse) clusterClause := "" if len(strings.TrimSpace(cluster)) > 0 { @@ -775,7 +775,7 @@ func (ch *Clickhouse) createSchema() (err error) { } sqlStatement := fmt.Sprintf(`CREATE DATABASE IF NOT EXISTS %q %s`, ch.Namespace, clusterClause) ch.Logger.Infof("CH: Creating database in clickhouse for ch:%s : %v", ch.Warehouse.Destination.ID, sqlStatement) - _, err = dbHandle.Exec(sqlStatement) + _, err = dbHandle.ExecContext(ctx, sqlStatement) return } @@ -784,7 +784,7 @@ createUsersTable creates a user's table with engine AggregatingMergeTree, this lets us choose aggregation logic before merging records with same user id. current behaviour is to replace user properties with the latest non-null values */ -func (ch *Clickhouse) createUsersTable(name string, columns model.TableSchema) (err error) { +func (ch *Clickhouse) createUsersTable(ctx context.Context, name string, columns model.TableSchema) (err error) { sortKeyFields := []string{"id"} notNullableColumns := []string{"received_at", "id"} clusterClause := "" @@ -798,7 +798,7 @@ func (ch *Clickhouse) createUsersTable(name string, columns model.TableSchema) ( } sqlStatement := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %q.%q %s ( %v ) ENGINE = %s(%s) ORDER BY %s PARTITION BY toDate(%s)`, ch.Namespace, name, clusterClause, ch.ColumnsWithDataTypes(name, columns, notNullableColumns), engine, engineOptions, getSortKeyTuple(sortKeyFields), partitionField) ch.Logger.Infof("CH: Creating table in clickhouse for ch:%s : %v", ch.Warehouse.Destination.ID, sqlStatement) - _, err = ch.DB.Exec(sqlStatement) + _, err = ch.DB.ExecContext(ctx, sqlStatement) return } @@ -817,7 +817,7 @@ func getSortKeyTuple(sortKeyFields []string) string { // CreateTable creates table with engine ReplacingMergeTree(), this is used for dedupe event data and replace it will the latest data if duplicate data found. This logic is handled by clickhouse // The engine differs from MergeTree in that it removes duplicate entries with the same sorting key value. -func (ch *Clickhouse) CreateTable(tableName string, columns model.TableSchema) (err error) { +func (ch *Clickhouse) CreateTable(ctx context.Context, tableName string, columns model.TableSchema) (err error) { sortKeyFields := []string{"received_at", "id"} if tableName == warehouseutils.DiscardsTable { sortKeyFields = []string{"received_at"} @@ -827,7 +827,7 @@ func (ch *Clickhouse) CreateTable(tableName string, columns model.TableSchema) ( } var sqlStatement string if tableName == warehouseutils.UsersTable { - return ch.createUsersTable(tableName, columns) + return ch.createUsersTable(ctx, tableName, columns) } clusterClause := "" engine := "ReplacingMergeTree" @@ -851,22 +851,22 @@ func (ch *Clickhouse) CreateTable(tableName string, columns model.TableSchema) ( sqlStatement = fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %q.%q %s ( %v ) ENGINE = %s(%s) %s %s`, ch.Namespace, tableName, clusterClause, ch.ColumnsWithDataTypes(tableName, columns, sortKeyFields), engine, engineOptions, orderByClause, partitionByClause) ch.Logger.Infof("CH: Creating table in clickhouse for ch:%s : %v", ch.Warehouse.Destination.ID, sqlStatement) - _, err = ch.DB.Exec(sqlStatement) + _, err = ch.DB.ExecContext(ctx, sqlStatement) return } -func (ch *Clickhouse) DropTable(tableName string) (err error) { +func (ch *Clickhouse) DropTable(ctx context.Context, tableName string) (err error) { cluster := warehouseutils.GetConfigValue(Cluster, ch.Warehouse) clusterClause := "" if len(strings.TrimSpace(cluster)) > 0 { clusterClause = fmt.Sprintf(`ON CLUSTER %q`, cluster) } sqlStatement := fmt.Sprintf(`DROP TABLE %q.%q %s `, ch.Warehouse.Namespace, tableName, clusterClause) - _, err = ch.DB.Exec(sqlStatement) + _, err = ch.DB.ExecContext(ctx, sqlStatement) return } -func (ch *Clickhouse) AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { +func (ch *Clickhouse) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { var ( query string queryBuilder strings.Builder @@ -901,19 +901,19 @@ func (ch *Clickhouse) AddColumns(tableName string, columnsInfo []warehouseutils. query += ";" ch.Logger.Infof("CH: Adding columns for destinationID: %s, tableName: %s with query: %v", ch.Warehouse.Destination.ID, tableName, query) - _, err = ch.DB.Exec(query) + _, err = ch.DB.ExecContext(ctx, query) return } -func (ch *Clickhouse) CreateSchema() (err error) { +func (ch *Clickhouse) CreateSchema(ctx context.Context) (err error) { if len(ch.Uploader.GetSchemaInWarehouse()) > 0 { return nil } - err = ch.createSchema() + err = ch.createSchema(ctx) return err } -func (*Clickhouse) AlterColumn(_, _, _ string) (model.AlterTableResponse, error) { +func (*Clickhouse) AlterColumn(context.Context, string, string, string) (model.AlterTableResponse, error) { return model.AlterTableResponse{}, nil } @@ -930,7 +930,7 @@ func (ch *Clickhouse) TestConnection(ctx context.Context, _ model.Warehouse) err return nil } -func (ch *Clickhouse) Setup(warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { +func (ch *Clickhouse) Setup(_ context.Context, warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { ch.Warehouse = warehouse ch.Namespace = warehouse.Namespace ch.Uploader = uploader @@ -942,10 +942,10 @@ func (ch *Clickhouse) Setup(warehouse model.Warehouse, uploader warehouseutils.U return err } -func (*Clickhouse) CrashRecover() {} +func (*Clickhouse) CrashRecover(context.Context) {} // FetchSchema queries clickhouse and returns the schema associated with provided namespace -func (ch *Clickhouse) FetchSchema() (model.Schema, model.Schema, error) { +func (ch *Clickhouse) FetchSchema(ctx context.Context) (model.Schema, model.Schema, error) { schema := make(model.Schema) unrecognizedSchema := make(model.Schema) @@ -960,7 +960,7 @@ func (ch *Clickhouse) FetchSchema() (model.Schema, model.Schema, error) { database = ? ` - rows, err := ch.DB.Query(sqlStatement, ch.Namespace) + rows, err := ch.DB.QueryContext(ctx, sqlStatement, ch.Namespace) if errors.Is(err, sql.ErrNoRows) { return schema, unrecognizedSchema, nil } @@ -1025,25 +1025,25 @@ func (ch *Clickhouse) LoadTable(ctx context.Context, tableName string) error { return err } -func (ch *Clickhouse) Cleanup() { +func (ch *Clickhouse) Cleanup(context.Context) { if ch.DB != nil { _ = ch.DB.Close() } } -func (*Clickhouse) LoadIdentityMergeRulesTable() (err error) { +func (*Clickhouse) LoadIdentityMergeRulesTable(context.Context) (err error) { return } -func (*Clickhouse) LoadIdentityMappingsTable() (err error) { +func (*Clickhouse) LoadIdentityMappingsTable(context.Context) (err error) { return } -func (*Clickhouse) DownloadIdentityRules(*misc.GZipWriter) (err error) { +func (*Clickhouse) DownloadIdentityRules(context.Context, *misc.GZipWriter) (err error) { return } -func (*Clickhouse) IsEmpty(_ model.Warehouse) (empty bool, err error) { +func (*Clickhouse) IsEmpty(context.Context, model.Warehouse) (empty bool, err error) { return } @@ -1063,7 +1063,7 @@ func (ch *Clickhouse) GetTotalCountInTable(ctx context.Context, tableName string return total, err } -func (ch *Clickhouse) Connect(warehouse model.Warehouse) (client.Client, error) { +func (ch *Clickhouse) Connect(_ context.Context, warehouse model.Warehouse) (client.Client, error) { ch.Warehouse = warehouse ch.Namespace = warehouse.Namespace ch.ObjectStorage = warehouseutils.ObjectStorageType( @@ -1086,7 +1086,7 @@ func (ch *Clickhouse) GetLogIdentifier(args ...string) string { return fmt.Sprintf("[%s][%s][%s][%s][%s]", ch.Warehouse.Type, ch.Warehouse.Source.ID, ch.Warehouse.Destination.ID, ch.Warehouse.Namespace, strings.Join(args, "][")) } -func (ch *Clickhouse) LoadTestTable(_, tableName string, payloadMap map[string]interface{}, _ string) (err error) { +func (ch *Clickhouse) LoadTestTable(ctx context.Context, _, tableName string, payloadMap map[string]interface{}, _ string) (err error) { var columns []string var recordInterface []interface{} @@ -1101,17 +1101,17 @@ func (ch *Clickhouse) LoadTestTable(_, tableName string, payloadMap map[string]i strings.Join(columns, ","), generateArgumentString(len(columns)), ) - txn, err := ch.DB.Begin() + txn, err := ch.DB.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return } - stmt, err := txn.Prepare(sqlStatement) + stmt, err := txn.PrepareContext(ctx, sqlStatement) if err != nil { return } - if _, err = stmt.Exec(recordInterface...); err != nil { + if _, err = stmt.ExecContext(ctx, recordInterface...); err != nil { return } diff --git a/warehouse/integrations/clickhouse/clickhouse_test.go b/warehouse/integrations/clickhouse/clickhouse_test.go index c7f63e218d..53a7718dd8 100644 --- a/warehouse/integrations/clickhouse/clickhouse_test.go +++ b/warehouse/integrations/clickhouse/clickhouse_test.go @@ -370,23 +370,25 @@ type mockUploader struct { metadata []warehouseutils.LoadFile } -func (*mockUploader) GetSchemaInWarehouse() model.Schema { return model.Schema{} } -func (*mockUploader) GetLocalSchema() (model.Schema, error) { return model.Schema{}, nil } -func (*mockUploader) UpdateLocalSchema(_ model.Schema) error { return nil } -func (*mockUploader) ShouldOnDedupUseNewRecord() bool { return false } -func (*mockUploader) UseRudderStorage() bool { return false } -func (*mockUploader) GetLoadFileGenStartTIme() time.Time { return time.Time{} } -func (*mockUploader) GetLoadFileType() string { return "JSON" } -func (*mockUploader) GetFirstLastEvent() (time.Time, time.Time) { return time.Time{}, time.Time{} } +func (*mockUploader) GetSchemaInWarehouse() model.Schema { return model.Schema{} } +func (*mockUploader) GetLocalSchema(context.Context) (model.Schema, error) { + return model.Schema{}, nil +} +func (*mockUploader) UpdateLocalSchema(context.Context, model.Schema) error { return nil } +func (*mockUploader) ShouldOnDedupUseNewRecord() bool { return false } +func (*mockUploader) UseRudderStorage() bool { return false } +func (*mockUploader) GetLoadFileGenStartTIme() time.Time { return time.Time{} } +func (*mockUploader) GetLoadFileType() string { return "JSON" } +func (*mockUploader) GetFirstLastEvent() (time.Time, time.Time) { return time.Time{}, time.Time{} } func (*mockUploader) GetTableSchemaInWarehouse(_ string) model.TableSchema { return model.TableSchema{} } -func (*mockUploader) GetSingleLoadFile(_ string) (warehouseutils.LoadFile, error) { +func (*mockUploader) GetSingleLoadFile(_ context.Context, _ string) (warehouseutils.LoadFile, error) { return warehouseutils.LoadFile{}, nil } -func (m *mockUploader) GetSampleLoadFileLocation(_ string) (string, error) { +func (m *mockUploader) GetSampleLoadFileLocation(_ context.Context, _ string) (string, error) { minioHostPort := fmt.Sprintf("localhost:%s", m.minioPort) sampleLocation := m.metadata[0].Location @@ -398,7 +400,7 @@ func (m *mockUploader) GetTableSchemaInUpload(string) model.TableSchema { return m.tableSchema } -func (m *mockUploader) GetLoadFilesMetadata(warehouseutils.GetLoadFilesOptions) []warehouseutils.LoadFile { +func (m *mockUploader) GetLoadFilesMetadata(context.Context, warehouseutils.GetLoadFilesOptions) []warehouseutils.LoadFile { return m.metadata } @@ -542,7 +544,9 @@ func TestClickhouse_LoadTableRoundTrip(t *testing.T) { }) require.NoError(t, err) - uploadOutput, err := fm.Upload(context.TODO(), f, fmt.Sprintf("test_prefix_%d", i)) + ctx := context.Background() + + uploadOutput, err := fm.Upload(ctx, f, fmt.Sprintf("test_prefix_%d", i)) require.NoError(t, err) mockUploader.metadata = append(mockUploader.metadata, warehouseutils.LoadFile{ @@ -550,29 +554,29 @@ func TestClickhouse_LoadTableRoundTrip(t *testing.T) { }) t.Log("Setting up clickhouse") - err = ch.Setup(warehouse, mockUploader) + err = ch.Setup(ctx, warehouse, mockUploader) require.NoError(t, err) t.Log("Verifying connection") - _, err = ch.Connect(warehouse) + _, err = ch.Connect(ctx, warehouse) require.NoError(t, err) t.Log("Verifying empty schema") - schema, unrecognizedSchema, err := ch.FetchSchema() + schema, unrecognizedSchema, err := ch.FetchSchema(ctx) require.NoError(t, err) require.Empty(t, schema) require.Empty(t, unrecognizedSchema) t.Log("Creating schema") - err = ch.CreateSchema() + err = ch.CreateSchema(ctx) require.NoError(t, err) t.Log("Creating schema twice should not fail") - err = ch.CreateSchema() + err = ch.CreateSchema(ctx) require.NoError(t, err) t.Log("Creating table") - err = ch.CreateTable(table, model.TableSchema{ + err = ch.CreateTable(ctx, table, model.TableSchema{ "id": "string", "test_int": "int", "test_float": "float", @@ -589,7 +593,7 @@ func TestClickhouse_LoadTableRoundTrip(t *testing.T) { require.NoError(t, err) t.Log("Adding columns") - err = ch.AddColumns(table, []warehouseutils.ColumnInfo{ + err = ch.AddColumns(ctx, table, []warehouseutils.ColumnInfo{ {Name: "alter_test_int", Type: "int"}, {Name: "alter_test_float", Type: "float"}, {Name: "alter_test_bool", Type: "boolean"}, @@ -599,7 +603,7 @@ func TestClickhouse_LoadTableRoundTrip(t *testing.T) { require.NoError(t, err) t.Log("Verifying schema") - schema, unrecognizedSchema, err = ch.FetchSchema() + schema, unrecognizedSchema, err = ch.FetchSchema(ctx) require.NoError(t, err) require.NotEmpty(t, schema) require.Empty(t, unrecognizedSchema) @@ -628,21 +632,21 @@ func TestClickhouse_LoadTableRoundTrip(t *testing.T) { } t.Log("Loading data into table") - err = ch.LoadTable(context.TODO(), table) + err = ch.LoadTable(ctx, table) require.NoError(t, err) t.Log("Checking table count") - count, err := ch.GetTotalCountInTable(context.TODO(), table) + count, err := ch.GetTotalCountInTable(ctx, table) require.NoError(t, err) require.EqualValues(t, 2, count) t.Log("Drop table") - err = ch.DropTable(table) + err = ch.DropTable(ctx, table) require.NoError(t, err) t.Log("Creating users identifies and table") for _, tableName := range []string{warehouseutils.IdentifiesTable, warehouseutils.UsersTable} { - err = ch.CreateTable(tableName, model.TableSchema{ + err = ch.CreateTable(ctx, tableName, model.TableSchema{ "id": "string", "user_id": "string", "test_int": "int", @@ -657,12 +661,12 @@ func TestClickhouse_LoadTableRoundTrip(t *testing.T) { t.Log("Drop users identifies and table") for _, tableName := range []string{warehouseutils.IdentifiesTable, warehouseutils.UsersTable} { - err = ch.DropTable(tableName) + err = ch.DropTable(ctx, tableName) require.NoError(t, err) } t.Log("Verifying empty schema") - schema, unrecognizedSchema, err = ch.FetchSchema() + schema, unrecognizedSchema, err = ch.FetchSchema(ctx) require.NoError(t, err) require.Empty(t, schema) require.Empty(t, unrecognizedSchema) @@ -698,6 +702,8 @@ func TestClickhouse_TestConnection(t *testing.T) { db := connectClickhouseDB(context.Background(), t, dsn) defer func() { _ = db.Close() }() + ctx := context.Background() + testCases := []struct { name string host string @@ -756,12 +762,12 @@ func TestClickhouse_TestConnection(t *testing.T) { }, } - err := ch.Setup(warehouse, &mockUploader{}) + err := ch.Setup(ctx, warehouse, &mockUploader{}) require.NoError(t, err) ch.SetConnectionTimeout(tc.timeout) - ctx, cancel := context.WithTimeout(context.TODO(), tc.timeout) + ctx, cancel := context.WithTimeout(context.Background(), tc.timeout) defer cancel() err = ch.TestConnection(ctx, warehouse) @@ -829,6 +835,8 @@ func TestClickhouse_LoadTestTable(t *testing.T) { }, } + ctx := context.Background() + for i, tc := range testCases { tc := tc i := i @@ -860,18 +868,18 @@ func TestClickhouse_LoadTestTable(t *testing.T) { payload[k] = v } - err := ch.Setup(warehouse, &mockUploader{}) + err := ch.Setup(ctx, warehouse, &mockUploader{}) require.NoError(t, err) - err = ch.CreateSchema() + err = ch.CreateSchema(ctx) require.NoError(t, err) tableName := fmt.Sprintf("%s_%d", tableName, i) - err = ch.CreateTable(tableName, testColumns) + err = ch.CreateTable(ctx, tableName, testColumns) require.NoError(t, err) - err = ch.LoadTestTable("", tableName, payload, "") + err = ch.LoadTestTable(ctx, "", tableName, payload, "") if tc.wantError != nil { require.ErrorContains(t, err, tc.wantError.Error()) return @@ -909,6 +917,8 @@ func TestClickhouse_FetchSchema(t *testing.T) { db := connectClickhouseDB(context.Background(), t, dsn) defer func() { _ = db.Close() }() + ctx := context.Background() + t.Run("Success", func(t *testing.T) { ch := clickhouse.New() ch.Logger = logger.NOP @@ -927,13 +937,13 @@ func TestClickhouse_FetchSchema(t *testing.T) { }, } - err := ch.Setup(warehouse, &mockUploader{}) + err := ch.Setup(ctx, warehouse, &mockUploader{}) require.NoError(t, err) - err = ch.CreateSchema() + err = ch.CreateSchema(ctx) require.NoError(t, err) - err = ch.CreateTable(table, model.TableSchema{ + err = ch.CreateTable(ctx, table, model.TableSchema{ "id": "string", "test_int": "int", "test_float": "float", @@ -949,7 +959,7 @@ func TestClickhouse_FetchSchema(t *testing.T) { }) require.NoError(t, err) - schema, unrecognizedSchema, err := ch.FetchSchema() + schema, unrecognizedSchema, err := ch.FetchSchema(ctx) require.NoError(t, err) require.NotEmpty(t, schema) require.Empty(t, unrecognizedSchema) @@ -973,10 +983,10 @@ func TestClickhouse_FetchSchema(t *testing.T) { }, } - err := ch.Setup(warehouse, &mockUploader{}) + err := ch.Setup(ctx, warehouse, &mockUploader{}) require.NoError(t, err) - schema, unrecognizedSchema, err := ch.FetchSchema() + schema, unrecognizedSchema, err := ch.FetchSchema(ctx) require.ErrorContains(t, err, errors.New("dial tcp: lookup clickhouse").Error()) require.Empty(t, schema) require.Empty(t, unrecognizedSchema) @@ -1000,10 +1010,10 @@ func TestClickhouse_FetchSchema(t *testing.T) { }, } - err := ch.Setup(warehouse, &mockUploader{}) + err := ch.Setup(ctx, warehouse, &mockUploader{}) require.NoError(t, err) - schema, unrecognizedSchema, err := ch.FetchSchema() + schema, unrecognizedSchema, err := ch.FetchSchema(ctx) require.NoError(t, err) require.Empty(t, schema) require.Empty(t, unrecognizedSchema) @@ -1027,13 +1037,13 @@ func TestClickhouse_FetchSchema(t *testing.T) { }, } - err := ch.Setup(warehouse, &mockUploader{}) + err := ch.Setup(ctx, warehouse, &mockUploader{}) require.NoError(t, err) - err = ch.CreateSchema() + err = ch.CreateSchema(ctx) require.NoError(t, err) - schema, unrecognizedSchema, err := ch.FetchSchema() + schema, unrecognizedSchema, err := ch.FetchSchema(ctx) require.NoError(t, err) require.Empty(t, schema) require.Empty(t, unrecognizedSchema) @@ -1057,10 +1067,10 @@ func TestClickhouse_FetchSchema(t *testing.T) { }, } - err := ch.Setup(warehouse, &mockUploader{}) + err := ch.Setup(ctx, warehouse, &mockUploader{}) require.NoError(t, err) - err = ch.CreateSchema() + err = ch.CreateSchema(ctx) require.NoError(t, err) _, err = ch.DB.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s (x Enum('hello' = 1, 'world' = 2)) ENGINE = TinyLog;", @@ -1069,7 +1079,7 @@ func TestClickhouse_FetchSchema(t *testing.T) { )) require.NoError(t, err) - schema, unrecognizedSchema, err := ch.FetchSchema() + schema, unrecognizedSchema, err := ch.FetchSchema(ctx) require.NoError(t, err) require.NotEmpty(t, schema) require.NotEmpty(t, unrecognizedSchema) diff --git a/warehouse/integrations/datalake/datalake.go b/warehouse/integrations/datalake/datalake.go index f8ba69a357..b397e7bb2d 100644 --- a/warehouse/integrations/datalake/datalake.go +++ b/warehouse/integrations/datalake/datalake.go @@ -40,7 +40,7 @@ func New() *Datalake { } } -func (d *Datalake) Setup(warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { +func (d *Datalake) Setup(_ context.Context, warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { d.Warehouse = warehouse d.Uploader = uploader @@ -49,30 +49,30 @@ func (d *Datalake) Setup(warehouse model.Warehouse, uploader warehouseutils.Uplo return err } -func (*Datalake) CrashRecover() {} +func (*Datalake) CrashRecover(context.Context) {} -func (d *Datalake) FetchSchema() (model.Schema, model.Schema, error) { - return d.SchemaRepository.FetchSchema(d.Warehouse) +func (d *Datalake) FetchSchema(ctx context.Context) (model.Schema, model.Schema, error) { + return d.SchemaRepository.FetchSchema(ctx, d.Warehouse) } -func (d *Datalake) CreateSchema() (err error) { - return d.SchemaRepository.CreateSchema() +func (d *Datalake) CreateSchema(ctx context.Context) (err error) { + return d.SchemaRepository.CreateSchema(ctx) } -func (d *Datalake) CreateTable(tableName string, columnMap model.TableSchema) (err error) { - return d.SchemaRepository.CreateTable(tableName, columnMap) +func (d *Datalake) CreateTable(ctx context.Context, tableName string, columnMap model.TableSchema) (err error) { + return d.SchemaRepository.CreateTable(ctx, tableName, columnMap) } -func (*Datalake) DropTable(_ string) (err error) { +func (*Datalake) DropTable(context.Context, string) (err error) { return fmt.Errorf("datalake err :not implemented") } -func (d *Datalake) AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { - return d.SchemaRepository.AddColumns(tableName, columnsInfo) +func (d *Datalake) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { + return d.SchemaRepository.AddColumns(ctx, tableName, columnsInfo) } -func (d *Datalake) AlterColumn(tableName, columnName, columnType string) (model.AlterTableResponse, error) { - return d.SchemaRepository.AlterColumn(tableName, columnName, columnType) +func (d *Datalake) AlterColumn(ctx context.Context, tableName, columnName, columnType string) (model.AlterTableResponse, error) { + return d.SchemaRepository.AlterColumn(ctx, tableName, columnName, columnType) } func (d *Datalake) LoadTable(_ context.Context, tableName string) error { @@ -80,7 +80,7 @@ func (d *Datalake) LoadTable(_ context.Context, tableName string) error { return nil } -func (*Datalake) DeleteBy([]string, warehouseutils.DeleteByParams) (err error) { +func (*Datalake) DeleteBy(context.Context, []string, warehouseutils.DeleteByParams) (err error) { return fmt.Errorf(warehouseutils.NotImplementedErrorCode) } @@ -95,20 +95,20 @@ func (d *Datalake) LoadUserTables(context.Context) map[string]error { return errorMap } -func (d *Datalake) LoadIdentityMergeRulesTable() error { +func (d *Datalake) LoadIdentityMergeRulesTable(context.Context) error { d.Logger.Infof("Skipping load for identity merge rules : %s is a datalake destination", d.Warehouse.Destination.ID) return nil } -func (d *Datalake) LoadIdentityMappingsTable() error { +func (d *Datalake) LoadIdentityMappingsTable(context.Context) error { d.Logger.Infof("Skipping load for identity mappings : %s is a datalake destination", d.Warehouse.Destination.ID) return nil } -func (*Datalake) Cleanup() { +func (*Datalake) Cleanup(context.Context) { } -func (*Datalake) IsEmpty(_ model.Warehouse) (bool, error) { +func (*Datalake) IsEmpty(context.Context, model.Warehouse) (bool, error) { return false, nil } @@ -116,7 +116,7 @@ func (*Datalake) TestConnection(context.Context, model.Warehouse) error { return fmt.Errorf("datalake err :not implemented") } -func (*Datalake) DownloadIdentityRules(*misc.GZipWriter) error { +func (*Datalake) DownloadIdentityRules(context.Context, *misc.GZipWriter) error { return fmt.Errorf("datalake err :not implemented") } @@ -124,11 +124,11 @@ func (*Datalake) GetTotalCountInTable(context.Context, string) (int64, error) { return 0, nil } -func (*Datalake) Connect(_ model.Warehouse) (client.Client, error) { +func (*Datalake) Connect(context.Context, model.Warehouse) (client.Client, error) { return client.Client{}, fmt.Errorf("datalake err :not implemented") } -func (*Datalake) LoadTestTable(_, _ string, _ map[string]interface{}, _ string) error { +func (*Datalake) LoadTestTable(context.Context, string, string, map[string]interface{}, string) error { return fmt.Errorf("datalake err :not implemented") } diff --git a/warehouse/integrations/datalake/schema-repository/glue.go b/warehouse/integrations/datalake/schema-repository/glue.go index fbc74e31ab..f48d61c5dc 100644 --- a/warehouse/integrations/datalake/schema-repository/glue.go +++ b/warehouse/integrations/datalake/schema-repository/glue.go @@ -1,6 +1,7 @@ package schemarepository import ( + "context" "fmt" "net/url" "regexp" @@ -57,7 +58,7 @@ func NewGlueSchemaRepository(wh model.Warehouse) (*GlueSchemaRepository, error) return &gl, nil } -func (gl *GlueSchemaRepository) FetchSchema(warehouse model.Warehouse) (model.Schema, model.Schema, error) { +func (gl *GlueSchemaRepository) FetchSchema(ctx context.Context, warehouse model.Warehouse) (model.Schema, model.Schema, error) { schema := model.Schema{} unrecognizedSchema := model.Schema{} var err error @@ -72,7 +73,7 @@ func (gl *GlueSchemaRepository) FetchSchema(warehouse model.Warehouse) (model.Sc getTablesInput.NextToken = getTablesOutput.NextToken } - getTablesOutput, err = gl.GlueClient.GetTables(getTablesInput) + getTablesOutput, err = gl.GlueClient.GetTablesWithContext(ctx, getTablesInput) if err != nil { if _, ok := err.(*glue.EntityNotFoundException); ok { gl.Logger.Debugf("FetchSchema: database %s not found in glue. returning empty schema", warehouse.Namespace) @@ -112,8 +113,8 @@ func (gl *GlueSchemaRepository) FetchSchema(warehouse model.Warehouse) (model.Sc return schema, unrecognizedSchema, err } -func (gl *GlueSchemaRepository) CreateSchema() (err error) { - _, err = gl.GlueClient.CreateDatabase(&glue.CreateDatabaseInput{ +func (gl *GlueSchemaRepository) CreateSchema(ctx context.Context) (err error) { + _, err = gl.GlueClient.CreateDatabaseWithContext(ctx, &glue.CreateDatabaseInput{ DatabaseInput: &glue.DatabaseInput{ Name: &gl.Namespace, }, @@ -125,7 +126,7 @@ func (gl *GlueSchemaRepository) CreateSchema() (err error) { return } -func (gl *GlueSchemaRepository) CreateTable(tableName string, columnMap model.TableSchema) (err error) { +func (gl *GlueSchemaRepository) CreateTable(ctx context.Context, tableName string, columnMap model.TableSchema) (err error) { partitionKeys, err := gl.partitionColumns() if err != nil { return fmt.Errorf("partition keys: %w", err) @@ -145,7 +146,7 @@ func (gl *GlueSchemaRepository) CreateTable(tableName string, columnMap model.Ta // add storage descriptor to create table request input.TableInput.StorageDescriptor = gl.getStorageDescriptor(tableName, columnMap) - _, err = gl.GlueClient.CreateTable(&input) + _, err = gl.GlueClient.CreateTableWithContext(ctx, &input) if err != nil { _, ok := err.(*glue.AlreadyExistsException) if ok { @@ -155,7 +156,7 @@ func (gl *GlueSchemaRepository) CreateTable(tableName string, columnMap model.Ta return } -func (gl *GlueSchemaRepository) updateTable(tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { +func (gl *GlueSchemaRepository) updateTable(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { updateTableInput := glue.UpdateTableInput{ DatabaseName: aws.String(gl.Namespace), TableInput: &glue.TableInput{ @@ -164,7 +165,7 @@ func (gl *GlueSchemaRepository) updateTable(tableName string, columnsInfo []ware } // fetch schema from glue - schema, _, err := gl.FetchSchema(gl.Warehouse) + schema, _, err := gl.FetchSchema(ctx, gl.Warehouse) if err != nil { return err } @@ -190,16 +191,16 @@ func (gl *GlueSchemaRepository) updateTable(tableName string, columnsInfo []ware updateTableInput.TableInput.PartitionKeys = partitionKeys // update table - _, err = gl.GlueClient.UpdateTable(&updateTableInput) + _, err = gl.GlueClient.UpdateTableWithContext(ctx, &updateTableInput) return } -func (gl *GlueSchemaRepository) AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { - return gl.updateTable(tableName, columnsInfo) +func (gl *GlueSchemaRepository) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { + return gl.updateTable(ctx, tableName, columnsInfo) } -func (gl *GlueSchemaRepository) AlterColumn(tableName, columnName, columnType string) (model.AlterTableResponse, error) { - return model.AlterTableResponse{}, gl.updateTable(tableName, []warehouseutils.ColumnInfo{{Name: columnName, Type: columnType}}) +func (gl *GlueSchemaRepository) AlterColumn(ctx context.Context, tableName, columnName, columnType string) (model.AlterTableResponse, error) { + return model.AlterTableResponse{}, gl.updateTable(ctx, tableName, []warehouseutils.ColumnInfo{{Name: columnName, Type: columnType}}) } func getGlueClient(wh model.Warehouse) (*glue.Glue, error) { @@ -250,7 +251,7 @@ func (gl *GlueSchemaRepository) getS3LocationForTable(tableName string) string { // RefreshPartitions takes a tableName and a list of loadFiles and refreshes all the // partitions that are modified by the path in those loadFiles. It returns any error // reported by Glue -func (gl *GlueSchemaRepository) RefreshPartitions(tableName string, loadFiles []warehouseutils.LoadFile) error { +func (gl *GlueSchemaRepository) RefreshPartitions(ctx context.Context, tableName string, loadFiles []warehouseutils.LoadFile) error { gl.Logger.Infof("Refreshing partitions for table: %s", tableName) // Skip if time window layout is not defined @@ -299,7 +300,7 @@ func (gl *GlueSchemaRepository) RefreshPartitions(tableName string, loadFiles [] // partitions) changes in Glue tables (since the number of versions of a Glue table // is limited) for location, partition := range locationsToPartition { - _, err := gl.GlueClient.GetPartition(&glue.GetPartitionInput{ + _, err := gl.GlueClient.GetPartitionWithContext(ctx, &glue.GetPartitionInput{ DatabaseName: aws.String(gl.Namespace), PartitionValues: partition.Values, TableName: aws.String(tableName), @@ -320,13 +321,13 @@ func (gl *GlueSchemaRepository) RefreshPartitions(tableName string, loadFiles [] } // Updating table partitions with empty columns to create partition keys if not created - if err = gl.updateTable(tableName, []warehouseutils.ColumnInfo{}); err != nil { + if err = gl.updateTable(ctx, tableName, []warehouseutils.ColumnInfo{}); err != nil { return fmt.Errorf("update table: %w", err) } gl.Logger.Debugf("Refreshing %d partitions", len(partitionInputs)) - if _, err = gl.GlueClient.BatchCreatePartition(&glue.BatchCreatePartitionInput{ + if _, err = gl.GlueClient.BatchCreatePartitionWithContext(ctx, &glue.BatchCreatePartitionInput{ DatabaseName: aws.String(gl.Namespace), PartitionInputList: partitionInputs, TableName: aws.String(tableName), diff --git a/warehouse/integrations/datalake/schema-repository/glue_test.go b/warehouse/integrations/datalake/schema-repository/glue_test.go index 5bd8b5d2d5..c752eb6390 100644 --- a/warehouse/integrations/datalake/schema-repository/glue_test.go +++ b/warehouse/integrations/datalake/schema-repository/glue_test.go @@ -121,16 +121,18 @@ func TestGlueSchemaRepositoryRoundTrip(t *testing.T) { warehouseutils.Init() encoding.Init() + ctx := context.Background() + g, err := NewGlueSchemaRepository(warehouse) g.Logger = logger.NOP require.NoError(t, err) t.Logf("Creating schema %s", testNamespace) - err = g.CreateSchema() + err = g.CreateSchema(ctx) require.NoError(t, err) t.Log("Creating already existing schema should not fail") - err = g.CreateSchema() + err = g.CreateSchema(ctx) require.NoError(t, err) t.Cleanup(func() { @@ -142,15 +144,15 @@ func TestGlueSchemaRepositoryRoundTrip(t *testing.T) { }) t.Logf("Creating table %s", testTable) - err = g.CreateTable(testTable, testColumns) + err = g.CreateTable(ctx, testTable, testColumns) require.NoError(t, err) t.Log("Creating already existing table should not fail") - err = g.CreateTable(testTable, testColumns) + err = g.CreateTable(ctx, testTable, testColumns) require.NoError(t, err) t.Log("Adding columns to table") - err = g.AddColumns(testTable, []warehouseutils.ColumnInfo{ + err = g.AddColumns(ctx, testTable, []warehouseutils.ColumnInfo{ {Name: "alter_test_bool", Type: "boolean"}, {Name: "alter_test_string", Type: "string"}, {Name: "alter_test_int", Type: "int"}, @@ -178,10 +180,10 @@ func TestGlueSchemaRepositoryRoundTrip(t *testing.T) { }) require.NoError(t, err) - uploadOutput, err := fm.Upload(context.TODO(), f, fmt.Sprintf("rudder-test-payload/s3-datalake/%s/%s/", warehouseutils.RandHex(), tc.windowLayout)) + uploadOutput, err := fm.Upload(ctx, f, fmt.Sprintf("rudder-test-payload/s3-datalake/%s/%s/", warehouseutils.RandHex(), tc.windowLayout)) require.NoError(t, err) - err = g.RefreshPartitions(testTable, []warehouseutils.LoadFile{ + err = g.RefreshPartitions(ctx, testTable, []warehouseutils.LoadFile{ { Location: uploadOutput.Location, }, diff --git a/warehouse/integrations/datalake/schema-repository/local.go b/warehouse/integrations/datalake/schema-repository/local.go index 21c58a3a24..8a293f1917 100644 --- a/warehouse/integrations/datalake/schema-repository/local.go +++ b/warehouse/integrations/datalake/schema-repository/local.go @@ -1,6 +1,7 @@ package schemarepository import ( + "context" "fmt" "github.com/rudderlabs/rudder-server/warehouse/internal/model" @@ -22,8 +23,8 @@ func NewLocalSchemaRepository(wh model.Warehouse, uploader warehouseutils.Upload return &ls, nil } -func (ls *LocalSchemaRepository) FetchSchema(_ model.Warehouse) (model.Schema, model.Schema, error) { - schema, err := ls.uploader.GetLocalSchema() +func (ls *LocalSchemaRepository) FetchSchema(ctx context.Context, _ model.Warehouse) (model.Schema, model.Schema, error) { + schema, err := ls.uploader.GetLocalSchema(ctx) if err != nil { return model.Schema{}, model.Schema{}, fmt.Errorf("fetching local schema: %w", err) } @@ -31,13 +32,13 @@ func (ls *LocalSchemaRepository) FetchSchema(_ model.Warehouse) (model.Schema, m return schema, model.Schema{}, nil } -func (*LocalSchemaRepository) CreateSchema() (err error) { +func (*LocalSchemaRepository) CreateSchema(context.Context) (err error) { return nil } -func (ls *LocalSchemaRepository) CreateTable(tableName string, columnMap model.TableSchema) (err error) { +func (ls *LocalSchemaRepository) CreateTable(ctx context.Context, tableName string, columnMap model.TableSchema) (err error) { // fetch schema from local db - schema, err := ls.uploader.GetLocalSchema() + schema, err := ls.uploader.GetLocalSchema(ctx) if err != nil { return fmt.Errorf("fetching local schema: %w", err) } @@ -50,12 +51,12 @@ func (ls *LocalSchemaRepository) CreateTable(tableName string, columnMap model.T schema[tableName] = columnMap // update schema - return ls.uploader.UpdateLocalSchema(schema) + return ls.uploader.UpdateLocalSchema(ctx, schema) } -func (ls *LocalSchemaRepository) AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { +func (ls *LocalSchemaRepository) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { // fetch schema from local db - schema, err := ls.uploader.GetLocalSchema() + schema, err := ls.uploader.GetLocalSchema(ctx) if err != nil { return fmt.Errorf("fetching local schema: %w", err) } @@ -70,12 +71,12 @@ func (ls *LocalSchemaRepository) AddColumns(tableName string, columnsInfo []ware } // update schema - return ls.uploader.UpdateLocalSchema(schema) + return ls.uploader.UpdateLocalSchema(ctx, schema) } -func (ls *LocalSchemaRepository) AlterColumn(tableName, columnName, columnType string) (model.AlterTableResponse, error) { +func (ls *LocalSchemaRepository) AlterColumn(ctx context.Context, tableName, columnName, columnType string) (model.AlterTableResponse, error) { // fetch schema from local db - schema, err := ls.uploader.GetLocalSchema() + schema, err := ls.uploader.GetLocalSchema(ctx) if err != nil { return model.AlterTableResponse{}, fmt.Errorf("fetching local schema: %w", err) } @@ -93,9 +94,9 @@ func (ls *LocalSchemaRepository) AlterColumn(tableName, columnName, columnType s schema[tableName][columnName] = columnType // update schema - return model.AlterTableResponse{}, ls.uploader.UpdateLocalSchema(schema) + return model.AlterTableResponse{}, ls.uploader.UpdateLocalSchema(ctx, schema) } -func (*LocalSchemaRepository) RefreshPartitions(_ string, _ []warehouseutils.LoadFile) error { +func (*LocalSchemaRepository) RefreshPartitions(context.Context, string, []warehouseutils.LoadFile) error { return nil } diff --git a/warehouse/integrations/datalake/schema-repository/local_test.go b/warehouse/integrations/datalake/schema-repository/local_test.go index a66467576a..9cb4de04bc 100644 --- a/warehouse/integrations/datalake/schema-repository/local_test.go +++ b/warehouse/integrations/datalake/schema-repository/local_test.go @@ -1,6 +1,7 @@ package schemarepository_test import ( + "context" "fmt" "testing" "time" @@ -17,15 +18,18 @@ type mockUploader struct { localSchema model.Schema } -func (*mockUploader) GetSchemaInWarehouse() model.Schema { return model.Schema{} } -func (*mockUploader) ShouldOnDedupUseNewRecord() bool { return false } -func (*mockUploader) UseRudderStorage() bool { return false } -func (*mockUploader) GetLoadFileGenStartTIme() time.Time { return time.Time{} } -func (*mockUploader) GetLoadFileType() string { return "JSON" } -func (*mockUploader) GetFirstLastEvent() (time.Time, time.Time) { return time.Time{}, time.Time{} } -func (*mockUploader) GetTableSchemaInUpload(string) model.TableSchema { return nil } -func (*mockUploader) GetSampleLoadFileLocation(string) (string, error) { return "", nil } -func (*mockUploader) GetLoadFilesMetadata(warehouseutils.GetLoadFilesOptions) []warehouseutils.LoadFile { +func (*mockUploader) GetSchemaInWarehouse() model.Schema { return model.Schema{} } +func (*mockUploader) ShouldOnDedupUseNewRecord() bool { return false } +func (*mockUploader) UseRudderStorage() bool { return false } +func (*mockUploader) GetLoadFileGenStartTIme() time.Time { return time.Time{} } +func (*mockUploader) GetLoadFileType() string { return "JSON" } +func (*mockUploader) GetFirstLastEvent() (time.Time, time.Time) { return time.Time{}, time.Time{} } +func (*mockUploader) GetTableSchemaInUpload(string) model.TableSchema { return nil } +func (*mockUploader) GetSampleLoadFileLocation(context.Context, string) (string, error) { + return "", nil +} + +func (*mockUploader) GetLoadFilesMetadata(context.Context, warehouseutils.GetLoadFilesOptions) []warehouseutils.LoadFile { return nil } @@ -33,19 +37,21 @@ func (*mockUploader) GetTableSchemaInWarehouse(string) model.TableSchema { return model.TableSchema{} } -func (*mockUploader) GetSingleLoadFile(string) (warehouseutils.LoadFile, error) { +func (*mockUploader) GetSingleLoadFile(context.Context, string) (warehouseutils.LoadFile, error) { return warehouseutils.LoadFile{}, nil } -func (m *mockUploader) GetLocalSchema() (model.Schema, error) { +func (m *mockUploader) GetLocalSchema(context.Context) (model.Schema, error) { return m.localSchema, nil } -func (m *mockUploader) UpdateLocalSchema(model.Schema) error { +func (m *mockUploader) UpdateLocalSchema(context.Context, model.Schema) error { return m.mockError } func TestLocalSchemaRepository_CreateTable(t *testing.T) { + t.Parallel() + testCases := []struct { name string mockError error @@ -88,7 +94,9 @@ func TestLocalSchemaRepository_CreateTable(t *testing.T) { s, err := schemarepository.NewLocalSchemaRepository(warehouse, uploader) require.NoError(t, err) - err = s.CreateTable("test_table", model.TableSchema{ + ctx := context.Background() + + err = s.CreateTable(ctx, "test_table", model.TableSchema{ "test_column_2": "test_type_2", }) if tc.wantError != nil { @@ -101,6 +109,8 @@ func TestLocalSchemaRepository_CreateTable(t *testing.T) { } func TestLocalSchemaRepository_AddColumns(t *testing.T) { + t.Parallel() + testCases := []struct { name string mockError error @@ -146,7 +156,9 @@ func TestLocalSchemaRepository_AddColumns(t *testing.T) { s, err := schemarepository.NewLocalSchemaRepository(warehouse, uploader) require.NoError(t, err) - err = s.AddColumns("test_table", []warehouseutils.ColumnInfo{ + ctx := context.Background() + + err = s.AddColumns(ctx, "test_table", []warehouseutils.ColumnInfo{ { Name: "test_column_2", Type: "test_type_2", @@ -162,6 +174,8 @@ func TestLocalSchemaRepository_AddColumns(t *testing.T) { } func TestLocalSchemaRepository_AlterColumn(t *testing.T) { + t.Parallel() + testCases := []struct { name string mockError error @@ -216,7 +230,9 @@ func TestLocalSchemaRepository_AlterColumn(t *testing.T) { s, err := schemarepository.NewLocalSchemaRepository(warehouse, uploader) require.NoError(t, err) - _, err = s.AlterColumn("test_table", "test_column_1", "test_type_2") + ctx := context.Background() + + _, err = s.AlterColumn(ctx, "test_table", "test_column_1", "test_type_2") if tc.wantError != nil { require.EqualError(t, err, tc.wantError.Error()) } else { diff --git a/warehouse/integrations/datalake/schema-repository/schema_repository.go b/warehouse/integrations/datalake/schema-repository/schema_repository.go index 208a66450d..4072374965 100644 --- a/warehouse/integrations/datalake/schema-repository/schema_repository.go +++ b/warehouse/integrations/datalake/schema-repository/schema_repository.go @@ -1,11 +1,11 @@ package schemarepository import ( + "context" "fmt" - "github.com/rudderlabs/rudder-server/warehouse/internal/model" - "github.com/rudderlabs/rudder-server/utils/misc" + "github.com/rudderlabs/rudder-server/warehouse/internal/model" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) @@ -33,12 +33,12 @@ var ( ) type SchemaRepository interface { - FetchSchema(warehouse model.Warehouse) (model.Schema, model.Schema, error) - CreateSchema() (err error) - CreateTable(tableName string, columnMap model.TableSchema) (err error) - AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) - AlterColumn(tableName, columnName, columnType string) (model.AlterTableResponse, error) - RefreshPartitions(tableName string, loadFiles []warehouseutils.LoadFile) error + FetchSchema(ctx context.Context, warehouse model.Warehouse) (model.Schema, model.Schema, error) + CreateSchema(ctx context.Context) (err error) + CreateTable(ctx context.Context, tableName string, columnMap model.TableSchema) (err error) + AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) + AlterColumn(ctx context.Context, tableName, columnName, columnType string) (model.AlterTableResponse, error) + RefreshPartitions(ctx context.Context, tableName string, loadFiles []warehouseutils.LoadFile) error } func UseGlue(w *model.Warehouse) bool { diff --git a/warehouse/integrations/deltalake-native/deltalake.go b/warehouse/integrations/deltalake-native/deltalake.go index 4545796871..39767d6373 100644 --- a/warehouse/integrations/deltalake-native/deltalake.go +++ b/warehouse/integrations/deltalake-native/deltalake.go @@ -156,7 +156,7 @@ func WithConfig(h *Deltalake, config *config.Config) { } // Setup sets up the warehouse -func (d *Deltalake) Setup(warehouse model.Warehouse, uploader warehouseutils.Uploader) error { +func (d *Deltalake) Setup(_ context.Context, warehouse model.Warehouse, uploader warehouseutils.Uploader) error { d.Warehouse = warehouse d.Namespace = warehouse.Namespace d.Uploader = uploader @@ -230,13 +230,13 @@ func (d *Deltalake) connect() (*sqlmiddleware.DB, error) { } // CrashRecover crash recover scenarios -func (d *Deltalake) CrashRecover() { - d.dropDanglingStagingTables() +func (d *Deltalake) CrashRecover(ctx context.Context) { + d.dropDanglingStagingTables(ctx) } // dropDanglingStagingTables drops dangling staging tables -func (d *Deltalake) dropDanglingStagingTables() { - tableNames, err := d.fetchTables(rudderStagingTableRegex) +func (d *Deltalake) dropDanglingStagingTables(ctx context.Context) { + tableNames, err := d.fetchTables(ctx, rudderStagingTableRegex) if err != nil { d.Logger.Warnw("fetching tables for dropping dangling staging tables", logfield.SourceID, d.Warehouse.Source.ID, @@ -250,14 +250,14 @@ func (d *Deltalake) dropDanglingStagingTables() { return } - d.dropStagingTables(tableNames) + d.dropStagingTables(ctx, tableNames) } // fetchTables fetches tables from the database -func (d *Deltalake) fetchTables(regex string) ([]string, error) { +func (d *Deltalake) fetchTables(ctx context.Context, regex string) ([]string, error) { query := fmt.Sprintf(`SHOW tables FROM %s LIKE '%s';`, d.Namespace, regex) - rows, err := d.DB.Query(query) + rows, err := d.DB.QueryContext(ctx, query) if err != nil { if strings.Contains(err.Error(), schemaNotFound) { return nil, nil @@ -287,9 +287,9 @@ func (d *Deltalake) fetchTables(regex string) ([]string, error) { } // dropStagingTables drops all the staging tables -func (d *Deltalake) dropStagingTables(stagingTables []string) { +func (d *Deltalake) dropStagingTables(ctx context.Context, stagingTables []string) { for _, stagingTable := range stagingTables { - err := d.dropTable(stagingTable) + err := d.dropTable(ctx, stagingTable) if err != nil { d.Logger.Warnw("dropping staging table", logfield.SourceID, d.Warehouse.Source.ID, @@ -306,10 +306,10 @@ func (d *Deltalake) dropStagingTables(stagingTables []string) { } // DropTable drops a table from the warehouse -func (d *Deltalake) dropTable(table string) error { +func (d *Deltalake) dropTable(ctx context.Context, table string) error { query := fmt.Sprintf(`DROP TABLE %s.%s;`, d.Namespace, table) - _, err := d.DB.Exec(query) + _, err := d.DB.ExecContext(ctx, query) if err != nil { return fmt.Errorf("executing drop table: %w", err) } @@ -318,17 +318,17 @@ func (d *Deltalake) dropTable(table string) error { } // FetchSchema fetches the schema from the warehouse -func (d *Deltalake) FetchSchema() (model.Schema, model.Schema, error) { +func (d *Deltalake) FetchSchema(ctx context.Context) (model.Schema, model.Schema, error) { schema := make(model.Schema) unrecognizedSchema := make(model.Schema) - tableNames, err := d.fetchTables(nonRudderStagingTableRegex) + tableNames, err := d.fetchTables(ctx, nonRudderStagingTableRegex) if err != nil { return model.Schema{}, model.Schema{}, fmt.Errorf("fetching tables: %w", err) } // For each table, fetch the attributes for _, tableName := range tableNames { - tableSchema, err := d.fetchTableAttributes(tableName) + tableSchema, err := d.fetchTableAttributes(ctx, tableName) if err != nil { return model.Schema{}, model.Schema{}, fmt.Errorf("fetching table attributes: %w", err) } @@ -357,12 +357,12 @@ func (d *Deltalake) FetchSchema() (model.Schema, model.Schema, error) { } // fetchTableAttributes fetches the attributes of a table -func (d *Deltalake) fetchTableAttributes(tableName string) (model.TableSchema, error) { +func (d *Deltalake) fetchTableAttributes(ctx context.Context, tableName string) (model.TableSchema, error) { tableSchema := make(model.TableSchema) query := fmt.Sprintf(`DESCRIBE QUERY TABLE %s.%s;`, d.Namespace, tableName) - rows, err := d.DB.Query(query) + rows, err := d.DB.QueryContext(ctx, query) if err != nil { return nil, fmt.Errorf("executing fetching table attributes: %w", err) } @@ -387,12 +387,12 @@ func (d *Deltalake) fetchTableAttributes(tableName string) (model.TableSchema, e } // CreateSchema creates a schema in the warehouse if it does not exist. -func (d *Deltalake) CreateSchema() error { - if exists, err := d.schemaExists(); err != nil { +func (d *Deltalake) CreateSchema(ctx context.Context) error { + if exists, err := d.schemaExists(ctx); err != nil { return fmt.Errorf("checking if schema exists: %w", err) } else if exists { return nil - } else if err := d.createSchema(); err != nil { + } else if err := d.createSchema(ctx); err != nil { return fmt.Errorf("create schema: %w", err) } @@ -400,11 +400,11 @@ func (d *Deltalake) CreateSchema() error { } // schemaExists checks if a schema exists in the warehouse. -func (d *Deltalake) schemaExists() (bool, error) { +func (d *Deltalake) schemaExists(ctx context.Context) (bool, error) { query := fmt.Sprintf(`SHOW SCHEMAS LIKE '%s';`, d.Namespace) var schema string - err := d.DB.QueryRow(query).Scan(&schema) + err := d.DB.QueryRowContext(ctx, query).Scan(&schema) if err == sql.ErrNoRows { return false, nil @@ -416,10 +416,10 @@ func (d *Deltalake) schemaExists() (bool, error) { } // createSchema creates a schema in the warehouse. -func (d *Deltalake) createSchema() error { +func (d *Deltalake) createSchema(ctx context.Context) error { query := fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %s;`, d.Namespace) - _, err := d.DB.Exec(query) + _, err := d.DB.ExecContext(ctx, query) if err != nil { return fmt.Errorf("executing create schema: %w", err) } @@ -428,7 +428,7 @@ func (d *Deltalake) createSchema() error { } // CreateTable creates a table in the warehouse. -func (d *Deltalake) CreateTable(tableName string, columns model.TableSchema) error { +func (d *Deltalake) CreateTable(ctx context.Context, tableName string, columns model.TableSchema) error { var partitionedSql, tableLocationSql string tableLocationSql = d.tableLocationQuery(tableName) @@ -452,7 +452,7 @@ func (d *Deltalake) CreateTable(tableName string, columns model.TableSchema) err partitionedSql, ) - _, err := d.DB.Exec(query) + _, err := d.DB.ExecContext(ctx, query) if err != nil { return fmt.Errorf("creating table: %w", err) } @@ -489,7 +489,7 @@ func (d *Deltalake) tableLocationQuery(tableName string) string { } // AddColumns adds columns to the table. -func (d *Deltalake) AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) error { +func (d *Deltalake) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) error { var queryBuilder strings.Builder queryBuilder.WriteString(fmt.Sprintf(` @@ -507,7 +507,7 @@ func (d *Deltalake) AddColumns(tableName string, columnsInfo []warehouseutils.Co query := strings.TrimSuffix(queryBuilder.String(), ",") query += ");" - _, err := d.DB.Exec(query) + _, err := d.DB.ExecContext(ctx, query) // Handle error in case of single column if len(columnsInfo) == 1 { @@ -535,7 +535,7 @@ func (d *Deltalake) AddColumns(tableName string, columnsInfo []warehouseutils.Co } // AlterColumn alters a column in the warehouse -func (*Deltalake) AlterColumn(_, _, _ string) (model.AlterTableResponse, error) { +func (*Deltalake) AlterColumn(context.Context, string, string, string) (model.AlterTableResponse, error) { return model.AlterTableResponse{}, nil } @@ -572,19 +572,19 @@ func (d *Deltalake) loadTable(ctx context.Context, tableName string, tableSchema logfield.TableName, tableName, ) - if err = d.CreateTable(stagingTableName, tableSchemaAfterUpload); err != nil { + if err = d.CreateTable(ctx, stagingTableName, tableSchemaAfterUpload); err != nil { return "", fmt.Errorf("creating staging table: %w", err) } if !skipTempTableDelete { - defer d.dropStagingTables([]string{stagingTableName}) + defer d.dropStagingTables(ctx, []string{stagingTableName}) } if auth, err = d.authQuery(); err != nil { return "", fmt.Errorf("getting auth query: %w", err) } - objectsLocation, err := d.Uploader.GetSampleLoadFileLocation(tableName) + objectsLocation, err := d.Uploader.GetSampleLoadFileLocation(ctx, tableName) if err != nil { return "", fmt.Errorf("getting sample load file location: %w", err) } @@ -958,7 +958,7 @@ func (d *Deltalake) LoadUserTables(ctx context.Context) map[string]error { } } - defer d.dropStagingTables([]string{identifyStagingTable}) + defer d.dropStagingTables(ctx, []string{identifyStagingTable}) if len(usersSchemaInUpload) == 0 { return map[string]error{ @@ -1034,7 +1034,7 @@ func (d *Deltalake) LoadUserTables(ctx context.Context) map[string]error { } } - defer d.dropStagingTables([]string{stagingTableName}) + defer d.dropStagingTables(ctx, []string{stagingTableName}) columnKeys := append([]string{`id`}, userColNames...) @@ -1171,25 +1171,25 @@ func getColumnProperties(usersSchemaInWarehouse model.TableSchema) ([]string, [] } // LoadIdentityMergeRulesTable loads identifies merge rules tables -func (*Deltalake) LoadIdentityMergeRulesTable() error { +func (*Deltalake) LoadIdentityMergeRulesTable(context.Context) error { return nil } // LoadIdentityMappingsTable loads identifies mappings table -func (*Deltalake) LoadIdentityMappingsTable() error { +func (*Deltalake) LoadIdentityMappingsTable(context.Context) error { return nil } // Cleanup cleans up the warehouse -func (d *Deltalake) Cleanup() { +func (d *Deltalake) Cleanup(ctx context.Context) { if d.DB != nil { - d.dropDanglingStagingTables() + d.dropDanglingStagingTables(ctx) _ = d.DB.Close() } } // IsEmpty checks if the warehouse is empty or not -func (*Deltalake) IsEmpty(model.Warehouse) (bool, error) { +func (*Deltalake) IsEmpty(context.Context, model.Warehouse) (bool, error) { return false, nil } @@ -1207,7 +1207,7 @@ func (d *Deltalake) TestConnection(ctx context.Context, _ model.Warehouse) error } // DownloadIdentityRules downloadchecking if schema exists identity rules -func (*Deltalake) DownloadIdentityRules(*misc.GZipWriter) error { +func (*Deltalake) DownloadIdentityRules(context.Context, *misc.GZipWriter) error { return nil } @@ -1233,7 +1233,7 @@ func (d *Deltalake) GetTotalCountInTable(ctx context.Context, tableName string) } // Connect returns Client -func (d *Deltalake) Connect(warehouse model.Warehouse) (warehouseclient.Client, error) { +func (d *Deltalake) Connect(_ context.Context, warehouse model.Warehouse) (warehouseclient.Client, error) { d.Warehouse = warehouse d.Namespace = warehouse.Namespace d.ObjectStorage = warehouseutils.ObjectStorageType( @@ -1251,7 +1251,7 @@ func (d *Deltalake) Connect(warehouse model.Warehouse) (warehouseclient.Client, } // LoadTestTable loads the test table -func (d *Deltalake) LoadTestTable(location, tableName string, _ map[string]interface{}, format string) error { +func (d *Deltalake) LoadTestTable(ctx context.Context, location, tableName string, _ map[string]interface{}, format string) error { auth, err := d.authQuery() if err != nil { return fmt.Errorf("auth query: %w", err) @@ -1308,7 +1308,7 @@ func (d *Deltalake) LoadTestTable(location, tableName string, _ map[string]inter ) } - _, err = d.DB.Exec(query) + _, err = d.DB.ExecContext(ctx, query) if err != nil { return fmt.Errorf("loading test table: %w", err) } @@ -1327,10 +1327,10 @@ func (*Deltalake) ErrorMappings() []model.JobError { } // DropTable drops a table in the warehouse -func (d *Deltalake) DropTable(tableName string) error { - return d.dropTable(tableName) +func (d *Deltalake) DropTable(ctx context.Context, tableName string) error { + return d.dropTable(ctx, tableName) } -func (*Deltalake) DeleteBy([]string, warehouseutils.DeleteByParams) error { +func (*Deltalake) DeleteBy(context.Context, []string, warehouseutils.DeleteByParams) error { return fmt.Errorf(warehouseutils.NotImplementedErrorCode) } diff --git a/warehouse/integrations/deltalake/client/client.go b/warehouse/integrations/deltalake/client/client.go index 6d7aefcda9..628c0e1d39 100644 --- a/warehouse/integrations/deltalake/client/client.go +++ b/warehouse/integrations/deltalake/client/client.go @@ -19,14 +19,13 @@ type Client struct { Logger logger.Logger CredConfig *proto.ConnectionConfig CredIdentifier string - Context context.Context Conn *grpc.ClientConn Client proto.DatabricksClient } // Close closes sql connection as well as closes grpc connection -func (client *Client) Close() { - closeConnectionResponse, err := client.Client.Close(client.Context, &proto.CloseRequest{ +func (client *Client) Close(ctx context.Context) { + closeConnectionResponse, err := client.Client.Close(ctx, &proto.CloseRequest{ Config: client.CredConfig, Identifier: client.CredIdentifier, }) @@ -36,5 +35,5 @@ func (client *Client) Close() { if closeConnectionResponse.GetErrorCode() != "" { client.Logger.Errorf("Error closing connection in delta lake with response: %v", err, closeConnectionResponse.GetErrorMessage()) } - client.Conn.Close() + _ = client.Conn.Close() } diff --git a/warehouse/integrations/deltalake/deltalake.go b/warehouse/integrations/deltalake/deltalake.go index 05b3fa53ea..63ec95e95c 100644 --- a/warehouse/integrations/deltalake/deltalake.go +++ b/warehouse/integrations/deltalake/deltalake.go @@ -206,8 +206,7 @@ func checkAndIgnoreAlreadyExistError(errorCode, ignoreError string) bool { } // NewClient creates deltalake client -func (dl *Deltalake) NewClient(cred *client.Credentials, connectTimeout time.Duration) (Client *client.Client, err error) { - ctx := context.Background() +func (dl *Deltalake) NewClient(ctx context.Context, cred *client.Credentials, connectTimeout time.Duration) (Client *client.Client, err error) { identifier := misc.FastUUID().String() connConfig := &proto.ConnectionConfig{ Host: cred.Host, @@ -265,27 +264,26 @@ func (dl *Deltalake) NewClient(cred *client.Credentials, connectTimeout time.Dur CredIdentifier: identifier, Conn: conn, Client: dbClient, - Context: ctx, } // Setting up catalog at the client level if catalog := warehouseutils.GetConfigValue(Catalog, dl.Warehouse); catalog != "" { sqlStatement := fmt.Sprintf("USE CATALOG `%s`;", catalog) - if err = dl.ExecuteSQLClient(Client, sqlStatement); err != nil { + if err = dl.ExecuteSQLClient(ctx, Client, sqlStatement); err != nil { return } } return } -func (*Deltalake) DeleteBy([]string, warehouseutils.DeleteByParams) error { +func (*Deltalake) DeleteBy(context.Context, []string, warehouseutils.DeleteByParams) error { return fmt.Errorf(warehouseutils.NotImplementedErrorCode) } // fetchTables fetch tables with tableNames -func (dl *Deltalake) fetchTables(dbT *client.Client, schema string) (tableNames []string, err error) { - fetchTableResponse, err := dbT.Client.FetchTables(dbT.Context, &proto.FetchTablesRequest{ +func (dl *Deltalake) fetchTables(ctx context.Context, dbT *client.Client, schema string) (tableNames []string, err error) { + fetchTableResponse, err := dbT.Client.FetchTables(ctx, &proto.FetchTablesRequest{ Config: dbT.CredConfig, Identifier: dbT.CredIdentifier, Schema: schema, @@ -302,10 +300,10 @@ func (dl *Deltalake) fetchTables(dbT *client.Client, schema string) (tableNames } // fetchPartitionColumns return the partition columns for the corresponding tables -func (dl *Deltalake) fetchPartitionColumns(dbT *client.Client, tableName string) ([]string, error) { +func (dl *Deltalake) fetchPartitionColumns(ctx context.Context, dbT *client.Client, tableName string) ([]string, error) { sqlStatement := fmt.Sprintf(`SHOW PARTITIONS %s.%s`, dl.Warehouse.Namespace, tableName) - columnsResponse, err := dbT.Client.FetchPartitionColumns(dbT.Context, &proto.FetchPartitionColumnsRequest{ + columnsResponse, err := dbT.Client.FetchPartitionColumns(ctx, &proto.FetchPartitionColumnsRequest{ Config: dbT.CredConfig, Identifier: dbT.CredIdentifier, SqlStatement: sqlStatement, @@ -328,12 +326,12 @@ func isPartitionedByEventDate(partitionedColumns []string) bool { // Checks whether the table is partition with event_date column // If specified, then calculates the date range from first and last event at and add it IN predicate query for event_date // If not specified, them returns empty string -func (dl *Deltalake) partitionQuery(tableName string) (string, error) { +func (dl *Deltalake) partitionQuery(ctx context.Context, tableName string) (string, error) { if !dl.EnablePartitionPruning { return "", nil } - partitionColumns, err := dl.fetchPartitionColumns(dl.Client, tableName) + partitionColumns, err := dl.fetchPartitionColumns(ctx, dl.Client, tableName) if err != nil { return "", fmt.Errorf("failed to prepare partition query, error: %w", err) } @@ -357,8 +355,8 @@ func (dl *Deltalake) partitionQuery(tableName string) (string, error) { } // ExecuteSQLClient executes sql client using grpc Client -func (*Deltalake) ExecuteSQLClient(client *client.Client, sqlStatement string) (err error) { - executeResponse, err := client.Client.Execute(client.Context, &proto.ExecuteRequest{ +func (*Deltalake) ExecuteSQLClient(ctx context.Context, client *client.Client, sqlStatement string) (err error) { + executeResponse, err := client.Client.Execute(ctx, &proto.ExecuteRequest{ Config: client.CredConfig, Identifier: client.CredIdentifier, SqlStatement: sqlStatement, @@ -374,9 +372,9 @@ func (*Deltalake) ExecuteSQLClient(client *client.Client, sqlStatement string) ( } // schemaExists checks it schema exists or not. -func (dl *Deltalake) schemaExists(schemaName string) (exists bool, err error) { +func (dl *Deltalake) schemaExists(ctx context.Context, schemaName string) (exists bool, err error) { sqlStatement := fmt.Sprintf(`SHOW SCHEMAS LIKE '%s';`, schemaName) - fetchSchemasResponse, err := dl.Client.Client.FetchSchemas(dl.Client.Context, &proto.FetchSchemasRequest{ + fetchSchemasResponse, err := dl.Client.Client.FetchSchemas(ctx, &proto.FetchSchemasRequest{ Config: dl.Client.CredConfig, Identifier: dl.Client.CredIdentifier, SqlStatement: sqlStatement, @@ -393,19 +391,19 @@ func (dl *Deltalake) schemaExists(schemaName string) (exists bool, err error) { } // createSchema creates schema -func (dl *Deltalake) createSchema() (err error) { +func (dl *Deltalake) createSchema(ctx context.Context) (err error) { sqlStatement := fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %s;`, dl.Namespace) dl.Logger.Infof("%s Creating schema in delta lake with SQL:%v", dl.GetLogIdentifier(), sqlStatement) - err = dl.ExecuteSQLClient(dl.Client, sqlStatement) + err = dl.ExecuteSQLClient(ctx, dl.Client, sqlStatement) return } // dropStagingTables drops staging tables -func (dl *Deltalake) dropStagingTables(tableNames []string) { +func (dl *Deltalake) dropStagingTables(ctx context.Context, tableNames []string) { for _, stagingTableName := range tableNames { dl.Logger.Infof("%s Dropping table %+v\n", dl.GetLogIdentifier(), stagingTableName) sqlStatement := fmt.Sprintf(`DROP TABLE %[1]s.%[2]s;`, dl.Namespace, stagingTableName) - dropTableResponse, err := dl.Client.Client.Execute(dl.Client.Context, &proto.ExecuteRequest{ + dropTableResponse, err := dl.Client.Client.Execute(ctx, &proto.ExecuteRequest{ Config: dl.Client.CredConfig, Identifier: dl.Client.CredIdentifier, SqlStatement: sqlStatement, @@ -498,20 +496,20 @@ func getTableSchemaDiff(tableSchemaInUpload, tableSchemaAfterUpload model.TableS } // loadTable Loads table with table name -func (dl *Deltalake) loadTable(tableName string, tableSchemaInUpload, tableSchemaAfterUpload model.TableSchema, skipTempTableDelete bool) (stagingTableName string, err error) { +func (dl *Deltalake) loadTable(ctx context.Context, tableName string, tableSchemaInUpload, tableSchemaAfterUpload model.TableSchema, skipTempTableDelete bool) (stagingTableName string, err error) { // Getting sorted column keys from tableSchemaInUpload sortedColumnKeys := warehouseutils.SortColumnKeysFromColumnMap(tableSchemaInUpload) // Creating staging table stagingTableName = warehouseutils.StagingTableName(provider, tableName, tableNameLimit) - err = dl.CreateTable(stagingTableName, tableSchemaAfterUpload) + err = dl.CreateTable(ctx, stagingTableName, tableSchemaAfterUpload) if err != nil { return } // Dropping staging tables if required if !skipTempTableDelete { - defer dl.dropStagingTables([]string{stagingTableName}) + defer dl.dropStagingTables(ctx, []string{stagingTableName}) } // Get the credentials string to copy from the staging location to table @@ -521,7 +519,7 @@ func (dl *Deltalake) loadTable(tableName string, tableSchemaInUpload, tableSchem } // Getting objects location - objectsLocation, err := dl.Uploader.GetSampleLoadFileLocation(tableName) + objectsLocation, err := dl.Uploader.GetSampleLoadFileLocation(ctx, tableName) if err != nil { return } @@ -560,7 +558,7 @@ func (dl *Deltalake) loadTable(tableName string, tableSchemaInUpload, tableSchem } // Executing copy sql statement - err = dl.ExecuteSQLClient(dl.Client, sqlStatement) + err = dl.ExecuteSQLClient(ctx, dl.Client, sqlStatement) if err != nil { dl.Logger.Errorf("%s Error running COPY command with SQL: %s\n error: %v", dl.GetLogIdentifier(tableName), sqlStatement, err) return @@ -576,7 +574,7 @@ func (dl *Deltalake) loadTable(tableName string, tableSchemaInUpload, tableSchem } else { // Partition query var partitionQuery string - partitionQuery, err = dl.partitionQuery(tableName) + partitionQuery, err = dl.partitionQuery(ctx, tableName) if err != nil { err = fmt.Errorf("failed getting partition query during load table, error: %w", err) return @@ -593,7 +591,7 @@ func (dl *Deltalake) loadTable(tableName string, tableSchemaInUpload, tableSchem dl.Logger.Infof("%v Inserting records using staging table with SQL: %s\n", dl.GetLogIdentifier(tableName), sqlStatement) // Executing load table sql statement - err = dl.ExecuteSQLClient(dl.Client, sqlStatement) + err = dl.ExecuteSQLClient(ctx, dl.Client, sqlStatement) if err != nil { dl.Logger.Errorf("%v Error inserting into original table: %v\n", dl.GetLogIdentifier(tableName), err) return @@ -604,20 +602,20 @@ func (dl *Deltalake) loadTable(tableName string, tableSchemaInUpload, tableSchem } // loadUserTables Loads users table -func (dl *Deltalake) loadUserTables() (errorMap map[string]error) { +func (dl *Deltalake) loadUserTables(ctx context.Context) (errorMap map[string]error) { // Creating errorMap errorMap = map[string]error{warehouseutils.IdentifiesTable: nil} dl.Logger.Infof("%s Starting load for identifies and users tables\n", dl.GetLogIdentifier()) // Loading identifies tables - identifyStagingTable, err := dl.loadTable(warehouseutils.IdentifiesTable, dl.Uploader.GetTableSchemaInUpload(warehouseutils.IdentifiesTable), dl.Uploader.GetTableSchemaInWarehouse(warehouseutils.IdentifiesTable), true) + identifyStagingTable, err := dl.loadTable(ctx, warehouseutils.IdentifiesTable, dl.Uploader.GetTableSchemaInUpload(warehouseutils.IdentifiesTable), dl.Uploader.GetTableSchemaInWarehouse(warehouseutils.IdentifiesTable), true) if err != nil { errorMap[warehouseutils.IdentifiesTable] = err return } // dropping identifies staging table - defer dl.dropStagingTables([]string{identifyStagingTable}) + defer dl.dropStagingTables(ctx, []string{identifyStagingTable}) // Checking if users schema is present in GetTableSchemaInUpload if len(dl.Uploader.GetTableSchemaInUpload(warehouseutils.UsersTable)) == 0 { @@ -665,7 +663,7 @@ func (dl *Deltalake) loadUserTables() (errorMap map[string]error) { ) // Executing create sql statement - err = dl.ExecuteSQLClient(dl.Client, sqlStatement) + err = dl.ExecuteSQLClient(ctx, dl.Client, sqlStatement) if err != nil { dl.Logger.Errorf("%s Creating staging table for users failed with SQL: %s\n", dl.GetLogIdentifier(), sqlStatement) dl.Logger.Errorf("%s Error creating users staging table from original table and identifies staging table: %v\n", dl.GetLogIdentifier(), err) @@ -674,7 +672,7 @@ func (dl *Deltalake) loadUserTables() (errorMap map[string]error) { } // Dropping staging users table - defer dl.dropStagingTables([]string{stagingTableName}) + defer dl.dropStagingTables(ctx, []string{stagingTableName}) // Creating the column Keys columnKeys := append([]string{`id`}, userColNames...) @@ -689,7 +687,7 @@ func (dl *Deltalake) loadUserTables() (errorMap map[string]error) { } else { // Partition query var partitionQuery string - partitionQuery, err = dl.partitionQuery(warehouseutils.UsersTable) + partitionQuery, err = dl.partitionQuery(ctx, warehouseutils.UsersTable) if err != nil { err = fmt.Errorf("failed getting partition query during load users table, error: %w", err) errorMap[warehouseutils.UsersTable] = err @@ -707,7 +705,7 @@ func (dl *Deltalake) loadUserTables() (errorMap map[string]error) { dl.Logger.Infof("%s Inserting records using staging table with SQL: %s\n", dl.GetLogIdentifier(warehouseutils.UsersTable), sqlStatement) // Executing the load users table sql statement - err = dl.ExecuteSQLClient(dl.Client, sqlStatement) + err = dl.ExecuteSQLClient(ctx, dl.Client, sqlStatement) if err != nil { dl.Logger.Errorf("%s Error inserting into users table from staging table: %v\n", err) errorMap[warehouseutils.UsersTable] = err @@ -736,9 +734,9 @@ func (dl *Deltalake) getTableLocationSql(tableName string) (tableLocation string } // dropDanglingStagingTables drop dandling staging tables. -func (dl *Deltalake) dropDanglingStagingTables() { +func (dl *Deltalake) dropDanglingStagingTables(ctx context.Context) { // Fetching the staging tables - tableNames, err := dl.fetchTables(dl.Client, dl.Namespace) + tableNames, err := dl.fetchTables(ctx, dl.Client, dl.Namespace) if err != nil { return } @@ -754,22 +752,22 @@ func (dl *Deltalake) dropDanglingStagingTables() { } // Drop staging tables - dl.dropStagingTables(filteredTablesNames) + dl.dropStagingTables(ctx, filteredTablesNames) } // connectToWarehouse returns the database connection configured with Credentials -func (dl *Deltalake) connectToWarehouse() (Client *client.Client, err error) { +func (dl *Deltalake) connectToWarehouse(ctx context.Context) (Client *client.Client, err error) { credT := &client.Credentials{ Host: warehouseutils.GetConfigValue(Host, dl.Warehouse), Port: warehouseutils.GetConfigValue(Port, dl.Warehouse), Path: warehouseutils.GetConfigValue(Path, dl.Warehouse), Token: warehouseutils.GetConfigValue(Token, dl.Warehouse), } - return dl.NewClient(credT, dl.ConnectTimeout) + return dl.NewClient(ctx, credT, dl.ConnectTimeout) } // CreateTable creates tables with table name and columns -func (dl *Deltalake) CreateTable(tableName string, columns model.TableSchema) (err error) { +func (dl *Deltalake) CreateTable(ctx context.Context, tableName string, columns model.TableSchema) (err error) { name := fmt.Sprintf(`%s.%s`, dl.Namespace, tableName) tableLocationSql := dl.getTableLocationSql(tableName) @@ -785,14 +783,14 @@ func (dl *Deltalake) CreateTable(tableName string, columns model.TableSchema) (e sqlStatement := fmt.Sprintf(`%s %s ( %v ) USING DELTA %s %s;`, createTableClauseSql, name, ColumnsWithDataTypes(columns, ""), tableLocationSql, partitionedSql) dl.Logger.Infof("%s Creating table in delta lake with SQL: %v", dl.GetLogIdentifier(tableName), sqlStatement) - err = dl.ExecuteSQLClient(dl.Client, sqlStatement) + err = dl.ExecuteSQLClient(ctx, dl.Client, sqlStatement) return } -func (dl *Deltalake) DropTable(tableName string) (err error) { +func (dl *Deltalake) DropTable(ctx context.Context, tableName string) (err error) { dl.Logger.Infof("%s Dropping table %s", dl.GetLogIdentifier(), tableName) sqlStatement := fmt.Sprintf(`DROP TABLE %[1]s.%[2]s;`, dl.Namespace, tableName) - dropTableResponse, err := dl.Client.Client.Execute(dl.Client.Context, &proto.ExecuteRequest{ + dropTableResponse, err := dl.Client.Client.Execute(ctx, &proto.ExecuteRequest{ Config: dl.Client.CredConfig, Identifier: dl.Client.CredIdentifier, SqlStatement: sqlStatement, @@ -807,7 +805,7 @@ func (dl *Deltalake) DropTable(tableName string) (err error) { return } -func (dl *Deltalake) AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) error { +func (dl *Deltalake) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) error { var ( query string queryBuilder strings.Builder @@ -830,7 +828,7 @@ func (dl *Deltalake) AddColumns(tableName string, columnsInfo []warehouseutils.C query += ");" dl.Logger.Infof("DL: Adding columns for destinationID: %s, tableName: %s with query: %v", dl.Warehouse.Destination.ID, tableName, query) - executeResponse, err := dl.Client.Client.Execute(dl.Client.Context, &proto.ExecuteRequest{ + executeResponse, err := dl.Client.Client.Execute(ctx, &proto.ExecuteRequest{ Config: dl.Client.CredConfig, Identifier: dl.Client.CredIdentifier, SqlStatement: query, @@ -854,10 +852,10 @@ func (dl *Deltalake) AddColumns(tableName string, columnsInfo []warehouseutils.C } // CreateSchema checks if schema exists or not. If it does not exist, it creates the schema. -func (dl *Deltalake) CreateSchema() (err error) { +func (dl *Deltalake) CreateSchema(ctx context.Context) (err error) { // Checking if schema exists or not var schemaExists bool - schemaExists, err = dl.schemaExists(dl.Namespace) + schemaExists, err = dl.schemaExists(ctx, dl.Namespace) if err != nil { dl.Logger.Errorf("%s Error checking if schema exists: %s, error: %v", dl.GetLogIdentifier(), dl.Namespace, err) return err @@ -868,22 +866,22 @@ func (dl *Deltalake) CreateSchema() (err error) { } // Creating schema - return dl.createSchema() + return dl.createSchema(ctx) } // AlterColumn alter table with column name and type -func (*Deltalake) AlterColumn(_, _, _ string) (model.AlterTableResponse, error) { +func (*Deltalake) AlterColumn(context.Context, string, string, string) (model.AlterTableResponse, error) { return model.AlterTableResponse{}, nil } // FetchSchema queries delta lake and returns the schema associated with provided namespace -func (dl *Deltalake) FetchSchema() (model.Schema, model.Schema, error) { +func (dl *Deltalake) FetchSchema(ctx context.Context) (model.Schema, model.Schema, error) { // Schema Initialization schema := make(model.Schema) unrecognizedSchema := make(model.Schema) // Fetching the tables - tableNames, err := dl.fetchTables(dl.Client, dl.Namespace) + tableNames, err := dl.fetchTables(ctx, dl.Client, dl.Namespace) if err != nil { return nil, nil, fmt.Errorf("fetching tables: %w", err) } @@ -900,7 +898,7 @@ func (dl *Deltalake) FetchSchema() (model.Schema, model.Schema, error) { // For each table we are generating schema for _, tableName := range filteredTablesNames { - fetchTableAttributesResponse, err := dl.Client.Client.FetchTableAttributes(dl.Client.Context, &proto.FetchTableAttributesRequest{ + fetchTableAttributesResponse, err := dl.Client.Client.FetchTableAttributes(ctx, &proto.FetchTableAttributesRequest{ Config: dl.Client.CredConfig, Identifier: dl.Client.CredIdentifier, Schema: dl.Namespace, @@ -938,13 +936,13 @@ func (dl *Deltalake) FetchSchema() (model.Schema, model.Schema, error) { } // Setup populate the Deltalake -func (dl *Deltalake) Setup(warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { +func (dl *Deltalake) Setup(ctx context.Context, warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { dl.Warehouse = warehouse dl.Namespace = warehouse.Namespace dl.Uploader = uploader dl.ObjectStorage = warehouseutils.ObjectStorageType(warehouseutils.DELTALAKE, warehouse.Destination.Config, dl.Uploader.UseRudderStorage()) - dl.Client, err = dl.connectToWarehouse() + dl.Client, err = dl.connectToWarehouse(ctx) return err } @@ -954,46 +952,46 @@ func (*Deltalake) TestConnection(context.Context, model.Warehouse) error { } // Cleanup cleanup when upload is done. -func (dl *Deltalake) Cleanup() { +func (dl *Deltalake) Cleanup(ctx context.Context) { if dl.Client != nil { - dl.dropDanglingStagingTables() - dl.Client.Close() + dl.dropDanglingStagingTables(ctx) + dl.Client.Close(ctx) } } // CrashRecover crash recover scenarios -func (dl *Deltalake) CrashRecover() { - dl.dropDanglingStagingTables() +func (dl *Deltalake) CrashRecover(ctx context.Context) { + dl.dropDanglingStagingTables(ctx) } // IsEmpty checks if the warehouse is empty or not -func (*Deltalake) IsEmpty(model.Warehouse) (empty bool, err error) { +func (*Deltalake) IsEmpty(context.Context, model.Warehouse) (empty bool, err error) { return } // LoadUserTables loads user tables -func (dl *Deltalake) LoadUserTables(context.Context) map[string]error { - return dl.loadUserTables() +func (dl *Deltalake) LoadUserTables(ctx context.Context) map[string]error { + return dl.loadUserTables(ctx) } // LoadTable loads table for table name -func (dl *Deltalake) LoadTable(_ context.Context, tableName string) error { - _, err := dl.loadTable(tableName, dl.Uploader.GetTableSchemaInUpload(tableName), dl.Uploader.GetTableSchemaInWarehouse(tableName), false) +func (dl *Deltalake) LoadTable(ctx context.Context, tableName string) error { + _, err := dl.loadTable(ctx, tableName, dl.Uploader.GetTableSchemaInUpload(tableName), dl.Uploader.GetTableSchemaInWarehouse(tableName), false) return err } // LoadIdentityMergeRulesTable loads identifies merge rules tables -func (*Deltalake) LoadIdentityMergeRulesTable() (err error) { +func (*Deltalake) LoadIdentityMergeRulesTable(context.Context) (err error) { return } // LoadIdentityMappingsTable loads identifies mappings table -func (*Deltalake) LoadIdentityMappingsTable() (err error) { +func (*Deltalake) LoadIdentityMappingsTable(context.Context) (err error) { return } // DownloadIdentityRules download identity rules -func (*Deltalake) DownloadIdentityRules(*misc.GZipWriter) (err error) { +func (*Deltalake) DownloadIdentityRules(context.Context, *misc.GZipWriter) (err error) { return } @@ -1024,7 +1022,7 @@ func (dl *Deltalake) GetTotalCountInTable(ctx context.Context, tableName string) } // Connect returns Client -func (dl *Deltalake) Connect(warehouse model.Warehouse) (warehouseclient.Client, error) { +func (dl *Deltalake) Connect(ctx context.Context, warehouse model.Warehouse) (warehouseclient.Client, error) { dl.Warehouse = warehouse dl.Namespace = warehouse.Namespace dl.ObjectStorage = warehouseutils.ObjectStorageType( @@ -1032,7 +1030,7 @@ func (dl *Deltalake) Connect(warehouse model.Warehouse) (warehouseclient.Client, warehouse.Destination.Config, misc.IsConfiguredToUseRudderObjectStorage(dl.Warehouse.Destination.Config), ) - Client, err := dl.connectToWarehouse() + Client, err := dl.connectToWarehouse(ctx) if err != nil { return warehouseclient.Client{}, err } @@ -1049,8 +1047,7 @@ func (dl *Deltalake) GetLogIdentifier(args ...string) string { } // GetDatabricksVersion Gets the databricks version by making a grpc call to Version stub. -func GetDatabricksVersion() (databricksBuildVersion string) { - ctx := context.Background() +func GetDatabricksVersion(ctx context.Context) (databricksBuildVersion string) { connectorURL := config.GetString("DATABRICKS_CONNECTOR_URL", "localhost:50051") conn, err := grpc.DialContext(ctx, connectorURL, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -1069,7 +1066,7 @@ func GetDatabricksVersion() (databricksBuildVersion string) { return } -func (dl *Deltalake) LoadTestTable(location, tableName string, _ map[string]interface{}, format string) (err error) { +func (dl *Deltalake) LoadTestTable(ctx context.Context, location, tableName string, _ map[string]interface{}, format string) (err error) { // Get the credentials string to copy from the staging location to table auth, err := dl.credentialsStr() if err != nil { @@ -1104,7 +1101,7 @@ func (dl *Deltalake) LoadTestTable(location, tableName string, _ map[string]inte ) } - err = dl.ExecuteSQLClient(dl.Client, sqlStatement) + err = dl.ExecuteSQLClient(ctx, dl.Client, sqlStatement) return } diff --git a/warehouse/integrations/manager/manager.go b/warehouse/integrations/manager/manager.go index 223a74c34d..60faa7ef63 100644 --- a/warehouse/integrations/manager/manager.go +++ b/warehouse/integrations/manager/manager.go @@ -28,31 +28,31 @@ import ( ) type Manager interface { - Setup(warehouse model.Warehouse, uploader warehouseutils.Uploader) error - CrashRecover() - FetchSchema() (model.Schema, model.Schema, error) - CreateSchema() (err error) - CreateTable(tableName string, columnMap model.TableSchema) (err error) - AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) - AlterColumn(tableName, columnName, columnType string) (model.AlterTableResponse, error) + Setup(ctx context.Context, warehouse model.Warehouse, uploader warehouseutils.Uploader) error + CrashRecover(ctx context.Context) + FetchSchema(ctx context.Context) (model.Schema, model.Schema, error) + CreateSchema(ctx context.Context) (err error) + CreateTable(ctx context.Context, tableName string, columnMap model.TableSchema) (err error) + AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) + AlterColumn(ctx context.Context, tableName, columnName, columnType string) (model.AlterTableResponse, error) LoadTable(ctx context.Context, tableName string) error LoadUserTables(ctx context.Context) map[string]error - LoadIdentityMergeRulesTable() error - LoadIdentityMappingsTable() error - Cleanup() - IsEmpty(warehouse model.Warehouse) (bool, error) + LoadIdentityMergeRulesTable(ctx context.Context) error + LoadIdentityMappingsTable(ctx context.Context) error + Cleanup(ctx context.Context) + IsEmpty(ctx context.Context, warehouse model.Warehouse) (bool, error) TestConnection(ctx context.Context, warehouse model.Warehouse) error - DownloadIdentityRules(*misc.GZipWriter) error + DownloadIdentityRules(ctx context.Context, gzWriter *misc.GZipWriter) error GetTotalCountInTable(ctx context.Context, tableName string) (int64, error) - Connect(warehouse model.Warehouse) (client.Client, error) - LoadTestTable(location, stagingTableName string, payloadMap map[string]interface{}, loadFileFormat string) error + Connect(ctx context.Context, warehouse model.Warehouse) (client.Client, error) + LoadTestTable(ctx context.Context, location, stagingTableName string, payloadMap map[string]interface{}, loadFileFormat string) error SetConnectionTimeout(timeout time.Duration) ErrorMappings() []model.JobError } type WarehouseDelete interface { - DropTable(tableName string) (err error) - DeleteBy(tableName []string, params warehouseutils.DeleteByParams) error + DropTable(ctx context.Context, tableName string) (err error) + DeleteBy(ctx context.Context, tableName []string, params warehouseutils.DeleteByParams) error } type WarehouseOperations interface { diff --git a/warehouse/integrations/mssql/mssql.go b/warehouse/integrations/mssql/mssql.go index eb9767ed3d..e4d17afc08 100644 --- a/warehouse/integrations/mssql/mssql.go +++ b/warehouse/integrations/mssql/mssql.go @@ -193,11 +193,11 @@ func ColumnsWithDataTypes(columns model.TableSchema, prefix string) string { return strings.Join(arr, ",") } -func (*MSSQL) IsEmpty(_ model.Warehouse) (empty bool, err error) { +func (*MSSQL) IsEmpty(context.Context, model.Warehouse) (empty bool, err error) { return } -func (ms *MSSQL) DeleteBy(tableNames []string, params warehouseutils.DeleteByParams) (err error) { +func (ms *MSSQL) DeleteBy(ctx context.Context, tableNames []string, params warehouseutils.DeleteByParams) (err error) { for _, tb := range tableNames { ms.Logger.Infof("MSSQL: Cleaning up the table %q ", tb) sqlStatement := fmt.Sprintf(`DELETE FROM "%[1]s"."%[2]s" WHERE @@ -213,7 +213,7 @@ func (ms *MSSQL) DeleteBy(tableNames []string, params warehouseutils.DeleteByPar ms.Logger.Infof("MSSQL: Executing the statement %v", sqlStatement) if ms.EnableDeleteByJobs { - _, err = ms.DB.Exec(sqlStatement, + _, err = ms.DB.ExecContext(ctx, sqlStatement, sql.Named("jobrunid", params.JobRunId), sql.Named("taskrunid", params.TaskRunId), sql.Named("sourceid", params.SourceId), @@ -258,17 +258,17 @@ func (ms *MSSQL) loadTable(ctx context.Context, tableName string, tableSchemaInU _, err = txn.ExecContext(ctx, sqlStatement) if err != nil { ms.Logger.Errorf("MSSQL: Error creating temporary table for table:%s: %v\n", tableName, err) - txn.Rollback() + _ = txn.Rollback() return } if !skipTempTableDelete { - defer ms.dropStagingTable(stagingTableName) + defer ms.dropStagingTable(ctx, stagingTableName) } stmt, err := txn.PrepareContext(ctx, mssql.CopyIn(ms.Namespace+"."+stagingTableName, mssql.BulkOptions{CheckConstraints: false}, sortedColumnKeys...)) if err != nil { ms.Logger.Errorf("MSSQL: Error while preparing statement for transaction in db for loading in staging table:%s: %v\nstmt: %v", stagingTableName, err, stmt) - txn.Rollback() + _ = txn.Rollback() return } for _, objectFileName := range fileNames { @@ -276,7 +276,7 @@ func (ms *MSSQL) loadTable(ctx context.Context, tableName string, tableSchemaInU gzipFile, err = os.Open(objectFileName) if err != nil { ms.Logger.Errorf("MSSQL: Error opening file using os.Open for file:%s while loading to table %s", objectFileName, tableName) - txn.Rollback() + _ = txn.Rollback() return } @@ -285,7 +285,7 @@ func (ms *MSSQL) loadTable(ctx context.Context, tableName string, tableSchemaInU if err != nil { ms.Logger.Errorf("MSSQL: Error reading file using gzip.NewReader for file:%s while loading to table %s", gzipFile, tableName) gzipFile.Close() - txn.Rollback() + _ = txn.Rollback() return } @@ -300,13 +300,13 @@ func (ms *MSSQL) loadTable(ctx context.Context, tableName string, tableSchemaInU break } ms.Logger.Errorf("MSSQL: Error while reading csv file %s for loading in staging table:%s: %v", objectFileName, stagingTableName, err) - txn.Rollback() + _ = txn.Rollback() return } if len(sortedColumnKeys) != len(record) { err = fmt.Errorf(`load file CSV columns for a row mismatch number found in upload schema. Columns in CSV row: %d, Columns in upload schema of table-%s: %d. Processed rows in csv file until mismatch: %d`, len(record), tableName, len(sortedColumnKeys), csvRowsProcessedCount) ms.Logger.Error(err) - txn.Rollback() + _ = txn.Rollback() return } var recordInterface []interface{} @@ -389,18 +389,18 @@ func (ms *MSSQL) loadTable(ctx context.Context, tableName string, tableSchemaInU _, err = stmt.ExecContext(ctx, finalColumnValues...) if err != nil { ms.Logger.Errorf("MSSQL: Error in exec statement for loading in staging table:%s: %v", stagingTableName, err) - txn.Rollback() + _ = txn.Rollback() return } csvRowsProcessedCount++ } - gzipReader.Close() + _ = gzipReader.Close() gzipFile.Close() } _, err = stmt.ExecContext(ctx) if err != nil { - txn.Rollback() + _ = txn.Rollback() ms.Logger.Errorf("MSSQL: Rollback transaction as there was error while loading staging table:%s: %v", stagingTableName, err) return @@ -423,7 +423,7 @@ func (ms *MSSQL) loadTable(ctx context.Context, tableName string, tableSchemaInU _, err = txn.ExecContext(ctx, sqlStatement) if err != nil { ms.Logger.Errorf("MSSQL: Error deleting from original table for dedup: %v\n", err) - txn.Rollback() + _ = txn.Rollback() return } @@ -438,13 +438,13 @@ func (ms *MSSQL) loadTable(ctx context.Context, tableName string, tableSchemaInU if err != nil { ms.Logger.Errorf("MSSQL: Error inserting into original table: %v\n", err) - txn.Rollback() + _ = txn.Rollback() return } if err = txn.Commit(); err != nil { ms.Logger.Errorf("MSSQL: Error while committing transaction as there was error while loading staging table:%s: %v", stagingTableName, err) - txn.Rollback() + _ = txn.Rollback() return } @@ -488,9 +488,9 @@ func (ms *MSSQL) loadUserTables(ctx context.Context) (errorMap map[string]error) unionStagingTableName := warehouseutils.StagingTableName(provider, "users_identifies_union", tableNameLimit) stagingTableName := warehouseutils.StagingTableName(provider, warehouseutils.UsersTable, tableNameLimit) - defer ms.dropStagingTable(stagingTableName) - defer ms.dropStagingTable(unionStagingTableName) - defer ms.dropStagingTable(identifyStagingTable) + defer ms.dropStagingTable(ctx, stagingTableName) + defer ms.dropStagingTable(ctx, unionStagingTableName) + defer ms.dropStagingTable(ctx, identifyStagingTable) userColMap := ms.Uploader.GetTableSchemaInWarehouse(warehouseutils.UsersTable) var userColNames, firstValProps []string @@ -565,7 +565,7 @@ func (ms *MSSQL) loadUserTables(ctx context.Context) (errorMap map[string]error) _, err = tx.ExecContext(ctx, sqlStatement) if err != nil { ms.Logger.Errorf("MSSQL: Error deleting from original table for dedup: %v\n", err) - tx.Rollback() + _ = tx.Rollback() errorMap[warehouseutils.UsersTable] = err return } @@ -576,7 +576,7 @@ func (ms *MSSQL) loadUserTables(ctx context.Context) (errorMap map[string]error) if err != nil { ms.Logger.Errorf("MSSQL: Error inserting into users table from staging table: %v\n", err) - tx.Rollback() + _ = tx.Rollback() errorMap[warehouseutils.UsersTable] = err return } @@ -584,56 +584,56 @@ func (ms *MSSQL) loadUserTables(ctx context.Context) (errorMap map[string]error) err = tx.Commit() if err != nil { ms.Logger.Errorf("MSSQL: Error in transaction commit for users table: %v\n", err) - tx.Rollback() + _ = tx.Rollback() errorMap[warehouseutils.UsersTable] = err return } return } -func (ms *MSSQL) CreateSchema() (err error) { +func (ms *MSSQL) CreateSchema(ctx context.Context) (err error) { sqlStatement := fmt.Sprintf(`IF NOT EXISTS ( SELECT * FROM sys.schemas WHERE name = N'%s' ) EXEC('CREATE SCHEMA [%s]'); `, ms.Namespace, ms.Namespace) ms.Logger.Infof("MSSQL: Creating schema name in mssql for MSSQL:%s : %v", ms.Warehouse.Destination.ID, sqlStatement) - _, err = ms.DB.Exec(sqlStatement) + _, err = ms.DB.ExecContext(ctx, sqlStatement) if err == io.EOF { return nil } return } -func (ms *MSSQL) dropStagingTable(stagingTableName string) { +func (ms *MSSQL) dropStagingTable(ctx context.Context, stagingTableName string) { ms.Logger.Infof("MSSQL: dropping table %+v\n", stagingTableName) - _, err := ms.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s`, ms.Namespace+"."+stagingTableName)) + _, err := ms.DB.ExecContext(ctx, fmt.Sprintf(`DROP TABLE IF EXISTS %s`, ms.Namespace+"."+stagingTableName)) if err != nil { ms.Logger.Errorf("MSSQL: Error dropping staging table %s in mssql: %v", ms.Namespace+"."+stagingTableName, err) } } -func (ms *MSSQL) createTable(name string, columns model.TableSchema) (err error) { +func (ms *MSSQL) createTable(ctx context.Context, name string, columns model.TableSchema) (err error) { sqlStatement := fmt.Sprintf(`IF NOT EXISTS (SELECT 1 FROM sys.objects WHERE object_id = OBJECT_ID(N'%[1]s') AND type = N'U') CREATE TABLE %[1]s ( %v )`, name, ColumnsWithDataTypes(columns, "")) ms.Logger.Infof("MSSQL: Creating table in mssql for MSSQL:%s : %v", ms.Warehouse.Destination.ID, sqlStatement) - _, err = ms.DB.Exec(sqlStatement) + _, err = ms.DB.ExecContext(ctx, sqlStatement) return } -func (ms *MSSQL) CreateTable(tableName string, columnMap model.TableSchema) (err error) { +func (ms *MSSQL) CreateTable(ctx context.Context, tableName string, columnMap model.TableSchema) (err error) { // Search paths doesn't exist unlike Postgres, default is dbo. Hence, use namespace wherever possible - err = ms.createTable(ms.Namespace+"."+tableName, columnMap) + err = ms.createTable(ctx, ms.Namespace+"."+tableName, columnMap) return err } -func (ms *MSSQL) DropTable(tableName string) (err error) { +func (ms *MSSQL) DropTable(ctx context.Context, tableName string) (err error) { sqlStatement := `DROP TABLE "%[1]s"."%[2]s"` ms.Logger.Infof("AZ: Dropping table in synapse for AZ:%s : %v", ms.Warehouse.Destination.ID, sqlStatement) - _, err = ms.DB.Exec(fmt.Sprintf(sqlStatement, ms.Namespace, tableName)) + _, err = ms.DB.ExecContext(ctx, fmt.Sprintf(sqlStatement, ms.Namespace, tableName)) return } -func (ms *MSSQL) AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { +func (ms *MSSQL) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { var ( query string queryBuilder strings.Builder @@ -673,11 +673,11 @@ func (ms *MSSQL) AddColumns(tableName string, columnsInfo []warehouseutils.Colum query += ";" ms.Logger.Infof("MSSQL: Adding columns for destinationID: %s, tableName: %s with query: %v", ms.Warehouse.Destination.ID, tableName, query) - _, err = ms.DB.Exec(query) + _, err = ms.DB.ExecContext(ctx, query) return } -func (*MSSQL) AlterColumn(_, _, _ string) (model.AlterTableResponse, error) { +func (*MSSQL) AlterColumn(context.Context, string, string, string) (model.AlterTableResponse, error) { return model.AlterTableResponse{}, nil } @@ -693,7 +693,7 @@ func (ms *MSSQL) TestConnection(ctx context.Context, _ model.Warehouse) error { return nil } -func (ms *MSSQL) Setup(warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { +func (ms *MSSQL) Setup(_ context.Context, warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { ms.Warehouse = warehouse ms.Namespace = warehouse.Namespace ms.Uploader = uploader @@ -704,11 +704,11 @@ func (ms *MSSQL) Setup(warehouse model.Warehouse, uploader warehouseutils.Upload return err } -func (ms *MSSQL) CrashRecover() { - ms.dropDanglingStagingTables() +func (ms *MSSQL) CrashRecover(ctx context.Context) { + ms.dropDanglingStagingTables(ctx) } -func (ms *MSSQL) dropDanglingStagingTables() bool { +func (ms *MSSQL) dropDanglingStagingTables(ctx context.Context) bool { sqlStatement := fmt.Sprintf(` select table_name @@ -721,12 +721,12 @@ func (ms *MSSQL) dropDanglingStagingTables() bool { ms.Namespace, fmt.Sprintf(`%s%%`, warehouseutils.StagingTablePrefix(provider)), ) - rows, err := ms.DB.Query(sqlStatement) + rows, err := ms.DB.QueryContext(ctx, sqlStatement) if err != nil { ms.Logger.Errorf("WH: MSSQL: Error dropping dangling staging tables in MSSQL: %v\nQuery: %s\n", err, sqlStatement) return false } - defer rows.Close() + defer func() { _ = rows.Close() }() var stagingTableNames []string for rows.Next() { @@ -740,7 +740,7 @@ func (ms *MSSQL) dropDanglingStagingTables() bool { ms.Logger.Infof("WH: MSSQL: Dropping dangling staging tables: %+v %+v\n", len(stagingTableNames), stagingTableNames) delSuccess := true for _, stagingTableName := range stagingTableNames { - _, err := ms.DB.Exec(fmt.Sprintf(`DROP TABLE "%[1]s"."%[2]s"`, ms.Namespace, stagingTableName)) + _, err := ms.DB.ExecContext(ctx, fmt.Sprintf(`DROP TABLE "%[1]s"."%[2]s"`, ms.Namespace, stagingTableName)) if err != nil { ms.Logger.Errorf("WH: MSSQL: Error dropping dangling staging table: %s in redshift: %v\n", stagingTableName, err) delSuccess = false @@ -750,7 +750,7 @@ func (ms *MSSQL) dropDanglingStagingTables() bool { } // FetchSchema queries mssql and returns the schema associated with provided namespace -func (ms *MSSQL) FetchSchema() (model.Schema, model.Schema, error) { +func (ms *MSSQL) FetchSchema(ctx context.Context) (model.Schema, model.Schema, error) { schema := make(model.Schema) unrecognizedSchema := make(model.Schema) @@ -765,7 +765,7 @@ func (ms *MSSQL) FetchSchema() (model.Schema, model.Schema, error) { table_schema = @schema and table_name not like @prefix ` - rows, err := ms.DB.Query(sqlStatement, + rows, err := ms.DB.QueryContext(ctx, sqlStatement, sql.Named("schema", ms.Namespace), sql.Named("prefix", fmt.Sprintf("%s%%", warehouseutils.StagingTablePrefix(provider))), ) @@ -814,23 +814,23 @@ func (ms *MSSQL) LoadTable(ctx context.Context, tableName string) error { return err } -func (ms *MSSQL) Cleanup() { +func (ms *MSSQL) Cleanup(ctx context.Context) { if ms.DB != nil { // extra check aside dropStagingTable(table) - ms.dropDanglingStagingTables() - ms.DB.Close() + ms.dropDanglingStagingTables(ctx) + _ = ms.DB.Close() } } -func (*MSSQL) LoadIdentityMergeRulesTable() (err error) { +func (*MSSQL) LoadIdentityMergeRulesTable(context.Context) (err error) { return } -func (*MSSQL) LoadIdentityMappingsTable() (err error) { +func (*MSSQL) LoadIdentityMappingsTable(context.Context) (err error) { return } -func (*MSSQL) DownloadIdentityRules(*misc.GZipWriter) (err error) { +func (*MSSQL) DownloadIdentityRules(context.Context, *misc.GZipWriter) (err error) { return } @@ -850,7 +850,7 @@ func (ms *MSSQL) GetTotalCountInTable(ctx context.Context, tableName string) (in return total, err } -func (ms *MSSQL) Connect(warehouse model.Warehouse) (client.Client, error) { +func (ms *MSSQL) Connect(_ context.Context, warehouse model.Warehouse) (client.Client, error) { ms.Warehouse = warehouse ms.Namespace = warehouse.Namespace ms.ObjectStorage = warehouseutils.ObjectStorageType( @@ -866,14 +866,14 @@ func (ms *MSSQL) Connect(warehouse model.Warehouse) (client.Client, error) { return client.Client{Type: client.SQLClient, SQL: dbHandle}, err } -func (ms *MSSQL) LoadTestTable(_, tableName string, payloadMap map[string]interface{}, _ string) (err error) { +func (ms *MSSQL) LoadTestTable(ctx context.Context, _, tableName string, payloadMap map[string]interface{}, _ string) (err error) { sqlStatement := fmt.Sprintf(`INSERT INTO %q.%q (%v) VALUES (%s)`, ms.Namespace, tableName, fmt.Sprintf(`%q, %q`, "id", "val"), fmt.Sprintf(`'%d', '%s'`, payloadMap["id"], payloadMap["val"]), ) - _, err = ms.DB.Exec(sqlStatement) + _, err = ms.DB.ExecContext(ctx, sqlStatement) return } diff --git a/warehouse/integrations/postgres-legacy/postgres.go b/warehouse/integrations/postgres-legacy/postgres.go index 7218a019d9..25d20cebb0 100644 --- a/warehouse/integrations/postgres-legacy/postgres.go +++ b/warehouse/integrations/postgres-legacy/postgres.go @@ -277,12 +277,12 @@ func ColumnsWithDataTypes(columns map[string]string, prefix string) string { return strings.Join(arr, ",") } -func (*Postgres) IsEmpty(_ model.Warehouse) (empty bool, err error) { +func (*Postgres) IsEmpty(context.Context, model.Warehouse) (empty bool, err error) { return } func (pg *Postgres) DownloadLoadFiles(ctx context.Context, tableName string) ([]string, error) { - objects := pg.Uploader.GetLoadFilesMetadata(warehouseutils.GetLoadFilesOptions{Table: tableName}) + objects := pg.Uploader.GetLoadFilesMetadata(ctx, warehouseutils.GetLoadFilesOptions{Table: tableName}) storageProvider := warehouseutils.ObjectStorageType(pg.Warehouse.Destination.DestinationDefinition.Name, pg.Warehouse.Destination.Config, pg.Uploader.UseRudderStorage()) downloader, err := filemanager.DefaultFileManagerFactory.New(&filemanager.SettingsT{ Provider: storageProvider, @@ -400,7 +400,7 @@ func (pg *Postgres) loadTable(ctx context.Context, tableName string, tableSchema return } if !skipTempTableDelete { - defer pg.dropStagingTable(stagingTableName) + defer pg.dropStagingTable(ctx, stagingTableName) } stmt, err := txn.PrepareContext(ctx, pq.CopyInSchema(pg.Namespace, stagingTableName, sortedColumnKeys...)) @@ -468,7 +468,7 @@ func (pg *Postgres) loadTable(ctx context.Context, tableName string, tableSchema } csvRowsProcessedCount++ } - gzipReader.Close() + _ = gzipReader.Close() gzipFile.Close() } @@ -495,7 +495,7 @@ func (pg *Postgres) loadTable(ctx context.Context, tableName string, tableSchema } sqlStatement = fmt.Sprintf(`DELETE FROM "%[1]s"."%[2]s" USING "%[1]s"."%[3]s" as _source where (_source.%[4]s = "%[1]s"."%[2]s"."%[4]s" %[5]s)`, pg.Namespace, tableName, stagingTableName, primaryKey, additionalJoinClause) pg.logger.Infof("PG: Deduplicate records for table:%s using staging table: %s\n", tableName, sqlStatement) - err = pg.handleExec(&QueryParams{ + err = pg.handleExecContext(ctx, &QueryParams{ txn: txn, query: sqlStatement, enableWithQueryPlan: pg.EnableSQLStatementExecutionPlan || slices.Contains(pg.EnableSQLStatementExecutionPlanWorkspaceIDs, pg.Warehouse.WorkspaceID), @@ -514,7 +514,7 @@ func (pg *Postgres) loadTable(ctx context.Context, tableName string, tableSchema ) AS _ where _rudder_staging_row_number = 1 `, pg.Namespace, tableName, quotedColumnNames, stagingTableName, partitionKey) pg.logger.Infof("PG: Inserting records for table:%s using staging table: %s\n", tableName, sqlStatement) - err = pg.handleExec(&QueryParams{ + err = pg.handleExecContext(ctx, &QueryParams{ txn: txn, query: sqlStatement, enableWithQueryPlan: pg.EnableSQLStatementExecutionPlan || slices.Contains(pg.EnableSQLStatementExecutionPlanWorkspaceIDs, pg.Warehouse.WorkspaceID), @@ -539,7 +539,7 @@ func (pg *Postgres) loadTable(ctx context.Context, tableName string, tableSchema } // DeleteBy Need to create a structure with delete parameters instead of simply adding a long list of params -func (pg *Postgres) DeleteBy(tableNames []string, params warehouseutils.DeleteByParams) (err error) { +func (pg *Postgres) DeleteBy(ctx context.Context, tableNames []string, params warehouseutils.DeleteByParams) (err error) { pg.logger.Infof("PG: Cleaning up the following tables in postgres for PG:%s : %+v", tableNames, params) for _, tb := range tableNames { sqlStatement := fmt.Sprintf(`DELETE FROM "%[1]s"."%[2]s" WHERE @@ -553,7 +553,7 @@ func (pg *Postgres) DeleteBy(tableNames []string, params warehouseutils.DeleteBy pg.logger.Infof("PG: Deleting rows in table in postgres for PG:%s", pg.Warehouse.Destination.ID) pg.logger.Debugf("PG: Executing the statement %v", sqlStatement) if pg.EnableDeleteByJobs { - _, err = pg.DB.Exec(sqlStatement, + _, err = pg.DB.ExecContext(ctx, sqlStatement, params.JobRunId, params.TaskRunId, params.SourceId, @@ -579,7 +579,7 @@ func (pg *Postgres) loadUserTables(ctx context.Context) (errorMap map[string]err pg.logger.Infof("PG: Updated search_path to %s in postgres for PG:%s : %v", pg.Namespace, pg.Warehouse.Destination.ID, sqlStatement) pg.logger.Infof("PG: Starting load for identifies and users tables\n") identifyStagingTable, err := pg.loadTable(ctx, warehouseutils.IdentifiesTable, pg.Uploader.GetTableSchemaInUpload(warehouseutils.IdentifiesTable), true) - defer pg.dropStagingTable(identifyStagingTable) + defer pg.dropStagingTable(ctx, identifyStagingTable) if err != nil { errorMap[warehouseutils.IdentifiesTable] = err return @@ -600,8 +600,8 @@ func (pg *Postgres) loadUserTables(ctx context.Context) (errorMap map[string]err unionStagingTableName := warehouseutils.StagingTableName(provider, "users_identifies_union", tableNameLimit) stagingTableName := warehouseutils.StagingTableName(provider, warehouseutils.UsersTable, tableNameLimit) - defer pg.dropStagingTable(stagingTableName) - defer pg.dropStagingTable(unionStagingTableName) + defer pg.dropStagingTable(ctx, stagingTableName) + defer pg.dropStagingTable(ctx, unionStagingTableName) userColMap := pg.Uploader.GetTableSchemaInWarehouse(warehouseutils.UsersTable) var userColNames, firstValProps []string @@ -674,7 +674,7 @@ func (pg *Postgres) loadUserTables(ctx context.Context) (errorMap map[string]err "destId": pg.Warehouse.Destination.ID, "tableName": warehouseutils.UsersTable, } - err = pg.handleExec(&QueryParams{ + err = pg.handleExecContext(ctx, &QueryParams{ txn: tx, query: sqlStatement, enableWithQueryPlan: pg.EnableSQLStatementExecutionPlan || slices.Contains(pg.EnableSQLStatementExecutionPlanWorkspaceIDs, pg.Warehouse.WorkspaceID), @@ -689,7 +689,7 @@ func (pg *Postgres) loadUserTables(ctx context.Context) (errorMap map[string]err sqlStatement = fmt.Sprintf(`INSERT INTO "%[1]s"."%[2]s" (%[4]s) SELECT %[4]s FROM "%[1]s"."%[3]s"`, pg.Namespace, warehouseutils.UsersTable, stagingTableName, strings.Join(append([]string{"id"}, userColNames...), ",")) pg.logger.Infof("PG: Inserting records for table:%s using staging table: %s\n", warehouseutils.UsersTable, sqlStatement) - err = pg.handleExec(&QueryParams{ + err = pg.handleExecContext(ctx, &QueryParams{ txn: tx, query: sqlStatement, enableWithQueryPlan: pg.EnableSQLStatementExecutionPlan || slices.Contains(pg.EnableSQLStatementExecutionPlanWorkspaceIDs, pg.Warehouse.WorkspaceID), @@ -714,15 +714,15 @@ func (pg *Postgres) loadUserTables(ctx context.Context) (errorMap map[string]err return } -func (pg *Postgres) schemaExists(_ string) (exists bool, err error) { +func (pg *Postgres) schemaExists(ctx context.Context) (exists bool, err error) { sqlStatement := fmt.Sprintf(`SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_namespace WHERE nspname = '%s');`, pg.Namespace) - err = pg.DB.QueryRow(sqlStatement).Scan(&exists) + err = pg.DB.QueryRowContext(ctx, sqlStatement).Scan(&exists) return } -func (pg *Postgres) CreateSchema() (err error) { +func (pg *Postgres) CreateSchema(ctx context.Context) (err error) { var schemaExists bool - schemaExists, err = pg.schemaExists(pg.Namespace) + schemaExists, err = pg.schemaExists(ctx) if err != nil { pg.logger.Errorf("PG: Error checking if schema: %s exists: %v", pg.Namespace, err) return err @@ -733,45 +733,45 @@ func (pg *Postgres) CreateSchema() (err error) { } sqlStatement := fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %q`, pg.Namespace) pg.logger.Infof("PG: Creating schema name in postgres for PG:%s : %v", pg.Warehouse.Destination.ID, sqlStatement) - _, err = pg.DB.Exec(sqlStatement) + _, err = pg.DB.ExecContext(ctx, sqlStatement) return } -func (pg *Postgres) dropStagingTable(stagingTableName string) { +func (pg *Postgres) dropStagingTable(ctx context.Context, stagingTableName string) { pg.logger.Infof("PG: dropping table %+v\n", stagingTableName) - _, err := pg.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS "%[1]s"."%[2]s"`, pg.Namespace, stagingTableName)) + _, err := pg.DB.ExecContext(ctx, fmt.Sprintf(`DROP TABLE IF EXISTS "%[1]s"."%[2]s"`, pg.Namespace, stagingTableName)) if err != nil { pg.logger.Errorf("PG: Error dropping staging table %s in postgres: %v", stagingTableName, err) } } -func (pg *Postgres) createTable(name string, columns model.TableSchema) (err error) { +func (pg *Postgres) createTable(ctx context.Context, name string, columns model.TableSchema) (err error) { sqlStatement := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS "%[1]s"."%[2]s" ( %v )`, pg.Namespace, name, ColumnsWithDataTypes(columns, "")) pg.logger.Infof("PG: Creating table in postgres for PG:%s : %v", pg.Warehouse.Destination.ID, sqlStatement) - _, err = pg.DB.Exec(sqlStatement) + _, err = pg.DB.ExecContext(ctx, sqlStatement) return } -func (pg *Postgres) CreateTable(tableName string, columnMap model.TableSchema) (err error) { +func (pg *Postgres) CreateTable(ctx context.Context, tableName string, columnMap model.TableSchema) (err error) { // set the schema in search path. so that we can query table with unqualified name which is just the table name rather than using schema.table in queries sqlStatement := fmt.Sprintf(`SET search_path to %q`, pg.Namespace) - _, err = pg.DB.Exec(sqlStatement) + _, err = pg.DB.ExecContext(ctx, sqlStatement) if err != nil { return err } pg.logger.Infof("PG: Updated search_path to %s in postgres for PG:%s : %v", pg.Namespace, pg.Warehouse.Destination.ID, sqlStatement) - err = pg.createTable(tableName, columnMap) + err = pg.createTable(ctx, tableName, columnMap) return err } -func (pg *Postgres) DropTable(tableName string) (err error) { +func (pg *Postgres) DropTable(ctx context.Context, tableName string) (err error) { sqlStatement := `DROP TABLE "%[1]s"."%[2]s"` pg.logger.Infof("PG: Dropping table in postgres for PG:%s : %v", pg.Warehouse.Destination.ID, sqlStatement) - _, err = pg.DB.Exec(fmt.Sprintf(sqlStatement, pg.Namespace, tableName)) + _, err = pg.DB.ExecContext(ctx, fmt.Sprintf(sqlStatement, pg.Namespace, tableName)) return } -func (pg *Postgres) AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { +func (pg *Postgres) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { var ( query string queryBuilder strings.Builder @@ -779,7 +779,7 @@ func (pg *Postgres) AddColumns(tableName string, columnsInfo []warehouseutils.Co // set the schema in search path. so that we can query table with unqualified name which is just the table name rather than using schema.table in queries query = fmt.Sprintf(`SET search_path to %q`, pg.Namespace) - if _, err = pg.DB.Exec(query); err != nil { + if _, err = pg.DB.ExecContext(ctx, query); err != nil { return } pg.logger.Infof("PG: Updated search_path to %s in postgres for PG:%s : %v", pg.Namespace, pg.Warehouse.Destination.ID, query) @@ -799,11 +799,11 @@ func (pg *Postgres) AddColumns(tableName string, columnsInfo []warehouseutils.Co query += ";" pg.logger.Infof("PG: Adding columns for destinationID: %s, tableName: %s with query: %v", pg.Warehouse.Destination.ID, tableName, query) - _, err = pg.DB.Exec(query) + _, err = pg.DB.ExecContext(ctx, query) return } -func (*Postgres) AlterColumn(_, _, _ string) (model.AlterTableResponse, error) { +func (*Postgres) AlterColumn(context.Context, string, string, string) (model.AlterTableResponse, error) { return model.AlterTableResponse{}, nil } @@ -825,10 +825,7 @@ func (pg *Postgres) TestConnection(ctx context.Context, warehouse model.Warehous return nil } -func (pg *Postgres) Setup( - warehouse model.Warehouse, - uploader warehouseutils.Uploader, -) (err error) { +func (pg *Postgres) Setup(_ context.Context, warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { pg.Warehouse = warehouse pg.Namespace = warehouse.Namespace pg.Uploader = uploader @@ -838,11 +835,11 @@ func (pg *Postgres) Setup( return err } -func (pg *Postgres) CrashRecover() { - pg.dropDanglingStagingTables() +func (pg *Postgres) CrashRecover(ctx context.Context) { + pg.dropDanglingStagingTables(ctx) } -func (pg *Postgres) dropDanglingStagingTables() bool { +func (pg *Postgres) dropDanglingStagingTables(ctx context.Context) bool { sqlStatement := ` SELECT table_name @@ -852,7 +849,7 @@ func (pg *Postgres) dropDanglingStagingTables() bool { table_schema = $1 AND table_name like $2; ` - rows, err := pg.DB.Query( + rows, err := pg.DB.QueryContext(ctx, sqlStatement, pg.Namespace, fmt.Sprintf(`%s%%`, warehouseutils.StagingTablePrefix(provider)), @@ -861,7 +858,7 @@ func (pg *Postgres) dropDanglingStagingTables() bool { pg.logger.Errorf("WH: PG: Error dropping dangling staging tables in PG: %v\nQuery: %s\n", err, sqlStatement) return false } - defer rows.Close() + defer func() { _ = rows.Close() }() var stagingTableNames []string for rows.Next() { @@ -875,7 +872,7 @@ func (pg *Postgres) dropDanglingStagingTables() bool { pg.logger.Infof("WH: PG: Dropping dangling staging tables: %+v %+v\n", len(stagingTableNames), stagingTableNames) delSuccess := true for _, stagingTableName := range stagingTableNames { - _, err := pg.DB.Exec(fmt.Sprintf(`DROP TABLE "%[1]s"."%[2]s"`, pg.Namespace, stagingTableName)) + _, err := pg.DB.ExecContext(ctx, fmt.Sprintf(`DROP TABLE "%[1]s"."%[2]s"`, pg.Namespace, stagingTableName)) if err != nil { pg.logger.Errorf("WH: PG: Error dropping dangling staging table: %s in PG: %v\n", stagingTableName, err) delSuccess = false @@ -885,7 +882,7 @@ func (pg *Postgres) dropDanglingStagingTables() bool { } // FetchSchema queries postgres and returns the schema associated with provided namespace -func (pg *Postgres) FetchSchema() (model.Schema, model.Schema, error) { +func (pg *Postgres) FetchSchema(ctx context.Context) (model.Schema, model.Schema, error) { schema := make(model.Schema) unrecognizedSchema := make(model.Schema) @@ -900,7 +897,8 @@ func (pg *Postgres) FetchSchema() (model.Schema, model.Schema, error) { table_schema = $1 AND table_name NOT LIKE $2; ` - rows, err := pg.DB.Query( + rows, err := pg.DB.QueryContext( + ctx, sqlStatement, pg.Namespace, fmt.Sprintf(`%s%%`, warehouseutils.StagingTablePrefix(provider)), @@ -950,22 +948,22 @@ func (pg *Postgres) LoadTable(ctx context.Context, tableName string) error { return err } -func (pg *Postgres) Cleanup() { +func (pg *Postgres) Cleanup(ctx context.Context) { if pg.DB != nil { - pg.dropDanglingStagingTables() - pg.DB.Close() + pg.dropDanglingStagingTables(ctx) + _ = pg.DB.Close() } } -func (*Postgres) LoadIdentityMergeRulesTable() (err error) { +func (*Postgres) LoadIdentityMergeRulesTable(context.Context) (err error) { return } -func (*Postgres) LoadIdentityMappingsTable() (err error) { +func (*Postgres) LoadIdentityMappingsTable(context.Context) (err error) { return } -func (*Postgres) DownloadIdentityRules(*misc.GZipWriter) (err error) { +func (*Postgres) DownloadIdentityRules(context.Context, *misc.GZipWriter) (err error) { return } @@ -985,7 +983,7 @@ func (pg *Postgres) GetTotalCountInTable(ctx context.Context, tableName string) return total, err } -func (pg *Postgres) Connect(warehouse model.Warehouse) (client.Client, error) { +func (pg *Postgres) Connect(_ context.Context, warehouse model.Warehouse) (client.Client, error) { if warehouse.Destination.Config["sslMode"] == "verify-ca" { if err := warehouseutils.WriteSSLKeys(warehouse.Destination); err.IsError() { pg.logger.Error(err.Error()) @@ -1007,14 +1005,14 @@ func (pg *Postgres) Connect(warehouse model.Warehouse) (client.Client, error) { return client.Client{Type: client.SQLClient, SQL: dbHandle.DB}, err } -func (pg *Postgres) LoadTestTable(_, tableName string, payloadMap map[string]interface{}, _ string) (err error) { +func (pg *Postgres) LoadTestTable(ctx context.Context, _, tableName string, payloadMap map[string]interface{}, _ string) (err error) { sqlStatement := fmt.Sprintf(`INSERT INTO %q.%q (%v) VALUES (%s)`, pg.Namespace, tableName, fmt.Sprintf(`%q, %q`, "id", "val"), fmt.Sprintf(`'%d', '%s'`, payloadMap["id"], payloadMap["val"]), ) - _, err = pg.DB.Exec(sqlStatement) + _, err = pg.DB.ExecContext(ctx, sqlStatement) return } @@ -1040,7 +1038,7 @@ func (q *QueryParams) validate() (err error) { // Print execution plan if enableWithQueryPlan is set to true else return result set. // Currently, these statements are supported by EXPLAIN // Any INSERT, UPDATE, DELETE whose execution plan you wish to see. -func (pg *Postgres) handleExec(e *QueryParams) (err error) { +func (pg *Postgres) handleExecContext(ctx context.Context, e *QueryParams) (err error) { sqlStatement := e.query if err = e.validate(); err != nil { @@ -1053,9 +1051,9 @@ func (pg *Postgres) handleExec(e *QueryParams) (err error) { var rows *sql.Rows if e.txn != nil { - rows, err = e.txn.Query(sqlStatement) + rows, err = e.txn.QueryContext(ctx, sqlStatement) } else if e.db != nil { - rows, err = e.db.Query(sqlStatement) + rows, err = e.db.QueryContext(ctx, sqlStatement) } if err != nil { err = fmt.Errorf("[WH][POSTGRES] error occurred while handling transaction for query: %s with err: %w", sqlStatement, err) @@ -1076,9 +1074,9 @@ func (pg *Postgres) handleExec(e *QueryParams) (err error) { `))) } if e.txn != nil { - _, err = e.txn.Exec(sqlStatement) + _, err = e.txn.ExecContext(ctx, sqlStatement) } else if e.db != nil { - _, err = e.db.Exec(sqlStatement) + _, err = e.db.ExecContext(ctx, sqlStatement) } return } diff --git a/warehouse/integrations/postgres/load.go b/warehouse/integrations/postgres/load.go index 0ff3de32c9..1e65ff172d 100644 --- a/warehouse/integrations/postgres/load.go +++ b/warehouse/integrations/postgres/load.go @@ -223,7 +223,7 @@ func (pg *Postgres) loadTable( logfield.Query, query, ) - result, err := txn.Exec(query) + result, err := txn.ExecContext(ctx, query) if err != nil { return loadTableResponse{}, fmt.Errorf("deleting from original table for dedup: %w", err) } diff --git a/warehouse/integrations/postgres/load_test.go b/warehouse/integrations/postgres/load_test.go index d5e6e6cd1e..079f44f306 100644 --- a/warehouse/integrations/postgres/load_test.go +++ b/warehouse/integrations/postgres/load_test.go @@ -39,23 +39,25 @@ type mockUploader struct { schema model.Schema } -func (*mockUploader) GetSchemaInWarehouse() model.Schema { return model.Schema{} } -func (*mockUploader) GetLocalSchema() (model.Schema, error) { return model.Schema{}, nil } -func (*mockUploader) UpdateLocalSchema(_ model.Schema) error { return nil } -func (*mockUploader) ShouldOnDedupUseNewRecord() bool { return false } -func (*mockUploader) UseRudderStorage() bool { return false } -func (*mockUploader) GetLoadFileGenStartTIme() time.Time { return time.Time{} } -func (*mockUploader) GetLoadFileType() string { return "JSON" } -func (*mockUploader) GetFirstLastEvent() (time.Time, time.Time) { return time.Time{}, time.Time{} } -func (*mockUploader) GetLoadFilesMetadata(warehouseutils.GetLoadFilesOptions) []warehouseutils.LoadFile { +func (*mockUploader) GetSchemaInWarehouse() model.Schema { return model.Schema{} } +func (*mockUploader) GetLocalSchema(context.Context) (model.Schema, error) { + return model.Schema{}, nil +} +func (*mockUploader) UpdateLocalSchema(context.Context, model.Schema) error { return nil } +func (*mockUploader) ShouldOnDedupUseNewRecord() bool { return false } +func (*mockUploader) UseRudderStorage() bool { return false } +func (*mockUploader) GetLoadFileGenStartTIme() time.Time { return time.Time{} } +func (*mockUploader) GetLoadFileType() string { return "JSON" } +func (*mockUploader) GetFirstLastEvent() (time.Time, time.Time) { return time.Time{}, time.Time{} } +func (*mockUploader) GetLoadFilesMetadata(context.Context, warehouseutils.GetLoadFilesOptions) []warehouseutils.LoadFile { return []warehouseutils.LoadFile{} } -func (*mockUploader) GetSingleLoadFile(_ string) (warehouseutils.LoadFile, error) { +func (*mockUploader) GetSingleLoadFile(context.Context, string) (warehouseutils.LoadFile, error) { return warehouseutils.LoadFile{}, nil } -func (*mockUploader) GetSampleLoadFileLocation(_ string) (string, error) { +func (*mockUploader) GetSampleLoadFileLocation(context.Context, string) (string, error) { return "", nil } @@ -88,6 +90,8 @@ func cloneFiles(t *testing.T, files []string) []string { } func TestLoadTable(t *testing.T) { + t.Parallel() + misc.Init() warehouseutils.Init() @@ -394,6 +398,8 @@ func TestLoadTable(t *testing.T) { } func TestLoadUsersTable(t *testing.T) { + t.Parallel() + misc.Init() warehouseutils.Init() diff --git a/warehouse/integrations/postgres/postgres.go b/warehouse/integrations/postgres/postgres.go index 5fbb7c22fd..c16ebfc38b 100644 --- a/warehouse/integrations/postgres/postgres.go +++ b/warehouse/integrations/postgres/postgres.go @@ -259,12 +259,12 @@ func ColumnsWithDataTypes(columns model.TableSchema, prefix string) string { return strings.Join(arr, ",") } -func (*Postgres) IsEmpty(_ model.Warehouse) (empty bool, err error) { +func (*Postgres) IsEmpty(context.Context, model.Warehouse) (empty bool, err error) { return } // DeleteBy Need to create a structure with delete parameters instead of simply adding a long list of params -func (pg *Postgres) DeleteBy(tableNames []string, params warehouseutils.DeleteByParams) (err error) { +func (pg *Postgres) DeleteBy(ctx context.Context, tableNames []string, params warehouseutils.DeleteByParams) (err error) { pg.Logger.Infof("PG: Cleaning up the following tables in postgres for PG:%s : %+v", tableNames, params) for _, tb := range tableNames { sqlStatement := fmt.Sprintf(`DELETE FROM "%[1]s"."%[2]s" WHERE @@ -278,7 +278,7 @@ func (pg *Postgres) DeleteBy(tableNames []string, params warehouseutils.DeleteBy pg.Logger.Infof("PG: Deleting rows in table in postgres for PG:%s", pg.Warehouse.Destination.ID) pg.Logger.Debugf("PG: Executing the statement %v", sqlStatement) if pg.EnableDeleteByJobs { - _, err = pg.DB.Exec(sqlStatement, + _, err = pg.DB.ExecContext(ctx, sqlStatement, params.JobRunId, params.TaskRunId, params.SourceId, @@ -293,15 +293,15 @@ func (pg *Postgres) DeleteBy(tableNames []string, params warehouseutils.DeleteBy return nil } -func (pg *Postgres) schemaExists(_ string) (exists bool, err error) { +func (pg *Postgres) schemaExists(ctx context.Context, _ string) (exists bool, err error) { sqlStatement := fmt.Sprintf(`SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_namespace WHERE nspname = '%s');`, pg.Namespace) - err = pg.DB.QueryRow(sqlStatement).Scan(&exists) + err = pg.DB.QueryRowContext(ctx, sqlStatement).Scan(&exists) return } -func (pg *Postgres) CreateSchema() (err error) { +func (pg *Postgres) CreateSchema(ctx context.Context) (err error) { var schemaExists bool - schemaExists, err = pg.schemaExists(pg.Namespace) + schemaExists, err = pg.schemaExists(ctx, pg.Namespace) if err != nil { pg.Logger.Errorf("PG: Error checking if schema: %s exists: %v", pg.Namespace, err) return err @@ -312,37 +312,37 @@ func (pg *Postgres) CreateSchema() (err error) { } sqlStatement := fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %q`, pg.Namespace) pg.Logger.Infof("PG: Creating schema name in postgres for PG:%s : %v", pg.Warehouse.Destination.ID, sqlStatement) - _, err = pg.DB.Exec(sqlStatement) + _, err = pg.DB.ExecContext(ctx, sqlStatement) return } -func (pg *Postgres) createTable(name string, columns model.TableSchema) (err error) { +func (pg *Postgres) createTable(ctx context.Context, name string, columns model.TableSchema) (err error) { sqlStatement := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS "%[1]s"."%[2]s" ( %v )`, pg.Namespace, name, ColumnsWithDataTypes(columns, "")) pg.Logger.Infof("PG: Creating table in postgres for PG:%s : %v", pg.Warehouse.Destination.ID, sqlStatement) - _, err = pg.DB.Exec(sqlStatement) + _, err = pg.DB.ExecContext(ctx, sqlStatement) return } -func (pg *Postgres) CreateTable(tableName string, columnMap model.TableSchema) (err error) { +func (pg *Postgres) CreateTable(ctx context.Context, tableName string, columnMap model.TableSchema) (err error) { // set the schema in search path. so that we can query table with unqualified name which is just the table name rather than using schema.table in queries sqlStatement := fmt.Sprintf(`SET search_path to %q`, pg.Namespace) - _, err = pg.DB.Exec(sqlStatement) + _, err = pg.DB.ExecContext(ctx, sqlStatement) if err != nil { return err } pg.Logger.Infof("PG: Updated search_path to %s in postgres for PG:%s : %v", pg.Namespace, pg.Warehouse.Destination.ID, sqlStatement) - err = pg.createTable(tableName, columnMap) + err = pg.createTable(ctx, tableName, columnMap) return err } -func (pg *Postgres) DropTable(tableName string) (err error) { +func (pg *Postgres) DropTable(ctx context.Context, tableName string) (err error) { sqlStatement := `DROP TABLE "%[1]s"."%[2]s"` pg.Logger.Infof("PG: Dropping table in postgres for PG:%s : %v", pg.Warehouse.Destination.ID, sqlStatement) - _, err = pg.DB.Exec(fmt.Sprintf(sqlStatement, pg.Namespace, tableName)) + _, err = pg.DB.ExecContext(ctx, fmt.Sprintf(sqlStatement, pg.Namespace, tableName)) return } -func (pg *Postgres) AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { +func (pg *Postgres) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { var ( query string queryBuilder strings.Builder @@ -350,7 +350,7 @@ func (pg *Postgres) AddColumns(tableName string, columnsInfo []warehouseutils.Co // set the schema in search path. so that we can query table with unqualified name which is just the table name rather than using schema.table in queries query = fmt.Sprintf(`SET search_path to %q`, pg.Namespace) - if _, err = pg.DB.Exec(query); err != nil { + if _, err = pg.DB.ExecContext(ctx, query); err != nil { return } pg.Logger.Infof("PG: Updated search_path to %s in postgres for PG:%s : %v", pg.Namespace, pg.Warehouse.Destination.ID, query) @@ -370,11 +370,11 @@ func (pg *Postgres) AddColumns(tableName string, columnsInfo []warehouseutils.Co query += ";" pg.Logger.Infof("PG: Adding columns for destinationID: %s, tableName: %s with query: %v", pg.Warehouse.Destination.ID, tableName, query) - _, err = pg.DB.Exec(query) + _, err = pg.DB.ExecContext(ctx, query) return } -func (*Postgres) AlterColumn(_, _, _ string) (model.AlterTableResponse, error) { +func (*Postgres) AlterColumn(context.Context, string, string, string) (model.AlterTableResponse, error) { return model.AlterTableResponse{}, nil } @@ -396,10 +396,7 @@ func (pg *Postgres) TestConnection(ctx context.Context, warehouse model.Warehous return nil } -func (pg *Postgres) Setup( - warehouse model.Warehouse, - uploader warehouseutils.Uploader, -) (err error) { +func (pg *Postgres) Setup(_ context.Context, warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { pg.Warehouse = warehouse pg.Namespace = warehouse.Namespace pg.Uploader = uploader @@ -410,10 +407,10 @@ func (pg *Postgres) Setup( return err } -func (pg *Postgres) CrashRecover() {} +func (*Postgres) CrashRecover(context.Context) {} // FetchSchema queries postgres and returns the schema associated with provided namespace -func (pg *Postgres) FetchSchema() (model.Schema, model.Schema, error) { +func (pg *Postgres) FetchSchema(ctx context.Context) (model.Schema, model.Schema, error) { schema := make(model.Schema) unrecognizedSchema := make(model.Schema) @@ -428,7 +425,8 @@ func (pg *Postgres) FetchSchema() (model.Schema, model.Schema, error) { table_schema = $1 AND table_name NOT LIKE $2; ` - rows, err := pg.DB.Query( + rows, err := pg.DB.QueryContext( + ctx, sqlStatement, pg.Namespace, fmt.Sprintf(`%s%%`, warehouseutils.StagingTablePrefix(provider)), @@ -469,21 +467,21 @@ func (pg *Postgres) FetchSchema() (model.Schema, model.Schema, error) { return schema, unrecognizedSchema, nil } -func (pg *Postgres) Cleanup() { +func (pg *Postgres) Cleanup(context.Context) { if pg.DB != nil { - pg.DB.Close() + _ = pg.DB.Close() } } -func (*Postgres) LoadIdentityMergeRulesTable() (err error) { +func (*Postgres) LoadIdentityMergeRulesTable(context.Context) (err error) { return } -func (*Postgres) LoadIdentityMappingsTable() (err error) { +func (*Postgres) LoadIdentityMappingsTable(context.Context) (err error) { return } -func (*Postgres) DownloadIdentityRules(*misc.GZipWriter) (err error) { +func (*Postgres) DownloadIdentityRules(context.Context, *misc.GZipWriter) (err error) { return } @@ -503,7 +501,7 @@ func (pg *Postgres) GetTotalCountInTable(ctx context.Context, tableName string) return total, err } -func (pg *Postgres) Connect(warehouse model.Warehouse) (client.Client, error) { +func (pg *Postgres) Connect(_ context.Context, warehouse model.Warehouse) (client.Client, error) { if warehouse.Destination.Config["sslMode"] == "verify-ca" { if err := warehouseutils.WriteSSLKeys(warehouse.Destination); err.IsError() { pg.Logger.Error(err.Error()) @@ -525,14 +523,14 @@ func (pg *Postgres) Connect(warehouse model.Warehouse) (client.Client, error) { return client.Client{Type: client.SQLClient, SQL: dbHandle.DB}, err } -func (pg *Postgres) LoadTestTable(_, tableName string, payloadMap map[string]interface{}, _ string) (err error) { +func (pg *Postgres) LoadTestTable(ctx context.Context, _, tableName string, payloadMap map[string]interface{}, _ string) (err error) { sqlStatement := fmt.Sprintf(`INSERT INTO %q.%q (%v) VALUES (%s)`, pg.Namespace, tableName, fmt.Sprintf(`%q, %q`, "id", "val"), fmt.Sprintf(`'%d', '%s'`, payloadMap["id"], payloadMap["val"]), ) - _, err = pg.DB.Exec(sqlStatement) + _, err = pg.DB.ExecContext(ctx, sqlStatement) return } diff --git a/warehouse/integrations/redshift/redshift.go b/warehouse/integrations/redshift/redshift.go index 4999ad8455..00d59dd87b 100644 --- a/warehouse/integrations/redshift/redshift.go +++ b/warehouse/integrations/redshift/redshift.go @@ -220,7 +220,7 @@ func ColumnsWithDataTypes(columns model.TableSchema, prefix string) string { return strings.Join(arr, ",") } -func (rs *Redshift) CreateTable(tableName string, columns model.TableSchema) (err error) { +func (rs *Redshift) CreateTable(ctx context.Context, tableName string, columns model.TableSchema) (err error) { name := fmt.Sprintf(`%q.%q`, rs.Namespace, tableName) sortKeyField := "received_at" if _, ok := columns["received_at"]; !ok { @@ -235,24 +235,24 @@ func (rs *Redshift) CreateTable(tableName string, columns model.TableSchema) (er } sqlStatement := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( %v ) %s SORTKEY(%q) `, name, ColumnsWithDataTypes(columns, ""), distKeySql, sortKeyField) rs.Logger.Infof("Creating table in redshift for RS:%s : %v", rs.Warehouse.Destination.ID, sqlStatement) - _, err = rs.DB.Exec(sqlStatement) + _, err = rs.DB.ExecContext(ctx, sqlStatement) return } -func (rs *Redshift) DropTable(tableName string) (err error) { +func (rs *Redshift) DropTable(ctx context.Context, tableName string) (err error) { sqlStatement := `DROP TABLE "%[1]s"."%[2]s"` rs.Logger.Infof("RS: Dropping table in redshift for RS:%s : %v", rs.Warehouse.Destination.ID, sqlStatement) - _, err = rs.DB.Exec(fmt.Sprintf(sqlStatement, rs.Namespace, tableName)) + _, err = rs.DB.ExecContext(ctx, fmt.Sprintf(sqlStatement, rs.Namespace, tableName)) return } -func (rs *Redshift) schemaExists(_ string) (exists bool, err error) { +func (rs *Redshift) schemaExists(ctx context.Context) (exists bool, err error) { sqlStatement := fmt.Sprintf(`SELECT EXISTS (SELECT 1 FROM pg_catalog.pg_namespace WHERE nspname = '%s');`, rs.Namespace) - err = rs.DB.QueryRow(sqlStatement).Scan(&exists) + err = rs.DB.QueryRowContext(ctx, sqlStatement).Scan(&exists) return } -func (rs *Redshift) AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) error { +func (rs *Redshift) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) error { for _, columnInfo := range columnsInfo { columnType := getRSDataType(columnInfo.Type) query := fmt.Sprintf(` @@ -268,7 +268,7 @@ func (rs *Redshift) AddColumns(tableName string, columnsInfo []warehouseutils.Co ) rs.Logger.Infof("RS: Adding column for destinationID: %s, tableName: %s with query: %v", rs.Warehouse.Destination.ID, tableName, query) - if _, err := rs.DB.Exec(query); err != nil { + if _, err := rs.DB.ExecContext(ctx, query); err != nil { if CheckAndIgnoreColumnAlreadyExistError(err) { rs.Logger.Infow("column already exists", logfield.SourceID, rs.Warehouse.Source.ID, @@ -304,7 +304,7 @@ func CheckAndIgnoreColumnAlreadyExistError(err error) bool { return true } -func (rs *Redshift) DeleteBy(tableNames []string, params warehouseutils.DeleteByParams) (err error) { +func (rs *Redshift) DeleteBy(ctx context.Context, tableNames []string, params warehouseutils.DeleteByParams) (err error) { rs.Logger.Infof("RS: Cleaning up the following tables in redshift for RS:%s : %+v", tableNames, params) rs.Logger.Infof("RS: Flag for enableDeleteByJobs is %t", rs.EnableDeleteByJobs) for _, tb := range tableNames { @@ -321,7 +321,7 @@ func (rs *Redshift) DeleteBy(tableNames []string, params warehouseutils.DeleteBy rs.Logger.Infof("RS: Executing the query %v", sqlStatement) if rs.EnableDeleteByJobs { - _, err = rs.DB.Exec(sqlStatement, + _, err = rs.DB.ExecContext(ctx, sqlStatement, params.JobRunId, params.TaskRunId, params.SourceId, @@ -337,15 +337,15 @@ func (rs *Redshift) DeleteBy(tableNames []string, params warehouseutils.DeleteBy return nil } -func (rs *Redshift) createSchema() (err error) { +func (rs *Redshift) createSchema(ctx context.Context) (err error) { sqlStatement := fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %q`, rs.Namespace) rs.Logger.Infof("Creating schema name in redshift for RS:%s : %v", rs.Warehouse.Destination.ID, sqlStatement) - _, err = rs.DB.Exec(sqlStatement) + _, err = rs.DB.ExecContext(ctx, sqlStatement) return } func (rs *Redshift) generateManifest(ctx context.Context, tableName string) (string, error) { - loadFiles := rs.Uploader.GetLoadFilesMetadata(warehouseutils.GetLoadFilesOptions{Table: tableName}) + loadFiles := rs.Uploader.GetLoadFilesMetadata(ctx, warehouseutils.GetLoadFilesOptions{Table: tableName}) loadFiles = warehouseutils.GetS3Locations(loadFiles) var manifest S3Manifest for idx, loadFile := range loadFiles { @@ -400,10 +400,10 @@ func (rs *Redshift) generateManifest(ctx context.Context, tableName string) (str return uploadOutput.Location, nil } -func (rs *Redshift) dropStagingTables(stagingTableNames []string) { +func (rs *Redshift) dropStagingTables(ctx context.Context, stagingTableNames []string) { for _, stagingTableName := range stagingTableNames { rs.Logger.Infof("WH: dropping table %+v\n", stagingTableName) - _, err := rs.DB.Exec(fmt.Sprintf(`DROP TABLE "%[1]s"."%[2]s"`, rs.Namespace, stagingTableName)) + _, err := rs.DB.ExecContext(ctx, fmt.Sprintf(`DROP TABLE "%[1]s"."%[2]s"`, rs.Namespace, stagingTableName)) if err != nil { rs.Logger.Errorf("WH: RS: Error dropping staging tables in redshift: %v", err) } @@ -463,7 +463,7 @@ func (rs *Redshift) loadTable(ctx context.Context, tableName string, tableSchema } if !skipTempTableDelete { - defer rs.dropStagingTables([]string{stagingTableName}) + defer rs.dropStagingTables(ctx, []string{stagingTableName}) } manifestS3Location, region := warehouseutils.GetS3Location(manifestLocation) @@ -796,7 +796,7 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { } } - defer rs.dropStagingTables([]string{identifyStagingTable}) + defer rs.dropStagingTables(ctx, []string{identifyStagingTable}) if len(rs.Uploader.GetTableSchemaInUpload(warehouseutils.UsersTable)) == 0 { return map[string]error{ @@ -904,7 +904,7 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { warehouseutils.UsersTable: fmt.Errorf("creating staging table for users: %w", err), } } - defer rs.dropStagingTables([]string{stagingTableName}) + defer rs.dropStagingTables(ctx, []string{stagingTableName}) primaryKey := "id" query = fmt.Sprintf(` @@ -1019,7 +1019,7 @@ func (rs *Redshift) loadUserTables(ctx context.Context) map[string]error { } } -func (rs *Redshift) connect() (*sqlmiddleware.DB, error) { +func (rs *Redshift) connect(ctx context.Context) (*sqlmiddleware.DB, error) { cred := rs.getConnectionCredentials() dsn := url.URL{ Scheme: "postgres", @@ -1053,7 +1053,7 @@ func (rs *Redshift) connect() (*sqlmiddleware.DB, error) { } stmt := `SET query_group to 'RudderStack'` - _, err = db.Exec(stmt) + _, err = db.ExecContext(ctx, stmt) if err != nil { return nil, fmt.Errorf("redshift set query_group error : %v", err) } @@ -1078,7 +1078,7 @@ func (rs *Redshift) connect() (*sqlmiddleware.DB, error) { return middleware, nil } -func (rs *Redshift) dropDanglingStagingTables() bool { +func (rs *Redshift) dropDanglingStagingTables(ctx context.Context) bool { sqlStatement := ` SELECT table_name @@ -1088,7 +1088,7 @@ func (rs *Redshift) dropDanglingStagingTables() bool { table_schema = $1 AND table_name like $2; ` - rows, err := rs.DB.Query( + rows, err := rs.DB.QueryContext(ctx, sqlStatement, rs.Namespace, fmt.Sprintf(`%s%%`, warehouseutils.StagingTablePrefix(provider)), @@ -1111,7 +1111,7 @@ func (rs *Redshift) dropDanglingStagingTables() bool { rs.Logger.Infof("WH: RS: Dropping dangling staging tables: %+v %+v\n", len(stagingTableNames), stagingTableNames) delSuccess := true for _, stagingTableName := range stagingTableNames { - _, err := rs.DB.Exec(fmt.Sprintf(`DROP TABLE "%[1]s"."%[2]s"`, rs.Namespace, stagingTableName)) + _, err := rs.DB.ExecContext(ctx, fmt.Sprintf(`DROP TABLE "%[1]s"."%[2]s"`, rs.Namespace, stagingTableName)) if err != nil { rs.Logger.Errorf("WH: RS: Error dropping dangling staging table: %s in redshift: %v\n", stagingTableName, err) delSuccess = false @@ -1120,9 +1120,9 @@ func (rs *Redshift) dropDanglingStagingTables() bool { return delSuccess } -func (rs *Redshift) CreateSchema() (err error) { +func (rs *Redshift) CreateSchema(ctx context.Context) (err error) { var schemaExists bool - schemaExists, err = rs.schemaExists(rs.Namespace) + schemaExists, err = rs.schemaExists(ctx) if err != nil { rs.Logger.Errorf("RS: Error checking if schema: %s exists: %v", rs.Namespace, err) return err @@ -1131,10 +1131,10 @@ func (rs *Redshift) CreateSchema() (err error) { rs.Logger.Infof("RS: Skipping creating schema: %s since it already exists", rs.Namespace) return } - return rs.createSchema() + return rs.createSchema(ctx) } -func (rs *Redshift) AlterColumn(tableName, columnName, columnType string) (model.AlterTableResponse, error) { +func (rs *Redshift) AlterColumn(ctx context.Context, tableName, columnName, columnType string) (model.AlterTableResponse, error) { var ( query string stagingColumnName string @@ -1143,7 +1143,6 @@ func (rs *Redshift) AlterColumn(tableName, columnName, columnType string) (model isDependent bool tx *sqlmiddleware.Tx err error - ctx = context.TODO() ) // Begin a transaction @@ -1283,7 +1282,7 @@ func (rs *Redshift) getConnectionCredentials() RedshiftCredentials { } // FetchSchema queries redshift and returns the schema associated with provided namespace -func (rs *Redshift) FetchSchema() (model.Schema, model.Schema, error) { +func (rs *Redshift) FetchSchema(ctx context.Context) (model.Schema, model.Schema, error) { schema := make(model.Schema) unrecognizedSchema := make(model.Schema) @@ -1300,7 +1299,8 @@ func (rs *Redshift) FetchSchema() (model.Schema, model.Schema, error) { and table_name not like $2; ` - rows, err := rs.DB.Query( + rows, err := rs.DB.QueryContext( + ctx, sqlStatement, rs.Namespace, fmt.Sprintf(`%s%%`, warehouseutils.StagingTablePrefix(provider)), @@ -1352,12 +1352,12 @@ func calculateDataType(columnType string, charLength sql.NullInt64) (string, boo return "", false } -func (rs *Redshift) Setup(warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { +func (rs *Redshift) Setup(ctx context.Context, warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { rs.Warehouse = warehouse rs.Namespace = warehouse.Namespace rs.Uploader = uploader - rs.DB, err = rs.connect() + rs.DB, err = rs.connect(ctx) return err } @@ -1373,18 +1373,18 @@ func (rs *Redshift) TestConnection(ctx context.Context, _ model.Warehouse) error return nil } -func (rs *Redshift) Cleanup() { +func (rs *Redshift) Cleanup(ctx context.Context) { if rs.DB != nil { - rs.dropDanglingStagingTables() + rs.dropDanglingStagingTables(ctx) _ = rs.DB.Close() } } -func (rs *Redshift) CrashRecover() { - rs.dropDanglingStagingTables() +func (rs *Redshift) CrashRecover(ctx context.Context) { + rs.dropDanglingStagingTables(ctx) } -func (*Redshift) IsEmpty(_ model.Warehouse) (empty bool, err error) { +func (*Redshift) IsEmpty(context.Context, model.Warehouse) (empty bool, err error) { return } @@ -1397,15 +1397,15 @@ func (rs *Redshift) LoadTable(ctx context.Context, tableName string) error { return err } -func (*Redshift) LoadIdentityMergeRulesTable() (err error) { +func (*Redshift) LoadIdentityMergeRulesTable(context.Context) (err error) { return } -func (*Redshift) LoadIdentityMappingsTable() (err error) { +func (*Redshift) LoadIdentityMappingsTable(context.Context) (err error) { return } -func (*Redshift) DownloadIdentityRules(*misc.GZipWriter) (err error) { +func (*Redshift) DownloadIdentityRules(context.Context, *misc.GZipWriter) (err error) { return } @@ -1425,10 +1425,10 @@ func (rs *Redshift) GetTotalCountInTable(ctx context.Context, tableName string) return total, err } -func (rs *Redshift) Connect(warehouse model.Warehouse) (client.Client, error) { +func (rs *Redshift) Connect(ctx context.Context, warehouse model.Warehouse) (client.Client, error) { rs.Warehouse = warehouse rs.Namespace = warehouse.Namespace - dbHandle, err := rs.connect() + dbHandle, err := rs.connect(ctx) if err != nil { return client.Client{}, err } @@ -1436,7 +1436,7 @@ func (rs *Redshift) Connect(warehouse model.Warehouse) (client.Client, error) { return client.Client{Type: client.SQLClient, SQL: dbHandle.DB}, err } -func (rs *Redshift) LoadTestTable(location, tableName string, _ map[string]interface{}, format string) (err error) { +func (rs *Redshift) LoadTestTable(ctx context.Context, location, tableName string, _ map[string]interface{}, format string) (err error) { tempAccessKeyId, tempSecretAccessKey, token, err := warehouseutils.GetTemporaryS3Cred(&rs.Warehouse.Destination) if err != nil { rs.Logger.Errorf("RS: Failed to create temp credentials before copying, while create load for table %v, err%v", tableName, err) @@ -1479,7 +1479,7 @@ func (rs *Redshift) LoadTestTable(location, tableName string, _ map[string]inter rs.Logger.Infof("RS: Running COPY command for load test table: %s with sqlStatement: %s", tableName, sanitisedSQLStmt) } - _, err = rs.DB.Exec(sqlStatement) + _, err = rs.DB.ExecContext(ctx, sqlStatement) return normalizeError(err) } diff --git a/warehouse/integrations/redshift/redshift_test.go b/warehouse/integrations/redshift/redshift_test.go index f384cda105..dbda47dcaf 100644 --- a/warehouse/integrations/redshift/redshift_test.go +++ b/warehouse/integrations/redshift/redshift_test.go @@ -386,6 +386,8 @@ func TestRedshift_AlterColumn(t *testing.T) { }, } + ctx := context.Background() + for _, tc := range testCases { tc := tc @@ -451,7 +453,7 @@ func TestRedshift_AlterColumn(t *testing.T) { ) require.ErrorContains(t, err, errors.New("pq: value too long for type character varying(512)").Error()) - res, err := rs.AlterColumn(testTable, testColumn, testColumnType) + res, err := rs.AlterColumn(ctx, testTable, testColumn, testColumnType) require.NoError(t, err) if tc.createView { diff --git a/warehouse/integrations/snowflake/snowflake.go b/warehouse/integrations/snowflake/snowflake.go index c43a45fe84..6de4351cc6 100644 --- a/warehouse/integrations/snowflake/snowflake.go +++ b/warehouse/integrations/snowflake/snowflake.go @@ -184,38 +184,38 @@ func (sf *Snowflake) schemaIdentifier() string { ) } -func (sf *Snowflake) createTable(tableName string, columns model.TableSchema) (err error) { +func (sf *Snowflake) createTable(ctx context.Context, tableName string, columns model.TableSchema) (err error) { schemaIdentifier := sf.schemaIdentifier() sqlStatement := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s.%q ( %v )`, schemaIdentifier, tableName, ColumnsWithDataTypes(columns, "")) sf.Logger.Infof("Creating table in snowflake for SF:%s : %v", sf.Warehouse.Destination.ID, sqlStatement) - _, err = sf.DB.Exec(sqlStatement) + _, err = sf.DB.ExecContext(ctx, sqlStatement) return } -func (sf *Snowflake) tableExists(tableName string) (exists bool, err error) { +func (sf *Snowflake) tableExists(ctx context.Context, tableName string) (exists bool, err error) { sqlStatement := fmt.Sprintf(`SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = '%s' AND table_name = '%s' )`, sf.Namespace, tableName) - err = sf.DB.QueryRow(sqlStatement).Scan(&exists) + err = sf.DB.QueryRowContext(ctx, sqlStatement).Scan(&exists) return } -func (sf *Snowflake) columnExists(columnName, tableName string) (exists bool, err error) { +func (sf *Snowflake) columnExists(ctx context.Context, columnName, tableName string) (exists bool, err error) { sqlStatement := fmt.Sprintf(`SELECT EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_schema = '%s' AND table_name = '%s' AND column_name = '%s' )`, sf.Namespace, tableName, columnName) - err = sf.DB.QueryRow(sqlStatement).Scan(&exists) + err = sf.DB.QueryRowContext(ctx, sqlStatement).Scan(&exists) return } -func (sf *Snowflake) schemaExists() (exists bool, err error) { +func (sf *Snowflake) schemaExists(ctx context.Context) (exists bool, err error) { sqlStatement := fmt.Sprintf("SELECT EXISTS ( SELECT 1 FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = '%s' )", sf.Namespace) - r := sf.DB.QueryRow(sqlStatement) + r := sf.DB.QueryRowContext(ctx, sqlStatement) err = r.Scan(&exists) // ignore err if no results for query if err == sql.ErrNoRows { @@ -224,11 +224,11 @@ func (sf *Snowflake) schemaExists() (exists bool, err error) { return } -func (sf *Snowflake) createSchema() (err error) { +func (sf *Snowflake) createSchema(ctx context.Context) (err error) { schemaIdentifier := sf.schemaIdentifier() sqlStatement := fmt.Sprintf(`CREATE SCHEMA IF NOT EXISTS %s`, schemaIdentifier) sf.Logger.Infof("SF: Creating schema name in snowflake for %s:%s : %v", sf.Warehouse.Namespace, sf.Warehouse.Destination.ID, sqlStatement) - _, err = sf.DB.Exec(sqlStatement) + _, err = sf.DB.ExecContext(ctx, sqlStatement) return } @@ -256,7 +256,7 @@ func (sf *Snowflake) authString() string { return auth } -func (sf *Snowflake) DeleteBy(tableNames []string, params warehouseutils.DeleteByParams) (err error) { +func (sf *Snowflake) DeleteBy(ctx context.Context, tableNames []string, params warehouseutils.DeleteByParams) (err error) { for _, tb := range tableNames { sf.Logger.Infof("SF: Cleaning up the following tables in snowflake for SF:%s", tb) sqlStatement := fmt.Sprintf(` @@ -280,7 +280,7 @@ func (sf *Snowflake) DeleteBy(tableNames []string, params warehouseutils.DeleteB sf.Logger.Debugf("SF: Executing the sql statement %v", sqlStatement) if sf.EnableDeleteByJobs { - _, err = sf.DB.Exec(sqlStatement) + _, err = sf.DB.ExecContext(ctx, sqlStatement) if err != nil { sf.Logger.Errorf("Error %s", err) return err @@ -307,7 +307,7 @@ func (sf *Snowflake) loadTable(ctx context.Context, tableName string, tableSchem logfield.TableName, tableName, ) - if db, err = sf.connect(optionalCreds{schemaName: sf.Namespace}); err != nil { + if db, err = sf.connect(ctx, optionalCreds{schemaName: sf.Namespace}); err != nil { return tableLoadResp{}, fmt.Errorf("connect: %w", err) } @@ -357,7 +357,7 @@ func (sf *Snowflake) loadTable(ctx context.Context, tableName string, tableSchem return tableLoadResp{}, fmt.Errorf("create temporary table: %w", err) } - csvObjectLocation, err = sf.Uploader.GetSampleLoadFileLocation(tableName) + csvObjectLocation, err = sf.Uploader.GetSampleLoadFileLocation(ctx, tableName) if err != nil { return tableLoadResp{}, fmt.Errorf("getting sample load file location: %w", err) } @@ -591,17 +591,17 @@ func (sf *Snowflake) loadTable(ctx context.Context, tableName string, tableSchem return res, nil } -func (sf *Snowflake) LoadIdentityMergeRulesTable() (err error) { +func (sf *Snowflake) LoadIdentityMergeRulesTable(ctx context.Context) (err error) { sf.Logger.Infof("SF: Starting load for table:%s\n", identityMergeRulesTable) sf.Logger.Infof("SF: Fetching load file location for %s", identityMergeRulesTable) var loadFile warehouseutils.LoadFile - loadFile, err = sf.Uploader.GetSingleLoadFile(identityMergeRulesTable) + loadFile, err = sf.Uploader.GetSingleLoadFile(ctx, identityMergeRulesTable) if err != nil { return err } - dbHandle, err := sf.connect(optionalCreds{schemaName: sf.Namespace}) + dbHandle, err := sf.connect(ctx, optionalCreds{schemaName: sf.Namespace}) if err != nil { sf.Logger.Errorf("SF: Error establishing connection for copying table:%s: %v\n", identityMergeRulesTable, err) return @@ -622,7 +622,7 @@ func (sf *Snowflake) LoadIdentityMergeRulesTable() (err error) { sf.Logger.Infof("SF: Dedup records for table:%s using staging table: %s\n", identityMergeRulesTable, sanitisedSQLStmt) } - _, err = dbHandle.Exec(sqlStatement) + _, err = dbHandle.ExecContext(ctx, sqlStatement) if err != nil { sf.Logger.Errorf("SF: Error running MERGE for dedup: %v\n", err) return @@ -631,17 +631,17 @@ func (sf *Snowflake) LoadIdentityMergeRulesTable() (err error) { return } -func (sf *Snowflake) LoadIdentityMappingsTable() (err error) { +func (sf *Snowflake) LoadIdentityMappingsTable(ctx context.Context) (err error) { sf.Logger.Infof("SF: Starting load for table:%s\n", identityMappingsTable) sf.Logger.Infof("SF: Fetching load file location for %s", identityMappingsTable) var loadFile warehouseutils.LoadFile - loadFile, err = sf.Uploader.GetSingleLoadFile(identityMappingsTable) + loadFile, err = sf.Uploader.GetSingleLoadFile(ctx, identityMappingsTable) if err != nil { return err } - dbHandle, err := sf.connect(optionalCreds{schemaName: sf.Namespace}) + dbHandle, err := sf.connect(ctx, optionalCreds{schemaName: sf.Namespace}) if err != nil { sf.Logger.Errorf("SF: Error establishing connection for copying table:%s: %v\n", identityMappingsTable, err) return @@ -652,7 +652,7 @@ func (sf *Snowflake) LoadIdentityMappingsTable() (err error) { sqlStatement := fmt.Sprintf(`CREATE TEMPORARY TABLE %[1]s.%[2]q LIKE %[1]s.%[3]q`, schemaIdentifier, stagingTableName, identityMappingsTable) sf.Logger.Infof("SF: Creating temporary table for table:%s at %s\n", identityMappingsTable, sqlStatement) - _, err = dbHandle.Exec(sqlStatement) + _, err = dbHandle.ExecContext(ctx, sqlStatement) if err != nil { sf.Logger.Errorf("SF: Error creating temporary table for table:%s: %v\n", identityMappingsTable, err) return @@ -660,7 +660,7 @@ func (sf *Snowflake) LoadIdentityMappingsTable() (err error) { sqlStatement = fmt.Sprintf(`ALTER TABLE %s.%q ADD COLUMN "ID" int AUTOINCREMENT start 1 increment 1`, schemaIdentifier, stagingTableName) sf.Logger.Infof("SF: Adding autoincrement column for table:%s at %s\n", stagingTableName, sqlStatement) - _, err = dbHandle.Exec(sqlStatement) + _, err = dbHandle.ExecContext(ctx, sqlStatement) if err != nil && !checkAndIgnoreAlreadyExistError(err) { sf.Logger.Errorf("SF: Error adding autoincrement column for table:%s: %v\n", stagingTableName, err) return @@ -671,7 +671,7 @@ func (sf *Snowflake) LoadIdentityMappingsTable() (err error) { FILE_FORMAT = ( TYPE = csv FIELD_OPTIONALLY_ENCLOSED_BY = '"' ESCAPE_UNENCLOSED_FIELD = NONE ) TRUNCATECOLUMNS = TRUE`, fmt.Sprintf(`%s.%q`, schemaIdentifier, stagingTableName), loadLocation, sf.authString()) sf.Logger.Infof("SF: Dedup records for table:%s using staging table: %s\n", identityMappingsTable, sqlStatement) - _, err = dbHandle.Exec(sqlStatement) + _, err = dbHandle.ExecContext(ctx, sqlStatement) if err != nil { sf.Logger.Errorf("SF: Error running MERGE for dedup: %v\n", err) return @@ -689,7 +689,7 @@ func (sf *Snowflake) LoadIdentityMappingsTable() (err error) { WHEN NOT MATCHED THEN INSERT ("MERGE_PROPERTY_TYPE", "MERGE_PROPERTY_VALUE", "RUDDER_ID", "UPDATED_AT") VALUES (staging."MERGE_PROPERTY_TYPE", staging."MERGE_PROPERTY_VALUE", staging."RUDDER_ID", staging."UPDATED_AT")`, identityMappingsTable, stagingTableName, schemaIdentifier) sf.Logger.Infof("SF: Dedup records for table:%s using staging table: %s\n", identityMappingsTable, sqlStatement) - _, err = dbHandle.Exec(sqlStatement) + _, err = dbHandle.ExecContext(ctx, sqlStatement) if err != nil { sf.Logger.Errorf("SF: Error running MERGE for dedup: %v\n", err) return @@ -951,7 +951,7 @@ func (sf *Snowflake) loadUserTables(ctx context.Context) map[string]error { } } -func (sf *Snowflake) connect(opts optionalCreds) (*sqlmiddleware.DB, error) { +func (sf *Snowflake) connect(ctx context.Context, opts optionalCreds) (*sqlmiddleware.DB, error) { cred := sf.getConnectionCredentials(opts) urlConfig := snowflake.Config{ Account: cred.Account, @@ -980,7 +980,7 @@ func (sf *Snowflake) connect(opts optionalCreds) (*sqlmiddleware.DB, error) { } alterStatement := `ALTER SESSION SET ABORT_DETACHED_QUERY=TRUE` - _, err = db.Exec(alterStatement) + _, err = db.ExecContext(ctx, alterStatement) if err != nil { return nil, fmt.Errorf("SF: snowflake alter session error : (%v)", err) } @@ -1005,10 +1005,10 @@ func (sf *Snowflake) connect(opts optionalCreds) (*sqlmiddleware.DB, error) { return middleware, nil } -func (sf *Snowflake) CreateSchema() (err error) { +func (sf *Snowflake) CreateSchema(ctx context.Context) (err error) { var schemaExists bool schemaIdentifier := sf.schemaIdentifier() - schemaExists, err = sf.schemaExists() + schemaExists, err = sf.schemaExists(ctx) if err != nil { sf.Logger.Errorf("SF: Error checking if schema: %s exists: %v", schemaIdentifier, err) return err @@ -1017,22 +1017,22 @@ func (sf *Snowflake) CreateSchema() (err error) { sf.Logger.Infof("SF: Skipping creating schema: %s since it already exists", schemaIdentifier) return } - return sf.createSchema() + return sf.createSchema(ctx) } -func (sf *Snowflake) CreateTable(tableName string, columnMap model.TableSchema) (err error) { - return sf.createTable(tableName, columnMap) +func (sf *Snowflake) CreateTable(ctx context.Context, tableName string, columnMap model.TableSchema) (err error) { + return sf.createTable(ctx, tableName, columnMap) } -func (sf *Snowflake) DropTable(tableName string) (err error) { +func (sf *Snowflake) DropTable(ctx context.Context, tableName string) (err error) { schemaIdentifier := sf.schemaIdentifier() sqlStatement := fmt.Sprintf(`DROP TABLE %[1]s.%[2]q`, schemaIdentifier, tableName) sf.Logger.Infof("SF: Dropping table in snowflake for SF:%s : %v", sf.Warehouse.Destination.ID, sqlStatement) - _, err = sf.DB.Exec(sqlStatement) + _, err = sf.DB.ExecContext(ctx, sqlStatement) return } -func (sf *Snowflake) AddColumns(tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { +func (sf *Snowflake) AddColumns(ctx context.Context, tableName string, columnsInfo []warehouseutils.ColumnInfo) (err error) { var ( query string queryBuilder strings.Builder @@ -1057,7 +1057,7 @@ func (sf *Snowflake) AddColumns(tableName string, columnsInfo []warehouseutils.C query += ";" sf.Logger.Infof("SF: Adding columns for destinationID: %s, tableName: %s with query: %v", sf.Warehouse.Destination.ID, tableName, query) - _, err = sf.DB.Exec(query) + _, err = sf.DB.ExecContext(ctx, query) // Handle error in case of single column if len(columnsInfo) == 1 { @@ -1071,15 +1071,15 @@ func (sf *Snowflake) AddColumns(tableName string, columnsInfo []warehouseutils.C return } -func (*Snowflake) AlterColumn(_, _, _ string) (model.AlterTableResponse, error) { +func (*Snowflake) AlterColumn(context.Context, string, string, string) (model.AlterTableResponse, error) { return model.AlterTableResponse{}, nil } // DownloadIdentityRules gets distinct combinations of anonymous_id, user_id from tables in warehouse -func (sf *Snowflake) DownloadIdentityRules(gzWriter *misc.GZipWriter) (err error) { +func (sf *Snowflake) DownloadIdentityRules(ctx context.Context, gzWriter *misc.GZipWriter) (err error) { getFromTable := func(tableName string) (err error) { var exists bool - exists, err = sf.tableExists(tableName) + exists, err = sf.tableExists(ctx, tableName) if err != nil || !exists { return } @@ -1087,17 +1087,17 @@ func (sf *Snowflake) DownloadIdentityRules(gzWriter *misc.GZipWriter) (err error schemaIdentifier := sf.schemaIdentifier() sqlStatement := fmt.Sprintf(`SELECT count(*) FROM %s.%q`, schemaIdentifier, tableName) var totalRows int64 - err = sf.DB.QueryRow(sqlStatement).Scan(&totalRows) + err = sf.DB.QueryRowContext(ctx, sqlStatement).Scan(&totalRows) if err != nil { return } // check if table in warehouse has anonymous_id and user_id and construct accordingly - hasAnonymousID, err := sf.columnExists("ANONYMOUS_ID", tableName) + hasAnonymousID, err := sf.columnExists(ctx, "ANONYMOUS_ID", tableName) if err != nil { return } - hasUserID, err := sf.columnExists("USER_ID", tableName) + hasUserID, err := sf.columnExists(ctx, "USER_ID", tableName) if err != nil { return } @@ -1121,7 +1121,7 @@ func (sf *Snowflake) DownloadIdentityRules(gzWriter *misc.GZipWriter) (err error sqlStatement = fmt.Sprintf(`SELECT DISTINCT %s FROM %s.%q LIMIT %d OFFSET %d`, toSelectFields, schemaIdentifier, tableName, batchSize, offset) sf.Logger.Infof("SF: Downloading distinct combinations of anonymous_id, user_id: %s, totalRows: %d", sqlStatement, totalRows) var rows *sql.Rows - rows, err = sf.DB.Query(sqlStatement) + rows, err = sf.DB.QueryContext(ctx, sqlStatement) if err != nil { return } @@ -1147,9 +1147,9 @@ func (sf *Snowflake) DownloadIdentityRules(gzWriter *misc.GZipWriter) (err error } else { csvRow = append(csvRow, "user_id", userID.String, "anonymous_id", anonymousID.String) } - csvWriter.Write(csvRow) + _ = csvWriter.Write(csvRow) csvWriter.Flush() - gzWriter.WriteGZ(buff.String()) + _ = gzWriter.WriteGZ(buff.String()) } offset += batchSize @@ -1170,23 +1170,23 @@ func (sf *Snowflake) DownloadIdentityRules(gzWriter *misc.GZipWriter) (err error return nil } -func (*Snowflake) CrashRecover() {} +func (*Snowflake) CrashRecover(context.Context) {} -func (sf *Snowflake) IsEmpty(warehouse model.Warehouse) (empty bool, err error) { +func (sf *Snowflake) IsEmpty(ctx context.Context, warehouse model.Warehouse) (empty bool, err error) { empty = true sf.Warehouse = warehouse sf.Namespace = warehouse.Namespace - sf.DB, err = sf.connect(optionalCreds{}) + sf.DB, err = sf.connect(ctx, optionalCreds{}) if err != nil { return } - defer sf.DB.Close() + defer func() { _ = sf.DB.Close() }() tables := []string{"TRACKS", "PAGES", "SCREENS", "IDENTIFIES", "ALIASES"} for _, tableName := range tables { var exists bool - exists, err = sf.tableExists(tableName) + exists, err = sf.tableExists(ctx, tableName) if err != nil { return } @@ -1196,7 +1196,7 @@ func (sf *Snowflake) IsEmpty(warehouse model.Warehouse) (empty bool, err error) schemaIdentifier := sf.schemaIdentifier() sqlStatement := fmt.Sprintf(`SELECT COUNT(*) FROM %s.%q`, schemaIdentifier, tableName) var count int64 - err = sf.DB.QueryRow(sqlStatement).Scan(&count) + err = sf.DB.QueryRowContext(ctx, sqlStatement).Scan(&count) if err != nil { return } @@ -1221,14 +1221,14 @@ func (sf *Snowflake) getConnectionCredentials(opts optionalCreds) Credentials { } } -func (sf *Snowflake) Setup(warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { +func (sf *Snowflake) Setup(ctx context.Context, warehouse model.Warehouse, uploader warehouseutils.Uploader) (err error) { sf.Warehouse = warehouse sf.Namespace = warehouse.Namespace sf.CloudProvider = warehouseutils.SnowflakeCloudProvider(warehouse.Destination.Config) sf.Uploader = uploader sf.ObjectStorage = warehouseutils.ObjectStorageType(warehouseutils.SNOWFLAKE, warehouse.Destination.Config, sf.Uploader.UseRudderStorage()) - sf.DB, err = sf.connect(optionalCreds{}) + sf.DB, err = sf.connect(ctx, optionalCreds{}) return err } @@ -1245,7 +1245,7 @@ func (sf *Snowflake) TestConnection(ctx context.Context, _ model.Warehouse) erro } // FetchSchema queries the snowflake database and returns the schema -func (sf *Snowflake) FetchSchema() (model.Schema, model.Schema, error) { +func (sf *Snowflake) FetchSchema(ctx context.Context) (model.Schema, model.Schema, error) { schema := make(model.Schema) unrecognizedSchema := make(model.Schema) @@ -1261,7 +1261,7 @@ func (sf *Snowflake) FetchSchema() (model.Schema, model.Schema, error) { table_schema = ? ` - rows, err := sf.DB.Query(sqlStatement, sf.Namespace) + rows, err := sf.DB.QueryContext(ctx, sqlStatement, sf.Namespace) if errors.Is(err, sql.ErrNoRows) { return schema, unrecognizedSchema, nil } @@ -1300,9 +1300,9 @@ func (sf *Snowflake) FetchSchema() (model.Schema, model.Schema, error) { return schema, unrecognizedSchema, nil } -func (sf *Snowflake) Cleanup() { +func (sf *Snowflake) Cleanup(context.Context) { if sf.DB != nil { - sf.DB.Close() + _ = sf.DB.Close() } } @@ -1331,7 +1331,7 @@ func (sf *Snowflake) GetTotalCountInTable(ctx context.Context, tableName string) return total, err } -func (sf *Snowflake) Connect(warehouse model.Warehouse) (client.Client, error) { +func (sf *Snowflake) Connect(ctx context.Context, warehouse model.Warehouse) (client.Client, error) { sf.Warehouse = warehouse sf.Namespace = warehouse.Namespace sf.CloudProvider = warehouseutils.SnowflakeCloudProvider(warehouse.Destination.Config) @@ -1340,7 +1340,7 @@ func (sf *Snowflake) Connect(warehouse model.Warehouse) (client.Client, error) { warehouse.Destination.Config, misc.IsConfiguredToUseRudderObjectStorage(sf.Warehouse.Destination.Config), ) - dbHandle, err := sf.connect(optionalCreds{}) + dbHandle, err := sf.connect(ctx, optionalCreds{}) if err != nil { return client.Client{}, err } @@ -1348,7 +1348,7 @@ func (sf *Snowflake) Connect(warehouse model.Warehouse) (client.Client, error) { return client.Client{Type: client.SQLClient, SQL: dbHandle.DB}, err } -func (sf *Snowflake) LoadTestTable(location, tableName string, _ map[string]interface{}, _ string) (err error) { +func (sf *Snowflake) LoadTestTable(ctx context.Context, location, tableName string, _ map[string]interface{}, _ string) (err error) { loadFolder := warehouseutils.GetObjectFolder(sf.ObjectStorage, location) schemaIdentifier := sf.schemaIdentifier() sqlStatement := fmt.Sprintf(`COPY INTO %v(%v) FROM '%v' %s PATTERN = '.*\.csv\.gz' @@ -1359,7 +1359,7 @@ func (sf *Snowflake) LoadTestTable(location, tableName string, _ map[string]inte sf.authString(), ) - _, err = sf.DB.Exec(sqlStatement) + _, err = sf.DB.ExecContext(ctx, sqlStatement) return } diff --git a/warehouse/integrations/testhelper/verify.go b/warehouse/integrations/testhelper/verify.go index e71baf3200..05c2b7c258 100644 --- a/warehouse/integrations/testhelper/verify.go +++ b/warehouse/integrations/testhelper/verify.go @@ -1,6 +1,7 @@ package testhelper import ( + "context" "database/sql" "encoding/base64" "encoding/json" @@ -284,7 +285,7 @@ func VerifyConfigurationTest(t testing.TB, destination backendconfig.Destination t.Logf("Started configuration tests for destination type: %s", destination.DestinationDefinition.Name) require.NoError(t, WithConstantRetries(func() error { - response := validations.NewDestinationValidator().Validate(&destination) + response := validations.NewDestinationValidator().Validate(context.Background(), &destination) if !response.Success { return fmt.Errorf("failed to validate credentials for destination: %s with error: %s", destination.DestinationDefinition.Name, response.Error) } diff --git a/warehouse/internal/repo/staging_test.go b/warehouse/internal/repo/staging_test.go index ec51af037a..bb0bc3a7f6 100644 --- a/warehouse/internal/repo/staging_test.go +++ b/warehouse/internal/repo/staging_test.go @@ -185,7 +185,7 @@ func TestStagingFileRepo_Many(t *testing.T) { t.Run("GetForUploadID", func(t *testing.T) { t.Parallel() u := repo.NewUploads(db) - uploadId, err := u.CreateWithStagingFiles(context.TODO(), model.Upload{}, stagingFiles) + uploadId, err := u.CreateWithStagingFiles(ctx, model.Upload{}, stagingFiles) require.NoError(t, err) testcases := []struct { name string diff --git a/warehouse/internal/service/loadfiles/downloader/downloader.go b/warehouse/internal/service/loadfiles/downloader/downloader.go index 9f02e4954a..960cb881ab 100644 --- a/warehouse/internal/service/loadfiles/downloader/downloader.go +++ b/warehouse/internal/service/loadfiles/downloader/downloader.go @@ -46,7 +46,7 @@ func (l *downloaderImpl) Download(ctx context.Context, tableName string) ([]stri fileNamesLock sync.RWMutex ) - objects := l.uploader.GetLoadFilesMetadata(warehouseutils.GetLoadFilesOptions{Table: tableName}) + objects := l.uploader.GetLoadFilesMetadata(ctx, warehouseutils.GetLoadFilesOptions{Table: tableName}) storageProvider := warehouseutils.ObjectStorageType( l.warehouse.Destination.DestinationDefinition.Name, l.warehouse.Destination.Config, diff --git a/warehouse/internal/service/loadfiles/downloader/downloader_test.go b/warehouse/internal/service/loadfiles/downloader/downloader_test.go index b95eee187c..b36e4c0d4d 100644 --- a/warehouse/internal/service/loadfiles/downloader/downloader_test.go +++ b/warehouse/internal/service/loadfiles/downloader/downloader_test.go @@ -26,32 +26,38 @@ type mockUploader struct { loadFiles []warehouseutils.LoadFile } -func (*mockUploader) GetSchemaInWarehouse() model.Schema { return model.Schema{} } -func (*mockUploader) GetLocalSchema() (model.Schema, error) { return model.Schema{}, nil } -func (*mockUploader) UpdateLocalSchema(model.Schema) error { return nil } -func (*mockUploader) ShouldOnDedupUseNewRecord() bool { return false } -func (*mockUploader) GetLoadFileGenStartTIme() time.Time { return time.Time{} } -func (*mockUploader) GetLoadFileType() string { return "JSON" } -func (*mockUploader) GetFirstLastEvent() (time.Time, time.Time) { return time.Time{}, time.Time{} } -func (*mockUploader) GetSampleLoadFileLocation(string) (string, error) { return "", nil } -func (*mockUploader) UseRudderStorage() bool { return false } +func (*mockUploader) GetSchemaInWarehouse() model.Schema { return model.Schema{} } +func (*mockUploader) GetLocalSchema(context.Context) (model.Schema, error) { + return model.Schema{}, nil +} +func (*mockUploader) UpdateLocalSchema(context.Context, model.Schema) error { return nil } +func (*mockUploader) ShouldOnDedupUseNewRecord() bool { return false } +func (*mockUploader) GetLoadFileGenStartTIme() time.Time { return time.Time{} } +func (*mockUploader) GetLoadFileType() string { return "JSON" } +func (*mockUploader) GetFirstLastEvent() (time.Time, time.Time) { return time.Time{}, time.Time{} } +func (*mockUploader) GetSampleLoadFileLocation(context.Context, string) (string, error) { + return "", nil +} +func (*mockUploader) UseRudderStorage() bool { return false } func (*mockUploader) GetTableSchemaInWarehouse(string) model.TableSchema { return model.TableSchema{} } -func (*mockUploader) GetSingleLoadFile(string) (warehouseutils.LoadFile, error) { +func (*mockUploader) GetSingleLoadFile(context.Context, string) (warehouseutils.LoadFile, error) { return warehouseutils.LoadFile{}, nil } -func (m *mockUploader) GetTableSchemaInUpload(string) model.TableSchema { +func (*mockUploader) GetTableSchemaInUpload(string) model.TableSchema { return model.TableSchema{} } -func (m *mockUploader) GetLoadFilesMetadata(warehouseutils.GetLoadFilesOptions) []warehouseutils.LoadFile { +func (m *mockUploader) GetLoadFilesMetadata(context.Context, warehouseutils.GetLoadFilesOptions) []warehouseutils.LoadFile { return m.loadFiles } func TestDownloader(t *testing.T) { + t.Parallel() + misc.Init() pool, err := dockertest.NewPool("") diff --git a/warehouse/internal/service/recovery.go b/warehouse/internal/service/recovery.go index 5fa4f8a6da..54e04199a0 100644 --- a/warehouse/internal/service/recovery.go +++ b/warehouse/internal/service/recovery.go @@ -25,7 +25,7 @@ type repo interface { } type destination interface { - CrashRecover() + CrashRecover(ctx context.Context) } type Recovery struct { @@ -77,7 +77,7 @@ func (r *Recovery) Recover(ctx context.Context, whManager destination, wh model. } once.Do(func() { - whManager.CrashRecover() + whManager.CrashRecover(ctx) }) return nil diff --git a/warehouse/internal/service/recovery_test.go b/warehouse/internal/service/recovery_test.go index e4595d811a..ae32d384aa 100644 --- a/warehouse/internal/service/recovery_test.go +++ b/warehouse/internal/service/recovery_test.go @@ -26,7 +26,7 @@ func (r *mockRepo) InterruptedDestinations(_ context.Context, destinationType st return r.m[destinationType], r.err } -func (d *mockDestination) CrashRecover() { +func (d *mockDestination) CrashRecover(_ context.Context) { d.recovered += 1 } diff --git a/warehouse/jobs/jobs.go b/warehouse/jobs/jobs.go index 5003ce265b..84f7b8b1a6 100644 --- a/warehouse/jobs/jobs.go +++ b/warehouse/jobs/jobs.go @@ -1,6 +1,7 @@ package jobs import ( + "context" "time" "github.com/rudderlabs/rudder-server/warehouse/internal/model" @@ -14,11 +15,11 @@ func (*WhAsyncJob) GetSchemaInWarehouse() model.Schema { return model.Schema{} } -func (*WhAsyncJob) GetLocalSchema() (model.Schema, error) { +func (*WhAsyncJob) GetLocalSchema(context.Context) (model.Schema, error) { return model.Schema{}, nil } -func (*WhAsyncJob) UpdateLocalSchema(model.Schema) error { +func (*WhAsyncJob) UpdateLocalSchema(context.Context, model.Schema) error { return nil } @@ -30,15 +31,15 @@ func (*WhAsyncJob) GetTableSchemaInUpload(string) model.TableSchema { return model.TableSchema{} } -func (*WhAsyncJob) GetLoadFilesMetadata(warehouseutils.GetLoadFilesOptions) []warehouseutils.LoadFile { +func (*WhAsyncJob) GetLoadFilesMetadata(context.Context, warehouseutils.GetLoadFilesOptions) []warehouseutils.LoadFile { return []warehouseutils.LoadFile{} } -func (*WhAsyncJob) GetSampleLoadFileLocation(string) (string, error) { +func (*WhAsyncJob) GetSampleLoadFileLocation(context.Context, string) (string, error) { return "", nil } -func (*WhAsyncJob) GetSingleLoadFile(string) (warehouseutils.LoadFile, error) { +func (*WhAsyncJob) GetSingleLoadFile(context.Context, string) (warehouseutils.LoadFile, error) { return warehouseutils.LoadFile{}, nil } diff --git a/warehouse/schema.go b/warehouse/schema.go index 97fe246904..6df5116a37 100644 --- a/warehouse/schema.go +++ b/warehouse/schema.go @@ -41,7 +41,7 @@ type stagingFileRepo interface { } type fetchSchemaRepo interface { - FetchSchema() (model.Schema, model.Schema, error) + FetchSchema(ctx context.Context) (model.Schema, model.Schema, error) } type Schema struct { @@ -74,8 +74,8 @@ func NewSchema( } } -func (sh *Schema) updateLocalSchema(uploadId int64, updatedSchema model.Schema) error { - _, err := sh.schemaRepo.Insert(context.TODO(), &model.WHSchema{ +func (sh *Schema) updateLocalSchema(ctx context.Context, uploadId int64, updatedSchema model.Schema) error { + _, err := sh.schemaRepo.Insert(ctx, &model.WHSchema{ UploadID: uploadId, SourceID: sh.warehouse.Source.ID, Namespace: sh.warehouse.Namespace, @@ -93,8 +93,8 @@ func (sh *Schema) updateLocalSchema(uploadId int64, updatedSchema model.Schema) } // fetchSchemaFromLocal fetches schema from local -func (sh *Schema) fetchSchemaFromLocal() error { - localSchema, err := sh.getLocalSchema() +func (sh *Schema) fetchSchemaFromLocal(ctx context.Context) error { + localSchema, err := sh.getLocalSchema(ctx) if err != nil { return fmt.Errorf("fetching schema from local: %w", err) } @@ -104,9 +104,9 @@ func (sh *Schema) fetchSchemaFromLocal() error { return nil } -func (sh *Schema) getLocalSchema() (model.Schema, error) { +func (sh *Schema) getLocalSchema(ctx context.Context) (model.Schema, error) { whSchema, err := sh.schemaRepo.GetForNamespace( - context.TODO(), + ctx, sh.warehouse.Source.ID, sh.warehouse.Destination.ID, sh.warehouse.Namespace, @@ -121,8 +121,8 @@ func (sh *Schema) getLocalSchema() (model.Schema, error) { } // fetchSchemaFromWarehouse fetches schema from warehouse -func (sh *Schema) fetchSchemaFromWarehouse(repo fetchSchemaRepo) error { - warehouseSchema, unrecognizedWarehouseSchema, err := repo.FetchSchema() +func (sh *Schema) fetchSchemaFromWarehouse(ctx context.Context, repo fetchSchemaRepo) error { + warehouseSchema, unrecognizedWarehouseSchema, err := repo.FetchSchema(ctx) if err != nil { return fmt.Errorf("fetching schema from warehouse: %w", err) } @@ -157,8 +157,8 @@ func (sh *Schema) skipDeprecatedColumns(schema model.Schema) { } } -func (sh *Schema) prepareUploadSchema(stagingFiles []*model.StagingFile) error { - consolidatedSchema, err := sh.consolidateStagingFilesSchemaUsingWarehouseSchema(stagingFiles) +func (sh *Schema) prepareUploadSchema(ctx context.Context, stagingFiles []*model.StagingFile) error { + consolidatedSchema, err := sh.consolidateStagingFilesSchemaUsingWarehouseSchema(ctx, stagingFiles) if err != nil { return fmt.Errorf("consolidating staging files schema: %w", err) } @@ -168,12 +168,12 @@ func (sh *Schema) prepareUploadSchema(stagingFiles []*model.StagingFile) error { } // consolidateStagingFilesSchemaUsingWarehouseSchema consolidates staging files schema with warehouse schema -func (sh *Schema) consolidateStagingFilesSchemaUsingWarehouseSchema(stagingFiles []*model.StagingFile) (model.Schema, error) { +func (sh *Schema) consolidateStagingFilesSchemaUsingWarehouseSchema(ctx context.Context, stagingFiles []*model.StagingFile) (model.Schema, error) { consolidatedSchema := model.Schema{} batches := lo.Chunk(stagingFiles, sh.stagingFilesSchemaPaginationSize) for _, batch := range batches { schemas, err := sh.stagingFileRepo.GetSchemasByIDs( - context.TODO(), + ctx, repo.StagingFileIDs(batch), ) if err != nil { diff --git a/warehouse/schema_test.go b/warehouse/schema_test.go index e74a6ee638..893a72d62c 100644 --- a/warehouse/schema_test.go +++ b/warehouse/schema_test.go @@ -303,7 +303,7 @@ type mockFetchSchemaFromWarehouse struct { err error } -func (m *mockFetchSchemaFromWarehouse) FetchSchema() (model.Schema, model.Schema, error) { +func (m *mockFetchSchemaFromWarehouse) FetchSchema(context.Context) (model.Schema, model.Schema, error) { return m.schemaInWarehouse, m.unrecognizedSchemaInWarehouse, m.err } @@ -398,14 +398,16 @@ func TestSchema_GetUpdateLocalSchema(t *testing.T) { schemaRepo: mockSchemaRepo, } - err := sch.updateLocalSchema(uploadID, tc.mockSchema.Schema) + ctx := context.Background() + + err := sch.updateLocalSchema(ctx, uploadID, tc.mockSchema.Schema) if tc.wantError == nil { require.NoError(t, err) } else { require.ErrorContains(t, err, tc.wantError.Error()) } - err = sch.fetchSchemaFromLocal() + err = sch.fetchSchemaFromLocal(ctx) require.Equal(t, tc.wantSchema, sch.localSchema) if tc.wantError == nil { require.NoError(t, err) @@ -566,7 +568,9 @@ func TestSchema_FetchSchemaFromWarehouse(t *testing.T) { log: logger.NOP, } - err := sh.fetchSchemaFromWarehouse(&fechSchemaRepo) + ctx := context.Background() + + err := sh.fetchSchemaFromWarehouse(ctx, &fechSchemaRepo) if tc.wantError != nil { require.EqualError(t, err, tc.wantError.Error()) } else { @@ -887,6 +891,8 @@ func TestSchema_PrepareUploadSchema(t *testing.T) { } }) + ctx := context.Background() + testsCases := []struct { name string warehouseType string @@ -1791,7 +1797,7 @@ func TestSchema_PrepareUploadSchema(t *testing.T) { stagingFilesSchemaPaginationSize: 2, } - err := sh.prepareUploadSchema(stagingFiles) + err := sh.prepareUploadSchema(ctx, stagingFiles) if tc.wantError != nil { require.EqualError(t, err, tc.wantError.Error()) } else { diff --git a/warehouse/slave.go b/warehouse/slave.go index 85e276068f..751bf18ae0 100644 --- a/warehouse/slave.go +++ b/warehouse/slave.go @@ -130,7 +130,7 @@ func (job *Payload) getFileManager(config interface{}, useRudderStorage bool) (f * If error occurs with the current config and current revision is different from staging revision * We retry with the staging revision config if it is present */ -func (jobRun *JobRun) downloadStagingFile() error { +func (jobRun *JobRun) downloadStagingFile(ctx context.Context) error { job := jobRun.job downloadTask := func(config interface{}, useRudderStorage bool) (err error) { filePath := jobRun.stagingFilePath @@ -147,7 +147,7 @@ func (jobRun *JobRun) downloadStagingFile() error { downloadStart := time.Now() - err = downloader.Download(context.TODO(), file, job.StagingFileLocation) + err = downloader.Download(ctx, file, job.StagingFileLocation) if err != nil { pkgLogger.Errorf("[WH]: Failed to download file") return err @@ -214,7 +214,7 @@ type loadFileUploadOutput struct { UseRudderStorage bool } -func (jobRun *JobRun) uploadLoadFilesToObjectStorage() ([]loadFileUploadOutput, error) { +func (jobRun *JobRun) uploadLoadFilesToObjectStorage(ctx context.Context) ([]loadFileUploadOutput, error) { job := jobRun.job uploader, err := job.getFileManager(job.DestinationConfig, job.UseRudderStorage) if err != nil { @@ -233,7 +233,7 @@ func (jobRun *JobRun) uploadLoadFilesToObjectStorage() ([]loadFileUploadOutput, // close chan to avoid memory leak ranging over it defer close(uploadJobChan) uploadErrorChan := make(chan error, numLoadFileUploadWorkers) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) defer cancel() for i := 0; i < numLoadFileUploadWorkers; i++ { go func(ctx context.Context) { @@ -245,7 +245,7 @@ func (jobRun *JobRun) uploadLoadFilesToObjectStorage() ([]loadFileUploadOutput, default: tableName := uploadJob.tableName loadFileUploadStart := time.Now() - uploadOutput, err := jobRun.uploadLoadFileToObjectStorage(uploader, uploadJob.outputFile, tableName) + uploadOutput, err := jobRun.uploadLoadFileToObjectStorage(ctx, uploader, uploadJob.outputFile, tableName) if err != nil { uploadErrorChan <- err return @@ -293,7 +293,7 @@ func (jobRun *JobRun) uploadLoadFilesToObjectStorage() ([]loadFileUploadOutput, } } -func (jobRun *JobRun) uploadLoadFileToObjectStorage(uploader filemanager.FileManager, uploadFile encoding.LoadFileWriter, tableName string) (filemanager.UploadOutput, error) { +func (jobRun *JobRun) uploadLoadFileToObjectStorage(ctx context.Context, uploader filemanager.FileManager, uploadFile encoding.LoadFileWriter, tableName string) (filemanager.UploadOutput, error) { job := jobRun.job file, err := os.Open(uploadFile.GetLoadFile().Name()) // opens file in read mode if err != nil { @@ -304,9 +304,9 @@ func (jobRun *JobRun) uploadLoadFileToObjectStorage(uploader filemanager.FileMan pkgLogger.Debugf("[WH]: %s: Uploading load_file to %s for table: %s with staging_file id: %v", job.DestinationType, warehouseutils.ObjectStorageType(job.DestinationType, job.DestinationConfig, job.UseRudderStorage), tableName, job.StagingFileID) var uploadLocation filemanager.UploadOutput if slices.Contains(warehouseutils.TimeWindowDestinations, job.DestinationType) { - uploadLocation, err = uploader.Upload(context.TODO(), file, warehouseutils.GetTablePathInObjectStorage(jobRun.job.DestinationNamespace, tableName), job.LoadFilePrefix) + uploadLocation, err = uploader.Upload(ctx, file, warehouseutils.GetTablePathInObjectStorage(jobRun.job.DestinationNamespace, tableName), job.LoadFilePrefix) } else { - uploadLocation, err = uploader.Upload(context.TODO(), file, config.GetString("WAREHOUSE_BUCKET_LOAD_OBJECTS_FOLDER_NAME", "rudder-warehouse-load-objects"), tableName, job.SourceID, getBucketFolder(job.UniqueLoadGenID, tableName)) + uploadLocation, err = uploader.Upload(ctx, file, config.GetString("WAREHOUSE_BUCKET_LOAD_OBJECTS_FOLDER_NAME", "rudder-warehouse-load-objects"), tableName, job.SourceID, getBucketFolder(job.UniqueLoadGenID, tableName)) } return uploadLocation, err } @@ -385,7 +385,7 @@ func (event *BatchRouterEvent) GetColumnInfo(columnName string) (columnInfo ware // 5. Delete the staging and load files from tmp directory // -func processStagingFile(job Payload, workerIndex int) (loadFileUploadOutputs []loadFileUploadOutput, err error) { +func processStagingFile(ctx context.Context, job Payload, workerIndex int) (loadFileUploadOutputs []loadFileUploadOutput, err error) { processStartTime := time.Now() jobRun := JobRun{ job: job, @@ -404,7 +404,7 @@ func processStagingFile(job Payload, workerIndex int) (loadFileUploadOutputs []l jobRun.setStagingFileDownloadPath(workerIndex) // This creates the file, so on successful creation remove it - err = jobRun.downloadStagingFile() + err = jobRun.downloadStagingFile(ctx) if err != nil { return loadFileUploadOutputs, err } @@ -586,11 +586,11 @@ func processStagingFile(job Payload, workerIndex int) (loadFileUploadOutputs []l pkgLogger.Errorf("Error while closing load file %s : %v", loadFile.GetLoadFile().Name(), err) } } - loadFileUploadOutputs, err = jobRun.uploadLoadFilesToObjectStorage() + loadFileUploadOutputs, err = jobRun.uploadLoadFilesToObjectStorage(ctx) return loadFileUploadOutputs, err } -func processClaimedUploadJob(claimedJob pgnotifier.Claim, workerIndex int) { +func processClaimedUploadJob(ctx context.Context, claimedJob pgnotifier.Claim, workerIndex int) { claimProcessTimeStart := time.Now() defer func() { warehouseutils.NewTimerStat(statsWorkerClaimProcessingTime, warehouseutils.Tag{Name: tagWorkerid, Value: fmt.Sprintf("%d", workerIndex)}).Since(claimProcessTimeStart) @@ -612,7 +612,7 @@ func processClaimedUploadJob(claimedJob pgnotifier.Claim, workerIndex int) { } job.BatchID = claimedJob.BatchID pkgLogger.Infof(`Starting processing staging-file:%v from claim:%v`, job.StagingFileID, claimedJob.ID) - loadFileOutputs, err := processStagingFile(job, workerIndex) + loadFileOutputs, err := processStagingFile(ctx, job, workerIndex) if err != nil { handleErr(err, claimedJob) return @@ -636,7 +636,7 @@ type AsyncJobRunResult struct { Id string } -func runAsyncJob(asyncjob jobs.AsyncJobPayload) (AsyncJobRunResult, error) { +func runAsyncJob(ctx context.Context, asyncjob jobs.AsyncJobPayload) (AsyncJobRunResult, error) { warehouse, err := getDestinationFromSlaveConnectionMap(asyncjob.DestinationID, asyncjob.SourceID) if err != nil { return AsyncJobRunResult{Id: asyncjob.Id, Result: false}, err @@ -653,11 +653,11 @@ func runAsyncJob(asyncjob jobs.AsyncJobPayload) (AsyncJobRunResult, error) { if err != nil { return AsyncJobRunResult{Id: asyncjob.Id, Result: false}, err } - err = whManager.Setup(warehouse, whasyncjob) + err = whManager.Setup(ctx, warehouse, whasyncjob) if err != nil { return AsyncJobRunResult{Id: asyncjob.Id, Result: false}, err } - defer whManager.Cleanup() + defer whManager.Cleanup(ctx) tableNames := []string{asyncjob.TableName} if asyncjob.AsyncJobType == "deletebyjobrunid" { pkgLogger.Info("[WH-Jobs]: Running DeleteByJobRunID on slave worker") @@ -668,7 +668,7 @@ func runAsyncJob(asyncjob jobs.AsyncJobPayload) (AsyncJobRunResult, error) { JobRunId: metadata.JobRunId, StartTime: metadata.StartTime, } - err = whManager.DeleteBy(tableNames, params) + err = whManager.DeleteBy(ctx, tableNames, params) } asyncJobRunResult := AsyncJobRunResult{ Result: err == nil, @@ -677,7 +677,7 @@ func runAsyncJob(asyncjob jobs.AsyncJobPayload) (AsyncJobRunResult, error) { return asyncJobRunResult, err } -func processClaimedAsyncJob(claimedJob pgnotifier.Claim) { +func processClaimedAsyncJob(ctx context.Context, claimedJob pgnotifier.Claim) { pkgLogger.Infof("[WH-Jobs]: Got request for processing Async Job with Batch ID %s", claimedJob.BatchID) handleErr := func(err error, claim pgnotifier.Claim) { pkgLogger.Errorf("[WH]: Error processing claim: %v", err) @@ -692,7 +692,7 @@ func processClaimedAsyncJob(claimedJob pgnotifier.Claim) { handleErr(err, claimedJob) return } - result, err := runAsyncJob(job) + result, err := runAsyncJob(ctx, job) if err != nil { handleErr(err, claimedJob) return @@ -726,9 +726,9 @@ func setupSlave(ctx context.Context) error { pkgLogger.Infof("[WH]: Successfully claimed job:%v by slave worker-%v-%v & job type %s", claimedJob.ID, idx, slaveID, claimedJob.JobType) if claimedJob.JobType == jobs.AsyncJobType { - processClaimedAsyncJob(claimedJob) + processClaimedAsyncJob(ctx, claimedJob) } else { - processClaimedUploadJob(claimedJob, idx) + processClaimedUploadJob(ctx, claimedJob, idx) } pkgLogger.Infof("[WH]: Successfully processed job:%v by slave worker-%v-%v", claimedJob.ID, idx, slaveID) diff --git a/warehouse/stats.go b/warehouse/stats.go index 2f258a0b0a..777123a71b 100644 --- a/warehouse/stats.go +++ b/warehouse/stats.go @@ -1,7 +1,6 @@ package warehouse import ( - "context" "fmt" "strconv" "strings" @@ -119,7 +118,7 @@ func (job *UploadJob) generateUploadSuccessMetrics() { err error ) numUploadedEvents, err = job.tableUploadsRepo.TotalExportedEvents( - context.TODO(), + job.ctx, job.upload.ID, []string{}, ) @@ -136,7 +135,7 @@ func (job *UploadJob) generateUploadSuccessMetrics() { } numStagedEvents, err = repo.NewStagingFiles(dbHandle).TotalEventsForUpload( - context.TODO(), + job.ctx, job.upload, ) if err != nil { @@ -168,7 +167,7 @@ func (job *UploadJob) generateUploadAbortedMetrics() { err error ) numUploadedEvents, err = job.tableUploadsRepo.TotalExportedEvents( - context.TODO(), + job.ctx, job.upload.ID, []string{}, ) @@ -185,7 +184,7 @@ func (job *UploadJob) generateUploadAbortedMetrics() { } numStagedEvents, err = repo.NewStagingFiles(dbHandle).TotalEventsForUpload( - context.TODO(), + job.ctx, job.upload, ) if err != nil { @@ -228,7 +227,7 @@ func (job *UploadJob) recordTableLoad(tableName string, numEvents int64) { Value: strings.ToLower(tableName), }).Count(int(numEvents)) // Delay for the oldest event in the batch - firstEventAt, err := repo.NewStagingFiles(dbHandle).FirstEventForUpload(context.TODO(), job.upload) + firstEventAt, err := repo.NewStagingFiles(dbHandle).FirstEventForUpload(job.ctx, job.upload) if err != nil { pkgLogger.Errorf("[WH]: Failed to generate delay metrics: %s, Err: %v", job.warehouse.Identifier, err) return @@ -254,7 +253,7 @@ func (job *UploadJob) recordLoadFileGenerationTimeStat(startID, endID int64) (er (SELECT created_at FROM %[1]s WHERE id=%[3]d) f2 `, warehouseutils.WarehouseLoadFilesTable, startID, endID) var timeTakenInS time.Duration - err = job.dbHandle.QueryRow(stmt).Scan(&timeTakenInS) + err = job.dbHandle.QueryRowContext(job.ctx, stmt).Scan(&timeTakenInS) if err != nil { return } diff --git a/warehouse/stats_test.go b/warehouse/stats_test.go index 26b226cf6a..b0555b6fb1 100644 --- a/warehouse/stats_test.go +++ b/warehouse/stats_test.go @@ -35,7 +35,7 @@ var _ = Describe("Stats", Ordered, func() { initWarehouse() - err = setupDB(context.TODO(), getConnectionString()) + err = setupDB(context.Background(), getConnectionString()) Expect(err).To(BeNil()) sqlStatement, err := os.ReadFile("testdata/sql/stats_test.sql") @@ -73,6 +73,7 @@ var _ = Describe("Stats", Ordered, func() { }, stats: mockStats, tableUploadsRepo: repo.NewTableUploads(pgResource.DB), + ctx: context.Background(), } }) @@ -102,6 +103,7 @@ var _ = Describe("Stats", Ordered, func() { }, stats: mockStats, tableUploadsRepo: repo.NewTableUploads(pgResource.DB), + ctx: context.Background(), } }) @@ -127,6 +129,7 @@ var _ = Describe("Stats", Ordered, func() { Type: "POSTGRES", }, stats: mockStats, + ctx: context.Background(), } job.recordTableLoad("tracks", 4) }) @@ -147,6 +150,7 @@ var _ = Describe("Stats", Ordered, func() { }, dbHandle: pgResource.DB, stats: mockStats, + ctx: context.Background(), } err = job.recordLoadFileGenerationTimeStat(1, 4) diff --git a/warehouse/upload.go b/warehouse/upload.go index e040e96c1c..98004f85ec 100644 --- a/warehouse/upload.go +++ b/warehouse/upload.go @@ -77,6 +77,7 @@ type UploadJobFactory struct { } type UploadJob struct { + ctx context.Context dbHandle *sql.DB destinationValidator validations.DestinationValidator loadfile *loadfiles.LoadFileGenerator @@ -94,18 +95,18 @@ type UploadJob struct { stagingFileIDs []int64 schemaLock sync.Mutex uploadLock sync.Mutex - AlertSender alerta.AlertSender - Now func() time.Time + alertSender alerta.AlertSender + now func() time.Time pendingTableUploads []model.PendingTableUpload pendingTableUploadsRepo pendingTableUploadsRepo pendingTableUploadsOnce sync.Once pendingTableUploadsError error - RefreshPartitionBatchSize int - RetryTimeWindow time.Duration - MinRetryAttempts int + refreshPartitionBatchSize int + retryTimeWindow time.Duration + minRetryAttempts int - ErrorHandler ErrorHandler + errorHandler ErrorHandler } type UploadColumn struct { @@ -171,8 +172,9 @@ func setMaxParallelLoads() { } } -func (f *UploadJobFactory) NewUploadJob(dto *model.UploadJob, whManager manager.Manager) *UploadJob { +func (f *UploadJobFactory) NewUploadJob(ctx context.Context, dto *model.UploadJob, whManager manager.Manager) *UploadJob { return &UploadJob{ + ctx: ctx, dbHandle: f.dbHandle, loadfile: f.loadFile, recovery: f.recovery, @@ -191,16 +193,16 @@ func (f *UploadJobFactory) NewUploadJob(dto *model.UploadJob, whManager manager. pendingTableUploadsRepo: repo.NewUploads(f.dbHandle), pendingTableUploads: []model.PendingTableUpload{}, - RefreshPartitionBatchSize: config.GetInt("Warehouse.refreshPartitionBatchSize", 100), - RetryTimeWindow: retryTimeWindow, - MinRetryAttempts: minRetryAttempts, + refreshPartitionBatchSize: config.GetInt("Warehouse.refreshPartitionBatchSize", 100), + retryTimeWindow: retryTimeWindow, + minRetryAttempts: minRetryAttempts, - AlertSender: alerta.NewClient( + alertSender: alerta.NewClient( config.GetString("ALERTA_URL", "https://alerta.rudderstack.com/api/"), ), - Now: timeutil.Now, + now: timeutil.Now, - ErrorHandler: ErrorHandler{whManager}, + errorHandler: ErrorHandler{whManager}, } } @@ -244,6 +246,7 @@ func (job *UploadJob) trackLongRunningUpload() chan struct{} { func (job *UploadJob) generateUploadSchema() error { err := job.schemaHandle.prepareUploadSchema( + job.ctx, job.stagingFiles, ) if err != nil { @@ -275,19 +278,19 @@ func (job *UploadJob) initTableUploads() error { } return job.tableUploadsRepo.Insert( - context.TODO(), + job.ctx, job.upload.ID, tables, ) } func (job *UploadJob) syncRemoteSchema() (bool, error) { - err := job.schemaHandle.fetchSchemaFromLocal() + err := job.schemaHandle.fetchSchemaFromLocal(job.ctx) if err != nil { return false, fmt.Errorf("fetching schema from local: %w", err) } - err = job.schemaHandle.fetchSchemaFromWarehouse(job.whManager) + err = job.schemaHandle.fetchSchemaFromWarehouse(job.ctx, job.whManager) if err != nil { return false, fmt.Errorf("fetching schema from warehouse: %w", err) } @@ -303,7 +306,7 @@ func (job *UploadJob) syncRemoteSchema() (bool, error) { logfield.Namespace, job.warehouse.Namespace, ) - err = job.schemaHandle.updateLocalSchema(job.upload.ID, job.schemaHandle.schemaInWarehouse) + err = job.schemaHandle.updateLocalSchema(job.ctx, job.upload.ID, job.schemaHandle.schemaInWarehouse) if err != nil { return false, fmt.Errorf("updating local schema: %w", err) } @@ -342,7 +345,7 @@ func (job *UploadJob) getTotalRowsInLoadFiles() int64 { misc.IntArrayToString(job.stagingFileIDs, ","), warehouseutils.ToProviderCase(job.warehouse.Type, warehouseutils.DiscardsTable), ) - err := dbHandle.QueryRow(sqlStatement).Scan(&total) + err := dbHandle.QueryRowContext(job.ctx, sqlStatement).Scan(&total) if err != nil { pkgLogger.Errorf(`Error in getTotalRowsInLoadFiles: %v`, err) } @@ -364,7 +367,7 @@ func (job *UploadJob) matchRowsInStagingAndLoadFiles(ctx context.Context) error func (job *UploadJob) run() (err error) { timerStat := job.timerStat("upload_time") - start := job.Now() + start := job.now() ch := job.trackLongRunningUpload() defer func() { _ = job.setUploadColumns(UploadColumnsOpts{Fields: []UploadColumn{{Column: UploadInProgress, Value: false}}}) @@ -375,7 +378,7 @@ func (job *UploadJob) run() (err error) { job.uploadLock.Lock() defer job.uploadLock.Unlock() - _ = job.setUploadColumns(UploadColumnsOpts{Fields: []UploadColumn{{Column: UploadLastExecAtField, Value: job.Now()}, {Column: UploadInProgress, Value: true}}}) + _ = job.setUploadColumns(UploadColumnsOpts{Fields: []UploadColumn{{Column: UploadLastExecAtField, Value: job.now()}, {Column: UploadInProgress, Value: true}}}) if len(job.stagingFiles) == 0 { err := fmt.Errorf("no staging files found") @@ -384,14 +387,14 @@ func (job *UploadJob) run() (err error) { } whManager := job.whManager - err = whManager.Setup(job.warehouse, job) + err = whManager.Setup(job.ctx, job.warehouse, job) if err != nil { _, _ = job.setUploadError(err, InternalProcessingFailed) return err } - defer whManager.Cleanup() + defer whManager.Cleanup(job.ctx) - err = job.recovery.Recover(context.TODO(), whManager, job.warehouse) + err = job.recovery.Recover(job.ctx, whManager, job.warehouse) if err != nil { _, _ = job.setUploadError(err, InternalProcessingFailed) return err @@ -425,7 +428,7 @@ func (job *UploadJob) run() (err error) { } for { - stateStartTime := job.Now() + stateStartTime := job.now() err = nil _ = job.setUploadStatus(UploadStatusOpts{Status: nextUploadState.inProgress}) @@ -456,9 +459,9 @@ func (job *UploadJob) run() (err error) { generateAll := hasSchemaChanged || slices.Contains(warehousesToAlwaysRegenerateAllLoadFilesOnResume, job.warehouse.Type) || config.GetBool("Warehouse.alwaysRegenerateAllLoadFiles", true) var startLoadFileID, endLoadFileID int64 if generateAll { - startLoadFileID, endLoadFileID, err = job.loadfile.ForceCreateLoadFiles(context.TODO(), job.DTO()) + startLoadFileID, endLoadFileID, err = job.loadfile.ForceCreateLoadFiles(job.ctx, job.DTO()) } else { - startLoadFileID, endLoadFileID, err = job.loadfile.CreateLoadFiles(context.TODO(), job.DTO()) + startLoadFileID, endLoadFileID, err = job.loadfile.CreateLoadFiles(job.ctx, job.DTO()) } if err != nil { break @@ -469,7 +472,7 @@ func (job *UploadJob) run() (err error) { break } - err = job.matchRowsInStagingAndLoadFiles(context.TODO()) + err = job.matchRowsInStagingAndLoadFiles(job.ctx) if err != nil { break } @@ -482,7 +485,7 @@ func (job *UploadJob) run() (err error) { newStatus = nextUploadState.failed for tableName := range job.upload.UploadSchema { err = job.tableUploadsRepo.PopulateTotalEventsFromStagingFileIDs( - context.TODO(), + job.ctx, job.upload.ID, tableName, job.stagingFileIDs, @@ -500,7 +503,7 @@ func (job *UploadJob) run() (err error) { case model.CreatedRemoteSchema: newStatus = nextUploadState.failed if len(job.schemaHandle.schemaInWarehouse) == 0 { - err = whManager.CreateSchema() + err = whManager.CreateSchema(job.ctx) if err != nil { break } @@ -618,7 +621,7 @@ func (job *UploadJob) run() (err error) { uploadStatusOpts := UploadStatusOpts{Status: newStatus} if newStatus == model.ExportedData { - rowCount, _ := repo.NewStagingFiles(dbHandle).TotalEventsForUpload(context.TODO(), job.upload) + rowCount, _ := repo.NewStagingFiles(dbHandle).TotalEventsForUpload(job.ctx, job.upload) reportingMetric := types.PUReportedMetric{ ConnectionDetails: types.ConnectionDetails{ @@ -723,7 +726,7 @@ func (job *UploadJob) exportRegularTables(specialTables []string, loadFilesTable func (job *UploadJob) TablesToSkip() (map[string]model.PendingTableUpload, map[string]model.PendingTableUpload, error) { job.pendingTableUploadsOnce.Do(func() { job.pendingTableUploads, job.pendingTableUploadsError = job.pendingTableUploadsRepo.PendingTableUploads( - context.TODO(), + job.ctx, job.upload.Namespace, job.upload.ID, job.upload.DestinationID, @@ -751,14 +754,15 @@ func (job *UploadJob) TablesToSkip() (map[string]model.PendingTableUpload, map[s } func (job *UploadJob) resolveIdentities(populateHistoricIdentities bool) (err error) { - idr := identity.HandleT{ - Warehouse: job.warehouse, - DB: job.dbHandle, - UploadID: job.upload.ID, - Uploader: job, - WarehouseManager: job.whManager, - LoadFileDownloader: downloader.NewDownloader(&job.warehouse, job, 8), - } + idr := identity.New( + job.ctx, + job.warehouse, + job.dbHandle, + job, + job.upload.ID, + job.whManager, + downloader.NewDownloader(&job.warehouse, job, 8), + ) if populateHistoricIdentities { return idr.ResolveHistoricIdentities() } @@ -769,7 +773,7 @@ func (job *UploadJob) UpdateTableSchema(tName string, tableSchemaDiff warehouseu pkgLogger.Infof(`[WH]: Starting schema update for table %s in namespace %s of destination %s:%s`, tName, job.warehouse.Namespace, job.warehouse.Type, job.warehouse.Destination.ID) if tableSchemaDiff.TableToBeCreated { - err = job.whManager.CreateTable(tName, tableSchemaDiff.ColumnMap) + err = job.whManager.CreateTable(job.ctx, tName, tableSchemaDiff.ColumnMap) if err != nil { pkgLogger.Errorf("Error creating table %s on namespace: %s, error: %v", tName, job.warehouse.Namespace, err) return err @@ -794,7 +798,7 @@ func (job *UploadJob) alterColumnsToWarehouse(tName string, columnsMap model.Tab var errs []error for columnName, columnType := range columnsMap { - res, err := job.whManager.AlterColumn(tName, columnName, columnType) + res, err := job.whManager.AlterColumn(job.ctx, tName, columnName, columnType) if err != nil { errs = append(errs, err) continue @@ -826,7 +830,7 @@ func (job *UploadJob) alterColumnsToWarehouse(tName string, columnsMap model.Tab query := strings.Join(queries, "\n") pkgLogger.Infof("altering dependent columns: %s", query) - err := job.AlertSender.SendAlert(context.TODO(), "warehouse-column-changes", + err := job.alertSender.SendAlert(job.ctx, "warehouse-column-changes", alerta.SendAlertOpts{ Severity: alerta.SeverityCritical, Priority: alerta.PriorityP1, @@ -871,7 +875,7 @@ func (job *UploadJob) addColumnsToWarehouse(tName string, columnsMap model.Table chunks := lo.Chunk(columnsToAdd, columnsBatchSize) for _, chunk := range chunks { - err = job.whManager.AddColumns(tName, chunk) + err = job.whManager.AddColumns(job.ctx, tName, chunk) if err != nil { err = fmt.Errorf("failed to add columns for table %s in namespace %s of destination %s:%s with error: %w", tName, job.warehouse.Namespace, job.warehouse.Type, job.warehouse.Destination.ID, err) break @@ -937,7 +941,7 @@ func (job *UploadJob) loadAllTablesExcept(skipLoadForTables []string, loadFilesT wg.Done() if slices.Contains(alwaysMarkExported, strings.ToLower(tableName)) { status := model.TableUploadExported - _ = job.tableUploadsRepo.Set(context.TODO(), job.upload.ID, tableName, repo.TableUploadSetOptions{ + _ = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, tableName, repo.TableUploadSetOptions{ Status: &status, }) } @@ -964,7 +968,7 @@ func (job *UploadJob) loadAllTablesExcept(skipLoadForTables []string, loadFilesT if alteredSchemaInAtLeastOneTable { pkgLogger.Infof("loadAllTablesExcept: schema changed - updating local schema for %s", job.warehouse.Identifier) - _ = job.schemaHandle.updateLocalSchema(job.upload.ID, job.schemaHandle.schemaInWarehouse) + _ = job.schemaHandle.updateLocalSchema(job.ctx, job.upload.ID, job.schemaHandle.schemaInWarehouse) } return loadErrors @@ -991,7 +995,7 @@ func (job *UploadJob) getTotalCount(tName string) (int64, error) { ) operation := func() error { - ctx, cancel := context.WithTimeout(context.TODO(), tableCountQueryTimeout) + ctx, cancel := context.WithTimeout(job.ctx, tableCountQueryTimeout) defer cancel() total, countErr = job.whManager.GetTotalCountInTable(ctx, tName) @@ -1013,7 +1017,7 @@ func (job *UploadJob) loadTable(tName string) (bool, error) { if err != nil { status := model.TableUploadUpdatingSchemaFailed errorsString := misc.QuoteLiteral(err.Error()) - _ = job.tableUploadsRepo.Set(context.TODO(), job.upload.ID, tName, repo.TableUploadSetOptions{ + _ = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, tName, repo.TableUploadSetOptions{ Status: &status, Error: &errorsString, }) @@ -1032,8 +1036,8 @@ func (job *UploadJob) loadTable(tName string) (bool, error) { ) status := model.TableUploadExecuting - lastExecTime := job.Now() - _ = job.tableUploadsRepo.Set(context.TODO(), job.upload.ID, tName, repo.TableUploadSetOptions{ + lastExecTime := job.now() + _ = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, tName, repo.TableUploadSetOptions{ Status: &status, LastExecTime: &lastExecTime, }) @@ -1061,11 +1065,11 @@ func (job *UploadJob) loadTable(tName string) (bool, error) { } } - err = job.whManager.LoadTable(context.TODO(), tName) + err = job.whManager.LoadTable(job.ctx, tName) if err != nil { status := model.TableUploadExportingFailed errorsString := misc.QuoteLiteral(err.Error()) - _ = job.tableUploadsRepo.Set(context.TODO(), job.upload.ID, tName, repo.TableUploadSetOptions{ + _ = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, tName, repo.TableUploadSetOptions{ Status: &status, Error: &errorsString, }) @@ -1089,7 +1093,7 @@ func (job *UploadJob) loadTable(tName string) (bool, error) { ) return } - tableUpload, errEventCount := job.tableUploadsRepo.GetByUploadIDAndTableName(context.TODO(), job.upload.ID, tName) + tableUpload, errEventCount := job.tableUploadsRepo.GetByUploadIDAndTableName(job.ctx, job.upload.ID, tName) if errEventCount != nil { return } @@ -1101,10 +1105,10 @@ func (job *UploadJob) loadTable(tName string) (bool, error) { }() status = model.TableUploadExported - _ = job.tableUploadsRepo.Set(context.TODO(), job.upload.ID, tName, repo.TableUploadSetOptions{ + _ = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, tName, repo.TableUploadSetOptions{ Status: &status, }) - tableUpload, queryErr := job.tableUploadsRepo.GetByUploadIDAndTableName(context.TODO(), job.upload.ID, tName) + tableUpload, queryErr := job.tableUploadsRepo.GetByUploadIDAndTableName(job.ctx, job.upload.ID, tName) if queryErr == nil { job.recordTableLoad(tName, tableUpload.TotalEvents) } @@ -1178,8 +1182,8 @@ func (job *UploadJob) loadUserTables(loadFilesTableMap map[tableNameT]bool) ([]e // Load all user tables status := model.TableUploadExecuting - lastExecTime := job.Now() - _ = job.tableUploadsRepo.Set(context.TODO(), job.upload.ID, job.identifiesTableName(), repo.TableUploadSetOptions{ + lastExecTime := job.now() + _ = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, job.identifiesTableName(), repo.TableUploadSetOptions{ Status: &status, LastExecTime: &lastExecTime, }) @@ -1188,7 +1192,7 @@ func (job *UploadJob) loadUserTables(loadFilesTableMap map[tableNameT]bool) ([]e if err != nil { status := model.TableUploadUpdatingSchemaFailed errorsString := misc.QuoteLiteral(err.Error()) - _ = job.tableUploadsRepo.Set(context.TODO(), job.upload.ID, job.identifiesTableName(), repo.TableUploadSetOptions{ + _ = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, job.identifiesTableName(), repo.TableUploadSetOptions{ Status: &status, Error: &errorsString, }) @@ -1197,8 +1201,8 @@ func (job *UploadJob) loadUserTables(loadFilesTableMap map[tableNameT]bool) ([]e var alteredUserSchema bool if _, ok := job.upload.UploadSchema[job.usersTableName()]; ok { status := model.TableUploadExecuting - lastExecTime := job.Now() - _ = job.tableUploadsRepo.Set(context.TODO(), job.upload.ID, job.usersTableName(), repo.TableUploadSetOptions{ + lastExecTime := job.now() + _ = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, job.usersTableName(), repo.TableUploadSetOptions{ Status: &status, LastExecTime: &lastExecTime, }) @@ -1206,7 +1210,7 @@ func (job *UploadJob) loadUserTables(loadFilesTableMap map[tableNameT]bool) ([]e if err != nil { status = model.TableUploadUpdatingSchemaFailed errorsString := misc.QuoteLiteral(err.Error()) - _ = job.tableUploadsRepo.Set(context.TODO(), job.upload.ID, job.usersTableName(), repo.TableUploadSetOptions{ + _ = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, job.usersTableName(), repo.TableUploadSetOptions{ Status: &status, Error: &errorsString, }) @@ -1219,11 +1223,11 @@ func (job *UploadJob) loadUserTables(loadFilesTableMap map[tableNameT]bool) ([]e return []error{}, nil } - errorMap := job.whManager.LoadUserTables(context.TODO()) + errorMap := job.whManager.LoadUserTables(job.ctx) if alteredIdentitySchema || alteredUserSchema { pkgLogger.Infof("loadUserTables: schema changed - updating local schema for %s", job.warehouse.Identifier) - _ = job.schemaHandle.updateLocalSchema(job.upload.ID, job.schemaHandle.schemaInWarehouse) + _ = job.schemaHandle.updateLocalSchema(job.ctx, job.upload.ID, job.schemaHandle.schemaInWarehouse) } return job.processLoadTableResponse(errorMap) } @@ -1273,7 +1277,7 @@ func (job *UploadJob) loadIdentityTables(populateHistoricIdentities bool) (loadE if err != nil { status := model.TableUploadUpdatingSchemaFailed errorsString := misc.QuoteLiteral(err.Error()) - _ = job.tableUploadsRepo.Set(context.TODO(), job.upload.ID, tableName, repo.TableUploadSetOptions{ + _ = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, tableName, repo.TableUploadSetOptions{ Status: &status, Error: &errorsString, }) @@ -1283,15 +1287,15 @@ func (job *UploadJob) loadIdentityTables(populateHistoricIdentities bool) (loadE job.setUpdatedTableSchema(tableName, tableSchemaDiff.UpdatedSchema) status := model.TableUploadUpdatedSchema - _ = job.tableUploadsRepo.Set(context.TODO(), job.upload.ID, tableName, repo.TableUploadSetOptions{ + _ = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, tableName, repo.TableUploadSetOptions{ Status: &status, }) alteredSchema = true } status := model.TableUploadExecuting - lastExecTime := job.Now() - err = job.tableUploadsRepo.Set(context.TODO(), job.upload.ID, tableName, repo.TableUploadSetOptions{ + lastExecTime := job.now() + err = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, tableName, repo.TableUploadSetOptions{ Status: &status, LastExecTime: &lastExecTime, }) @@ -1302,9 +1306,9 @@ func (job *UploadJob) loadIdentityTables(populateHistoricIdentities bool) (loadE switch tableName { case job.identityMergeRulesTableName(): - err = job.whManager.LoadIdentityMergeRulesTable() + err = job.whManager.LoadIdentityMergeRulesTable(job.ctx) case job.identityMappingsTableName(): - err = job.whManager.LoadIdentityMappingsTable() + err = job.whManager.LoadIdentityMappingsTable(job.ctx) } if err != nil { @@ -1315,7 +1319,7 @@ func (job *UploadJob) loadIdentityTables(populateHistoricIdentities bool) (loadE if alteredSchema { pkgLogger.Infof("loadIdentityTables: schema changed - updating local schema for %s", job.warehouse.Identifier) - _ = job.schemaHandle.updateLocalSchema(job.upload.ID, job.schemaHandle.schemaInWarehouse) + _ = job.schemaHandle.updateLocalSchema(job.ctx, job.upload.ID, job.schemaHandle.schemaInWarehouse) } return job.processLoadTableResponse(errorMap) @@ -1334,18 +1338,18 @@ func (job *UploadJob) processLoadTableResponse(errorMap map[string]error) (error errors = append(errors, loadErr) errorsString := misc.QuoteLiteral(loadErr.Error()) status := model.TableUploadExportingFailed - tableUploadErr = job.tableUploadsRepo.Set(context.TODO(), job.upload.ID, tName, repo.TableUploadSetOptions{ + tableUploadErr = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, tName, repo.TableUploadSetOptions{ Status: &status, Error: &errorsString, }) } else { status := model.TableUploadExported - tableUploadErr = job.tableUploadsRepo.Set(context.TODO(), job.upload.ID, tName, repo.TableUploadSetOptions{ + tableUploadErr = job.tableUploadsRepo.Set(job.ctx, job.upload.ID, tName, repo.TableUploadSetOptions{ Status: &status, }) if tableUploadErr == nil { // Since load is successful, we assume all events in load files are uploaded - tableUpload, queryErr := job.tableUploadsRepo.GetByUploadIDAndTableName(context.TODO(), job.upload.ID, tName) + tableUpload, queryErr := job.tableUploadsRepo.GetByUploadIDAndTableName(job.ctx, job.upload.ID, tName) if queryErr == nil { job.recordTableLoad(tName, tableUpload.TotalEvents) } @@ -1363,11 +1367,11 @@ func (job *UploadJob) processLoadTableResponse(errorMap map[string]error) (error // getNewTimings appends current status with current time to timings column // e.g. status: exported_data, timings: [{exporting_data: 2020-04-21 15:16:19.687716] -> [{exporting_data: 2020-04-21 15:16:19.687716, exported_data: 2020-04-21 15:26:34.344356}] func (job *UploadJob) getNewTimings(status string) ([]byte, model.Timings) { - timings, err := repo.NewUploads(job.dbHandle).UploadTimings(context.TODO(), job.upload.ID) + timings, err := repo.NewUploads(job.dbHandle).UploadTimings(job.ctx, job.upload.ID) if err != nil { pkgLogger.Error("error getting timing, scrapping them", err) } - timing := map[string]time.Time{status: job.Now()} + timing := map[string]time.Time{status: job.now()} timings = append(timings, timing) marshalledTimings, err := json.Marshal(timings) if err != nil { @@ -1389,7 +1393,7 @@ func (job *UploadJob) getUploadFirstAttemptTime() (timing time.Time) { warehouseutils.WarehouseUploadsTable, job.upload.ID, ) - err := job.dbHandle.QueryRow(sqlStatement).Scan(&firstTiming) + err := job.dbHandle.QueryRowContext(job.ctx, sqlStatement).Scan(&firstTiming) if err != nil { return } @@ -1410,7 +1414,7 @@ func (job *UploadJob) setUploadStatus(statusOpts UploadStatusOpts) (err error) { opts := []UploadColumn{ {Column: UploadStatusField, Value: statusOpts.Status}, {Column: UploadTimingsField, Value: marshalledTimings}, - {Column: UploadUpdatedAtField, Value: job.Now()}, + {Column: UploadUpdatedAtField, Value: job.now()}, } job.upload.Status = statusOpts.Status @@ -1423,7 +1427,7 @@ func (job *UploadJob) setUploadStatus(statusOpts UploadStatusOpts) (err error) { uploadColumnOpts := UploadColumnsOpts{Fields: additionalFields} if statusOpts.ReportingMetric != (types.PUReportedMetric{}) { - txn, err := dbHandle.Begin() + txn, err := dbHandle.BeginTx(job.ctx, &sql.TxOptions{}) if err != nil { return err } @@ -1508,9 +1512,9 @@ func (job *UploadJob) setUploadColumns(opts UploadColumnsOpts) (err error) { columns, ) if opts.Txn != nil { - _, err = opts.Txn.Exec(sqlStatement, values...) + _, err = opts.Txn.ExecContext(job.ctx, sqlStatement, values...) } else { - _, err = dbHandle.Exec(sqlStatement, values...) + _, err = dbHandle.ExecContext(job.ctx, sqlStatement, values...) } return err @@ -1523,7 +1527,7 @@ func (job *UploadJob) triggerUploadNow() (err error) { metadata := repo.ExtractUploadMetadata(job.upload) - metadata.NextRetryTime = job.Now().Add(-time.Hour * 1) + metadata.NextRetryTime = job.now().Add(-time.Hour * 1) metadata.Retried = true metadata.Priority = 50 @@ -1535,10 +1539,10 @@ func (job *UploadJob) triggerUploadNow() (err error) { uploadColumns := []UploadColumn{ {Column: "status", Value: newJobState}, {Column: "metadata", Value: metadataJSON}, - {Column: "updated_at", Value: job.Now()}, + {Column: "updated_at", Value: job.now()}, } - txn, err := job.dbHandle.Begin() + txn, err := job.dbHandle.BeginTx(job.ctx, &sql.TxOptions{}) if err != nil { panic(err) } @@ -1595,12 +1599,12 @@ func (job *UploadJob) Aborted(attempts int, startTime time.Time) bool { return false } - return attempts > job.MinRetryAttempts && job.Now().Sub(startTime) > job.RetryTimeWindow + return attempts > job.minRetryAttempts && job.now().Sub(startTime) > job.retryTimeWindow } func (job *UploadJob) setUploadError(statusError error, state string) (string, error) { var ( - errorTags = job.ErrorHandler.MatchErrorMappings(statusError) + errorTags = job.errorHandler.MatchErrorMappings(statusError) destCredentialsValidations *bool ) @@ -1650,7 +1654,7 @@ func (job *UploadJob) setUploadError(statusError error, state string) (string, e metadata := repo.ExtractUploadMetadata(job.upload) - metadata.NextRetryTime = job.Now().Add(DurationBeforeNextAttempt(upload.Attempts + 1)) + metadata.NextRetryTime = job.now().Add(DurationBeforeNextAttempt(upload.Attempts + 1)) metadataJSON, err := json.Marshal(metadata) if err != nil { metadataJSON = []byte("{}") @@ -1663,10 +1667,10 @@ func (job *UploadJob) setUploadError(statusError error, state string) (string, e {Column: "status", Value: state}, {Column: "metadata", Value: metadataJSON}, {Column: "error", Value: serializedErr}, - {Column: "updated_at", Value: job.Now()}, + {Column: "updated_at", Value: job.now()}, } - txn, err := job.dbHandle.Begin() + txn, err := job.dbHandle.BeginTx(job.ctx, &sql.TxOptions{}) if err != nil { return "", fmt.Errorf("unable to start transaction: %w", err) } @@ -1676,8 +1680,8 @@ func (job *UploadJob) setUploadError(statusError error, state string) (string, e return "", fmt.Errorf("unable to change upload columns: %w", err) } - inputCount, _ := repo.NewStagingFiles(dbHandle).TotalEventsForUpload(context.TODO(), upload) - outputCount, _ := job.tableUploadsRepo.TotalExportedEvents(context.TODO(), job.upload.ID, []string{ + inputCount, _ := repo.NewStagingFiles(dbHandle).TotalEventsForUpload(job.ctx, upload) + outputCount, _ := job.tableUploadsRepo.TotalExportedEvents(job.ctx, job.upload.ID, []string{ warehouseutils.ToProviderCase(job.warehouse.Type, warehouseutils.DiscardsTable), }) failCount := inputCount - outputCount @@ -1765,7 +1769,7 @@ func (job *UploadJob) validateDestinationCredentials() (bool, error) { if job.destinationValidator == nil { return false, errors.New("failed to validate as destinationValidator is not set") } - response := job.destinationValidator.Validate(&job.warehouse.Destination) + response := job.destinationValidator.Validate(job.ctx, &job.warehouse.Destination) return response.Success, nil } @@ -1803,14 +1807,14 @@ func (job *UploadJob) getLoadFilesTableMap() (loadFilesMap map[tableNameT]bool, ); `, warehouseutils.WarehouseLoadFilesTable, - ) + ) /**/ sqlStatementArgs := []interface{}{ sourceID, destID, job.upload.LoadFileStartID, job.upload.LoadFileEndID, } - rows, err := dbHandle.Query(sqlStatement, sqlStatementArgs...) + rows, err := dbHandle.QueryContext(job.ctx, sqlStatement, sqlStatementArgs...) if err == sql.ErrNoRows { err = nil return @@ -1841,13 +1845,13 @@ func (job *UploadJob) areIdentityTablesLoadFilesGenerated() (bool, error) { err error ) - if tu, err = job.tableUploadsRepo.GetByUploadIDAndTableName(context.TODO(), job.upload.ID, mergeRulesTable); err != nil { + if tu, err = job.tableUploadsRepo.GetByUploadIDAndTableName(job.ctx, job.upload.ID, mergeRulesTable); err != nil { return false, fmt.Errorf("table upload not found for merge rules table: %w", err) } if tu.Location == "" { return false, fmt.Errorf("merge rules location not found: %w", err) } - if tu, err = job.tableUploadsRepo.GetByUploadIDAndTableName(context.TODO(), job.upload.ID, mappingsTable); err != nil { + if tu, err = job.tableUploadsRepo.GetByUploadIDAndTableName(job.ctx, job.upload.ID, mappingsTable); err != nil { return false, fmt.Errorf("table upload not found for mappings table: %w", err) } if tu.Location == "" { @@ -1856,7 +1860,7 @@ func (job *UploadJob) areIdentityTablesLoadFilesGenerated() (bool, error) { return true, nil } -func (job *UploadJob) GetLoadFilesMetadata(options warehouseutils.GetLoadFilesOptions) (loadFiles []warehouseutils.LoadFile) { +func (job *UploadJob) GetLoadFilesMetadata(ctx context.Context, options warehouseutils.GetLoadFilesOptions) (loadFiles []warehouseutils.LoadFile) { var tableFilterSQL string if options.Table != "" { tableFilterSQL = fmt.Sprintf(` AND table_name='%s'`, options.Table) @@ -1899,7 +1903,7 @@ func (job *UploadJob) GetLoadFilesMetadata(options warehouseutils.GetLoadFilesOp ) pkgLogger.Debugf(`Fetching loadFileLocations: %v`, sqlStatement) - rows, err := dbHandle.Query(sqlStatement) + rows, err := dbHandle.QueryContext(ctx, sqlStatement) if err != nil { panic(fmt.Errorf("Query: %s\nfailed with Error : %w", sqlStatement, err)) } @@ -1920,8 +1924,8 @@ func (job *UploadJob) GetLoadFilesMetadata(options warehouseutils.GetLoadFilesOp return } -func (job *UploadJob) GetSampleLoadFileLocation(tableName string) (location string, err error) { - locations := job.GetLoadFilesMetadata(warehouseutils.GetLoadFilesOptions{Table: tableName, Limit: 1}) +func (job *UploadJob) GetSampleLoadFileLocation(ctx context.Context, tableName string) (location string, err error) { + locations := job.GetLoadFilesMetadata(ctx, warehouseutils.GetLoadFilesOptions{Table: tableName, Limit: 1}) if len(locations) == 0 { return "", fmt.Errorf(`no load file found for table:%s`, tableName) } @@ -1943,13 +1947,13 @@ func (job *UploadJob) GetTableSchemaInUpload(tableName string) model.TableSchema return job.schemaHandle.uploadSchema[tableName] } -func (job *UploadJob) GetSingleLoadFile(tableName string) (warehouseutils.LoadFile, error) { +func (job *UploadJob) GetSingleLoadFile(ctx context.Context, tableName string) (warehouseutils.LoadFile, error) { var ( tableUpload model.TableUpload err error ) - if tableUpload, err = job.tableUploadsRepo.GetByUploadIDAndTableName(context.TODO(), job.upload.ID, tableName); err != nil { + if tableUpload, err = job.tableUploadsRepo.GetByUploadIDAndTableName(ctx, job.upload.ID, tableName); err != nil { return warehouseutils.LoadFile{}, fmt.Errorf("get single load file: %w", err) } @@ -2077,12 +2081,12 @@ func initializeStateMachine() { abortState.nextState = nil } -func (job *UploadJob) GetLocalSchema() (model.Schema, error) { - return job.schemaHandle.getLocalSchema() +func (job *UploadJob) GetLocalSchema(ctx context.Context) (model.Schema, error) { + return job.schemaHandle.getLocalSchema(ctx) } -func (job *UploadJob) UpdateLocalSchema(schema model.Schema) error { - return job.schemaHandle.updateLocalSchema(job.upload.ID, schema) +func (job *UploadJob) UpdateLocalSchema(ctx context.Context, schema model.Schema) error { + return job.schemaHandle.updateLocalSchema(ctx, job.upload.ID, schema) } func (job *UploadJob) RefreshPartitions(loadFileStartID, loadFileEndID int64) error { @@ -2101,15 +2105,15 @@ func (job *UploadJob) RefreshPartitions(loadFileStartID, loadFileEndID int64) er // Refresh partitions if exists for tableName := range job.upload.UploadSchema { - loadFiles := job.GetLoadFilesMetadata(warehouseutils.GetLoadFilesOptions{ + loadFiles := job.GetLoadFilesMetadata(job.ctx, warehouseutils.GetLoadFilesOptions{ Table: tableName, StartID: loadFileStartID, EndID: loadFileEndID, }) - batches := schemarepository.LoadFileBatching(loadFiles, job.RefreshPartitionBatchSize) + batches := schemarepository.LoadFileBatching(loadFiles, job.refreshPartitionBatchSize) for _, batch := range batches { - if err = repository.RefreshPartitions(tableName, batch); err != nil { + if err = repository.RefreshPartitions(job.ctx, tableName, batch); err != nil { return fmt.Errorf("refresh partitions: %w", err) } } diff --git a/warehouse/upload_test.go b/warehouse/upload_test.go index 800a2ec654..3446b722df 100644 --- a/warehouse/upload_test.go +++ b/warehouse/upload_test.go @@ -215,7 +215,7 @@ var _ = Describe("Upload", Ordered, func() { initWarehouse() - err = setupDB(context.TODO(), getConnectionString()) + err = setupDB(context.Background(), getConnectionString()) Expect(err).To(BeNil()) sqlStatement, err := os.ReadFile("testdata/sql/upload_test.sql") @@ -250,6 +250,7 @@ var _ = Describe("Upload", Ordered, func() { }, stagingFileIDs: []int64{1, 2, 3, 4, 5}, dbHandle: pgResource.DB, + ctx: context.Background(), } }) @@ -259,7 +260,7 @@ var _ = Describe("Upload", Ordered, func() { }) It("Total rows in staging files", func() { - count, err := repo.NewStagingFiles(pgResource.DB).TotalEventsForUpload(context.TODO(), job.upload) + count, err := repo.NewStagingFiles(pgResource.DB).TotalEventsForUpload(context.Background(), job.upload) Expect(err).To(BeNil()) Expect(count).To(BeEquivalentTo(5)) }) @@ -269,7 +270,7 @@ var _ = Describe("Upload", Ordered, func() { Expect(err).To(BeNil()) exportingData, err := time.Parse(time.RFC3339, "2020-04-21T15:16:19.687716Z") Expect(err).To(BeNil()) - Expect(repo.NewUploads(job.dbHandle).UploadTimings(context.TODO(), job.upload.ID)). + Expect(repo.NewUploads(job.dbHandle).UploadTimings(context.Background(), job.upload.ID)). To(BeEquivalentTo(model.Timings{ { "exported_data": exportedData, @@ -281,7 +282,7 @@ var _ = Describe("Upload", Ordered, func() { Describe("Staging files and load files events match", func() { When("Matched", func() { It("Should not send stats", func() { - job.matchRowsInStagingAndLoadFiles(context.TODO()) + job.matchRowsInStagingAndLoadFiles(context.Background()) }) }) @@ -293,7 +294,7 @@ var _ = Describe("Upload", Ordered, func() { job.stats = mockStats job.stagingFileIDs = []int64{1, 2} - job.matchRowsInStagingAndLoadFiles(context.TODO()) + job.matchRowsInStagingAndLoadFiles(context.Background()) }) }) }) @@ -308,6 +309,8 @@ func (m *mockAlertSender) SendAlert(context.Context, string, alerta.SendAlertOpt } func TestUploadJobT_UpdateTableSchema(t *testing.T) { + t.Parallel() + Init() Init4() @@ -379,9 +382,10 @@ func TestUploadJobT_UpdateTableSchema(t *testing.T) { DestinationID: testDestinationID, DestinationType: testDestinationType, }, - AlertSender: &mockAlertSender{ + alertSender: &mockAlertSender{ mockError: tc.mockAlertError, }, + ctx: context.Background(), } _, err = rs.DB.Exec( @@ -448,7 +452,8 @@ func TestUploadJobT_UpdateTableSchema(t *testing.T) { DestinationID: testDestinationID, DestinationType: testDestinationType, }, - AlertSender: &mockAlertSender{}, + alertSender: &mockAlertSender{}, + ctx: context.Background(), } _, err = rs.DB.Exec( @@ -516,6 +521,8 @@ func TestUploadJobT_UpdateTableSchema(t *testing.T) { } func TestUploadJobT_Aborted(t *testing.T) { + t.Parallel() + var ( minAttempts = 3 minRetryWindow = 3 * time.Hour @@ -560,9 +567,10 @@ func TestUploadJobT_Aborted(t *testing.T) { t.Parallel() job := &UploadJob{ - MinRetryAttempts: minAttempts, - RetryTimeWindow: minRetryWindow, - Now: func() time.Time { return now }, + minRetryAttempts: minAttempts, + retryTimeWindow: minRetryWindow, + now: func() time.Time { return now }, + ctx: context.Background(), } require.Equal(t, tc.expected, job.Aborted(tc.attempts, tc.startTime)) @@ -582,6 +590,8 @@ func (m *mockPendingTablesRepo) PendingTableUploads(context.Context, string, int } func TestUploadJobT_TablesToSkip(t *testing.T) { + t.Parallel() + t.Run("repo error", func(t *testing.T) { t.Parallel() @@ -592,6 +602,7 @@ func TestUploadJobT_TablesToSkip(t *testing.T) { pendingTableUploadsRepo: &mockPendingTablesRepo{ err: errors.New("some error"), }, + ctx: context.Background(), } previouslyFailedTables, currentJobSucceededTables, err := job.TablesToSkip() @@ -610,6 +621,7 @@ func TestUploadJobT_TablesToSkip(t *testing.T) { ID: 1, }, pendingTableUploadsRepo: ptRepo, + ctx: context.Background(), } for i := 0; i < 5; i++ { @@ -676,6 +688,7 @@ func TestUploadJobT_TablesToSkip(t *testing.T) { pendingTableUploadsRepo: &mockPendingTablesRepo{ pendingTables: pendingTables, }, + ctx: context.Background(), } previouslyFailedTables, currentJobSucceededTables, err := job.TablesToSkip() diff --git a/warehouse/utils/utils.go b/warehouse/utils/utils.go index 2d594a1c49..f735948ecd 100644 --- a/warehouse/utils/utils.go +++ b/warehouse/utils/utils.go @@ -237,13 +237,13 @@ type KeyValue struct { type Uploader interface { GetSchemaInWarehouse() model.Schema - GetLocalSchema() (model.Schema, error) - UpdateLocalSchema(schema model.Schema) error + GetLocalSchema(ctx context.Context) (model.Schema, error) + UpdateLocalSchema(ctx context.Context, schema model.Schema) error GetTableSchemaInWarehouse(tableName string) model.TableSchema GetTableSchemaInUpload(tableName string) model.TableSchema - GetLoadFilesMetadata(options GetLoadFilesOptions) []LoadFile - GetSampleLoadFileLocation(tableName string) (string, error) - GetSingleLoadFile(tableName string) (LoadFile, error) + GetLoadFilesMetadata(ctx context.Context, options GetLoadFilesOptions) []LoadFile + GetSampleLoadFileLocation(ctx context.Context, tableName string) (string, error) + GetSingleLoadFile(ctx context.Context, tableName string) (LoadFile, error) ShouldOnDedupUseNewRecord() bool UseRudderStorage() bool GetLoadFileGenStartTIme() time.Time diff --git a/warehouse/validations/steps.go b/warehouse/validations/steps.go index d37c3a67f6..246ea3e03f 100644 --- a/warehouse/validations/steps.go +++ b/warehouse/validations/steps.go @@ -1,6 +1,7 @@ package validations import ( + "context" "encoding/json" backendconfig "github.com/rudderlabs/rudder-server/backend-config" @@ -9,7 +10,7 @@ import ( warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) -func validateStepFunc(destination *backendconfig.DestinationT, _ string) (json.RawMessage, error) { +func validateStepFunc(_ context.Context, destination *backendconfig.DestinationT, _ string) (json.RawMessage, error) { return json.Marshal(StepsToValidate(destination)) } @@ -26,7 +27,6 @@ func StepsToValidate(dest *backendconfig.DestinationT) *model.StepsResponse { switch destType { case warehouseutils.GCS_DATALAKE, warehouseutils.AZURE_DATALAKE: - break case warehouseutils.S3_DATALAKE: wh := createDummyWarehouse(dest) if canUseGlue := schemarepository.UseGlue(&wh); !canUseGlue { diff --git a/warehouse/validations/validate.go b/warehouse/validations/validate.go index f96598924d..09e7e882c4 100644 --- a/warehouse/validations/validate.go +++ b/warehouse/validations/validate.go @@ -22,7 +22,7 @@ import ( ) type Validator interface { - Validate() error + Validate(ctx context.Context) error } type objectStorage struct { @@ -59,18 +59,23 @@ type dummyUploader struct { } type DestinationValidator interface { - Validate(dest *backendconfig.DestinationT) *model.DestinationValidationResponse + Validate(ctx context.Context, dest *backendconfig.DestinationT) *model.DestinationValidationResponse } type destinationValidationImpl struct{} -func (*dummyUploader) GetSchemaInWarehouse() model.Schema { return model.Schema{} } -func (*dummyUploader) GetLocalSchema() (model.Schema, error) { return model.Schema{}, nil } -func (*dummyUploader) UpdateLocalSchema(_ model.Schema) error { return nil } -func (*dummyUploader) ShouldOnDedupUseNewRecord() bool { return false } -func (*dummyUploader) GetFirstLastEvent() (time.Time, time.Time) { return time.Time{}, time.Time{} } -func (*dummyUploader) GetLoadFileGenStartTIme() time.Time { return time.Time{} } -func (*dummyUploader) GetSampleLoadFileLocation(string) (string, error) { return "", nil } +func (*dummyUploader) GetSchemaInWarehouse() model.Schema { return model.Schema{} } +func (*dummyUploader) GetLocalSchema(context.Context) (model.Schema, error) { + return model.Schema{}, nil +} +func (*dummyUploader) UpdateLocalSchema(context.Context, model.Schema) error { return nil } +func (*dummyUploader) ShouldOnDedupUseNewRecord() bool { return false } +func (*dummyUploader) GetFirstLastEvent() (time.Time, time.Time) { return time.Time{}, time.Time{} } +func (*dummyUploader) GetLoadFileGenStartTIme() time.Time { return time.Time{} } +func (*dummyUploader) GetSampleLoadFileLocation(context.Context, string) (string, error) { + return "", nil +} + func (*dummyUploader) GetTableSchemaInWarehouse(string) model.TableSchema { return model.TableSchema{} } @@ -79,11 +84,11 @@ func (*dummyUploader) GetTableSchemaInUpload(string) model.TableSchema { return model.TableSchema{} } -func (*dummyUploader) GetLoadFilesMetadata(warehouseutils.GetLoadFilesOptions) []warehouseutils.LoadFile { +func (*dummyUploader) GetLoadFilesMetadata(context.Context, warehouseutils.GetLoadFilesOptions) []warehouseutils.LoadFile { return []warehouseutils.LoadFile{} } -func (*dummyUploader) GetSingleLoadFile(string) (warehouseutils.LoadFile, error) { +func (*dummyUploader) GetSingleLoadFile(context.Context, string) (warehouseutils.LoadFile, error) { return warehouseutils.LoadFile{}, nil } @@ -99,15 +104,15 @@ func NewDestinationValidator() DestinationValidator { return &destinationValidationImpl{} } -func (*destinationValidationImpl) Validate(dest *backendconfig.DestinationT) *model.DestinationValidationResponse { - return validateDestination(dest, "") +func (*destinationValidationImpl) Validate(ctx context.Context, dest *backendconfig.DestinationT) *model.DestinationValidationResponse { + return validateDestination(ctx, dest, "") } -func validateDestinationFunc(dest *backendconfig.DestinationT, stepToValidate string) (json.RawMessage, error) { - return json.Marshal(validateDestination(dest, stepToValidate)) +func validateDestinationFunc(ctx context.Context, dest *backendconfig.DestinationT, stepToValidate string) (json.RawMessage, error) { + return json.Marshal(validateDestination(ctx, dest, stepToValidate)) } -func validateDestination(dest *backendconfig.DestinationT, stepToValidate string) *model.DestinationValidationResponse { +func validateDestination(ctx context.Context, dest *backendconfig.DestinationT, stepToValidate string) *model.DestinationValidationResponse { var ( destID = dest.ID destType = dest.DestinationDefinition.Name @@ -155,7 +160,7 @@ func validateDestination(dest *backendconfig.DestinationT, stepToValidate string // Iterate over all selected steps and validate for _, step := range stepsToValidate { - if validator, err = NewValidator(step.Name, dest); err != nil { + if validator, err = NewValidator(ctx, step.Name, dest); err != nil { err = fmt.Errorf("creating validator: %v", err) step.Error = err.Error() @@ -170,7 +175,7 @@ func validateDestination(dest *backendconfig.DestinationT, stepToValidate string break } - if stepError := validator.Validate(); stepError != nil { + if stepError := validator.Validate(ctx); stepError != nil { err = stepError step.Error = stepError.Error() } else { @@ -202,7 +207,7 @@ func validateDestination(dest *backendconfig.DestinationT, stepToValidate string return res } -func NewValidator(step string, dest *backendconfig.DestinationT) (Validator, error) { +func NewValidator(ctx context.Context, step string, dest *backendconfig.DestinationT) (Validator, error) { var ( operations manager.WarehouseOperations err error @@ -214,7 +219,7 @@ func NewValidator(step string, dest *backendconfig.DestinationT) (Validator, err destination: dest, }, nil case model.VerifyingConnections: - if operations, err = createManager(dest); err != nil { + if operations, err = createManager(ctx, dest); err != nil { return nil, fmt.Errorf("create manager: %w", err) } return &connections{ @@ -222,14 +227,14 @@ func NewValidator(step string, dest *backendconfig.DestinationT) (Validator, err manager: operations, }, nil case model.VerifyingCreateSchema: - if operations, err = createManager(dest); err != nil { + if operations, err = createManager(ctx, dest); err != nil { return nil, fmt.Errorf("create manager: %w", err) } return &createSchema{ manager: operations, }, nil case model.VerifyingCreateAndAlterTable: - if operations, err = createManager(dest); err != nil { + if operations, err = createManager(ctx, dest); err != nil { return nil, fmt.Errorf("create manager: %w", err) } return &createAlterTable{ @@ -237,7 +242,7 @@ func NewValidator(step string, dest *backendconfig.DestinationT) (Validator, err manager: operations, }, nil case model.VerifyingFetchSchema: - if operations, err = createManager(dest); err != nil { + if operations, err = createManager(ctx, dest); err != nil { return nil, fmt.Errorf("create manager: %w", err) } return &fetchSchema{ @@ -245,7 +250,7 @@ func NewValidator(step string, dest *backendconfig.DestinationT) (Validator, err manager: operations, }, nil case model.VerifyingLoadTable: - if operations, err = createManager(dest); err != nil { + if operations, err = createManager(ctx, dest); err != nil { return nil, fmt.Errorf("create manager: %w", err) } return &loadTable{ @@ -258,7 +263,7 @@ func NewValidator(step string, dest *backendconfig.DestinationT) (Validator, err return nil, fmt.Errorf("invalid step: %s", step) } -func (os *objectStorage) Validate() error { +func (os *objectStorage) Validate(ctx context.Context) error { var ( tempPath string err error @@ -269,46 +274,43 @@ func (os *objectStorage) Validate() error { return fmt.Errorf("creating temp load file: %w", err) } - if uploadObject, err = uploadFile(os.destination, tempPath); err != nil { + if uploadObject, err = uploadFile(ctx, os.destination, tempPath); err != nil { return fmt.Errorf("upload file: %w", err) } - if err = downloadFile(os.destination, uploadObject.ObjectName); err != nil { + if err = downloadFile(ctx, os.destination, uploadObject.ObjectName); err != nil { return fmt.Errorf("download file: %w", err) } return nil } -func (c *connections) Validate() error { - defer c.manager.Cleanup() +func (c *connections) Validate(ctx context.Context) error { + defer c.manager.Cleanup(ctx) - ctx, cancel := context.WithTimeout(context.TODO(), warehouseutils.TestConnectionTimeout) + ctx, cancel := context.WithTimeout(ctx, warehouseutils.TestConnectionTimeout) defer cancel() return c.manager.TestConnection(ctx, createDummyWarehouse(c.destination)) } -func (cs *createSchema) Validate() error { - defer cs.manager.Cleanup() +func (cs *createSchema) Validate(ctx context.Context) error { + defer cs.manager.Cleanup(ctx) - return cs.manager.CreateSchema() + return cs.manager.CreateSchema(ctx) } -func (cat *createAlterTable) Validate() error { - defer cat.manager.Cleanup() +func (cat *createAlterTable) Validate(ctx context.Context) error { + defer cat.manager.Cleanup(ctx) - if err := cat.manager.CreateTable(cat.table, tableSchemaMap); err != nil { + if err := cat.manager.CreateTable(ctx, cat.table, tableSchemaMap); err != nil { return fmt.Errorf("create table: %w", err) } - defer func() { _ = cat.manager.DropTable(cat.table) }() + defer func() { _ = cat.manager.DropTable(ctx, cat.table) }() for columnName, columnType := range alterColumnMap { - if err := cat.manager.AddColumns( - cat.table, - []warehouseutils.ColumnInfo{{Name: columnName, Type: columnType}}, - ); err != nil { + if err := cat.manager.AddColumns(ctx, cat.table, []warehouseutils.ColumnInfo{{Name: columnName, Type: columnType}}); err != nil { return fmt.Errorf("alter table: %w", err) } } @@ -316,16 +318,16 @@ func (cat *createAlterTable) Validate() error { return nil } -func (fs *fetchSchema) Validate() error { - defer fs.manager.Cleanup() +func (fs *fetchSchema) Validate(ctx context.Context) error { + defer fs.manager.Cleanup(ctx) - if _, _, err := fs.manager.FetchSchema(); err != nil { + if _, _, err := fs.manager.FetchSchema(ctx); err != nil { return fmt.Errorf("fetch schema: %w", err) } return nil } -func (lt *loadTable) Validate() error { +func (lt *loadTable) Validate(ctx context.Context) error { var ( destinationType = lt.destination.DestinationDefinition.Name loadFileType = warehouseutils.GetLoadFileType(destinationType) @@ -335,28 +337,23 @@ func (lt *loadTable) Validate() error { err error ) - defer lt.manager.Cleanup() + defer lt.manager.Cleanup(ctx) if tempPath, err = CreateTempLoadFile(lt.destination); err != nil { return fmt.Errorf("create temp load file: %w", err) } - if uploadOutput, err = uploadFile(lt.destination, tempPath); err != nil { + if uploadOutput, err = uploadFile(ctx, lt.destination, tempPath); err != nil { return fmt.Errorf("upload file: %w", err) } - if err = lt.manager.CreateTable(lt.table, tableSchemaMap); err != nil { + if err = lt.manager.CreateTable(ctx, lt.table, tableSchemaMap); err != nil { return fmt.Errorf("create table: %w", err) } - defer func() { _ = lt.manager.DropTable(lt.table) }() + defer func() { _ = lt.manager.DropTable(ctx, lt.table) }() - if err = lt.manager.LoadTestTable( - uploadOutput.Location, - lt.table, - payloadMap, - loadFileType, - ); err != nil { + if err = lt.manager.LoadTestTable(ctx, uploadOutput.Location, lt.table, payloadMap, loadFileType); err != nil { return fmt.Errorf("load test table: %w", err) } @@ -416,7 +413,7 @@ func CreateTempLoadFile(dest *backendconfig.DestinationT) (string, error) { return filePath, nil } -func uploadFile(dest *backendconfig.DestinationT, filePath string) (filemanager.UploadOutput, error) { +func uploadFile(ctx context.Context, dest *backendconfig.DestinationT, filePath string) (filemanager.UploadOutput, error) { var ( err error output filemanager.UploadOutput @@ -439,14 +436,14 @@ func uploadFile(dest *backendconfig.DestinationT, filePath string) (filemanager. defer misc.RemoveFilePaths(filePath) defer func() { _ = uploadFile.Close() }() - if output, err = fm.Upload(context.TODO(), uploadFile, prefixes...); err != nil { + if output, err = fm.Upload(ctx, uploadFile, prefixes...); err != nil { return filemanager.UploadOutput{}, fmt.Errorf("uploading file: %w", err) } return output, nil } -func downloadFile(dest *backendconfig.DestinationT, location string) error { +func downloadFile(ctx context.Context, dest *backendconfig.DestinationT, location string) error { var ( err error fm filemanager.FileManager @@ -486,7 +483,7 @@ func downloadFile(dest *backendconfig.DestinationT, location string) error { defer misc.RemoveFilePaths(filePath) defer func() { _ = downloadFile.Close() }() - if err = fm.Download(context.TODO(), downloadFile, location); err != nil { + if err = fm.Download(ctx, downloadFile, location); err != nil { return fmt.Errorf("downloading file: %w", err) } return nil @@ -517,7 +514,7 @@ func createFileManager(dest *backendconfig.DestinationT) (filemanager.FileManage return fileManager, nil } -func createManager(dest *backendconfig.DestinationT) (manager.WarehouseOperations, error) { +func createManager(ctx context.Context, dest *backendconfig.DestinationT) (manager.WarehouseOperations, error) { var ( destType = dest.DestinationDefinition.Name warehouse = createDummyWarehouse(dest) @@ -532,7 +529,7 @@ func createManager(dest *backendconfig.DestinationT) (manager.WarehouseOperation operations.SetConnectionTimeout(warehouseutils.TestConnectionTimeout) - if err = operations.Setup(warehouse, &dummyUploader{ + if err = operations.Setup(ctx, warehouse, &dummyUploader{ dest: dest, }); err != nil { return nil, fmt.Errorf("setting up manager: %w", err) diff --git a/warehouse/validations/validate_test.go b/warehouse/validations/validate_test.go index cb28abcd69..2a17960bdb 100644 --- a/warehouse/validations/validate_test.go +++ b/warehouse/validations/validate_test.go @@ -1,6 +1,7 @@ package validations_test import ( + "context" "errors" "fmt" "testing" @@ -58,6 +59,8 @@ func setup(t *testing.T, pool *dockertest.Pool) testResource { } func TestValidator(t *testing.T) { + t.Parallel() + misc.Init() warehouseutils.Init() encoding.Init() @@ -70,6 +73,8 @@ func TestValidator(t *testing.T) { sslmode = "disable" ) + ctx := context.Background() + pool, err := dockertest.NewPool("") require.NoError(t, err) @@ -82,7 +87,7 @@ func TestValidator(t *testing.T) { t.Run("Non Datalakes", func(t *testing.T) { t.Parallel() - v, err := validations.NewValidator(model.VerifyingObjectStorage, &backendconfig.DestinationT{ + v, err := validations.NewValidator(ctx, model.VerifyingObjectStorage, &backendconfig.DestinationT{ DestinationDefinition: backendconfig.DestinationDefinitionT{ Name: warehouseutils.POSTGRES, }, @@ -100,7 +105,7 @@ func TestValidator(t *testing.T) { }, }) require.NoError(t, err) - require.NoError(t, v.Validate()) + require.NoError(t, v.Validate(ctx)) }) t.Run("Datalakes", func(t *testing.T) { @@ -118,7 +123,7 @@ func TestValidator(t *testing.T) { _ = minioResource.Client.MakeBucket(bucket, "us-east-1") - v, err := validations.NewValidator(model.VerifyingObjectStorage, &backendconfig.DestinationT{ + v, err := validations.NewValidator(ctx, model.VerifyingObjectStorage, &backendconfig.DestinationT{ DestinationDefinition: backendconfig.DestinationDefinitionT{ Name: warehouseutils.S3_DATALAKE, }, @@ -136,7 +141,7 @@ func TestValidator(t *testing.T) { }, }) require.NoError(t, err) - require.NoError(t, v.Validate()) + require.NoError(t, v.Validate(ctx)) }) }) @@ -187,7 +192,7 @@ func TestValidator(t *testing.T) { conf[k] = v } - v, err := validations.NewValidator(model.VerifyingConnections, &backendconfig.DestinationT{ + v, err := validations.NewValidator(ctx, model.VerifyingConnections, &backendconfig.DestinationT{ DestinationDefinition: backendconfig.DestinationDefinitionT{ Name: warehouseutils.POSTGRES, }, @@ -196,9 +201,9 @@ func TestValidator(t *testing.T) { require.NoError(t, err) if tc.wantError != nil { - require.EqualError(t, v.Validate(), tc.wantError.Error()) + require.EqualError(t, v.Validate(ctx), tc.wantError.Error()) } else { - require.NoError(t, v.Validate()) + require.NoError(t, v.Validate(ctx)) } }) } @@ -265,7 +270,7 @@ func TestValidator(t *testing.T) { conf[k] = v } - v, err := validations.NewValidator(model.VerifyingCreateSchema, &backendconfig.DestinationT{ + v, err := validations.NewValidator(ctx, model.VerifyingCreateSchema, &backendconfig.DestinationT{ DestinationDefinition: backendconfig.DestinationDefinitionT{ Name: warehouseutils.POSTGRES, }, @@ -274,9 +279,9 @@ func TestValidator(t *testing.T) { require.NoError(t, err) if tc.wantError != nil { - require.EqualError(t, v.Validate(), tc.wantError.Error()) + require.EqualError(t, v.Validate(ctx), tc.wantError.Error()) } else { - require.NoError(t, v.Validate()) + require.NoError(t, v.Validate(ctx)) } }) } @@ -372,7 +377,7 @@ func TestValidator(t *testing.T) { conf[k] = v } - v, err := validations.NewValidator(model.VerifyingCreateAndAlterTable, &backendconfig.DestinationT{ + v, err := validations.NewValidator(ctx, model.VerifyingCreateAndAlterTable, &backendconfig.DestinationT{ DestinationDefinition: backendconfig.DestinationDefinitionT{ Name: warehouseutils.POSTGRES, }, @@ -381,9 +386,9 @@ func TestValidator(t *testing.T) { require.NoError(t, err) if tc.wantError != nil { - require.EqualError(t, v.Validate(), tc.wantError.Error()) + require.EqualError(t, v.Validate(ctx), tc.wantError.Error()) } else { - require.NoError(t, v.Validate()) + require.NoError(t, v.Validate(ctx)) } _, err = pgResource.DB.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s.setup_test_staging", namespace)) @@ -398,7 +403,7 @@ func TestValidator(t *testing.T) { tr := setup(t, pool) pgResource, minioResource := tr.pgResource, tr.minioResource - v, err := validations.NewValidator(model.VerifyingFetchSchema, &backendconfig.DestinationT{ + v, err := validations.NewValidator(ctx, model.VerifyingFetchSchema, &backendconfig.DestinationT{ DestinationDefinition: backendconfig.DestinationDefinitionT{ Name: warehouseutils.POSTGRES, }, @@ -425,7 +430,7 @@ func TestValidator(t *testing.T) { _, err = pgResource.DB.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s.%s(id int, val varchar)", namespace, table)) require.NoError(t, err) - require.NoError(t, v.Validate()) + require.NoError(t, v.Validate(ctx)) }) t.Run("Load table", func(t *testing.T) { @@ -530,7 +535,7 @@ func TestValidator(t *testing.T) { conf[k] = v } - v, err := validations.NewValidator(model.VerifyingLoadTable, &backendconfig.DestinationT{ + v, err := validations.NewValidator(ctx, model.VerifyingLoadTable, &backendconfig.DestinationT{ DestinationDefinition: backendconfig.DestinationDefinitionT{ Name: warehouseutils.POSTGRES, }, @@ -539,9 +544,9 @@ func TestValidator(t *testing.T) { require.NoError(t, err) if tc.wantError != nil { - require.EqualError(t, v.Validate(), tc.wantError.Error()) + require.EqualError(t, v.Validate(ctx), tc.wantError.Error()) } else { - require.NoError(t, v.Validate()) + require.NoError(t, v.Validate(ctx)) } _, err = pgResource.DB.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s.setup_test_staging", namespace)) diff --git a/warehouse/validations/validations.go b/warehouse/validations/validations.go index 5b25a4d40f..ac295124f6 100644 --- a/warehouse/validations/validations.go +++ b/warehouse/validations/validations.go @@ -1,6 +1,7 @@ package validations import ( + "context" "encoding/json" "fmt" "time" @@ -42,7 +43,7 @@ var ( ) type validationFunc struct { - Func func(*backendconfig.DestinationT, string) (json.RawMessage, error) + Func func(context.Context, *backendconfig.DestinationT, string) (json.RawMessage, error) } func Init() { @@ -53,7 +54,7 @@ func Init() { } // Validate the destination by running all the validation steps -func Validate(req *model.ValidationRequest) (*model.ValidationResponse, error) { +func Validate(ctx context.Context, req *model.ValidationRequest) (*model.ValidationResponse, error) { res := &model.ValidationResponse{} f, ok := validationFunctions()[req.Path] @@ -61,7 +62,7 @@ func Validate(req *model.ValidationRequest) (*model.ValidationResponse, error) { return res, fmt.Errorf("invalid path: %s", req.Path) } - result, requestError := f.Func(req.Destination, req.Step) + result, requestError := f.Func(ctx, req.Destination, req.Step) res.Data = string(result) if requestError != nil { diff --git a/warehouse/validations/validations_test.go b/warehouse/validations/validations_test.go index 2830e3ea2b..fb4ed9f33f 100644 --- a/warehouse/validations/validations_test.go +++ b/warehouse/validations/validations_test.go @@ -1,6 +1,7 @@ package validations_test import ( + "context" "errors" "testing" @@ -28,13 +29,15 @@ func TestValidate(t *testing.T) { sslmode = "disable" ) + ctx := context.Background() + pool, err := dockertest.NewPool("") require.NoError(t, err) t.Run("invalid path", func(t *testing.T) { t.Parallel() - _, err := validations.Validate(&model.ValidationRequest{ + _, err := validations.Validate(ctx, &model.ValidationRequest{ Path: "invalid", }) require.Equal(t, err, errors.New("invalid path: invalid")) @@ -43,7 +46,7 @@ func TestValidate(t *testing.T) { t.Run("steps", func(t *testing.T) { t.Parallel() - res, err := validations.Validate(&model.ValidationRequest{ + res, err := validations.Validate(ctx, &model.ValidationRequest{ Path: "steps", Destination: &backendconfig.DestinationT{ DestinationDefinition: backendconfig.DestinationDefinitionT{ @@ -62,7 +65,7 @@ func TestValidate(t *testing.T) { t.Run("invalid step", func(t *testing.T) { t.Parallel() - res, err := validations.Validate(&model.ValidationRequest{ + res, err := validations.Validate(ctx, &model.ValidationRequest{ Path: "validate", Step: "invalid", Destination: &backendconfig.DestinationT{ @@ -79,7 +82,7 @@ func TestValidate(t *testing.T) { t.Run("step not found", func(t *testing.T) { t.Parallel() - res, err := validations.Validate(&model.ValidationRequest{ + res, err := validations.Validate(ctx, &model.ValidationRequest{ Path: "validate", Step: "1000", Destination: &backendconfig.DestinationT{ @@ -96,7 +99,7 @@ func TestValidate(t *testing.T) { t.Run("invalid destination", func(t *testing.T) { t.Parallel() - res, err := validations.Validate(&model.ValidationRequest{ + res, err := validations.Validate(ctx, &model.ValidationRequest{ Path: "validate", Step: "2", Destination: &backendconfig.DestinationT{ @@ -113,7 +116,7 @@ func TestValidate(t *testing.T) { t.Run("step error", func(t *testing.T) { t.Parallel() - res, err := validations.Validate(&model.ValidationRequest{ + res, err := validations.Validate(ctx, &model.ValidationRequest{ Path: "validate", Destination: &backendconfig.DestinationT{ DestinationDefinition: backendconfig.DestinationDefinitionT{ @@ -129,7 +132,7 @@ func TestValidate(t *testing.T) { t.Run("invalid destination", func(t *testing.T) { t.Parallel() - res, err := validations.Validate(&model.ValidationRequest{ + res, err := validations.Validate(ctx, &model.ValidationRequest{ Path: "validate", Step: "2", Destination: &backendconfig.DestinationT{ @@ -149,7 +152,7 @@ func TestValidate(t *testing.T) { tr := setup(t, pool) pgResource, minioResource := tr.pgResource, tr.minioResource - res, err := validations.Validate(&model.ValidationRequest{ + res, err := validations.Validate(ctx, &model.ValidationRequest{ Path: "validate", Step: "", Destination: &backendconfig.DestinationT{ @@ -223,7 +226,7 @@ func TestValidate(t *testing.T) { for _, tc := range testCases { tc := tc - res, err := validations.Validate(&model.ValidationRequest{ + res, err := validations.Validate(ctx, &model.ValidationRequest{ Path: "validate", Step: tc.step, Destination: &backendconfig.DestinationT{ diff --git a/warehouse/warehouse.go b/warehouse/warehouse.go index 6f793f61aa..ba9b1daed2 100644 --- a/warehouse/warehouse.go +++ b/warehouse/warehouse.go @@ -307,7 +307,7 @@ func (wh *HandleT) backendConfigSubscriber(ctx context.Context) { destination = wh.attachSSHTunnellingInfo(ctx, destination) } - namespace := wh.getNamespace(source, destination) + namespace := wh.getNamespace(ctx, source, destination) warehouse := model.Warehouse{ WorkspaceID: workspaceID, Source: source, @@ -347,10 +347,10 @@ func (wh *HandleT) backendConfigSubscriber(ctx context.Context) { connectionsMapLock.Unlock() if warehouseutils.IDResolutionEnabled() && slices.Contains(warehouseutils.IdentityEnabledWarehouses, warehouse.Type) { - wh.setupIdentityTables(warehouse) + wh.setupIdentityTables(ctx, warehouse) if shouldPopulateHistoricIdentities && warehouse.Destination.Enabled { // non-blocking populate historic identities - wh.populateHistoricIdentities(warehouse) + wh.populateHistoricIdentities(ctx, warehouse) } } } @@ -408,7 +408,7 @@ func deepCopy(src, dest interface{}) error { // 1. user set name from destinationConfig // 2. from existing record in wh_schemas with same source + dest combo // 3. convert source name -func (wh *HandleT) getNamespace(source backendconfig.SourceT, destination backendconfig.DestinationT) string { +func (wh *HandleT) getNamespace(ctx context.Context, source backendconfig.SourceT, destination backendconfig.DestinationT) string { configMap := destination.Config if wh.destType == warehouseutils.CLICKHOUSE { if _, ok := configMap["database"].(string); ok { @@ -428,7 +428,7 @@ func (wh *HandleT) getNamespace(source backendconfig.SourceT, destination backen return warehouseutils.ToProviderCase(wh.destType, warehouseutils.ToSafeNamespace(wh.destType, fmt.Sprintf(`%s_%s`, namespacePrefix, source.Name))) } - namespace, err := wh.whSchemaRepo.GetNamespace(context.TODO(), source.ID, destination.ID) + namespace, err := wh.whSchemaRepo.GetNamespace(ctx, source.ID, destination.ID) if err != nil { pkgLogger.Errorw("getting namespace", logfield.SourceID, source.ID, @@ -571,9 +571,9 @@ func getUploadStartAfterTime() time.Time { return time.Now() } -func (wh *HandleT) getLatestUploadStatus(warehouse *model.Warehouse) (int64, string, int) { +func (wh *HandleT) getLatestUploadStatus(ctx context.Context, warehouse *model.Warehouse) (int64, string, int) { uploadID, status, priority, err := wh.warehouseDBHandle.GetLatestUploadStatus( - context.TODO(), + ctx, warehouse.Type, warehouse.Source.ID, warehouse.Destination.ID) @@ -597,7 +597,7 @@ func (wh *HandleT) createJobs(ctx context.Context, warehouse model.Warehouse) (e } priority := defaultUploadPriority - uploadID, uploadStatus, uploadPriority := wh.getLatestUploadStatus(&warehouse) + uploadID, uploadStatus, uploadPriority := wh.getLatestUploadStatus(ctx, &warehouse) if uploadStatus == model.Waiting { // If it is present do nothing else delete it if _, inProgress := wh.isUploadJobInProgress(warehouse, uploadID); !inProgress { @@ -747,7 +747,7 @@ func (wh *HandleT) getUploadsToProcess(ctx context.Context, availableWorkers int upload.UseRudderStorage = warehouse.GetBoolDestinationConfig("useRudderStorage") if !found { - uploadJob := wh.uploadJobFactory.NewUploadJob(&model.UploadJob{ + uploadJob := wh.uploadJobFactory.NewUploadJob(ctx, &model.UploadJob{ Upload: upload, }, nil) err := fmt.Errorf("unable to find source : %s or destination : %s, both or the connection between them", upload.SourceID, upload.DestinationID) @@ -765,7 +765,7 @@ func (wh *HandleT) getUploadsToProcess(ctx context.Context, availableWorkers int if err != nil { return nil, err } - uploadJob := wh.uploadJobFactory.NewUploadJob(&model.UploadJob{ + uploadJob := wh.uploadJobFactory.NewUploadJob(ctx, &model.UploadJob{ Warehouse: warehouse, Upload: upload, StagingFiles: stagingFilesList, @@ -863,7 +863,7 @@ func (wh *HandleT) Disable() { wh.isEnabled = false } -func (wh *HandleT) Setup(whType string) error { +func (wh *HandleT) Setup(ctx context.Context, whType string) error { pkgLogger.Infof("WH: Warehouse Router started: %s", whType) wh.Logger = pkgLogger wh.conf = config.Default @@ -877,7 +877,7 @@ func (wh *HandleT) Setup(whType string) error { wh.notifier = notifier wh.destType = whType - wh.resetInProgressJobs() + wh.resetInProgressJobs(ctx) wh.Enable() wh.workerChannelMap = make(map[string]chan *UploadJob) wh.inProgressMap = make(map[WorkerIdentifierT][]JobID) @@ -915,7 +915,7 @@ func (wh *HandleT) Setup(whType string) error { }, ) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(ctx) g, ctx := errgroup.WithContext(ctx) wh.backgroundCancel = cancel @@ -952,7 +952,7 @@ func (wh *HandleT) Shutdown() error { return wh.backgroundWait() } -func (wh *HandleT) resetInProgressJobs() { +func (wh *HandleT) resetInProgressJobs(ctx context.Context) { sqlStatement := fmt.Sprintf(` UPDATE %s @@ -967,7 +967,7 @@ func (wh *HandleT) resetInProgressJobs() { wh.destType, true, ) - rows, err := wh.dbHandle.Query(sqlStatement) + rows, err := wh.dbHandle.QueryContext(ctx, sqlStatement) if err != nil { panic(fmt.Errorf("query: %s failed with Error : %w", sqlStatement, err)) } @@ -1000,7 +1000,7 @@ func minimalConfigSubscriber(ctx context.Context) { whSchemaRepo: repo.NewWHSchemas(dbHandle), conf: config.Default, } - namespace := wh.getNamespace(source, destination) + namespace := wh.getNamespace(ctx, source, destination) connectionsMapLock.Lock() if _, ok := slaveConnectionsMap[destination.ID]; !ok { @@ -1048,7 +1048,7 @@ func monitorDestRouters(ctx context.Context) error { ch := tenantManager.WatchConfig(ctx) for configData := range ch { - err := onConfigDataEvent(configData, dstToWhRouter) + err := onConfigDataEvent(ctx, configData, dstToWhRouter) if err != nil { return err } @@ -1062,7 +1062,7 @@ func monitorDestRouters(ctx context.Context) error { return g.Wait() } -func onConfigDataEvent(config map[string]backendconfig.ConfigT, dstToWhRouter map[string]*HandleT) error { +func onConfigDataEvent(ctx context.Context, config map[string]backendconfig.ConfigT, dstToWhRouter map[string]*HandleT) error { pkgLogger.Debug("Got config from config-backend", config) enabledDestinations := make(map[string]bool) @@ -1078,7 +1078,7 @@ func onConfigDataEvent(config map[string]backendconfig.ConfigT, dstToWhRouter ma pkgLogger.Info("Starting a new Warehouse Destination Router: ", destination.DestinationDefinition.Name) wh = &HandleT{} wh.configSubscriberLock.Lock() - if err := wh.Setup(destination.DestinationDefinition.Name); err != nil { + if err := wh.Setup(ctx, destination.DestinationDefinition.Name); err != nil { return fmt.Errorf("setup warehouse %q: %w", destination.DestinationDefinition.Name, err) } wh.configSubscriberLock.Unlock() @@ -1431,9 +1431,9 @@ func TriggerUploadHandler(sourceID, destID string) error { return nil } -func databricksVersionHandler(w http.ResponseWriter, _ *http.Request) { +func databricksVersionHandler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(deltalake.GetDatabricksVersion())) + _, _ = w.Write([]byte(deltalake.GetDatabricksVersion(r.Context()))) } func fetchTablesHandler(w http.ResponseWriter, r *http.Request) { diff --git a/warehouse/warehouse_test.go b/warehouse/warehouse_test.go index b17031fd48..f9be44ad7a 100644 --- a/warehouse/warehouse_test.go +++ b/warehouse/warehouse_test.go @@ -355,7 +355,7 @@ func Test_GetNamespace(t *testing.T) { _, err = pgResource.DB.Exec(string(sqlStatement)) require.NoError(t, err) - namespace := wh.getNamespace(tc.source, tc.destination) + namespace := wh.getNamespace(context.Background(), tc.source, tc.destination) require.Equal(t, tc.result, namespace) }) } diff --git a/warehouse/warehousegrpc.go b/warehouse/warehousegrpc.go index 4564a81535..1cb27bfffd 100644 --- a/warehouse/warehousegrpc.go +++ b/warehouse/warehousegrpc.go @@ -25,7 +25,7 @@ type warehouseGRPC struct { EnableTunnelling bool } -func (*warehouseGRPC) GetWHUploads(_ context.Context, request *proto.WHUploadsRequest) (*proto.WHUploadsResponse, error) { +func (*warehouseGRPC) GetWHUploads(ctx context.Context, request *proto.WHUploadsRequest) (*proto.WHUploadsResponse, error) { uploadsReq := UploadsReq{ WorkspaceID: request.WorkspaceId, SourceID: request.SourceId, @@ -42,11 +42,11 @@ func (*warehouseGRPC) GetWHUploads(_ context.Context, request *proto.WHUploadsRe uploadsReq.SourceID, uploadsReq.DestinationID, ) - res, err := uploadsReq.GetWhUploads() + res, err := uploadsReq.GetWhUploads(ctx) return res, err } -func (*warehouseGRPC) TriggerWHUploads(_ context.Context, request *proto.WHUploadsRequest) (*proto.TriggerWhUploadsResponse, error) { +func (*warehouseGRPC) TriggerWHUploads(ctx context.Context, request *proto.WHUploadsRequest) (*proto.TriggerWhUploadsResponse, error) { uploadsReq := UploadsReq{ WorkspaceID: request.WorkspaceId, SourceID: request.SourceId, @@ -59,11 +59,11 @@ func (*warehouseGRPC) TriggerWHUploads(_ context.Context, request *proto.WHUploa uploadsReq.SourceID, uploadsReq.DestinationID, ) - res, err := uploadsReq.TriggerWhUploads() + res, err := uploadsReq.TriggerWhUploads(ctx) return res, err } -func (*warehouseGRPC) GetWHUpload(_ context.Context, request *proto.WHUploadRequest) (*proto.WHUploadResponse, error) { +func (*warehouseGRPC) GetWHUpload(ctx context.Context, request *proto.WHUploadRequest) (*proto.WHUploadResponse, error) { uploadReq := UploadReq{ UploadId: request.UploadId, WorkspaceID: request.WorkspaceId, @@ -74,7 +74,7 @@ func (*warehouseGRPC) GetWHUpload(_ context.Context, request *proto.WHUploadRequ uploadReq.WorkspaceID, uploadReq.UploadId, ) - res, err := uploadReq.GetWHUpload() + res, err := uploadReq.GetWHUpload(ctx) return res, err } @@ -82,7 +82,7 @@ func (*warehouseGRPC) GetHealth(context.Context, *emptypb.Empty) (*wrapperspb.Bo return wrapperspb.Bool(UploadAPI.enabled), nil } -func (*warehouseGRPC) TriggerWHUpload(_ context.Context, request *proto.WHUploadRequest) (*proto.TriggerWhUploadsResponse, error) { +func (*warehouseGRPC) TriggerWHUpload(ctx context.Context, request *proto.WHUploadRequest) (*proto.TriggerWhUploadsResponse, error) { uploadReq := UploadReq{ UploadId: request.UploadId, WorkspaceID: request.WorkspaceId, @@ -93,7 +93,7 @@ func (*warehouseGRPC) TriggerWHUpload(_ context.Context, request *proto.WHUpload uploadReq.WorkspaceID, uploadReq.UploadId, ) - res, err := uploadReq.TriggerWHUpload() + res, err := uploadReq.TriggerWHUpload(ctx) return res, err } @@ -143,7 +143,7 @@ func (grpc *warehouseGRPC) Validate(ctx context.Context, req *proto.WHValidation } } - res, err := validations.Validate(&model.ValidationRequest{ + res, err := validations.Validate(ctx, &model.ValidationRequest{ Path: req.Path, Step: req.Step, Destination: &destination, diff --git a/warehouse/warehousegrpc_test.go b/warehouse/warehousegrpc_test.go index 8dca43669f..2175a38737 100644 --- a/warehouse/warehousegrpc_test.go +++ b/warehouse/warehousegrpc_test.go @@ -46,6 +46,8 @@ var _ = Describe("WarehouseGrpc", func() { ) BeforeAll(func() { + c = context.Background() + pool, err := dockertest.NewPool("") Expect(err).To(BeNil()) @@ -58,7 +60,7 @@ var _ = Describe("WarehouseGrpc", func() { initWarehouse() - err = setupDB(context.TODO(), getConnectionString()) + err = setupDB(c, getConnectionString()) Expect(err).To(BeNil()) sqlStatement, err := os.ReadFile("testdata/sql/grpc_test.sql") @@ -83,7 +85,6 @@ var _ = Describe("WarehouseGrpc", func() { } w = &warehouseGRPC{} - c = context.TODO() }) AfterAll(func() { @@ -549,6 +550,8 @@ var _ = Describe("WarehouseGrpc", func() { ) BeforeAll(func() { + c = context.Background() + pool, err := dockertest.NewPool("") Expect(err).To(BeNil()) @@ -559,7 +562,7 @@ var _ = Describe("WarehouseGrpc", func() { initWarehouse() - err = setupDB(context.TODO(), getConnectionString()) + err = setupDB(c, getConnectionString()) Expect(err).To(BeNil()) sqlStatement, err := os.ReadFile("testdata/sql/grpc_test.sql") @@ -581,7 +584,6 @@ var _ = Describe("WarehouseGrpc", func() { } w = &warehouseGRPC{} - c = context.TODO() }) AfterAll(func() {