Skip to content

Commit

Permalink
feature: rate limit for generic scans (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
v-byte-cpu committed Jun 8, 2021
1 parent 5234e02 commit 171a529
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 8 deletions.
25 changes: 25 additions & 0 deletions command/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package command

import (
"context"
"errors"
"io"
"io/ioutil"
Expand All @@ -16,6 +17,7 @@ import (
"github.com/v-byte-cpu/sx/pkg/ip"
"github.com/v-byte-cpu/sx/pkg/scan"
"github.com/v-byte-cpu/sx/pkg/scan/arp"
"go.uber.org/ratelimit"
)

const (
Expand Down Expand Up @@ -347,16 +349,25 @@ type genericScanCmdOpts struct {
ipFile string
portRanges []*scan.PortRange
workers int
rateCount int
rateWindow time.Duration
exitDelay time.Duration

rawPortRanges string
rawRateLimit string
}

func (o *genericScanCmdOpts) initCliFlags(cmd *cobra.Command) {
cmd.Flags().BoolVar(&o.json, "json", false, "enable JSON output")
cmd.Flags().StringVarP(&o.rawPortRanges, "ports", "p", "", "set ports to scan")
cmd.Flags().StringVarP(&o.ipFile, "file", "f", "", "set JSONL file with ip/port pairs to scan")
cmd.Flags().IntVarP(&o.workers, "workers", "w", defaultWorkerCount, "set workers count")
cmd.Flags().StringVarP(&o.rawRateLimit, "rate", "r", "",
strings.Join([]string{
"set rate limit for generated scan requests",
`format: "rateCount/rateWindow"`,
"where rateCount is a number of scan requests, rateWindow is the time interval",
"e.g. 1000/s -- 1000 requests per second", "500/7s -- 500 requests per 7 seconds\n"}, "\n"))
cmd.Flags().DurationVar(&o.exitDelay, "exit-delay", defaultExitDelay,
strings.Join([]string{
"set exit delay to wait for last response",
Expand All @@ -369,6 +380,11 @@ func (o *genericScanCmdOpts) parseRawOptions() (err error) {
return
}
}
if len(o.rawRateLimit) > 0 {
if o.rateCount, o.rateWindow, err = parseRateLimit(o.rawRateLimit); err != nil {
return
}
}
if o.workers <= 0 {
return errors.New("invalid workers count")
}
Expand Down Expand Up @@ -403,6 +419,15 @@ func (o *genericScanCmdOpts) getLogger(name string, w io.Writer) (logger log.Log
return
}

func (o *genericScanCmdOpts) newScanEngine(ctx context.Context, scanner scan.Scanner) *scan.GenericEngine {
if o.rateCount > 0 {
scanner = scan.NewRateLimitScanner(scanner,
ratelimit.New(o.rateCount, ratelimit.Per(o.rateWindow)))
}
results := scan.NewResultChan(ctx, 1000)
return scan.NewScanEngine(o.newIPPortGenerator(), scanner, results, scan.WithScanWorkerCount(o.workers))
}

func (o *genericScanCmdOpts) newIPPortGenerator() (reqgen scan.RequestGenerator) {
if len(o.ipFile) == 0 {
return scan.NewIPPortGenerator(scan.NewIPGenerator(), scan.NewPortGenerator())
Expand Down
6 changes: 5 additions & 1 deletion command/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,20 +151,22 @@ func TestGenericScanCmdOptsInitCliFlags(t *testing.T) {

opts.initCliFlags(cmd)
err := cmd.ParseFlags(strings.Split(
"--json -p 23-57,71-2733 -f ip_file.jsonl -w 300 --exit-delay 10s", " "))
"--json -p 23-57,71-2733 -f ip_file.jsonl -w 300 -r 500/7s --exit-delay 10s", " "))

require.NoError(t, err)
require.Equal(t, true, opts.json)
require.Equal(t, "23-57,71-2733", opts.rawPortRanges)
require.Equal(t, "ip_file.jsonl", opts.ipFile)
require.Equal(t, 300, opts.workers)
require.Equal(t, "500/7s", opts.rawRateLimit)
require.Equal(t, 10*time.Second, opts.exitDelay)
}

func TestGenericScanCmdOptsParseRawOptions(t *testing.T) {
t.Parallel()
opts := genericScanCmdOpts{
rawPortRanges: "23-57,71-2733",
rawRateLimit: "500/7s",
workers: 300,
}

Expand All @@ -174,6 +176,8 @@ func TestGenericScanCmdOptsParseRawOptions(t *testing.T) {
require.Equal(t, []*scan.PortRange{
{StartPort: 23, EndPort: 57},
{StartPort: 71, EndPort: 2733}}, opts.portRanges)
require.Equal(t, 500, opts.rateCount)
require.Equal(t, 7*time.Second, opts.rateWindow)
}

func TestIPScanCmdOptsIsARPCacheFromStdin(t *testing.T) {
Expand Down
3 changes: 1 addition & 2 deletions command/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,5 @@ func (o *dockerCmdOpts) parseRawOptions() (err error) {

func (o *dockerCmdOpts) newDockerScanEngine(ctx context.Context) scan.EngineResulter {
scanner := docker.NewScanner(o.proto, docker.WithDataTimeout(o.timeout))
results := scan.NewResultChan(ctx, 1000)
return scan.NewScanEngine(o.newIPPortGenerator(), scanner, results, scan.WithScanWorkerCount(o.workers))
return o.newScanEngine(ctx, scanner)
}
3 changes: 1 addition & 2 deletions command/elastic.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,5 @@ func (o *elasticCmdOpts) parseRawOptions() (err error) {

func (o *elasticCmdOpts) newElasticScanEngine(ctx context.Context) scan.EngineResulter {
scanner := elastic.NewScanner(o.proto, elastic.WithDataTimeout(o.timeout))
results := scan.NewResultChan(ctx, 1000)
return scan.NewScanEngine(o.newIPPortGenerator(), scanner, results, scan.WithScanWorkerCount(o.workers))
return o.newScanEngine(ctx, scanner)
}
3 changes: 1 addition & 2 deletions command/socks.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,5 @@ func (o *socksCmdOpts) newSOCKSScanEngine(ctx context.Context) scan.EngineResult
scanner := socks5.NewScanner(
socks5.WithDialTimeout(o.timeout),
socks5.WithDataTimeout(o.timeout))
results := scan.NewResultChan(ctx, 1000)
return scan.NewScanEngine(o.newIPPortGenerator(), scanner, results, scan.WithScanWorkerCount(o.workers))
return o.newScanEngine(ctx, scanner)
}
21 changes: 20 additions & 1 deletion pkg/scan/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"net"
"sync"
"time"

"github.com/v-byte-cpu/sx/pkg/packet"
)
Expand Down Expand Up @@ -133,6 +134,25 @@ type Scanner interface {
Scan(ctx context.Context, r *Request) (Result, error)
}

type RateLimiter interface {
// Take should block to make sure that the RPS is met.
Take() time.Time
}

type rateLimitScanner struct {
Scanner
limiter RateLimiter
}

func NewRateLimitScanner(delegate Scanner, limiter RateLimiter) Scanner {
return &rateLimitScanner{Scanner: delegate, limiter: limiter}
}

func (s *rateLimitScanner) Scan(ctx context.Context, r *Request) (Result, error) {
s.limiter.Take()
return s.Scanner.Scan(ctx, r)
}

type GenericEngine struct {
reqgen RequestGenerator
scanner Scanner
Expand Down Expand Up @@ -207,7 +227,6 @@ func (e *GenericEngine) worker(ctx context.Context, wg *sync.WaitGroup,
writeError(ctx, errc, r.Err)
continue
}
// TODO rate limit
result, err := e.scanner.Scan(ctx, r)
if err != nil {
writeError(ctx, errc, err)
Expand Down
37 changes: 37 additions & 0 deletions pkg/scan/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/v-byte-cpu/sx/pkg/packet"
"go.uber.org/ratelimit"
)

func TestMergeErrChanEmptyChannels(t *testing.T) {
Expand Down Expand Up @@ -200,6 +201,42 @@ func TestPacketSourceReturnsData(t *testing.T) {
waitDone(t, done)
}

func TestRateLimitScanner(t *testing.T) {
t.Parallel()

done := make(chan interface{})
go func() {
defer close(done)

ctrl := gomock.NewController(t)
scanner := NewMockScanner(ctrl)

req1 := &Request{DstIP: net.IPv4(192, 168, 0, 1), DstPort: 22}
expectedResult := &mockScanResult{"id1"}
scanner.EXPECT().Scan(gomock.Not(gomock.Nil()), req1).
Return(expectedResult, nil).AnyTimes()

rateScanner := NewRateLimitScanner(scanner,
ratelimit.New(2, ratelimit.Per(20*time.Millisecond)))
timer := time.After(10 * time.Millisecond)
count := 0
loop:
for {
select {
case <-timer:
break loop
default:
result, err := rateScanner.Scan(context.Background(), req1)
require.NoError(t, err)
require.Equal(t, expectedResult, result)
count++
}
}
require.LessOrEqual(t, count, 2)
}()
waitDone(t, done)
}

func TestScanEngineWithRequestGeneratorError(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 171a529

Please sign in to comment.