From 1e5189db851ac5637732ba064025a71c1491e974 Mon Sep 17 00:00:00 2001 From: Peter Bwire Date: Mon, 20 Apr 2026 16:45:03 +0300 Subject: [PATCH 1/2] fix: race-free subscriber/service lifecycle and xid-derived CreatedAt Four changes land together to unblock running tests under -race across frame and its consumers: - queue/subscriber: guard s.subscription with a mutex so Stop() called from listen() on context cancel and from queueManager.Close() no longer race on the nil-out. Promote s.state to atomic.Int32 (and SubscriberState to int32) so State()/IsIdle() stop racing Receive(). - service: guard Service.driver with a mutex. Stop() calling Shutdown() while initializeServerDrivers()/startServerDriver() is still running no longer races on the field. - data.BaseModel: derive CreatedAt (and ModifiedAt on first insert) from the xid embedded timestamp so sort-by-id == sort-by-created_at. Reject caller-supplied non-xid IDs explicitly rather than falling back to time.Now(); hypertable composite PKs rely on that invariant. - datastore.BaseRepository: BulkCreate retains bare ON CONFLICT DO NOTHING as the portable default, and gains WithBulkCreateConflictColumns(...) so callers can target a specific unique index - e.g. (id, created_at) for hypertable-promoted tables or a domain-unique column for upsert flows. --- data/model.go | 26 ++++++-- datastore/repository.go | 65 +++++++++++++++---- datastore/repository_test.go | 19 +++++- queue/interface.go | 2 +- queue/subscriber.go | 121 ++++++++++++++++++++++++----------- service.go | 35 ++++++++-- 6 files changed, 204 insertions(+), 64 deletions(-) diff --git a/data/model.go b/data/model.go index db5042f9..baf3e211 100644 --- a/data/model.go +++ b/data/model.go @@ -2,6 +2,7 @@ package data import ( "context" + "fmt" "time" "github.com/pitabwire/util" @@ -98,16 +99,33 @@ func (model *BaseModel) BeforeSave(db *gorm.DB) error { } func (model *BaseModel) BeforeCreate(db *gorm.DB) error { + // Generate the ID first so CreatedAt can be derived from its embedded + // xid timestamp, keeping sort-by-id ≡ sort-by-created_at. + model.GenID(db.Statement.Context) + if model.Version <= 0 { - model.CreatedAt = time.Now() - model.ModifiedAt = time.Now() + created, err := createdAtFromID(model.ID) + if err != nil { + return err + } + model.CreatedAt = created + model.ModifiedAt = created model.Version = 1 } - - model.GenID(db.Statement.Context) return nil } +// createdAtFromID returns the time component embedded in a generated xid. +// All BaseModel IDs must be valid xids so sort-by-id ≡ sort-by-created_at +// and hypertable promotions retain monotonic time ordering. +func createdAtFromID(id string) (time.Time, error) { + parsed, err := xid.FromString(id) + if err != nil { + return time.Time{}, fmt.Errorf("BaseModel.ID %q is not a valid xid: %w", id, err) + } + return parsed.Time(), nil +} + // BeforeUpdate Updates time stamp every time we update status of a migration. func (model *BaseModel) BeforeUpdate(db *gorm.DB) error { model.ModifiedAt = time.Now() diff --git a/datastore/repository.go b/datastore/repository.go index dba089e6..a9a9c4d1 100644 --- a/datastore/repository.go +++ b/datastore/repository.go @@ -41,6 +41,30 @@ type BaseRepository[T any] interface { DeleteBatch(ctx context.Context, ids []string) error } +// BaseRepositoryOption configures optional behaviour on NewBaseRepository. +type BaseRepositoryOption func(*baseRepositoryConfig) + +type baseRepositoryConfig struct { + // bulkCreateConflictColumns is the explicit ON CONFLICT target used by + // BulkCreate. Empty means "bare ON CONFLICT DO NOTHING" — portable across + // plain-PK and composite-PK (hypertable) tables. Set this when the caller + // depends on conflict inference against a specific unique index, e.g. + // (id, created_at) for hypertables or a domain-unique column. + bulkCreateConflictColumns []clause.Column +} + +// WithBulkCreateConflictColumns sets the ON CONFLICT target columns for +// BulkCreate. Columns must match an existing unique index on the target +// table, otherwise Postgres rejects the insert. +func WithBulkCreateConflictColumns(columns ...string) BaseRepositoryOption { + return func(c *baseRepositoryConfig) { + c.bulkCreateConflictColumns = c.bulkCreateConflictColumns[:0] + for _, col := range columns { + c.bulkCreateConflictColumns = append(c.bulkCreateConflictColumns, clause.Column{Name: col}) + } + } +} + // baseRepository is the concrete implementation of BaseRepository. type baseRepository[T data.BaseModelI] struct { dbPool pool.Pool @@ -55,6 +79,10 @@ type baseRepository[T data.BaseModelI] struct { // allowedFields whitelist for safe column access (set during initialization) allowedFields map[string]struct{} + + // bulkCreateConflictColumns is the optional explicit ON CONFLICT target + // for BulkCreate. See WithBulkCreateConflictColumns. + bulkCreateConflictColumns []clause.Column } // NewBaseRepository creates a new base repository instance. @@ -64,14 +92,21 @@ func NewBaseRepository[T data.BaseModelI]( dbPool pool.Pool, workMan workerpool.Manager, modelFactory func() T, + opts ...BaseRepositoryOption, ) BaseRepository[T] { + cfg := &baseRepositoryConfig{} + for _, opt := range opts { + opt(cfg) + } + repo := &baseRepository[T]{ - dbPool: dbPool, - workMan: workMan, - modelFactory: modelFactory, - batchSize: 751, //nolint:mnd // default batch size - immutableFields: []string{"id", "created_at", "tenant_id", "partition_id"}, - allowedFields: make(map[string]struct{}), + dbPool: dbPool, + workMan: workMan, + modelFactory: modelFactory, + batchSize: 751, //nolint:mnd // default batch size + immutableFields: []string{"id", "created_at", "tenant_id", "partition_id"}, + allowedFields: make(map[string]struct{}), + bulkCreateConflictColumns: cfg.bulkCreateConflictColumns, } db := dbPool.DB(ctx, true) @@ -176,16 +211,24 @@ func (br *baseRepository[T]) BatchSize() int { } // BulkCreate inserts multiple entities efficiently in a single transaction. +// The conflict target defaults to bare ON CONFLICT DO NOTHING so the same +// repository works for plain-PK and composite-PK (hypertable) tables. +// Callers that need inference against a specific unique index set it via +// WithBulkCreateConflictColumns at construction. func (br *baseRepository[T]) BulkCreate(ctx context.Context, entities []T) error { if len(entities) == 0 { return nil } - // CreateInBatches uses GORM's batch insert which is more efficient - // The batch size is configured in pool options (InsertBatchSize) - return br.Pool().DB(ctx, false).Clauses(clause.OnConflict{ - DoNothing: true, - }).CreateInBatches(entities, br.BatchSize()).Error + onConflict := clause.OnConflict{DoNothing: true} + if len(br.bulkCreateConflictColumns) > 0 { + onConflict.Columns = br.bulkCreateConflictColumns + } + + // CreateInBatches uses GORM's batch insert which is more efficient; + // the batch size is configured in pool options (InsertBatchSize). + return br.Pool().DB(ctx, false).Clauses(onConflict). + CreateInBatches(entities, br.BatchSize()).Error } // validateAffectedColumns checks if all columns are valid and allowed. diff --git a/datastore/repository_test.go b/datastore/repository_test.go index 83a34c27..4e32061f 100644 --- a/datastore/repository_test.go +++ b/datastore/repository_test.go @@ -367,17 +367,30 @@ func (s *RepositoryTestSuite) TestCreate() { expectError: false, }, { - name: "create entity with pre-set ID", + name: "create entity with pre-set xid", setupEntity: func(_ context.Context) *TestEntity { entity := &TestEntity{ Name: "Entity with ID", } - // Use unique timestamp-based ID to avoid conflicts between test runs - entity.ID = fmt.Sprintf("custom-id-%d", time.Now().UnixNano()) + // Caller-supplied IDs must be valid xids so CreatedAt can be + // derived deterministically from the embedded timestamp. + entity.ID = util.IDString() return entity }, expectError: false, }, + { + name: "create entity with non-xid ID should fail", + setupEntity: func(_ context.Context) *TestEntity { + entity := &TestEntity{ + Name: "Entity with legacy id", + } + entity.ID = fmt.Sprintf("custom-id-%d", time.Now().UnixNano()) + return entity + }, + expectError: true, + errorMsg: "is not a valid xid", + }, { name: "create entity with version > 0 should fail", setupEntity: func(_ context.Context) *TestEntity { diff --git a/queue/interface.go b/queue/interface.go index fcfff83a..d5c7a702 100644 --- a/queue/interface.go +++ b/queue/interface.go @@ -7,7 +7,7 @@ import ( "gocloud.dev/pubsub" ) -type SubscriberState int +type SubscriberState int32 const ( SubscriberStateWaiting SubscriberState = iota diff --git a/queue/subscriber.go b/queue/subscriber.go index 6eb5493e..cc8bcfaf 100644 --- a/queue/subscriber.go +++ b/queue/subscriber.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "strings" + "sync" "sync/atomic" "time" @@ -27,18 +28,38 @@ const ( ) type subscriber struct { - reference string - url string - handlers []SubscribeWorker + reference string + url string + handlers []SubscribeWorker + + // mu guards subscription lifecycle transitions (create/recreate/shutdown). + // Holding it serialises Stop() against listen()'s context-cancel path and + // external Close() callers so the field is never written by two goroutines. + mu sync.Mutex subscription *pubsub.Subscription - isInit atomic.Bool - state SubscriberState - metrics *subscriberMetrics - tracer telemetry.Tracer + + isInit atomic.Bool + state atomic.Int32 + metrics *subscriberMetrics + tracer telemetry.Tracer workManager workerpool.Manager } +func (s *subscriber) loadSubscription() *pubsub.Subscription { + s.mu.Lock() + defer s.mu.Unlock() + return s.subscription +} + +func (s *subscriber) storeState(st SubscriberState) { + s.state.Store(int32(st)) +} + +func (s *subscriber) loadState() SubscriberState { + return SubscriberState(s.state.Load()) +} + func subscriberReceiveErrorBackoffDelay(consecutiveErrors int) time.Duration { if consecutiveErrors <= 0 { return 0 @@ -80,50 +101,63 @@ func (s *subscriber) URI() string { } func (s *subscriber) Receive(ctx context.Context) (*pubsub.Message, error) { - if s.subscription == nil { + sub := s.loadSubscription() + if sub == nil { return nil, errors.New("only initialised subscriptions can pull messages") } - s.state = SubscriberStateWaiting + s.storeState(SubscriberStateWaiting) s.metrics.LastActivity.Store(time.Now().UnixNano()) - msg, err := s.subscription.Receive(ctx) + msg, err := sub.Receive(ctx) if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return nil, err } - s.state = SubscriberStateInError + s.storeState(SubscriberStateInError) s.metrics.ErrorCount.Add(1) return nil, err } - s.state = SubscriberStateProcessing + s.storeState(SubscriberStateProcessing) s.metrics.ActiveMessages.Add(1) return msg, nil } func (s *subscriber) createSubscription(ctx context.Context) error { + s.mu.Lock() if s.subscription != nil { + s.mu.Unlock() return nil } + s.mu.Unlock() // Validate URL before attempting to open subscription if strings.TrimSpace(s.url) == "" { return errors.New("subscriber URL cannot be empty") } - if !strings.HasPrefix(s.url, "http") { - subs, err := pubsub.OpenSubscription(ctx, s.url) - if err != nil { - return fmt.Errorf("could not open topic subscription: %w", err) - } - s.subscription = subs + if strings.HasPrefix(s.url, "http") { + return nil + } + + subs, err := pubsub.OpenSubscription(ctx, s.url) + if err != nil { + return fmt.Errorf("could not open topic subscription: %w", err) } + s.mu.Lock() + defer s.mu.Unlock() + if s.subscription != nil { + // A concurrent caller won the race; discard ours to avoid leaking it. + _ = subs.Shutdown(ctx) + return nil + } + s.subscription = subs return nil } func (s *subscriber) Init(ctx context.Context) error { - if s.isInit.Load() && s.subscription != nil { + if s.isInit.Load() && s.loadSubscription() != nil { return nil } @@ -151,17 +185,19 @@ func (s *subscriber) recreateSubscription(ctx context.Context) { log.Warn("recreating subscription") - if s.subscription != nil { - err := s.subscription.Shutdown(ctx) - if err != nil { + s.mu.Lock() + sub := s.subscription + s.subscription = nil + s.mu.Unlock() + + if sub != nil { + if err := sub.Shutdown(ctx); err != nil && !isSubscriptionAlreadyShutdownErr(err) { log.WithError(err).Error("could not recreate subscription, stopping listener") s.workManager.StopError(ctx, err) } - s.subscription = nil } - err := s.createSubscription(ctx) - if err != nil { + if err := s.createSubscription(ctx); err != nil { log.WithError(err).Error("could not recreate subscription, stopping listener") s.workManager.StopError(ctx, err) } @@ -172,7 +208,7 @@ func (s *subscriber) Initiated() bool { } func (s *subscriber) State() SubscriberState { - return s.state + return s.loadState() } func (s *subscriber) Metrics() SubscriberMetrics { @@ -180,7 +216,7 @@ func (s *subscriber) Metrics() SubscriberMetrics { } func (s *subscriber) IsIdle() bool { - return s.metrics.IsIdle(s.state) + return s.metrics.IsIdle(s.loadState()) } func (s *subscriber) Stop(ctx context.Context) error { @@ -200,23 +236,32 @@ func (s *subscriber) Stop(ctx context.Context) error { s.isInit.Store(false) - if s.subscription != nil { - err := s.subscription.Shutdown(sctx) - if err != nil { - if isSubscriptionAlreadyShutdownErr(err) { - s.subscription = nil - return nil - } - return err - } - s.subscription = nil + // Detach the subscription under lock so concurrent Stop() / listen() callers + // each get either the live pointer (exactly once) or nil. + s.mu.Lock() + sub := s.subscription + s.subscription = nil + s.mu.Unlock() + + if sub == nil { + return nil } + if err := sub.Shutdown(sctx); err != nil { + if isSubscriptionAlreadyShutdownErr(err) { + return nil + } + return err + } return nil } func (s *subscriber) As(i any) bool { - return s.subscription.As(i) + sub := s.loadSubscription() + if sub == nil { + return false + } + return sub.As(i) } func (s *subscriber) processReceivedMessage(ctx context.Context, msg *pubsub.Message) error { diff --git a/service.go b/service.go index 2cb9bd59..42a062b3 100644 --- a/service.go +++ b/service.go @@ -68,7 +68,10 @@ type Service struct { backGroundClient func(ctx context.Context) error - driver server.Driver + // driverMu guards driver so Stop() can safely observe it while + // initializeServerDrivers() or startServerDriver() are still running. + driverMu sync.Mutex + driver server.Driver healthCheckers []Checker healthCheckPath string @@ -513,7 +516,7 @@ func (s *Service) openAPIHandler(base string) http.HandlerFunc { } func (s *Service) initializeServerDrivers(ctx context.Context, httpPort string) error { - if s.driver != nil { + if s.getDriver() != nil { return nil } @@ -530,17 +533,30 @@ func (s *Service) initializeServerDrivers(ctx context.Context, httpPort string) tlsConfig = nil } - s.driver = implementation.NewDefaultDriverWithTLS( + driver := implementation.NewDefaultDriverWithTLS( ctx, httpCfg, s.handler, httpPort, tlsConfig, ) + s.setDriver(driver) return nil } +func (s *Service) getDriver() server.Driver { + s.driverMu.Lock() + defer s.driverMu.Unlock() + return s.driver +} + +func (s *Service) setDriver(d server.Driver) { + s.driverMu.Lock() + defer s.driverMu.Unlock() + s.driver = d +} + func (s *Service) setupWorkloadAPITLS(ctx context.Context) (*tls.Config, bool, error) { if !s.shouldUseWorkloadAPIServerTLS() { return nil, false, nil @@ -616,19 +632,24 @@ func (s *Service) executeStartupMethods(ctx context.Context) { func (s *Service) startServerDriver(ctx context.Context, httpPort string) error { util.Log(ctx).WithField("port", httpPort).Info("Initiating server operations") + driver := s.getDriver() + if driver == nil { + return errors.New("server driver is not initialised") + } + if s.TLSEnabled() { cfg, ok := s.Config().(config.ConfigurationTLS) if !ok { return errors.New("TLS is enabled but configuration does not implement ConfigurationTLS") } - err := s.driver.ListenAndServeTLS(httpPort, cfg.TLSCertPath(), cfg.TLSCertKeyPath(), s.handler) + err := driver.ListenAndServeTLS(httpPort, cfg.TLSCertPath(), cfg.TLSCertKeyPath(), s.handler) if errors.Is(err, http.ErrServerClosed) { return nil } return err } - err := s.driver.ListenAndServe(httpPort, s.handler) + err := driver.ListenAndServe(httpPort, s.handler) if errors.Is(err, http.ErrServerClosed) { return nil } @@ -691,8 +712,8 @@ func (s *Service) Stop(ctx context.Context) { log.Info("service stopping") - if s.driver != nil { - if err := s.driver.Shutdown(ctx); err != nil && !errors.Is(err, context.Canceled) { + if driver := s.getDriver(); driver != nil { + if err := driver.Shutdown(ctx); err != nil && !errors.Is(err, context.Canceled) { log.WithError(err).Error("failed to shutdown HTTP server") } } From 5cb7e342e2fded5492abd4edd8bdd48e92c4060c Mon Sep 17 00:00:00 2001 From: Peter Bwire Date: Mon, 20 Apr 2026 16:57:27 +0300 Subject: [PATCH 2/2] =?UTF-8?q?fix:=20address=20review=20=E2=80=94=20preve?= =?UTF-8?q?nt=20subscriber=20recreate=20past=20Stop()=20and=20driver=20ini?= =?UTF-8?q?t=20leak?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two refinements from review on PR #646: - recreateSubscription: bail out if isInit is already false (Stop() ran first), and after creating the replacement subscription, re-check isInit and shut down the new subscription if Stop() raced in while we were creating. Previously a late recreate could orphan a live subscription past shutdown. - initializeServerDrivers: do the final "is driver already set" check under the driver mutex. Without it, two concurrent callers could both construct drivers and clobber each other, leaking the http.Server listener from the loser. --- queue/subscriber.go | 14 ++++++++++++++ service.go | 15 ++++++++------- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/queue/subscriber.go b/queue/subscriber.go index cc8bcfaf..bc4faab5 100644 --- a/queue/subscriber.go +++ b/queue/subscriber.go @@ -181,6 +181,7 @@ func (s *subscriber) recreateSubscription(ctx context.Context) { if !s.isInit.Load() { log.Error("only initialised subscriptions can be recreated") + return } log.Warn("recreating subscription") @@ -200,6 +201,19 @@ func (s *subscriber) recreateSubscription(ctx context.Context) { if err := s.createSubscription(ctx); err != nil { log.WithError(err).Error("could not recreate subscription, stopping listener") s.workManager.StopError(ctx, err) + return + } + + // If Stop() ran while we were recreating, drop the new subscription so we + // don't leak it past the shutdown boundary. + if !s.isInit.Load() { + s.mu.Lock() + orphan := s.subscription + s.subscription = nil + s.mu.Unlock() + if orphan != nil { + _ = orphan.Shutdown(ctx) + } } } diff --git a/service.go b/service.go index 42a062b3..c1aefc29 100644 --- a/service.go +++ b/service.go @@ -540,8 +540,15 @@ func (s *Service) initializeServerDrivers(ctx context.Context, httpPort string) httpPort, tlsConfig, ) - s.setDriver(driver) + // Double-check under lock: if a concurrent caller beat us to it, discard + // our driver to avoid leaking the underlying http.Server listener. + s.driverMu.Lock() + defer s.driverMu.Unlock() + if s.driver != nil { + return nil + } + s.driver = driver return nil } @@ -551,12 +558,6 @@ func (s *Service) getDriver() server.Driver { return s.driver } -func (s *Service) setDriver(d server.Driver) { - s.driverMu.Lock() - defer s.driverMu.Unlock() - s.driver = d -} - func (s *Service) setupWorkloadAPITLS(ctx context.Context) (*tls.Config, bool, error) { if !s.shouldUseWorkloadAPIServerTLS() { return nil, false, nil