From 69d876d519658c957c12bfa9422dd6c472e7bf1b Mon Sep 17 00:00:00 2001 From: Sergei Khoroshilov Date: Wed, 27 Mar 2024 02:45:51 +0100 Subject: [PATCH] feat: migrate other services to use cases (2nd iteration) --- cmd/swat4master/application/application.go | 2 - cmd/swat4master/container/container.go | 39 ++- cmd/swat4master/modules/cleaner/cleaner.go | 40 ++- .../modules/refresher/refresher.go | 25 +- cmd/swat4master/modules/reviver/reviver.go | 30 +- internal/core/entities/addr/addr_test.go | 74 +++++ internal/core/entities/addr/public.go | 36 +++ internal/core/entities/addr/public_test.go | 58 ++++ internal/core/entities/server/server_test.go | 29 +- internal/core/usecases/addserver/addserver.go | 16 +- .../core/usecases/addserver/addserver_test.go | 65 +--- .../usecases/cleanservers/cleanservers.go} | 73 ++--- .../cleanservers/cleanservers_test.go | 190 ++++++++++++ internal/core/usecases/getserver/getserver.go | 16 +- .../core/usecases/getserver/getserver_test.go | 60 +--- .../usecases/refreshservers/refreshservers.go | 70 +++++ .../refreshservers/refreshservers_test.go | 139 +++++++++ .../usecases/reviveservers/reviveservers.go | 110 +++++++ .../reviveservers/reviveservers_test.go | 166 +++++++++++ .../persistence/memory/instances/instances.go | 4 +- internal/rest/api/servers_add.go | 13 +- internal/rest/api/servers_view.go | 24 +- internal/services/cleaning/cleaning_test.go | 130 -------- .../services/discovery/finding/finding.go | 123 -------- .../discovery/finding/finding_test.go | 282 ------------------ internal/testutils/factories/info.go | 41 +++ internal/testutils/factories/instance.go | 19 ++ internal/testutils/factories/server.go | 9 +- tests/api/servers_view_test.go | 7 +- tests/modules/cleaner_test.go | 127 ++++++-- tests/modules/refresher_test.go | 212 +++++++------ tests/modules/reviver_test.go | 206 +++++++------ 32 files changed, 1419 insertions(+), 1016 deletions(-) create mode 100644 internal/core/entities/addr/addr_test.go create mode 100644 internal/core/entities/addr/public.go create mode 100644 internal/core/entities/addr/public_test.go rename internal/{services/cleaning/cleaning.go => core/usecases/cleanservers/cleanservers.go} (51%) create mode 100644 internal/core/usecases/cleanservers/cleanservers_test.go create mode 100644 internal/core/usecases/refreshservers/refreshservers.go create mode 100644 internal/core/usecases/refreshservers/refreshservers_test.go create mode 100644 internal/core/usecases/reviveservers/reviveservers.go create mode 100644 internal/core/usecases/reviveservers/reviveservers_test.go delete mode 100644 internal/services/cleaning/cleaning_test.go delete mode 100644 internal/services/discovery/finding/finding.go delete mode 100644 internal/services/discovery/finding/finding_test.go create mode 100644 internal/testutils/factories/info.go create mode 100644 internal/testutils/factories/instance.go diff --git a/cmd/swat4master/application/application.go b/cmd/swat4master/application/application.go index d69aa0d..4b17cc8 100644 --- a/cmd/swat4master/application/application.go +++ b/cmd/swat4master/application/application.go @@ -7,7 +7,6 @@ import ( "github.com/sergeii/swat4master/cmd/swat4master/container" "github.com/sergeii/swat4master/cmd/swat4master/logging" "github.com/sergeii/swat4master/cmd/swat4master/persistence" - "github.com/sergeii/swat4master/internal/services/discovery/finding" "github.com/sergeii/swat4master/internal/services/monitoring" "github.com/sergeii/swat4master/internal/services/probe" "github.com/sergeii/swat4master/internal/validation" @@ -20,7 +19,6 @@ var Module = fx.Module("application", fx.Provide(validation.New), fx.Provide(persistence.Provide), fx.Provide(monitoring.NewMetricService), - fx.Provide(finding.NewService), fx.Provide(probe.NewService), container.Module, ) diff --git a/cmd/swat4master/container/container.go b/cmd/swat4master/container/container.go index ed91f56..89250b2 100644 --- a/cmd/swat4master/container/container.go +++ b/cmd/swat4master/container/container.go @@ -4,20 +4,26 @@ import ( "go.uber.org/fx" "github.com/sergeii/swat4master/internal/core/usecases/addserver" + "github.com/sergeii/swat4master/internal/core/usecases/cleanservers" "github.com/sergeii/swat4master/internal/core/usecases/getserver" "github.com/sergeii/swat4master/internal/core/usecases/listservers" + "github.com/sergeii/swat4master/internal/core/usecases/refreshservers" "github.com/sergeii/swat4master/internal/core/usecases/removeserver" "github.com/sergeii/swat4master/internal/core/usecases/renewserver" "github.com/sergeii/swat4master/internal/core/usecases/reportserver" + "github.com/sergeii/swat4master/internal/core/usecases/reviveservers" ) type Container struct { - GetServer getserver.UseCase - AddServer addserver.UseCase - ListServers listservers.UseCase - ReportServer reportserver.UseCase - RenewServer renewserver.UseCase - RemoveServer removeserver.UseCase + GetServer getserver.UseCase + AddServer addserver.UseCase + ListServers listservers.UseCase + ReportServer reportserver.UseCase + RenewServer renewserver.UseCase + RemoveServer removeserver.UseCase + CleanServers cleanservers.UseCase + RefreshServers refreshservers.UseCase + ReviveServers reviveservers.UseCase } func New( @@ -27,14 +33,20 @@ func New( reportServerUseCase reportserver.UseCase, renewServerUseCase renewserver.UseCase, removeServerUseCase removeserver.UseCase, + cleanServersUseCase cleanservers.UseCase, + refreshServersUseCase refreshservers.UseCase, + reviveServersUseCase reviveservers.UseCase, ) Container { return Container{ - GetServer: getServerUseCase, - AddServer: addServerUseCase, - ListServers: listServersUseCase, - ReportServer: reportServerUseCase, - RenewServer: renewServerUseCase, - RemoveServer: removeServerUseCase, + GetServer: getServerUseCase, + AddServer: addServerUseCase, + ListServers: listServersUseCase, + ReportServer: reportServerUseCase, + RenewServer: renewServerUseCase, + RemoveServer: removeServerUseCase, + CleanServers: cleanServersUseCase, + RefreshServers: refreshServersUseCase, + ReviveServers: reviveServersUseCase, } } @@ -45,5 +57,8 @@ var Module = fx.Module("container", fx.Provide(reportserver.New), fx.Provide(renewserver.New), fx.Provide(removeserver.New), + fx.Provide(cleanservers.New), + fx.Provide(refreshservers.New), + fx.Provide(reviveservers.New), fx.Provide(New), ) diff --git a/cmd/swat4master/modules/cleaner/cleaner.go b/cmd/swat4master/modules/cleaner/cleaner.go index e56125a..eb42448 100644 --- a/cmd/swat4master/modules/cleaner/cleaner.go +++ b/cmd/swat4master/modules/cleaner/cleaner.go @@ -2,13 +2,15 @@ package cleaner import ( "context" + "time" "github.com/jonboulle/clockwork" "github.com/rs/zerolog" "go.uber.org/fx" "github.com/sergeii/swat4master/cmd/swat4master/config" - "github.com/sergeii/swat4master/internal/services/cleaning" + "github.com/sergeii/swat4master/internal/core/usecases/cleanservers" + "github.com/sergeii/swat4master/internal/services/monitoring" ) type Cleaner struct{} @@ -18,7 +20,8 @@ func Run( stopped chan struct{}, clock clockwork.Clock, logger *zerolog.Logger, - service *cleaning.Service, + metrics *monitoring.MetricService, + uc cleanservers.UseCase, cfg config.Config, ) { ticker := clock.NewTicker(cfg.CleanInterval) @@ -39,11 +42,7 @@ func Run( close(stopped) return case <-tickerCh: - if err := service.Clean(ctx, clock.Now().Add(-cfg.CleanRetention)); err != nil { - logger.Error(). - Err(err). - Msg("Failed to clean outdated servers") - } + clean(ctx, clock, logger, metrics, uc, cfg.CleanRetention) } } } @@ -52,7 +51,8 @@ func NewCleaner( lc fx.Lifecycle, cfg config.Config, clock clockwork.Clock, - service *cleaning.Service, + metrics *monitoring.MetricService, + uc cleanservers.UseCase, logger *zerolog.Logger, ) *Cleaner { stopped := make(chan struct{}) @@ -60,7 +60,7 @@ func NewCleaner( lc.Append(fx.Hook{ OnStart: func(context.Context) error { - go Run(stop, stopped, clock, logger, service, cfg) // nolint: contextcheck + go Run(stop, stopped, clock, logger, metrics, uc, cfg) // nolint: contextcheck return nil }, OnStop: func(context.Context) error { @@ -73,10 +73,24 @@ func NewCleaner( return &Cleaner{} } +func clean( + ctx context.Context, + clock clockwork.Clock, + logger *zerolog.Logger, + metrics *monitoring.MetricService, + uc cleanservers.UseCase, + retention time.Duration, +) { + resp, err := uc.Execute(ctx, clock.Now().Add(-retention)) + if err != nil { + logger.Error(). + Err(err). + Msg("Failed to clean outdated servers") + } + metrics.CleanerRemovals.Add(float64(resp.Count)) + metrics.CleanerErrors.Add(float64(resp.Errors)) +} + var Module = fx.Module("cleaner", - fx.Provide( - fx.Private, - cleaning.NewService, - ), fx.Provide(NewCleaner), ) diff --git a/cmd/swat4master/modules/refresher/refresher.go b/cmd/swat4master/modules/refresher/refresher.go index 60988ba..3acb9aa 100644 --- a/cmd/swat4master/modules/refresher/refresher.go +++ b/cmd/swat4master/modules/refresher/refresher.go @@ -8,7 +8,8 @@ import ( "go.uber.org/fx" "github.com/sergeii/swat4master/cmd/swat4master/config" - "github.com/sergeii/swat4master/internal/services/discovery/finding" + "github.com/sergeii/swat4master/internal/core/usecases/refreshservers" + "github.com/sergeii/swat4master/internal/services/monitoring" ) type Refresher struct{} @@ -18,7 +19,8 @@ func Run( stopped chan struct{}, clock clockwork.Clock, logger *zerolog.Logger, - service *finding.Service, + metrics *monitoring.MetricService, + uc refreshservers.UseCase, cfg config.Config, ) { refresher := clock.NewTicker(cfg.DiscoveryRefreshInterval) @@ -36,7 +38,7 @@ func Run( close(stopped) return case <-refresher.Chan(): - refresh(ctx, clock, logger, service, cfg) + refresh(ctx, clock, logger, metrics, uc, cfg) } } } @@ -45,7 +47,8 @@ func NewRefresher( lc fx.Lifecycle, cfg config.Config, clock clockwork.Clock, - service *finding.Service, + metrics *monitoring.MetricService, + uc refreshservers.UseCase, logger *zerolog.Logger, ) *Refresher { stopped := make(chan struct{}) @@ -53,7 +56,7 @@ func NewRefresher( lc.Append(fx.Hook{ OnStart: func(context.Context) error { - go Run(stop, stopped, clock, logger, service, cfg) // nolint: contextcheck + go Run(stop, stopped, clock, logger, metrics, uc, cfg) // nolint: contextcheck return nil }, OnStop: func(context.Context) error { @@ -70,18 +73,22 @@ func refresh( ctx context.Context, clock clockwork.Clock, logger *zerolog.Logger, - service *finding.Service, + metrics *monitoring.MetricService, + uc refreshservers.UseCase, cfg config.Config, ) { // make sure the probes don't run beyond the next cycle of discovery deadline := clock.Now().Add(cfg.DiscoveryRefreshInterval) - cnt, err := service.RefreshDetails(ctx, deadline) + + result, err := uc.Execute(ctx, deadline) if err != nil { logger.Warn().Err(err).Msg("Unable to refresh details for servers") return } - if cnt > 0 { - logger.Info().Int("count", cnt).Msg("Added servers to details discovery queue") + + if result.Count > 0 { + metrics.DiscoveryQueueProduced.Add(float64(result.Count)) + logger.Info().Int("count", result.Count).Msg("Added servers to details discovery queue") } else { logger.Debug().Msg("Added no servers to details discovery queue") } diff --git a/cmd/swat4master/modules/reviver/reviver.go b/cmd/swat4master/modules/reviver/reviver.go index e4f9c0e..c8db343 100644 --- a/cmd/swat4master/modules/reviver/reviver.go +++ b/cmd/swat4master/modules/reviver/reviver.go @@ -8,7 +8,8 @@ import ( "go.uber.org/fx" "github.com/sergeii/swat4master/cmd/swat4master/config" - "github.com/sergeii/swat4master/internal/services/discovery/finding" + "github.com/sergeii/swat4master/internal/core/usecases/reviveservers" + "github.com/sergeii/swat4master/internal/services/monitoring" ) type Reviver struct{} @@ -18,7 +19,8 @@ func Run( stopped chan struct{}, clock clockwork.Clock, logger *zerolog.Logger, - service *finding.Service, + metrics *monitoring.MetricService, + uc reviveservers.UseCase, cfg config.Config, ) { reviver := clock.NewTicker(cfg.DiscoveryRevivalInterval) @@ -41,7 +43,7 @@ func Run( close(stopped) return case <-reviverCh: - revive(ctx, clock, logger, service, cfg) + revive(ctx, clock, logger, metrics, uc, cfg) } } } @@ -50,7 +52,8 @@ func NewReviver( lc fx.Lifecycle, cfg config.Config, clock clockwork.Clock, - service *finding.Service, + metrics *monitoring.MetricService, + uc reviveservers.UseCase, logger *zerolog.Logger, ) *Reviver { stopped := make(chan struct{}) @@ -58,7 +61,7 @@ func NewReviver( lc.Append(fx.Hook{ OnStart: func(context.Context) error { - go Run(stop, stopped, clock, logger, service, cfg) // nolint: contextcheck + go Run(stop, stopped, clock, logger, metrics, uc, cfg) // nolint: contextcheck return nil }, OnStop: func(context.Context) error { @@ -75,7 +78,8 @@ func revive( ctx context.Context, clock clockwork.Clock, logger *zerolog.Logger, - service *finding.Service, + metrics *monitoring.MetricService, + uc reviveservers.UseCase, cfg config.Config, ) { now := clock.Now() @@ -83,22 +87,24 @@ func revive( // make sure the probes don't run beyond the next cycle of discovery deadline := now.Add(cfg.DiscoveryRevivalInterval) - cnt, err := service.ReviveServers( - ctx, + ucRequest := reviveservers.NewRequest( now.Add(-cfg.DiscoveryRevivalScope), // min scope now.Add(-cfg.DiscoveryRevivalInterval), // max scope now, // min countdown now.Add(cfg.DiscoveryRevivalCountdown), // max countdown deadline, ) + result, err := uc.Execute(ctx, ucRequest) if err != nil { - logger.Warn().Err(err).Msg("Unable to refresh revive outdated servers") + logger.Warn().Err(err).Msg("Unable to revive outdated servers") return } - if cnt > 0 { - logger.Info().Int("count", cnt).Msg("Added servers to port discovery queue") + + if result.Count > 0 { + metrics.DiscoveryQueueProduced.Add(float64(result.Count)) + logger.Info().Int("count", result.Count).Msg("Added servers to revival discovery queue") } else { - logger.Debug().Msg("Added no servers to port discovery queue") + logger.Debug().Msg("Added no servers to revival discovery queue") } } diff --git a/internal/core/entities/addr/addr_test.go b/internal/core/entities/addr/addr_test.go new file mode 100644 index 0000000..b3a922d --- /dev/null +++ b/internal/core/entities/addr/addr_test.go @@ -0,0 +1,74 @@ +package addr_test + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/sergeii/swat4master/internal/core/entities/addr" +) + +func TestAddr(t *testing.T) { + tests := []struct { + name string + ip string + want bool + }{ + { + name: "public address is accepted", + ip: "1.1.1.1", + want: true, + }, + { + name: "private network address accepted", + ip: "192.168.10.12", + want: true, + }, + { + name: "another private network address is accepted", + ip: "10.0.0.1", + want: true, + }, + { + name: "loopback address is accepted", + ip: "127.0.0.1", + want: true, + }, + { + name: "non-routable ip address is not accepted", + ip: "0.0.0.0", + want: false, + }, + { + name: "invalid ip address", + ip: "256.500.0.1", + want: false, + }, + { + name: "ipv4 address is required", + ip: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", + want: false, + }, + { + name: "multicast address is not accepted", + ip: "224.0.0.1", + want: false, + }, + { + name: "link local broadcast address is not accepted", + ip: "169.254.0.1", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := addr.New(net.ParseIP(tt.ip), 10480) + if tt.want { + require.NoError(t, err) + } else { + require.ErrorIs(t, err, addr.ErrInvalidIP) + } + }) + } +} diff --git a/internal/core/entities/addr/public.go b/internal/core/entities/addr/public.go new file mode 100644 index 0000000..db4dad1 --- /dev/null +++ b/internal/core/entities/addr/public.go @@ -0,0 +1,36 @@ +package addr + +import ( + "errors" + "net" +) + +type PublicAddr struct { + addr Addr +} + +var BlankPublicAddr PublicAddr // nolint: gochecknoglobals + +var ErrInvalidPublicIP = errors.New("invalid public IP address") + +func NewPublicAddr(addr Addr) (PublicAddr, error) { + ipv4 := net.IPv4(addr.IP[0], addr.IP[1], addr.IP[2], addr.IP[3]) + + if ipv4.IsPrivate() || ipv4.IsLoopback() { + return PublicAddr{}, ErrInvalidPublicIP + } + + return PublicAddr{addr}, nil +} + +func MustNewPublicAddr(addr Addr) PublicAddr { + pa, err := NewPublicAddr(addr) + if err != nil { + panic(err) + } + return pa +} + +func (pa PublicAddr) ToAddr() Addr { + return pa.addr +} diff --git a/internal/core/entities/addr/public_test.go b/internal/core/entities/addr/public_test.go new file mode 100644 index 0000000..5f61422 --- /dev/null +++ b/internal/core/entities/addr/public_test.go @@ -0,0 +1,58 @@ +package addr_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/sergeii/swat4master/internal/core/entities/addr" +) + +func TestPublicAddr(t *testing.T) { + tests := []struct { + name string + ip string + want bool + }{ + { + "public address", + "147.128.88.19", + true, + }, + { + "localhost", + "127.0.0.1", + false, + }, + { + "192.168.0.0/16 range", + "192.168.1.1", + false, + }, + { + "172.16.0.0/12 range", + "172.16.128.18", + false, + }, + { + "10.0.0.0/8 range", + "10.39.1.19", + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + anyAddr := addr.MustNewFromDotted(tt.ip, 10480) + + publicAddr, err := addr.NewPublicAddr(anyAddr) + + if tt.want { + assert.NoError(t, err) + assert.Equal(t, anyAddr, publicAddr.ToAddr()) + } else { + assert.ErrorIs(t, err, addr.ErrInvalidPublicIP) + } + }) + } +} diff --git a/internal/core/entities/server/server_test.go b/internal/core/entities/server/server_test.go index 302c526..c0fbcb6 100644 --- a/internal/core/entities/server/server_test.go +++ b/internal/core/entities/server/server_test.go @@ -123,46 +123,21 @@ func TestServer_New_ValidIPAddress(t *testing.T) { ip: "1.1.1.1", want: true, }, - { - name: "private network address accepted", - ip: "192.168.10.12", - want: true, - }, - { - name: "another private network address is accepted", - ip: "10.0.0.1", - want: true, - }, { name: "loopback address is accepted", ip: "127.0.0.1", want: true, }, { - name: "invalid ip address", + name: "invalid ip address is not accepted", ip: "256.500.0.1", want: false, }, { - name: "unspecified ip address", - ip: "0.0.0.0", - want: false, - }, - { - name: "ipv4 address is required", + name: "ipv6 address is not accepted", ip: "2001:0db8:85a3:0000:0000:8a2e:0370:7334", want: false, }, - { - name: "multicast address is not accepted", - ip: "224.0.0.1", - want: false, - }, - { - name: "link local broadcast address is not accepted", - ip: "169.254.0.1", - want: false, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/core/usecases/addserver/addserver.go b/internal/core/usecases/addserver/addserver.go index f94afa7..e1a44e1 100644 --- a/internal/core/usecases/addserver/addserver.go +++ b/internal/core/usecases/addserver/addserver.go @@ -39,12 +39,8 @@ func New( } } -func (uc UseCase) Execute(ctx context.Context, address addr.Addr) (server.Server, error) { - if err := uc.validateAddress(address); err != nil { - return server.Blank, err - } - - svr, err := uc.getOrCreateServer(ctx, address) +func (uc UseCase) Execute(ctx context.Context, publicAddr addr.PublicAddr) (server.Server, error) { + svr, err := uc.getOrCreateServer(ctx, publicAddr.ToAddr()) if err != nil { return server.Blank, err } @@ -56,14 +52,6 @@ func (uc UseCase) Execute(ctx context.Context, address addr.Addr) (server.Server return svr, nil } -func (uc UseCase) validateAddress(address addr.Addr) error { - ipv4 := address.GetIP() - if !ipv4.IsGlobalUnicast() || ipv4.IsPrivate() { - return ErrInvalidAddress - } - return nil -} - func (uc UseCase) getOrCreateServer(ctx context.Context, address addr.Addr) (server.Server, error) { svr, err := uc.serverRepo.Get(ctx, address) if err != nil { diff --git a/internal/core/usecases/addserver/addserver_test.go b/internal/core/usecases/addserver/addserver_test.go index 0d2ece4..0c30eef 100644 --- a/internal/core/usecases/addserver/addserver_test.go +++ b/internal/core/usecases/addserver/addserver_test.go @@ -141,7 +141,7 @@ func TestAddServerUseCase_ServerExists(t *testing.T) { probeRepo.On("AddBetween", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil) uc := addserver.New(serverRepo, probeRepo, &logger) - addedSvr, err := uc.Execute(ctx, svrAddr) + addedSvr, err := uc.Execute(ctx, addr.MustNewPublicAddr(svrAddr)) if tt.wantErr != nil { assert.ErrorIs(t, err, tt.wantErr) @@ -200,7 +200,7 @@ func TestAddServerUseCase_ServerDoesNotExist(t *testing.T) { probeRepo.On("AddBetween", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil) uc := addserver.New(serverRepo, probeRepo, &logger) - _, err := uc.Execute(ctx, svrAddr) + _, err := uc.Execute(ctx, addr.MustNewPublicAddr(svrAddr)) assert.ErrorIs(t, err, addserver.ErrServerDiscoveryInProgress) serverRepo.AssertCalled(t, "Get", ctx, svrAddr) @@ -224,64 +224,3 @@ func TestAddServerUseCase_ServerDoesNotExist(t *testing.T) { repositories.NC, ) } - -func TestAddServerUseCase_ValidateAddress(t *testing.T) { - tests := []struct { - name string - ip string - port int - want bool - }{ - { - "positive case", - "1.1.1.1", - 10480, - true, - }, - { - "private ip address", - "127.0.0.1", - 10480, - false, - }, - { - "Private address", - "192.168.1.1", - 10480, - false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.TODO() - logger := zerolog.Nop() - - svrAddr := addr.MustNewFromDotted(tt.ip, tt.port) - svr := factories.BuildServer( - factories.WithAddress(tt.ip, tt.port), - factories.WithDiscoveryStatus(ds.Details), - ) - - serverRepo := new(MockServerRepository) - serverRepo.On("Get", ctx, svrAddr).Return(server.Blank, repositories.ErrServerNotFound) - serverRepo.On("Add", ctx, mock.Anything, mock.Anything).Return(svr, nil) - serverRepo.On("Update", ctx, mock.Anything, mock.Anything).Return(svr, nil) - - probeRepo := new(MockProbeRepository) - probeRepo.On("AddBetween", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil) - - uc := addserver.New(serverRepo, probeRepo, &logger) - _, err := uc.Execute(ctx, svrAddr) - - if tt.want { - assert.NoError(t, err) - } else { - assert.ErrorIs(t, err, addserver.ErrInvalidAddress) - serverRepo.AssertNotCalled(t, "Get", mock.Anything, mock.Anything) - serverRepo.AssertNotCalled(t, "Add", mock.Anything, mock.Anything, mock.Anything) - serverRepo.AssertNotCalled(t, "Update", mock.Anything, mock.Anything, mock.Anything) - } - }) - } -} diff --git a/internal/services/cleaning/cleaning.go b/internal/core/usecases/cleanservers/cleanservers.go similarity index 51% rename from internal/services/cleaning/cleaning.go rename to internal/core/usecases/cleanservers/cleanservers.go index 74e2bd4..ecbf488 100644 --- a/internal/services/cleaning/cleaning.go +++ b/internal/core/usecases/cleanservers/cleanservers.go @@ -1,4 +1,4 @@ -package cleaning +package cleanservers import ( "context" @@ -9,67 +9,71 @@ import ( "github.com/sergeii/swat4master/internal/core/entities/filterset" "github.com/sergeii/swat4master/internal/core/entities/server" "github.com/sergeii/swat4master/internal/core/repositories" - "github.com/sergeii/swat4master/internal/services/monitoring" ) -type Service struct { - servers repositories.ServerRepository - instances repositories.InstanceRepository - metrics *monitoring.MetricService - logger *zerolog.Logger +type UseCase struct { + serverRepo repositories.ServerRepository + instanceRepo repositories.InstanceRepository + logger *zerolog.Logger } -func NewService( - servers repositories.ServerRepository, - instances repositories.InstanceRepository, - metrics *monitoring.MetricService, +func New( + serverRepo repositories.ServerRepository, + instanceRepo repositories.InstanceRepository, logger *zerolog.Logger, -) *Service { - return &Service{ - servers: servers, - instances: instances, - metrics: metrics, - logger: logger, +) UseCase { + return UseCase{ + serverRepo: serverRepo, + instanceRepo: instanceRepo, + logger: logger, } } -func (s *Service) Clean(ctx context.Context, until time.Time) error { +type Response struct { + Count int + Errors int +} + +var NoResponse = Response{} + +func (uc UseCase) Execute(ctx context.Context, until time.Time) (Response, error) { var before, after, removed, errors int var err error - if before, err = s.servers.Count(ctx); err != nil { - return err + if before, err = uc.serverRepo.Count(ctx); err != nil { + return NoResponse, err } - s.logger.Info(). + + uc.logger.Info(). Stringer("until", until).Int("servers", before). Msg("Starting to clean outdated servers") fs := filterset.New().UpdatedBefore(until) - outdatedServers, err := s.servers.Filter(ctx, fs) + outdatedServers, err := uc.serverRepo.Filter(ctx, fs) if err != nil { - s.logger.Error().Err(err).Msg("Unable to obtain servers for cleanup") - return err + uc.logger.Error().Err(err).Msg("Unable to obtain servers for cleanup") + return NoResponse, err } for _, svr := range outdatedServers { - if err = s.instances.RemoveByAddr(ctx, svr.Addr); err != nil { - s.logger.Error(). + if err = uc.instanceRepo.RemoveByAddr(ctx, svr.Addr); err != nil { + uc.logger.Error(). Err(err). Stringer("until", until).Stringer("addr", svr.Addr). Msg("Failed to remove instance for removed server") errors++ continue } - if err = s.servers.Remove(ctx, svr, func(conflict *server.Server) bool { + if err = uc.serverRepo.Remove(ctx, svr, func(conflict *server.Server) bool { if conflict.RefreshedAt.After(until) { - s.logger.Info(). + uc.logger.Info(). Stringer("server", conflict).Stringer("refreshed", conflict.RefreshedAt). Msg("Removed server is more recent") return false } return true }); err != nil { - s.logger.Error(). + uc.logger.Error(). Err(err). Stringer("until", until).Stringer("addr", svr.Addr). Msg("Failed to remove outdated server") @@ -79,18 +83,15 @@ func (s *Service) Clean(ctx context.Context, until time.Time) error { removed++ } - s.metrics.CleanerRemovals.Add(float64(removed)) - s.metrics.CleanerErrors.Add(float64(errors)) - - if after, err = s.servers.Count(ctx); err != nil { - return err + if after, err = uc.serverRepo.Count(ctx); err != nil { + return NoResponse, err } - s.logger.Info(). + uc.logger.Info(). Stringer("until", until). Int("removed", removed).Int("errors", errors). Int("before", before).Int("after", after). Msg("Finished cleaning outdated servers") - return nil + return Response{Count: removed, Errors: errors}, nil } diff --git a/internal/core/usecases/cleanservers/cleanservers_test.go b/internal/core/usecases/cleanservers/cleanservers_test.go new file mode 100644 index 0000000..ec95b9b --- /dev/null +++ b/internal/core/usecases/cleanservers/cleanservers_test.go @@ -0,0 +1,190 @@ +package cleanservers_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/sergeii/swat4master/internal/core/entities/addr" + "github.com/sergeii/swat4master/internal/core/entities/filterset" + "github.com/sergeii/swat4master/internal/core/entities/server" + "github.com/sergeii/swat4master/internal/core/repositories" + "github.com/sergeii/swat4master/internal/core/usecases/cleanservers" + "github.com/sergeii/swat4master/internal/testutils/factories" +) + +type MockServerRepository struct { + mock.Mock + repositories.ServerRepository +} + +func (m *MockServerRepository) Count(ctx context.Context) (int, error) { + args := m.Called(ctx) + return args.Int(0), args.Error(1) +} + +func (m *MockServerRepository) Filter( + ctx context.Context, + fs filterset.FilterSet, +) ([]server.Server, error) { + args := m.Called(ctx, fs) + return args.Get(0).([]server.Server), args.Error(1) // nolint: forcetypeassert +} + +func (m *MockServerRepository) Remove( + ctx context.Context, + svr server.Server, + onConflict func(*server.Server) bool, +) error { + args := m.Called(ctx, svr, onConflict) + return args.Error(0) +} + +type MockInstanceRepository struct { + mock.Mock + repositories.InstanceRepository +} + +func (m *MockInstanceRepository) RemoveByAddr(ctx context.Context, addr addr.Addr) error { + args := m.Called(ctx, addr) + return args.Error(0) +} + +func TestCleanServersUseCase_Success(t *testing.T) { + ctx := context.TODO() + logger := zerolog.Nop() + + until := time.Now().Add(-24 * time.Hour) // Example time filter + + outdatedServers := []server.Server{ + factories.BuildRandomServer(), + factories.BuildRandomServer(), + } + + serverRepo := new(MockServerRepository) + serverRepo.On("Count", ctx).Return(10, nil).Once() + serverRepo.On("Count", ctx).Return(8, nil).Once() + serverRepo.On("Filter", ctx, mock.Anything).Return(outdatedServers, nil).Once() + serverRepo.On("Remove", ctx, mock.Anything, mock.Anything).Return(nil).Times(2) + + instanceRepo := new(MockInstanceRepository) + instanceRepo.On("RemoveByAddr", ctx, mock.Anything).Return(nil).Times(2) + + uc := cleanservers.New(serverRepo, instanceRepo, &logger) + response, err := uc.Execute(ctx, until) + + assert.NoError(t, err) + assert.Equal(t, 2, response.Count) + assert.Equal(t, 0, response.Errors) + + serverRepo.AssertExpectations(t) + instanceRepo.AssertExpectations(t) + + serverRepo.AssertCalled( + t, + "Filter", + ctx, + mock.MatchedBy(func(fs filterset.FilterSet) bool { + updatedBefore, _ := fs.GetUpdatedBefore() + return updatedBefore.Equal(until) + }), + ) + for _, svr := range outdatedServers { + serverRepo.AssertCalled(t, "Remove", ctx, svr, mock.Anything) + instanceRepo.AssertCalled(t, "RemoveByAddr", ctx, svr.Addr) + } +} + +func TestCleanServersUseCase_NothingToClean(t *testing.T) { + ctx := context.TODO() + logger := zerolog.Nop() + + until := time.Now().Add(-24 * time.Hour) // Example time filter + + serverRepo := new(MockServerRepository) + serverRepo.On("Count", ctx).Return(0, nil).Times(2) + serverRepo.On("Filter", ctx, mock.Anything).Return([]server.Server{}, nil).Once() + + instanceRepo := new(MockInstanceRepository) + + uc := cleanservers.New(serverRepo, instanceRepo, &logger) + response, err := uc.Execute(ctx, until) + + assert.NoError(t, err) + assert.Equal(t, 0, response.Count) + assert.Equal(t, 0, response.Errors) + + serverRepo.AssertExpectations(t) + instanceRepo.AssertExpectations(t) + + serverRepo.AssertNotCalled(t, "Remove", mock.Anything, mock.Anything, mock.Anything) + instanceRepo.AssertNotCalled(t, "RemoveByAddr", mock.Anything, mock.Anything) +} + +func TestCleanServersUseCase_RemoveErrors(t *testing.T) { + ctx := context.TODO() + logger := zerolog.Nop() + + until := time.Now().Add(-24 * time.Hour) // Example time filter + + svr1 := factories.BuildRandomServer() + svr2 := factories.BuildRandomServer() + svr3 := factories.BuildRandomServer() + outdatedServers := []server.Server{svr1, svr2, svr3} + + serverRepo := new(MockServerRepository) + serverRepo.On("Count", ctx).Return(3, nil).Once() + serverRepo.On("Count", ctx).Return(2, nil).Once() + serverRepo.On("Filter", ctx, mock.Anything).Return(outdatedServers, nil).Once() + serverRepo.On("Remove", ctx, svr2, mock.Anything).Return(nil).Once() + serverRepo.On("Remove", ctx, svr3, mock.Anything).Return(errors.New("error")).Once() + + instanceRepo := new(MockInstanceRepository) + instanceRepo.On("RemoveByAddr", ctx, svr1.Addr).Return(errors.New("error")).Once() + instanceRepo.On("RemoveByAddr", ctx, svr2.Addr).Return(nil).Once() + instanceRepo.On("RemoveByAddr", ctx, svr3.Addr).Return(nil).Once() + + uc := cleanservers.New(serverRepo, instanceRepo, &logger) + response, err := uc.Execute(ctx, until) + + assert.NoError(t, err) + assert.Equal(t, 1, response.Count) + assert.Equal(t, 2, response.Errors) + + serverRepo.AssertExpectations(t) + instanceRepo.AssertExpectations(t) + + serverRepo.AssertNumberOfCalls(t, "Remove", 2) + instanceRepo.AssertNumberOfCalls(t, "RemoveByAddr", 3) +} + +func TestCleanServersUseCase_CountError(t *testing.T) { + ctx := context.TODO() + logger := zerolog.Nop() + + until := time.Now().Add(-24 * time.Hour) // Example time filter + countErr := errors.New("error") + + serverRepo := new(MockServerRepository) + serverRepo.On("Count", ctx).Return(0, countErr).Once() + + instanceRepo := new(MockInstanceRepository) + + uc := cleanservers.New(serverRepo, instanceRepo, &logger) + response, err := uc.Execute(ctx, until) + + assert.ErrorIs(t, err, countErr) + assert.Equal(t, cleanservers.NoResponse, response) + + serverRepo.AssertExpectations(t) + instanceRepo.AssertExpectations(t) + + serverRepo.AssertNumberOfCalls(t, "Filter", 0) + serverRepo.AssertNumberOfCalls(t, "Remove", 0) + instanceRepo.AssertNumberOfCalls(t, "RemoveByAddr", 0) +} diff --git a/internal/core/usecases/getserver/getserver.go b/internal/core/usecases/getserver/getserver.go index 78a8b9b..846813f 100644 --- a/internal/core/usecases/getserver/getserver.go +++ b/internal/core/usecases/getserver/getserver.go @@ -29,12 +29,8 @@ func New( } } -func (uc UseCase) Execute(ctx context.Context, address addr.Addr) (server.Server, error) { - if err := uc.validateAddress(address); err != nil { - return server.Blank, err - } - - svr, err := uc.serverRepo.Get(ctx, address) +func (uc UseCase) Execute(ctx context.Context, publicAddr addr.PublicAddr) (server.Server, error) { + svr, err := uc.serverRepo.Get(ctx, publicAddr.ToAddr()) if err != nil { switch { case errors.Is(err, repositories.ErrServerNotFound): @@ -50,11 +46,3 @@ func (uc UseCase) Execute(ctx context.Context, address addr.Addr) (server.Server return svr, nil } - -func (uc UseCase) validateAddress(address addr.Addr) error { - ipv4 := address.GetIP() - if !ipv4.IsGlobalUnicast() || ipv4.IsPrivate() { - return ErrInvalidAddress - } - return nil -} diff --git a/internal/core/usecases/getserver/getserver_test.go b/internal/core/usecases/getserver/getserver_test.go index 1a80e46..3f5ca52 100644 --- a/internal/core/usecases/getserver/getserver_test.go +++ b/internal/core/usecases/getserver/getserver_test.go @@ -34,7 +34,7 @@ func TestGetServerUseCase_OK(t *testing.T) { mockRepo.On("Get", ctx, svr.Addr).Return(svr, nil) uc := getserver.New(mockRepo) - got, err := uc.Execute(ctx, svr.Addr) + got, err := uc.Execute(ctx, addr.MustNewPublicAddr(svr.Addr)) assert.NoError(t, err) assert.Equal(t, 10481, got.QueryPort) @@ -57,67 +57,13 @@ func TestGetServerUseCase_NotFound(t *testing.T) { mockRepo.On("Get", ctx, svrAddr).Return(server.Blank, repositories.ErrServerNotFound) uc := getserver.New(mockRepo) - _, err := uc.Execute(ctx, svrAddr) + _, err := uc.Execute(ctx, addr.MustNewPublicAddr(svrAddr)) assert.ErrorIs(t, err, getserver.ErrServerNotFound) mockRepo.AssertExpectations(t) } -func TestGetServerUseCase_ValidateAddress(t *testing.T) { - tests := []struct { - name string - ip string - port int - want bool - }{ - { - "positive case", - "1.1.1.1", - 10480, - true, - }, - { - "private ip address", - "127.0.0.1", - 10480, - false, - }, - { - "Private address", - "192.168.1.1", - 10480, - false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.TODO() - - svr := factories.BuildServer( - factories.WithAddress(tt.ip, tt.port), - factories.WithDiscoveryStatus(ds.Details), - ) - - mockRepo := new(MockServerRepository) - mockRepo.On("Get", ctx, svr.Addr).Return(svr, nil) - - uc := getserver.New(mockRepo) - got, err := uc.Execute(ctx, svr.Addr) - - if tt.want { - assert.NoError(t, err) - assert.Equal(t, "Swat4 Server", got.Info.Hostname) - mockRepo.AssertCalled(t, "Get", ctx, got.Addr) - } else { - assert.ErrorIs(t, err, getserver.ErrInvalidAddress) - mockRepo.AssertNotCalled(t, "Get", mock.Anything, mock.Anything) - } - }) - } -} - func TestGetServerUseCase_ValidateStatus(t *testing.T) { tests := []struct { name string @@ -163,7 +109,7 @@ func TestGetServerUseCase_ValidateStatus(t *testing.T) { mockRepo.On("Get", ctx, svr.Addr).Return(svr, nil) uc := getserver.New(mockRepo) - got, err := uc.Execute(ctx, svr.Addr) + got, err := uc.Execute(ctx, addr.MustNewPublicAddr(svr.Addr)) mockRepo.AssertExpectations(t) diff --git a/internal/core/usecases/refreshservers/refreshservers.go b/internal/core/usecases/refreshservers/refreshservers.go new file mode 100644 index 0000000..cb3326d --- /dev/null +++ b/internal/core/usecases/refreshservers/refreshservers.go @@ -0,0 +1,70 @@ +package refreshservers + +import ( + "context" + "time" + + "github.com/rs/zerolog" + + "github.com/sergeii/swat4master/internal/core/entities/addr" + ds "github.com/sergeii/swat4master/internal/core/entities/discovery/status" + "github.com/sergeii/swat4master/internal/core/entities/filterset" + "github.com/sergeii/swat4master/internal/core/entities/probe" + "github.com/sergeii/swat4master/internal/core/repositories" +) + +type UseCase struct { + serverRepo repositories.ServerRepository + probeRepo repositories.ProbeRepository + logger *zerolog.Logger +} + +func New( + serverRepo repositories.ServerRepository, + probeRepo repositories.ProbeRepository, + logger *zerolog.Logger, +) UseCase { + return UseCase{ + serverRepo: serverRepo, + probeRepo: probeRepo, + logger: logger, + } +} + +type Response struct { + Count int +} + +var NoResponse = Response{} + +func (uc UseCase) Execute(ctx context.Context, deadline time.Time) (Response, error) { + fs := filterset.New().WithStatus(ds.Port).NoStatus(ds.DetailsRetry) + serversWithDetails, err := uc.serverRepo.Filter(ctx, fs) + if err != nil { + uc.logger.Error().Err(err).Msg("Unable to obtain servers for refresh") + return NoResponse, err + } + + cnt := 0 + for _, svr := range serversWithDetails { + if err := uc.addProbe(ctx, svr.Addr, svr.QueryPort, deadline); err != nil { + uc.logger.Warn(). + Err(err).Stringer("server", svr). + Msg("Failed to add server to details discovery queue") + continue + } + cnt++ + } + + return Response{cnt}, nil +} + +func (uc UseCase) addProbe( + ctx context.Context, + svrAddr addr.Addr, + queryPort int, + deadline time.Time, +) error { + prb := probe.New(svrAddr, queryPort, probe.GoalDetails) + return uc.probeRepo.AddBetween(ctx, prb, repositories.NC, deadline) +} diff --git a/internal/core/usecases/refreshservers/refreshservers_test.go b/internal/core/usecases/refreshservers/refreshservers_test.go new file mode 100644 index 0000000..6da3338 --- /dev/null +++ b/internal/core/usecases/refreshservers/refreshservers_test.go @@ -0,0 +1,139 @@ +package refreshservers_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + ds "github.com/sergeii/swat4master/internal/core/entities/discovery/status" + "github.com/sergeii/swat4master/internal/core/entities/filterset" + "github.com/sergeii/swat4master/internal/core/entities/probe" + "github.com/sergeii/swat4master/internal/core/entities/server" + "github.com/sergeii/swat4master/internal/core/repositories" + "github.com/sergeii/swat4master/internal/core/usecases/refreshservers" + "github.com/sergeii/swat4master/internal/testutils/factories" +) + +type MockServerRepository struct { + mock.Mock + repositories.ServerRepository +} + +func (m *MockServerRepository) Filter(ctx context.Context, fs filterset.FilterSet) ([]server.Server, error) { + args := m.Called(ctx, fs) + return args.Get(0).([]server.Server), args.Error(1) // nolint: forcetypeassert +} + +type MockProbeRepository struct { + mock.Mock + repositories.ProbeRepository +} + +func (m *MockProbeRepository) AddBetween(ctx context.Context, prb probe.Probe, countdown, deadline time.Time) error { + args := m.Called(ctx, prb, countdown, deadline) + return args.Error(0) +} + +func TestRefreshServersUseCase_Success(t *testing.T) { + ctx := context.TODO() + logger := zerolog.Nop() + + now := time.Now() + deadline := now.Add(time.Minute * 10) + + svr1 := factories.BuildRandomServer() + svr2 := factories.BuildRandomServer() + serversToRevive := []server.Server{svr1, svr2} + + serverRepo := new(MockServerRepository) + serverRepo.On("Filter", ctx, mock.Anything).Return(serversToRevive, nil).Once() + + probeRepo := new(MockProbeRepository) + probeRepo.On("AddBetween", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice() + + uc := refreshservers.New(serverRepo, probeRepo, &logger) + resp, err := uc.Execute(ctx, deadline) + + assert.NoError(t, err) + assert.Equal(t, 2, resp.Count) + + serverRepo.AssertExpectations(t) + probeRepo.AssertExpectations(t) + + serverRepo.AssertCalled( + t, + "Filter", + ctx, + mock.MatchedBy(func(fs filterset.FilterSet) bool { + noStatus, hasNoStatus := fs.GetNoStatus() + withStatus, hasWithStatus := fs.GetWithStatus() + wantNoStatus := hasNoStatus && (noStatus == ds.DetailsRetry) + wantWithStatus := hasWithStatus && (withStatus == ds.Port) + return wantWithStatus && wantNoStatus + }), + ) + for _, svr := range serversToRevive { + probeRepo.AssertCalled( + t, + "AddBetween", + ctx, + probe.New(svr.Addr, svr.QueryPort, probe.GoalDetails), + repositories.NC, + deadline, + ) + } +} + +func TestRefreshServersUseCase_FilterError(t *testing.T) { + ctx := context.TODO() + logger := zerolog.Nop() + + filterErr := errors.New("filter error") + + serverRepo := new(MockServerRepository) + serverRepo.On("Filter", ctx, mock.Anything).Return([]server.Server{}, filterErr).Once() + + probeRepo := new(MockProbeRepository) + + uc := refreshservers.New(serverRepo, probeRepo, &logger) + resp, err := uc.Execute(ctx, time.Now()) + + assert.ErrorIs(t, err, filterErr) + assert.Equal(t, 0, resp.Count) + + serverRepo.AssertExpectations(t) + probeRepo.AssertExpectations(t) + + probeRepo.AssertNotCalled(t, "AddBetween", mock.Anything, mock.Anything, mock.Anything, mock.Anything) +} + +func TestRefreshServersUseCase_AddProbeError(t *testing.T) { + ctx := context.TODO() + logger := zerolog.Nop() + + addProbeErr := errors.New("probe error") + + svr1 := factories.BuildRandomServer() + svr2 := factories.BuildRandomServer() + serversToRevive := []server.Server{svr1, svr2} + + serverRepo := new(MockServerRepository) + serverRepo.On("Filter", ctx, mock.Anything).Return(serversToRevive, nil).Once() + + probeRepo := new(MockProbeRepository) + probeRepo.On("AddBetween", ctx, mock.Anything, mock.Anything, mock.Anything).Return(addProbeErr).Twice() + + uc := refreshservers.New(serverRepo, probeRepo, &logger) + resp, err := uc.Execute(ctx, time.Now()) + + assert.NoError(t, err) + assert.Equal(t, 0, resp.Count) + + serverRepo.AssertExpectations(t) + probeRepo.AssertExpectations(t) +} diff --git a/internal/core/usecases/reviveservers/reviveservers.go b/internal/core/usecases/reviveservers/reviveservers.go new file mode 100644 index 0000000..30b9163 --- /dev/null +++ b/internal/core/usecases/reviveservers/reviveservers.go @@ -0,0 +1,110 @@ +package reviveservers + +import ( + "context" + "time" + + "github.com/rs/zerolog" + + "github.com/sergeii/swat4master/internal/core/entities/addr" + ds "github.com/sergeii/swat4master/internal/core/entities/discovery/status" + "github.com/sergeii/swat4master/internal/core/entities/filterset" + "github.com/sergeii/swat4master/internal/core/entities/probe" + "github.com/sergeii/swat4master/internal/core/repositories" + "github.com/sergeii/swat4master/pkg/random" +) + +type UseCase struct { + serverRepo repositories.ServerRepository + probeRepo repositories.ProbeRepository + logger *zerolog.Logger +} + +func New( + serverRepo repositories.ServerRepository, + probeRepo repositories.ProbeRepository, + logger *zerolog.Logger, +) UseCase { + return UseCase{ + serverRepo: serverRepo, + probeRepo: probeRepo, + logger: logger, + } +} + +type Request struct { + MinScope time.Time + MaxScope time.Time + MinCountdown time.Time + MaxCountdown time.Time + Deadline time.Time +} + +func NewRequest( + minScope time.Time, + maxScope time.Time, + minCountdown time.Time, + maxCountdown time.Time, + deadline time.Time, +) Request { + return Request{ + MinScope: minScope, + MaxScope: maxScope, + MinCountdown: minCountdown, + MaxCountdown: maxCountdown, + Deadline: deadline, + } +} + +type Response struct { + Count int +} + +var NoResponse = Response{} + +func (uc UseCase) Execute(ctx context.Context, req Request) (Response, error) { + fs := filterset.New().ActiveAfter(req.MinScope).ActiveBefore(req.MaxScope).NoStatus(ds.Port | ds.PortRetry) + + serversWithoutPort, err := uc.serverRepo.Filter(ctx, fs) + if err != nil { + uc.logger.Error().Err(err).Msg("Unable to obtain servers for port discovery") + return NoResponse, err + } + + cnt := 0 + for _, svr := range serversWithoutPort { + countdown := selectCountdown(req.MinCountdown, req.MaxCountdown) + if err := uc.addProbe(ctx, svr.Addr, countdown, req.Deadline); err != nil { + uc.logger.Warn(). + Err(err). + Stringer("server", svr).Time("countdown", countdown).Time("deadline", req.Deadline). + Msg("Failed to add server to port discovery queue") + continue + } + uc.logger.Debug(). + Time("countdown", countdown).Time("deadline", req.Deadline).Stringer("server", svr). + Msg("Added server to port discovery queue") + cnt++ + } + + return Response{cnt}, nil +} + +func (uc UseCase) addProbe( + ctx context.Context, + svrAddr addr.Addr, + countdown time.Time, + deadline time.Time, +) error { + prb := probe.New(svrAddr, svrAddr.Port, probe.GoalPort) + return uc.probeRepo.AddBetween(ctx, prb, countdown, deadline) +} + +func selectCountdown(min, max time.Time) time.Time { + if !max.After(min) { + return min + } + spread := max.Sub(min) + countdown := random.RandInt(0, int(spread)) + return min.Add(time.Duration(countdown)) +} diff --git a/internal/core/usecases/reviveservers/reviveservers_test.go b/internal/core/usecases/reviveservers/reviveservers_test.go new file mode 100644 index 0000000..07fa80d --- /dev/null +++ b/internal/core/usecases/reviveservers/reviveservers_test.go @@ -0,0 +1,166 @@ +package reviveservers_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + ds "github.com/sergeii/swat4master/internal/core/entities/discovery/status" + "github.com/sergeii/swat4master/internal/core/entities/filterset" + "github.com/sergeii/swat4master/internal/core/entities/probe" + "github.com/sergeii/swat4master/internal/core/entities/server" + "github.com/sergeii/swat4master/internal/core/repositories" + "github.com/sergeii/swat4master/internal/core/usecases/reviveservers" + "github.com/sergeii/swat4master/internal/testutils/factories" +) + +type MockServerRepository struct { + mock.Mock + repositories.ServerRepository +} + +func (m *MockServerRepository) Filter(ctx context.Context, fs filterset.FilterSet) ([]server.Server, error) { + args := m.Called(ctx, fs) + return args.Get(0).([]server.Server), args.Error(1) // nolint: forcetypeassert +} + +type MockProbeRepository struct { + mock.Mock + repositories.ProbeRepository +} + +func (m *MockProbeRepository) AddBetween(ctx context.Context, prb probe.Probe, countdown, deadline time.Time) error { + args := m.Called(ctx, prb, countdown, deadline) + return args.Error(0) +} + +func TestReviveServersUseCase_Success(t *testing.T) { + ctx := context.TODO() + logger := zerolog.Nop() + + now := time.Now() + minScope := now.Add(-time.Hour) + maxScope := now.Add(-time.Minute * 10) + minCountdown := now + maxCountdown := now.Add(time.Minute * 5) + deadline := now.Add(time.Minute * 10) + + svr1 := factories.BuildRandomServer() + svr2 := factories.BuildRandomServer() + serversToRevive := []server.Server{svr1, svr2} + + serverRepo := new(MockServerRepository) + serverRepo.On("Filter", ctx, mock.Anything).Return(serversToRevive, nil).Once() + + probeRepo := new(MockProbeRepository) + probeRepo.On("AddBetween", ctx, mock.Anything, mock.Anything, mock.Anything).Return(nil).Twice() + + req := reviveservers.NewRequest(minScope, maxScope, minCountdown, maxCountdown, deadline) + uc := reviveservers.New(serverRepo, probeRepo, &logger) + resp, err := uc.Execute(ctx, req) + + assert.NoError(t, err) + assert.Equal(t, 2, resp.Count) + + serverRepo.AssertExpectations(t) + probeRepo.AssertExpectations(t) + + serverRepo.AssertCalled( + t, + "Filter", + ctx, + mock.MatchedBy(func(fs filterset.FilterSet) bool { + noStatus, hasNoStatus := fs.GetNoStatus() + activeAfter, hasActiveAfter := fs.GetActiveAfter() + activeBefore, hasActiveBefore := fs.GetActiveBefore() + wantNoStatus := hasNoStatus && (noStatus == ds.Port|ds.PortRetry) + wantActiveAfter := hasActiveAfter && activeAfter.Equal(req.MinScope) + wantActiveBefore := hasActiveBefore && activeBefore.Equal(req.MaxScope) + return wantNoStatus && wantActiveAfter && wantActiveBefore + }), + ) + for _, svr := range serversToRevive { + probeRepo.AssertCalled( + t, + "AddBetween", + ctx, + probe.New(svr.Addr, svr.Addr.Port, probe.GoalPort), + mock.MatchedBy(func(countdown time.Time) bool { + gteMinCountdown := countdown.Equal(req.MinCountdown) || countdown.After(req.MinCountdown) + lteMaxCountdown := countdown.Equal(req.MaxCountdown) || countdown.Before(req.MaxCountdown) + return gteMinCountdown && lteMaxCountdown + }), + deadline, + ) + } +} + +func TestReviveServersUseCase_FilterError(t *testing.T) { + ctx := context.TODO() + logger := zerolog.Nop() + + now := time.Now() + filterErr := errors.New("filter error") + + serverRepo := new(MockServerRepository) + serverRepo.On("Filter", ctx, mock.Anything).Return([]server.Server{}, filterErr).Once() + + probeRepo := new(MockProbeRepository) + + req := reviveservers.NewRequest( + now.Add(-time.Hour), + now.Add(-time.Minute*10), + now, + now.Add(time.Minute*5), + now.Add(time.Minute*10), + ) + uc := reviveservers.New(serverRepo, probeRepo, &logger) + resp, err := uc.Execute(ctx, req) + + assert.ErrorIs(t, err, filterErr) + assert.Equal(t, 0, resp.Count) + + serverRepo.AssertExpectations(t) + probeRepo.AssertExpectations(t) + + probeRepo.AssertNotCalled(t, "AddBetween", mock.Anything, mock.Anything, mock.Anything, mock.Anything) +} + +func TestReviveServersUseCase_AddProbeError(t *testing.T) { + ctx := context.TODO() + logger := zerolog.Nop() + + now := time.Now() + addProbeErr := errors.New("probe error") + + svr1 := factories.BuildRandomServer() + svr2 := factories.BuildRandomServer() + serversToRevive := []server.Server{svr1, svr2} + + serverRepo := new(MockServerRepository) + serverRepo.On("Filter", ctx, mock.Anything).Return(serversToRevive, nil).Once() + + probeRepo := new(MockProbeRepository) + probeRepo.On("AddBetween", ctx, mock.Anything, mock.Anything, mock.Anything).Return(addProbeErr).Twice() + + req := reviveservers.NewRequest( + now.Add(-time.Hour), + now.Add(-time.Minute*10), + now, + now.Add(time.Minute*5), + now.Add(time.Minute*10), + ) + uc := reviveservers.New(serverRepo, probeRepo, &logger) + resp, err := uc.Execute(ctx, req) + + assert.NoError(t, err) + assert.Equal(t, 0, resp.Count) + + serverRepo.AssertExpectations(t) + probeRepo.AssertExpectations(t) +} diff --git a/internal/persistence/memory/instances/instances.go b/internal/persistence/memory/instances/instances.go index b6aedc1..17cd2f3 100644 --- a/internal/persistence/memory/instances/instances.go +++ b/internal/persistence/memory/instances/instances.go @@ -52,11 +52,11 @@ func (r *Repository) RemoveByID(_ context.Context, id string) error { func (r *Repository) RemoveByAddr(_ context.Context, insAddr addr.Addr) error { r.mutex.Lock() defer r.mutex.Unlock() - instance, exists := r.addrs[insAddr] + ins, exists := r.addrs[insAddr] if !exists { return nil } - delete(r.ids, instance.ID) + delete(r.ids, ins.ID) delete(r.addrs, insAddr) return nil } diff --git a/internal/rest/api/servers_add.go b/internal/rest/api/servers_add.go index 5b38f67..7ac8c2a 100644 --- a/internal/rest/api/servers_add.go +++ b/internal/rest/api/servers_add.go @@ -48,17 +48,22 @@ func (a *API) AddServer(c *gin.Context) { c.JSON(http.StatusOK, model.NewServerFromDomain(svr)) } -func parseAddServerAddress(c *gin.Context) (addr.Addr, error) { +func parseAddServerAddress(c *gin.Context) (addr.PublicAddr, error) { var req model.NewServer if err := c.ShouldBindJSON(&req); err != nil { - return addr.Blank, err + return addr.BlankPublicAddr, err } address, err := addr.NewFromDotted(req.IP, req.Port) if err != nil { - return addr.Blank, err + return addr.BlankPublicAddr, err } - return address, nil + pubAddress, err := addr.NewPublicAddr(address) + if err != nil { + return addr.BlankPublicAddr, err + } + + return pubAddress, nil } diff --git a/internal/rest/api/servers_view.go b/internal/rest/api/servers_view.go index 50c5b41..8794740 100644 --- a/internal/rest/api/servers_view.go +++ b/internal/rest/api/servers_view.go @@ -19,8 +19,8 @@ import ( // @Success 200 {object} model.ServerDetail // @Router /servers/:address [get] func (a *API) ViewServer(c *gin.Context) { - address, err := addr.NewFromString(c.Param("address")) - if err != nil { + address, parseErr := parseViewServerAddress(c) + if parseErr != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid server address"}) return } @@ -30,17 +30,17 @@ func (a *API) ViewServer(c *gin.Context) { switch { case errors.Is(err, getserver.ErrServerNotFound): a.logger.Debug(). - Stringer("addr", address). + Stringer("addr", address.ToAddr()). Msg("Requested server not found") c.Status(http.StatusNotFound) case errors.Is(err, getserver.ErrInvalidAddress): a.logger.Debug(). - Stringer("addr", address). + Stringer("addr", address.ToAddr()). Msg("Requested server address is invalid") c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid server address"}) case errors.Is(err, getserver.ErrServerHasNoDetails): a.logger.Debug(). - Stringer("addr", address). + Stringer("addr", address.ToAddr()). Msg("Requested server has no details") c.Status(http.StatusNoContent) } @@ -49,3 +49,17 @@ func (a *API) ViewServer(c *gin.Context) { c.JSON(http.StatusOK, model.NewServerDetailFromDomain(svr)) } + +func parseViewServerAddress(c *gin.Context) (addr.PublicAddr, error) { + address, err := addr.NewFromString(c.Param("address")) + if err != nil { + return addr.BlankPublicAddr, err + } + + pubAddress, err := addr.NewPublicAddr(address) + if err != nil { + return addr.BlankPublicAddr, err + } + + return pubAddress, nil +} diff --git a/internal/services/cleaning/cleaning_test.go b/internal/services/cleaning/cleaning_test.go deleted file mode 100644 index 242eeb2..0000000 --- a/internal/services/cleaning/cleaning_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package cleaning_test - -import ( - "context" - "net" - "testing" - "time" - - "github.com/jonboulle/clockwork" - "github.com/rs/zerolog" - "github.com/stretchr/testify/assert" - "go.uber.org/fx" - "go.uber.org/fx/fxtest" - - "github.com/sergeii/swat4master/internal/core/entities/addr" - "github.com/sergeii/swat4master/internal/core/entities/instance" - "github.com/sergeii/swat4master/internal/core/entities/server" - repos "github.com/sergeii/swat4master/internal/core/repositories" - "github.com/sergeii/swat4master/internal/persistence/memory" - "github.com/sergeii/swat4master/internal/services/cleaning" - "github.com/sergeii/swat4master/internal/services/monitoring" -) - -func makeApp(tb fxtest.TB, extra ...fx.Option) { - fxopts := []fx.Option{ - fx.Provide(clockwork.NewRealClock), - fx.Provide(func(c clockwork.Clock) (repos.ServerRepository, repos.InstanceRepository, repos.ProbeRepository) { - mem := memory.New(c) - return mem.Servers, mem.Instances, mem.Probes - }), - fx.Provide(func() *zerolog.Logger { - logger := zerolog.Nop() - return &logger - }), - fx.Provide( - monitoring.NewMetricService, - cleaning.NewService, - ), - fx.NopLogger, - } - fxopts = append(fxopts, extra...) - app := fxtest.New(tb, fxopts...) - app.RequireStart().RequireStop() -} - -func TestCleaningService_Clean(t *testing.T) { - var service *cleaning.Service - var serversRepo repos.ServerRepository - var instancesRepo repos.InstanceRepository - - ctx := context.TODO() - - makeApp(t, fx.Populate(&service, &serversRepo, &instancesRepo)) - - instance1 := instance.MustNew("foo", net.ParseIP("1.1.1.1"), 10480) - instance3 := instance.MustNew("bar", net.ParseIP("3.3.3.3"), 10480) - instance4 := instance.MustNew("baz", net.ParseIP("4.4.4.4"), 10480) - - instancesRepo.Add(ctx, instance1) // nolint: errcheck - instancesRepo.Add(ctx, instance3) // nolint: errcheck - instancesRepo.Add(ctx, instance4) // nolint: errcheck - - server1 := server.MustNew(net.ParseIP("1.1.1.1"), 10480, 10481) - server2 := server.MustNew(net.ParseIP("2.2.2.2"), 10480, 10481) - server3 := server.MustNew(net.ParseIP("3.3.3.3"), 10480, 10481) - server4 := server.MustNew(net.ParseIP("4.4.4.4"), 10480, 10481) - - beforeAll := time.Now() - - serversRepo.Add(ctx, server1, repos.ServerOnConflictIgnore) // nolint: errcheck - - before2 := time.Now() - - serversRepo.Add(ctx, server2, repos.ServerOnConflictIgnore) // nolint: errcheck - serversRepo.Add(ctx, server3, repos.ServerOnConflictIgnore) // nolint: errcheck - serversRepo.Add(ctx, server4, repos.ServerOnConflictIgnore) // nolint: errcheck - - afterAll := time.Now() - - svrCount, _ := serversRepo.Count(ctx) - insCount, _ := instancesRepo.Count(ctx) - assert.Equal(t, 4, svrCount) - assert.Equal(t, 3, insCount) - - err := service.Clean(context.TODO(), beforeAll) - assert.NoError(t, err) - // no changes - svrCount, _ = serversRepo.Count(ctx) - assert.Equal(t, 4, svrCount) - insCount, _ = instancesRepo.Count(ctx) - assert.Equal(t, 3, insCount) - - err = service.Clean(context.TODO(), before2) - assert.NoError(t, err) - svrCount, _ = serversRepo.Count(ctx) - assert.Equal(t, 3, svrCount) - insCount, _ = instancesRepo.Count(ctx) - assert.Equal(t, 2, insCount) - _, getSvrErr := serversRepo.Get(ctx, addr.MustNewFromDotted("1.1.1.1", 10480)) - assert.ErrorIs(t, getSvrErr, repos.ErrServerNotFound) - _, getInsErr := instancesRepo.GetByID(ctx, "foo") - assert.ErrorIs(t, getInsErr, repos.ErrInstanceNotFound) - - serversRepo.Update(ctx, server3, repos.ServerOnConflictIgnore) // nolint: errcheck - - err = service.Clean(context.TODO(), afterAll) - assert.NoError(t, err) - svrCount, _ = serversRepo.Count(ctx) - assert.Equal(t, 1, svrCount) - insCount, _ = instancesRepo.Count(ctx) - assert.Equal(t, 1, insCount) - _, getSvrErr = serversRepo.Get(ctx, addr.MustNewFromDotted("3.3.3.3", 10480)) - assert.NoError(t, getSvrErr) - _, getInsErr = instancesRepo.GetByID(ctx, "bar") - assert.NoError(t, getInsErr) - - err = service.Clean(context.TODO(), time.Now()) - assert.NoError(t, err) - svrCount, _ = serversRepo.Count(ctx) - assert.Equal(t, 0, svrCount) - insCount, _ = instancesRepo.Count(ctx) - assert.Equal(t, 0, insCount) -} - -func TestCleaningService_Clean_EmptyNoError(t *testing.T) { - var service *cleaning.Service - makeApp(t, fx.Populate(&service)) - err := service.Clean(context.TODO(), time.Now()) - assert.NoError(t, err) -} diff --git a/internal/services/discovery/finding/finding.go b/internal/services/discovery/finding/finding.go deleted file mode 100644 index 9b3fbfb..0000000 --- a/internal/services/discovery/finding/finding.go +++ /dev/null @@ -1,123 +0,0 @@ -package finding - -import ( - "context" - "time" - - "github.com/rs/zerolog" - - "github.com/sergeii/swat4master/internal/core/entities/addr" - ds "github.com/sergeii/swat4master/internal/core/entities/discovery/status" - "github.com/sergeii/swat4master/internal/core/entities/filterset" - "github.com/sergeii/swat4master/internal/core/entities/probe" - "github.com/sergeii/swat4master/internal/core/repositories" - ps "github.com/sergeii/swat4master/internal/services/probe" - "github.com/sergeii/swat4master/pkg/random" -) - -type Service struct { - servers repositories.ServerRepository - queue *ps.Service - logger *zerolog.Logger -} - -func NewService( - servers repositories.ServerRepository, - queue *ps.Service, - logger *zerolog.Logger, -) *Service { - service := &Service{ - servers: servers, - queue: queue, - logger: logger, - } - return service -} - -func (s *Service) RefreshDetails( - ctx context.Context, - deadline time.Time, -) (int, error) { - fs := filterset.New().WithStatus(ds.Port).NoStatus(ds.DetailsRetry) - serversWithDetails, err := s.servers.Filter(ctx, fs) - if err != nil { - s.logger.Error().Err(err).Msg("Unable to obtain servers for details discovery") - return -1, err - } - - cnt := 0 - for _, svr := range serversWithDetails { - if err := s.DiscoverDetails(ctx, svr.Addr, svr.QueryPort, deadline); err != nil { - s.logger.Warn(). - Err(err).Stringer("server", svr). - Msg("Failed to add server to details discovery queue") - continue - } - cnt++ - } - - return cnt, nil -} - -func (s *Service) ReviveServers( - ctx context.Context, - minScope time.Time, - maxScope time.Time, - minCountdown time.Time, - maxCountdown time.Time, - deadline time.Time, -) (int, error) { - fs := filterset.New().ActiveAfter(minScope).ActiveBefore(maxScope).NoStatus(ds.Port | ds.PortRetry) - serversWithoutPort, err := s.servers.Filter(ctx, fs) - if err != nil { - s.logger.Error().Err(err).Msg("Unable to obtain servers for port discovery") - return -1, err - } - - cnt := 0 - for _, svr := range serversWithoutPort { - countdown := selectCountdown(minCountdown, maxCountdown) - if err := s.DiscoverPort(ctx, svr.Addr, countdown, deadline); err != nil { - s.logger.Warn(). - Err(err). - Stringer("server", svr).Time("countdown", countdown).Time("deadline", deadline). - Msg("Failed to add server to port discovery queue") - continue - } - s.logger.Debug(). - Time("countdown", countdown).Time("deadline", deadline).Stringer("server", svr). - Msg("Added server to port discovery queue") - cnt++ - } - - return cnt, nil -} - -func (s *Service) DiscoverDetails( - ctx context.Context, - addr addr.Addr, - queryPort int, - deadline time.Time, -) error { - prb := probe.New(addr, queryPort, probe.GoalDetails) - return s.queue.AddBefore(ctx, prb, deadline) -} - -func (s *Service) DiscoverPort( - ctx context.Context, - addr addr.Addr, - countdown time.Time, - deadline time.Time, -) error { - prb := probe.New(addr, addr.Port, probe.GoalPort) - return s.queue.AddBetween(ctx, prb, countdown, deadline) -} - -func selectCountdown(min, max time.Time) time.Time { - if !max.After(min) { - return min - } - spread := max.Sub(min) - countdown := random.RandInt(0, int(spread)) - return min.Add(time.Duration(countdown)) -} diff --git a/internal/services/discovery/finding/finding_test.go b/internal/services/discovery/finding/finding_test.go deleted file mode 100644 index 5a91c3c..0000000 --- a/internal/services/discovery/finding/finding_test.go +++ /dev/null @@ -1,282 +0,0 @@ -package finding_test - -import ( - "context" - "net" - "testing" - "time" - - "github.com/jonboulle/clockwork" - "github.com/rs/zerolog" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/fx" - "go.uber.org/fx/fxtest" - - "github.com/sergeii/swat4master/internal/core/entities/addr" - ds "github.com/sergeii/swat4master/internal/core/entities/discovery/status" - "github.com/sergeii/swat4master/internal/core/entities/probe" - "github.com/sergeii/swat4master/internal/core/entities/server" - repos "github.com/sergeii/swat4master/internal/core/repositories" - "github.com/sergeii/swat4master/internal/persistence/memory" - "github.com/sergeii/swat4master/internal/services/discovery/finding" - "github.com/sergeii/swat4master/internal/services/monitoring" - sp "github.com/sergeii/swat4master/internal/services/probe" -) - -func makeApp(tb fxtest.TB, extra ...fx.Option) { - fxopts := []fx.Option{ - fx.Provide(func(c clockwork.Clock) (repos.ServerRepository, repos.InstanceRepository, repos.ProbeRepository) { - mem := memory.New(c) - return mem.Servers, mem.Instances, mem.Probes - }), - fx.Provide(func() *zerolog.Logger { - logger := zerolog.Nop() - return &logger - }), - fx.Provide( - monitoring.NewMetricService, - sp.NewService, - finding.NewService, - ), - fx.NopLogger, - } - fxopts = append(fxopts, extra...) - app := fxtest.New(tb, fxopts...) - app.RequireStart().RequireStop() -} - -func provideClock(c clockwork.Clock) fx.Option { - return fx.Provide( - func() clockwork.Clock { - return c - }, - ) -} - -func TestFindingService_DiscoverDetails(t *testing.T) { - ctx := context.TODO() - c := clockwork.NewFakeClock() - - var queue repos.ProbeRepository - var finder *finding.Service - makeApp(t, fx.Populate(&finder, &queue), provideClock(c)) - - deadline := c.Now().Add(time.Millisecond * 10) - for _, ipaddr := range []string{"1.1.1.1", "2.2.2.2", "3.3.3.3"} { - err := finder.DiscoverDetails(ctx, addr.MustNewFromDotted(ipaddr, 10480), 10481, deadline) - assert.NoError(t, err) - } - - p1, _ := queue.Pop(ctx) - assert.Equal(t, "1.1.1.1", p1.Addr.GetDottedIP()) - assert.Equal(t, probe.GoalDetails, p1.Goal) - assert.Equal(t, 10481, p1.Port) - - c.Advance(time.Millisecond * 15) - - _, err := queue.Pop(ctx) - assert.ErrorIs(t, err, repos.ErrProbeQueueIsEmpty) -} - -func TestFindingService_DiscoverPort(t *testing.T) { - ctx := context.TODO() - c := clockwork.NewFakeClock() - - var queue repos.ProbeRepository - var finder *finding.Service - makeApp(t, fx.Populate(&finder, &queue), provideClock(c)) - - now := c.Now() - countdown := now.Add(time.Millisecond * 5) - deadline := now.Add(time.Millisecond * 15) - - for _, ipaddr := range []string{"1.1.1.1", "2.2.2.2", "3.3.3.3"} { - err := finder.DiscoverPort(ctx, addr.MustNewFromDotted(ipaddr, 10480), countdown, deadline) - assert.NoError(t, err) - } - - _, err := queue.Pop(ctx) - assert.ErrorIs(t, err, repos.ErrProbeIsNotReady) - - c.Advance(time.Millisecond * 5) - - p1, _ := queue.Pop(ctx) - assert.Equal(t, "1.1.1.1", p1.Addr.GetDottedIP()) - assert.Equal(t, probe.GoalPort, p1.Goal) - assert.Equal(t, 10480, p1.Port) - - c.Advance(time.Millisecond * 5) - - p2, _ := queue.Pop(ctx) - assert.Equal(t, "2.2.2.2", p2.Addr.GetDottedIP()) - assert.Equal(t, probe.GoalPort, p2.Goal) - assert.Equal(t, 10480, p2.Port) - - c.Advance(time.Millisecond * 10) - - _, err = queue.Pop(ctx) - assert.ErrorIs(t, err, repos.ErrProbeQueueIsEmpty) -} - -func TestFindingService_RefreshDetails(t *testing.T) { - ctx := context.TODO() - c := clockwork.NewFakeClock() - - var serversRepo repos.ServerRepository - var probesRepo repos.ProbeRepository - var service *finding.Service - makeApp(t, fx.Populate(&serversRepo, &probesRepo, &service), provideClock(c)) - - gs1 := server.MustNew(net.ParseIP("1.1.1.1"), 10480, 10481) - gs1.Refresh(c.Now()) - gs1.UpdateDiscoveryStatus(ds.Master) - - gs2 := server.MustNew(net.ParseIP("2.2.2.2"), 10480, 10481) - gs2.Refresh(c.Now()) - gs2.UpdateDiscoveryStatus(ds.Port) - - gs3 := server.MustNew(net.ParseIP("3.3.3.3"), 10480, 10481) - gs3.Refresh(c.Now()) - gs3.UpdateDiscoveryStatus(ds.Master | ds.Details | ds.Port) - - gs4 := server.MustNew(net.ParseIP("4.4.4.4"), 10480, 10481) - gs4.Refresh(c.Now()) - gs4.UpdateDiscoveryStatus(ds.NoDetails) - - gs5 := server.MustNew(net.ParseIP("5.5.5.5"), 10480, 10481) - gs5.Refresh(c.Now()) - gs5.UpdateDiscoveryStatus(ds.DetailsRetry) - - gs6 := server.MustNew(net.ParseIP("6.6.6.6"), 10480, 10481) - gs6.Refresh(c.Now()) - gs6.UpdateDiscoveryStatus(ds.Port | ds.Details | ds.DetailsRetry) - - gs7 := server.MustNew(net.ParseIP("7.7.7.7"), 10480, 10481) - gs7.Refresh(c.Now()) - gs7.UpdateDiscoveryStatus(ds.Master | ds.Info | ds.Details | ds.Port) - - gs1, _ = serversRepo.Add(ctx, gs1, repos.ServerOnConflictIgnore) - gs2, _ = serversRepo.Add(ctx, gs2, repos.ServerOnConflictIgnore) - gs3, _ = serversRepo.Add(ctx, gs3, repos.ServerOnConflictIgnore) - gs4, _ = serversRepo.Add(ctx, gs4, repos.ServerOnConflictIgnore) - gs5, _ = serversRepo.Add(ctx, gs5, repos.ServerOnConflictIgnore) - gs6, _ = serversRepo.Add(ctx, gs6, repos.ServerOnConflictIgnore) - gs7, _ = serversRepo.Add(ctx, gs7, repos.ServerOnConflictIgnore) - - deadline := c.Now().Add(time.Second * 60) - - refreshedCount, err := service.RefreshDetails(ctx, deadline) - require.NoError(t, err) - assert.Equal(t, 3, refreshedCount) - - probeCnt, _ := probesRepo.Count(ctx) - assert.Equal(t, 3, probeCnt) - - refreshedServers := make([]string, 0, 3) - for i := 0; i < 3; i++ { - prb, err := probesRepo.PopAny(ctx) - require.NoError(t, err) - require.Equal(t, probe.GoalDetails, prb.Goal) - refreshedServers = append(refreshedServers, prb.Addr.GetDottedIP()) - } - assert.Equal(t, []string{"7.7.7.7", "3.3.3.3", "2.2.2.2"}, refreshedServers) -} - -func TestFindingService_ReviveServers(t *testing.T) { - ctx := context.TODO() - c := clockwork.NewFakeClock() - - var serversRepo repos.ServerRepository - var probesRepo repos.ProbeRepository - var service *finding.Service - makeApp(t, fx.Populate(&serversRepo, &probesRepo, &service), provideClock(c)) - - c.Advance(time.Millisecond) - - gs1 := server.MustNew(net.ParseIP("1.1.1.1"), 10480, 10481) - gs1.Refresh(c.Now()) - gs1.UpdateDiscoveryStatus(ds.Master) - - c.Advance(time.Millisecond) - - gs2 := server.MustNew(net.ParseIP("2.2.2.2"), 10480, 10481) - gs2.Refresh(c.Now()) - gs2.UpdateDiscoveryStatus(ds.Port) - - before3rd := c.Now() - - c.Advance(time.Millisecond) - - gs3 := server.MustNew(net.ParseIP("3.3.3.3"), 10480, 10481) - gs3.Refresh(c.Now()) - gs3.UpdateDiscoveryStatus(ds.Master | ds.Details | ds.Port) - - c.Advance(time.Millisecond) - - gs4 := server.MustNew(net.ParseIP("4.4.4.4"), 10480, 10481) - gs4.Refresh(c.Now()) - gs4.UpdateDiscoveryStatus(ds.NoDetails) - - c.Advance(time.Millisecond) - - gs5 := server.MustNew(net.ParseIP("5.5.5.5"), 10480, 10481) - gs5.Refresh(c.Now()) - gs5.UpdateDiscoveryStatus(ds.DetailsRetry) - - c.Advance(time.Millisecond) - gs6 := server.MustNew(net.ParseIP("6.6.6.6"), 10480, 10481) - gs6.Refresh(c.Now()) - gs6.UpdateDiscoveryStatus(ds.Port | ds.Details | ds.DetailsRetry) - - c.Advance(time.Millisecond) - - gs7 := server.MustNew(net.ParseIP("7.7.7.7"), 10480, 10481) - gs7.Refresh(c.Now()) - gs7.UpdateDiscoveryStatus(ds.Master | ds.Info | ds.Details) - - c.Advance(time.Millisecond) - - gs8 := server.MustNew(net.ParseIP("8.8.8.8"), 10480, 10481) - gs8.Refresh(c.Now()) - gs8.UpdateDiscoveryStatus(ds.Master | ds.PortRetry) - - beforeLast := c.Now() - - c.Advance(time.Millisecond) - - gs9 := server.MustNew(net.ParseIP("9.9.9.9"), 10480, 10481) - gs9.Refresh(c.Now()) - gs9.UpdateDiscoveryStatus(ds.Info) - - gs1, _ = serversRepo.Add(ctx, gs1, repos.ServerOnConflictIgnore) - gs2, _ = serversRepo.Add(ctx, gs2, repos.ServerOnConflictIgnore) - gs3, _ = serversRepo.Add(ctx, gs3, repos.ServerOnConflictIgnore) - gs4, _ = serversRepo.Add(ctx, gs4, repos.ServerOnConflictIgnore) - gs5, _ = serversRepo.Add(ctx, gs5, repos.ServerOnConflictIgnore) - gs6, _ = serversRepo.Add(ctx, gs6, repos.ServerOnConflictIgnore) - gs7, _ = serversRepo.Add(ctx, gs7, repos.ServerOnConflictIgnore) - gs8, _ = serversRepo.Add(ctx, gs8, repos.ServerOnConflictIgnore) - gs9, _ = serversRepo.Add(ctx, gs9, repos.ServerOnConflictIgnore) - - now := c.Now() - minCountdown := now - maxCountdown := now - deadline := now.Add(time.Second * 60) - - revivedCnt, err := service.ReviveServers(ctx, before3rd, beforeLast, minCountdown, maxCountdown, deadline) - require.NoError(t, err) - assert.Equal(t, 3, revivedCnt) - - probeCnt, _ := probesRepo.Count(ctx) - assert.Equal(t, 3, probeCnt) - - revivedServers := make([]string, 0, 3) - for i := 0; i < 3; i++ { - prb, err := probesRepo.PopAny(ctx) - require.NoError(t, err) - require.Equal(t, probe.GoalPort, prb.Goal) - revivedServers = append(revivedServers, prb.Addr.GetDottedIP()) - } - assert.Equal(t, []string{"7.7.7.7", "5.5.5.5", "4.4.4.4"}, revivedServers) -} diff --git a/internal/testutils/factories/info.go b/internal/testutils/factories/info.go new file mode 100644 index 0000000..ca3acf3 --- /dev/null +++ b/internal/testutils/factories/info.go @@ -0,0 +1,41 @@ +package factories + +import ( + "github.com/sergeii/swat4master/internal/core/entities/details" + "github.com/sergeii/swat4master/pkg/slice" +) + +type BuildInfoOption func(map[string]string) + +func WithFields(extra map[string]string) BuildInfoOption { + return func(fields map[string]string) { + for k, v := range extra { + fields[k] = v + } + } +} + +func BuildInfo(opts ...BuildInfoOption) details.Info { + fields := map[string]string{ + "hostname": slice.RandomChoice([]string{ + "Swat4 Server", + "Awesome Server", + "Another Swat4 Server", + "Pro Server", + }), + "hostport": slice.RandomChoice([]string{ + "10480", + "10580", + }), + "mapname": slice.RandomChoice([]string{"A-Bomb Nightclub", "Food Wall Restaurant", "-EXP- FunTime Amusements"}), + "gamever": slice.RandomChoice([]string{"1.0", "1.1"}), + "gamevariant": slice.RandomChoice([]string{"SWAT 4", "SEF", "SWAT 4X"}), + "gametype": slice.RandomChoice([]string{"VIP Escort", "Rapid Deployment", "Barricaded Suspects", "CO-OP"}), + } + + for _, opt := range opts { + opt(fields) + } + + return details.MustNewInfoFromParams(fields) +} diff --git a/internal/testutils/factories/instance.go b/internal/testutils/factories/instance.go new file mode 100644 index 0000000..2444c78 --- /dev/null +++ b/internal/testutils/factories/instance.go @@ -0,0 +1,19 @@ +package factories + +import ( + "context" + + "github.com/sergeii/swat4master/internal/core/entities/instance" + "github.com/sergeii/swat4master/internal/core/repositories" +) + +func SaveInstance( + ctx context.Context, + repo repositories.InstanceRepository, + ins instance.Instance, +) instance.Instance { + if err := repo.Add(ctx, ins); err != nil { + panic(err) + } + return ins +} diff --git a/internal/testutils/factories/server.go b/internal/testutils/factories/server.go index 574e601..686d42a 100644 --- a/internal/testutils/factories/server.go +++ b/internal/testutils/factories/server.go @@ -50,14 +50,6 @@ func WithInfo(fields map[string]string) BuildServerOption { } } -func WithExtraFields(extraFields map[string]string) BuildServerOption { - return func(p *BuildServerParams) { - for k, v := range extraFields { - p.Info[k] = v - } - } -} - func WithNoInfo() BuildServerOption { return func(p *BuildServerParams) { p.Info = nil @@ -102,6 +94,7 @@ func BuildServer(opts ...BuildServerOption) server.Server { svr.UpdateDetails(details.MustNewDetailsFromParams(params.Info, params.Players, params.Objectives), time.Now()) svr.UpdateDiscoveryStatus(params.DiscoveryStatus) + return svr } diff --git a/tests/api/servers_view_test.go b/tests/api/servers_view_test.go index 99be648..04d1720 100644 --- a/tests/api/servers_view_test.go +++ b/tests/api/servers_view_test.go @@ -512,10 +512,15 @@ func TestAPI_ViewServer_ValidateAddress(t *testing.T) { false, }, { - "private ip address", + "local ip address", "127.0.0.1:10480", false, }, + { + "private ip address", + "192.168.1.1:10480", + false, + }, { "v6 ip address", "2001:db8:3c4d:15::1a2f:1a2b:10480", diff --git a/tests/modules/cleaner_test.go b/tests/modules/cleaner_test.go index 2d0e78e..b6c319d 100644 --- a/tests/modules/cleaner_test.go +++ b/tests/modules/cleaner_test.go @@ -6,23 +6,21 @@ import ( "testing" "time" + "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" "go.uber.org/fx" "github.com/sergeii/swat4master/cmd/swat4master/application" "github.com/sergeii/swat4master/cmd/swat4master/config" "github.com/sergeii/swat4master/cmd/swat4master/modules/cleaner" - "github.com/sergeii/swat4master/internal/core/entities/server" + "github.com/sergeii/swat4master/internal/core/entities/instance" "github.com/sergeii/swat4master/internal/core/repositories" + "github.com/sergeii/swat4master/internal/services/monitoring" + "github.com/sergeii/swat4master/internal/testutils/factories" ) -func TestCleaner_Run(t *testing.T) { - var repo repositories.ServerRepository - - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - app := fx.New( +func makeAppWithCleaner(extra ...fx.Option) (*fx.App, func()) { + fxopts := []fx.Option{ application.Module, fx.Provide(func() config.Config { return config.Config{ @@ -33,32 +31,117 @@ func TestCleaner_Run(t *testing.T) { cleaner.Module, fx.NopLogger, fx.Invoke(func(*cleaner.Cleaner) {}), - fx.Populate(&repo), - ) - app.Start(context.TODO()) // nolint: errcheck - defer func() { + } + fxopts = append(fxopts, extra...) + app := fx.New(fxopts...) + return app, func() { app.Stop(context.TODO()) // nolint: errcheck - }() + } +} + +func TestCleaner_OK(t *testing.T) { + var serverRepo repositories.ServerRepository + var instanceRepo repositories.InstanceRepository + var metrics *monitoring.MetricService + + ctx := context.TODO() + app, cancel := makeAppWithCleaner( + fx.Populate(&serverRepo, &instanceRepo, &metrics), + ) + defer cancel() + app.Start(ctx) // nolint: errcheck - gs1 := server.MustNew(net.ParseIP("1.1.1.1"), 10480, 10481) - gs2 := server.MustNew(net.ParseIP("2.2.2.2"), 10480, 10481) - gs3 := server.MustNew(net.ParseIP("3.3.3.3"), 10480, 10481) + ins1 := instance.MustNew("foo", net.ParseIP("1.1.1.1"), 10480) + ins2 := instance.MustNew("bar", net.ParseIP("3.3.3.3"), 10480) + ins4 := instance.MustNew("baz", net.ParseIP("4.4.4.4"), 10480) - repo.Add(ctx, gs1, nil) // nolint: errcheck - repo.Add(ctx, gs2, repositories.ServerOnConflictIgnore) // nolint: errcheck - repo.Add(ctx, gs3, repositories.ServerOnConflictIgnore) // nolint: errcheck + factories.SaveInstance(ctx, instanceRepo, ins1) + factories.SaveInstance(ctx, instanceRepo, ins2) + factories.SaveInstance(ctx, instanceRepo, ins4) + + gs1 := factories.CreateServer( + ctx, + serverRepo, + factories.WithAddress("1.1.1.1", 10480), + factories.WithQueryPort(10481), + ) + factories.CreateServer( + ctx, + serverRepo, + factories.WithAddress("2.2.2.2", 10480), + factories.WithQueryPort(10481), + ) + factories.CreateServer( + ctx, + serverRepo, + factories.WithAddress("3.3.3.3", 10480), + factories.WithQueryPort(10481), + ) // wait for cleaner to run some cycles <-time.After(time.Millisecond * 100) // refresh server 1 to prevent it from being cleaned gs1.Refresh(time.Now()) - repo.Update(ctx, gs1, repositories.ServerOnConflictIgnore) // nolint: errcheck + serverRepo.Update(ctx, gs1, repositories.ServerOnConflictIgnore) // nolint: errcheck + + // add a new server with an instance, it should not be cleaned right away + gs5 := factories.CreateServer( + ctx, + serverRepo, + factories.WithAddress("5.5.5.5", 10480), + factories.WithQueryPort(10481), + ) + + ins5 := instance.MustNew("qux", net.ParseIP("5.5.5.5"), 10480) + factories.SaveInstance(ctx, instanceRepo, ins5) // wait for cleaner to clean servers 2 and 3 <-time.After(time.Millisecond * 150) - cnt, err := repo.Count(ctx) + // check that the refreshed server and the new one are still there + svrCount, err := serverRepo.Count(ctx) + assert.NoError(t, err) + assert.Equal(t, 2, svrCount) + + _, err = serverRepo.Get(ctx, gs1.Addr) assert.NoError(t, err) - assert.Equal(t, 1, cnt) + _, err = serverRepo.Get(ctx, gs5.Addr) + assert.NoError(t, err) + + // only 1 instance was removed because only 1 of them belonged to a removed server + insCount, err := instanceRepo.Count(ctx) + assert.NoError(t, err) + assert.Equal(t, 3, insCount) + + _, err = instanceRepo.GetByID(ctx, "foo") + assert.NoError(t, err) + _, err = instanceRepo.GetByID(ctx, "baz") + assert.NoError(t, err) + _, err = instanceRepo.GetByID(ctx, "qux") + assert.NoError(t, err) + + removalValue := testutil.ToFloat64(metrics.CleanerRemovals) + assert.Equal(t, 2.0, removalValue) + errorValue := testutil.ToFloat64(metrics.CleanerErrors) + assert.Equal(t, 0.0, errorValue) +} + +func TestCleaner_NoErrorWhenNothingToClean(t *testing.T) { + var metrics *monitoring.MetricService + + ctx := context.TODO() + app, cancel := makeAppWithCleaner( + fx.Populate(&metrics), + ) + defer cancel() + app.Start(ctx) // nolint: errcheck + + // wait for cleaner to run some cycles + <-time.After(time.Millisecond * 100) + + removalValue := testutil.ToFloat64(metrics.CleanerRemovals) + errorValue := testutil.ToFloat64(metrics.CleanerErrors) + assert.Equal(t, 0.0, removalValue) + assert.Equal(t, 0.0, errorValue) } diff --git a/tests/modules/refresher_test.go b/tests/modules/refresher_test.go index b768147..2286e70 100644 --- a/tests/modules/refresher_test.go +++ b/tests/modules/refresher_test.go @@ -2,10 +2,10 @@ package modules_test import ( "context" - "net" "testing" "time" + "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/fx" @@ -13,43 +13,16 @@ import ( "github.com/sergeii/swat4master/cmd/swat4master/application" "github.com/sergeii/swat4master/cmd/swat4master/config" "github.com/sergeii/swat4master/cmd/swat4master/modules/refresher" - "github.com/sergeii/swat4master/internal/core/entities/details" ds "github.com/sergeii/swat4master/internal/core/entities/discovery/status" "github.com/sergeii/swat4master/internal/core/entities/probe" "github.com/sergeii/swat4master/internal/core/entities/server" "github.com/sergeii/swat4master/internal/core/repositories" + "github.com/sergeii/swat4master/internal/services/monitoring" + "github.com/sergeii/swat4master/internal/testutils/factories" ) -func TestRefresher_Run_OK(t *testing.T) { - var serverRepo repositories.ServerRepository - var probeRepo repositories.ProbeRepository - - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - assertProbes := func(wantCount, wantExpired int, wantDetails []string) { - count, err := probeRepo.Count(ctx) - require.NoError(t, err) - assert.Equal(t, wantCount, count) - - probes, expired, err := probeRepo.PopMany(ctx, count) - require.NoError(t, err) - detailsProbes := make([]string, 0, len(wantDetails)) - portProbes := make([]string, 0) - for _, prb := range probes { - switch prb.Goal { - case probe.GoalDetails: - detailsProbes = append(detailsProbes, prb.Addr.String()) - case probe.GoalPort: - portProbes = append(portProbes, prb.Addr.String()) - } - } - assert.Equal(t, wantExpired, expired) - assert.Equal(t, wantDetails, detailsProbes) - assert.Equal(t, []string{}, portProbes) - } - - app := fx.New( +func makeAppWithRefresher(extra ...fx.Option) (*fx.App, func()) { + fxopts := []fx.Option{ application.Module, fx.Provide(func() config.Config { return config.Config{ @@ -59,62 +32,110 @@ func TestRefresher_Run_OK(t *testing.T) { refresher.Module, fx.NopLogger, fx.Invoke(func(*refresher.Refresher) {}), - fx.Populate(&serverRepo, &probeRepo), - ) + } + fxopts = append(fxopts, extra...) + app := fx.New(fxopts...) + return app, func() { + app.Stop(context.TODO()) // nolint: errcheck + } +} - info := details.MustNewInfoFromParams(map[string]string{ - "hostname": "Awesome Server", - "hostport": "10580", - "mapname": "A-Bomb Nightclub", - "gamever": "1.1", - "gamevariant": "SWAT 4", - "gametype": "CO-OP", - }) - - gs1 := server.MustNew(net.ParseIP("1.1.1.1"), 10480, 10481) - gs1.UpdateInfo(info, time.Now()) - gs1.UpdateDiscoveryStatus(ds.Master) - - gs2 := server.MustNew(net.ParseIP("2.2.2.2"), 10480, 10481) - gs2.UpdateInfo(info, time.Now()) - gs2.UpdateDiscoveryStatus(ds.Port) - - gs3 := server.MustNew(net.ParseIP("3.3.3.3"), 10480, 10481) - gs3.UpdateInfo(info, time.Now()) - gs3.UpdateDiscoveryStatus(ds.Master | ds.Details | ds.Port) - - gs4 := server.MustNew(net.ParseIP("5.5.5.5"), 10480, 10481) - gs4.UpdateInfo(info, time.Now()) - gs4.UpdateDiscoveryStatus(ds.DetailsRetry) - - gs5 := server.MustNew(net.ParseIP("6.6.6.6"), 10480, 10481) - gs5.UpdateInfo(info, time.Now()) - gs5.UpdateDiscoveryStatus(ds.Port | ds.Details | ds.DetailsRetry) - - gs6 := server.MustNew(net.ParseIP("7.7.7.7"), 10480, 10481) - gs6.UpdateInfo(info, time.Now()) - gs6.UpdateDiscoveryStatus(ds.Master | ds.Info | ds.Details) - - gs7 := server.MustNew(net.ParseIP("9.9.9.9"), 10480, 10481) - gs7.UpdateInfo(info, time.Now()) - gs7.UpdateDiscoveryStatus(ds.Port | ds.PortRetry) - - for _, gs := range []*server.Server{&gs1, &gs2, &gs3, &gs4, &gs5, &gs6, &gs7} { - *gs, _ = serverRepo.Add(ctx, *gs, repositories.ServerOnConflictIgnore) +type refresherProbeCount struct { + count int + expired int + probes []string +} + +func countRefresherProbes( + ctx context.Context, + repo repositories.ProbeRepository, +) (refresherProbeCount, error) { + count, err := repo.Count(ctx) + if err != nil { + return refresherProbeCount{}, err } - app.Start(context.TODO()) // nolint: errcheck - defer func() { - app.Stop(context.TODO()) // nolint: errcheck - }() + probes, expired, err := repo.PopMany(ctx, count) + if err != nil { + return refresherProbeCount{}, err + } + + detailsProbes := make([]string, 0, count) + for _, prb := range probes { + if prb.Goal == probe.GoalDetails { + detailsProbes = append(detailsProbes, prb.Addr.String()) + } + } + return refresherProbeCount{ + count: count, + expired: expired, + probes: detailsProbes, + }, nil +} + +func TestRefresher_OK(t *testing.T) { + var serverRepo repositories.ServerRepository + var probeRepo repositories.ProbeRepository + var metrics *monitoring.MetricService + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + gs1 := factories.BuildServer( + factories.WithAddress("1.1.1.1", 10480), + factories.WithQueryPort(10481), + factories.WithDiscoveryStatus(ds.Master), + ) + gs2 := factories.BuildServer( + factories.WithAddress("2.2.2.2", 10480), + factories.WithQueryPort(10481), + factories.WithDiscoveryStatus(ds.Port), + ) + gs3 := factories.BuildServer( + factories.WithAddress("3.3.3.3", 10480), + factories.WithQueryPort(10481), + factories.WithDiscoveryStatus(ds.Master|ds.Details|ds.Port), + ) + gs4 := factories.BuildServer( + factories.WithAddress("5.5.5.5", 10480), + factories.WithQueryPort(10481), + factories.WithDiscoveryStatus(ds.DetailsRetry), + ) + gs5 := factories.BuildServer( + factories.WithAddress("6.6.6.6", 10480), + factories.WithQueryPort(10481), + factories.WithDiscoveryStatus(ds.Port|ds.Details|ds.DetailsRetry), + ) + gs6 := factories.BuildServer( + factories.WithAddress("7.7.7.7", 10480), + factories.WithQueryPort(10481), + factories.WithDiscoveryStatus(ds.Master|ds.Info|ds.Details), + ) + gs7 := factories.BuildServer( + factories.WithAddress("9.9.9.9", 10480), + factories.WithQueryPort(10481), + factories.WithDiscoveryStatus(ds.Port|ds.PortRetry), + ) + + app, cancel := makeAppWithRefresher( + fx.Populate(&serverRepo, &probeRepo, &metrics), + ) + defer cancel() + app.Start(ctx) // nolint: errcheck + + for _, gs := range []server.Server{gs1, gs2, gs3, gs4, gs5, gs6, gs7} { + factories.SaveServer(ctx, serverRepo, gs) + } // let refresher run a cycle <-time.After(time.Millisecond * 150) // details probes are added - assertProbes(3, 0, - []string{"9.9.9.9:10480", "3.3.3.3:10480", "2.2.2.2:10480"}, - ) + result, err := countRefresherProbes(ctx, probeRepo) + require.NoError(t, err) + assert.Equal(t, 3, result.count) + assert.Equal(t, 0, result.expired) + assert.Equal(t, []string{"9.9.9.9:10480", "3.3.3.3:10480", "2.2.2.2:10480"}, result.probes) // clear the server's refreshable status, so that it doesn't get picked up again gs3.ClearDiscoveryStatus(ds.Details | ds.Port) @@ -127,22 +148,20 @@ func TestRefresher_Run_OK(t *testing.T) { // let refresher run another cycle <-time.After(time.Millisecond * 100) - assertProbes(3, - 0, - []string{ - "6.6.6.6:10480", "9.9.9.9:10480", "2.2.2.2:10480", - }, - ) + result, err = countRefresherProbes(ctx, probeRepo) + require.NoError(t, err) + assert.Equal(t, 3, result.count) + assert.Equal(t, 0, result.expired) + assert.Equal(t, []string{"6.6.6.6:10480", "9.9.9.9:10480", "2.2.2.2:10480"}, result.probes) // run a couple of cycles, expect some probes to expire <-time.After(time.Millisecond * 200) - assertProbes(6, - 3, - []string{ - "6.6.6.6:10480", "9.9.9.9:10480", "2.2.2.2:10480", - }, - ) + result, err = countRefresherProbes(ctx, probeRepo) + require.NoError(t, err) + assert.Equal(t, 6, result.count) + assert.Equal(t, 3, result.expired) + assert.Equal(t, []string{"6.6.6.6:10480", "9.9.9.9:10480", "2.2.2.2:10480"}, result.probes) // make the remaining servers non-refreshable gs2.ClearDiscoveryStatus(ds.Port) @@ -156,5 +175,12 @@ func TestRefresher_Run_OK(t *testing.T) { // run another cycle, expect no probes <-time.After(time.Millisecond * 100) - assertProbes(0, 0, []string{}) + result, err = countRefresherProbes(ctx, probeRepo) + require.NoError(t, err) + assert.Equal(t, 0, result.count) + assert.Equal(t, 0, result.expired) + assert.Equal(t, []string{}, result.probes) + + producedMetricValue := testutil.ToFloat64(metrics.DiscoveryQueueProduced) + assert.Equal(t, 12.0, producedMetricValue) } diff --git a/tests/modules/reviver_test.go b/tests/modules/reviver_test.go index a7b076b..8cc9867 100644 --- a/tests/modules/reviver_test.go +++ b/tests/modules/reviver_test.go @@ -2,10 +2,10 @@ package modules_test import ( "context" - "net" "testing" "time" + "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/fx" @@ -13,43 +13,16 @@ import ( "github.com/sergeii/swat4master/cmd/swat4master/application" "github.com/sergeii/swat4master/cmd/swat4master/config" "github.com/sergeii/swat4master/cmd/swat4master/modules/reviver" - "github.com/sergeii/swat4master/internal/core/entities/details" ds "github.com/sergeii/swat4master/internal/core/entities/discovery/status" "github.com/sergeii/swat4master/internal/core/entities/probe" "github.com/sergeii/swat4master/internal/core/entities/server" "github.com/sergeii/swat4master/internal/core/repositories" + "github.com/sergeii/swat4master/internal/services/monitoring" + "github.com/sergeii/swat4master/internal/testutils/factories" ) -func TestReviver_Run_OK(t *testing.T) { - var serverRepo repositories.ServerRepository - var probeRepo repositories.ProbeRepository - - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - - assertProbes := func(wantCount, wantExpired int, wantPorts []string) { - count, err := probeRepo.Count(ctx) - require.NoError(t, err) - assert.Equal(t, wantCount, count) - - probes, expired, err := probeRepo.PopMany(ctx, count) - require.NoError(t, err) - detailsProbes := make([]string, 0) - portProbes := make([]string, 0, len(wantPorts)) - for _, prb := range probes { - switch prb.Goal { - case probe.GoalDetails: - detailsProbes = append(detailsProbes, prb.Addr.String()) - case probe.GoalPort: - portProbes = append(portProbes, prb.Addr.String()) - } - } - assert.Equal(t, wantExpired, expired) - assert.Equal(t, []string{}, detailsProbes) - assert.Equal(t, wantPorts, portProbes) - } - - app := fx.New( +func makeAppWithReviver(extra ...fx.Option) (*fx.App, func()) { + fxopts := []fx.Option{ application.Module, fx.Provide(func() config.Config { return config.Config{ @@ -62,58 +35,105 @@ func TestReviver_Run_OK(t *testing.T) { reviver.Module, fx.NopLogger, fx.Invoke(func(*reviver.Reviver) {}), - fx.Populate(&serverRepo, &probeRepo), - ) + } + fxopts = append(fxopts, extra...) + app := fx.New(fxopts...) + return app, func() { + app.Stop(context.TODO()) // nolint: errcheck + } +} - info := details.MustNewInfoFromParams(map[string]string{ - "hostname": "Awesome Server", - "hostport": "10580", - "mapname": "A-Bomb Nightclub", - "gamever": "1.1", - "gamevariant": "SWAT 4", - "gametype": "CO-OP", - }) - - gs1 := server.MustNew(net.ParseIP("1.1.1.1"), 10480, 10481) - gs1.UpdateInfo(info, time.Now()) - gs1.UpdateDiscoveryStatus(ds.Master) - - gs2 := server.MustNew(net.ParseIP("2.2.2.2"), 10480, 10481) - gs2.UpdateInfo(info, time.Now()) - gs2.UpdateDiscoveryStatus(ds.Port) - - gs3 := server.MustNew(net.ParseIP("3.3.3.3"), 10480, 10481) - gs3.UpdateInfo(info, time.Now()) - gs3.UpdateDiscoveryStatus(ds.Master | ds.Details | ds.Port) - - gs4 := server.MustNew(net.ParseIP("4.4.4.4"), 10480, 10481) - gs4.UpdateInfo(info, time.Now()) - gs4.UpdateDiscoveryStatus(ds.DetailsRetry) - - gs5 := server.MustNew(net.ParseIP("5.5.5.5"), 10480, 10481) - gs5.UpdateInfo(info, time.Now()) - gs5.UpdateDiscoveryStatus(ds.Master | ds.Info | ds.Details) - - gs6 := server.MustNew(net.ParseIP("6.6.6.6"), 10480, 10481) - gs6.UpdateInfo(info, time.Now()) - gs6.UpdateDiscoveryStatus(ds.Master | ds.PortRetry) - - for _, gs := range []*server.Server{&gs1, &gs2, &gs3, &gs4, &gs5, &gs6} { - *gs, _ = serverRepo.Add(ctx, *gs, repositories.ServerOnConflictIgnore) +type reviverProbeCount struct { + count int + expired int + probes []string +} + +func countReviverProbes( + ctx context.Context, + repo repositories.ProbeRepository, +) (reviverProbeCount, error) { + count, err := repo.Count(ctx) + if err != nil { + return reviverProbeCount{}, err } - app.Start(context.TODO()) // nolint: errcheck - defer func() { - app.Stop(context.TODO()) // nolint: errcheck - }() + probes, expired, err := repo.PopMany(ctx, count) + if err != nil { + return reviverProbeCount{}, err + } + + portProbes := make([]string, 0, count) + for _, prb := range probes { + if prb.Goal == probe.GoalPort { + portProbes = append(portProbes, prb.Addr.String()) + } + } + return reviverProbeCount{ + count: count, + expired: expired, + probes: portProbes, + }, nil +} + +func TestReviver_OK(t *testing.T) { + var serverRepo repositories.ServerRepository + var probeRepo repositories.ProbeRepository + var metrics *monitoring.MetricService + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + gs1 := factories.BuildServer( + factories.WithAddress("1.1.1.1", 10480), + factories.WithQueryPort(10481), + factories.WithDiscoveryStatus(ds.Master), + ) + gs2 := factories.BuildServer( + factories.WithAddress("2.2.2.2", 10480), + factories.WithQueryPort(10481), + factories.WithDiscoveryStatus(ds.Port), + ) + gs3 := factories.BuildServer( + factories.WithAddress("3.3.3.3", 10480), + factories.WithQueryPort(10481), + factories.WithDiscoveryStatus(ds.Master|ds.Details|ds.Port), + ) + gs4 := factories.BuildServer( + factories.WithAddress("4.4.4.4", 10480), + factories.WithQueryPort(10481), + factories.WithDiscoveryStatus(ds.DetailsRetry), + ) + gs5 := factories.BuildServer( + factories.WithAddress("5.5.5.5", 10480), + factories.WithQueryPort(10481), + factories.WithDiscoveryStatus(ds.Master|ds.Info|ds.Details), + ) + gs6 := factories.BuildServer( + factories.WithAddress("6.6.6.6", 10480), + factories.WithQueryPort(10481), + factories.WithDiscoveryStatus(ds.Master|ds.PortRetry), + ) + + app, cancel := makeAppWithReviver( + fx.Populate(&serverRepo, &probeRepo, &metrics), + ) + defer cancel() + app.Start(ctx) // nolint: errcheck + + for _, gs := range []server.Server{gs1, gs2, gs3, gs4, gs5, gs6} { + factories.SaveServer(ctx, serverRepo, gs) + } // let refresher run a cycle <-time.After(time.Millisecond * 150) // port probes are added - assertProbes(3, 0, - []string{"5.5.5.5:10480", "4.4.4.4:10480", "1.1.1.1:10480"}, - ) + result, err := countReviverProbes(ctx, probeRepo) + require.NoError(t, err) + assert.Equal(t, 3, result.count) + assert.Equal(t, 0, result.expired) + assert.Equal(t, []string{"5.5.5.5:10480", "4.4.4.4:10480", "1.1.1.1:10480"}, result.probes) // make gs3 non-revivable gs3.ClearDiscoveryStatus(ds.Port) @@ -121,16 +141,19 @@ func TestReviver_Run_OK(t *testing.T) { // let reviver run another cycle <-time.After(time.Millisecond * 100) - assertProbes(4, 0, - []string{"3.3.3.3:10480", "5.5.5.5:10480", "4.4.4.4:10480", "1.1.1.1:10480"}, - ) + result, err = countReviverProbes(ctx, probeRepo) + require.NoError(t, err) + assert.Equal(t, 4, result.count) + assert.Equal(t, 0, result.expired) + assert.Equal(t, []string{"3.3.3.3:10480", "5.5.5.5:10480", "4.4.4.4:10480", "1.1.1.1:10480"}, result.probes) // run a couple of cycles, expect some probes to expire <-time.After(time.Millisecond * 200) - - assertProbes(8, 4, - []string{"3.3.3.3:10480", "5.5.5.5:10480", "4.4.4.4:10480", "1.1.1.1:10480"}, - ) + result, err = countReviverProbes(ctx, probeRepo) + require.NoError(t, err) + assert.Equal(t, 8, result.count) + assert.Equal(t, 4, result.expired) + assert.Equal(t, []string{"3.3.3.3:10480", "5.5.5.5:10480", "4.4.4.4:10480", "1.1.1.1:10480"}, result.probes) // make the remaining servers non-revivable gs1.UpdateDiscoveryStatus(ds.Port) @@ -143,11 +166,20 @@ func TestReviver_Run_OK(t *testing.T) { // run another cycle, expect only gs3 to be revived <-time.After(time.Millisecond * 100) - - assertProbes(1, 0, []string{"3.3.3.3:10480"}) + result, err = countReviverProbes(ctx, probeRepo) + require.NoError(t, err) + assert.Equal(t, 1, result.count) + assert.Equal(t, 0, result.expired) + assert.Equal(t, []string{"3.3.3.3:10480"}, result.probes) // the remaining server goes out of revival scope <-time.After(time.Millisecond * 500) - - assertProbes(4, 4, []string{}) + result, err = countReviverProbes(ctx, probeRepo) + require.NoError(t, err) + assert.Equal(t, 4, result.count) + assert.Equal(t, 4, result.expired) + assert.Equal(t, []string{}, result.probes) + + producedMetricValue := testutil.ToFloat64(metrics.DiscoveryQueueProduced) + assert.Equal(t, 20.0, producedMetricValue) }