Skip to content

Commit

Permalink
chore: warehouse sql stats (#3638)
Browse files Browse the repository at this point in the history
  • Loading branch information
fracasula committed Jul 21, 2023
1 parent f9a19f3 commit e20976d
Show file tree
Hide file tree
Showing 14 changed files with 304 additions and 96 deletions.
1 change: 1 addition & 0 deletions warehouse/integrations/deltalake/deltalake.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ func (d *Deltalake) connect() (*sqlmiddleware.DB, error) {
db := sql.OpenDB(connector)
middleware := sqlmiddleware.New(
db,
sqlmiddleware.WithStats(d.stats),
sqlmiddleware.WithLogger(d.logger),
sqlmiddleware.WithKeyAndValues(
logfield.SourceID, d.Warehouse.Source.ID,
Expand Down
58 changes: 49 additions & 9 deletions warehouse/integrations/middleware/sqlquerywrapper/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import (
"database/sql"
"errors"
"fmt"
"sync"
"time"

rslogger "github.com/rudderlabs/rudder-go-kit/logger"

"github.com/rudderlabs/rudder-go-kit/stats"
"github.com/rudderlabs/rudder-server/utils/misc"
"github.com/rudderlabs/rudder-server/warehouse/logfield"
warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils"
)

type Opt func(*DB)
Expand All @@ -22,6 +24,7 @@ type logger interface {

type DB struct {
*sql.DB
stats stats.Stats

since func(time.Time) time.Duration
logger logger
Expand Down Expand Up @@ -61,16 +64,20 @@ func (r *Rows) Err() error {
type Row struct {
*sql.Row
context.CancelFunc
once sync.Once
logQ
}

func (r *Row) Scan(dest ...interface{}) error {
defer r.CancelFunc()
r.logQ()
r.once.Do(r.logQ)
return r.Row.Scan(dest...)
}

// Err provides a way for wrapping packages to check for
// query errors without calling Scan.
func (r *Row) Err() error {
r.once.Do(r.logQ)
return r.Row.Err()
}

Expand All @@ -86,6 +93,12 @@ func WithLogger(logger logger) Opt {
}
}

func WithStats(stats stats.Stats) Opt {
return func(s *DB) {
s.stats = stats
}
}

func WithKeyAndValues(keyAndValues ...any) Opt {
return func(s *DB) {
s.keysAndValues = keyAndValues
Expand Down Expand Up @@ -209,21 +222,48 @@ func (db *DB) WithTx(ctx context.Context, fn func(*Tx) error) error {

func (db *DB) logQuery(query string, since time.Time) logQ {
return func() {
var (
sanitizedQuery string
keysAndValues []any
)
createLogData := func() {
sanitizedQuery, _ = misc.ReplaceMultiRegex(query, db.secretsRegex)

keysAndValues = []any{
logfield.Query, sanitizedQuery,
logfield.QueryExecutionTime, db.since(since),
}
keysAndValues = append(keysAndValues, db.keysAndValues...)
}

if db.stats != nil {
var expected bool
tags := make(stats.Tags, len(db.keysAndValues)/2+1)
tags["query_type"], expected = warehouseutils.GetQueryType(query)
if !expected {
createLogData()
db.logger.Warnw("sql stats: unexpected query type", keysAndValues...)
}
for i := 0; i < len(db.keysAndValues); i += 2 {
key, ok := db.keysAndValues[i].(string)
if !ok {
continue
}
tags[key] = fmt.Sprint(db.keysAndValues[i+1])
}
db.stats.NewTaggedStat("wh_query_count", stats.CountType, tags).Increment()
}

if db.slowQueryThreshold <= 0 {
return
}
if db.since(since) < db.slowQueryThreshold {
return
}

sanitizedQuery, _ := misc.ReplaceMultiRegex(query, db.secretsRegex)

keysAndValues := []any{
logfield.Query, sanitizedQuery,
logfield.QueryExecutionTime, db.since(since),
if sanitizedQuery == "" {
createLogData()
}
keysAndValues = append(keysAndValues, db.keysAndValues...)

db.logger.Infow("executing query", keysAndValues...)
}
}
Expand Down
29 changes: 29 additions & 0 deletions warehouse/integrations/middleware/sqlquerywrapper/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"time"

rslogger "github.com/rudderlabs/rudder-go-kit/logger"
"github.com/rudderlabs/rudder-go-kit/stats"
"github.com/rudderlabs/rudder-go-kit/stats/memstats"

"github.com/google/uuid"

Expand Down Expand Up @@ -449,3 +451,30 @@ func TestQueryWrapper(t *testing.T) {
require.NoError(t, err)
})
}

func TestWithStats(t *testing.T) {
t.Parallel()

pool, err := dockertest.NewPool("")
require.NoError(t, err)

pgResource, err := resource.SetupPostgres(pool, t)
require.NoError(t, err)

s := memstats.New()

qw := New(
pgResource.DB,
WithKeyAndValues("k1", "v1", "k2", "v2"),
WithStats(s),
)
row := qw.QueryRowContext(context.Background(), "SELECT 1")
require.NoError(t, row.Err())

measurement := s.Get("wh_query_count", stats.Tags{
"k1": "v1",
"k2": "v2",
"query_type": "SELECT",
})
require.NotNilf(t, measurement, "measurement should not be nil")
}
1 change: 1 addition & 0 deletions warehouse/integrations/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ func New(conf *config.Config, log logger.Logger, stat stats.Stats) *Postgres {
func (pg *Postgres) getNewMiddleWare(db *sql.DB) *sqlmiddleware.DB {
middleware := sqlmiddleware.New(
db,
sqlmiddleware.WithStats(pg.stats),
sqlmiddleware.WithLogger(pg.logger),
sqlmiddleware.WithKeyAndValues(
logfield.SourceID, pg.Warehouse.Source.ID,
Expand Down
1 change: 1 addition & 0 deletions warehouse/integrations/redshift/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,7 @@ func (rs *Redshift) connect(ctx context.Context) (*sqlmiddleware.DB, error) {
}
middleware := sqlmiddleware.New(
db,
sqlmiddleware.WithStats(rs.stats),
sqlmiddleware.WithLogger(rs.logger),
sqlmiddleware.WithKeyAndValues(
logfield.SourceID, rs.Warehouse.Source.ID,
Expand Down
1 change: 1 addition & 0 deletions warehouse/integrations/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,7 @@ func (sf *Snowflake) connect(ctx context.Context, opts optionalCreds) (*sqlmiddl
}
middleware := sqlmiddleware.New(
db,
sqlmiddleware.WithStats(sf.stats),
sqlmiddleware.WithLogger(sf.logger),
sqlmiddleware.WithKeyAndValues(
logfield.SourceID, sf.Warehouse.Source.ID,
Expand Down
110 changes: 54 additions & 56 deletions warehouse/jobs/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,66 +162,67 @@ startAsyncJobRunner is the main runner that
*/
func (a *AsyncJobWh) startAsyncJobRunner(ctx context.Context) error {
a.logger.Info("[WH-Jobs]: Starting async job runner")
defer a.logger.Info("[WH-Jobs]: Stopping AsyncJobRunner")

var wg sync.WaitGroup
for {
a.logger.Debug("[WH-Jobs]: Scanning for waiting async job")

select {
case <-ctx.Done():
a.logger.Info("[WH-Jobs]: Stopping AsyncJobRunner")
return nil
case <-time.After(a.retryTimeInterval):

}

pendingAsyncJobs, err := a.getPendingAsyncJobs(ctx)
if err != nil {
a.logger.Errorf("[WH-Jobs]: unable to get pending async jobs with error %s", err.Error())
continue
}
if len(pendingAsyncJobs) == 0 {
continue
}

if len(pendingAsyncJobs) > 0 {
a.logger.Info("[WH-Jobs]: Got pending wh async jobs")
a.logger.Infof("[WH-Jobs]: Number of async wh jobs left = %d\n", len(pendingAsyncJobs))
notifierClaims, err := getMessagePayloadsFromAsyncJobPayloads(pendingAsyncJobs)
if err != nil {
a.logger.Errorf("Error converting the asyncJobType to notifier payload %s ", err)
asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, err)
_ = a.updateAsyncJobs(ctx, asyncJobStatusMap)
continue
}
messagePayload := pgnotifier.MessagePayload{
Jobs: notifierClaims,
JobType: AsyncJobType,
}
ch, err := a.pgnotifier.Publish(ctx, messagePayload, &warehouseutils.Schema{}, 100)
if err != nil {
a.logger.Errorf("[WH-Jobs]: unable to get publish async jobs to pgnotifier. Task failed with error %s", err.Error())
a.logger.Infof("[WH-Jobs]: Number of async wh jobs left = %d", len(pendingAsyncJobs))

notifierClaims, err := getMessagePayloadsFromAsyncJobPayloads(pendingAsyncJobs)
if err != nil {
a.logger.Errorf("Error converting the asyncJobType to notifier payload %s ", err)
asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, err)
_ = a.updateAsyncJobs(ctx, asyncJobStatusMap)
continue
}
messagePayload := pgnotifier.MessagePayload{
Jobs: notifierClaims,
JobType: AsyncJobType,
}
ch, err := a.pgnotifier.Publish(ctx, messagePayload, &warehouseutils.Schema{}, 100)
if err != nil {
a.logger.Errorf("[WH-Jobs]: unable to get publish async jobs to pgnotifier. Task failed with error %s", err.Error())
asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, err)
_ = a.updateAsyncJobs(ctx, asyncJobStatusMap)
continue
}
asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobExecuting, err)
_ = a.updateAsyncJobs(ctx, asyncJobStatusMap)
wg.Add(1)
go func() {
defer wg.Done()
select {
case <-ctx.Done():
a.logger.Infof("[WH-Jobs]: Context cancelled for async job runner")
case responses := <-ch:
a.logger.Info("[WH-Jobs]: Response received from the pgnotifier track batch")
asyncJobsStatusMap := getAsyncStatusMapFromAsyncPayloads(pendingAsyncJobs)
a.updateStatusJobPayloadsFromPgNotifierResponse(responses, asyncJobsStatusMap)
_ = a.updateAsyncJobs(ctx, asyncJobsStatusMap)
case <-time.After(a.asyncJobTimeOut):
a.logger.Errorf("Go Routine timed out waiting for a response from PgNotifier", pendingAsyncJobs[0].Id)
asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, err)
_ = a.updateAsyncJobs(ctx, asyncJobStatusMap)
continue
}
asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobExecuting, err)
_ = a.updateAsyncJobs(ctx, asyncJobStatusMap)
wg.Add(1)
go func() {
select {
case responses := <-ch:
a.logger.Info("[WH-Jobs]: Response received from the pgnotifier track batch")
asyncJobsStatusMap := getAsyncStatusMapFromAsyncPayloads(pendingAsyncJobs)
a.updateStatusJobPayloadsFromPgNotifierResponse(responses, asyncJobsStatusMap)
_ = a.updateAsyncJobs(ctx, asyncJobsStatusMap)
wg.Done()
case <-time.After(a.asyncJobTimeOut):
a.logger.Errorf("Go Routine timed out waiting for a response from PgNotifier", pendingAsyncJobs[0].Id)
asyncJobStatusMap := convertToPayloadStatusStructWithSingleStatus(pendingAsyncJobs, WhJobFailed, err)
_ = a.updateAsyncJobs(ctx, asyncJobStatusMap)
wg.Done()
}
}()
wg.Wait()
}
}()
wg.Wait()
}
}

Expand Down Expand Up @@ -293,9 +294,6 @@ func (a *AsyncJobWh) getPendingAsyncJobs(ctx context.Context) ([]AsyncJobPayload

// Updates the warehouse async jobs with the status sent as a parameter
func (a *AsyncJobWh) updateAsyncJobs(ctx context.Context, payloads map[string]AsyncJobStatus) error {
if ctx.Err() != nil {
return ctx.Err()
}
a.logger.Info("[WH-Jobs]: Updating wh async jobs to Executing")
var err error
for _, payload := range payloads {
Expand All @@ -304,37 +302,37 @@ func (a *AsyncJobWh) updateAsyncJobs(ctx context.Context, payloads map[string]As
continue
}
err = a.updateAsyncJobStatus(ctx, payload.Id, payload.Status, "")

}
return err
}

func (a *AsyncJobWh) updateAsyncJobStatus(ctx context.Context, Id, status, errMessage string) error {
a.logger.Infof("[WH-Jobs]: Updating status of wh async jobs to %s", status)
sqlStatement := fmt.Sprintf(`UPDATE %s SET status=(CASE
WHEN attempt >= $1
THEN $2
ELSE $3
END) ,
error=$4 WHERE id=$5 AND status!=$6 AND status!=$7 `, warehouseutils.WarehouseAsyncJobTable)
WHEN attempt >= $1
THEN $2
ELSE $3
END) ,
error=$4 WHERE id=$5 AND status!=$6 AND status!=$7 `,
warehouseutils.WarehouseAsyncJobTable,
)
var err error
for retryCount := 0; retryCount < a.maxQueryRetries; retryCount++ {
a.logger.Debugf("[WH-Jobs]: updating async jobs table query %s, retry no : %d", sqlStatement, retryCount)
row, err := a.dbHandle.QueryContext(ctx, sqlStatement, a.maxAttemptsPerJob, WhJobAborted, status, errMessage, Id, WhJobAborted, WhJobSucceeded)
_, err := a.dbHandle.ExecContext(ctx, sqlStatement,
a.maxAttemptsPerJob, WhJobAborted, status, errMessage, Id, WhJobAborted, WhJobSucceeded,
)
if err == nil {
a.logger.Info("Update successful")
a.logger.Debugf("query: %s successfully executed", sqlStatement)
if status == WhJobFailed {
err = a.updateAsyncJobAttempt(ctx, Id)
return err
return a.updateAsyncJobAttempt(ctx, Id)
}
return err
}
_ = row.Err()
}
if err != nil {
a.logger.Errorf("query: %s failed with Error : %s", sqlStatement, err.Error())
}

a.logger.Errorf("Query: %s failed with error: %s", sqlStatement, err.Error())
return err
}

Expand Down
5 changes: 3 additions & 2 deletions warehouse/tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"
"time"

"github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper"
"github.com/rudderlabs/rudder-server/warehouse/internal/model"

"github.com/golang/mock/gomock"
Expand Down Expand Up @@ -157,7 +158,7 @@ func TestHandleT_Track(t *testing.T) {
},
NowSQL: nowSQL,
stats: store,
dbHandle: pgResource.DB,
dbHandle: sqlquerywrapper.New(pgResource.DB),
Logger: logger.NOP,
}

Expand Down Expand Up @@ -267,7 +268,7 @@ func TestHandleT_CronTracker(t *testing.T) {
},
NowSQL: "ABC",
stats: memstats.New(),
dbHandle: pgResource.DB,
dbHandle: sqlquerywrapper.New(pgResource.DB),
Logger: logger.NOP,
}
wh.warehouses = append(wh.warehouses, warehouse)
Expand Down

0 comments on commit e20976d

Please sign in to comment.