Skip to content

Commit

Permalink
TCP SYN scan
Browse files Browse the repository at this point in the history
  • Loading branch information
v-byte-cpu committed Mar 20, 2021
1 parent 4d2ee64 commit 1c1e57d
Show file tree
Hide file tree
Showing 32 changed files with 1,627 additions and 309 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

The goal of this project is to create the fastest network scanner with clean and simple code.

Right now, only ARP scan is supported.
Features:
* ARP scan
* TCP SYN scan

## Building

Expand Down
133 changes: 20 additions & 113 deletions command/arp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,21 @@ package command
import (
"context"
"errors"
"io"
"net"
"os"
"os/signal"
"strings"
"sync"
"time"

"github.com/spf13/cobra"
"github.com/v-byte-cpu/sx/command/log"
"github.com/v-byte-cpu/sx/pkg/ip"
"github.com/v-byte-cpu/sx/pkg/packet/afpacket"
"github.com/v-byte-cpu/sx/pkg/scan"
"github.com/v-byte-cpu/sx/pkg/scan/arp"
)

var errSrcIP = errors.New("invalid source IP")

var interfaceFlag string
var srcIPFlag string
var srcMACFlag string
var liveModeFlag bool
var arpLiveModeFlag bool

func init() {
arpCmd.Flags().StringVarP(&interfaceFlag, "iface", "i", "", "set interface to send/receive packets")
arpCmd.Flags().StringVar(&srcIPFlag, "srcip", "", "set source IP address for generated packets")
arpCmd.Flags().StringVar(&srcMACFlag, "srcmac", "", "set source MAC address for generated packets")
arpCmd.Flags().BoolVar(&liveModeFlag, "live", false, "enable live mode")
arpCmd.Flags().BoolVar(&arpLiveModeFlag, "live", false, "enable live mode")
rootCmd.AddCommand(arpCmd)
}

Expand All @@ -45,113 +32,33 @@ var arpCmd = &cobra.Command{
return nil
},
RunE: func(cmd *cobra.Command, args []string) (err error) {
dstSubnet, err := ip.ParseIPNet(args[0])
if err != nil {
var r *scan.Range
if r, err = parseScanRange(args[0]); err != nil {
return err
}

var iface *net.Interface
var srcIP net.IP

if len(interfaceFlag) > 0 {
if iface, err = net.InterfaceByName(interfaceFlag); err != nil {
return err
}
if srcIP, err = ip.GetSubnetInterfaceIP(iface, dstSubnet); err != nil {
return err
}
} else {
if iface, srcIP, err = ip.GetSubnetInterface(dstSubnet); err != nil {
return err
}
}

if len(srcIPFlag) > 0 {
if srcIP = net.ParseIP(srcIPFlag); srcIP == nil {
return errSrcIP
}
}

srcMAC := iface.HardwareAddr
if len(srcMACFlag) > 0 {
if srcMAC, err = net.ParseMAC(srcMACFlag); err != nil {
return err
}
}

ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel()

var logger log.Logger
if logger, err = getLogger(ctx, os.Stdout); err != nil {
if logger, err = getLogger("arp", os.Stdout); err != nil {
return err
}
if arpLiveModeFlag {
logger = log.NewUniqueLogger(logger)
}

r := &scan.Range{Subnet: dstSubnet, Interface: iface, SrcIP: srcIP.To4(), SrcMAC: srcMAC}
return startEngine(ctx, logger, r)
},
}

func getLogger(ctx context.Context, w io.Writer) (logger log.Logger, err error) {
opts := []log.LoggerOption{log.FlushInterval(1 * time.Second)}
if jsonFlag {
opts = append(opts, log.JSON())
}
if logger, err = log.NewLogger(w, "arp", opts...); err != nil {
return
}
if liveModeFlag {
logger = log.NewUniqueLogger(ctx, logger)
}
return
}

func startEngine(ctx context.Context, logger log.Logger, r *scan.Range) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

// setup network interface to read/write packets
rw, err := afpacket.NewPacketSource(r.Interface.Name)
if err != nil {
return err
}
defer rw.Close()
err = rw.SetBPFFilter(arp.BPFFilter(r))
if err != nil {
return err
}

var opts []arp.ScanMethodOption
if liveModeFlag {
opts = append(opts, arp.LiveMode(1*time.Second))
}
m := arp.NewScanMethod(ctx, opts...)

// setup result logging
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
logger.LogResults(m.Results())
}()

// start scan
engine := scan.SetupEngine(rw, m)
done, errc := engine.Start(ctx, r)
go func() {
defer cancel()
<-done
<-time.After(300 * time.Millisecond)
}()

// error logging
wg.Add(1)
go func() {
defer wg.Done()
for err := range errc {
logger.Error(err)
var opts []arp.ScanMethodOption
if arpLiveModeFlag {
opts = append(opts, arp.LiveMode(1*time.Second))
}
}()
wg.Wait()
return nil
m := arp.NewScanMethod(ctx, opts...)

return startEngine(ctx, &engineConfig{
logger: logger,
scanRange: r,
scanMethod: m,
bpfFilter: arp.BPFFilter,
})
},
}
7 changes: 5 additions & 2 deletions command/log/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package log

import (
"bufio"
"context"
"io"
"time"

Expand All @@ -11,7 +12,7 @@ import (

type Logger interface {
Error(err error)
LogResults(results <-chan scan.Result)
LogResults(ctx context.Context, results <-chan scan.Result)
}

type FlushWriter interface {
Expand Down Expand Up @@ -75,13 +76,15 @@ func (l *logger) Error(err error) {
l.zapl.Error(l.label, zap.Error(err))
}

func (l *logger) LogResults(results <-chan scan.Result) {
func (l *logger) LogResults(ctx context.Context, results <-chan scan.Result) {
bw := bufio.NewWriter(l.w)
defer bw.Flush()
var err error
timec := time.After(l.flushInterval)
for {
select {
case <-ctx.Done():
return
case result, ok := <-results:
if !ok {
return
Expand Down
29 changes: 27 additions & 2 deletions command/log/logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package log

import (
"bytes"
"context"
"net"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -64,7 +66,7 @@ func TestJSONLoggerResults(t *testing.T) {
resultCh <- result
}
close(resultCh)
logger.LogResults(resultCh)
logger.LogResults(context.Background(), resultCh)

assert.Equal(t, string(tt.expected), buf.String())
})
Expand Down Expand Up @@ -124,9 +126,32 @@ func TestPlainLoggerResults(t *testing.T) {
resultCh <- result
}
close(resultCh)
logger.LogResults(resultCh)
logger.LogResults(context.Background(), resultCh)

assert.Equal(t, string(tt.expected), buf.String())
})
}
}

func TestLoggerContextExit(t *testing.T) {
t.Parallel()

done := make(chan interface{})
go func() {
defer close(done)

ctx, cancel := context.WithCancel(context.Background())
cancel()

var buf bytes.Buffer
logger, err := NewLogger(&buf, "arp", Plain())
require.NoError(t, err)

logger.LogResults(ctx, nil)
}()
select {
case <-done:
case <-time.After(3 * time.Second):
require.Fail(t, "test timeout")
}
}
33 changes: 20 additions & 13 deletions command/log/unique_logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,45 @@ import (
)

type UniqueLogger struct {
ctx context.Context
logger Logger
}

func NewUniqueLogger(ctx context.Context, logger Logger) *UniqueLogger {
return &UniqueLogger{ctx, logger}
func NewUniqueLogger(logger Logger) *UniqueLogger {
return &UniqueLogger{logger}
}

func (l *UniqueLogger) Error(err error) {
l.logger.Error(err)
}

func (l *UniqueLogger) LogResults(results <-chan scan.Result) {
l.logger.LogResults(l.uniqResults(results))
func (l *UniqueLogger) LogResults(ctx context.Context, results <-chan scan.Result) {
l.logger.LogResults(ctx, l.uniqResults(ctx, results))
}

func (l *UniqueLogger) uniqResults(in <-chan scan.Result) <-chan scan.Result {
func (*UniqueLogger) uniqResults(ctx context.Context, in <-chan scan.Result) <-chan scan.Result {
results := make(chan scan.Result, cap(in))
go func() {
defer close(results)
var member struct{}
set := make(map[string]interface{})

for result := range in {
id := result.ID()
if _, exists := set[id]; !exists {
set[id] = member
select {
case results <- result:
case <-l.ctx.Done():
for {
select {
case <-ctx.Done():
return
case result, ok := <-in:
if !ok {
return
}
id := result.ID()
if _, exists := set[id]; !exists {
set[id] = member
select {
case <-ctx.Done():
return
case results <- result:
}
}
}
}
}()
Expand Down
29 changes: 27 additions & 2 deletions command/log/unique_logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -71,16 +72,40 @@ func TestUniqueLoggerResults(t *testing.T) {
var buf bytes.Buffer
plainLogger, err := NewLogger(&buf, "arp")
require.NoError(t, err)
logger := NewUniqueLogger(context.Background(), plainLogger)
logger := NewUniqueLogger(plainLogger)

resultCh := make(chan scan.Result, len(tt.results))
for _, result := range tt.results {
resultCh <- result
}
close(resultCh)
logger.LogResults(resultCh)
logger.LogResults(context.Background(), resultCh)

assert.Equal(t, string(tt.expected), buf.String())
})
}
}

func TestUniqueLoggerContextExit(t *testing.T) {
t.Parallel()

done := make(chan interface{})
go func() {
defer close(done)

ctx, cancel := context.WithCancel(context.Background())
cancel()

var buf bytes.Buffer
logger, err := NewLogger(&buf, "arp", Plain())
require.NoError(t, err)

uniqLogger := NewUniqueLogger(logger)
<-uniqLogger.uniqResults(ctx, nil)
}()
select {
case <-done:
case <-time.After(3 * time.Second):
require.Fail(t, "test timeout")
}
}
Loading

0 comments on commit 1c1e57d

Please sign in to comment.