diff --git a/warehouse/integrations/deltalake/deltalake.go b/warehouse/integrations/deltalake/deltalake.go index 27f384e962..a0e4cadd8a 100644 --- a/warehouse/integrations/deltalake/deltalake.go +++ b/warehouse/integrations/deltalake/deltalake.go @@ -215,6 +215,7 @@ func (d *Deltalake) connect() (*sqlmiddleware.DB, error) { db := sql.OpenDB(connector) middleware := sqlmiddleware.New( db, + sqlmiddleware.WithStats(d.stats), sqlmiddleware.WithLogger(d.logger), sqlmiddleware.WithKeyAndValues( logfield.SourceID, d.Warehouse.Source.ID, diff --git a/warehouse/integrations/middleware/sqlquerywrapper/sql.go b/warehouse/integrations/middleware/sqlquerywrapper/sql.go index 8c78765674..538de0563a 100644 --- a/warehouse/integrations/middleware/sqlquerywrapper/sql.go +++ b/warehouse/integrations/middleware/sqlquerywrapper/sql.go @@ -5,12 +5,14 @@ import ( "database/sql" "errors" "fmt" + "sync" "time" rslogger "github.com/rudderlabs/rudder-go-kit/logger" - + "github.com/rudderlabs/rudder-go-kit/stats" "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/warehouse/logfield" + warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" ) type Opt func(*DB) @@ -22,6 +24,7 @@ type logger interface { type DB struct { *sql.DB + stats stats.Stats since func(time.Time) time.Duration logger logger @@ -61,16 +64,20 @@ func (r *Rows) Err() error { type Row struct { *sql.Row context.CancelFunc + once sync.Once logQ } func (r *Row) Scan(dest ...interface{}) error { defer r.CancelFunc() - r.logQ() + r.once.Do(r.logQ) return r.Row.Scan(dest...) } +// Err provides a way for wrapping packages to check for +// query errors without calling Scan. func (r *Row) Err() error { + r.once.Do(r.logQ) return r.Row.Err() } @@ -86,6 +93,12 @@ func WithLogger(logger logger) Opt { } } +func WithStats(stats stats.Stats) Opt { + return func(s *DB) { + s.stats = stats + } +} + func WithKeyAndValues(keyAndValues ...any) Opt { return func(s *DB) { s.keysAndValues = keyAndValues @@ -209,6 +222,38 @@ func (db *DB) WithTx(ctx context.Context, fn func(*Tx) error) error { func (db *DB) logQuery(query string, since time.Time) logQ { return func() { + var ( + sanitizedQuery string + keysAndValues []any + ) + createLogData := func() { + sanitizedQuery, _ = misc.ReplaceMultiRegex(query, db.secretsRegex) + + keysAndValues = []any{ + logfield.Query, sanitizedQuery, + logfield.QueryExecutionTime, db.since(since), + } + keysAndValues = append(keysAndValues, db.keysAndValues...) + } + + if db.stats != nil { + var expected bool + tags := make(stats.Tags, len(db.keysAndValues)/2+1) + tags["query_type"], expected = warehouseutils.GetQueryType(query) + if !expected { + createLogData() + db.logger.Warnw("sql stats: unexpected query type", keysAndValues...) + } + for i := 0; i < len(db.keysAndValues); i += 2 { + key, ok := db.keysAndValues[i].(string) + if !ok { + continue + } + tags[key] = fmt.Sprint(db.keysAndValues[i+1]) + } + db.stats.NewTaggedStat("wh_query_count", stats.CountType, tags).Increment() + } + if db.slowQueryThreshold <= 0 { return } @@ -216,14 +261,9 @@ func (db *DB) logQuery(query string, since time.Time) logQ { return } - sanitizedQuery, _ := misc.ReplaceMultiRegex(query, db.secretsRegex) - - keysAndValues := []any{ - logfield.Query, sanitizedQuery, - logfield.QueryExecutionTime, db.since(since), + if sanitizedQuery == "" { + createLogData() } - keysAndValues = append(keysAndValues, db.keysAndValues...) - db.logger.Infow("executing query", keysAndValues...) } } diff --git a/warehouse/integrations/middleware/sqlquerywrapper/sql_test.go b/warehouse/integrations/middleware/sqlquerywrapper/sql_test.go index eceb46efee..81c7e2568b 100644 --- a/warehouse/integrations/middleware/sqlquerywrapper/sql_test.go +++ b/warehouse/integrations/middleware/sqlquerywrapper/sql_test.go @@ -8,6 +8,8 @@ import ( "time" rslogger "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/stats" + "github.com/rudderlabs/rudder-go-kit/stats/memstats" "github.com/google/uuid" @@ -449,3 +451,30 @@ func TestQueryWrapper(t *testing.T) { require.NoError(t, err) }) } + +func TestWithStats(t *testing.T) { + t.Parallel() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + pgResource, err := resource.SetupPostgres(pool, t) + require.NoError(t, err) + + s := memstats.New() + + qw := New( + pgResource.DB, + WithKeyAndValues("k1", "v1", "k2", "v2"), + WithStats(s), + ) + row := qw.QueryRowContext(context.Background(), "SELECT 1") + require.NoError(t, row.Err()) + + measurement := s.Get("wh_query_count", stats.Tags{ + "k1": "v1", + "k2": "v2", + "query_type": "SELECT", + }) + require.NotNilf(t, measurement, "measurement should not be nil") +} diff --git a/warehouse/integrations/postgres/postgres.go b/warehouse/integrations/postgres/postgres.go index ebb7d46596..0eb63838a8 100644 --- a/warehouse/integrations/postgres/postgres.go +++ b/warehouse/integrations/postgres/postgres.go @@ -176,6 +176,7 @@ func New(conf *config.Config, log logger.Logger, stat stats.Stats) *Postgres { func (pg *Postgres) getNewMiddleWare(db *sql.DB) *sqlmiddleware.DB { middleware := sqlmiddleware.New( db, + sqlmiddleware.WithStats(pg.stats), sqlmiddleware.WithLogger(pg.logger), sqlmiddleware.WithKeyAndValues( logfield.SourceID, pg.Warehouse.Source.ID, diff --git a/warehouse/integrations/redshift/redshift.go b/warehouse/integrations/redshift/redshift.go index c54338c898..de75757f89 100644 --- a/warehouse/integrations/redshift/redshift.go +++ b/warehouse/integrations/redshift/redshift.go @@ -1060,6 +1060,7 @@ func (rs *Redshift) connect(ctx context.Context) (*sqlmiddleware.DB, error) { } middleware := sqlmiddleware.New( db, + sqlmiddleware.WithStats(rs.stats), sqlmiddleware.WithLogger(rs.logger), sqlmiddleware.WithKeyAndValues( logfield.SourceID, rs.Warehouse.Source.ID, diff --git a/warehouse/integrations/snowflake/snowflake.go b/warehouse/integrations/snowflake/snowflake.go index 58b0627d1f..fc2103e414 100644 --- a/warehouse/integrations/snowflake/snowflake.go +++ b/warehouse/integrations/snowflake/snowflake.go @@ -987,6 +987,7 @@ func (sf *Snowflake) connect(ctx context.Context, opts optionalCreds) (*sqlmiddl } middleware := sqlmiddleware.New( db, + sqlmiddleware.WithStats(sf.stats), sqlmiddleware.WithLogger(sf.logger), sqlmiddleware.WithKeyAndValues( logfield.SourceID, sf.Warehouse.Source.ID, diff --git a/warehouse/jobs/runner.go b/warehouse/jobs/runner.go index 6fb108c4e9..068b9877ad 100644 --- a/warehouse/jobs/runner.go +++ b/warehouse/jobs/runner.go @@ -162,6 +162,7 @@ startAsyncJobRunner is the main runner that */ func (a *AsyncJobWh) startAsyncJobRunner(ctx context.Context) error { a.logger.Info("[WH-Jobs]: Starting async job runner") + defer a.logger.Info("[WH-Jobs]: Stopping AsyncJobRunner") var wg sync.WaitGroup for { @@ -169,10 +170,8 @@ func (a *AsyncJobWh) startAsyncJobRunner(ctx context.Context) error { select { case <-ctx.Done(): - a.logger.Info("[WH-Jobs]: Stopping AsyncJobRunner") return nil case <-time.After(a.retryTimeInterval): - } pendingAsyncJobs, err := a.getPendingAsyncJobs(ctx) @@ -180,48 +179,50 @@ func (a *AsyncJobWh) startAsyncJobRunner(ctx context.Context) error { a.logger.Errorf("[WH-Jobs]: unable to get pending async jobs with error %s", err.Error()) continue } + if len(pendingAsyncJobs) == 0 { + continue + } - if len(pendingAsyncJobs) > 0 { - a.logger.Info("[WH-Jobs]: Got pending wh async jobs") - a.logger.Infof("[WH-Jobs]: Number of async wh jobs left = %d\n", len(pendingAsyncJobs)) - notifierClaims, err := getMessagePayloadsFromAsyncJobPayloads(pendingAsyncJobs) - if err != nil { - a.logger.Errorf("Error converting the asyncJobType to notifier payload %s ", err) - asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, err) - _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) - continue - } - messagePayload := pgnotifier.MessagePayload{ - Jobs: notifierClaims, - JobType: AsyncJobType, - } - ch, err := a.pgnotifier.Publish(ctx, messagePayload, &warehouseutils.Schema{}, 100) - if err != nil { - a.logger.Errorf("[WH-Jobs]: unable to get publish async jobs to pgnotifier. Task failed with error %s", err.Error()) + a.logger.Infof("[WH-Jobs]: Number of async wh jobs left = %d", len(pendingAsyncJobs)) + + notifierClaims, err := getMessagePayloadsFromAsyncJobPayloads(pendingAsyncJobs) + if err != nil { + a.logger.Errorf("Error converting the asyncJobType to notifier payload %s ", err) + asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, err) + _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) + continue + } + messagePayload := pgnotifier.MessagePayload{ + Jobs: notifierClaims, + JobType: AsyncJobType, + } + ch, err := a.pgnotifier.Publish(ctx, messagePayload, &warehouseutils.Schema{}, 100) + if err != nil { + a.logger.Errorf("[WH-Jobs]: unable to get publish async jobs to pgnotifier. Task failed with error %s", err.Error()) + asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, err) + _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) + continue + } + asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobExecuting, err) + _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + a.logger.Infof("[WH-Jobs]: Context cancelled for async job runner") + case responses := <-ch: + a.logger.Info("[WH-Jobs]: Response received from the pgnotifier track batch") + asyncJobsStatusMap := getAsyncStatusMapFromAsyncPayloads(pendingAsyncJobs) + a.updateStatusJobPayloadsFromPgNotifierResponse(responses, asyncJobsStatusMap) + _ = a.updateAsyncJobs(ctx, asyncJobsStatusMap) + case <-time.After(a.asyncJobTimeOut): + a.logger.Errorf("Go Routine timed out waiting for a response from PgNotifier", pendingAsyncJobs[0].Id) asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, err) _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) - continue } - asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobExecuting, err) - _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) - wg.Add(1) - go func() { - select { - case responses := <-ch: - a.logger.Info("[WH-Jobs]: Response received from the pgnotifier track batch") - asyncJobsStatusMap := getAsyncStatusMapFromAsyncPayloads(pendingAsyncJobs) - a.updateStatusJobPayloadsFromPgNotifierResponse(responses, asyncJobsStatusMap) - _ = a.updateAsyncJobs(ctx, asyncJobsStatusMap) - wg.Done() - case <-time.After(a.asyncJobTimeOut): - a.logger.Errorf("Go Routine timed out waiting for a response from PgNotifier", pendingAsyncJobs[0].Id) - asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, err) - _ = a.updateAsyncJobs(ctx, asyncJobStatusMap) - wg.Done() - } - }() - wg.Wait() - } + }() + wg.Wait() } } @@ -293,9 +294,6 @@ func (a *AsyncJobWh) getPendingAsyncJobs(ctx context.Context) ([]AsyncJobPayload // Updates the warehouse async jobs with the status sent as a parameter func (a *AsyncJobWh) updateAsyncJobs(ctx context.Context, payloads map[string]AsyncJobStatus) error { - if ctx.Err() != nil { - return ctx.Err() - } a.logger.Info("[WH-Jobs]: Updating wh async jobs to Executing") var err error for _, payload := range payloads { @@ -304,7 +302,6 @@ func (a *AsyncJobWh) updateAsyncJobs(ctx context.Context, payloads map[string]As continue } err = a.updateAsyncJobStatus(ctx, payload.Id, payload.Status, "") - } return err } @@ -312,29 +309,30 @@ func (a *AsyncJobWh) updateAsyncJobs(ctx context.Context, payloads map[string]As func (a *AsyncJobWh) updateAsyncJobStatus(ctx context.Context, Id, status, errMessage string) error { a.logger.Infof("[WH-Jobs]: Updating status of wh async jobs to %s", status) sqlStatement := fmt.Sprintf(`UPDATE %s SET status=(CASE - WHEN attempt >= $1 - THEN $2 - ELSE $3 - END) , - error=$4 WHERE id=$5 AND status!=$6 AND status!=$7 `, warehouseutils.WarehouseAsyncJobTable) + WHEN attempt >= $1 + THEN $2 + ELSE $3 + END) , + error=$4 WHERE id=$5 AND status!=$6 AND status!=$7 `, + warehouseutils.WarehouseAsyncJobTable, + ) var err error for retryCount := 0; retryCount < a.maxQueryRetries; retryCount++ { a.logger.Debugf("[WH-Jobs]: updating async jobs table query %s, retry no : %d", sqlStatement, retryCount) - row, err := a.dbHandle.QueryContext(ctx, sqlStatement, a.maxAttemptsPerJob, WhJobAborted, status, errMessage, Id, WhJobAborted, WhJobSucceeded) + _, err := a.dbHandle.ExecContext(ctx, sqlStatement, + a.maxAttemptsPerJob, WhJobAborted, status, errMessage, Id, WhJobAborted, WhJobSucceeded, + ) if err == nil { a.logger.Info("Update successful") a.logger.Debugf("query: %s successfully executed", sqlStatement) if status == WhJobFailed { - err = a.updateAsyncJobAttempt(ctx, Id) - return err + return a.updateAsyncJobAttempt(ctx, Id) } return err } - _ = row.Err() - } - if err != nil { - a.logger.Errorf("query: %s failed with Error : %s", sqlStatement, err.Error()) } + + a.logger.Errorf("Query: %s failed with error: %s", sqlStatement, err.Error()) return err } diff --git a/warehouse/tracker_test.go b/warehouse/tracker_test.go index 62bf02b7eb..e89d5c2cb5 100644 --- a/warehouse/tracker_test.go +++ b/warehouse/tracker_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" "github.com/rudderlabs/rudder-server/warehouse/internal/model" "github.com/golang/mock/gomock" @@ -157,7 +158,7 @@ func TestHandleT_Track(t *testing.T) { }, NowSQL: nowSQL, stats: store, - dbHandle: pgResource.DB, + dbHandle: sqlquerywrapper.New(pgResource.DB), Logger: logger.NOP, } @@ -267,7 +268,7 @@ func TestHandleT_CronTracker(t *testing.T) { }, NowSQL: "ABC", stats: memstats.New(), - dbHandle: pgResource.DB, + dbHandle: sqlquerywrapper.New(pgResource.DB), Logger: logger.NOP, } wh.warehouses = append(wh.warehouses, warehouse) diff --git a/warehouse/upload.go b/warehouse/upload.go index a447537187..a0bd537bf5 100644 --- a/warehouse/upload.go +++ b/warehouse/upload.go @@ -11,36 +11,30 @@ import ( "sync" "time" - "github.com/rudderlabs/rudder-server/warehouse/internal/service/loadfiles/downloader" - + "github.com/cenkalti/backoff/v4" "github.com/samber/lo" - - "github.com/rudderlabs/rudder-server/warehouse/logfield" - - "github.com/rudderlabs/rudder-server/services/alerta" - - schemarepository "github.com/rudderlabs/rudder-server/warehouse/integrations/datalake/schema-repository" - sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" - - "github.com/rudderlabs/rudder-server/warehouse/integrations/manager" - "golang.org/x/exp/slices" - "github.com/cenkalti/backoff/v4" - "github.com/rudderlabs/rudder-go-kit/config" "github.com/rudderlabs/rudder-go-kit/stats" "github.com/rudderlabs/rudder-server/jobsdb" "github.com/rudderlabs/rudder-server/rruntime" + "github.com/rudderlabs/rudder-server/services/alerta" "github.com/rudderlabs/rudder-server/services/pgnotifier" "github.com/rudderlabs/rudder-server/utils/misc" "github.com/rudderlabs/rudder-server/utils/timeutil" "github.com/rudderlabs/rudder-server/utils/types" "github.com/rudderlabs/rudder-server/warehouse/identity" + schemarepository "github.com/rudderlabs/rudder-server/warehouse/integrations/datalake/schema-repository" + "github.com/rudderlabs/rudder-server/warehouse/integrations/manager" + "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" + sqlmiddleware "github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper" "github.com/rudderlabs/rudder-server/warehouse/internal/loadfiles" "github.com/rudderlabs/rudder-server/warehouse/internal/model" "github.com/rudderlabs/rudder-server/warehouse/internal/repo" "github.com/rudderlabs/rudder-server/warehouse/internal/service" + "github.com/rudderlabs/rudder-server/warehouse/internal/service/loadfiles/downloader" + "github.com/rudderlabs/rudder-server/warehouse/logfield" warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils" "github.com/rudderlabs/rudder-server/warehouse/validations" ) @@ -69,7 +63,7 @@ type uploadState struct { type tableNameT string type UploadJobFactory struct { - dbHandle *sql.DB + dbHandle *sqlquerywrapper.DB destinationValidator validations.DestinationValidator loadFile *loadfiles.LoadFileGenerator recovery *service.Recovery @@ -174,13 +168,9 @@ func setMaxParallelLoads() { } func (f *UploadJobFactory) NewUploadJob(ctx context.Context, dto *model.UploadJob, whManager manager.Manager) *UploadJob { - wrappedDBHandle := sqlmiddleware.New( - f.dbHandle, - sqlmiddleware.WithQueryTimeout(dbHanndleTimeout), - ) return &UploadJob{ ctx: ctx, - dbHandle: wrappedDBHandle, + dbHandle: f.dbHandle, loadfile: f.loadFile, recovery: f.recovery, pgNotifier: f.pgNotifier, @@ -1799,7 +1789,7 @@ func (job *UploadJob) getLoadFilesTableMap() (loadFilesMap map[tableNameT]bool, job.upload.LoadFileStartID, job.upload.LoadFileEndID, } - rows, err := dbHandle.QueryContext(job.ctx, sqlStatement, sqlStatementArgs...) + rows, err := wrappedDBHandle.QueryContext(job.ctx, sqlStatement, sqlStatementArgs...) if err == sql.ErrNoRows { err = nil return diff --git a/warehouse/utils/querytype.go b/warehouse/utils/querytype.go new file mode 100644 index 0000000000..8734009199 --- /dev/null +++ b/warehouse/utils/querytype.go @@ -0,0 +1,68 @@ +package warehouseutils + +import ( + "fmt" + "regexp" + "strings" +) + +var ( + // queryTypeIndex works for both regexes as long as the groups order is not changed + queryTypeIndex int + queryTypeRegex *regexp.Regexp + + unknownQueryTypeRegex = regexp.MustCompile(`^(?i)\s*(?P\w+)\s+`) +) + +func init() { + tokens := []string{ + "SELECT", "UPDATE", "DELETE FROM", "INSERT INTO", "COPY", + "CREATE TEMP TABLE", "CREATE TEMPORARY TABLE", + "CREATE DATABASE", "CREATE SCHEMA", "CREATE TABLE", "CREATE INDEX", + "ALTER TABLE", "ALTER SESSION", + "DROP TABLE", + } + queryTypeRegex = regexp.MustCompile(`^(?i)\s*(?P` + strings.Join(tokens, "|") + `)\s+`) + + var found bool + for i, name := range queryTypeRegex.SubexpNames() { + if name == "type" { + found = true + queryTypeIndex = i + break + } + } + if !found { + panic(fmt.Errorf("warehouseutils: query type index not found")) + } +} + +// GetQueryType returns the type of the query. +func GetQueryType(query string) (string, bool) { + var ( + expected bool + queryType = "" + submatch = queryTypeRegex.FindStringSubmatch(query) + ) + + if len(submatch) > queryTypeIndex { + expected = true + queryType = strings.ToUpper(submatch[queryTypeIndex]) + if queryType == "CREATE TEMPORARY TABLE" { + queryType = "CREATE TEMP TABLE" + } + } + + if queryType == "" { // get the first word + submatch = unknownQueryTypeRegex.FindStringSubmatch(query) + if len(submatch) > queryTypeIndex { + queryType = strings.ToUpper(submatch[queryTypeIndex]) + } + } + + if queryType == "" { + queryType = "UNKNOWN" + } + + return queryType, expected +} diff --git a/warehouse/utils/querytype_benchmark_test.go b/warehouse/utils/querytype_benchmark_test.go new file mode 100644 index 0000000000..b644a3ed57 --- /dev/null +++ b/warehouse/utils/querytype_benchmark_test.go @@ -0,0 +1,37 @@ +package warehouseutils + +import ( + "testing" +) + +/* +BenchmarkGetQueryType/expected-24 4825515 248.1 ns/op +BenchmarkGetQueryType/unknown-24 922964 1296 ns/op +BenchmarkGetQueryType/empty-24 824426 1253 ns/op +*/ +func BenchmarkGetQueryType(b *testing.B) { + b.Run("expected", func(b *testing.B) { + query := "\t\n\n \t\n\n seLeCt * from table" + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = GetQueryType(query) + } + }) + b.Run("unexpected", func(b *testing.B) { + query := "\t\n\n \t\n\n something * from table" + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = GetQueryType(query) + } + }) + b.Run("empty", func(b *testing.B) { + query := "\t\n\n \t\n\n " + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = GetQueryType(query) + } + }) +} diff --git a/warehouse/utils/querytype_test.go b/warehouse/utils/querytype_test.go new file mode 100644 index 0000000000..936bfc3dec --- /dev/null +++ b/warehouse/utils/querytype_test.go @@ -0,0 +1,41 @@ +package warehouseutils + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetQueryType(t *testing.T) { + tests := []struct { + name string + query string + want string + expected bool + }{ + {"select 1", "Select * from table", "SELECT", true}, + {"select 2", "\t\n\n \t\n\n seLeCt * from table", "SELECT", true}, + {"update", "\t\n\n \t\n\n UpDaTe something", "UPDATE", true}, + {"delete", "\t\n\n \t\n\n DeLeTe FROm something", "DELETE FROM", true}, + {"insert", "\t\n\n \t\n\n InSerT INTO something", "INSERT INTO", true}, + {"copy", "\t\n\n \t\n\n cOpY t1 from t2", "COPY", true}, + {"create temp table 1", "\t\n\n \t\n\n create temp table t1", "CREATE TEMP TABLE", true}, + {"create temp table 2", "\t\n\n \t\n\n create tempORARY table t1", "CREATE TEMP TABLE", true}, + {"create database", "\t\n\n \t\n\n creATE dataBASE db1", "CREATE DATABASE", true}, + {"create schema", "\t\n\n \t\n\n creATE schEMA sch1", "CREATE SCHEMA", true}, + {"create table", "\t\n\n \t\n\n creATE tABLE t1", "CREATE TABLE", true}, + {"create index", "\t\n\n \t\n\n creATE inDeX idx1", "CREATE INDEX", true}, + {"alter table", "\t\n\n \t\n\n ALTer tABLE t1", "ALTER TABLE", true}, + {"alter session", "\t\n\n \t\n\n ALTer seSsIoN s1", "ALTER SESSION", true}, + {"drop table", "\t\n\n \t\n\n dROp Table t1", "DROP TABLE", true}, + {"unexpected", "\t\n\n \t\n\n something unexpected", "SOMETHING", false}, + {"empty", "\t\n\n \t\n\n ", "UNKNOWN", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, expected := GetQueryType(tt.query) + require.Equalf(t, tt.want, got, "GetQueryType() value = %v, want %v", got, tt.want) + require.Equalf(t, tt.want, got, "GetQueryType() expected = %v, want %v", tt.expected, expected) + }) + } +} diff --git a/warehouse/warehouse.go b/warehouse/warehouse.go index 24fc411d63..eee2129994 100644 --- a/warehouse/warehouse.go +++ b/warehouse/warehouse.go @@ -66,7 +66,7 @@ var ( webPort int dbHandle *sql.DB wrappedDBHandle *sqlquerywrapper.DB - dbHanndleTimeout time.Duration + dbHandleTimeout time.Duration notifier pgnotifier.PGNotifier tenantManager *multitenant.Manager controlPlaneClient *controlplane.Client @@ -132,7 +132,7 @@ type ( type HandleT struct { destType string warehouses []model.Warehouse - dbHandle *sql.DB + dbHandle *sqlquerywrapper.DB warehouseDBHandle *DB stagingRepo *repo.StagingFiles uploadRepo *repo.Uploads @@ -205,7 +205,7 @@ func loadConfig() { config.RegisterIntConfigVariable(8, &maxParallelJobCreation, true, 1, "Warehouse.maxParallelJobCreation") config.RegisterBoolConfigVariable(false, &enableJitterForSyncs, true, "Warehouse.enableJitterForSyncs") config.RegisterDurationConfigVariable(30, &tableCountQueryTimeout, true, time.Second, []string{"Warehouse.tableCountQueryTimeout", "Warehouse.tableCountQueryTimeoutInS"}...) - config.RegisterDurationConfigVariable(5, &dbHanndleTimeout, true, time.Minute, []string{"Warehouse.dbHanndleTimeout", "Warehouse.dbHanndleTimeoutInMin"}...) + config.RegisterDurationConfigVariable(5, &dbHandleTimeout, true, time.Minute, []string{"Warehouse.dbHandleTimeout", "Warehouse.dbHanndleTimeoutInMin"}...) appName = misc.DefaultString("rudder-server").OnError(os.Hostname()) } @@ -866,7 +866,7 @@ func (wh *HandleT) Setup(ctx context.Context, whType string) error { pkgLogger.Infof("WH: Warehouse Router started: %s", whType) wh.Logger = pkgLogger wh.conf = config.Default - wh.dbHandle = dbHandle + wh.dbHandle = wrappedDBHandle // We now have access to the warehouseDBHandle through // which we will be running the db calls. wh.warehouseDBHandle = NewWarehouseDB(wrappedDBHandle) @@ -992,7 +992,7 @@ func minimalConfigSubscriber(ctx context.Context) { for _, destination := range source.Destinations { if slices.Contains(warehouseutils.WarehouseDestinations, destination.DestinationDefinition.Name) { wh := &HandleT{ - dbHandle: dbHandle, + dbHandle: wrappedDBHandle, destType: destination.DestinationDefinition.Name, whSchemaRepo: repo.NewWHSchemas(wrappedDBHandle), conf: config.Default, @@ -1637,7 +1637,7 @@ func setupDB(ctx context.Context, connInfo string) error { wrappedDBHandle = sqlquerywrapper.New( dbHandle, sqlquerywrapper.WithLogger(pkgLogger.Child("dbHandle")), - sqlquerywrapper.WithQueryTimeout(dbHanndleTimeout), + sqlquerywrapper.WithQueryTimeout(dbHandleTimeout), ) return setupTables(dbHandle) diff --git a/warehouse/warehouse_test.go b/warehouse/warehouse_test.go index 934dc93658..891dfc6e1c 100644 --- a/warehouse/warehouse_test.go +++ b/warehouse/warehouse_test.go @@ -151,7 +151,7 @@ func TestUploadJob_ProcessingStats(t *testing.T) { wh := HandleT{ destType: tc.destType, stats: store, - dbHandle: pgResource.DB, + dbHandle: sqlquerywrapper.New(pgResource.DB), whSchemaRepo: repo.NewWHSchemas(sqlquerywrapper.New(pgResource.DB)), } tenantManager = &multitenant.Manager{} @@ -344,7 +344,7 @@ func Test_GetNamespace(t *testing.T) { wh := HandleT{ destType: tc.destType, stats: store, - dbHandle: pgResource.DB, + dbHandle: sqlquerywrapper.New(pgResource.DB), whSchemaRepo: repo.NewWHSchemas(sqlquerywrapper.New(pgResource.DB)), conf: conf, }