From 6ef1a99c313e09155e3e69eb1fcc1a3153d60de9 Mon Sep 17 00:00:00 2001 From: Filip Borkiewicz Date: Wed, 13 Dec 2023 18:52:40 +0000 Subject: [PATCH] Time window based replication (#60) --- Makefile | 4 + cmd/event-service/di/inject_adapters.go | 4 + cmd/event-service/di/service.go | 8 + cmd/event-service/di/wire.go | 10 +- cmd/event-service/di/wire_gen.go | 15 +- internal/fixtures/fixtures.go | 14 + service/adapters/current_time_provider.go | 16 + .../adapters/mocks/current_time_provider.go | 21 + service/app/sources.go | 29 ++ service/domain/downloader/downloader.go | 134 ++---- service/domain/downloader/scheduler.go | 400 ++++++++++++++++++ service/domain/downloader/scheduler_test.go | 305 +++++++++++++ service/domain/downloader/time_window.go | 44 ++ service/domain/downloader/time_window_task.go | 138 ++++++ .../downloader/time_window_task_test.go | 22 + service/domain/downloader/time_window_test.go | 23 + service/domain/filter.go | 42 +- 17 files changed, 1124 insertions(+), 105 deletions(-) create mode 100644 service/adapters/current_time_provider.go create mode 100644 service/adapters/mocks/current_time_provider.go create mode 100644 service/domain/downloader/scheduler.go create mode 100644 service/domain/downloader/scheduler_test.go create mode 100644 service/domain/downloader/time_window.go create mode 100644 service/domain/downloader/time_window_task.go create mode 100644 service/domain/downloader/time_window_task_test.go create mode 100644 service/domain/downloader/time_window_test.go diff --git a/Makefile b/Makefile index 8088441..9f968d0 100644 --- a/Makefile +++ b/Makefile @@ -17,6 +17,10 @@ fmt: test: go test -race ./... +.PHONY: test-nocache +test-nocache: + go test -race -count=1 ./... + .PHONY: test-bench test-bench: go test -v -race -run="^$$" -bench=. -benchtime=1x ./... diff --git a/cmd/event-service/di/inject_adapters.go b/cmd/event-service/di/inject_adapters.go index bf9213d..e11d9ba 100644 --- a/cmd/event-service/di/inject_adapters.go +++ b/cmd/event-service/di/inject_adapters.go @@ -6,6 +6,7 @@ import ( "github.com/boreq/errors" "github.com/google/wire" "github.com/planetary-social/nos-event-service/internal/logging" + "github.com/planetary-social/nos-event-service/service/adapters" "github.com/planetary-social/nos-event-service/service/adapters/prometheus" "github.com/planetary-social/nos-event-service/service/adapters/sqlite" "github.com/planetary-social/nos-event-service/service/app" @@ -64,6 +65,9 @@ var adaptersSet = wire.NewSet( wire.Bind(new(app.Metrics), new(*prometheus.Prometheus)), wire.Bind(new(relays.Metrics), new(*prometheus.Prometheus)), wire.Bind(new(downloader.Metrics), new(*prometheus.Prometheus)), + + adapters.NewCurrentTimeProvider, + wire.Bind(new(downloader.CurrentTimeProvider), new(*adapters.CurrentTimeProvider)), ) func newAdaptersFactoryFn(deps buildTransactionSqliteAdaptersDependencies) sqlite.AdaptersFactoryFn { diff --git a/cmd/event-service/di/service.go b/cmd/event-service/di/service.go index 01270c8..9c2f986 100644 --- a/cmd/event-service/di/service.go +++ b/cmd/event-service/di/service.go @@ -23,6 +23,7 @@ type Service struct { eventSavedEventSubscriber *sqlitepubsub.EventSavedEventSubscriber metricsTimer *timer.Metrics transactionRunner *sqlite.TransactionRunner + taskScheduler *downloader.TaskScheduler migrationsRunner *migrations.Runner migrations migrations.Migrations migrationsProgressCallback migrations.ProgressCallback @@ -36,6 +37,7 @@ func NewService( eventSavedEventSubscriber *sqlitepubsub.EventSavedEventSubscriber, metricsTimer *timer.Metrics, transactionRunner *sqlite.TransactionRunner, + taskScheduler *downloader.TaskScheduler, migrationsRunner *migrations.Runner, migrations migrations.Migrations, migrationsProgressCallback migrations.ProgressCallback, @@ -49,6 +51,7 @@ func NewService( metricsTimer: metricsTimer, transactionRunner: transactionRunner, migrationsRunner: migrationsRunner, + taskScheduler: taskScheduler, migrations: migrations, migrationsProgressCallback: migrationsProgressCallback, } @@ -99,6 +102,11 @@ func (s Service) Run(ctx context.Context) error { errCh <- errors.Wrap(s.transactionRunner.Run(ctx), "transaction runner error") }() + runners++ + go func() { + errCh <- errors.Wrap(s.taskScheduler.Run(ctx), "task scheduler error") + }() + var err error for i := 0; i < runners; i++ { err = multierror.Append(err, errors.Wrap(<-errCh, "error returned by runner")) diff --git a/cmd/event-service/di/wire.go b/cmd/event-service/di/wire.go index 79cb5f0..75dd2d6 100644 --- a/cmd/event-service/di/wire.go +++ b/cmd/event-service/di/wire.go @@ -150,12 +150,20 @@ var downloaderSet = wire.NewSet( wire.Bind(new(downloader.RelayConnections), new(*relays.RelayConnections)), app.NewDatabasePublicKeySource, - wire.Bind(new(downloader.PublicKeySource), new(*app.DatabasePublicKeySource)), + newCachedPublicKeySource, + wire.Bind(new(downloader.PublicKeySource), new(*app.CachedDatabasePublicKeySource)), relays.NewEventSender, wire.Bind(new(app.EventSender), new(*relays.EventSender)), + + downloader.NewTaskScheduler, + wire.Bind(new(downloader.Scheduler), new(*downloader.TaskScheduler)), ) +func newCachedPublicKeySource(underlying *app.DatabasePublicKeySource) *app.CachedDatabasePublicKeySource { + return app.NewCachedDatabasePublicKeySource(underlying) +} + var domainSet = wire.NewSet( domain.NewRelaysExtractor, wire.Bind(new(app.RelaysExtractor), new(*domain.RelaysExtractor)), diff --git a/cmd/event-service/di/wire_gen.go b/cmd/event-service/di/wire_gen.go index 6cf0805..994e62f 100644 --- a/cmd/event-service/di/wire_gen.go +++ b/cmd/event-service/di/wire_gen.go @@ -89,9 +89,12 @@ func BuildService(contextContext context.Context, configConfig config.Config) (S bootstrapRelaySource := relays.NewBootstrapRelaySource() databaseRelaySource := app.NewDatabaseRelaySource(genericTransactionProvider, logger) databasePublicKeySource := app.NewDatabasePublicKeySource(genericTransactionProvider, logger) + cachedDatabasePublicKeySource := newCachedPublicKeySource(databasePublicKeySource) receivedEventPubSub := memorypubsub.NewReceivedEventPubSub() - relayDownloaderFactory := downloader.NewRelayDownloaderFactory(relayConnections, receivedEventPubSub, logger, prometheusPrometheus) - downloaderDownloader := downloader.NewDownloader(bootstrapRelaySource, databaseRelaySource, databasePublicKeySource, logger, prometheusPrometheus, relayDownloaderFactory) + currentTimeProvider := adapters.NewCurrentTimeProvider() + taskScheduler := downloader.NewTaskScheduler(cachedDatabasePublicKeySource, currentTimeProvider, logger) + relayDownloaderFactory := downloader.NewRelayDownloaderFactory(relayConnections, receivedEventPubSub, taskScheduler, logger, prometheusPrometheus) + downloaderDownloader := downloader.NewDownloader(bootstrapRelaySource, databaseRelaySource, cachedDatabasePublicKeySource, logger, prometheusPrometheus, relayDownloaderFactory) receivedEventSubscriber := memorypubsub2.NewReceivedEventSubscriber(receivedEventPubSub, saveReceivedEventHandler, logger) eventSavedEventSubscriber := sqlitepubsub.NewEventSavedEventSubscriber(processSavedEventHandler, subscriber, logger, prometheusPrometheus) metrics := timer.NewMetrics(application, logger) @@ -108,7 +111,7 @@ func BuildService(contextContext context.Context, configConfig config.Config) (S return Service{}, nil, err } loggingMigrationsProgressCallback := adapters.NewLoggingMigrationsProgressCallback(logger) - service := NewService(application, server, downloaderDownloader, receivedEventSubscriber, eventSavedEventSubscriber, metrics, transactionRunner, runner, migrationsMigrations, loggingMigrationsProgressCallback) + service := NewService(application, server, downloaderDownloader, receivedEventSubscriber, eventSavedEventSubscriber, metrics, transactionRunner, taskScheduler, runner, migrationsMigrations, loggingMigrationsProgressCallback) return service, func() { cleanup() }, nil @@ -261,6 +264,10 @@ type buildTransactionSqliteAdaptersDependencies struct { Logger logging.Logger } -var downloaderSet = wire.NewSet(downloader.NewRelayDownloaderFactory, downloader.NewDownloader, relays.NewBootstrapRelaySource, wire.Bind(new(downloader.BootstrapRelaySource), new(*relays.BootstrapRelaySource)), app.NewDatabaseRelaySource, wire.Bind(new(downloader.RelaySource), new(*app.DatabaseRelaySource)), relays.NewRelayConnections, wire.Bind(new(downloader.RelayConnections), new(*relays.RelayConnections)), app.NewDatabasePublicKeySource, wire.Bind(new(downloader.PublicKeySource), new(*app.DatabasePublicKeySource)), relays.NewEventSender, wire.Bind(new(app.EventSender), new(*relays.EventSender))) +var downloaderSet = wire.NewSet(downloader.NewRelayDownloaderFactory, downloader.NewDownloader, relays.NewBootstrapRelaySource, wire.Bind(new(downloader.BootstrapRelaySource), new(*relays.BootstrapRelaySource)), app.NewDatabaseRelaySource, wire.Bind(new(downloader.RelaySource), new(*app.DatabaseRelaySource)), relays.NewRelayConnections, wire.Bind(new(downloader.RelayConnections), new(*relays.RelayConnections)), app.NewDatabasePublicKeySource, newCachedPublicKeySource, wire.Bind(new(downloader.PublicKeySource), new(*app.CachedDatabasePublicKeySource)), relays.NewEventSender, wire.Bind(new(app.EventSender), new(*relays.EventSender)), downloader.NewTaskScheduler, wire.Bind(new(downloader.Scheduler), new(*downloader.TaskScheduler))) + +func newCachedPublicKeySource(underlying *app.DatabasePublicKeySource) *app.CachedDatabasePublicKeySource { + return app.NewCachedDatabasePublicKeySource(underlying) +} var domainSet = wire.NewSet(domain.NewRelaysExtractor, wire.Bind(new(app.RelaysExtractor), new(*domain.RelaysExtractor)), domain.NewContactsExtractor, wire.Bind(new(app.ContactsExtractor), new(*domain.ContactsExtractor))) diff --git a/internal/fixtures/fixtures.go b/internal/fixtures/fixtures.go index e0ff46f..e1849c4 100644 --- a/internal/fixtures/fixtures.go +++ b/internal/fixtures/fixtures.go @@ -8,11 +8,13 @@ import ( "math/rand" "os" "testing" + "time" "github.com/nbd-wtf/go-nostr" "github.com/planetary-social/nos-event-service/internal" "github.com/planetary-social/nos-event-service/internal/logging" "github.com/planetary-social/nos-event-service/service/domain" + "github.com/planetary-social/nos-event-service/service/domain/downloader" "github.com/stretchr/testify/require" ) @@ -186,6 +188,18 @@ func SomeMaybeRelayAddress() domain.MaybeRelayAddress { return domain.NewMaybeRelayAddress(SomeString()) } +func SomeTimeWindow() downloader.TimeWindow { + return downloader.MustNewTimeWindow(SomeTime(), SomeDuration()) +} + +func SomeTime() time.Time { + return time.Unix(int64(rand.Intn(10000000)), 0) +} + +func SomeDuration() time.Duration { + return time.Duration(1+rand.Intn(100)) * time.Second +} + var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") func randSeq(n int) string { diff --git a/service/adapters/current_time_provider.go b/service/adapters/current_time_provider.go new file mode 100644 index 0000000..62f55a7 --- /dev/null +++ b/service/adapters/current_time_provider.go @@ -0,0 +1,16 @@ +package adapters + +import ( + "time" +) + +type CurrentTimeProvider struct { +} + +func NewCurrentTimeProvider() *CurrentTimeProvider { + return &CurrentTimeProvider{} +} + +func (c *CurrentTimeProvider) GetCurrentTime() time.Time { + return time.Now() +} diff --git a/service/adapters/mocks/current_time_provider.go b/service/adapters/mocks/current_time_provider.go new file mode 100644 index 0000000..1de61ce --- /dev/null +++ b/service/adapters/mocks/current_time_provider.go @@ -0,0 +1,21 @@ +package mocks + +import ( + "time" +) + +type CurrentTimeProvider struct { + t time.Time +} + +func NewCurrentTimeProvider() *CurrentTimeProvider { + return &CurrentTimeProvider{} +} + +func (c *CurrentTimeProvider) SetCurrentTime(t time.Time) { + c.t = t +} + +func (c *CurrentTimeProvider) GetCurrentTime() time.Time { + return c.t +} diff --git a/service/app/sources.go b/service/app/sources.go index 552e751..61393ca 100644 --- a/service/app/sources.go +++ b/service/app/sources.go @@ -11,6 +11,8 @@ import ( "github.com/planetary-social/nos-event-service/service/domain/downloader" ) +const cachePublicKeysFor = 1 * time.Minute + type DatabaseRelaySource struct { transactionProvider TransactionProvider logger logging.Logger @@ -110,3 +112,30 @@ func (d *DatabasePublicKeySource) GetPublicKeys(ctx context.Context) (downloader publicKeysToMonitorFollowees.List(), ), nil } + +type CachedDatabasePublicKeySource struct { + keys *downloader.PublicKeys + t time.Time + source downloader.PublicKeySource +} + +func NewCachedDatabasePublicKeySource( + source downloader.PublicKeySource, +) *CachedDatabasePublicKeySource { + return &CachedDatabasePublicKeySource{ + source: source, + } +} + +func (d *CachedDatabasePublicKeySource) GetPublicKeys(ctx context.Context) (downloader.PublicKeys, error) { + if d.keys == nil || time.Since(d.t) > cachePublicKeysFor { + newKeys, err := d.source.GetPublicKeys(ctx) + if err != nil { + return downloader.PublicKeys{}, errors.Wrap(err, "error getting new public keys") + } + d.keys = &newKeys + d.t = time.Now() + } + + return *d.keys, nil +} diff --git a/service/domain/downloader/downloader.go b/service/domain/downloader/downloader.go index 29d2af7..ebcb0be 100644 --- a/service/domain/downloader/downloader.go +++ b/service/domain/downloader/downloader.go @@ -14,10 +14,9 @@ import ( ) const ( - downloadEventsFromLast = 1 * time.Hour - storeMetricsEvery = 1 * time.Minute + storeMetricsEvery = 1 * time.Minute - refreshKnownRelaysAndPublicKeysEvery = 1 * time.Minute + refreshKnownRelaysEvery = 1 * time.Minute ) var ( @@ -141,15 +140,8 @@ func (d *Downloader) Run(ctx context.Context) error { Message("error updating downloaders") } - if err := d.updatePublicKeys(ctx); err != nil { - d.logger. - Error(). - WithError(err). - Message("error updating downloaders") - } - select { - case <-time.After(refreshKnownRelaysAndPublicKeysEvery): + case <-time.After(refreshKnownRelaysEvery): case <-ctx.Done(): return ctx.Err() } @@ -209,7 +201,10 @@ func (d *Downloader) updateDownloaders(ctx context.Context) error { } ctx, cancel := context.WithCancel(ctx) - downloader.Start(ctx) + if err := downloader.Start(ctx); err != nil { + cancel() + return errors.Wrap(err, "error starting a downloader") + } d.relayDownloaders[relayAddress] = runningRelayDownloader{ Context: ctx, CancelFunc: cancel, @@ -239,27 +234,6 @@ func (d *Downloader) getRelays(ctx context.Context) (*internal.Set[domain.RelayA return result, nil } -func (d *Downloader) updatePublicKeys(ctx context.Context) error { - publicKeys, err := d.publicKeySource.GetPublicKeys(ctx) - if err != nil { - return errors.Wrap(err, "error getting public keys") - } - - d.relayDownloadersLock.Lock() - defer d.relayDownloadersLock.Unlock() - - isDifferentThanPrevious := publicKeys.Equal(d.previousPublicKeys) - - for _, v := range d.relayDownloaders { - if err := v.RelayDownloader.UpdateSubscription(v.Context, isDifferentThanPrevious, publicKeys); err != nil { - return errors.Wrap(err, "error updating subscription") - } - } - - d.previousPublicKeys = publicKeys - return nil -} - type runningRelayDownloader struct { Context context.Context CancelFunc context.CancelFunc @@ -267,11 +241,8 @@ type runningRelayDownloader struct { } type RelayDownloader struct { - address domain.RelayAddress - - publicKeySubscriptionCancelFunc context.CancelFunc - publicKeySubscriptionLock sync.Mutex - + address domain.RelayAddress + scheduler Scheduler receivedEventPublisher ReceivedEventPublisher relayConnections RelayConnections logger logging.Logger @@ -280,14 +251,15 @@ type RelayDownloader struct { func NewRelayDownloader( address domain.RelayAddress, + scheduler Scheduler, receivedEventPublisher ReceivedEventPublisher, relayConnections RelayConnections, logger logging.Logger, metrics Metrics, ) *RelayDownloader { v := &RelayDownloader{ - address: address, - + address: address, + scheduler: scheduler, receivedEventPublisher: receivedEventPublisher, relayConnections: relayConnections, logger: logger.New(fmt.Sprintf("relayDownloader(%s)", address.String())), @@ -296,35 +268,38 @@ func NewRelayDownloader( return v } -func (d *RelayDownloader) Start(ctx context.Context) { - go d.downloadMessages(ctx, domain.NewFilter( - nil, - globalEventKindsToDownload, - nil, - nil, - d.downloadSince(), - )) -} +func (d *RelayDownloader) Start(ctx context.Context) error { + ch, err := d.scheduler.GetTasks(ctx, d.address) + if err != nil { + return errors.Wrap(err, "error getting task channel") + } -func (d *RelayDownloader) downloadSince() *time.Time { - return internal.Pointer(time.Now().Add(-downloadEventsFromLast)) + go func() { + for task := range ch { + go d.performTask(task) + } + }() + return nil } -func (d *RelayDownloader) downloadMessages(ctx context.Context, filter domain.Filter) { - if err := d.downloadMessagesWithErr(ctx, filter); err != nil { +func (d *RelayDownloader) performTask(task Task) { + if err := d.performTaskWithErr(task); err != nil { d.logger.Error().WithError(err).Message("error downloading messages") + task.OnError(err) } } -func (d *RelayDownloader) downloadMessagesWithErr(ctx context.Context, filter domain.Filter) error { - ch, err := d.relayConnections.GetEvents(ctx, d.address, filter) +func (d *RelayDownloader) performTaskWithErr(task Task) error { + ch, err := d.relayConnections.GetEvents(task.Ctx(), d.address, task.Filter()) if err != nil { return errors.Wrap(err, "error getting events ch") } for eventOrEOSE := range ch { - if !eventOrEOSE.EOSE() { + if eventOrEOSE.EOSE() { + task.OnReceivedEOSE() + } else { d.metrics.ReportReceivedEvent(d.address) d.receivedEventPublisher.Publish(d.address, eventOrEOSE.Event()) } @@ -333,52 +308,10 @@ func (d *RelayDownloader) downloadMessagesWithErr(ctx context.Context, filter do return nil } -func (d *RelayDownloader) UpdateSubscription(ctx context.Context, isDifferentThanPrevious bool, publicKeys PublicKeys) error { - d.publicKeySubscriptionLock.Lock() - defer d.publicKeySubscriptionLock.Unlock() - - if d.publicKeySubscriptionCancelFunc != nil && !isDifferentThanPrevious { - return nil - } - - if d.publicKeySubscriptionCancelFunc != nil { - d.publicKeySubscriptionCancelFunc() - } - - var pTags []domain.FilterTag - for _, publicKey := range publicKeys.PublicKeysToMonitor() { - tag, err := domain.NewFilterTag(domain.TagProfile, publicKey.Hex()) - if err != nil { - return errors.Wrap(err, "error creating a filter tag") - } - pTags = append(pTags, tag) - } - - ctx, cancel := context.WithCancel(ctx) - - go d.downloadMessages(ctx, domain.NewFilter( - nil, - nil, - nil, - publicKeys.All(), - d.downloadSince(), - )) - - go d.downloadMessages(ctx, domain.NewFilter( - nil, - nil, - pTags, - nil, - d.downloadSince(), - )) - - d.publicKeySubscriptionCancelFunc = cancel - return nil -} - type RelayDownloaderFactory struct { relayConnections RelayConnections receivedEventPublisher ReceivedEventPublisher + scheduler Scheduler logger logging.Logger metrics Metrics } @@ -386,12 +319,14 @@ type RelayDownloaderFactory struct { func NewRelayDownloaderFactory( relayConnections RelayConnections, receivedEventPublisher ReceivedEventPublisher, + scheduler Scheduler, logger logging.Logger, metrics Metrics, ) *RelayDownloaderFactory { return &RelayDownloaderFactory{ relayConnections: relayConnections, receivedEventPublisher: receivedEventPublisher, + scheduler: scheduler, logger: logger.New("relayDownloaderFactory"), metrics: metrics, } @@ -400,6 +335,7 @@ func NewRelayDownloaderFactory( func (r *RelayDownloaderFactory) CreateRelayDownloader(address domain.RelayAddress) (*RelayDownloader, error) { return NewRelayDownloader( address, + r.scheduler, r.receivedEventPublisher, r.relayConnections, r.logger, diff --git a/service/domain/downloader/scheduler.go b/service/domain/downloader/scheduler.go new file mode 100644 index 0000000..4e47876 --- /dev/null +++ b/service/domain/downloader/scheduler.go @@ -0,0 +1,400 @@ +package downloader + +import ( + "context" + "slices" + "sync" + "time" + + "github.com/boreq/errors" + "github.com/planetary-social/nos-event-service/internal/logging" + "github.com/planetary-social/nos-event-service/service/domain" +) + +const ( + sendOutTasksEvery = 10 * time.Millisecond + + initialWindowAge = 1 * time.Hour + windowSize = 1 * time.Minute + + timeWindowTaskConcurrency = 1 +) + +type CurrentTimeProvider interface { + GetCurrentTime() time.Time +} + +type Task interface { + Ctx() context.Context + Filter() domain.Filter + + OnReceivedEOSE() + OnError(err error) +} + +type Scheduler interface { + GetTasks(ctx context.Context, relay domain.RelayAddress) (<-chan Task, error) +} + +type TaskScheduler struct { + taskGeneratorsLock sync.Mutex + taskGenerators map[domain.RelayAddress]*RelayTaskGenerator + + publicKeySource PublicKeySource + currentTimeProvider CurrentTimeProvider + logger logging.Logger +} + +func NewTaskScheduler( + publicKeySource PublicKeySource, + currentTimeProvider CurrentTimeProvider, + logger logging.Logger, +) *TaskScheduler { + return &TaskScheduler{ + taskGenerators: make(map[domain.RelayAddress]*RelayTaskGenerator), + publicKeySource: publicKeySource, + currentTimeProvider: currentTimeProvider, + logger: logger.New("taskScheduler"), + } +} + +func (t *TaskScheduler) GetTasks(ctx context.Context, relay domain.RelayAddress) (<-chan Task, error) { + generator, err := t.getOrCreateGeneratorWithLock(relay) + if err != nil { + return nil, errors.Wrap(err, "error getting a generator") + } + + ch := make(chan Task) + generator.AddSubscription(ctx, ch) + return ch, nil +} + +func (t *TaskScheduler) Run(ctx context.Context) error { + for { + hadTasks, err := t.sendOutTasks() + if err != nil { + return errors.Wrap(err, "error sending out tasks") + } + + if hadTasks { + select { + case <-ctx.Done(): + return ctx.Err() + default: + continue + } + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(sendOutTasksEvery): + continue + } + } +} + +func (t *TaskScheduler) sendOutTasks() (bool, error) { + t.taskGeneratorsLock.Lock() + defer t.taskGeneratorsLock.Unlock() + + atLeastOneHadTasks := false + for _, taskGenerator := range t.taskGenerators { + hadTasks, err := taskGenerator.SendOutTasks() + if err != nil { + return false, errors.Wrap(err, "error calling task generator") + } + if hadTasks { + atLeastOneHadTasks = true + } + } + + return atLeastOneHadTasks, nil +} + +func (t *TaskScheduler) getOrCreateGeneratorWithLock(address domain.RelayAddress) (*RelayTaskGenerator, error) { + t.taskGeneratorsLock.Lock() + defer t.taskGeneratorsLock.Unlock() + + v, ok := t.taskGenerators[address] + if ok { + return v, nil + } + + v, err := NewRelayTaskGenerator(t.publicKeySource, t.currentTimeProvider, t.logger) + if err != nil { + return nil, errors.Wrap(err, "error creating a task generator") + } + t.taskGenerators[address] = v + return v, nil +} + +type taskSubscription struct { + ctx context.Context + ch chan Task +} + +func newTaskSubscription(ctx context.Context, ch chan Task) *taskSubscription { + return &taskSubscription{ + ctx: ctx, + ch: ch, + } +} + +type RelayTaskGenerator struct { + lock sync.Mutex + + taskSubscriptions []*taskSubscription + + globalTask *TimeWindowTaskGenerator + authorTask *TimeWindowTaskGenerator + tagTask *TimeWindowTaskGenerator + + publicKeySource PublicKeySource + logger logging.Logger +} + +func NewRelayTaskGenerator( + publicKeySource PublicKeySource, + currentTimeProvider CurrentTimeProvider, + logger logging.Logger, +) (*RelayTaskGenerator, error) { + globalTask, err := NewTimeWindowTaskGenerator( + globalEventKindsToDownload, + nil, + nil, + currentTimeProvider, + logger, + ) + if err != nil { + return nil, errors.Wrap(err, "error creating the global task") + } + authorTask, err := NewTimeWindowTaskGenerator( + nil, + nil, + nil, + currentTimeProvider, + logger, + ) + if err != nil { + return nil, errors.Wrap(err, "error creating the author task") + } + tagTask, err := NewTimeWindowTaskGenerator( + nil, + nil, + nil, + currentTimeProvider, + logger, + ) + if err != nil { + return nil, errors.Wrap(err, "error creating the tag task") + } + + return &RelayTaskGenerator{ + publicKeySource: publicKeySource, + globalTask: globalTask, + authorTask: authorTask, + tagTask: tagTask, + logger: logger.New("relayTaskGenerator"), + }, nil +} + +func (t *RelayTaskGenerator) AddSubscription(ctx context.Context, ch chan Task) { + t.lock.Lock() + defer t.lock.Unlock() + + taskSubscription := newTaskSubscription(ctx, ch) + t.taskSubscriptions = append(t.taskSubscriptions, taskSubscription) +} + +func (t *RelayTaskGenerator) SendOutTasks() (bool, error) { + t.lock.Lock() + defer t.lock.Unlock() + + slices.DeleteFunc(t.taskSubscriptions, func(subscription *taskSubscription) bool { + select { + case <-subscription.ctx.Done(): + return true + default: + return false + } + }) + + sentTasksForAtLeastOneSubscription := false + for _, taskSubscription := range t.taskSubscriptions { + numberOfSentTasks, err := t.pushTasks(taskSubscription.ctx, taskSubscription.ch) + if err != nil { + return false, errors.Wrap(err, "error sending out generators") + } + if numberOfSentTasks > 0 { + sentTasksForAtLeastOneSubscription = true + } + } + + return sentTasksForAtLeastOneSubscription, nil +} + +func (t *RelayTaskGenerator) pushTasks(ctx context.Context, ch chan<- Task) (int, error) { + tasks, err := t.getTasksToPush(ctx) + if err != nil { + return 0, errors.Wrap(err, "error getting tasks to push") + } + + for _, task := range tasks { + select { + case <-ctx.Done(): + return 0, ctx.Err() + case ch <- task: + continue + } + } + return len(tasks), nil +} + +func (t *RelayTaskGenerator) getTasksToPush(ctx context.Context) ([]Task, error) { + if err := t.updateFilters(ctx); err != nil { + return nil, errors.Wrap(err, "error updating filters") + } + + var result []Task + for _, generator := range t.generators() { + tasks, err := generator.Generate(ctx) + if err != nil { + return nil, errors.Wrap(err, "error calling one of the generators") + } + result = append(result, tasks...) + } + return result, nil +} + +func (t *RelayTaskGenerator) generators() []*TimeWindowTaskGenerator { + generators := []*TimeWindowTaskGenerator{t.globalTask} + + if len(t.authorTask.authors) > 0 { + generators = append(generators, t.authorTask) + } + + if len(t.tagTask.tags) > 0 { + generators = append(generators, t.tagTask) + } + + return generators +} + +func (t *RelayTaskGenerator) updateFilters(ctx context.Context) error { + publicKeys, err := t.publicKeySource.GetPublicKeys(ctx) + if err != nil { + return errors.Wrap(err, "error getting public keys") + } + + var pTags []domain.FilterTag + for _, publicKey := range publicKeys.PublicKeysToMonitor() { + tag, err := domain.NewFilterTag(domain.TagProfile, publicKey.Hex()) + if err != nil { + return errors.Wrap(err, "error creating a filter tag") + } + pTags = append(pTags, tag) + } + + t.authorTask.UpdateAuthors(publicKeys.All()) + t.tagTask.UpdateTags(pTags) + return nil +} + +type TimeWindowTaskGenerator struct { + kinds []domain.EventKind + tags []domain.FilterTag + authors []domain.PublicKey + + lastWindow TimeWindow + runningTimeWindowTasks []*TimeWindowTask + lock sync.Mutex + + currentTimeProvider CurrentTimeProvider + logger logging.Logger +} + +func NewTimeWindowTaskGenerator( + kinds []domain.EventKind, + tags []domain.FilterTag, + authors []domain.PublicKey, + currentTimeProvider CurrentTimeProvider, + logger logging.Logger, +) (*TimeWindowTaskGenerator, error) { + now := currentTimeProvider.GetCurrentTime() + + startingWindow, err := NewTimeWindow(now.Add(-initialWindowAge-windowSize), windowSize) + if err != nil { + return nil, errors.Wrap(err, "error creating the starting time window") + } + + return &TimeWindowTaskGenerator{ + lastWindow: startingWindow, + kinds: kinds, + tags: tags, + authors: authors, + currentTimeProvider: currentTimeProvider, + logger: logger.New("timeWindowTaskGenerator"), + }, nil +} + +func (t *TimeWindowTaskGenerator) Generate(ctx context.Context) ([]Task, error) { + t.lock.Lock() + defer t.lock.Unlock() + + t.runningTimeWindowTasks = slices.DeleteFunc(t.runningTimeWindowTasks, func(task *TimeWindowTask) bool { + return task.CheckIfDoneAndEnd() + }) + + var result []Task + + for i := len(t.runningTimeWindowTasks); i < timeWindowTaskConcurrency; i++ { + task, ok, err := t.maybeGenerateNewTask(ctx) + if err != nil { + return nil, errors.Wrap(err, "error generating a new task") + } + if ok { + t.runningTimeWindowTasks = append(t.runningTimeWindowTasks, task) + result = append(result, task) + } + } + + for _, task := range t.runningTimeWindowTasks { + ok, err := task.MaybeReset(ctx, t.kinds, t.tags, t.authors) + if err != nil { + return nil, errors.Wrap(err, "error resetting a task") + } + if ok { + result = append(result, task) + } + } + return result, nil +} + +func (t *TimeWindowTaskGenerator) UpdateTags(tags []domain.FilterTag) { + t.lock.Lock() + defer t.lock.Unlock() + + t.tags = tags +} + +func (t *TimeWindowTaskGenerator) UpdateAuthors(authors []domain.PublicKey) { + t.lock.Lock() + defer t.lock.Unlock() + + t.authors = authors +} + +func (t *TimeWindowTaskGenerator) maybeGenerateNewTask(ctx context.Context) (*TimeWindowTask, bool, error) { + nextWindow := t.lastWindow.Advance() + now := t.currentTimeProvider.GetCurrentTime() + if nextWindow.End().After(now.Add(-time.Minute)) { + return nil, false, nil + } + t.lastWindow = nextWindow + v, err := NewTimeWindowTask(ctx, t.kinds, t.tags, t.authors, nextWindow) + if err != nil { + return nil, false, errors.Wrap(err, "error creating a task") + } + return v, true, nil +} diff --git a/service/domain/downloader/scheduler_test.go b/service/domain/downloader/scheduler_test.go new file mode 100644 index 0000000..32a54e5 --- /dev/null +++ b/service/domain/downloader/scheduler_test.go @@ -0,0 +1,305 @@ +package downloader_test + +import ( + "context" + "slices" + "sync" + "testing" + "time" + + "github.com/planetary-social/nos-event-service/internal" + "github.com/planetary-social/nos-event-service/internal/fixtures" + "github.com/planetary-social/nos-event-service/internal/logging" + "github.com/planetary-social/nos-event-service/service/adapters/mocks" + "github.com/planetary-social/nos-event-service/service/domain" + "github.com/planetary-social/nos-event-service/service/domain/downloader" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + numberOfTaskTypes = 3 + testTimeout = 10 * time.Second + delayWhenWaitingToConsiderThatTasksReceived = 5 * time.Second + + waitFor = 10 * time.Second + tick = 100 * time.Millisecond +) + +func TestTaskScheduler_SchedulerWaitsForTasksToCompleteBeforeProducingMore(t *testing.T) { + t.Parallel() + + ctx := fixtures.TestContext(t) + ts := newTestedTaskScheduler(ctx, t) + + start := date(2023, time.December, 27, 10, 30, 00) + ts.CurrentTimeProvider.SetCurrentTime(start) + ts.PublicKeySource.SetPublicKeys(downloader.NewPublicKeys( + []domain.PublicKey{fixtures.SomePublicKey()}, + []domain.PublicKey{fixtures.SomePublicKey()}, + )) + + ch, err := ts.Scheduler.GetTasks(ctx, fixtures.SomeRelayAddress()) + require.NoError(t, err) + + tasks := collectAllTasks(ctx, ch, false) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + require.Equal(t, numberOfTaskTypes, len(tasks.Tasks())) + }, waitFor, tick) + + <-time.After(5 * time.Second) + require.Len(t, tasks.Tasks(), numberOfTaskTypes) +} + +func TestTaskScheduler_SchedulerDoesNotProduceEmptyTasks(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(fixtures.TestContext(t), testTimeout) + defer cancel() + + start := date(2023, time.December, 27, 10, 30, 00) + + ts := newTestedTaskScheduler(ctx, t) + ts.CurrentTimeProvider.SetCurrentTime(start) + ts.PublicKeySource.SetPublicKeys(downloader.NewPublicKeys(nil, nil)) + + ch, err := ts.Scheduler.GetTasks(ctx, fixtures.SomeRelayAddress()) + require.NoError(t, err) + + tasks := collectAllTasks(ctx, ch, false) + + <-time.After(5 * time.Second) + require.Len(t, tasks.Tasks(), 1) +} + +func TestTaskScheduler_SchedulerProducesTasksFromSequentialTimeWindowsLeadingUpToCurrentTime(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(fixtures.TestContext(t), testTimeout) + defer cancel() + + start := date(2023, time.December, 27, 10, 30, 00) + + ts := newTestedTaskScheduler(ctx, t) + ts.CurrentTimeProvider.SetCurrentTime(start) + ts.PublicKeySource.SetPublicKeys(downloader.NewPublicKeys( + []domain.PublicKey{fixtures.SomePublicKey()}, + []domain.PublicKey{fixtures.SomePublicKey()}, + )) + + ch, err := ts.Scheduler.GetTasks(ctx, fixtures.SomeRelayAddress()) + require.NoError(t, err) + + tasks := collectAllTasks(ctx, ch, true) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + filters := make(map[downloader.TimeWindow][]domain.Filter) + for _, task := range tasks.Tasks() { + start := task.Filter().Since() + duration := task.Filter().Until().Sub(*start) + window := downloader.MustNewTimeWindow(*start, duration) + filters[window] = append(filters[window], task.Filter()) + } + + firstWindowStart := date(2023, time.December, 27, 9, 30, 00) + var expectedWindows []downloader.TimeWindow + for i := 0; i < 59; i++ { + window := downloader.MustNewTimeWindow(firstWindowStart.Add(time.Duration(i)*time.Minute), 1*time.Minute) + expectedWindows = append(expectedWindows, window) + } + + var windows []downloader.TimeWindow + for window, filters := range filters { + assert.Equal(t, numberOfTaskTypes, len(filters)) + windows = append(windows, window) + } + + cmp := func(a, b downloader.TimeWindow) int { + return a.Start().Compare(b.Start()) + } + + slices.SortFunc(expectedWindows, cmp) + slices.SortFunc(windows, cmp) + assertEqualWindows(t, expectedWindows, windows) + }, waitFor, tick) +} + +func TestTaskScheduler_ThereIsOneWindowOfDelayToLetRelaysSyncData(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(fixtures.TestContext(t), testTimeout) + defer cancel() + + start := date(2023, time.December, 27, 10, 30, 00) + + ts := newTestedTaskScheduler(ctx, t) + ts.CurrentTimeProvider.SetCurrentTime(start) + ts.PublicKeySource.SetPublicKeys(downloader.NewPublicKeys( + nil, + nil, + )) + + ch, err := ts.Scheduler.GetTasks(ctx, fixtures.SomeRelayAddress()) + require.NoError(t, err) + + tasks := collectAllTasks(ctx, ch, true) + + require.EventuallyWithT(t, func(t *assert.CollectT) { + var windows []downloader.TimeWindow + for _, task := range tasks.Tasks() { + start := task.Filter().Since() + duration := task.Filter().Until().Sub(*start) + window := downloader.MustNewTimeWindow(*start, duration) + windows = append(windows, window) + } + + slices.SortFunc(windows, func(a, b downloader.TimeWindow) int { + return a.Start().Compare(b.Start()) + }) + + assert.Len(t, windows, 59) + lastWindow := windows[len(windows)-1] + assert.Equal(t, date(2023, time.December, 27, 10, 29, 00), lastWindow.End().UTC()) + }, waitFor, tick) +} + +func TestTaskScheduler_TerminatesTasks(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(fixtures.TestContext(t), testTimeout) + defer cancel() + + start := date(2023, time.December, 27, 10, 30, 00) + + ts := newTestedTaskScheduler(ctx, t) + ts.CurrentTimeProvider.SetCurrentTime(start) + + ch, err := ts.Scheduler.GetTasks(ctx, fixtures.SomeRelayAddress()) + require.NoError(t, err) + + firstTaskCh := make(chan downloader.Task) + + go func() { + first := true + for { + select { + case <-ctx.Done(): + return + case v := <-ch: + go v.OnReceivedEOSE() + if first { + first = false + select { + case firstTaskCh <- v: + case <-ctx.Done(): + return + } + } + } + } + }() + + select { + case v := <-firstTaskCh: + require.EventuallyWithT(t, func(collect *assert.CollectT) { + assert.Error(collect, v.Ctx().Err()) + }, 5*time.Second, 10*time.Millisecond) + case <-ctx.Done(): + t.Fatal("timeout") + } + +} + +type testedTaskScheduler struct { + Scheduler *downloader.TaskScheduler + CurrentTimeProvider *mocks.CurrentTimeProvider + PublicKeySource *mockPublicKeySource +} + +func newTestedTaskScheduler(ctx context.Context, tb testing.TB) *testedTaskScheduler { + currentTimeProvider := mocks.NewCurrentTimeProvider() + source := newMockPublicKeySource() + logger := logging.NewDevNullLogger() + scheduler := downloader.NewTaskScheduler(source, currentTimeProvider, logger) + go func() { + _ = scheduler.Run(ctx) + }() + + return &testedTaskScheduler{ + Scheduler: scheduler, + PublicKeySource: source, + CurrentTimeProvider: currentTimeProvider, + } +} + +type mockPublicKeySource struct { + publicKeys downloader.PublicKeys +} + +func newMockPublicKeySource() *mockPublicKeySource { + return &mockPublicKeySource{ + publicKeys: downloader.NewPublicKeys(nil, nil), + } +} + +func (p *mockPublicKeySource) SetPublicKeys(publicKeys downloader.PublicKeys) { + p.publicKeys = publicKeys +} + +func (p *mockPublicKeySource) GetPublicKeys(ctx context.Context) (downloader.PublicKeys, error) { + return p.publicKeys, nil +} + +func date(year int, month time.Month, day, hour, min, sec int) time.Time { + return time.Date(year, month, day, hour, min, sec, 0, time.UTC) +} + +func assertEqualWindows(tb require.TestingT, a, b []downloader.TimeWindow) { + assert.Equal(tb, len(a), len(b)) + if len(a) == len(b) { + for i := 0; i < len(a); i++ { + assert.True(tb, a[i].Start().Equal(b[i].Start())) + assert.True(tb, a[i].End().Equal(b[i].End())) + } + } +} + +type collectedTasks struct { + tasks []downloader.Task + l sync.Mutex +} + +func newCollectedTasks() *collectedTasks { + return &collectedTasks{} +} + +func (c *collectedTasks) add(task downloader.Task) { + c.l.Lock() + defer c.l.Unlock() + c.tasks = append(c.tasks, task) +} + +func (c *collectedTasks) Tasks() []downloader.Task { + c.l.Lock() + defer c.l.Unlock() + return internal.CopySlice(c.tasks) +} + +func collectAllTasks(ctx context.Context, ch <-chan downloader.Task, ackAll bool) *collectedTasks { + v := newCollectedTasks() + go func() { + for { + select { + case <-ctx.Done(): + return + case task := <-ch: + if ackAll { + go task.OnReceivedEOSE() + } + v.add(task) + } + } + }() + return v +} diff --git a/service/domain/downloader/time_window.go b/service/domain/downloader/time_window.go new file mode 100644 index 0000000..b104932 --- /dev/null +++ b/service/domain/downloader/time_window.go @@ -0,0 +1,44 @@ +package downloader + +import ( + "time" + + "github.com/boreq/errors" +) + +type TimeWindow struct { + start time.Time + duration time.Duration +} + +func NewTimeWindow(start time.Time, duration time.Duration) (TimeWindow, error) { + if duration == 0 { + return TimeWindow{}, errors.New("time window must have a duration") + } + return TimeWindow{start: start, duration: duration}, nil +} + +func MustNewTimeWindow(start time.Time, duration time.Duration) TimeWindow { + v, err := NewTimeWindow(start, duration) + if err != nil { + panic(err) + } + return v +} + +func (t TimeWindow) Start() time.Time { + return t.start +} + +func (t TimeWindow) End() time.Time { + return t.start.Add(t.duration) +} + +func (t TimeWindow) Advance() TimeWindow { + newStart := t.start.Add(t.duration) + newWindow, err := NewTimeWindow(newStart, t.duration) + if err != nil { + panic(err) // guaranteed by invariants + } + return newWindow +} diff --git a/service/domain/downloader/time_window_task.go b/service/domain/downloader/time_window_task.go new file mode 100644 index 0000000..109d298 --- /dev/null +++ b/service/domain/downloader/time_window_task.go @@ -0,0 +1,138 @@ +package downloader + +import ( + "context" + "sync" + + "github.com/boreq/errors" + "github.com/planetary-social/nos-event-service/internal" + "github.com/planetary-social/nos-event-service/service/domain" +) + +var ( + TimeWindowTaskStateStarted = TimeWindowTaskState{"started"} + TimeWindowTaskStateDone = TimeWindowTaskState{"done"} + TimeWindowTaskStateError = TimeWindowTaskState{"error"} +) + +type TimeWindowTaskState struct { + s string +} + +type TimeWindowTask struct { + filter domain.Filter + window TimeWindow + + ctx context.Context + cancel context.CancelFunc + state TimeWindowTaskState + lock sync.Mutex +} + +func NewTimeWindowTask( + ctx context.Context, + kinds []domain.EventKind, + tags []domain.FilterTag, + authors []domain.PublicKey, + window TimeWindow, +) (*TimeWindowTask, error) { + ctx, cancel := context.WithCancel(ctx) + + t := &TimeWindowTask{ + ctx: ctx, + cancel: cancel, + state: TimeWindowTaskStateStarted, + window: window, + } + if err := t.updateFilter(kinds, tags, authors); err != nil { + return nil, errors.New("error updating the filter") + } + return t, nil +} + +func (t *TimeWindowTask) Ctx() context.Context { + return t.ctx +} + +func (t *TimeWindowTask) Filter() domain.Filter { + return t.filter +} + +func (t *TimeWindowTask) OnReceivedEOSE() { + t.lock.Lock() + defer t.lock.Unlock() + + t.state = TimeWindowTaskStateDone +} + +func (t *TimeWindowTask) OnError(err error) { + t.lock.Lock() + defer t.lock.Unlock() + + if t.state != TimeWindowTaskStateDone { + t.state = TimeWindowTaskStateError + } +} + +func (t *TimeWindowTask) CheckIfDoneAndEnd() bool { + t.lock.Lock() + defer t.lock.Unlock() + + if t.state != TimeWindowTaskStateDone { + return false + } + + t.cancel() + return true +} + +func (t *TimeWindowTask) MaybeReset(ctx context.Context, kinds []domain.EventKind, tags []domain.FilterTag, authors []domain.PublicKey) (bool, error) { + t.lock.Lock() + defer t.lock.Unlock() + + if t.state == TimeWindowTaskStateDone { + return false, errors.New("why are we trying to reset a completed task?") + } + + if !t.isDead() { + return false, nil + } + + t.cancel() + + ctx, cancel := context.WithCancel(ctx) + t.ctx = ctx + t.cancel = cancel + t.state = TimeWindowTaskStateStarted + + if err := t.updateFilter(kinds, tags, authors); err != nil { + return false, errors.New("error updating the filter") + } + + return true, nil +} + +func (t *TimeWindowTask) updateFilter(kinds []domain.EventKind, tags []domain.FilterTag, authors []domain.PublicKey) error { + filter, err := domain.NewFilter( + nil, + kinds, + tags, + authors, + internal.Pointer(t.window.Start()), + internal.Pointer(t.window.End()), + ) + if err != nil { + return errors.Wrap(err, "error creating a filter") + } + + t.filter = filter + + return nil +} + +func (t *TimeWindowTask) isDead() bool { + if err := t.ctx.Err(); err != nil { + return true + } + return t.state == TimeWindowTaskStateError +} diff --git a/service/domain/downloader/time_window_task_test.go b/service/domain/downloader/time_window_task_test.go new file mode 100644 index 0000000..85e8df7 --- /dev/null +++ b/service/domain/downloader/time_window_task_test.go @@ -0,0 +1,22 @@ +package downloader_test + +import ( + "testing" + + "github.com/planetary-social/nos-event-service/internal/fixtures" + "github.com/planetary-social/nos-event-service/service/domain/downloader" + "github.com/stretchr/testify/require" +) + +func TestTimeWindowTask_ReportingErrorsAfterTaskIsConsideredToBeDoneShouldBeIgnored(t *testing.T) { + ctx := fixtures.TestContext(t) + + task, err := downloader.NewTimeWindowTask(ctx, nil, nil, nil, fixtures.SomeTimeWindow()) + require.NoError(t, err) + + task.OnReceivedEOSE() + task.OnError(fixtures.SomeError()) + + ok := task.CheckIfDoneAndEnd() + require.True(t, ok) +} diff --git a/service/domain/downloader/time_window_test.go b/service/domain/downloader/time_window_test.go new file mode 100644 index 0000000..f85f98d --- /dev/null +++ b/service/domain/downloader/time_window_test.go @@ -0,0 +1,23 @@ +package downloader + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestTimeWindow(t *testing.T) { + start := time.Now() + + first, err := NewTimeWindow(start, time.Minute) + require.NoError(t, err) + + require.Equal(t, start, first.Start()) + require.Equal(t, start.Add(time.Minute), first.End()) + + second := first.Advance() + require.Equal(t, start.Add(time.Minute), second.Start()) + require.Equal(t, start.Add(2*time.Minute), second.End()) + require.True(t, second.Start().Equal(first.End())) +} diff --git a/service/domain/filter.go b/service/domain/filter.go index 52a4033..415f7a9 100644 --- a/service/domain/filter.go +++ b/service/domain/filter.go @@ -18,7 +18,8 @@ func NewFilter( eventTags []FilterTag, authors []PublicKey, since *time.Time, -) Filter { + until *time.Time, +) (Filter, error) { filter := nostr.Filter{ IDs: nil, Kinds: nil, @@ -50,13 +51,52 @@ func NewFilter( filter.Authors = append(filter.Authors, author.Hex()) } + if since != nil && until != nil { + if !since.Before(*until) { + return Filter{}, errors.New("since must be before until") + } + } + if since != nil { filter.Since = internal.Pointer(nostr.Timestamp(since.Unix())) } + if until != nil { + filter.Until = internal.Pointer(nostr.Timestamp(until.Unix())) + } + return Filter{ filter: filter, + }, nil +} + +func MustNewFilter( + eventIDs []EventId, + eventKinds []EventKind, + eventTags []FilterTag, + authors []PublicKey, + since *time.Time, + until *time.Time, +) Filter { + v, err := NewFilter(eventIDs, eventKinds, eventTags, authors, since, until) + if err != nil { + panic(err) + } + return v +} + +func (f Filter) Since() *time.Time { + if f.filter.Since == nil { + return nil + } + return internal.Pointer(time.Unix(int64(*f.filter.Since), 0)) +} + +func (f Filter) Until() *time.Time { + if f.filter.Until == nil { + return nil } + return internal.Pointer(time.Unix(int64(*f.filter.Until), 0)) } func (e Filter) Libfilter() nostr.Filter {