Skip to content

Commit

Permalink
Rate limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
v-byte-cpu committed Mar 27, 2021
1 parent bf7f7ed commit 3e995ec
Show file tree
Hide file tree
Showing 12 changed files with 229 additions and 28 deletions.
3 changes: 3 additions & 0 deletions command/arp.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,16 @@ var arpCmd = &cobra.Command{
scanRange: r,
scanMethod: m,
bpfFilter: arp.BPFFilter,
rateCount: cliRateCount,
rateWindow: cliRateWindow,
})
},
}

func newARPScanMethod(ctx context.Context) *arp.ScanMethod {
var reqgen scan.RequestGenerator = scan.NewIPRequestGenerator(scan.NewIPGenerator())
if arpLiveModeFlag {
// TODO rescanTimeout option
reqgen = scan.NewLiveRequestGenerator(reqgen, 1*time.Second)
}
pktgen := scan.NewPacketMultiGenerator(arp.NewPacketFiller(), runtime.NumCPU())
Expand Down
102 changes: 84 additions & 18 deletions command/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,46 @@ import (
"github.com/spf13/cobra"
"github.com/v-byte-cpu/sx/command/log"
"github.com/v-byte-cpu/sx/pkg/ip"
"github.com/v-byte-cpu/sx/pkg/packet"
"github.com/v-byte-cpu/sx/pkg/packet/afpacket"
"github.com/v-byte-cpu/sx/pkg/scan"
"github.com/v-byte-cpu/sx/pkg/scan/arp"
"go.uber.org/ratelimit"
)

var rootCmd = &cobra.Command{
Use: "sx",
Short: "Fast, modern, easy-to-use network scanner",
Version: "0.1.0",
// Parse common flags
PersistentPreRunE: func(cmd *cobra.Command, args []string) (err error) {
if len(cliInterfaceFlag) > 0 {
if cliInterface, err = net.InterfaceByName(cliInterfaceFlag); err != nil {
return
}
}
if len(cliSrcIPFlag) > 0 {
if cliSrcIP = net.ParseIP(cliSrcIPFlag); cliSrcIP == nil {
return errSrcIP
}
}
if len(cliSrcMACFlag) > 0 {
if cliSrcMAC, err = net.ParseMAC(cliSrcMACFlag); err != nil {
return
}
}
if len(cliPortsFlag) > 0 {
if cliPortRanges, err = parsePortRanges(cliPortsFlag); err != nil {
return
}
}
if len(cliRateLimitFlag) > 0 {
if cliRateCount, cliRateWindow, err = parseRateLimit(cliRateLimitFlag); err != nil {
return
}
}
return
},
}

var (
Expand All @@ -32,19 +63,29 @@ var (
cliSrcIPFlag string
cliSrcMACFlag string
cliPortsFlag string
cliRateLimitFlag string

cliInterface *net.Interface
cliSrcIP net.IP
cliSrcMAC net.HardwareAddr
cliPortRanges []*scan.PortRange
cliRateCount int
cliRateWindow time.Duration
)

var (
errSrcIP = errors.New("invalid source IP")
errSrcMAC = errors.New("invalid source MAC")
errSrcInterface = errors.New("invalid source interface")
errRateLimit = errors.New("invalid ratelimit")
)

func init() {
rootCmd.PersistentFlags().BoolVar(&cliJSONFlag, "json", false, "enable JSON output")
rootCmd.PersistentFlags().StringVarP(&cliInterfaceFlag, "iface", "i", "", "set interface to send/receive packets")
rootCmd.PersistentFlags().StringVar(&cliSrcIPFlag, "srcip", "", "set source IP address for generated packets")
rootCmd.PersistentFlags().StringVar(&cliSrcMACFlag, "srcmac", "", "set source MAC address for generated packets")
rootCmd.PersistentFlags().StringVarP(&cliRateLimitFlag, "rate", "r", "", "set rate limit for generated packets")
}

func Main() {
Expand All @@ -60,14 +101,11 @@ type scanConfig struct {
gatewayIP net.IP
}

func parseScanConfig(scanName, subnet, ports string) (c *scanConfig, err error) {
func parseScanConfig(scanName, subnet string) (c *scanConfig, err error) {
var r *scan.Range
if r, err = parseScanRange(subnet); err != nil {
return
}
if r.Ports, err = parsePortRanges(ports); err != nil {
return
}

var logger log.Logger
if logger, err = getLogger(scanName, os.Stdout); err != nil {
Expand Down Expand Up @@ -108,18 +146,16 @@ func parseScanRange(subnet string) (*scan.Range, error) {
}

srcIP := srcSubnet.IP
if len(cliSrcIPFlag) > 0 {
srcIP = net.ParseIP(cliSrcIPFlag)
if cliSrcIP != nil {
srcIP = cliSrcIP
}
if srcIP == nil {
return nil, errSrcIP
}

srcMAC := iface.HardwareAddr
if len(cliSrcMACFlag) > 0 {
if srcMAC, err = net.ParseMAC(cliSrcMACFlag); err != nil {
return nil, err
}
if cliSrcMAC != nil {
srcMAC = cliSrcMAC
}
if srcMAC == nil {
return nil, errSrcMAC
Expand All @@ -128,6 +164,7 @@ func parseScanRange(subnet string) (*scan.Range, error) {
return &scan.Range{
Interface: iface,
DstSubnet: dstSubnet,
Ports: cliPortRanges,
SrcSubnet: srcSubnet,
SrcIP: srcIP.To4(),
SrcMAC: srcMAC}, nil
Expand Down Expand Up @@ -161,14 +198,35 @@ func parsePortRanges(portsRanges string) (result []*scan.PortRange, err error) {
return
}

func getSubnetInterface(dstSubnet *net.IPNet) (iface *net.Interface, srcSubnet *net.IPNet, err error) {
if len(cliInterfaceFlag) == 0 {
return ip.GetSubnetInterface(dstSubnet)
func parseRateLimit(rateLimit string) (rateCount int, rateWindow time.Duration, err error) {
parts := strings.Split(rateLimit, "/")
if len(parts) > 2 {
return 0, 0, errRateLimit
}
var rate int64
if rate, err = strconv.ParseInt(parts[0], 10, 32); err != nil || rate < 0 {
return 0, 0, errRateLimit
}
if iface, err = net.InterfaceByName(cliInterfaceFlag); err != nil {
rateCount = int(rate)
rateWindow = 1 * time.Second
if len(parts) < 2 {
return
}
if srcSubnet, err = ip.GetSubnetInterfaceIP(iface, dstSubnet); err != nil {
win := parts[1]
if len(win) > 0 && (win[0] < '0' || win[0] > '9') {
win = "1" + win
}
if rateWindow, err = time.ParseDuration(win); err != nil || rateWindow < 0 {
return 0, 0, errRateLimit
}
return
}

func getSubnetInterface(dstSubnet *net.IPNet) (iface *net.Interface, srcSubnet *net.IPNet, err error) {
if cliInterface == nil {
return ip.GetSubnetInterface(dstSubnet)
}
if srcSubnet, err = ip.GetSubnetInterfaceIP(cliInterface, dstSubnet); err != nil {
return
}
return iface, srcSubnet, nil
Expand Down Expand Up @@ -205,6 +263,8 @@ type engineConfig struct {
scanRange *scan.Range
scanMethod resultScanMethod
bpfFilter func(r *scan.Range) (filter string, maxPacketLength int)
rateCount int
rateWindow time.Duration
}

type resultScanMethod interface {
Expand All @@ -221,15 +281,21 @@ func startEngine(ctx context.Context, conf *engineConfig) error {
logger := conf.logger

// setup network interface to read/write packets
rw, err := afpacket.NewPacketSource(r.Interface.Name)
afps, err := afpacket.NewPacketSource(r.Interface.Name)
if err != nil {
return err
}
defer rw.Close()
err = rw.SetBPFFilter(conf.bpfFilter(r))
defer afps.Close()
err = afps.SetBPFFilter(conf.bpfFilter(r))
if err != nil {
return err
}
var rw packet.ReadWriter = afps
// setup rate limit for sending packets
if conf.rateCount > 0 {
rw = packet.NewRateLimitReadWriter(afps,
ratelimit.New(conf.rateCount, ratelimit.Per(conf.rateWindow)))
}

// setup result logging
var wg sync.WaitGroup
Expand Down
93 changes: 93 additions & 0 deletions command/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package command

import (
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/v-byte-cpu/sx/pkg/scan"
Expand Down Expand Up @@ -141,3 +142,95 @@ func TestParsePortRanges(t *testing.T) {
})
}
}

func TestParseRateLimitError(t *testing.T) {
t.Parallel()

tests := []struct {
name string
rateLimit string
}{
{
name: "InvalidRateLimit",
rateLimit: "abc",
},
{
name: "NegativeRateCount",
rateLimit: "-1000",
},
{
name: "InvalidRateWindow",
rateLimit: "1000/f",
},
{
name: "EmptySlashRateWindow",
rateLimit: "1000/",
},
{
name: "MultipleSlashes",
rateLimit: "1000//s",
},
{
name: "NegativeRateWindowDuration",
rateLimit: "1000/-1s",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, _, err := parseRateLimit(tt.rateLimit)
require.Error(t, err)
})
}
}

func TestParseRateLimit(t *testing.T) {
t.Parallel()

tests := []struct {
name string
rateLimit string
expectedRateCount int
expectedRateWindow time.Duration
}{
{
name: "ZeroRateCount",
rateLimit: "0",
expectedRateCount: 0,
expectedRateWindow: 1 * time.Second,
},
{
name: "EmptyRateWindow",
rateLimit: "1000",
expectedRateCount: 1000,
expectedRateWindow: 1 * time.Second,
},
{
name: "OneSecondRate",
rateLimit: "1000/1s",
expectedRateCount: 1000,
expectedRateWindow: 1 * time.Second,
},
{
name: "SevenMinureRate",
rateLimit: "5000/7m",
expectedRateCount: 5000,
expectedRateWindow: 7 * time.Minute,
},
{
name: "OneSecondRate2",
rateLimit: "1000/s",
expectedRateCount: 1000,
expectedRateWindow: 1 * time.Second,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rate, rateWindow, err := parseRateLimit(tt.rateLimit)
require.NoError(t, err)
require.Equal(t, tt.expectedRateCount, rate)
require.Equal(t, tt.expectedRateWindow, rateWindow)
})
}
}
6 changes: 4 additions & 2 deletions command/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ var tcpCmd = &cobra.Command{
defer cancel()

if len(cliTCPPacketFlags) == 0 {
return startTCPSYNScan(ctx, args[0], cliPortsFlag)
return startTCPSYNScan(ctx, args[0])
}

var tcpFlags []string
Expand All @@ -75,7 +75,7 @@ var tcpCmd = &cobra.Command{

scanName := tcp.FlagsScanType
var conf *scanConfig
if conf, err = parseScanConfig(scanName, args[0], cliPortsFlag); err != nil {
if conf, err = parseScanConfig(scanName, args[0]); err != nil {
return
}

Expand All @@ -91,6 +91,8 @@ var tcpCmd = &cobra.Command{
scanRange: conf.scanRange,
scanMethod: m,
bpfFilter: tcp.BPFFilter,
rateCount: cliRateCount,
rateWindow: cliRateWindow,
})
},
}
Expand Down
4 changes: 3 additions & 1 deletion command/tcp_fin.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var tcpfinCmd = &cobra.Command{
scanName := tcp.FINScanType

var conf *scanConfig
if conf, err = parseScanConfig(scanName, args[0], cliPortsFlag); err != nil {
if conf, err = parseScanConfig(scanName, args[0]); err != nil {
return
}

Expand All @@ -48,6 +48,8 @@ var tcpfinCmd = &cobra.Command{
scanRange: conf.scanRange,
scanMethod: m,
bpfFilter: tcp.BPFFilter,
rateCount: cliRateCount,
rateWindow: cliRateWindow,
})
},
}
4 changes: 3 additions & 1 deletion command/tcp_null.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var tcpnullCmd = &cobra.Command{
scanName := tcp.NULLScanType

var conf *scanConfig
if conf, err = parseScanConfig(scanName, args[0], cliPortsFlag); err != nil {
if conf, err = parseScanConfig(scanName, args[0]); err != nil {
return
}

Expand All @@ -48,6 +48,8 @@ var tcpnullCmd = &cobra.Command{
scanRange: conf.scanRange,
scanMethod: m,
bpfFilter: tcp.BPFFilter,
rateCount: cliRateCount,
rateWindow: cliRateWindow,
})
},
}
Loading

0 comments on commit 3e995ec

Please sign in to comment.