Skip to content

Commit

Permalink
chore: pass context (#3326)
Browse files Browse the repository at this point in the history
  • Loading branch information
achettyiitr committed May 25, 2023
1 parent b245915 commit 990a405
Show file tree
Hide file tree
Showing 56 changed files with 1,244 additions and 1,143 deletions.
5 changes: 3 additions & 2 deletions warehouse/admin.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package warehouse

import (
"context"
"errors"
"fmt"
"strings"
Expand Down Expand Up @@ -78,7 +79,7 @@ func (*WarehouseAdmin) Query(s QueryInput, reply *warehouseutils.QueryResult) er
if err != nil {
return err
}
client, err := whManager.Connect(warehouse)
client, err := whManager.Connect(context.TODO(), warehouse)
if err != nil {
return err
}
Expand Down Expand Up @@ -109,7 +110,7 @@ func (*WarehouseAdmin) ConfigurationTest(s ConfigurationTestInput, reply *Config
pkgLogger.Infof(`[WH Admin]: Validating warehouse destination: %s:%s`, warehouse.Type, warehouse.Destination.ID)

destinationValidator := validations.NewDestinationValidator()
res := destinationValidator.Validate(&warehouse.Destination)
res := destinationValidator.Validate(context.TODO(), &warehouse.Destination)

reply.Valid = res.Success
reply.Error = res.Error
Expand Down
65 changes: 37 additions & 28 deletions warehouse/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ var statusMap = map[string]string{
"failed": "%failed%",
}

func (uploadsReq *UploadsReq) GetWhUploads() (uploadsRes *proto.WHUploadsResponse, err error) {
func (uploadsReq *UploadsReq) GetWhUploads(ctx context.Context) (uploadsRes *proto.WHUploadsResponse, err error) {
uploadsRes = &proto.WHUploadsResponse{
Uploads: make([]*proto.WHUploadResponse, 0),
}
Expand All @@ -206,15 +206,15 @@ func (uploadsReq *UploadsReq) GetWhUploads() (uploadsRes *proto.WHUploadsRespons
}

if UploadAPI.isMultiWorkspace {
uploadsRes, err = uploadsReq.warehouseUploadsForHosted(authorizedSourceIDs, `id, source_id, destination_id, destination_type, namespace, status, error, first_event_at, last_event_at, last_exec_at, updated_at, timings, metadata->>'nextRetryTime', metadata->>'archivedStagingAndLoadFiles'`)
uploadsRes, err = uploadsReq.warehouseUploadsForHosted(ctx, authorizedSourceIDs, `id, source_id, destination_id, destination_type, namespace, status, error, first_event_at, last_event_at, last_exec_at, updated_at, timings, metadata->>'nextRetryTime', metadata->>'archivedStagingAndLoadFiles'`)
return
}

uploadsRes, err = uploadsReq.warehouseUploads(`id, source_id, destination_id, destination_type, namespace, status, error, first_event_at, last_event_at, last_exec_at, updated_at, timings, metadata->>'nextRetryTime', metadata->>'archivedStagingAndLoadFiles'`)
uploadsRes, err = uploadsReq.warehouseUploads(ctx, `id, source_id, destination_id, destination_type, namespace, status, error, first_event_at, last_event_at, last_exec_at, updated_at, timings, metadata->>'nextRetryTime', metadata->>'archivedStagingAndLoadFiles'`)
return
}

func (uploadsReq *UploadsReq) TriggerWhUploads() (response *proto.TriggerWhUploadsResponse, err error) {
func (uploadsReq *UploadsReq) TriggerWhUploads(ctx context.Context) (response *proto.TriggerWhUploadsResponse, err error) {
err = uploadsReq.validateReq()
defer func() {
if err != nil {
Expand Down Expand Up @@ -245,7 +245,7 @@ func (uploadsReq *UploadsReq) TriggerWhUploads() (response *proto.TriggerWhUploa
return
}
if pendingUploadCount == int64(0) {
pendingStagingFileCount, err = repo.NewStagingFiles(dbHandle).CountPendingForDestination(context.TODO(), uploadsReq.DestinationID)
pendingStagingFileCount, err = repo.NewStagingFiles(dbHandle).CountPendingForDestination(ctx, uploadsReq.DestinationID)
if err != nil {
return
}
Expand All @@ -269,7 +269,7 @@ func (uploadsReq *UploadsReq) TriggerWhUploads() (response *proto.TriggerWhUploa
return
}

func (uploadReq *UploadReq) GetWHUpload() (*proto.WHUploadResponse, error) {
func (uploadReq *UploadReq) GetWHUpload(ctx context.Context) (*proto.WHUploadResponse, error) {
err := uploadReq.validateReq()
if err != nil {
return &proto.WHUploadResponse{}, status.Errorf(codes.Code(code.Code_INVALID_ARGUMENT), err.Error())
Expand All @@ -287,7 +287,7 @@ func (uploadReq *UploadReq) GetWHUpload() (*proto.WHUploadResponse, error) {
isUploadArchived sql.NullBool
)

row := uploadReq.API.dbHandle.QueryRow(query)
row := uploadReq.API.dbHandle.QueryRowContext(ctx, query)
err = row.Scan(
&upload.Id,
&upload.SourceId,
Expand Down Expand Up @@ -358,15 +358,15 @@ func (uploadReq *UploadReq) GetWHUpload() (*proto.WHUploadResponse, error) {
Name: "",
API: uploadReq.API,
}
upload.Tables, err = tableUploadReq.GetWhTableUploads()
upload.Tables, err = tableUploadReq.GetWhTableUploads(ctx)
if err != nil {
return &proto.WHUploadResponse{}, status.Errorf(codes.Code(code.Code_INTERNAL), err.Error())
}

return &upload, nil
}

func (uploadReq *UploadReq) TriggerWHUpload() (response *proto.TriggerWhUploadsResponse, err error) {
func (uploadReq *UploadReq) TriggerWHUpload(ctx context.Context) (response *proto.TriggerWhUploadsResponse, err error) {
err = uploadReq.validateReq()
defer func() {
if err != nil {
Expand All @@ -380,7 +380,7 @@ func (uploadReq *UploadReq) TriggerWHUpload() (response *proto.TriggerWhUploadsR
return
}

upload, err := repo.NewUploads(uploadReq.API.dbHandle).Get(context.TODO(), uploadReq.UploadId)
upload, err := repo.NewUploads(uploadReq.API.dbHandle).Get(ctx, uploadReq.UploadId)
if err == model.ErrUploadNotFound {
return &proto.TriggerWhUploadsResponse{
Message: NoSuchSync,
Expand All @@ -401,7 +401,8 @@ func (uploadReq *UploadReq) TriggerWHUpload() (response *proto.TriggerWhUploadsR
uploadJobT := UploadJob{
upload: upload,
dbHandle: uploadReq.API.dbHandle,
Now: timeutil.Now,
now: timeutil.Now,
ctx: ctx,
}

err = uploadJobT.triggerUploadNow()
Expand All @@ -415,14 +416,14 @@ func (uploadReq *UploadReq) TriggerWHUpload() (response *proto.TriggerWhUploadsR
return
}

func (tableUploadReq TableUploadReq) GetWhTableUploads() ([]*proto.WHTable, error) {
func (tableUploadReq TableUploadReq) GetWhTableUploads(ctx context.Context) ([]*proto.WHTable, error) {
err := tableUploadReq.validateReq()
if err != nil {
return []*proto.WHTable{}, err
}
query := tableUploadReq.generateQuery(`id, wh_upload_id, table_name, total_events, status, error, last_exec_time, updated_at`)
tableUploadReq.API.log.Debug(query)
rows, err := tableUploadReq.API.dbHandle.Query(query)
rows, err := tableUploadReq.API.dbHandle.QueryContext(ctx, query)
if err != nil {
tableUploadReq.API.log.Errorf(err.Error())
return []*proto.WHTable{}, err
Expand Down Expand Up @@ -544,12 +545,12 @@ func (uploadsReq *UploadsReq) authorizedSources() (sourceIDs []string) {
return sourceIDs
}

func (uploadsReq *UploadsReq) getUploadsFromDB(isMultiWorkspace bool, query string) ([]*proto.WHUploadResponse, int32, error) {
func (uploadsReq *UploadsReq) getUploadsFromDB(ctx context.Context, isMultiWorkspace bool, query string) ([]*proto.WHUploadResponse, int32, error) {
var totalUploadCount int32
var err error
uploads := make([]*proto.WHUploadResponse, 0)

rows, err := uploadsReq.API.dbHandle.Query(query)
rows, err := uploadsReq.API.dbHandle.QueryContext(ctx, query)
if err != nil {
uploadsReq.API.log.Errorf(err.Error())
return nil, 0, err
Expand Down Expand Up @@ -651,7 +652,7 @@ func (uploadsReq *UploadsReq) getUploadsFromDB(isMultiWorkspace bool, query stri
return uploads, totalUploadCount, err
}

func (uploadsReq *UploadsReq) getTotalUploadCount(whereClause string) (int32, error) {
func (uploadsReq *UploadsReq) getTotalUploadCount(ctx context.Context, whereClause string) (int32, error) {
var totalUploadCount int32
query := fmt.Sprintf(`
select
Expand All @@ -665,12 +666,12 @@ func (uploadsReq *UploadsReq) getTotalUploadCount(whereClause string) (int32, er
query += fmt.Sprintf(` %s`, whereClause)
}
uploadsReq.API.log.Info(query)
err := uploadsReq.API.dbHandle.QueryRow(query).Scan(&totalUploadCount)
err := uploadsReq.API.dbHandle.QueryRowContext(ctx, query).Scan(&totalUploadCount)
return totalUploadCount, err
}

// for hosted workspaces - we get the uploads and the total upload count using the same query
func (uploadsReq *UploadsReq) warehouseUploadsForHosted(authorizedSourceIDs []string, selectFields string) (uploadsRes *proto.WHUploadsResponse, err error) {
func (uploadsReq *UploadsReq) warehouseUploadsForHosted(ctx context.Context, authorizedSourceIDs []string, selectFields string) (uploadsRes *proto.WHUploadsResponse, err error) {
var (
uploads []*proto.WHUploadResponse
totalUploadCount int32
Expand Down Expand Up @@ -724,7 +725,7 @@ func (uploadsReq *UploadsReq) warehouseUploadsForHosted(authorizedSourceIDs []st
uploadsReq.API.log.Info(query)

// get uploads from db
uploads, totalUploadCount, err = uploadsReq.getUploadsFromDB(true, query)
uploads, totalUploadCount, err = uploadsReq.getUploadsFromDB(ctx, true, query)
if err != nil {
uploadsReq.API.log.Errorf(err.Error())
return &proto.WHUploadsResponse{}, err
Expand All @@ -743,7 +744,7 @@ func (uploadsReq *UploadsReq) warehouseUploadsForHosted(authorizedSourceIDs []st
}

// for non hosted workspaces - we get the uploads and the total upload count using separate queries
func (uploadsReq *UploadsReq) warehouseUploads(selectFields string) (uploadsRes *proto.WHUploadsResponse, err error) {
func (uploadsReq *UploadsReq) warehouseUploads(ctx context.Context, selectFields string) (uploadsRes *proto.WHUploadsResponse, err error) {
var (
uploads []*proto.WHUploadResponse
totalUploadCount int32
Expand Down Expand Up @@ -784,13 +785,13 @@ func (uploadsReq *UploadsReq) warehouseUploads(selectFields string) (uploadsRes
// we get uploads for non hosted workspaces in two steps
// this is because getting this info via 2 queries is faster than getting it via one query(using the 'count(*) OVER()' clause)
// step1 - get all uploads
uploads, _, err = uploadsReq.getUploadsFromDB(false, query)
uploads, _, err = uploadsReq.getUploadsFromDB(ctx, false, query)
if err != nil {
uploadsReq.API.log.Errorf(err.Error())
return &proto.WHUploadsResponse{}, err
}
// step2 - get total upload count
totalUploadCount, err = uploadsReq.getTotalUploadCount(whereClause)
totalUploadCount, err = uploadsReq.getTotalUploadCount(ctx, whereClause)
if err != nil {
uploadsReq.API.log.Errorf(err.Error())
return &proto.WHUploadsResponse{}, err
Expand Down Expand Up @@ -825,8 +826,13 @@ func checkMapForValidKey(configMap map[string]interface{}, key string) bool {
func validateObjectStorage(ctx context.Context, request *ObjectStorageValidationRequest) error {
pkgLogger.Infof("Received call to validate object storage for type: %s\n", request.Type)

settings, err := getFileManagerSettings(ctx, request.Type, request.Config)
if err != nil {
return fmt.Errorf("unable to create file manager settings: \n%s", err.Error())
}

factory := &filemanager.FileManagerFactoryT{}
fileManager, err := factory.New(getFileManagerSettings(request.Type, request.Config))
fileManager, err := factory.New(settings)
if err != nil {
return fmt.Errorf("unable to create file manager: \n%s", err.Error())
}
Expand Down Expand Up @@ -874,20 +880,22 @@ func validateObjectStorage(ctx context.Context, request *ObjectStorageValidation
return nil
}

func getFileManagerSettings(provider string, inputConfig map[string]interface{}) *filemanager.SettingsT {
func getFileManagerSettings(ctx context.Context, provider string, inputConfig map[string]interface{}) (*filemanager.SettingsT, error) {
settings := &filemanager.SettingsT{
Provider: provider,
Config: inputConfig,
}

overrideWithEnv(settings)
return settings
if err := overrideWithEnv(ctx, settings); err != nil {
return nil, fmt.Errorf("overriding config with env: %w", err)
}
return settings, nil
}

// overrideWithEnv overrides the config keys in the fileManager settings
// with fallback values pulled from env. Only supported for S3 for now.
func overrideWithEnv(settings *filemanager.SettingsT) {
envConfig := filemanager.GetProviderConfigFromEnv(context.TODO(), settings.Provider)
func overrideWithEnv(ctx context.Context, settings *filemanager.SettingsT) error {
envConfig := filemanager.GetProviderConfigFromEnv(ctx, settings.Provider)

if settings.Provider == "S3" {
ifNotExistThenSet("prefix", envConfig["prefix"], settings.Config)
Expand All @@ -898,6 +906,7 @@ func overrideWithEnv(settings *filemanager.SettingsT) {
ifNotExistThenSet("externalID", envConfig["externalID"], settings.Config)
ifNotExistThenSet("regionHint", envConfig["regionHint"], settings.Config)
}
return ctx.Err()
}

func ifNotExistThenSet(keyToReplace string, replaceWith interface{}, configMap map[string]interface{}) {
Expand Down
8 changes: 6 additions & 2 deletions warehouse/api_test.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
package warehouse

import (
"context"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/rudderlabs/rudder-server/services/filemanager"
)

var _ = Describe("warehouse_api", func() {
Context("Testing objectStorageValidation ", func() {
ctx := context.Background()

It("Should fallback to backup credentials when fields missing(as of now backup only supported for s3)", func() {
fm := &filemanager.SettingsT{
Provider: "AZURE_BLOB",
Config: map[string]interface{}{"containerName": "containerName1", "prefix": "prefix1", "accountKey": "accountKey1"},
}
overrideWithEnv(fm)
overrideWithEnv(ctx, fm)
Expect(fm.Config["accountName"]).To(BeNil())
fm.Provider = "S3"
fm.Config = map[string]interface{}{"bucketName": "bucket1", "prefix": "prefix1", "accessKeyID": "KeyID1"}
overrideWithEnv(fm)
overrideWithEnv(ctx, fm)
Expect(fm.Config["accessKey"]).ToNot(BeNil())
})
It("Should set value for key when key not present", func() {
Expand Down
24 changes: 12 additions & 12 deletions warehouse/archive/archiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ type Archiver struct {
Multitenant *multitenant.Manager
}

func (a *Archiver) backupRecords(args backupRecordsArgs) (backupLocation string, err error) {
func (a *Archiver) backupRecords(ctx context.Context, args backupRecordsArgs) (backupLocation string, err error) {
a.Logger.Infof(`Starting backupRecords for uploadId: %s, sourceId: %s, destinationId: %s, tableName: %s,`, args.uploadID, args.sourceID, args.destID, args.tableName)
tmpDirPath, err := misc.CreateTMPDIR()
if err != nil {
Expand All @@ -91,7 +91,7 @@ func (a *Archiver) backupRecords(args backupRecordsArgs) (backupLocation string,

fManager, err := a.FileManager.New(&filemanager.SettingsT{
Provider: config.GetString("JOBS_BACKUP_STORAGE_PROVIDER", "S3"),
Config: filemanager.GetProviderConfigForBackupsFromEnv(context.TODO()),
Config: filemanager.GetProviderConfigForBackupsFromEnv(ctx),
})
if err != nil {
err = fmt.Errorf("error in creating a file manager for:%s. Error: %w", config.GetString("JOBS_BACKUP_STORAGE_PROVIDER", "S3"), err)
Expand Down Expand Up @@ -133,7 +133,7 @@ func (a *Archiver) backupRecords(args backupRecordsArgs) (backupLocation string,
return
}

func (a *Archiver) deleteFilesInStorage(locations []string) error {
func (a *Archiver) deleteFilesInStorage(ctx context.Context, locations []string) error {
fManager, err := a.FileManager.New(&filemanager.SettingsT{
Provider: warehouseutils.S3,
Config: misc.GetRudderObjectStorageConfig(""),
Expand All @@ -143,7 +143,7 @@ func (a *Archiver) deleteFilesInStorage(locations []string) error {
return err
}

err = fManager.DeleteObjects(context.TODO(), locations)
err = fManager.DeleteObjects(ctx, locations)
if err != nil {
a.Logger.Errorf("Error in deleting objects in Rudder S3: %v", err)
}
Expand Down Expand Up @@ -239,7 +239,7 @@ func (a *Archiver) Do(ctx context.Context) error {
var archivedUploads int
for _, u := range uploadsToArchive {

txn, err := a.DB.Begin()
txn, err := a.DB.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
a.Logger.Errorf(`Error creating txn in archiveUploadFiles. Error: %v`, err)
continue
Expand Down Expand Up @@ -267,7 +267,7 @@ func (a *Archiver) Do(ctx context.Context) error {
u.endStagingFileId,
)

stagingFileRows, err := txn.Query(stmt)
stagingFileRows, err := txn.QueryContext(ctx, stmt)
if err != nil {
a.Logger.Errorf(`Error running txn in archiveUploadFiles. Query: %s Error: %v`, stmt, err)
txn.Rollback()
Expand Down Expand Up @@ -297,7 +297,7 @@ func (a *Archiver) Do(ctx context.Context) error {
if len(stagingFileIDs) > 0 {
if !hasUsedRudderStorage {
filterSQL := fmt.Sprintf(`id IN (%v)`, misc.IntArrayToString(stagingFileIDs, ","))
storedStagingFilesLocation, err = a.backupRecords(backupRecordsArgs{
storedStagingFilesLocation, err = a.backupRecords(ctx, backupRecordsArgs{
tableName: warehouseutils.WarehouseStagingFilesTable,
sourceID: u.sourceID,
destID: u.destID,
Expand All @@ -315,7 +315,7 @@ func (a *Archiver) Do(ctx context.Context) error {
}

if hasUsedRudderStorage {
err = a.deleteFilesInStorage(stagingFileLocations)
err = a.deleteFilesInStorage(ctx, stagingFileLocations)
if err != nil {
a.Logger.Errorf(`Error deleting staging files from Rudder S3. Error: %v`, stmt, err)
txn.Rollback()
Expand All @@ -333,7 +333,7 @@ func (a *Archiver) Do(ctx context.Context) error {
warehouseutils.WarehouseStagingFilesTable,
misc.IntArrayToString(stagingFileIDs, ","),
)
_, err = txn.Query(stmt)
_, err = txn.QueryContext(ctx, stmt)
if err != nil {
a.Logger.Errorf(`Error running txn in archiveUploadFiles. Query: %s Error: %v`, stmt, err)
txn.Rollback()
Expand All @@ -349,7 +349,7 @@ func (a *Archiver) Do(ctx context.Context) error {
`,
warehouseutils.WarehouseLoadFilesTable,
)
loadLocationRows, err := txn.Query(stmt, pq.Array(stagingFileIDs))
loadLocationRows, err := txn.QueryContext(ctx, stmt, pq.Array(stagingFileIDs))
if err != nil {
a.Logger.Errorf(`Error running txn in archiveUploadFiles. Query: %s Error: %v`, stmt, err)
txn.Rollback()
Expand Down Expand Up @@ -380,7 +380,7 @@ func (a *Archiver) Do(ctx context.Context) error {
}
paths = append(paths, u.Path[1:])
}
err = a.deleteFilesInStorage(paths)
err = a.deleteFilesInStorage(ctx, paths)
if err != nil {
a.Logger.Errorf(`Error deleting load files from Rudder S3. Error: %v`, stmt, err)
txn.Rollback()
Expand All @@ -403,7 +403,7 @@ func (a *Archiver) Do(ctx context.Context) error {
warehouseutils.WarehouseUploadsTable,
u.uploadID,
)
_, err = txn.Exec(stmt, u.uploadMetdata)
_, err = txn.ExecContext(ctx, stmt, u.uploadMetdata)
if err != nil {
a.Logger.Errorf(`Error running txn in archiveUploadFiles. Query: %s Error: %v`, stmt, err)
txn.Rollback()
Expand Down

0 comments on commit 990a405

Please sign in to comment.