From 171a5295bdca6fc933d1b450b29e09651f2043d1 Mon Sep 17 00:00:00 2001 From: v-byte-cpu <65545655+v-byte-cpu@users.noreply.github.com> Date: Wed, 9 Jun 2021 00:58:23 +0300 Subject: [PATCH] feature: rate limit for generic scans (#82) --- command/config.go | 25 +++++++++++++++++++++++++ command/config_test.go | 6 +++++- command/docker.go | 3 +-- command/elastic.go | 3 +-- command/socks.go | 3 +-- pkg/scan/engine.go | 21 ++++++++++++++++++++- pkg/scan/engine_test.go | 37 +++++++++++++++++++++++++++++++++++++ 7 files changed, 90 insertions(+), 8 deletions(-) diff --git a/command/config.go b/command/config.go index 4a68a16..52b87fc 100644 --- a/command/config.go +++ b/command/config.go @@ -1,6 +1,7 @@ package command import ( + "context" "errors" "io" "io/ioutil" @@ -16,6 +17,7 @@ import ( "github.com/v-byte-cpu/sx/pkg/ip" "github.com/v-byte-cpu/sx/pkg/scan" "github.com/v-byte-cpu/sx/pkg/scan/arp" + "go.uber.org/ratelimit" ) const ( @@ -347,9 +349,12 @@ type genericScanCmdOpts struct { ipFile string portRanges []*scan.PortRange workers int + rateCount int + rateWindow time.Duration exitDelay time.Duration rawPortRanges string + rawRateLimit string } func (o *genericScanCmdOpts) initCliFlags(cmd *cobra.Command) { @@ -357,6 +362,12 @@ func (o *genericScanCmdOpts) initCliFlags(cmd *cobra.Command) { cmd.Flags().StringVarP(&o.rawPortRanges, "ports", "p", "", "set ports to scan") cmd.Flags().StringVarP(&o.ipFile, "file", "f", "", "set JSONL file with ip/port pairs to scan") cmd.Flags().IntVarP(&o.workers, "workers", "w", defaultWorkerCount, "set workers count") + cmd.Flags().StringVarP(&o.rawRateLimit, "rate", "r", "", + strings.Join([]string{ + "set rate limit for generated scan requests", + `format: "rateCount/rateWindow"`, + "where rateCount is a number of scan requests, rateWindow is the time interval", + "e.g. 1000/s -- 1000 requests per second", "500/7s -- 500 requests per 7 seconds\n"}, "\n")) cmd.Flags().DurationVar(&o.exitDelay, "exit-delay", defaultExitDelay, strings.Join([]string{ "set exit delay to wait for last response", @@ -369,6 +380,11 @@ func (o *genericScanCmdOpts) parseRawOptions() (err error) { return } } + if len(o.rawRateLimit) > 0 { + if o.rateCount, o.rateWindow, err = parseRateLimit(o.rawRateLimit); err != nil { + return + } + } if o.workers <= 0 { return errors.New("invalid workers count") } @@ -403,6 +419,15 @@ func (o *genericScanCmdOpts) getLogger(name string, w io.Writer) (logger log.Log return } +func (o *genericScanCmdOpts) newScanEngine(ctx context.Context, scanner scan.Scanner) *scan.GenericEngine { + if o.rateCount > 0 { + scanner = scan.NewRateLimitScanner(scanner, + ratelimit.New(o.rateCount, ratelimit.Per(o.rateWindow))) + } + results := scan.NewResultChan(ctx, 1000) + return scan.NewScanEngine(o.newIPPortGenerator(), scanner, results, scan.WithScanWorkerCount(o.workers)) +} + func (o *genericScanCmdOpts) newIPPortGenerator() (reqgen scan.RequestGenerator) { if len(o.ipFile) == 0 { return scan.NewIPPortGenerator(scan.NewIPGenerator(), scan.NewPortGenerator()) diff --git a/command/config_test.go b/command/config_test.go index a99b789..5694ce0 100644 --- a/command/config_test.go +++ b/command/config_test.go @@ -151,13 +151,14 @@ func TestGenericScanCmdOptsInitCliFlags(t *testing.T) { opts.initCliFlags(cmd) err := cmd.ParseFlags(strings.Split( - "--json -p 23-57,71-2733 -f ip_file.jsonl -w 300 --exit-delay 10s", " ")) + "--json -p 23-57,71-2733 -f ip_file.jsonl -w 300 -r 500/7s --exit-delay 10s", " ")) require.NoError(t, err) require.Equal(t, true, opts.json) require.Equal(t, "23-57,71-2733", opts.rawPortRanges) require.Equal(t, "ip_file.jsonl", opts.ipFile) require.Equal(t, 300, opts.workers) + require.Equal(t, "500/7s", opts.rawRateLimit) require.Equal(t, 10*time.Second, opts.exitDelay) } @@ -165,6 +166,7 @@ func TestGenericScanCmdOptsParseRawOptions(t *testing.T) { t.Parallel() opts := genericScanCmdOpts{ rawPortRanges: "23-57,71-2733", + rawRateLimit: "500/7s", workers: 300, } @@ -174,6 +176,8 @@ func TestGenericScanCmdOptsParseRawOptions(t *testing.T) { require.Equal(t, []*scan.PortRange{ {StartPort: 23, EndPort: 57}, {StartPort: 71, EndPort: 2733}}, opts.portRanges) + require.Equal(t, 500, opts.rateCount) + require.Equal(t, 7*time.Second, opts.rateWindow) } func TestIPScanCmdOptsIsARPCacheFromStdin(t *testing.T) { diff --git a/command/docker.go b/command/docker.go index 3c3ec8c..7494258 100644 --- a/command/docker.go +++ b/command/docker.go @@ -86,6 +86,5 @@ func (o *dockerCmdOpts) parseRawOptions() (err error) { func (o *dockerCmdOpts) newDockerScanEngine(ctx context.Context) scan.EngineResulter { scanner := docker.NewScanner(o.proto, docker.WithDataTimeout(o.timeout)) - results := scan.NewResultChan(ctx, 1000) - return scan.NewScanEngine(o.newIPPortGenerator(), scanner, results, scan.WithScanWorkerCount(o.workers)) + return o.newScanEngine(ctx, scanner) } diff --git a/command/elastic.go b/command/elastic.go index 6c36150..5c27d52 100644 --- a/command/elastic.go +++ b/command/elastic.go @@ -86,6 +86,5 @@ func (o *elasticCmdOpts) parseRawOptions() (err error) { func (o *elasticCmdOpts) newElasticScanEngine(ctx context.Context) scan.EngineResulter { scanner := elastic.NewScanner(o.proto, elastic.WithDataTimeout(o.timeout)) - results := scan.NewResultChan(ctx, 1000) - return scan.NewScanEngine(o.newIPPortGenerator(), scanner, results, scan.WithScanWorkerCount(o.workers)) + return o.newScanEngine(ctx, scanner) } diff --git a/command/socks.go b/command/socks.go index 1524ffc..c7367e5 100644 --- a/command/socks.go +++ b/command/socks.go @@ -74,6 +74,5 @@ func (o *socksCmdOpts) newSOCKSScanEngine(ctx context.Context) scan.EngineResult scanner := socks5.NewScanner( socks5.WithDialTimeout(o.timeout), socks5.WithDataTimeout(o.timeout)) - results := scan.NewResultChan(ctx, 1000) - return scan.NewScanEngine(o.newIPPortGenerator(), scanner, results, scan.WithScanWorkerCount(o.workers)) + return o.newScanEngine(ctx, scanner) } diff --git a/pkg/scan/engine.go b/pkg/scan/engine.go index 234ed9f..057961d 100644 --- a/pkg/scan/engine.go +++ b/pkg/scan/engine.go @@ -6,6 +6,7 @@ import ( "context" "net" "sync" + "time" "github.com/v-byte-cpu/sx/pkg/packet" ) @@ -133,6 +134,25 @@ type Scanner interface { Scan(ctx context.Context, r *Request) (Result, error) } +type RateLimiter interface { + // Take should block to make sure that the RPS is met. + Take() time.Time +} + +type rateLimitScanner struct { + Scanner + limiter RateLimiter +} + +func NewRateLimitScanner(delegate Scanner, limiter RateLimiter) Scanner { + return &rateLimitScanner{Scanner: delegate, limiter: limiter} +} + +func (s *rateLimitScanner) Scan(ctx context.Context, r *Request) (Result, error) { + s.limiter.Take() + return s.Scanner.Scan(ctx, r) +} + type GenericEngine struct { reqgen RequestGenerator scanner Scanner @@ -207,7 +227,6 @@ func (e *GenericEngine) worker(ctx context.Context, wg *sync.WaitGroup, writeError(ctx, errc, r.Err) continue } - // TODO rate limit result, err := e.scanner.Scan(ctx, r) if err != nil { writeError(ctx, errc, err) diff --git a/pkg/scan/engine_test.go b/pkg/scan/engine_test.go index 5a55ffd..206540e 100644 --- a/pkg/scan/engine_test.go +++ b/pkg/scan/engine_test.go @@ -16,6 +16,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/v-byte-cpu/sx/pkg/packet" + "go.uber.org/ratelimit" ) func TestMergeErrChanEmptyChannels(t *testing.T) { @@ -200,6 +201,42 @@ func TestPacketSourceReturnsData(t *testing.T) { waitDone(t, done) } +func TestRateLimitScanner(t *testing.T) { + t.Parallel() + + done := make(chan interface{}) + go func() { + defer close(done) + + ctrl := gomock.NewController(t) + scanner := NewMockScanner(ctrl) + + req1 := &Request{DstIP: net.IPv4(192, 168, 0, 1), DstPort: 22} + expectedResult := &mockScanResult{"id1"} + scanner.EXPECT().Scan(gomock.Not(gomock.Nil()), req1). + Return(expectedResult, nil).AnyTimes() + + rateScanner := NewRateLimitScanner(scanner, + ratelimit.New(2, ratelimit.Per(20*time.Millisecond))) + timer := time.After(10 * time.Millisecond) + count := 0 + loop: + for { + select { + case <-timer: + break loop + default: + result, err := rateScanner.Scan(context.Background(), req1) + require.NoError(t, err) + require.Equal(t, expectedResult, result) + count++ + } + } + require.LessOrEqual(t, count, 2) + }() + waitDone(t, done) +} + func TestScanEngineWithRequestGeneratorError(t *testing.T) { t.Parallel()