diff --git a/CLAUDE.md b/CLAUDE.md index c5d6f516..bcfac746 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -240,7 +240,7 @@ deps = [ ### Testing - **Table-driven tests** — prefer table-driven tests with `t.Run` subtests over individual test functions. -- **Avoid asserting on error messages** — assert on error type or generic error. +- **Avoid asserting on error messages** — assert on error type or check the error with `require.Error`, do not `assert.Contains(t, err.Error(), message)` - **No change detector tests** — don't assert on default values, internal structure, or implementation details that can change without affecting behavior. Test what the code *does*, not how it's constructed. - **No `time.Sleep` for synchronization** — use channels, callbacks, condition variables. - **Use testify** — `assert`/`require` instead of `t.Fatal()`. diff --git a/core/request/BUILD.bazel b/core/request/BUILD.bazel index 9fb5ed72..2a533a91 100644 --- a/core/request/BUILD.bazel +++ b/core/request/BUILD.bazel @@ -2,21 +2,32 @@ load("@rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "request", - srcs = ["request.go"], + srcs = [ + "log.go", + "request.go", + ], importpath = "github.com/uber/submitqueue/core/request", visibility = ["//visibility:public"], deps = [ + "//core/consumer", "//entity", + "//entity/queue", "//extension/storage", ], ) go_test( name = "request_test", - srcs = ["request_test.go"], + srcs = [ + "log_test.go", + "request_test.go", + ], embed = [":request"], deps = [ + "//core/consumer", "//entity", + "//entity/queue", + "//extension/queue/mock", "//extension/storage", "//extension/storage/mock", "@com_github_stretchr_testify//assert", diff --git a/core/request/log.go b/core/request/log.go new file mode 100644 index 00000000..4a163edd --- /dev/null +++ b/core/request/log.go @@ -0,0 +1,63 @@ +// Copyright (c) 2025 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package request + +import ( + "context" + "fmt" + + "github.com/uber/submitqueue/core/consumer" + "github.com/uber/submitqueue/entity" + entityqueue "github.com/uber/submitqueue/entity/queue" +) + +// PublishLog publishes a single request log entry to the log topic for async persistence. +// The partitionKey ensures ordering of log entries for the same request; typically set to the request ID. +func PublishLog(ctx context.Context, registry consumer.TopicRegistry, logEntry entity.RequestLog, partitionKey string) error { + payload, err := logEntry.ToBytes() + if err != nil { + return fmt.Errorf("failed to serialize request log: %w", err) + } + + msg := entityqueue.NewMessage(logEntry.RequestID, payload, partitionKey, nil) + + q, ok := registry.Queue(consumer.TopicKeyLog) + if !ok { + return fmt.Errorf("no queue registered for topic key %s", consumer.TopicKeyLog) + } + + topicName, ok := registry.TopicName(consumer.TopicKeyLog) + if !ok { + return fmt.Errorf("no topic name registered for topic key %s", consumer.TopicKeyLog) + } + + if err := q.Publisher().Publish(ctx, topicName, msg); err != nil { + return fmt.Errorf("failed to publish message: %w", err) + } + + return nil +} + +// PublishBatchLogs publishes a request log entry for each request ID in the batch to the log topic. +// Each entry uses the request ID as the partition key to ensure per-request ordering. +func PublishBatchLogs(ctx context.Context, registry consumer.TopicRegistry, requestIDs []string, status entity.RequestStatus, metadata map[string]string) error { + for _, requestID := range requestIDs { + logEntry := entity.NewRequestLog(requestID, status, 0, "", metadata) + if err := PublishLog(ctx, registry, logEntry, requestID); err != nil { + return fmt.Errorf("failed to publish request log for request %s: %w", requestID, err) + } + } + return nil +} diff --git a/core/request/log_test.go b/core/request/log_test.go new file mode 100644 index 00000000..ac53e499 --- /dev/null +++ b/core/request/log_test.go @@ -0,0 +1,115 @@ +// Copyright (c) 2025 Uber Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package request + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "github.com/uber/submitqueue/core/consumer" + "github.com/uber/submitqueue/entity" + "github.com/uber/submitqueue/entity/queue" + queuemock "github.com/uber/submitqueue/extension/queue/mock" + "go.uber.org/mock/gomock" +) + +func newTestRegistry(t *testing.T, ctrl *gomock.Controller, publishErr error) consumer.TopicRegistry { + mockPub := queuemock.NewMockPublisher(ctrl) + mockPub.EXPECT().Publish(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, topic string, msg queue.Message) error { + return publishErr + }, + ).AnyTimes() + + mockQ := queuemock.NewMockQueue(ctrl) + mockQ.EXPECT().Publisher().Return(mockPub).AnyTimes() + + registry, err := consumer.NewTopicRegistry( + []consumer.TopicConfig{{Key: consumer.TopicKeyLog, Name: "log", Queue: mockQ}}, + ) + require.NoError(t, err) + return registry +} + +func TestPublishLog_Success(t *testing.T) { + ctrl := gomock.NewController(t) + registry := newTestRegistry(t, ctrl, nil) + + logEntry := entity.NewRequestLog("req/1", entity.RequestStatusStarted, 1, "", nil) + err := PublishLog(context.Background(), registry, logEntry, "req/1") + require.NoError(t, err) +} + +func TestPublishLog_PublishFailure(t *testing.T) { + ctrl := gomock.NewController(t) + registry := newTestRegistry(t, ctrl, fmt.Errorf("connection refused")) + + logEntry := entity.NewRequestLog("req/1", entity.RequestStatusStarted, 1, "", nil) + err := PublishLog(context.Background(), registry, logEntry, "req/1") + require.Error(t, err) +} + +func TestPublishBatchLogs_Success(t *testing.T) { + ctrl := gomock.NewController(t) + registry := newTestRegistry(t, ctrl, nil) + + err := PublishBatchLogs(context.Background(), registry, + []string{"req/1", "req/2", "req/3"}, + entity.RequestStatusScored, + map[string]string{"batch_id": "b/1"}, + ) + require.NoError(t, err) +} + +func TestPublishBatchLogs_PartialFailure(t *testing.T) { + ctrl := gomock.NewController(t) + + callCount := 0 + mockPub := queuemock.NewMockPublisher(ctrl) + mockPub.EXPECT().Publish(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, topic string, msg queue.Message) error { + callCount++ + if callCount == 2 { + return fmt.Errorf("publish failed") + } + return nil + }, + ).AnyTimes() + + mockQ := queuemock.NewMockQueue(ctrl) + mockQ.EXPECT().Publisher().Return(mockPub).AnyTimes() + + registry, err := consumer.NewTopicRegistry( + []consumer.TopicConfig{{Key: consumer.TopicKeyLog, Name: "log", Queue: mockQ}}, + ) + require.NoError(t, err) + + err = PublishBatchLogs(context.Background(), registry, + []string{"req/1", "req/2", "req/3"}, + entity.RequestStatusScored, + map[string]string{"batch_id": "b/1"}, + ) + require.Error(t, err) +} + +func TestPublishBatchLogs_Empty(t *testing.T) { + ctrl := gomock.NewController(t) + registry := newTestRegistry(t, ctrl, nil) + + err := PublishBatchLogs(context.Background(), registry, nil, entity.RequestStatusScored, nil) + require.NoError(t, err) +} diff --git a/entity/batch.go b/entity/batch.go index 6bf958ea..1e8e2a55 100644 --- a/entity/batch.go +++ b/entity/batch.go @@ -32,6 +32,8 @@ const ( BatchStateSucceeded BatchState = "succeeded" // BatchStateFailed is the terminal state of a batch that has failed. BatchStateFailed BatchState = "failed" + // BatchStateScored is the state of a batch that has been scored for build success probability. + BatchStateScored BatchState = "scored" // BatchStateCancelled is the terminal state of a batch that was cancelled before completion. BatchStateCancelled BatchState = "cancelled" ) @@ -81,6 +83,10 @@ type Batch struct { // - queueA/batch/3 will contain queueA/batch/1 Dependencies []string + // Score is the predicted probability of build success for this batch, ranging from 0.0 to 1.0. + // Set during the scoring phase. Zero value means the batch has not been scored yet. + Score float64 + // The state of the batch lifecycle this batch is in. Updateable field with Version for optimistic locking. State BatchState diff --git a/entity/request_log.go b/entity/request_log.go index 7d85ee9f..35569778 100644 --- a/entity/request_log.go +++ b/entity/request_log.go @@ -49,6 +49,9 @@ const ( // RequestStatusBatched indicates that the request has been included in a new batch and will be sent to speculation. RequestStatusBatched RequestStatus = "batched" + // RequestStatusScored indicates that the batch containing the request has been scored for build success probability. + RequestStatusScored RequestStatus = "scored" + // RequestStatusSpeculating indicates that the request is currently being speculated (e.g., speculative merge/rebase, etc.). RequestStatusSpeculating RequestStatus = "speculating" diff --git a/example/server/orchestrator/BUILD.bazel b/example/server/orchestrator/BUILD.bazel index 4a89170f..26c5316d 100644 --- a/example/server/orchestrator/BUILD.bazel +++ b/example/server/orchestrator/BUILD.bazel @@ -12,12 +12,14 @@ go_library( visibility = ["//visibility:private"], deps = [ "//core/consumer", + "//entity", "//extension/counter", "//extension/counter/mysql", "//extension/mergechecker", "//extension/mergechecker/github", "//extension/queue", "//extension/queue/mysql", + "//extension/scorer/heuristic", "//extension/storage", "//extension/storage/mysql", "//orchestrator/controller", diff --git a/example/server/orchestrator/main.go b/example/server/orchestrator/main.go index 25f13364..d7daae4a 100644 --- a/example/server/orchestrator/main.go +++ b/example/server/orchestrator/main.go @@ -30,12 +30,14 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/core/consumer" + "github.com/uber/submitqueue/entity" "github.com/uber/submitqueue/extension/counter" mysqlcounter "github.com/uber/submitqueue/extension/counter/mysql" "github.com/uber/submitqueue/extension/mergechecker" githubchecker "github.com/uber/submitqueue/extension/mergechecker/github" extqueue "github.com/uber/submitqueue/extension/queue" queueMySQL "github.com/uber/submitqueue/extension/queue/mysql" + "github.com/uber/submitqueue/extension/scorer/heuristic" "github.com/uber/submitqueue/extension/storage" mysqlstorage "github.com/uber/submitqueue/extension/storage/mysql" "github.com/uber/submitqueue/orchestrator/controller" @@ -417,6 +419,19 @@ func registerControllers(c consumer.Consumer, logger *zap.SugaredLogger, scope t logger, scope, store, + // TODO: replace with a real scorer + heuristic.New( + []heuristic.Bucket{ + {Min: 0, Max: 1, Score: 0.95}, + {Min: 2, Max: 5, Score: 0.80}, + {Min: 6, Max: 20, Score: 0.60}, + {Min: 21, Max: 1<<31 - 1, Score: 0.40}, + }, + func(_ context.Context, change entity.Change) (int, error) { + return len(change.URIs), nil + }, + scope.SubScope("scorer"), + ), registry, consumer.TopicKeyScore, "orchestrator-score", diff --git a/extension/storage/batch_store.go b/extension/storage/batch_store.go index 92e65b1a..412a90cb 100644 --- a/extension/storage/batch_store.go +++ b/extension/storage/batch_store.go @@ -35,6 +35,10 @@ type BatchStore interface { // The implementation should increment the version by 1 atomically with the state update. UpdateState(ctx context.Context, id string, version int32, newState entity.BatchState) error + // UpdateScoreAndState atomically updates the score and state of a batch if the current version matches the expected version. + // If versions do not match, returns ErrVersionMismatch. The implementation should increment the version by 1 atomically. + UpdateScoreAndState(ctx context.Context, id string, version int32, score float64, newState entity.BatchState) error + // GetByQueueAndStates retrieves all batches that belong to the given queue and are in the given states. GetByQueueAndStates(ctx context.Context, queue string, states []entity.BatchState) ([]entity.Batch, error) } diff --git a/extension/storage/mock/batch_store_mock.go b/extension/storage/mock/batch_store_mock.go index f026713f..4b79c0d5 100644 --- a/extension/storage/mock/batch_store_mock.go +++ b/extension/storage/mock/batch_store_mock.go @@ -85,6 +85,20 @@ func (mr *MockBatchStoreMockRecorder) GetByQueueAndStates(ctx, queue, states any return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByQueueAndStates", reflect.TypeOf((*MockBatchStore)(nil).GetByQueueAndStates), ctx, queue, states) } +// UpdateScoreAndState mocks base method. +func (m *MockBatchStore) UpdateScoreAndState(ctx context.Context, id string, version int32, score float64, newState entity.BatchState) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateScoreAndState", ctx, id, version, score, newState) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateScoreAndState indicates an expected call of UpdateScoreAndState. +func (mr *MockBatchStoreMockRecorder) UpdateScoreAndState(ctx, id, version, score, newState any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateScoreAndState", reflect.TypeOf((*MockBatchStore)(nil).UpdateScoreAndState), ctx, id, version, score, newState) +} + // UpdateState mocks base method. func (m *MockBatchStore) UpdateState(ctx context.Context, id string, version int32, newState entity.BatchState) error { m.ctrl.T.Helper() diff --git a/extension/storage/mysql/batch_store.go b/extension/storage/mysql/batch_store.go index 7cad69ca..b14b4fe7 100644 --- a/extension/storage/mysql/batch_store.go +++ b/extension/storage/mysql/batch_store.go @@ -50,9 +50,9 @@ func (s *batchStore) Get(ctx context.Context, id string) (ret entity.Batch, retE var dependenciesJSON []byte err := s.db.QueryRowContext(ctx, - "SELECT id, queue, contains, dependencies, state, version FROM batch WHERE id = ?", + "SELECT id, queue, contains, dependencies, score, state, version FROM batch WHERE id = ?", id, - ).Scan(&batch.ID, &batch.Queue, &containsJSON, &dependenciesJSON, &batch.State, &batch.Version) + ).Scan(&batch.ID, &batch.Queue, &containsJSON, &dependenciesJSON, &batch.Score, &batch.State, &batch.Version) if errors.Is(err, sql.ErrNoRows) { return entity.Batch{}, storage.WrapNotFound(err) @@ -88,8 +88,8 @@ func (s *batchStore) Create(ctx context.Context, batch entity.Batch) (retErr err } _, err = s.db.ExecContext(ctx, - "INSERT INTO batch (id, queue, contains, dependencies, state, version) VALUES (?, ?, ?, ?, ?, ?)", - batch.ID, batch.Queue, containsJSON, dependenciesJSON, batch.State, batch.Version, + "INSERT INTO batch (id, queue, contains, dependencies, score, state, version) VALUES (?, ?, ?, ?, ?, ?, ?)", + batch.ID, batch.Queue, containsJSON, dependenciesJSON, batch.Score, batch.State, batch.Version, ) if err != nil { var mysqlErr *mysql.MySQLError @@ -137,6 +137,41 @@ func (s *batchStore) UpdateState(ctx context.Context, id string, version int32, return nil } +// UpdateScoreAndState atomically updates the score and state of a batch if the current version matches the expected version. +// If versions do not match, returns ErrVersionMismatch. The implementation increments the version by 1 atomically. +func (s *batchStore) UpdateScoreAndState(ctx context.Context, id string, version int32, score float64, newState entity.BatchState) (retErr error) { + op := metrics.Begin(s.scope, "update_score_and_state") + defer func() { op.Complete(retErr) }() + + result, err := s.db.ExecContext(ctx, + "UPDATE batch SET score = ?, state = ?, version = version + 1 WHERE id = ? AND version = ?", + score, newState, id, version, + ) + if err != nil { + return fmt.Errorf( + "failed to update batch score and state for id=%q version=%d score=%f newState=%v: %w", + id, version, score, newState, err, + ) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf( + "failed to get rows affected from update score and state for id=%q version=%d score=%f newState=%v: %w", + id, version, score, newState, err, + ) + } + + if rowsAffected != 1 { + return fmt.Errorf( + "version mismatch for batch update score and state: id=%q expected_version=%d score=%f newState=%v: %w", + id, version, score, newState, storage.ErrVersionMismatch, + ) + } + + return nil +} + // GetByQueueAndStates retrieves all batches that belong to the given queue and are in the given states. func (s *batchStore) GetByQueueAndStates(ctx context.Context, queue string, states []entity.BatchState) (ret []entity.Batch, retErr error) { op := metrics.Begin(s.scope, "get_by_queue_and_states") @@ -146,7 +181,7 @@ func (s *batchStore) GetByQueueAndStates(ctx context.Context, queue string, stat return nil, nil } - query := "SELECT id, queue, contains, dependencies, state, version FROM batch WHERE queue = ? AND state IN (?" + strings.Repeat(", ?", len(states)-1) + ")" + query := "SELECT id, queue, contains, dependencies, score, state, version FROM batch WHERE queue = ? AND state IN (?" + strings.Repeat(", ?", len(states)-1) + ")" args := make([]any, 1+len(states)) args[0] = queue @@ -166,7 +201,7 @@ func (s *batchStore) GetByQueueAndStates(ctx context.Context, queue string, stat var containsJSON []byte var dependenciesJSON []byte - if err := rows.Scan(&batch.ID, &batch.Queue, &containsJSON, &dependenciesJSON, &batch.State, &batch.Version); err != nil { + if err := rows.Scan(&batch.ID, &batch.Queue, &containsJSON, &dependenciesJSON, &batch.Score, &batch.State, &batch.Version); err != nil { return nil, fmt.Errorf("failed to scan batch entity by queue=%q states=%v from the database: %w", queue, states, err) } diff --git a/extension/storage/mysql/schema/batch.sql b/extension/storage/mysql/schema/batch.sql index 398d17ca..8e7deda0 100644 --- a/extension/storage/mysql/schema/batch.sql +++ b/extension/storage/mysql/schema/batch.sql @@ -3,6 +3,7 @@ CREATE TABLE IF NOT EXISTS batch ( queue VARCHAR(255) NOT NULL, contains JSON NOT NULL, dependencies JSON NOT NULL, + score DOUBLE NOT NULL, state VARCHAR(255) NOT NUll, version INT NOT NULL, PRIMARY KEY (id), diff --git a/orchestrator/controller/score/BUILD.bazel b/orchestrator/controller/score/BUILD.bazel index d7fb2254..26938183 100644 --- a/orchestrator/controller/score/BUILD.bazel +++ b/orchestrator/controller/score/BUILD.bazel @@ -7,8 +7,10 @@ go_library( visibility = ["//visibility:public"], deps = [ "//core/consumer", + "//core/request", "//entity", "//entity/queue", + "//extension/scorer", "//extension/storage", "@com_github_uber_go_tally_v4//:tally", "@org_uber_go_zap//:zap", @@ -25,6 +27,7 @@ go_test( "//entity", "//entity/queue", "//extension/queue/mock", + "//extension/scorer/mock", "//extension/storage/mock", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", diff --git a/orchestrator/controller/score/score.go b/orchestrator/controller/score/score.go index 4a916caf..f3f3b687 100644 --- a/orchestrator/controller/score/score.go +++ b/orchestrator/controller/score/score.go @@ -20,19 +20,23 @@ import ( "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/core/consumer" + corerequest "github.com/uber/submitqueue/core/request" "github.com/uber/submitqueue/entity" entityqueue "github.com/uber/submitqueue/entity/queue" + "github.com/uber/submitqueue/extension/scorer" "github.com/uber/submitqueue/extension/storage" "go.uber.org/zap" ) // Controller handles score queue messages. -// It consumes batches, scores them, and publishes to the speculate stage. +// It consumes batches, scores them using the provided scorer, persists the score, +// and publishes to the speculate stage. // Implements consumer.Controller interface for integration with the consumer. type Controller struct { logger *zap.SugaredLogger metricsScope tally.Scope store storage.Storage + scorer scorer.Scorer registry consumer.TopicRegistry topicKey consumer.TopicKey consumerGroup string @@ -46,6 +50,7 @@ func NewController( logger *zap.SugaredLogger, scope tally.Scope, store storage.Storage, + scorer scorer.Scorer, registry consumer.TopicRegistry, topicKey consumer.TopicKey, consumerGroup string, @@ -54,6 +59,7 @@ func NewController( logger: logger.Named("score_controller"), metricsScope: scope.SubScope("score_controller"), store: store, + scorer: scorer, registry: registry, topicKey: topicKey, consumerGroup: consumerGroup, @@ -61,7 +67,9 @@ func NewController( } // Process processes a score delivery from the queue. -// Deserializes the batch, scores it, and publishes to the speculate topic. +// Deserializes the batch, scores each request's change using the scorer, +// persists the minimum score, publishes request log entries, +// and publishes to the speculate topic. // Returns nil to ack (success), or error to nack (retry). func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) error { c.metricsScope.Counter("received").Inc(1) @@ -91,9 +99,32 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er "partition_key", msg.PartitionKey, ) - // TODO: Add scoring logic - // - Evaluate batch priority - // - Apply scoring heuristics + // Score each request's change and take the minimum (worst-case) as the batch score + batchScore, err := c.scoreBatch(ctx, batch) + if err != nil { + c.metricsScope.Counter("scorer_errors").Inc(1) + return fmt.Errorf("failed to score batch %s: %w", batch.ID, err) + } + + // Atomically update score and state to "scored" in the database + if err := c.store.GetBatchStore().UpdateScoreAndState(ctx, batch.ID, batch.Version, batchScore, entity.BatchStateScored); err != nil { + c.metricsScope.Counter("storage_errors").Inc(1) + return fmt.Errorf("failed to update score for batch %s: %w", batch.ID, err) + } + + c.logger.Infow("scored batch", + "batch_id", batch.ID, + "score", batchScore, + ) + + // Publish request log entries for all requests in the batch + if err := corerequest.PublishBatchLogs(ctx, c.registry, batch.Contains, entity.RequestStatusScored, map[string]string{ + "batch_id": batch.ID, + "score": fmt.Sprintf("%.4f", batchScore), + }); err != nil { + c.metricsScope.Counter("request_log_errors").Inc(1) + return fmt.Errorf("failed to publish request logs for batch %s: %w", batch.ID, err) + } // Publish to speculate topic if err := c.publish(ctx, consumer.TopicKeySpeculate, batch.ID, batch.Queue); err != nil { @@ -111,6 +142,25 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er return nil // Success - message will be acked } +// scoreBatch scores each request's change in the batch and returns the combined probability. +// Uses multiplicative probability: if any single request fails, the entire batch fails, +// so the batch score is the product of individual request scores. +func (c *Controller) scoreBatch(ctx context.Context, batch entity.Batch) (float64, error) { + score := 1.0 + for _, requestID := range batch.Contains { + request, err := c.store.GetRequestStore().Get(ctx, requestID) + if err != nil { + return 0, fmt.Errorf("failed to get request %s: %w", requestID, err) + } + s, err := c.scorer.Score(ctx, request.Change) + if err != nil { + return 0, fmt.Errorf("failed to score request %s: %w", requestID, err) + } + score *= s + } + return score, nil +} + // publish publishes a batch ID to the specified topic key. func (c *Controller) publish(ctx context.Context, key consumer.TopicKey, batchID string, partitionKey string) error { bid := entity.BatchID{ID: batchID} diff --git a/orchestrator/controller/score/score_test.go b/orchestrator/controller/score/score_test.go index 763dc46e..6face5d7 100644 --- a/orchestrator/controller/score/score_test.go +++ b/orchestrator/controller/score/score_test.go @@ -27,6 +27,7 @@ import ( "github.com/uber/submitqueue/entity" "github.com/uber/submitqueue/entity/queue" queuemock "github.com/uber/submitqueue/extension/queue/mock" + scorermock "github.com/uber/submitqueue/extension/scorer/mock" storagemock "github.com/uber/submitqueue/extension/storage/mock" "go.uber.org/mock/gomock" "go.uber.org/zap/zaptest" @@ -42,25 +43,44 @@ func batchIDPayload(t *testing.T, id string) []byte { // testBatch returns a standard test batch for score tests. func testBatch() entity.Batch { return entity.Batch{ - ID: "test-queue/batch/1", - Queue: "test-queue", - State: entity.BatchStateCreated, + ID: "test-queue/batch/1", + Queue: "test-queue", + Contains: []string{"test-queue/1"}, + State: entity.BatchStateCreated, + Version: 1, + } +} + +// testRequest returns a standard test request for score tests. +func testRequest() entity.Request { + return entity.Request{ + ID: "test-queue/1", + Queue: "test-queue", + Change: entity.Change{ + URIs: []string{"github://uber/repo/pull/1/abc123"}, + }, + State: entity.RequestStateStarted, Version: 1, } } -// newMockStorage creates a MockStorage with a MockBatchStore that returns the given batch on Get. -func newMockStorage(ctrl *gomock.Controller, batch entity.Batch) *storagemock.MockStorage { +// newMockStorage creates a MockStorage with a MockBatchStore and MockRequestStore. +func newMockStorage(ctrl *gomock.Controller, batch entity.Batch, request entity.Request) *storagemock.MockStorage { mockBatchStore := storagemock.NewMockBatchStore(ctrl) mockBatchStore.EXPECT().Get(gomock.Any(), batch.ID).Return(batch, nil).AnyTimes() + mockBatchStore.EXPECT().UpdateScoreAndState(gomock.Any(), batch.ID, batch.Version, gomock.Any(), entity.BatchStateScored).Return(nil).AnyTimes() + + mockRequestStore := storagemock.NewMockRequestStore(ctrl) + mockRequestStore.EXPECT().Get(gomock.Any(), request.ID).Return(request, nil).AnyTimes() store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() + store.EXPECT().GetRequestStore().Return(mockRequestStore).AnyTimes() return store } // newTestController creates a controller with test dependencies. -func newTestController(t *testing.T, ctrl *gomock.Controller, store *storagemock.MockStorage, publishErr error) *Controller { +func newTestController(t *testing.T, ctrl *gomock.Controller, store *storagemock.MockStorage, scorer *scorermock.MockScorer, publishErr error) *Controller { logger := zaptest.NewLogger(t).Sugar() scope := tally.NoopScope @@ -75,18 +95,23 @@ func newTestController(t *testing.T, ctrl *gomock.Controller, store *storagemock mockQ.EXPECT().Publisher().Return(mockPub).AnyTimes() registry, err := consumer.NewTopicRegistry( - []consumer.TopicConfig{{Key: consumer.TopicKeySpeculate, Name: "speculate", Queue: mockQ}}, + []consumer.TopicConfig{ + {Key: consumer.TopicKeySpeculate, Name: "speculate", Queue: mockQ}, + {Key: consumer.TopicKeyLog, Name: "log", Queue: mockQ}, + }, ) require.NoError(t, err) - return NewController(logger, scope, store, registry, consumer.TopicKeyScore, "orchestrator-score") + return NewController(logger, scope, store, scorer, registry, consumer.TopicKeyScore, "orchestrator-score") } func TestNewController(t *testing.T) { ctrl := gomock.NewController(t) batch := testBatch() - store := newMockStorage(ctrl, batch) - controller := newTestController(t, ctrl, store, nil) + request := testRequest() + store := newMockStorage(ctrl, batch, request) + mockScorer := scorermock.NewMockScorer(ctrl) + controller := newTestController(t, ctrl, store, mockScorer, nil) require.NotNil(t, controller) assert.Equal(t, consumer.TopicKeyScore, controller.TopicKey()) @@ -98,8 +123,67 @@ func TestController_Process_Success(t *testing.T) { ctrl := gomock.NewController(t) batch := testBatch() - store := newMockStorage(ctrl, batch) - controller := newTestController(t, ctrl, store, nil) + request := testRequest() + store := newMockStorage(ctrl, batch, request) + + mockScorer := scorermock.NewMockScorer(ctrl) + mockScorer.EXPECT().Score(gomock.Any(), request.Change).Return(0.85, nil) + + controller := newTestController(t, ctrl, store, mockScorer, nil) + + msg := queue.NewMessage(batch.ID, batchIDPayload(t, batch.ID), batch.Queue, nil) + delivery := queuemock.NewMockDelivery(ctrl) + delivery.EXPECT().Message().Return(msg).AnyTimes() + delivery.EXPECT().Attempt().Return(1).AnyTimes() + + err := controller.Process(context.Background(), delivery) + require.NoError(t, err) +} + +func TestController_Process_MultipleRequests_MinScore(t *testing.T) { + ctrl := gomock.NewController(t) + + batch := entity.Batch{ + ID: "test-queue/batch/1", + Queue: "test-queue", + Contains: []string{"test-queue/1", "test-queue/2"}, + State: entity.BatchStateCreated, + Version: 1, + } + + request1 := entity.Request{ + ID: "test-queue/1", + Queue: "test-queue", + Change: entity.Change{URIs: []string{"github://uber/repo/pull/1/abc"}}, + State: entity.RequestStateStarted, + Version: 1, + } + request2 := entity.Request{ + ID: "test-queue/2", + Queue: "test-queue", + Change: entity.Change{URIs: []string{"github://uber/repo/pull/2/def"}}, + State: entity.RequestStateStarted, + Version: 1, + } + + mockBatchStore := storagemock.NewMockBatchStore(ctrl) + mockBatchStore.EXPECT().Get(gomock.Any(), batch.ID).Return(batch, nil) + // Expect the multiplicative score (0.9 * 0.6 = 0.54) to be persisted + mockBatchStore.EXPECT().UpdateScoreAndState(gomock.Any(), batch.ID, batch.Version, 0.54, entity.BatchStateScored).Return(nil) + + mockRequestStore := storagemock.NewMockRequestStore(ctrl) + mockRequestStore.EXPECT().Get(gomock.Any(), "test-queue/1").Return(request1, nil) + mockRequestStore.EXPECT().Get(gomock.Any(), "test-queue/2").Return(request2, nil) + + store := storagemock.NewMockStorage(ctrl) + store.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() + store.EXPECT().GetRequestStore().Return(mockRequestStore).AnyTimes() + + mockScorer := scorermock.NewMockScorer(ctrl) + mockScorer.EXPECT().Score(gomock.Any(), request1.Change).Return(0.9, nil) + mockScorer.EXPECT().Score(gomock.Any(), request2.Change).Return(0.6, nil) + + controller := newTestController(t, ctrl, store, mockScorer, nil) msg := queue.NewMessage(batch.ID, batchIDPayload(t, batch.ID), batch.Queue, nil) delivery := queuemock.NewMockDelivery(ctrl) @@ -118,7 +202,8 @@ func TestController_Process_StorageFailure(t *testing.T) { store := storagemock.NewMockStorage(ctrl) store.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() - controller := newTestController(t, ctrl, store, nil) + mockScorer := scorermock.NewMockScorer(ctrl) + controller := newTestController(t, ctrl, store, mockScorer, nil) msg := queue.NewMessage("test-queue/batch/1", batchIDPayload(t, "test-queue/batch/1"), "test-queue", nil) delivery := queuemock.NewMockDelivery(ctrl) @@ -130,12 +215,78 @@ func TestController_Process_StorageFailure(t *testing.T) { assert.False(t, errs.IsRetryable(err)) } +func TestController_Process_ScorerFailure(t *testing.T) { + ctrl := gomock.NewController(t) + + batch := testBatch() + request := testRequest() + + mockBatchStore := storagemock.NewMockBatchStore(ctrl) + mockBatchStore.EXPECT().Get(gomock.Any(), batch.ID).Return(batch, nil) + + mockRequestStore := storagemock.NewMockRequestStore(ctrl) + mockRequestStore.EXPECT().Get(gomock.Any(), request.ID).Return(request, nil) + + store := storagemock.NewMockStorage(ctrl) + store.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() + store.EXPECT().GetRequestStore().Return(mockRequestStore).AnyTimes() + + mockScorer := scorermock.NewMockScorer(ctrl) + mockScorer.EXPECT().Score(gomock.Any(), request.Change).Return(0.0, fmt.Errorf("no bucket matches value 99")) + + controller := newTestController(t, ctrl, store, mockScorer, nil) + + msg := queue.NewMessage(batch.ID, batchIDPayload(t, batch.ID), batch.Queue, nil) + delivery := queuemock.NewMockDelivery(ctrl) + delivery.EXPECT().Message().Return(msg).AnyTimes() + delivery.EXPECT().Attempt().Return(1).AnyTimes() + + err := controller.Process(context.Background(), delivery) + require.Error(t, err) +} + +func TestController_Process_UpdateScoreFailure(t *testing.T) { + ctrl := gomock.NewController(t) + + batch := testBatch() + request := testRequest() + + mockBatchStore := storagemock.NewMockBatchStore(ctrl) + mockBatchStore.EXPECT().Get(gomock.Any(), batch.ID).Return(batch, nil) + mockBatchStore.EXPECT().UpdateScoreAndState(gomock.Any(), batch.ID, batch.Version, gomock.Any(), entity.BatchStateScored).Return(fmt.Errorf("version mismatch")) + + mockRequestStore := storagemock.NewMockRequestStore(ctrl) + mockRequestStore.EXPECT().Get(gomock.Any(), request.ID).Return(request, nil) + + store := storagemock.NewMockStorage(ctrl) + store.EXPECT().GetBatchStore().Return(mockBatchStore).AnyTimes() + store.EXPECT().GetRequestStore().Return(mockRequestStore).AnyTimes() + + mockScorer := scorermock.NewMockScorer(ctrl) + mockScorer.EXPECT().Score(gomock.Any(), request.Change).Return(0.85, nil) + + controller := newTestController(t, ctrl, store, mockScorer, nil) + + msg := queue.NewMessage(batch.ID, batchIDPayload(t, batch.ID), batch.Queue, nil) + delivery := queuemock.NewMockDelivery(ctrl) + delivery.EXPECT().Message().Return(msg).AnyTimes() + delivery.EXPECT().Attempt().Return(1).AnyTimes() + + err := controller.Process(context.Background(), delivery) + require.Error(t, err) +} + func TestController_Process_PublishFailure(t *testing.T) { ctrl := gomock.NewController(t) batch := testBatch() - store := newMockStorage(ctrl, batch) - controller := newTestController(t, ctrl, store, fmt.Errorf("publish failed")) + request := testRequest() + store := newMockStorage(ctrl, batch, request) + + mockScorer := scorermock.NewMockScorer(ctrl) + mockScorer.EXPECT().Score(gomock.Any(), request.Change).Return(0.85, nil) + + controller := newTestController(t, ctrl, store, mockScorer, fmt.Errorf("publish failed")) msg := queue.NewMessage(batch.ID, batchIDPayload(t, batch.ID), batch.Queue, nil) delivery := queuemock.NewMockDelivery(ctrl) @@ -149,8 +300,10 @@ func TestController_Process_PublishFailure(t *testing.T) { func TestController_InterfaceImplementation(t *testing.T) { ctrl := gomock.NewController(t) batch := testBatch() - store := newMockStorage(ctrl, batch) - controller := newTestController(t, ctrl, store, nil) + request := testRequest() + store := newMockStorage(ctrl, batch, request) + mockScorer := scorermock.NewMockScorer(ctrl) + controller := newTestController(t, ctrl, store, mockScorer, nil) var _ consumer.Controller = controller } diff --git a/orchestrator/controller/start/BUILD.bazel b/orchestrator/controller/start/BUILD.bazel index 416ec03a..df436444 100644 --- a/orchestrator/controller/start/BUILD.bazel +++ b/orchestrator/controller/start/BUILD.bazel @@ -7,6 +7,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//core/consumer", + "//core/request", "//entity", "//entity/queue", "//extension/storage", diff --git a/orchestrator/controller/start/start.go b/orchestrator/controller/start/start.go index be8ddf97..643f6ef1 100644 --- a/orchestrator/controller/start/start.go +++ b/orchestrator/controller/start/start.go @@ -21,6 +21,7 @@ import ( "github.com/uber-go/tally/v4" "github.com/uber/submitqueue/core/consumer" + corerequest "github.com/uber/submitqueue/core/request" "github.com/uber/submitqueue/entity" entityqueue "github.com/uber/submitqueue/entity/queue" "github.com/uber/submitqueue/extension/storage" @@ -109,7 +110,7 @@ func (c *Controller) Process(ctx context.Context, delivery consumer.Delivery) er logEntry := entity.NewRequestLog(request.ID, entity.RequestStatusStarted, request.Version, "", nil) // Using request.ID as the partition key to ensure ordering of log entries for the same request // and parallel processing of log entries for different requests. - if err := c.publishLog(ctx, logEntry, request.ID); err != nil { + if err := corerequest.PublishLog(ctx, c.registry, logEntry, request.ID); err != nil { c.metricsScope.Counter("request_log_errors").Inc(1) return fmt.Errorf("failed to publish request log: %w", err) } @@ -157,32 +158,6 @@ func (c *Controller) publish(ctx context.Context, key consumer.TopicKey, request return nil } -// publishLog publishes a request log entry to the log topic for async persistence. -func (c *Controller) publishLog(ctx context.Context, logEntry entity.RequestLog, partitionKey string) error { - payload, err := logEntry.ToBytes() - if err != nil { - return fmt.Errorf("failed to serialize request log: %w", err) - } - - msg := entityqueue.NewMessage(logEntry.RequestID, payload, partitionKey, nil) - - q, ok := c.registry.Queue(consumer.TopicKeyLog) - if !ok { - return fmt.Errorf("no queue registered for topic key %s", consumer.TopicKeyLog) - } - - topicName, ok := c.registry.TopicName(consumer.TopicKeyLog) - if !ok { - return fmt.Errorf("no topic name registered for topic key %s", consumer.TopicKeyLog) - } - - if err := q.Publisher().Publish(ctx, topicName, msg); err != nil { - return fmt.Errorf("failed to publish message: %w", err) - } - - return nil -} - // Name returns the controller name for logging and metrics. func (c *Controller) Name() string { return "start"