Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions data/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package data

import (
"context"
"fmt"
"time"

"github.com/pitabwire/util"
Expand Down Expand Up @@ -98,16 +99,33 @@ func (model *BaseModel) BeforeSave(db *gorm.DB) error {
}

func (model *BaseModel) BeforeCreate(db *gorm.DB) error {
// Generate the ID first so CreatedAt can be derived from its embedded
// xid timestamp, keeping sort-by-id ≡ sort-by-created_at.
model.GenID(db.Statement.Context)

if model.Version <= 0 {
model.CreatedAt = time.Now()
model.ModifiedAt = time.Now()
created, err := createdAtFromID(model.ID)
if err != nil {
return err
}
model.CreatedAt = created
model.ModifiedAt = created
model.Version = 1
}

model.GenID(db.Statement.Context)
return nil
}

// createdAtFromID returns the time component embedded in a generated xid.
// All BaseModel IDs must be valid xids so sort-by-id ≡ sort-by-created_at
// and hypertable promotions retain monotonic time ordering.
func createdAtFromID(id string) (time.Time, error) {
parsed, err := xid.FromString(id)
if err != nil {
return time.Time{}, fmt.Errorf("BaseModel.ID %q is not a valid xid: %w", id, err)
}
return parsed.Time(), nil
}

// BeforeUpdate Updates time stamp every time we update status of a migration.
func (model *BaseModel) BeforeUpdate(db *gorm.DB) error {
model.ModifiedAt = time.Now()
Expand Down
65 changes: 54 additions & 11 deletions datastore/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,30 @@ type BaseRepository[T any] interface {
DeleteBatch(ctx context.Context, ids []string) error
}

// BaseRepositoryOption configures optional behaviour on NewBaseRepository.
type BaseRepositoryOption func(*baseRepositoryConfig)

type baseRepositoryConfig struct {
// bulkCreateConflictColumns is the explicit ON CONFLICT target used by
// BulkCreate. Empty means "bare ON CONFLICT DO NOTHING" — portable across
// plain-PK and composite-PK (hypertable) tables. Set this when the caller
// depends on conflict inference against a specific unique index, e.g.
// (id, created_at) for hypertables or a domain-unique column.
bulkCreateConflictColumns []clause.Column
}

// WithBulkCreateConflictColumns sets the ON CONFLICT target columns for
// BulkCreate. Columns must match an existing unique index on the target
// table, otherwise Postgres rejects the insert.
func WithBulkCreateConflictColumns(columns ...string) BaseRepositoryOption {
return func(c *baseRepositoryConfig) {
c.bulkCreateConflictColumns = c.bulkCreateConflictColumns[:0]
for _, col := range columns {
c.bulkCreateConflictColumns = append(c.bulkCreateConflictColumns, clause.Column{Name: col})
}
}
}

// baseRepository is the concrete implementation of BaseRepository.
type baseRepository[T data.BaseModelI] struct {
dbPool pool.Pool
Expand All @@ -55,6 +79,10 @@ type baseRepository[T data.BaseModelI] struct {

// allowedFields whitelist for safe column access (set during initialization)
allowedFields map[string]struct{}

// bulkCreateConflictColumns is the optional explicit ON CONFLICT target
// for BulkCreate. See WithBulkCreateConflictColumns.
bulkCreateConflictColumns []clause.Column
}

// NewBaseRepository creates a new base repository instance.
Expand All @@ -64,14 +92,21 @@ func NewBaseRepository[T data.BaseModelI](
dbPool pool.Pool,
workMan workerpool.Manager,
modelFactory func() T,
opts ...BaseRepositoryOption,
) BaseRepository[T] {
cfg := &baseRepositoryConfig{}
for _, opt := range opts {
opt(cfg)
}

repo := &baseRepository[T]{
dbPool: dbPool,
workMan: workMan,
modelFactory: modelFactory,
batchSize: 751, //nolint:mnd // default batch size
immutableFields: []string{"id", "created_at", "tenant_id", "partition_id"},
allowedFields: make(map[string]struct{}),
dbPool: dbPool,
workMan: workMan,
modelFactory: modelFactory,
batchSize: 751, //nolint:mnd // default batch size
immutableFields: []string{"id", "created_at", "tenant_id", "partition_id"},
allowedFields: make(map[string]struct{}),
bulkCreateConflictColumns: cfg.bulkCreateConflictColumns,
}

db := dbPool.DB(ctx, true)
Expand Down Expand Up @@ -176,16 +211,24 @@ func (br *baseRepository[T]) BatchSize() int {
}

// BulkCreate inserts multiple entities efficiently in a single transaction.
// The conflict target defaults to bare ON CONFLICT DO NOTHING so the same
// repository works for plain-PK and composite-PK (hypertable) tables.
// Callers that need inference against a specific unique index set it via
// WithBulkCreateConflictColumns at construction.
func (br *baseRepository[T]) BulkCreate(ctx context.Context, entities []T) error {
if len(entities) == 0 {
return nil
}

// CreateInBatches uses GORM's batch insert which is more efficient
// The batch size is configured in pool options (InsertBatchSize)
return br.Pool().DB(ctx, false).Clauses(clause.OnConflict{
DoNothing: true,
}).CreateInBatches(entities, br.BatchSize()).Error
onConflict := clause.OnConflict{DoNothing: true}
if len(br.bulkCreateConflictColumns) > 0 {
onConflict.Columns = br.bulkCreateConflictColumns
}

// CreateInBatches uses GORM's batch insert which is more efficient;
// the batch size is configured in pool options (InsertBatchSize).
return br.Pool().DB(ctx, false).Clauses(onConflict).
CreateInBatches(entities, br.BatchSize()).Error
}

// validateAffectedColumns checks if all columns are valid and allowed.
Expand Down
19 changes: 16 additions & 3 deletions datastore/repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,17 +367,30 @@ func (s *RepositoryTestSuite) TestCreate() {
expectError: false,
},
{
name: "create entity with pre-set ID",
name: "create entity with pre-set xid",
setupEntity: func(_ context.Context) *TestEntity {
entity := &TestEntity{
Name: "Entity with ID",
}
// Use unique timestamp-based ID to avoid conflicts between test runs
entity.ID = fmt.Sprintf("custom-id-%d", time.Now().UnixNano())
// Caller-supplied IDs must be valid xids so CreatedAt can be
// derived deterministically from the embedded timestamp.
entity.ID = util.IDString()
return entity
},
expectError: false,
},
{
name: "create entity with non-xid ID should fail",
setupEntity: func(_ context.Context) *TestEntity {
entity := &TestEntity{
Name: "Entity with legacy id",
}
entity.ID = fmt.Sprintf("custom-id-%d", time.Now().UnixNano())
return entity
},
expectError: true,
errorMsg: "is not a valid xid",
},
{
name: "create entity with version > 0 should fail",
setupEntity: func(_ context.Context) *TestEntity {
Expand Down
2 changes: 1 addition & 1 deletion queue/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"gocloud.dev/pubsub"
)

type SubscriberState int
type SubscriberState int32

const (
SubscriberStateWaiting SubscriberState = iota
Expand Down
Loading
Loading