diff --git a/core/services/ocr2/plugins/vault/plugin.go b/core/services/ocr2/plugins/vault/plugin.go index c8f7f8c0dbb..18a0113b9e7 100644 --- a/core/services/ocr2/plugins/vault/plugin.go +++ b/core/services/ocr2/plugins/vault/plugin.go @@ -43,6 +43,11 @@ import ( "github.com/smartcontractkit/chainlink/v2/core/logger" ) +const ( + blobBroadcastTimeout = 2 * time.Second + maxConcurrentBlobBroadcasts = 10 +) + var ( isValidIDComponent = regexp.MustCompile(`^[a-zA-Z0-9_]+$`).MatchString ) @@ -559,11 +564,12 @@ func (r *ReportingPlugin) Observation(ctx context.Context, seqNr uint64, aq type // broadcastBlobPayloads broadcasts each payload as a blob in parallel to reduce // Observation() latency (shortening this phase helps the OCR round finish within // DeltaProgress). Each call is given a 2-second timeout so that a single slow -// broadcast cannot stall the entire batch. Individual broadcast failures are logged -// and skipped rather than aborting the entire observation, so that one problematic -// payload does not prevent the remaining items from being observed. Context -// cancellation/deadline errors on the parent context are propagated immediately so -// that expired rounds fail fast. +// broadcast cannot stall the entire batch. No more than 10 broadcasts are allowed +// in flight at a time. Individual broadcast failures are logged and skipped rather +// than aborting the entire observation, so that one problematic payload does not +// prevent the remaining items from being observed. Context cancellation/deadline +// errors on the parent context are propagated immediately so that expired rounds +// fail fast. func (r *ReportingPlugin) broadcastBlobPayloads( ctx context.Context, fetcher ocr3_1types.BlobBroadcastFetcher, @@ -578,11 +584,12 @@ func (r *ReportingPlugin) broadcastBlobPayloads( r.lggr.Debugw("observation blob broadcast finished", "seqNr", seqNr, "blobCount", len(payloads), "elapsed", time.Since(start)) }() - const perBlobTimeout = 2 * time.Second var g errgroup.Group + g.SetLimit(maxConcurrentBlobBroadcasts) for i, payload := range payloads { + requestID := requestIDs[i] g.Go(func() error { - broadcastCtx, cancel := context.WithTimeout(ctx, perBlobTimeout) + broadcastCtx, cancel := context.WithTimeout(ctx, blobBroadcastTimeout) defer cancel() blobHandle, err := fetcher.BroadcastBlob(broadcastCtx, payload, ocr3_1types.BlobExpirationHintSequenceNumber{SeqNr: seqNr + 2}) @@ -592,7 +599,7 @@ func (r *ReportingPlugin) broadcastBlobPayloads( } r.lggr.Warnw("failed to broadcast pending queue item as blob, skipping", "seqNr", seqNr, - "requestID", requestIDs[i], + "requestID", requestID, "err", err) return nil } @@ -601,7 +608,7 @@ func (r *ReportingPlugin) broadcastBlobPayloads( if err != nil { r.lggr.Warnw("failed to marshal blob handle, skipping", "seqNr", seqNr, - "requestID", requestIDs[i], + "requestID", requestID, "err", err) return nil } diff --git a/core/services/ocr2/plugins/vault/plugin_test.go b/core/services/ocr2/plugins/vault/plugin_test.go index 9a5b37d4cd3..d0794daf366 100644 --- a/core/services/ocr2/plugins/vault/plugin_test.go +++ b/core/services/ocr2/plugins/vault/plugin_test.go @@ -7860,6 +7860,89 @@ func TestPlugin_broadcastBlobPayloads(t *testing.T) { } }) + t.Run("does not exceed max concurrent broadcasts", func(t *testing.T) { + lggr := logger.TestLogger(t) + r := &ReportingPlugin{ + lggr: lggr, + metrics: newTestMetrics(t), + marshalBlob: func(ocr3_1types.BlobHandle) ([]byte, error) { + return []byte("handle"), nil + }, + } + + payloads := make([][]byte, maxConcurrentBlobBroadcasts*2+1) + ids := make([]string, len(payloads)) + for i := range payloads { + payloads[i] = []byte(fmt.Sprintf("payload-%d", i)) + ids[i] = fmt.Sprintf("req-%d", i) + } + + var active atomic.Int32 + var maxActive atomic.Int32 + started := make(chan struct{}, len(payloads)) + release := make(chan struct{}) + released := atomic.Bool{} + releaseBroadcasts := func() { + if released.CompareAndSwap(false, true) { + close(release) + } + } + defer releaseBroadcasts() + + fetcher := &ctxCallbackBlobFetcher{fn: func(ctx context.Context, _ []byte) error { + current := active.Add(1) + defer active.Add(-1) + + for { + maxSeen := maxActive.Load() + if current <= maxSeen || maxActive.CompareAndSwap(maxSeen, current) { + break + } + } + + started <- struct{}{} + select { + case <-release: + return nil + case <-ctx.Done(): + return ctx.Err() + } + }} + + type broadcastResult struct { + payloads [][]byte + err error + } + done := make(chan broadcastResult, 1) + go func() { + result, err := r.broadcastBlobPayloads(t.Context(), fetcher, 1, payloads, ids) + done <- broadcastResult{payloads: result, err: err} + }() + + for i := 0; i < maxConcurrentBlobBroadcasts; i++ { + select { + case <-started: + case <-time.After(time.Second): + t.Fatalf("timed out waiting for broadcast %d to start", i+1) + } + } + + assert.Never(t, func() bool { + return maxActive.Load() > int32(maxConcurrentBlobBroadcasts) + }, 100*time.Millisecond, 10*time.Millisecond) + + releaseBroadcasts() + + select { + case result := <-done: + require.NoError(t, result.err) + assert.Len(t, result.payloads, len(payloads)) + case <-time.After(time.Second): + t.Fatal("timed out waiting for broadcasts to complete") + } + assert.LessOrEqual(t, maxActive.Load(), int32(maxConcurrentBlobBroadcasts)) + }) + t.Run("failed broadcast is skipped and logged", func(t *testing.T) { lggr, observed := logger.TestLoggerObserved(t, zapcore.WarnLevel) r := &ReportingPlugin{