From 28994a181e754be7f5f8d392bffb72aa229f059d Mon Sep 17 00:00:00 2001 From: boreq Date: Fri, 1 Dec 2023 16:24:20 +0100 Subject: [PATCH] Allow transactions to timeout --- cmd/event-service/di/inject_adapters.go | 4 +- cmd/event-service/di/service.go | 9 ++ cmd/event-service/di/wire_gen.go | 15 +- service/adapters/sqlite/pubsub_test.go | 2 +- service/adapters/sqlite/sqlite.go | 130 ++++++++++++------ service/adapters/sqlite/sqlite_test.go | 9 ++ service/app/app.go | 4 + service/app/downloader.go | 11 ++ .../app/handler_add_public_key_to_monitor.go | 3 + service/app/handler_get_event.go | 3 + service/app/handler_get_events.go | 3 + service/app/handler_get_public_key_info.go | 3 + service/app/handler_process_saved_event.go | 3 + service/app/handler_save_received_event.go | 3 + service/app/handler_update_metrics.go | 3 + 15 files changed, 156 insertions(+), 49 deletions(-) diff --git a/cmd/event-service/di/inject_adapters.go b/cmd/event-service/di/inject_adapters.go index 783d3f8..bf9213d 100644 --- a/cmd/event-service/di/inject_adapters.go +++ b/cmd/event-service/di/inject_adapters.go @@ -16,7 +16,7 @@ import ( var sqliteAdaptersSet = wire.NewSet( newSqliteDB, - sqlite.NewDatabaseMutex, + sqlite.NewTransactionRunner, sqlite.NewTransactionProvider, wire.Bind(new(app.TransactionProvider), new(*sqlite.TransactionProvider)), @@ -31,7 +31,7 @@ var sqliteAdaptersSet = wire.NewSet( var sqliteTestAdaptersSet = wire.NewSet( newSqliteDB, - sqlite.NewDatabaseMutex, + sqlite.NewTransactionRunner, sqlite.NewTestTransactionProvider, diff --git a/cmd/event-service/di/service.go b/cmd/event-service/di/service.go index 2290769..01270c8 100644 --- a/cmd/event-service/di/service.go +++ b/cmd/event-service/di/service.go @@ -6,6 +6,7 @@ import ( "github.com/boreq/errors" "github.com/hashicorp/go-multierror" "github.com/planetary-social/nos-event-service/internal/migrations" + "github.com/planetary-social/nos-event-service/service/adapters/sqlite" "github.com/planetary-social/nos-event-service/service/app" "github.com/planetary-social/nos-event-service/service/domain/downloader" "github.com/planetary-social/nos-event-service/service/ports/http" @@ -21,6 +22,7 @@ type Service struct { receivedEventSubscriber *memorypubsub.ReceivedEventSubscriber eventSavedEventSubscriber *sqlitepubsub.EventSavedEventSubscriber metricsTimer *timer.Metrics + transactionRunner *sqlite.TransactionRunner migrationsRunner *migrations.Runner migrations migrations.Migrations migrationsProgressCallback migrations.ProgressCallback @@ -33,6 +35,7 @@ func NewService( receivedEventSubscriber *memorypubsub.ReceivedEventSubscriber, eventSavedEventSubscriber *sqlitepubsub.EventSavedEventSubscriber, metricsTimer *timer.Metrics, + transactionRunner *sqlite.TransactionRunner, migrationsRunner *migrations.Runner, migrations migrations.Migrations, migrationsProgressCallback migrations.ProgressCallback, @@ -44,6 +47,7 @@ func NewService( receivedEventSubscriber: receivedEventSubscriber, eventSavedEventSubscriber: eventSavedEventSubscriber, metricsTimer: metricsTimer, + transactionRunner: transactionRunner, migrationsRunner: migrationsRunner, migrations: migrations, migrationsProgressCallback: migrationsProgressCallback, @@ -90,6 +94,11 @@ func (s Service) Run(ctx context.Context) error { errCh <- errors.Wrap(s.metricsTimer.Run(ctx), "metrics timer error") }() + runners++ + go func() { + errCh <- errors.Wrap(s.transactionRunner.Run(ctx), "transaction runner error") + }() + var err error for i := 0; i < runners; i++ { err = multierror.Append(err, errors.Wrap(<-errCh, "error returned by runner")) diff --git a/cmd/event-service/di/wire_gen.go b/cmd/event-service/di/wire_gen.go index d2e4a4f..4e5f611 100644 --- a/cmd/event-service/di/wire_gen.go +++ b/cmd/event-service/di/wire_gen.go @@ -48,8 +48,8 @@ func BuildService(contextContext context.Context, configConfig config.Config) (S Logger: logger, } genericAdaptersFactoryFn := newAdaptersFactoryFn(diBuildTransactionSqliteAdaptersDependencies) - databaseMutex := sqlite.NewDatabaseMutex() - genericTransactionProvider := sqlite.NewTransactionProvider(db, genericAdaptersFactoryFn, databaseMutex) + transactionRunner := sqlite.NewTransactionRunner(db) + genericTransactionProvider := sqlite.NewTransactionProvider(db, genericAdaptersFactoryFn, transactionRunner) prometheusPrometheus, err := prometheus.NewPrometheus(logger) if err != nil { cleanup() @@ -68,7 +68,7 @@ func BuildService(contextContext context.Context, configConfig config.Config) (S relayConnections := relays.NewRelayConnections(contextContext, logger, prometheusPrometheus) eventSender := relays.NewEventSender(relayConnections) processSavedEventHandler := app.NewProcessSavedEventHandler(genericTransactionProvider, relaysExtractor, contactsExtractor, externalEventPublisher, eventSender, logger, prometheusPrometheus) - sqliteGenericTransactionProvider := sqlite.NewPubSubTxTransactionProvider(db, databaseMutex) + sqliteGenericTransactionProvider := sqlite.NewPubSubTxTransactionProvider(db, transactionRunner) pubSub := sqlite.NewPubSub(sqliteGenericTransactionProvider, logger) subscriber := sqlite.NewSubscriber(pubSub, db) updateMetricsHandler := app.NewUpdateMetricsHandler(genericTransactionProvider, subscriber, logger, prometheusPrometheus) @@ -108,7 +108,7 @@ func BuildService(contextContext context.Context, configConfig config.Config) (S return Service{}, nil, err } loggingMigrationsProgressCallback := adapters.NewLoggingMigrationsProgressCallback(logger) - service := NewService(application, server, downloaderDownloader, receivedEventSubscriber, eventSavedEventSubscriber, metrics, runner, migrationsMigrations, loggingMigrationsProgressCallback) + service := NewService(application, server, downloaderDownloader, receivedEventSubscriber, eventSavedEventSubscriber, metrics, transactionRunner, runner, migrationsMigrations, loggingMigrationsProgressCallback) return service, func() { cleanup() }, nil @@ -132,9 +132,9 @@ func BuildTestAdapters(contextContext context.Context, tb testing.TB) (sqlite.Te Logger: logger, } genericAdaptersFactoryFn := newTestAdaptersFactoryFn(diBuildTransactionSqliteAdaptersDependencies) - databaseMutex := sqlite.NewDatabaseMutex() - genericTransactionProvider := sqlite.NewTestTransactionProvider(db, genericAdaptersFactoryFn, databaseMutex) - sqliteGenericTransactionProvider := sqlite.NewPubSubTxTransactionProvider(db, databaseMutex) + transactionRunner := sqlite.NewTransactionRunner(db) + genericTransactionProvider := sqlite.NewTestTransactionProvider(db, genericAdaptersFactoryFn, transactionRunner) + sqliteGenericTransactionProvider := sqlite.NewPubSubTxTransactionProvider(db, transactionRunner) pubSub := sqlite.NewPubSub(sqliteGenericTransactionProvider, logger) subscriber := sqlite.NewSubscriber(pubSub, db) migrationsStorage, err := sqlite.NewMigrationsStorage(db) @@ -158,6 +158,7 @@ func BuildTestAdapters(contextContext context.Context, tb testing.TB) (sqlite.Te MigrationsRunner: runner, Migrations: migrationsMigrations, MigrationsProgressCallback: loggingMigrationsProgressCallback, + TransactionRunner: transactionRunner, } return testedItems, func() { cleanup() diff --git a/service/adapters/sqlite/pubsub_test.go b/service/adapters/sqlite/pubsub_test.go index 9d9cd67..be024ef 100644 --- a/service/adapters/sqlite/pubsub_test.go +++ b/service/adapters/sqlite/pubsub_test.go @@ -37,7 +37,7 @@ func TestPubSub_PublishingMessagesWithIdenticalUUIDsReturnsAnError(t *testing.T) require.NoError(t, err) err = adapters.PubSub.Publish(ctx, fixtures.SomeString(), msg) - require.EqualError(t, err, "transaction error: error calling the provided function: UNIQUE constraint failed: pubsub.uuid") + require.EqualError(t, err, "transaction error: received an error: error calling the callback: error calling the adapters callback: UNIQUE constraint failed: pubsub.uuid") } func TestPubSub_NackedMessagesAreRetried(t *testing.T) { diff --git a/service/adapters/sqlite/sqlite.go b/service/adapters/sqlite/sqlite.go index 9ac504a..0d4d5e1 100644 --- a/service/adapters/sqlite/sqlite.go +++ b/service/adapters/sqlite/sqlite.go @@ -3,7 +3,6 @@ package sqlite import ( "context" "database/sql" - "sync" "github.com/boreq/errors" "github.com/hashicorp/go-multierror" @@ -30,6 +29,7 @@ type TestedItems struct { MigrationsRunner *migrations.Runner Migrations migrations.Migrations MigrationsProgressCallback migrations.ProgressCallback + TransactionRunner *TransactionRunner } func Open(conf config.Config) (*sql.DB, error) { @@ -47,12 +47,12 @@ type TransactionProvider = GenericTransactionProvider[app.Adapters] func NewTransactionProvider( db *sql.DB, fn AdaptersFactoryFn, - mutex *DatabaseMutex, + runner *TransactionRunner, ) *TransactionProvider { return &TransactionProvider{ - db: db, - fn: fn, - mutex: mutex, + db: db, + fn: fn, + runner: runner, } } @@ -62,12 +62,12 @@ type TestTransactionProvider = GenericTransactionProvider[TestAdapters] func NewTestTransactionProvider( db *sql.DB, fn TestAdaptersFactoryFn, - mutex *DatabaseMutex, + runner *TransactionRunner, ) *TestTransactionProvider { return &TestTransactionProvider{ - db: db, - fn: fn, - mutex: mutex, + db: db, + fn: fn, + runner: runner, } } @@ -75,47 +75,105 @@ type PubSubTxTransactionProvider = GenericTransactionProvider[*sql.Tx] func NewPubSubTxTransactionProvider( db *sql.DB, - mutex *DatabaseMutex, + runner *TransactionRunner, ) *PubSubTxTransactionProvider { return &PubSubTxTransactionProvider{ db: db, fn: func(db *sql.DB, tx *sql.Tx) (*sql.Tx, error) { return tx, nil }, - mutex: mutex, + runner: runner, } } type GenericAdaptersFactoryFn[T any] func(*sql.DB, *sql.Tx) (T, error) type GenericTransactionProvider[T any] struct { - db *sql.DB - fn GenericAdaptersFactoryFn[T] - mutex *DatabaseMutex + db *sql.DB + fn GenericAdaptersFactoryFn[T] + runner *TransactionRunner } -func (t *GenericTransactionProvider[T]) Transact(ctx context.Context, f func(context.Context, T) error) error { - t.mutex.Lock() - defer t.mutex.Unlock() +func (t *GenericTransactionProvider[T]) Transact(ctx context.Context, fn func(context.Context, T) error) error { + transactionFunc := t.makeTransactionFunc(fn) + return t.runner.TryRun(ctx, transactionFunc) +} - tx, err := t.db.BeginTx(ctx, nil) - if err != nil { - return errors.Wrap(err, "error starting the transaction") +func (t *GenericTransactionProvider[T]) makeTransactionFunc(fn func(context.Context, T) error) TransactionFunc { + return func(ctx context.Context, db *sql.DB, tx *sql.Tx) error { + adapters, err := t.fn(t.db, tx) + if err != nil { + return errors.Wrap(err, "error building the adapters") + } + + if err := fn(ctx, adapters); err != nil { + return errors.Wrap(err, "error calling the adapters callback") + } + + return nil } +} - adapters, err := t.fn(t.db, tx) - if err != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - err = multierror.Append(err, errors.Wrap(rollbackErr, "rollback error")) +type TransactionFunc func(context.Context, *sql.DB, *sql.Tx) error + +type TransactionRunner struct { + db *sql.DB + chIn chan transactionTask +} + +func NewTransactionRunner(db *sql.DB) *TransactionRunner { + return &TransactionRunner{ + db: db, + chIn: make(chan transactionTask), + } +} + +func (t *TransactionRunner) TryRun(ctx context.Context, fn TransactionFunc) error { + resultCh := make(chan error) + + select { + case t.chIn <- newTransactionTask(ctx, fn, resultCh): + case <-ctx.Done(): + return ctx.Err() + } + + select { + case err := <-resultCh: + return errors.Wrap(err, "received an error") + case <-ctx.Done(): + return ctx.Err() + } +} + +func (t *TransactionRunner) Run(ctx context.Context) error { + for { + select { + case task := <-t.chIn: + select { + case task.ResultCh <- t.run(task.Ctx, task.Fn): + continue + case <-task.Ctx.Done(): + continue + case <-ctx.Done(): + return ctx.Err() + } + case <-ctx.Done(): + return ctx.Err() } - return errors.Wrap(err, "error building the adapters") } +} - if err := f(ctx, adapters); err != nil { +func (t *TransactionRunner) run(ctx context.Context, fn TransactionFunc) error { + tx, err := t.db.BeginTx(ctx, nil) + if err != nil { + return errors.Wrap(err, "error starting the transaction") + } + + if err := fn(ctx, t.db, tx); err != nil { if rollbackErr := tx.Rollback(); rollbackErr != nil { err = multierror.Append(err, errors.Wrap(rollbackErr, "rollback error")) } - return errors.Wrap(err, "error calling the provided function") + return errors.Wrap(err, "error calling the callback") } if err := tx.Commit(); err != nil { @@ -125,18 +183,12 @@ func (t *GenericTransactionProvider[T]) Transact(ctx context.Context, f func(con return nil } -type DatabaseMutex struct { - m sync.Mutex -} - -func NewDatabaseMutex() *DatabaseMutex { - return &DatabaseMutex{} -} - -func (m *DatabaseMutex) Lock() { - m.m.Lock() +type transactionTask struct { + Ctx context.Context + Fn TransactionFunc + ResultCh chan<- error } -func (m *DatabaseMutex) Unlock() { - m.m.Unlock() +func newTransactionTask(ctx context.Context, fn TransactionFunc, resultCh chan<- error) transactionTask { + return transactionTask{Ctx: ctx, Fn: fn, ResultCh: resultCh} } diff --git a/service/adapters/sqlite/sqlite_test.go b/service/adapters/sqlite/sqlite_test.go index d913645..42a29b8 100644 --- a/service/adapters/sqlite/sqlite_test.go +++ b/service/adapters/sqlite/sqlite_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/boreq/errors" "github.com/planetary-social/nos-event-service/cmd/event-service/di" "github.com/planetary-social/nos-event-service/service/adapters/sqlite" "github.com/stretchr/testify/require" @@ -18,5 +19,13 @@ func NewTestAdapters(ctx context.Context, tb testing.TB) sqlite.TestedItems { err = adapters.MigrationsRunner.Run(ctx, adapters.Migrations, adapters.MigrationsProgressCallback) require.NoError(tb, err) + go func() { + if err := adapters.TransactionRunner.Run(ctx); err != nil { + if !errors.Is(err, context.Canceled) { + panic(err) + } + } + }() + return adapters } diff --git a/service/app/app.go b/service/app/app.go index c10015d..8287e80 100644 --- a/service/app/app.go +++ b/service/app/app.go @@ -14,6 +14,10 @@ var ( ErrPublicKeyToMonitorNotFound = errors.New("public key to monitor not found") ) +const ( + applicationHandlerTimeout = 30 * time.Second +) + type TransactionProvider interface { Transact(context.Context, func(context.Context, Adapters) error) error } diff --git a/service/app/downloader.go b/service/app/downloader.go index ad584df..232d309 100644 --- a/service/app/downloader.go +++ b/service/app/downloader.go @@ -2,6 +2,7 @@ package app import ( "context" + "time" "github.com/boreq/errors" "github.com/planetary-social/nos-event-service/internal" @@ -22,6 +23,11 @@ func NewDatabaseRelaySource(transactionProvider TransactionProvider, logger logg } func (m *DatabaseRelaySource) GetRelays(ctx context.Context) ([]domain.RelayAddress, error) { + start := time.Now() + defer func() { + m.logger.Debug().WithField("duration", time.Since(start)).Message("got relays") + }() + var maybeResult []domain.MaybeRelayAddress if err := m.transactionProvider.Transact(ctx, func(ctx context.Context, adapters Adapters) error { tmp, err := adapters.Relays.List(ctx) @@ -69,6 +75,11 @@ func NewDatabasePublicKeySource(transactionProvider TransactionProvider, logger } func (d *DatabasePublicKeySource) GetPublicKeys(ctx context.Context) ([]domain.PublicKey, error) { + start := time.Now() + defer func() { + d.logger.Debug().WithField("duration", time.Since(start)).Message("got public keys") + }() + result := internal.NewEmptySet[domain.PublicKey]() if err := d.transactionProvider.Transact(ctx, func(ctx context.Context, adapters Adapters) error { diff --git a/service/app/handler_add_public_key_to_monitor.go b/service/app/handler_add_public_key_to_monitor.go index 08bdb5e..b07d004 100644 --- a/service/app/handler_add_public_key_to_monitor.go +++ b/service/app/handler_add_public_key_to_monitor.go @@ -38,6 +38,9 @@ func NewAddPublicKeyToMonitorHandler( func (h *AddPublicKeyToMonitorHandler) Handle(ctx context.Context, cmd AddPublicKeyToMonitor) (err error) { defer h.metrics.StartApplicationCall("addPublicKeyToMonitor").End(&err) + ctx, cancel := context.WithTimeout(ctx, applicationHandlerTimeout) + defer cancel() + publicKeyToMonitor, err := domain.NewPublicKeyToMonitor( cmd.publicKey, time.Now(), diff --git a/service/app/handler_get_event.go b/service/app/handler_get_event.go index 175753a..bdbcbe3 100644 --- a/service/app/handler_get_event.go +++ b/service/app/handler_get_event.go @@ -37,6 +37,9 @@ func NewGetEventHandler( func (h *GetEventHandler) Handle(ctx context.Context, cmd GetEvent) (event domain.Event, err error) { defer h.metrics.StartApplicationCall("getEvent").End(&err) + ctx, cancel := context.WithTimeout(ctx, applicationHandlerTimeout) + defer cancel() + if err := h.transactionProvider.Transact(ctx, func(ctx context.Context, adapters Adapters) error { tmp, err := adapters.Events.Get(ctx, cmd.id) if err != nil { diff --git a/service/app/handler_get_events.go b/service/app/handler_get_events.go index 6a89ea1..0201f9c 100644 --- a/service/app/handler_get_events.go +++ b/service/app/handler_get_events.go @@ -59,6 +59,9 @@ func NewGetEventsHandler( func (h *GetEventsHandler) Handle(ctx context.Context, cmd GetEvents) (result GetEventsResult, err error) { defer h.metrics.StartApplicationCall("getEvents").End(&err) + ctx, cancel := context.WithTimeout(ctx, applicationHandlerTimeout) + defer cancel() + var events []domain.Event if err := h.transactionProvider.Transact(ctx, func(ctx context.Context, adapters Adapters) error { tmp, err := adapters.Events.List(ctx, cmd.after, getEventsLimit+1) diff --git a/service/app/handler_get_public_key_info.go b/service/app/handler_get_public_key_info.go index 8c3e893..9bb19be 100644 --- a/service/app/handler_get_public_key_info.go +++ b/service/app/handler_get_public_key_info.go @@ -54,6 +54,9 @@ func NewGetPublicKeyInfoHandler( func (h *GetPublicKeyInfoHandler) Handle(ctx context.Context, cmd GetPublicKeyInfo) (publicKeyInfo PublicKeyInfo, err error) { defer h.metrics.StartApplicationCall("getPublicKeyInfo").End(&err) + ctx, cancel := context.WithTimeout(ctx, applicationHandlerTimeout) + defer cancel() + var followeesCount, followersCount int if err := h.transactionProvider.Transact(ctx, func(ctx context.Context, adapters Adapters) error { diff --git a/service/app/handler_process_saved_event.go b/service/app/handler_process_saved_event.go index b8ded1c..a36688e 100644 --- a/service/app/handler_process_saved_event.go +++ b/service/app/handler_process_saved_event.go @@ -69,6 +69,9 @@ func NewProcessSavedEventHandler( func (h *ProcessSavedEventHandler) Handle(ctx context.Context, cmd ProcessSavedEvent) (err error) { defer h.metrics.StartApplicationCall("processSavedEvent").End(&err) + ctx, cancel := context.WithTimeout(ctx, applicationHandlerTimeout) + defer cancel() + event, err := h.loadEvent(ctx, cmd.id) if err != nil { return errors.Wrap(err, "error loading the event") diff --git a/service/app/handler_save_received_event.go b/service/app/handler_save_received_event.go index 03b299b..eb664e2 100644 --- a/service/app/handler_save_received_event.go +++ b/service/app/handler_save_received_event.go @@ -50,6 +50,9 @@ func NewSaveReceivedEventHandler( func (h *SaveReceivedEventHandler) Handle(ctx context.Context, cmd SaveReceivedEvent) (err error) { defer h.metrics.StartApplicationCall("saveReceivedEvent").End(&err) + ctx, cancel := context.WithTimeout(ctx, applicationHandlerTimeout) + defer cancel() + h.logger. Trace(). WithField("relay", cmd.relay.String()). diff --git a/service/app/handler_update_metrics.go b/service/app/handler_update_metrics.go index e9800fd..5540b76 100644 --- a/service/app/handler_update_metrics.go +++ b/service/app/handler_update_metrics.go @@ -31,6 +31,9 @@ func NewUpdateMetricsHandler( func (h *UpdateMetricsHandler) Handle(ctx context.Context) (err error) { defer h.metrics.StartApplicationCall("updateMetrics").End(&err) + ctx, cancel := context.WithTimeout(ctx, applicationHandlerTimeout) + defer cancel() + n, err := h.subscriber.EventSavedQueueLength(ctx) if err != nil { return errors.Wrap(err, "error reading queue length")