diff --git a/command/arp.go b/command/arp.go index 1713736..d4daa2a 100644 --- a/command/arp.go +++ b/command/arp.go @@ -16,67 +16,88 @@ import ( "github.com/v-byte-cpu/sx/pkg/scan/arp" ) -var ( - cliARPLiveTimeoutFlag string - cliARPLiveTimeout time.Duration -) +func newARPCmd() *arpCmd { + c := &arpCmd{} -func init() { - addPacketScanOptions(arpCmd, withoutGatewayMAC()) - arpCmd.Flags().StringVar(&cliARPLiveTimeoutFlag, "live", "", "enable live mode") - rootCmd.AddCommand(arpCmd) -} + cmd := &cobra.Command{ + Use: "arp [flags] subnet", + Example: strings.Join([]string{"arp 192.168.0.1/24", "arp 10.0.0.1"}, "\n"), + Short: "Perform ARP scan", + RunE: func(cmd *cobra.Command, args []string) (err error) { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() -var arpCmd = &cobra.Command{ - Use: "arp [flags] subnet", - Example: strings.Join([]string{"arp 192.168.0.1/24", "arp 10.0.0.1"}, "\n"), - Short: "Perform ARP scan", - PreRunE: func(cmd *cobra.Command, args []string) (err error) { - if len(cliARPLiveTimeoutFlag) > 0 { - if cliARPLiveTimeout, err = time.ParseDuration(cliARPLiveTimeoutFlag); err != nil { + if len(args) != 1 { + return errors.New("requires one ip subnet argument") + } + dstSubnet, err := ip.ParseIPNet(args[0]) + if err != nil { return } - } - if len(args) != 1 { - return errors.New("requires one ip subnet argument") - } - cliDstSubnet, err = ip.ParseIPNet(args[0]) - return - }, - RunE: func(cmd *cobra.Command, args []string) (err error) { - var r *scan.Range - if r, err = getScanRange(cliDstSubnet); err != nil { - return err - } - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) - defer cancel() + if err = c.opts.parseRawOptions(); err != nil { + return + } + var r *scan.Range + if r, err = c.opts.getScanRange(dstSubnet); err != nil { + return err + } + var logger log.Logger + if logger, err = c.opts.getLogger(); err != nil { + return err + } - var logger log.Logger - if logger, err = getLogger("arp", os.Stdout); err != nil { - return err - } - if cliARPLiveTimeout > 0 { - logger = log.NewUniqueLogger(logger) - } + m := c.opts.newARPScanMethod(ctx) - m := newARPScanMethod(ctx) + return startPacketScanEngine(ctx, newPacketScanConfig( + withPacketScanMethod(m), + withPacketBPFFilter(arp.BPFFilter), + withRateCount(c.opts.rateCount), + withRateWindow(c.opts.rateWindow), + withPacketEngineConfig(newEngineConfig( + withLogger(logger), + withScanRange(r), + withExitDelay(c.opts.exitDelay), + )), + )) + }, + } - return startPacketScanEngine(ctx, newPacketScanConfig( - withPacketScanMethod(m), - withPacketBPFFilter(arp.BPFFilter), - withPacketEngineConfig(newEngineConfig( - withLogger(logger), - withScanRange(r), - )), - )) - }, + c.opts.initCliFlags(cmd) + + c.cmd = cmd + return c +} + +type arpCmd struct { + cmd *cobra.Command + opts arpCmdOpts +} + +type arpCmdOpts struct { + packetScanCmdOpts + liveTimeout time.Duration +} + +func (o *arpCmdOpts) initCliFlags(cmd *cobra.Command) { + o.packetScanCmdOpts.initCliFlags(cmd) + cmd.Flags().DurationVar(&o.liveTimeout, "live", 0, "enable live mode") +} + +func (o *arpCmdOpts) getLogger() (logger log.Logger, err error) { + if logger, err = o.packetScanCmdOpts.getLogger("arp", os.Stdout); err != nil { + return + } + if o.liveTimeout > 0 { + logger = log.NewUniqueLogger(logger) + } + return } -func newARPScanMethod(ctx context.Context) *arp.ScanMethod { +func (o *arpCmdOpts) newARPScanMethod(ctx context.Context) *arp.ScanMethod { var reqgen scan.RequestGenerator = scan.NewIPRequestGenerator(scan.NewIPGenerator()) - if cliARPLiveTimeout > 0 { - reqgen = scan.NewLiveRequestGenerator(reqgen, cliARPLiveTimeout) + if o.liveTimeout > 0 { + reqgen = scan.NewLiveRequestGenerator(reqgen, o.liveTimeout) } pktgen := scan.NewPacketMultiGenerator(arp.NewPacketFiller(), runtime.NumCPU()) psrc := scan.NewPacketSource(reqgen, pktgen) diff --git a/command/arp_test.go b/command/arp_test.go new file mode 100644 index 0000000..966c552 --- /dev/null +++ b/command/arp_test.go @@ -0,0 +1,21 @@ +package command + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestArpCmdDstSubnetRequiredArg(t *testing.T) { + cmd := newARPCmd().cmd + err := cmd.Execute() + require.Error(t, err) + require.Equal(t, "requires one ip subnet argument", err.Error()) +} + +func TestArpCmdInvalidDstSubnet(t *testing.T) { + cmd := newARPCmd().cmd + cmd.SetArgs([]string{"invalid_ip_address"}) + err := cmd.Execute() + require.Error(t, err) +} diff --git a/command/config.go b/command/config.go new file mode 100644 index 0000000..7122b66 --- /dev/null +++ b/command/config.go @@ -0,0 +1,510 @@ +package command + +import ( + "errors" + "io" + "io/ioutil" + "net" + "os" + "strconv" + "strings" + "time" + + "github.com/google/gopacket/layers" + "github.com/spf13/cobra" + "github.com/v-byte-cpu/sx/command/log" + "github.com/v-byte-cpu/sx/pkg/ip" + "github.com/v-byte-cpu/sx/pkg/scan" + "github.com/v-byte-cpu/sx/pkg/scan/arp" +) + +const ( + cliHTTPProtoFlag = "http" + cliHTTPSProtoFlag = "https" + + defaultWorkerCount = 100 + defaultTimeout = 5 * time.Second + defaultExitDelay = 300 * time.Millisecond +) + +var ( + errSrcIP = errors.New("invalid source IP") + errSrcMAC = errors.New("invalid source MAC") + errSrcInterface = errors.New("invalid source interface") + errRateLimit = errors.New("invalid ratelimit") + errStdin = errors.New("stdin is from a terminal") + errIPFlags = errors.New("invalid ip flags") +) + +type packetScanCmdOpts struct { + json bool + iface *net.Interface + srcIP net.IP + srcMAC net.HardwareAddr + rateCount int + rateWindow time.Duration + exitDelay time.Duration + + rawInterface string + rawSrcMAC string + rawRateLimit string +} + +// TODO test +func (o *packetScanCmdOpts) initCliFlags(cmd *cobra.Command) { + cmd.Flags().BoolVar(&o.json, "json", false, "enable JSON output") + cmd.Flags().StringVarP(&o.rawInterface, "iface", "i", "", "set interface to send/receive packets") + cmd.Flags().IPVar(&o.srcIP, "srcip", nil, "set source IP address for generated packets") + cmd.Flags().StringVar(&o.rawSrcMAC, "srcmac", "", "set source MAC address for generated packets") + cmd.Flags().StringVarP(&o.rawRateLimit, "rate", "r", "", + strings.Join([]string{ + "set rate limit for generated packets", + `format: "rateCount/rateWindow"`, + "where rateCount is a number of packets, rateWindow is the time interval", + "e.g. 1000/s -- 1000 packets per second", "500/7s -- 500 packets per 7 seconds\n"}, "\n")) + cmd.Flags().DurationVar(&o.exitDelay, "exit-delay", defaultExitDelay, + strings.Join([]string{ + "set exit delay to wait for last response packets", + "any expression accepted by time.ParseDuration is valid"}, "\n")) +} + +// TODO test +func (o *packetScanCmdOpts) parseRawOptions() (err error) { + if len(o.rawInterface) > 0 { + if o.iface, err = net.InterfaceByName(o.rawInterface); err != nil { + return + } + } + if len(o.rawSrcMAC) > 0 { + if o.srcMAC, err = net.ParseMAC(o.rawSrcMAC); err != nil { + return + } + } + if len(o.rawRateLimit) > 0 { + if o.rateCount, o.rateWindow, err = parseRateLimit(o.rawRateLimit); err != nil { + return + } + } + return +} + +func (o *packetScanCmdOpts) getScanRange(dstSubnet *net.IPNet) (*scan.Range, error) { + iface, srcIP, err := o.getInterface(dstSubnet) + if err != nil { + return nil, err + } + if iface == nil { + return nil, errSrcInterface + } + + if o.srcIP != nil { + srcIP = o.srcIP + } + if srcIP == nil { + return nil, errSrcIP + } + + srcMAC := iface.HardwareAddr + if o.srcMAC != nil { + srcMAC = o.srcMAC + } + if srcMAC == nil { + return nil, errSrcMAC + } + + return &scan.Range{ + Interface: iface, + DstSubnet: dstSubnet, + SrcIP: srcIP.To4(), + SrcMAC: srcMAC}, nil +} + +func (o *packetScanCmdOpts) getInterface(dstSubnet *net.IPNet) (iface *net.Interface, ifaceIP net.IP, err error) { + if dstSubnet != nil { + // try to find directly connected interface + if iface, ifaceIP, err = o.getLocalSubnetInterface(dstSubnet); err != nil { + return + } + // found local interface + if iface != nil && ifaceIP != nil { + return + } + } + if o.iface != nil { + // try to get first ip address + ifaceIP, err = ip.GetInterfaceIP(o.iface) + return o.iface, ifaceIP, err + } + // fallback to interface of default gateway + return ip.GetDefaultInterface() +} + +func (o *packetScanCmdOpts) getLocalSubnetInterface(dstSubnet *net.IPNet) (iface *net.Interface, ifaceIP net.IP, err error) { + if o.iface == nil { + return ip.GetLocalSubnetInterface(dstSubnet) + } + ifaceIP, err = ip.GetLocalSubnetInterfaceIP(o.iface, dstSubnet) + return o.iface, ifaceIP, err +} + +func (o *packetScanCmdOpts) getLogger(name string, w io.Writer) (logger log.Logger, err error) { + opts := []log.LoggerOption{log.FlushInterval(1 * time.Second)} + if o.json { + opts = append(opts, log.JSON()) + } + logger, err = log.NewLogger(w, name, opts...) + return +} + +type ipScanCmdOpts struct { + packetScanCmdOpts + ipFile string + arpCacheFile string + gatewayMAC net.HardwareAddr + + rawGatewayMAC string +} + +func (o *ipScanCmdOpts) initCliFlags(cmd *cobra.Command) { + o.packetScanCmdOpts.initCliFlags(cmd) + cmd.Flags().StringVar(&o.rawGatewayMAC, "gwmac", "", "set gateway MAC address to send generated packets to") + cmd.Flags().StringVarP(&o.ipFile, "file", "f", "", "set JSONL file with IPs to scan") + cmd.Flags().StringVarP(&o.arpCacheFile, "arp-cache", "a", "", + strings.Join([]string{"set ARP cache file", "reads from stdin by default"}, "\n")) +} + +func (o *ipScanCmdOpts) parseRawOptions() (err error) { + if err = o.packetScanCmdOpts.parseRawOptions(); err != nil { + return + } + if len(o.rawGatewayMAC) > 0 { + if o.gatewayMAC, err = net.ParseMAC(o.rawGatewayMAC); err != nil { + return + } + } + return +} + +type scanConfig struct { + logger log.Logger + scanRange *scan.Range + cache *arp.Cache + gatewayMAC net.HardwareAddr +} + +func (o *ipScanCmdOpts) parseScanConfig(scanName string, args []string) (c *scanConfig, err error) { + if err = o.validateStdin(); err != nil { + return + } + + dstSubnet, err := o.parseDstSubnet(args) + if err != nil { + return + } + var r *scan.Range + if r, err = o.getScanRange(dstSubnet); err != nil { + return + } + + var logger log.Logger + if logger, err = o.getLogger(scanName, os.Stdout); err != nil { + return + } + + var cache *arp.Cache + if cache, err = o.parseARPCache(); err != nil { + return + } + + var gatewayMAC net.HardwareAddr + if gatewayMAC, err = o.getGatewayMAC(r.Interface, cache); err != nil { + return + } + + c = &scanConfig{ + logger: logger, + scanRange: r, + cache: cache, + gatewayMAC: gatewayMAC, + } + return +} + +// TODO test +func (o *ipScanCmdOpts) validateStdin() (err error) { + if o.isARPCacheFromStdin() && o.ipFile == "-" { + return errors.New("ARP cache and IP file can not be read from stdin at the same time") + } + return +} + +// TODO test +func (o *ipScanCmdOpts) parseDstSubnet(args []string) (ipnet *net.IPNet, err error) { + if len(args) == 0 && len(o.ipFile) == 0 { + return nil, errors.New("requires one ip subnet argument or file with ip/port pairs") + } + if len(args) == 0 { + return + } + return ip.ParseIPNet(args[0]) +} + +func (o *ipScanCmdOpts) parseARPCache() (cache *arp.Cache, err error) { + var r io.ReadCloser + if r, err = o.openARPCache(); err != nil { + return + } + defer r.Close() + cache = arp.NewCache() + err = arp.FillCache(cache, r) + return +} + +func (o *ipScanCmdOpts) openARPCache() (r io.ReadCloser, err error) { + if !o.isARPCacheFromStdin() { + return os.Open(o.arpCacheFile) + } + // read from stdin + var info os.FileInfo + if info, err = os.Stdin.Stat(); err != nil { + return + } + // only data being piped to stdin is valid + if (info.Mode() & os.ModeCharDevice) != 0 { + // stdin from terminal is not valid + return nil, errStdin + } + r = io.NopCloser(os.Stdin) + return +} + +// TODO test +func (o *ipScanCmdOpts) isARPCacheFromStdin() bool { + return len(o.arpCacheFile) == 0 || o.arpCacheFile == "-" +} + +func (o *ipScanCmdOpts) getGatewayMAC(iface *net.Interface, cache *arp.Cache) (mac net.HardwareAddr, err error) { + if o.gatewayMAC != nil { + return o.gatewayMAC, nil + } + var gatewayIP net.IP + if gatewayIP, err = ip.GetDefaultGatewayIP(iface); err != nil { + return + } + mac = cache.Get(gatewayIP.To4()) + return +} + +type ipPortScanCmdOpts struct { + ipScanCmdOpts + portRanges []*scan.PortRange + + rawPortRanges string +} + +func (o *ipPortScanCmdOpts) initCliFlags(cmd *cobra.Command) { + o.ipScanCmdOpts.initCliFlags(cmd) + cmd.Flags().StringVarP(&o.rawPortRanges, "ports", "p", "", "set ports to scan") +} + +// TODO test +func (o *ipPortScanCmdOpts) parseRawOptions() (err error) { + if err = o.ipScanCmdOpts.parseRawOptions(); err != nil { + return + } + if len(o.rawPortRanges) > 0 { + if o.portRanges, err = parsePortRanges(o.rawPortRanges); err != nil { + return + } + } + return +} + +func (o *ipPortScanCmdOpts) parseScanConfig(scanName string, args []string) (c *scanConfig, err error) { + if c, err = o.ipScanCmdOpts.parseScanConfig(scanName, args); err != nil { + return + } + c.scanRange.Ports = o.portRanges + return +} + +func (o *ipPortScanCmdOpts) newIPPortGenerator() (reqgen scan.RequestGenerator) { + if len(o.ipFile) == 0 { + return scan.NewIPPortGenerator(scan.NewIPGenerator(), scan.NewPortGenerator()) + } + if len(o.portRanges) == 0 { + return scan.NewFileIPPortGenerator(func() (io.ReadCloser, error) { + return os.Open(o.ipFile) + }) + } + ipgen := scan.NewFileIPGenerator(func() (io.ReadCloser, error) { + if o.ipFile == "-" { + return ioutil.NopCloser(os.Stdin), nil + } + return os.Open(o.ipFile) + }) + return scan.NewIPPortGenerator(ipgen, scan.NewPortGenerator()) +} + +type genericScanCmdOpts struct { + json bool + ipFile string + portRanges []*scan.PortRange + workers int + exitDelay time.Duration + + rawPortRanges string +} + +// TODO test +func (o *genericScanCmdOpts) initCliFlags(cmd *cobra.Command) { + cmd.Flags().BoolVar(&o.json, "json", false, "enable JSON output") + cmd.Flags().StringVarP(&o.rawPortRanges, "ports", "p", "", "set ports to scan") + cmd.Flags().StringVarP(&o.ipFile, "file", "f", "", "set JSONL file with ip/port pairs to scan") + cmd.Flags().IntVarP(&o.workers, "workers", "w", defaultWorkerCount, "set workers count") + cmd.Flags().DurationVar(&o.exitDelay, "exit-delay", defaultExitDelay, + strings.Join([]string{ + "set exit delay to wait for last response", + "any expression accepted by time.ParseDuration is valid"}, "\n")) +} + +// TODO test +func (o *genericScanCmdOpts) parseRawOptions() (err error) { + if len(o.rawPortRanges) > 0 { + if o.portRanges, err = parsePortRanges(o.rawPortRanges); err != nil { + return + } + } + if o.workers <= 0 { + return errors.New("invalid workers count") + } + return +} + +// TODO test +func (o *genericScanCmdOpts) parseScanRange(args []string) (r *scan.Range, err error) { + dstSubnet, err := o.parseDstSubnet(args) + r = &scan.Range{ + DstSubnet: dstSubnet, + Ports: o.portRanges, + } + return +} + +// TODO test +func (o *genericScanCmdOpts) parseDstSubnet(args []string) (ipnet *net.IPNet, err error) { + if len(args) == 0 && len(o.ipFile) == 0 { + return nil, errors.New("requires one ip subnet argument or file with ip/port pairs") + } + if len(args) == 0 { + return + } + return ip.ParseIPNet(args[0]) +} + +func (o *genericScanCmdOpts) getLogger(name string, w io.Writer) (logger log.Logger, err error) { + opts := []log.LoggerOption{log.FlushInterval(1 * time.Second)} + if o.json { + opts = append(opts, log.JSON()) + } + logger, err = log.NewLogger(w, name, opts...) + return +} + +func (o *genericScanCmdOpts) newIPPortGenerator() (reqgen scan.RequestGenerator) { + if len(o.ipFile) == 0 { + return scan.NewIPPortGenerator(scan.NewIPGenerator(), scan.NewPortGenerator()) + } + if len(o.portRanges) == 0 { + return scan.NewFileIPPortGenerator(func() (io.ReadCloser, error) { + return os.Open(o.ipFile) + }) + } + ipgen := scan.NewFileIPGenerator(func() (io.ReadCloser, error) { + if o.ipFile == "-" { + return ioutil.NopCloser(os.Stdin), nil + } + return os.Open(o.ipFile) + }) + return scan.NewIPPortGenerator(ipgen, scan.NewPortGenerator()) +} + +func parsePortRange(portsRange string) (r *scan.PortRange, err error) { + ports := strings.Split(portsRange, "-") + var port uint64 + if port, err = strconv.ParseUint(ports[0], 10, 16); err != nil { + return + } + result := &scan.PortRange{StartPort: uint16(port), EndPort: uint16(port)} + if len(ports) < 2 { + return result, nil + } + if port, err = strconv.ParseUint(ports[1], 10, 16); err != nil { + return + } + result.EndPort = uint16(port) + return result, nil +} + +func parsePortRanges(portsRanges string) (result []*scan.PortRange, err error) { + var ports *scan.PortRange + for _, portsRange := range strings.Split(portsRanges, ",") { + if ports, err = parsePortRange(portsRange); err != nil { + return + } + result = append(result, ports) + } + return +} + +func parseRateLimit(rateLimit string) (rateCount int, rateWindow time.Duration, err error) { + parts := strings.Split(rateLimit, "/") + if len(parts) > 2 { + return 0, 0, errRateLimit + } + var rate int64 + if rate, err = strconv.ParseInt(parts[0], 10, 32); err != nil || rate < 0 { + return 0, 0, errRateLimit + } + rateCount = int(rate) + rateWindow = 1 * time.Second + if len(parts) < 2 { + return + } + win := parts[1] + if len(win) > 0 && (win[0] < '0' || win[0] > '9') { + win = "1" + win + } + if rateWindow, err = time.ParseDuration(win); err != nil || rateWindow < 0 { + return 0, 0, errRateLimit + } + return +} + +func parsePacketPayload(payload string) (result []byte, err error) { + var unquoted string + if unquoted, err = strconv.Unquote(`"` + payload + `"`); err != nil { + return + } + return []byte(unquoted), nil +} + +func parseIPFlags(inputFlags string) (result uint8, err error) { + if len(inputFlags) == 0 { + return + } + flags := strings.Split(strings.ToLower(inputFlags), ",") + for _, flag := range flags { + switch flag { + case "df": + result |= uint8(layers.IPv4DontFragment) + case "evil": + result |= uint8(layers.IPv4EvilBit) + case "mf": + result |= uint8(layers.IPv4MoreFragments) + default: + return 0, errIPFlags + } + } + return +} diff --git a/command/root_test.go b/command/config_test.go similarity index 100% rename from command/root_test.go rename to command/config_test.go diff --git a/command/docker.go b/command/docker.go index 84ae800..7744e7f 100644 --- a/command/docker.go +++ b/command/docker.go @@ -6,6 +6,7 @@ import ( "os" "os/signal" "strings" + "time" "github.com/spf13/cobra" "github.com/v-byte-cpu/sx/command/log" @@ -13,55 +14,80 @@ import ( "github.com/v-byte-cpu/sx/pkg/scan/docker" ) -func init() { - dockerCmd.Flags().StringVarP(&cliPortsFlag, "ports", "p", "", "set ports to scan") - dockerCmd.Flags().StringVarP(&cliIPPortFileFlag, "file", "f", "", "set JSONL file with ip/port pairs to scan") - dockerCmd.Flags().StringVar(&cliProtoFlag, "proto", "", "set protocol to use, http is used by default; only http or https are valid") - dockerCmd.Flags().IntVarP(&cliWorkerCountFlag, "workers", "w", defaultWorkerCount, "set workers count") - dockerCmd.Flags().DurationVarP(&cliTimeoutFlag, "timeout", "t", defaultTimeout, "set request timeout") - rootCmd.AddCommand(dockerCmd) +func newDockerCmd() *dockerCmd { + c := &dockerCmd{} + + cmd := &cobra.Command{ + Use: "docker [flags] [subnet]", + Example: strings.Join([]string{ + "docker -p 2375 192.168.0.1/24", "docker -p 2300-2500 10.0.0.1", + "docker --proto https -p 2300-2500 192.168.0.3", + "docker -f ip_ports_file.jsonl", "docker -p 9200-9300 -f ips_file.jsonl"}, "\n"), + Short: "Perform Docker scan", + RunE: func(cmd *cobra.Command, args []string) (err error) { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + if err = c.opts.parseRawOptions(); err != nil { + return + } + scanRange, err := c.opts.parseScanRange(args) + if err != nil { + return + } + + var logger log.Logger + if logger, err = c.opts.getLogger(docker.ScanType, os.Stdout); err != nil { + return + } + + engine := c.opts.newDockerScanEngine(ctx) + return startScanEngine(ctx, engine, + newEngineConfig( + withLogger(logger), + withScanRange(scanRange), + withExitDelay(c.opts.exitDelay), + )) + }, + } + + c.opts.initCliFlags(cmd) + + c.cmd = cmd + return c } -var dockerCmd = &cobra.Command{ - Use: "docker [flags] [subnet]", - Example: strings.Join([]string{ - "docker -p 2375 192.168.0.1/24", "docker -p 2300-2500 10.0.0.1", - "docker --proto https -p 2300-2500 192.168.0.3", - "docker -f ip_ports_file.jsonl", "docker -p 9200-9300 -f ips_file.jsonl"}, "\n"), - Short: "Perform Docker scan", - PreRunE: func(cmd *cobra.Command, args []string) (err error) { - if len(cliProtoFlag) == 0 { - cliProtoFlag = cliHTTPProtoFlag - } - if cliProtoFlag != cliHTTPProtoFlag && cliProtoFlag != cliHTTPSProtoFlag { - return errors.New("invalid HTTP proto flag: http or https required") - } - cliDstSubnet, err = parseDstSubnet(args) - return - }, - RunE: func(cmd *cobra.Command, args []string) (err error) { - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) - defer cancel() +type dockerCmd struct { + cmd *cobra.Command + opts dockerCmdOpts +} + +type dockerCmdOpts struct { + genericScanCmdOpts + timeout time.Duration + proto string +} - var logger log.Logger - if logger, err = getLogger("docker", os.Stdout); err != nil { - return - } +// TODO test +func (o *dockerCmdOpts) initCliFlags(cmd *cobra.Command) { + o.genericScanCmdOpts.initCliFlags(cmd) + cmd.Flags().DurationVarP(&o.timeout, "timeout", "t", defaultTimeout, "set request timeout") + cmd.Flags().StringVar(&o.proto, "proto", cliHTTPProtoFlag, "set protocol to use, only http or https are valid") +} - engine := newDockerScanEngine(ctx) - return startScanEngine(ctx, engine, - newEngineConfig( - withLogger(logger), - withScanRange(&scan.Range{ - DstSubnet: cliDstSubnet, - Ports: cliPortRanges, - }), - )) - }, +// TODO test +func (o *dockerCmdOpts) parseRawOptions() (err error) { + if err = o.genericScanCmdOpts.parseRawOptions(); err != nil { + return + } + if o.proto != cliHTTPProtoFlag && o.proto != cliHTTPSProtoFlag { + return errors.New("invalid HTTP proto flag: http or https required") + } + return } -func newDockerScanEngine(ctx context.Context) scan.EngineResulter { - scanner := docker.NewScanner(cliProtoFlag, docker.WithDataTimeout(cliTimeoutFlag)) +func (o *dockerCmdOpts) newDockerScanEngine(ctx context.Context) scan.EngineResulter { + scanner := docker.NewScanner(o.proto, docker.WithDataTimeout(o.timeout)) results := scan.NewResultChan(ctx, 1000) - return scan.NewScanEngine(newIPPortGenerator(), scanner, results, scan.WithScanWorkerCount(cliWorkerCountFlag)) + return scan.NewScanEngine(o.newIPPortGenerator(), scanner, results, scan.WithScanWorkerCount(o.workers)) } diff --git a/command/elastic.go b/command/elastic.go index 01de8d8..d2cfd6e 100644 --- a/command/elastic.go +++ b/command/elastic.go @@ -6,6 +6,7 @@ import ( "os" "os/signal" "strings" + "time" "github.com/spf13/cobra" "github.com/v-byte-cpu/sx/command/log" @@ -13,55 +14,80 @@ import ( "github.com/v-byte-cpu/sx/pkg/scan/elastic" ) -func init() { - elasticCmd.Flags().StringVarP(&cliPortsFlag, "ports", "p", "", "set ports to scan") - elasticCmd.Flags().StringVarP(&cliIPPortFileFlag, "file", "f", "", "set JSONL file with ip/port pairs to scan") - elasticCmd.Flags().StringVar(&cliProtoFlag, "proto", "", "set protocol to use, http is used by default; only http or https are valid") - elasticCmd.Flags().IntVarP(&cliWorkerCountFlag, "workers", "w", defaultWorkerCount, "set workers count") - elasticCmd.Flags().DurationVarP(&cliTimeoutFlag, "timeout", "t", defaultTimeout, "set request timeout") - rootCmd.AddCommand(elasticCmd) +func newElasticCmd() *elasticCmd { + c := &elasticCmd{} + + cmd := &cobra.Command{ + Use: "elastic [flags] [subnet]", + Example: strings.Join([]string{ + "elastic -p 9200 192.168.0.1/24", "elastic -p 9200-9300 10.0.0.1", + "elastic --proto https -p 9200-9201 192.168.0.3", + "elastic -f ip_ports_file.jsonl", "elastic -p 9200-9300 -f ips_file.jsonl"}, "\n"), + Short: "Perform Elasticsearch scan", + RunE: func(cmd *cobra.Command, args []string) (err error) { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + if err = c.opts.parseRawOptions(); err != nil { + return + } + scanRange, err := c.opts.parseScanRange(args) + if err != nil { + return + } + + var logger log.Logger + if logger, err = c.opts.getLogger(elastic.ScanType, os.Stdout); err != nil { + return + } + + engine := c.opts.newElasticScanEngine(ctx) + return startScanEngine(ctx, engine, + newEngineConfig( + withLogger(logger), + withScanRange(scanRange), + withExitDelay(c.opts.exitDelay), + )) + }, + } + + c.opts.initCliFlags(cmd) + + c.cmd = cmd + return c } -var elasticCmd = &cobra.Command{ - Use: "elastic [flags] [subnet]", - Example: strings.Join([]string{ - "elastic -p 9200 192.168.0.1/24", "elastic -p 9200-9300 10.0.0.1", - "elastic --proto https -p 9200-9201 192.168.0.3", - "elastic -f ip_ports_file.jsonl", "elastic -p 9200-9300 -f ips_file.jsonl"}, "\n"), - Short: "Perform Elasticsearch scan", - PreRunE: func(cmd *cobra.Command, args []string) (err error) { - if len(cliProtoFlag) == 0 { - cliProtoFlag = cliHTTPProtoFlag - } - if cliProtoFlag != cliHTTPProtoFlag && cliProtoFlag != cliHTTPSProtoFlag { - return errors.New("invalid HTTP proto flag: http or https required") - } - cliDstSubnet, err = parseDstSubnet(args) - return - }, - RunE: func(cmd *cobra.Command, args []string) (err error) { - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) - defer cancel() +type elasticCmd struct { + cmd *cobra.Command + opts elasticCmdOpts +} + +type elasticCmdOpts struct { + genericScanCmdOpts + timeout time.Duration + proto string +} - var logger log.Logger - if logger, err = getLogger("elastic", os.Stdout); err != nil { - return - } +// TODO test +func (o *elasticCmdOpts) initCliFlags(cmd *cobra.Command) { + o.genericScanCmdOpts.initCliFlags(cmd) + cmd.Flags().DurationVarP(&o.timeout, "timeout", "t", defaultTimeout, "set request timeout") + cmd.Flags().StringVar(&o.proto, "proto", cliHTTPProtoFlag, "set protocol to use, only http or https are valid") +} - engine := newElasticScanEngine(ctx) - return startScanEngine(ctx, engine, - newEngineConfig( - withLogger(logger), - withScanRange(&scan.Range{ - DstSubnet: cliDstSubnet, - Ports: cliPortRanges, - }), - )) - }, +// TODO test +func (o *elasticCmdOpts) parseRawOptions() (err error) { + if err = o.genericScanCmdOpts.parseRawOptions(); err != nil { + return + } + if o.proto != cliHTTPProtoFlag && o.proto != cliHTTPSProtoFlag { + return errors.New("invalid HTTP proto flag: http or https required") + } + return } -func newElasticScanEngine(ctx context.Context) scan.EngineResulter { - scanner := elastic.NewScanner(cliProtoFlag, elastic.WithDataTimeout(cliTimeoutFlag)) +func (o *elasticCmdOpts) newElasticScanEngine(ctx context.Context) scan.EngineResulter { + scanner := elastic.NewScanner(o.proto, elastic.WithDataTimeout(o.timeout)) results := scan.NewResultChan(ctx, 1000) - return scan.NewScanEngine(newIPPortGenerator(), scanner, results, scan.WithScanWorkerCount(cliWorkerCountFlag)) + return scan.NewScanEngine(o.newIPPortGenerator(), scanner, results, scan.WithScanWorkerCount(o.workers)) } diff --git a/command/icmp.go b/command/icmp.go index 231549e..0d9f34e 100644 --- a/command/icmp.go +++ b/command/icmp.go @@ -6,7 +6,6 @@ import ( "os" "os/signal" "runtime" - "strconv" "strings" "github.com/spf13/cobra" @@ -15,134 +14,132 @@ import ( "github.com/v-byte-cpu/sx/pkg/scan/icmp" ) -var ( - cliICMPTypeFlag string - cliICMPCodeFlag string - cliICMPPayloadFlag string +func newICMPCmd() *icmpCmd { + c := &icmpCmd{} - cliICMPType uint8 - cliICMPCode uint8 - cliICMPPayload []byte -) + cmd := &cobra.Command{ + Use: "icmp [flags] subnet", + Example: strings.Join([]string{ + "icmp 192.168.0.1/24", + "icmp --ttl 37 192.168.0.1/24", + "icmp --ipproto 157 192.168.0.1/24", + `icmp --type 13 --code 0 --payload '\x01\x02\x03' 10.0.0.1`}, "\n"), + Short: "Perform ICMP scan", + RunE: func(cmd *cobra.Command, args []string) (err error) { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + if err = c.opts.parseRawOptions(); err != nil { + return + } + var conf *scanConfig + if conf, err = c.opts.parseScanConfig(icmp.ScanType, args); err != nil { + return + } + + m := c.opts.newICMPScanMethod(ctx, conf) + + return startPacketScanEngine(ctx, newPacketScanConfig( + withPacketScanMethod(m), + withPacketBPFFilter(icmp.BPFFilter), + withRateCount(c.opts.rateCount), + withRateWindow(c.opts.rateWindow), + withPacketEngineConfig(newEngineConfig( + withLogger(conf.logger), + withScanRange(conf.scanRange), + withExitDelay(c.opts.exitDelay), + )), + )) + }, + } + + c.opts.initCliFlags(cmd) + + c.cmd = cmd + return c +} -func init() { - addPacketScanOptions(icmpCmd) - icmpCmd.Flags().StringVarP(&cliIPPortFileFlag, "file", "f", "", "set JSONL file with IPs to scan") - icmpCmd.Flags().StringVar(&cliIPTTLFlag, "ttl", "", - strings.Join([]string{"set IP TTL field of generated packet", "64 by default"}, "\n")) - icmpCmd.Flags().StringVar(&cliIPTotalLenFlag, "iplen", "", +type icmpCmd struct { + cmd *cobra.Command + opts icmpCmdOpts +} + +type icmpCmdOpts struct { + ipScanCmdOpts + ipTTL uint8 + ipFlags uint8 + ipProtocol uint8 + ipTotalLen uint16 + + icmpType uint8 + icmpCode uint8 + icmpPayload []byte + + rawIPFlags string + rawICMPPayload string +} + +// TODO test +func (o *icmpCmdOpts) initCliFlags(cmd *cobra.Command) { + o.ipScanCmdOpts.initCliFlags(cmd) + cmd.Flags().Uint8Var(&o.ipTTL, "ttl", 64, "set IP TTL field of generated packet") + cmd.Flags().Uint8Var(&o.ipProtocol, "ipproto", 1, + strings.Join([]string{"set IP Protocol field of generated packet", "ICMP by default"}, "\n")) + cmd.Flags().StringVar(&o.rawIPFlags, "ipflags", "DF", "set IP Flags field of generated packet") + cmd.Flags().Uint16Var(&o.ipTotalLen, "iplen", 0, strings.Join([]string{"set IP Total Length field of generated packet", "calculated by default"}, "\n")) - icmpCmd.Flags().StringVar(&cliIPProtocolFlag, "ipproto", "", - strings.Join([]string{"set IP Protocol field of generated packet", "1 (ICMP) by default"}, "\n")) - icmpCmd.Flags().StringVar(&cliIPFlagsFlag, "ipflags", "", - strings.Join([]string{"set IP Flags field of generated packet", "DF by default"}, "\n")) - icmpCmd.Flags().StringVarP(&cliICMPTypeFlag, "type", "t", "", + cmd.Flags().Uint8VarP(&o.icmpType, "type", "t", 8, strings.Join([]string{"set ICMP type of generated packet", "ICMP Echo (Type 8) by default"}, "\n")) - icmpCmd.Flags().StringVarP(&cliICMPCodeFlag, "code", "c", "", - strings.Join([]string{"set ICMP code of generated packet", "0 by default"}, "\n")) - icmpCmd.Flags().StringVarP(&cliICMPPayloadFlag, "payload", "p", "", + cmd.Flags().Uint8VarP(&o.icmpCode, "code", "c", 0, "set ICMP code of generated packet") + cmd.Flags().StringVarP(&o.rawICMPPayload, "payload", "p", "", strings.Join([]string{"set byte payload of generated packet", "48 random bytes by default"}, "\n")) - - icmpCmd.Flags().StringVarP(&cliARPCacheFileFlag, "arp-cache", "a", "", - strings.Join([]string{"set ARP cache file", "reads from stdin by default"}, "\n")) - rootCmd.AddCommand(icmpCmd) } -var icmpCmd = &cobra.Command{ - Use: "icmp [flags] subnet", - Example: strings.Join([]string{ - "icmp 192.168.0.1/24", - "icmp --ttl 37 192.168.0.1/24", - "icmp --ipproto 157 192.168.0.1/24", - `icmp --type 13 --code 0 --payload '\x01\x02\x03' 10.0.0.1`}, "\n"), - Short: "Perform ICMP scan", - PreRunE: func(cmd *cobra.Command, args []string) (err error) { - if cliDstSubnet, err = parseDstSubnet(args); err != nil { - return - } - if err = validatePacketScanStdin(); err != nil { +// TODO test +func (o *icmpCmdOpts) parseRawOptions() (err error) { + if err = o.ipScanCmdOpts.parseRawOptions(); err != nil { + return + } + if len(o.rawIPFlags) > 0 { + if o.ipFlags, err = parseIPFlags(o.rawIPFlags); err != nil { return } - var icmpType uint64 - if len(cliICMPTypeFlag) > 0 { - if icmpType, err = strconv.ParseUint(cliICMPTypeFlag, 10, 8); err != nil { - return - } - cliICMPType = uint8(icmpType) - } - var icmpCode uint64 - if len(cliICMPCodeFlag) > 0 { - if icmpCode, err = strconv.ParseUint(cliICMPCodeFlag, 10, 8); err != nil { - return - } - cliICMPCode = uint8(icmpCode) - } - if len(cliICMPPayloadFlag) > 0 { - if cliICMPPayload, err = parsePacketPayload(cliICMPPayloadFlag); err != nil { - return - } - } - return - }, - RunE: func(cmd *cobra.Command, args []string) (err error) { - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) - defer cancel() - - var conf *scanConfig - if conf, err = parseScanConfig(icmp.ScanType, cliDstSubnet); err != nil { + } + if len(o.rawICMPPayload) > 0 { + if o.icmpPayload, err = parsePacketPayload(o.rawICMPPayload); err != nil { return } - - m := newICMPScanMethod(ctx, conf) - - return startPacketScanEngine(ctx, newPacketScanConfig( - withPacketScanMethod(m), - withPacketBPFFilter(icmp.BPFFilter), - withPacketEngineConfig(newEngineConfig( - withLogger(conf.logger), - withScanRange(conf.scanRange), - )), - )) - }, + } + return } -func newICMPScanMethod(ctx context.Context, conf *scanConfig) *icmp.ScanMethod { +func (o *icmpCmdOpts) newICMPScanMethod(ctx context.Context, conf *scanConfig) *icmp.ScanMethod { ipgen := scan.NewIPGenerator() - if len(cliIPPortFileFlag) > 0 { + if len(o.ipFile) > 0 { ipgen = scan.NewFileIPGenerator(func() (io.ReadCloser, error) { - return os.Open(cliIPPortFileFlag) + return os.Open(o.ipFile) }) } reqgen := arp.NewCacheRequestGenerator( scan.NewIPRequestGenerator(ipgen), conf.gatewayMAC, conf.cache) - pktgen := scan.NewPacketMultiGenerator(icmp.NewPacketFiller(getICMPOptions()...), runtime.NumCPU()) + pktgen := scan.NewPacketMultiGenerator(icmp.NewPacketFiller(o.getICMPOptions()...), runtime.NumCPU()) psrc := scan.NewPacketSource(reqgen, pktgen) results := scan.NewResultChan(ctx, 1000) return icmp.NewScanMethod(psrc, results) } -func getICMPOptions() (opts []icmp.PacketFillerOption) { - if len(cliIPTTLFlag) > 0 { - opts = append(opts, icmp.WithTTL(cliTTL)) - } - if len(cliIPTotalLenFlag) > 0 { - opts = append(opts, icmp.WithIPTotalLength(cliIPTotalLen)) - } - if len(cliIPProtocolFlag) > 0 { - opts = append(opts, icmp.WithIPProtocol(cliIPProtocol)) - } - if len(cliIPFlagsFlag) > 0 { - opts = append(opts, icmp.WithIPFlags(cliIPFlags)) - } - if len(cliICMPTypeFlag) > 0 { - opts = append(opts, icmp.WithType(cliICMPType)) - } - if len(cliICMPCodeFlag) > 0 { - opts = append(opts, icmp.WithCode(cliICMPCode)) - } - if len(cliICMPPayloadFlag) > 0 { - opts = append(opts, icmp.WithPayload(cliICMPPayload)) +func (o *icmpCmdOpts) getICMPOptions() (opts []icmp.PacketFillerOption) { + opts = append(opts, + icmp.WithTTL(o.ipTTL), + icmp.WithIPProtocol(o.ipProtocol), + icmp.WithIPFlags(o.ipFlags), + icmp.WithIPTotalLength(o.ipTotalLen), + icmp.WithType(o.icmpType), + icmp.WithCode(o.icmpCode)) + + if len(o.icmpPayload) > 0 { + opts = append(opts, icmp.WithPayload(o.icmpPayload)) } return } diff --git a/command/root.go b/command/root.go index ed6db5f..bc26219 100644 --- a/command/root.go +++ b/command/root.go @@ -1,470 +1,59 @@ package command import ( - "bufio" "context" - "errors" - "io" - "io/ioutil" - "net" "os" - "strconv" - "strings" "sync" "time" - "github.com/google/gopacket/layers" "github.com/spf13/cobra" "github.com/v-byte-cpu/sx/command/log" - "github.com/v-byte-cpu/sx/pkg/ip" "github.com/v-byte-cpu/sx/pkg/packet" "github.com/v-byte-cpu/sx/pkg/packet/afpacket" "github.com/v-byte-cpu/sx/pkg/scan" - "github.com/v-byte-cpu/sx/pkg/scan/arp" "go.uber.org/ratelimit" ) -var rootCmd = &cobra.Command{ - Use: "sx", - Short: "Fast, modern, easy-to-use network scanner", - // Parse common flags - PersistentPreRunE: func(cmd *cobra.Command, args []string) (err error) { - if len(cliInterfaceFlag) > 0 { - if cliInterface, err = net.InterfaceByName(cliInterfaceFlag); err != nil { - return - } - } - if len(cliSrcIPFlag) > 0 { - if cliSrcIP = net.ParseIP(cliSrcIPFlag); cliSrcIP == nil { - return errSrcIP - } - } - if len(cliSrcMACFlag) > 0 { - if cliSrcMAC, err = net.ParseMAC(cliSrcMACFlag); err != nil { - return - } - } - if len(cliGatewayMACFlag) > 0 { - if cliGatewayMAC, err = net.ParseMAC(cliGatewayMACFlag); err != nil { - return - } - } - if len(cliPortsFlag) > 0 { - if cliPortRanges, err = parsePortRanges(cliPortsFlag); err != nil { - return - } - } - if len(cliRateLimitFlag) > 0 { - if cliRateCount, cliRateWindow, err = parseRateLimit(cliRateLimitFlag); err != nil { - return - } - } - if len(cliExitDelayFlag) > 0 { - if cliExitDelay, err = time.ParseDuration(cliExitDelayFlag); err != nil { - return - } - } - var ttl uint64 - if len(cliIPTTLFlag) > 0 { - if ttl, err = strconv.ParseUint(cliIPTTLFlag, 10, 8); err != nil { - return - } - cliTTL = uint8(ttl) - } - var ipLen uint64 - if len(cliIPTotalLenFlag) > 0 { - if ipLen, err = strconv.ParseUint(cliIPTotalLenFlag, 10, 16); err != nil { - return - } - cliIPTotalLen = uint16(ipLen) - } - var ipProto uint64 - if len(cliIPProtocolFlag) > 0 { - if ipProto, err = strconv.ParseUint(cliIPProtocolFlag, 10, 8); err != nil { - return - } - cliIPProtocol = uint8(ipProto) - } - if len(cliIPFlagsFlag) > 0 { - if cliIPFlags, err = parseIPFlags(cliIPFlagsFlag); err != nil { - return - } - } - if cliWorkerCountFlag <= 0 { - return errors.New("invalid workers count") - } - return - }, -} - -var ( - cliJSONFlag bool - cliInterfaceFlag string - cliSrcIPFlag string - cliSrcMACFlag string - cliGatewayMACFlag string - cliPortsFlag string - cliRateLimitFlag string - cliExitDelayFlag string - cliARPCacheFileFlag string - cliIPPortFileFlag string - cliProtoFlag string - cliIPTTLFlag string - cliIPTotalLenFlag string - cliIPProtocolFlag string - cliIPFlagsFlag string - cliWorkerCountFlag int - cliTimeoutFlag time.Duration - - cliInterface *net.Interface - cliSrcIP net.IP - cliSrcMAC net.HardwareAddr - cliGatewayMAC net.HardwareAddr - cliPortRanges []*scan.PortRange - cliDstSubnet *net.IPNet - cliRateCount int - cliRateWindow time.Duration - cliExitDelay = 300 * time.Millisecond - cliIPTotalLen uint16 - cliIPProtocol uint8 - cliIPFlags uint8 - cliTTL uint8 -) - -const ( - cliHTTPProtoFlag = "http" - cliHTTPSProtoFlag = "https" - - defaultWorkerCount = 100 - defaultTimeout = 5 * time.Second -) - -var ( - errSrcIP = errors.New("invalid source IP") - errSrcMAC = errors.New("invalid source MAC") - errSrcInterface = errors.New("invalid source interface") - errRateLimit = errors.New("invalid ratelimit") - errStdin = errors.New("stdin is from a terminal") - errIPFlags = errors.New("invalid ip flags") -) - -func init() { - rootCmd.PersistentFlags().BoolVar(&cliJSONFlag, "json", false, "enable JSON output") -} - -type cliPacketScanConfig struct { - gatewayMAC bool -} - -type cliPacketScanOption func(c *cliPacketScanConfig) - -func withoutGatewayMAC() cliPacketScanOption { - return func(c *cliPacketScanConfig) { - c.gatewayMAC = false - } -} - -func addPacketScanOptions(cmd *cobra.Command, opts ...cliPacketScanOption) { - conf := &cliPacketScanConfig{gatewayMAC: true} - for _, o := range opts { - o(conf) - } - cmd.PersistentFlags().StringVarP(&cliInterfaceFlag, "iface", "i", "", "set interface to send/receive packets") - cmd.PersistentFlags().StringVar(&cliSrcIPFlag, "srcip", "", "set source IP address for generated packets") - cmd.PersistentFlags().StringVar(&cliSrcMACFlag, "srcmac", "", "set source MAC address for generated packets") - if conf.gatewayMAC { - cmd.PersistentFlags().StringVar(&cliGatewayMACFlag, "gwmac", "", "set gateway MAC address to send generated packets to") - } - cmd.PersistentFlags().StringVarP(&cliRateLimitFlag, "rate", "r", "", - strings.Join([]string{ - "set rate limit for generated packets", - `format: "rateCount/rateWindow"`, - "where rateCount is a number of packets, rateWindow is the time interval", - "e.g. 1000/s -- 1000 packets per second", "500/7s -- 500 packets per 7 seconds\n"}, "\n")) - cmd.PersistentFlags().StringVar(&cliExitDelayFlag, "exit-delay", "", - strings.Join([]string{ - "set exit delay to wait for response packets", - "any expression accepted by time.ParseDuration is valid (300ms by default)"}, "\n")) -} - -func validatePacketScanStdin() (err error) { - if isARPCacheFromStdin() && cliIPPortFileFlag == "-" { - return errors.New("ARP cache and IP file can not be read from stdin at the same time") - } - return -} - func Main(version string) { - rootCmd.Version = version - if err := rootCmd.Execute(); err != nil { + if err := newRootCmd(version).Execute(); err != nil { os.Exit(1) } } -type scanConfig struct { - logger log.Logger - scanRange *scan.Range - cache *arp.Cache - gatewayMAC net.HardwareAddr -} - -func parseScanConfig(scanName string, dstSubnet *net.IPNet) (c *scanConfig, err error) { - var r *scan.Range - if r, err = getScanRange(dstSubnet); err != nil { - return - } - - var logger log.Logger - if logger, err = getLogger(scanName, os.Stdout); err != nil { - return - } - - var cache *arp.Cache - if cache, err = parseARPCache(); err != nil { - return - } - - var gatewayMAC net.HardwareAddr - if gatewayMAC, err = getGatewayMAC(r.Interface, cache); err != nil { - return - } - - c = &scanConfig{ - logger: logger, - scanRange: r, - cache: cache, - gatewayMAC: gatewayMAC, - } - return -} - -func parseDstSubnet(args []string) (ipnet *net.IPNet, err error) { - if len(args) == 0 && len(cliIPPortFileFlag) == 0 { - return nil, errors.New("requires one ip subnet argument or file with ip/port pairs") - } - if len(args) == 0 { - return - } - return ip.ParseIPNet(args[0]) -} - -func isARPCacheFromStdin() bool { - return len(cliARPCacheFileFlag) == 0 || cliARPCacheFileFlag == "-" -} - -func parseARPCache() (cache *arp.Cache, err error) { - var r io.Reader - if isARPCacheFromStdin() { - var info os.FileInfo - if info, err = os.Stdin.Stat(); err != nil { - return - } - // only data being piped to stdin is valid - if (info.Mode() & os.ModeCharDevice) != 0 { - // stdin from terminal is not valid - return nil, errStdin - } - r = os.Stdin - } else { - var f *os.File - if f, err = os.Open(cliARPCacheFileFlag); err != nil { - return - } - defer f.Close() - r = bufio.NewReader(f) - } - cache = arp.NewCache() - err = arp.FillCache(cache, r) - return -} - -func getScanRange(dstSubnet *net.IPNet) (*scan.Range, error) { - iface, srcIP, err := getInterface(dstSubnet) - if err != nil { - return nil, err - } - if iface == nil { - return nil, errSrcInterface - } - - if cliSrcIP != nil { - srcIP = cliSrcIP - } - if srcIP == nil { - return nil, errSrcIP - } - - srcMAC := iface.HardwareAddr - if cliSrcMAC != nil { - srcMAC = cliSrcMAC - } - if srcMAC == nil { - return nil, errSrcMAC - } - - return &scan.Range{ - Interface: iface, - DstSubnet: dstSubnet, - Ports: cliPortRanges, - SrcIP: srcIP.To4(), - SrcMAC: srcMAC}, nil -} - -func parsePortRange(portsRange string) (r *scan.PortRange, err error) { - ports := strings.Split(portsRange, "-") - var port uint64 - if port, err = strconv.ParseUint(ports[0], 10, 16); err != nil { - return - } - result := &scan.PortRange{StartPort: uint16(port), EndPort: uint16(port)} - if len(ports) < 2 { - return result, nil +func newRootCmd(version string) *cobra.Command { + cmd := &cobra.Command{ + Use: "sx", + Short: "Fast, modern, easy-to-use network scanner", + Version: version, } - if port, err = strconv.ParseUint(ports[1], 10, 16); err != nil { - return - } - result.EndPort = uint16(port) - return result, nil -} -func parsePortRanges(portsRanges string) (result []*scan.PortRange, err error) { - var ports *scan.PortRange - for _, portsRange := range strings.Split(portsRanges, ",") { - if ports, err = parsePortRange(portsRange); err != nil { - return - } - result = append(result, ports) - } - return -} + tcpCmd := newTCPFlagsCmd().cmd + tcpCmd.AddCommand( + newTCPSYNCmd().cmd, + newTCPFINCmd().cmd, + newTCPNULLCmd().cmd, + newTCPXmasCmd().cmd, + ) -func parseRateLimit(rateLimit string) (rateCount int, rateWindow time.Duration, err error) { - parts := strings.Split(rateLimit, "/") - if len(parts) > 2 { - return 0, 0, errRateLimit - } - var rate int64 - if rate, err = strconv.ParseInt(parts[0], 10, 32); err != nil || rate < 0 { - return 0, 0, errRateLimit - } - rateCount = int(rate) - rateWindow = 1 * time.Second - if len(parts) < 2 { - return - } - win := parts[1] - if len(win) > 0 && (win[0] < '0' || win[0] > '9') { - win = "1" + win - } - if rateWindow, err = time.ParseDuration(win); err != nil || rateWindow < 0 { - return 0, 0, errRateLimit - } - return -} - -func parsePacketPayload(payload string) (result []byte, err error) { - var unquoted string - if unquoted, err = strconv.Unquote(`"` + payload + `"`); err != nil { - return - } - return []byte(unquoted), nil -} + cmd.AddCommand( + newARPCmd().cmd, + newICMPCmd().cmd, + newUDPCmd().cmd, + tcpCmd, + newSocksCmd().cmd, + newDockerCmd().cmd, + newElasticCmd().cmd, + ) -func parseIPFlags(inputFlags string) (result uint8, err error) { - if len(inputFlags) == 0 { - return - } - flags := strings.Split(strings.ToLower(inputFlags), ",") - for _, flag := range flags { - switch flag { - case "df": - result |= uint8(layers.IPv4DontFragment) - case "evil": - result |= uint8(layers.IPv4EvilBit) - case "mf": - result |= uint8(layers.IPv4MoreFragments) - default: - return 0, errIPFlags - } - } - return -} - -func getLocalSubnetInterface(dstSubnet *net.IPNet) (iface *net.Interface, ifaceIP net.IP, err error) { - if cliInterface == nil { - return ip.GetLocalSubnetInterface(dstSubnet) - } - ifaceIP, err = ip.GetLocalSubnetInterfaceIP(cliInterface, dstSubnet) - return cliInterface, ifaceIP, err -} - -func getInterface(dstSubnet *net.IPNet) (iface *net.Interface, ifaceIP net.IP, err error) { - if dstSubnet != nil { - // try to find directly connected interface - if iface, ifaceIP, err = getLocalSubnetInterface(dstSubnet); err != nil { - return - } - // found local interface - if iface != nil && ifaceIP != nil { - return - } - } - if cliInterface != nil { - // try to get first ip address - ifaceIP, err = ip.GetInterfaceIP(cliInterface) - return cliInterface, ifaceIP, err - } - // fallback to interface of default gateway - return ip.GetDefaultInterface() -} - -func getGatewayMAC(iface *net.Interface, cache *arp.Cache) (mac net.HardwareAddr, err error) { - if cliGatewayMAC != nil { - return cliGatewayMAC, nil - } - var gatewayIP net.IP - if gatewayIP, err = ip.GetDefaultGatewayIP(iface); err != nil { - return - } - mac = cache.Get(gatewayIP.To4()) - return -} - -func getLogger(name string, w io.Writer) (logger log.Logger, err error) { - opts := []log.LoggerOption{log.FlushInterval(1 * time.Second)} - if cliJSONFlag { - opts = append(opts, log.JSON()) - } - logger, err = log.NewLogger(w, name, opts...) - return -} - -func newIPPortGenerator() (reqgen scan.RequestGenerator) { - if len(cliIPPortFileFlag) == 0 { - return scan.NewIPPortGenerator(scan.NewIPGenerator(), scan.NewPortGenerator()) - } - if len(cliPortRanges) == 0 { - return scan.NewFileIPPortGenerator(func() (io.ReadCloser, error) { - return os.Open(cliIPPortFileFlag) - }) - } - ipgen := scan.NewFileIPGenerator(func() (io.ReadCloser, error) { - if cliIPPortFileFlag == "-" { - return ioutil.NopCloser(os.Stdin), nil - } - return os.Open(cliIPPortFileFlag) - }) - return scan.NewIPPortGenerator(ipgen, scan.NewPortGenerator()) + return cmd } type bpfFilterFunc func(r *scan.Range) (filter string, maxPacketLength int) type engineConfig struct { - logger log.Logger - scanRange *scan.Range - rateCount int - rateWindow time.Duration - exitDelay time.Duration + logger log.Logger + scanRange *scan.Range + exitDelay time.Duration } type engineConfigOption func(c *engineConfig) @@ -481,11 +70,15 @@ func withScanRange(r *scan.Range) engineConfigOption { } } +func withExitDelay(exitDelay time.Duration) engineConfigOption { + return func(c *engineConfig) { + c.exitDelay = exitDelay + } +} + func newEngineConfig(opts ...engineConfigOption) *engineConfig { c := &engineConfig{ - rateCount: cliRateCount, - rateWindow: cliRateWindow, - exitDelay: cliExitDelay, + exitDelay: defaultExitDelay, } for _, o := range opts { o(c) @@ -497,6 +90,8 @@ type packetScanConfig struct { *engineConfig scanMethod scan.PacketMethod bpfFilter bpfFilterFunc + rateCount int + rateWindow time.Duration } type packetScanConfigOption func(c *packetScanConfig) @@ -519,6 +114,18 @@ func withPacketBPFFilter(bpfFilter bpfFilterFunc) packetScanConfigOption { } } +func withRateCount(rateCount int) packetScanConfigOption { + return func(c *packetScanConfig) { + c.rateCount = rateCount + } +} + +func withRateWindow(rateWindow time.Duration) packetScanConfigOption { + return func(c *packetScanConfig) { + c.rateWindow = rateWindow + } +} + func newPacketScanConfig(opts ...packetScanConfigOption) *packetScanConfig { c := &packetScanConfig{} for _, o := range opts { @@ -531,19 +138,19 @@ func startPacketScanEngine(ctx context.Context, conf *packetScanConfig) error { r := conf.scanRange // setup network interface to read/write packets - afps, err := afpacket.NewPacketSource(r.Interface.Name) + ps, err := afpacket.NewPacketSource(r.Interface.Name) if err != nil { return err } - defer afps.Close() - err = afps.SetBPFFilter(conf.bpfFilter(r)) + defer ps.Close() + err = ps.SetBPFFilter(conf.bpfFilter(r)) if err != nil { return err } - var rw packet.ReadWriter = afps + var rw packet.ReadWriter = ps // setup rate limit for sending packets if conf.rateCount > 0 { - rw = packet.NewRateLimitReadWriter(afps, + rw = packet.NewRateLimitReadWriter(ps, ratelimit.New(conf.rateCount, ratelimit.Per(conf.rateWindow))) } engine := scan.SetupPacketEngine(rw, conf.scanMethod) diff --git a/command/socks.go b/command/socks.go index 9a13a3b..3e75a2b 100644 --- a/command/socks.go +++ b/command/socks.go @@ -13,50 +13,68 @@ import ( "github.com/v-byte-cpu/sx/pkg/scan/socks5" ) -func init() { - socksCmd.Flags().StringVarP(&cliPortsFlag, "ports", "p", "", "set ports to scan") - socksCmd.Flags().StringVarP(&cliIPPortFileFlag, "file", "f", "", "set JSONL file with ip/port pairs to scan") - socksCmd.Flags().IntVarP(&cliWorkerCountFlag, "workers", "w", defaultWorkerCount, "set workers count") - socksCmd.Flags().DurationVarP(&cliTimeoutFlag, "timeout", "t", 2*time.Second, "set connect and data timeout") - rootCmd.AddCommand(socksCmd) +func newSocksCmd() *socksCmd { + c := &socksCmd{} + + cmd := &cobra.Command{ + Use: "socks [flags] subnet", + Example: strings.Join([]string{ + "socks -p 1080 192.168.0.1/24", "socks -p 1080-4567 10.0.0.1", + "socks -f ip_ports_file.jsonl", "socks -p 1080-4567 -f ips_file.jsonl"}, "\n"), + Short: "Perform SOCKS5 scan", + RunE: func(cmd *cobra.Command, args []string) (err error) { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + if err = c.opts.parseRawOptions(); err != nil { + return + } + scanRange, err := c.opts.parseScanRange(args) + if err != nil { + return + } + + var logger log.Logger + if logger, err = c.opts.getLogger(socks5.ScanType, os.Stdout); err != nil { + return + } + + engine := c.opts.newSOCKSScanEngine(ctx) + return startScanEngine(ctx, engine, + newEngineConfig( + withLogger(logger), + withScanRange(scanRange), + withExitDelay(c.opts.exitDelay), + )) + }, + } + + c.opts.initCliFlags(cmd) + + c.cmd = cmd + return c +} + +type socksCmd struct { + cmd *cobra.Command + opts socksCmdOpts +} + +type socksCmdOpts struct { + genericScanCmdOpts + timeout time.Duration } -var socksCmd = &cobra.Command{ - Use: "socks [flags] subnet", - Example: strings.Join([]string{ - "socks -p 1080 192.168.0.1/24", "socks -p 1080-4567 10.0.0.1", - "socks -f ip_ports_file.jsonl", "socks -p 1080-4567 -f ips_file.jsonl"}, "\n"), - Short: "Perform SOCKS5 scan", - // Long: "Perform SOCKS scan. SOCKS5 scan is used by default unless --version option is specified", - PreRunE: func(cmd *cobra.Command, args []string) (err error) { - cliDstSubnet, err = parseDstSubnet(args) - return - }, - RunE: func(cmd *cobra.Command, args []string) (err error) { - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) - defer cancel() - - var logger log.Logger - if logger, err = getLogger("socks", os.Stdout); err != nil { - return - } - - engine := newSOCKSScanEngine(ctx) - return startScanEngine(ctx, engine, - newEngineConfig( - withLogger(logger), - withScanRange(&scan.Range{ - DstSubnet: cliDstSubnet, - Ports: cliPortRanges, - }), - )) - }, +// TODO test +func (o *socksCmdOpts) initCliFlags(cmd *cobra.Command) { + o.genericScanCmdOpts.initCliFlags(cmd) + cmd.Flags().DurationVarP(&o.timeout, "timeout", "t", 2*time.Second, "set connect and data timeout") } -func newSOCKSScanEngine(ctx context.Context) scan.EngineResulter { +func (o *socksCmdOpts) newSOCKSScanEngine(ctx context.Context) scan.EngineResulter { scanner := socks5.NewScanner( - socks5.WithDialTimeout(cliTimeoutFlag), - socks5.WithDataTimeout(cliTimeoutFlag)) + socks5.WithDialTimeout(o.timeout), + socks5.WithDataTimeout(o.timeout)) results := scan.NewResultChan(ctx, 1000) - return scan.NewScanEngine(newIPPortGenerator(), scanner, results, scan.WithScanWorkerCount(cliWorkerCountFlag)) + return scan.NewScanEngine(o.newIPPortGenerator(), scanner, results, scan.WithScanWorkerCount(o.workers)) } diff --git a/command/tcp.go b/command/tcp.go index 9799b70..447184d 100644 --- a/command/tcp.go +++ b/command/tcp.go @@ -14,8 +14,6 @@ import ( "github.com/v-byte-cpu/sx/pkg/scan/tcp" ) -var cliTCPPacketFlags string - const ( cliTCPSYNPacketFlag = "syn" cliTCPACKPacketFlag = "ack" @@ -32,73 +30,90 @@ var ( errTCPflag = errors.New("invalid TCP packet flag") ) -func init() { - addPacketScanOptions(tcpCmd) - tcpCmd.PersistentFlags().StringVarP(&cliIPPortFileFlag, "file", "f", "", "set JSONL file with ip/port pairs to scan") - tcpCmd.PersistentFlags().StringVarP(&cliPortsFlag, "ports", "p", "", "set ports to scan") - tcpCmd.PersistentFlags().StringVarP(&cliARPCacheFileFlag, "arp-cache", "a", "", - strings.Join([]string{"set ARP cache file", "reads from stdin by default"}, "\n")) - tcpCmd.Flags().StringVar(&cliTCPPacketFlags, "flags", "", "set TCP flags") - rootCmd.AddCommand(tcpCmd) -} +func newTCPFlagsCmd() *tcpFlagsCmd { + c := &tcpFlagsCmd{} + + cmd := &cobra.Command{ + Use: "tcp [flags] subnet", + Example: strings.Join([]string{ + "tcp -p 22 192.168.0.1/24", "tcp -p 22-4567 10.0.0.1", + "tcp --flags fin,ack -p 22 192.168.0.3"}, "\n"), + Short: "Perform TCP scan", + Long: "Perform TCP scan. TCP SYN scan is used by default unless --flags option is specified", + RunE: func(cmd *cobra.Command, args []string) (err error) { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + if err = c.opts.parseRawOptions(); err != nil { + return + } + if len(c.opts.tcpFlags) == 0 { + return newTCPSYNCmdOpts(c.opts.tcpCmdOpts).startScan(ctx, args) + } + + scanName := tcp.FlagsScanType + var conf *scanConfig + if conf, err = c.opts.parseScanConfig(scanName, args); err != nil { + return + } + + var opts []tcp.PacketFillerOption + for _, flag := range c.opts.tcpFlags { + opts = append(opts, tcpPacketFlagOptions[flag]) + } + + m := c.opts.newTCPScanMethod(ctx, conf, + withTCPScanName(scanName), + withTCPPacketFiller(tcp.NewPacketFiller(opts...)), + withTCPPacketFilterFunc(tcp.TrueFilter), + withTCPPacketFlags(tcp.AllFlags), + ) + + return startPacketScanEngine(ctx, newPacketScanConfig( + withPacketScanMethod(m), + withPacketBPFFilter(tcp.BPFFilter), + withRateCount(c.opts.rateCount), + withRateWindow(c.opts.rateWindow), + withPacketEngineConfig(newEngineConfig( + withLogger(conf.logger), + withScanRange(conf.scanRange), + withExitDelay(c.opts.exitDelay), + )), + )) + }, + } -var tcpCmd = &cobra.Command{ - Use: "tcp [flags] subnet", - Example: strings.Join([]string{ - "tcp -p 22 192.168.0.1/24", "tcp -p 22-4567 10.0.0.1", - "tcp --flags fin,ack -p 22 192.168.0.3"}, "\n"), - Short: "Perform TCP scan", - Long: "Perform TCP scan. TCP SYN scan is used by default unless --flags option is specified", - PersistentPreRunE: func(cmd *cobra.Command, args []string) (err error) { - if err = rootCmd.PersistentPreRunE(cmd, args); err != nil { - return - } - if err = validatePacketScanStdin(); err != nil { - return - } - cliDstSubnet, err = parseDstSubnet(args) - return - }, - RunE: func(cmd *cobra.Command, args []string) (err error) { - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) - defer cancel() + c.opts.initCliFlags(cmd) - if len(cliTCPPacketFlags) == 0 { - return startTCPSYNScan(ctx, cliDstSubnet) - } + c.cmd = cmd + return c +} - var tcpFlags []string - if tcpFlags, err = parseTCPFlags(cliTCPPacketFlags); err != nil { - return err - } +type tcpFlagsCmd struct { + cmd *cobra.Command + opts tcpFlagsCmdOpts +} - var opts []tcp.PacketFillerOption - for _, flag := range tcpFlags { - opts = append(opts, tcpPacketFlagOptions[flag]) - } +type tcpFlagsCmdOpts struct { + tcpCmdOpts + tcpFlags []string - scanName := tcp.FlagsScanType - var conf *scanConfig - if conf, err = parseScanConfig(scanName, cliDstSubnet); err != nil { - return - } + rawTCPFlags string +} - m := newTCPScanMethod(ctx, conf, - withTCPScanName(scanName), - withTCPPacketFiller(tcp.NewPacketFiller(opts...)), - withTCPPacketFilterFunc(tcp.TrueFilter), - withTCPPacketFlags(tcp.AllFlags), - ) - - return startPacketScanEngine(ctx, newPacketScanConfig( - withPacketScanMethod(m), - withPacketBPFFilter(tcp.BPFFilter), - withPacketEngineConfig(newEngineConfig( - withLogger(conf.logger), - withScanRange(conf.scanRange), - )), - )) - }, +// TODO test +func (o *tcpFlagsCmdOpts) initCliFlags(cmd *cobra.Command) { + o.ipPortScanCmdOpts.initCliFlags(cmd) + cmd.Flags().StringVar(&o.rawTCPFlags, "flags", "", "set TCP flags") +} + +// TODO test +func (o *tcpFlagsCmdOpts) parseRawOptions() (err error) { + if err = o.ipPortScanCmdOpts.parseRawOptions(); err != nil { + return + } + o.tcpFlags, err = parseTCPFlags(o.rawTCPFlags) + return } var tcpPacketFlagOptions = map[string]tcp.PacketFillerOption{ @@ -113,6 +128,7 @@ var tcpPacketFlagOptions = map[string]tcp.PacketFillerOption{ cliTCPNSPacketFlag: tcp.WithNS(), } +// TODO lowercase test func parseTCPFlags(tcpFlags string) ([]string, error) { if len(tcpFlags) == 0 { return []string{}, nil @@ -126,6 +142,25 @@ func parseTCPFlags(tcpFlags string) ([]string, error) { return flags, nil } +type tcpCmdOpts struct { + ipPortScanCmdOpts +} + +func (o *tcpCmdOpts) newTCPScanMethod(ctx context.Context, conf *scanConfig, opts ...tcpScanConfigOption) *tcp.ScanMethod { + c := &tcpScanConfig{} + for _, opt := range opts { + opt(c) + } + reqgen := arp.NewCacheRequestGenerator(o.newIPPortGenerator(), conf.gatewayMAC, conf.cache) + pktgen := scan.NewPacketMultiGenerator(c.packetFiller, runtime.NumCPU()) + psrc := scan.NewPacketSource(reqgen, pktgen) + results := scan.NewResultChan(ctx, 1000) + return tcp.NewScanMethod( + c.scanName, psrc, results, + tcp.WithPacketFilterFunc(c.packetFilter), + tcp.WithPacketFlagsFunc(c.packetFlags)) +} + type tcpScanConfig struct { scanName string packetFiller scan.PacketFiller @@ -158,18 +193,3 @@ func withTCPPacketFlags(packetFlags tcp.PacketFlagsFunc) tcpScanConfigOption { c.packetFlags = packetFlags } } - -func newTCPScanMethod(ctx context.Context, conf *scanConfig, opts ...tcpScanConfigOption) *tcp.ScanMethod { - c := &tcpScanConfig{} - for _, o := range opts { - o(c) - } - reqgen := arp.NewCacheRequestGenerator(newIPPortGenerator(), conf.gatewayMAC, conf.cache) - pktgen := scan.NewPacketMultiGenerator(c.packetFiller, runtime.NumCPU()) - psrc := scan.NewPacketSource(reqgen, pktgen) - results := scan.NewResultChan(ctx, 1000) - return tcp.NewScanMethod( - c.scanName, psrc, results, - tcp.WithPacketFilterFunc(c.packetFilter), - tcp.WithPacketFlagsFunc(c.packetFlags)) -} diff --git a/command/tcp_fin.go b/command/tcp_fin.go index e4e8480..c52e7ed 100644 --- a/command/tcp_fin.go +++ b/command/tcp_fin.go @@ -10,39 +10,55 @@ import ( "github.com/v-byte-cpu/sx/pkg/scan/tcp" ) -func init() { - tcpCmd.AddCommand(tcpfinCmd) +func newTCPFINCmd() *tcpFINCmd { + c := &tcpFINCmd{} + + cmd := &cobra.Command{ + Use: "fin [flags] subnet", + Example: strings.Join([]string{"tcp fin -p 22 192.168.0.1/24", "tcp fin -p 22-4567 10.0.0.1"}, "\n"), + Short: "Perform TCP FIN scan", + RunE: func(cmd *cobra.Command, args []string) (err error) { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + if err = c.opts.parseRawOptions(); err != nil { + return + } + + scanName := tcp.FINScanType + var conf *scanConfig + if conf, err = c.opts.parseScanConfig(scanName, args); err != nil { + return + } + + m := c.opts.newTCPScanMethod(ctx, conf, + withTCPScanName(scanName), + withTCPPacketFiller(tcp.NewPacketFiller(tcp.WithFIN())), + withTCPPacketFilterFunc(tcp.TrueFilter), + withTCPPacketFlags(tcp.AllFlags), + ) + + return startPacketScanEngine(ctx, newPacketScanConfig( + withPacketScanMethod(m), + withPacketBPFFilter(tcp.BPFFilter), + withRateCount(c.opts.rateCount), + withRateWindow(c.opts.rateWindow), + withPacketEngineConfig(newEngineConfig( + withLogger(conf.logger), + withScanRange(conf.scanRange), + withExitDelay(c.opts.exitDelay), + )), + )) + }, + } + + c.opts.initCliFlags(cmd) + + c.cmd = cmd + return c } -var tcpfinCmd = &cobra.Command{ - Use: "fin [flags] subnet", - Example: strings.Join([]string{"tcp fin -p 22 192.168.0.1/24", "tcp fin -p 22-4567 10.0.0.1"}, "\n"), - Short: "Perform TCP FIN scan", - RunE: func(cmd *cobra.Command, args []string) (err error) { - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) - defer cancel() - - scanName := tcp.FINScanType - - var conf *scanConfig - if conf, err = parseScanConfig(scanName, cliDstSubnet); err != nil { - return - } - - m := newTCPScanMethod(ctx, conf, - withTCPScanName(scanName), - withTCPPacketFiller(tcp.NewPacketFiller(tcp.WithFIN())), - withTCPPacketFilterFunc(tcp.TrueFilter), - withTCPPacketFlags(tcp.AllFlags), - ) - - return startPacketScanEngine(ctx, newPacketScanConfig( - withPacketScanMethod(m), - withPacketBPFFilter(tcp.BPFFilter), - withPacketEngineConfig(newEngineConfig( - withLogger(conf.logger), - withScanRange(conf.scanRange), - )), - )) - }, +type tcpFINCmd struct { + cmd *cobra.Command + opts tcpCmdOpts } diff --git a/command/tcp_null.go b/command/tcp_null.go index ed6f722..7429ff7 100644 --- a/command/tcp_null.go +++ b/command/tcp_null.go @@ -10,39 +10,55 @@ import ( "github.com/v-byte-cpu/sx/pkg/scan/tcp" ) -func init() { - tcpCmd.AddCommand(tcpnullCmd) +func newTCPNULLCmd() *tcpNULLCmd { + c := &tcpNULLCmd{} + + cmd := &cobra.Command{ + Use: "null [flags] subnet", + Example: strings.Join([]string{"tcp null -p 22 192.168.0.1/24", "tcp null -p 22-4567 10.0.0.1"}, "\n"), + Short: "Perform TCP NULL scan", + RunE: func(cmd *cobra.Command, args []string) (err error) { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + if err = c.opts.parseRawOptions(); err != nil { + return + } + + scanName := tcp.NULLScanType + var conf *scanConfig + if conf, err = c.opts.parseScanConfig(scanName, args); err != nil { + return + } + + m := c.opts.newTCPScanMethod(ctx, conf, + withTCPScanName(scanName), + withTCPPacketFiller(tcp.NewPacketFiller()), + withTCPPacketFilterFunc(tcp.TrueFilter), + withTCPPacketFlags(tcp.AllFlags), + ) + + return startPacketScanEngine(ctx, newPacketScanConfig( + withPacketScanMethod(m), + withPacketBPFFilter(tcp.BPFFilter), + withRateCount(c.opts.rateCount), + withRateWindow(c.opts.rateWindow), + withPacketEngineConfig(newEngineConfig( + withLogger(conf.logger), + withScanRange(conf.scanRange), + withExitDelay(c.opts.exitDelay), + )), + )) + }, + } + + c.opts.initCliFlags(cmd) + + c.cmd = cmd + return c } -var tcpnullCmd = &cobra.Command{ - Use: "null [flags] subnet", - Example: strings.Join([]string{"tcp null -p 22 192.168.0.1/24", "tcp null -p 22-4567 10.0.0.1"}, "\n"), - Short: "Perform TCP NULL scan", - RunE: func(cmd *cobra.Command, args []string) (err error) { - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) - defer cancel() - - scanName := tcp.NULLScanType - - var conf *scanConfig - if conf, err = parseScanConfig(scanName, cliDstSubnet); err != nil { - return - } - - m := newTCPScanMethod(ctx, conf, - withTCPScanName(scanName), - withTCPPacketFiller(tcp.NewPacketFiller()), - withTCPPacketFilterFunc(tcp.TrueFilter), - withTCPPacketFlags(tcp.AllFlags), - ) - - return startPacketScanEngine(ctx, newPacketScanConfig( - withPacketScanMethod(m), - withPacketBPFFilter(tcp.BPFFilter), - withPacketEngineConfig(newEngineConfig( - withLogger(conf.logger), - withScanRange(conf.scanRange), - )), - )) - }, +type tcpNULLCmd struct { + cmd *cobra.Command + opts tcpCmdOpts } diff --git a/command/tcp_syn.go b/command/tcp_syn.go index d820857..8c1937c 100644 --- a/command/tcp_syn.go +++ b/command/tcp_syn.go @@ -2,7 +2,6 @@ package command import ( "context" - "net" "os" "os/signal" "strings" @@ -12,30 +11,52 @@ import ( "github.com/v-byte-cpu/sx/pkg/scan/tcp" ) -func init() { - tcpCmd.AddCommand(tcpsynCmd) +func newTCPSYNCmd() *tcpSYNCmd { + c := &tcpSYNCmd{} + + cmd := &cobra.Command{ + Use: "syn [flags] subnet", + Example: strings.Join([]string{"tcp syn -p 22 192.168.0.1/24", "tcp syn -p 22-4567 10.0.0.1"}, "\n"), + Short: "Perform TCP SYN scan", + RunE: func(cmd *cobra.Command, args []string) (err error) { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + if err = c.opts.parseRawOptions(); err != nil { + return + } + return c.opts.startScan(ctx, args) + }, + } + + c.opts.initCliFlags(cmd) + + c.cmd = cmd + return c +} + +type tcpSYNCmd struct { + cmd *cobra.Command + opts tcpSYNCmdOpts +} + +type tcpSYNCmdOpts struct { + tcpCmdOpts } -var tcpsynCmd = &cobra.Command{ - Use: "syn [flags] subnet", - Example: strings.Join([]string{"tcp syn -p 22 192.168.0.1/24", "tcp syn -p 22-4567 10.0.0.1"}, "\n"), - Short: "Perform TCP SYN scan", - RunE: func(cmd *cobra.Command, args []string) (err error) { - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) - defer cancel() - return startTCPSYNScan(ctx, cliDstSubnet) - }, +func newTCPSYNCmdOpts(opts tcpCmdOpts) *tcpSYNCmdOpts { + return &tcpSYNCmdOpts{opts} } -func startTCPSYNScan(ctx context.Context, dstSubnet *net.IPNet) (err error) { +func (o *tcpSYNCmdOpts) startScan(ctx context.Context, args []string) (err error) { scanName := tcp.SYNScanType var conf *scanConfig - if conf, err = parseScanConfig(scanName, dstSubnet); err != nil { + if conf, err = o.parseScanConfig(scanName, args); err != nil { return } - m := newTCPScanMethod(ctx, conf, + m := o.newTCPScanMethod(ctx, conf, withTCPScanName(scanName), withTCPPacketFiller(tcp.NewPacketFiller(tcp.WithSYN())), withTCPPacketFilterFunc(func(pkt *layers.TCP) bool { @@ -48,9 +69,12 @@ func startTCPSYNScan(ctx context.Context, dstSubnet *net.IPNet) (err error) { return startPacketScanEngine(ctx, newPacketScanConfig( withPacketScanMethod(m), withPacketBPFFilter(tcp.SYNACKBPFFilter), + withRateCount(o.rateCount), + withRateWindow(o.rateWindow), withPacketEngineConfig(newEngineConfig( withLogger(conf.logger), withScanRange(conf.scanRange), + withExitDelay(o.exitDelay), )), )) } diff --git a/command/tcp_xmas.go b/command/tcp_xmas.go index cb137b7..cd28a6b 100644 --- a/command/tcp_xmas.go +++ b/command/tcp_xmas.go @@ -10,39 +10,55 @@ import ( "github.com/v-byte-cpu/sx/pkg/scan/tcp" ) -func init() { - tcpCmd.AddCommand(tcpxmasCmd) +func newTCPXmasCmd() *tcpXmasCmd { + c := &tcpXmasCmd{} + + cmd := &cobra.Command{ + Use: "xmas [flags] subnet", + Example: strings.Join([]string{"tcp xmas -p 22 192.168.0.1/24", "tcp xmas -p 22-4567 10.0.0.1"}, "\n"), + Short: "Perform TCP Xmas scan", + RunE: func(cmd *cobra.Command, args []string) (err error) { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + if err = c.opts.parseRawOptions(); err != nil { + return + } + + scanName := tcp.XmasScanType + var conf *scanConfig + if conf, err = c.opts.parseScanConfig(scanName, args); err != nil { + return + } + + m := c.opts.newTCPScanMethod(ctx, conf, + withTCPScanName(scanName), + withTCPPacketFiller(tcp.NewPacketFiller(tcp.WithFIN(), tcp.WithPSH(), tcp.WithURG())), + withTCPPacketFilterFunc(tcp.TrueFilter), + withTCPPacketFlags(tcp.AllFlags), + ) + + return startPacketScanEngine(ctx, newPacketScanConfig( + withPacketScanMethod(m), + withPacketBPFFilter(tcp.BPFFilter), + withRateCount(c.opts.rateCount), + withRateWindow(c.opts.rateWindow), + withPacketEngineConfig(newEngineConfig( + withLogger(conf.logger), + withScanRange(conf.scanRange), + withExitDelay(c.opts.exitDelay), + )), + )) + }, + } + + c.opts.initCliFlags(cmd) + + c.cmd = cmd + return c } -var tcpxmasCmd = &cobra.Command{ - Use: "xmas [flags] subnet", - Example: strings.Join([]string{"tcp xmas -p 22 192.168.0.1/24", "tcp xmas -p 22-4567 10.0.0.1"}, "\n"), - Short: "Perform TCP Xmas scan", - RunE: func(cmd *cobra.Command, args []string) (err error) { - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) - defer cancel() - - scanName := tcp.XmasScanType - - var conf *scanConfig - if conf, err = parseScanConfig(scanName, cliDstSubnet); err != nil { - return - } - - m := newTCPScanMethod(ctx, conf, - withTCPScanName(scanName), - withTCPPacketFiller(tcp.NewPacketFiller(tcp.WithFIN(), tcp.WithPSH(), tcp.WithURG())), - withTCPPacketFilterFunc(tcp.TrueFilter), - withTCPPacketFlags(tcp.AllFlags), - ) - - return startPacketScanEngine(ctx, newPacketScanConfig( - withPacketScanMethod(m), - withPacketBPFFilter(tcp.BPFFilter), - withPacketEngineConfig(newEngineConfig( - withLogger(conf.logger), - withScanRange(conf.scanRange), - )), - )) - }, +type tcpXmasCmd struct { + cmd *cobra.Command + opts tcpCmdOpts } diff --git a/command/udp.go b/command/udp.go index a9a2535..4584151 100644 --- a/command/udp.go +++ b/command/udp.go @@ -14,99 +14,119 @@ import ( "github.com/v-byte-cpu/sx/pkg/scan/udp" ) -var ( - cliUDPPayloadFlag string +func newUDPCmd() *udpCmd { + c := &udpCmd{} - cliUDPPayload []byte -) + cmd := &cobra.Command{ + Use: "udp [flags] subnet", + Example: strings.Join([]string{ + "udp -p 22 192.168.0.1/24", + "udp -p 22-4567 10.0.0.1", + "udp --ttl 37 -p 53 192.168.0.1/24", + "udp --ipproto 157 -p 53 192.168.0.1/24", + `udp --payload '\x01\x02\x03' -p 53 192.168.0.1/24`}, "\n"), + Short: "Perform UDP scan", + RunE: func(cmd *cobra.Command, args []string) (err error) { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + if err = c.opts.parseRawOptions(); err != nil { + return + } + var conf *scanConfig + if conf, err = c.opts.parseScanConfig(udp.ScanType, args); err != nil { + return + } + + m := c.opts.newUDPScanMethod(ctx, conf) + + return startPacketScanEngine(ctx, newPacketScanConfig( + withPacketScanMethod(m), + withPacketBPFFilter(icmp.BPFFilter), + withRateCount(c.opts.rateCount), + withRateWindow(c.opts.rateWindow), + withPacketEngineConfig(newEngineConfig( + withLogger(conf.logger), + withScanRange(conf.scanRange), + withExitDelay(c.opts.exitDelay), + )), + )) + }, + } + + c.opts.initCliFlags(cmd) + + c.cmd = cmd + return c +} -func init() { - addPacketScanOptions(udpCmd) - udpCmd.Flags().StringVarP(&cliIPPortFileFlag, "file", "f", "", "set JSONL file with ip/port pairs to scan") - udpCmd.Flags().StringVar(&cliIPTTLFlag, "ttl", "", - strings.Join([]string{"set IP TTL field of generated packet", "64 by default"}, "\n")) - udpCmd.Flags().StringVar(&cliIPTotalLenFlag, "iplen", "", +type udpCmd struct { + cmd *cobra.Command + opts udpCmdOpts +} + +type udpCmdOpts struct { + ipPortScanCmdOpts + ipTTL uint8 + ipFlags uint8 + ipProtocol uint8 + ipTotalLen uint16 + + udpPayload []byte + + rawIPFlags string + rawUDPPayload string +} + +// TODO test +func (o *udpCmdOpts) initCliFlags(cmd *cobra.Command) { + o.ipPortScanCmdOpts.initCliFlags(cmd) + cmd.Flags().Uint8Var(&o.ipTTL, "ttl", 64, "set IP TTL field of generated packet") + cmd.Flags().Uint8Var(&o.ipProtocol, "ipproto", 17, + strings.Join([]string{"set IP Protocol field of generated packet", "UDP by default"}, "\n")) + cmd.Flags().StringVar(&o.rawIPFlags, "ipflags", "DF", "set IP Flags field of generated packet") + cmd.Flags().Uint16Var(&o.ipTotalLen, "iplen", 0, strings.Join([]string{"set IP Total Length field of generated packet", "calculated by default"}, "\n")) - udpCmd.Flags().StringVar(&cliIPProtocolFlag, "ipproto", "", - strings.Join([]string{"set IP Protocol field of generated packet", "17 (UDP) by default"}, "\n")) - udpCmd.Flags().StringVar(&cliIPFlagsFlag, "ipflags", "", - strings.Join([]string{"set IP Flags field of generated packet", "DF by default"}, "\n")) - udpCmd.Flags().StringVarP(&cliPortsFlag, "ports", "p", "", "set ports to scan") - udpCmd.Flags().StringVar(&cliUDPPayloadFlag, "payload", "", + cmd.Flags().StringVar(&o.rawUDPPayload, "payload", "", strings.Join([]string{"set byte payload of generated packet", "0 bytes by default"}, "\n")) - - udpCmd.Flags().StringVarP(&cliARPCacheFileFlag, "arp-cache", "a", "", - strings.Join([]string{"set ARP cache file", "reads from stdin by default"}, "\n")) - rootCmd.AddCommand(udpCmd) } -var udpCmd = &cobra.Command{ - Use: "udp [flags] subnet", - Example: strings.Join([]string{ - "udp -p 22 192.168.0.1/24", - "udp -p 22-4567 10.0.0.1", - "udp --ttl 37 -p 53 192.168.0.1/24", - "udp --ipproto 157 -p 53 192.168.0.1/24", - `udp --payload '\x01\x02\x03' -p 53 192.168.0.1/24`}, "\n"), - Short: "Perform UDP scan", - PreRunE: func(cmd *cobra.Command, args []string) (err error) { - if cliDstSubnet, err = parseDstSubnet(args); err != nil { - return - } - if err = validatePacketScanStdin(); err != nil { +// TODO test +func (o *udpCmdOpts) parseRawOptions() (err error) { + if err = o.ipPortScanCmdOpts.parseRawOptions(); err != nil { + return + } + if len(o.rawIPFlags) > 0 { + if o.ipFlags, err = parseIPFlags(o.rawIPFlags); err != nil { return } - if len(cliUDPPayloadFlag) > 0 { - cliUDPPayload, err = parsePacketPayload(cliUDPPayloadFlag) - } - return - }, - RunE: func(cmd *cobra.Command, args []string) (err error) { - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) - defer cancel() - - var conf *scanConfig - if conf, err = parseScanConfig(udp.ScanType, cliDstSubnet); err != nil { + } + if len(o.rawUDPPayload) > 0 { + if o.udpPayload, err = parsePacketPayload(o.rawUDPPayload); err != nil { return } - - m := newUDPScanMethod(ctx, conf) - - return startPacketScanEngine(ctx, newPacketScanConfig( - withPacketScanMethod(m), - withPacketBPFFilter(icmp.BPFFilter), - withPacketEngineConfig(newEngineConfig( - withLogger(conf.logger), - withScanRange(conf.scanRange), - )), - )) - }, + } + return } -func newUDPScanMethod(ctx context.Context, conf *scanConfig) *udp.ScanMethod { - reqgen := arp.NewCacheRequestGenerator(newIPPortGenerator(), conf.gatewayMAC, conf.cache) - pktgen := scan.NewPacketMultiGenerator(udp.NewPacketFiller(getUDPOptions()...), runtime.NumCPU()) +func (o *udpCmdOpts) newUDPScanMethod(ctx context.Context, conf *scanConfig) *udp.ScanMethod { + reqgen := arp.NewCacheRequestGenerator(o.newIPPortGenerator(), conf.gatewayMAC, conf.cache) + pktgen := scan.NewPacketMultiGenerator(udp.NewPacketFiller(o.getUDPOptions()...), runtime.NumCPU()) psrc := scan.NewPacketSource(reqgen, pktgen) results := scan.NewResultChan(ctx, 1000) return udp.NewScanMethod(psrc, results) } -func getUDPOptions() (opts []udp.PacketFillerOption) { - if len(cliIPTTLFlag) > 0 { - opts = append(opts, udp.WithTTL(cliTTL)) - } - if len(cliIPTotalLenFlag) > 0 { - opts = append(opts, udp.WithIPTotalLength(cliIPTotalLen)) - } - if len(cliIPProtocolFlag) > 0 { - opts = append(opts, udp.WithIPProtocol(cliIPProtocol)) - } - if len(cliIPFlagsFlag) > 0 { - opts = append(opts, udp.WithIPFlags(cliIPFlags)) - } - if len(cliUDPPayloadFlag) > 0 { - opts = append(opts, udp.WithPayload(cliUDPPayload)) +func (o *udpCmdOpts) getUDPOptions() (opts []udp.PacketFillerOption) { + opts = append(opts, + udp.WithTTL(o.ipTTL), + udp.WithIPProtocol(o.ipProtocol), + udp.WithIPFlags(o.ipFlags), + udp.WithIPTotalLength(o.ipTotalLen)) + + if len(o.udpPayload) > 0 { + opts = append(opts, udp.WithPayload(o.udpPayload)) } return } diff --git a/pkg/scan/engine.go b/pkg/scan/engine.go index 17d7804..234ed9f 100644 --- a/pkg/scan/engine.go +++ b/pkg/scan/engine.go @@ -157,7 +157,7 @@ func NewScanEngine(reqgen RequestGenerator, reqgen: reqgen, scanner: scanner, results: results, - workerCount: 50, + workerCount: 100, } for _, o := range opts { o(s) diff --git a/pkg/scan/request.go b/pkg/scan/request.go index b58209e..c977768 100644 --- a/pkg/scan/request.go +++ b/pkg/scan/request.go @@ -202,7 +202,6 @@ func NewFileIPPortGenerator(openFile OpenFileFunc) RequestGenerator { return &fileIPPortGenerator{openFile} } -// TODO add meta field func (rg *fileIPPortGenerator) GenerateRequests(ctx context.Context, r *Range) (<-chan *Request, error) { input, err := rg.openFile() if err != nil {