From 189d984aafb42b2baab103492f74a9aa2ae8b2dc Mon Sep 17 00:00:00 2001 From: v-byte-cpu <65545655+v-byte-cpu@users.noreply.github.com> Date: Sun, 4 Apr 2021 02:50:05 +0300 Subject: [PATCH] SOCKS scan --- README.md | 1 + command/root.go | 17 ++ command/socks.go | 70 +++++ command/tcp.go | 2 +- command/udp.go | 6 +- pkg/scan/engine.go | 106 ++++++- pkg/scan/engine_test.go | 161 ++++++++++- pkg/scan/mock_engine_test.go | 40 ++- pkg/scan/mock_generator_test.go | 8 +- pkg/scan/mock_request_test.go | 42 ++- pkg/scan/request.go | 175 ++++++++++-- pkg/scan/request_easyjson.go | 92 +++++++ pkg/scan/request_test.go | 427 ++++++++++++++++++++--------- pkg/scan/socks5/message.go | 63 +++++ pkg/scan/socks5/message_test.go | 69 +++++ pkg/scan/socks5/result_easyjson.go | 113 ++++++++ pkg/scan/socks5/socks5.go | 137 +++++++++ pkg/scan/utils_test.go | 9 + 18 files changed, 1357 insertions(+), 181 deletions(-) create mode 100644 command/socks.go create mode 100644 pkg/scan/request_easyjson.go create mode 100644 pkg/scan/socks5/message.go create mode 100644 pkg/scan/socks5/message_test.go create mode 100644 pkg/scan/socks5/result_easyjson.go create mode 100644 pkg/scan/socks5/socks5.go diff --git a/README.md b/README.md index d026fb8..ae18714 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ The goal of this project is to create the fastest network scanner with clean and * **TCP FIN / NULL / Xmas scans**: Scan techniques to bypass some firewall rules * **Custom TCP scans with any TCP flags**: Send whatever exotic packets you want and get a result with all the TCP flags set in the reply packet * **UDP scan**: Scan UDP ports and get full ICMP replies to detect open ports or firewall rules + * **SOCKS5 scan**: Detect live SOCKS5 proxies by scanning ip range or list of ip/port pairs from a file * **JSON output support**: sx is designed specifically for convenient automatic processing of results ## Build from source diff --git a/command/root.go b/command/root.go index dcc2a98..0b590f6 100644 --- a/command/root.go +++ b/command/root.go @@ -72,11 +72,13 @@ var ( cliRateLimitFlag string cliExitDelayFlag string cliARPCacheFileFlag string + cliIPPortFileFlag string cliInterface *net.Interface cliSrcIP net.IP cliSrcMAC net.HardwareAddr cliPortRanges []*scan.PortRange + cliDstSubnet *net.IPNet cliRateCount int cliRateWindow time.Duration cliExitDelay = 300 * time.Millisecond @@ -301,6 +303,21 @@ func getGatewayIP(r *scan.Range) (gatewayIP net.IP, err error) { return } +func newIPPortGenerator() (reqgen scan.RequestGenerator) { + if len(cliIPPortFileFlag) == 0 { + return scan.NewIPPortGenerator(scan.NewIPGenerator(), scan.NewPortGenerator()) + } + if len(cliPortRanges) == 0 { + return scan.NewFileIPPortGenerator(func() (io.ReadCloser, error) { + return os.Open(cliIPPortFileFlag) + }) + } + ipgen := scan.NewFileIPGenerator(func() (io.ReadCloser, error) { + return os.Open(cliIPPortFileFlag) + }) + return scan.NewIPPortGenerator(ipgen, scan.NewPortGenerator()) +} + type bpfFilterFunc func(r *scan.Range) (filter string, maxPacketLength int) type engineConfig struct { diff --git a/command/socks.go b/command/socks.go new file mode 100644 index 0000000..6b25c82 --- /dev/null +++ b/command/socks.go @@ -0,0 +1,70 @@ +package command + +import ( + "context" + "errors" + "os" + "os/signal" + "strings" + "time" + + "github.com/spf13/cobra" + "github.com/v-byte-cpu/sx/command/log" + "github.com/v-byte-cpu/sx/pkg/ip" + "github.com/v-byte-cpu/sx/pkg/scan" + "github.com/v-byte-cpu/sx/pkg/scan/socks5" +) + +func init() { + socksCmd.Flags().StringVarP(&cliPortsFlag, "ports", "p", "", "set ports to scan") + socksCmd.Flags().StringVarP(&cliIPPortFileFlag, "file", "f", "", "set JSONL file with ip/port pairs to scan") + rootCmd.AddCommand(socksCmd) +} + +var socksCmd = &cobra.Command{ + Use: "socks [flags] subnet", + Example: strings.Join([]string{ + "socks -p 1080 192.168.0.1/24", "socks -p 1080-4567 10.0.0.1", + "socks -f ip_ports_file.jsonl", "socks -p 1080-4567 -f ips_file.jsonl"}, "\n"), + Short: "Perform SOCKS5 scan", + // Long: "Perform SOCKS scan. SOCKS5 scan is used by default unless --version option is specified", + PreRunE: func(cmd *cobra.Command, args []string) (err error) { + if len(args) == 0 && len(cliIPPortFileFlag) == 0 { + return errors.New("requires one ip subnet argument or file with ip/port pairs") + } + if len(args) == 0 { + return + } + cliDstSubnet, err = ip.ParseIPNet(args[0]) + return + }, + RunE: func(cmd *cobra.Command, args []string) (err error) { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + var logger log.Logger + if logger, err = getLogger("socks", os.Stdout); err != nil { + return + } + + engine := newSOCKSScanEngine(ctx) + return startScanEngine(ctx, engine, + newEngineConfig( + withLogger(logger), + withScanRange(&scan.Range{ + DstSubnet: cliDstSubnet, + Ports: cliPortRanges, + }), + )) + }, +} + +func newSOCKSScanEngine(ctx context.Context) scan.EngineResulter { + // TODO custom dialTimeout, dataTimeout + scanner := socks5.NewScanner( + socks5.WithDialTimeout(2*time.Second), + socks5.WithDataTimeout(2*time.Second)) + results := scan.NewResultChan(ctx, 1000) + // TODO custom workerCount + return scan.NewScanEngine(newIPPortGenerator(), scanner, results, scan.WithScanWorkerCount(50)) +} diff --git a/command/tcp.go b/command/tcp.go index 14d25c8..60934bf 100644 --- a/command/tcp.go +++ b/command/tcp.go @@ -165,7 +165,7 @@ func newTCPScanMethod(ctx context.Context, conf *scanConfig, opts ...tcpScanConf portgen := scan.NewPortGenerator() ipgen := scan.NewIPGenerator() reqgen := arp.NewCacheRequestGenerator( - scan.NewIPPortRequestGenerator(ipgen, portgen), conf.gatewayIP, conf.cache) + scan.NewIPPortGenerator(ipgen, portgen), conf.gatewayIP, conf.cache) pktgen := scan.NewPacketMultiGenerator(c.packetFiller, runtime.NumCPU()) psrc := scan.NewPacketSource(reqgen, pktgen) results := scan.NewResultChan(ctx, 1000) diff --git a/command/udp.go b/command/udp.go index f1ed670..d4d3b48 100644 --- a/command/udp.go +++ b/command/udp.go @@ -3,6 +3,7 @@ package command import ( "context" "errors" + golog "log" "os" "os/signal" "runtime" @@ -17,6 +18,9 @@ import ( func init() { udpCmd.Flags().StringVarP(&cliPortsFlag, "ports", "p", "", "set ports to scan") + if err := udpCmd.MarkFlagRequired("ports"); err != nil { + golog.Fatalln(err) + } udpCmd.Flags().StringVarP(&cliARPCacheFileFlag, "arp-cache", "a", "", strings.Join([]string{"set ARP cache file", "reads from stdin by default"}, "\n")) rootCmd.AddCommand(udpCmd) @@ -58,7 +62,7 @@ func newUDPScanMethod(ctx context.Context, conf *scanConfig) *udp.ScanMethod { portgen := scan.NewPortGenerator() ipgen := scan.NewIPGenerator() reqgen := arp.NewCacheRequestGenerator( - scan.NewIPPortRequestGenerator(ipgen, portgen), conf.gatewayIP, conf.cache) + scan.NewIPPortGenerator(ipgen, portgen), conf.gatewayIP, conf.cache) pktgen := scan.NewPacketMultiGenerator(udp.NewPacketFiller(), runtime.NumCPU()) psrc := scan.NewPacketSource(reqgen, pktgen) results := scan.NewResultChan(ctx, 1000) diff --git a/pkg/scan/engine.go b/pkg/scan/engine.go index 1cac21b..d0cc4d1 100644 --- a/pkg/scan/engine.go +++ b/pkg/scan/engine.go @@ -1,4 +1,4 @@ -//go:generate mockgen -package scan -destination=mock_engine_test.go . PacketSource +//go:generate mockgen -package scan -destination=mock_engine_test.go . PacketSource,Scanner package scan @@ -103,11 +103,7 @@ func mergeErrChan(ctx context.Context, channels ...<-chan error) <-chan error { if !ok { return } - select { - case <-ctx.Done(): - return - case out <- e: - } + writeError(ctx, out, e) } } } @@ -133,3 +129,101 @@ func SetupPacketEngine(rw packet.ReadWriter, m PacketMethod) EngineResulter { engine := NewPacketEngine(m, sender, receiver) return NewEngineResulter(engine, m) } + +type Scanner interface { + Scan(ctx context.Context, r *Request) (Result, error) +} + +type GenericEngine struct { + reqgen RequestGenerator + scanner Scanner + results ResultChan + workerCount int +} + +// Assert that GenericEngine conforms to the scan.EngineResulter interface +var _ EngineResulter = (*GenericEngine)(nil) + +type GenericEngineOption func(s *GenericEngine) + +func WithScanWorkerCount(workerCount int) GenericEngineOption { + return func(s *GenericEngine) { + s.workerCount = workerCount + } +} + +func NewScanEngine(reqgen RequestGenerator, + scanner Scanner, results ResultChan, opts ...GenericEngineOption) *GenericEngine { + s := &GenericEngine{ + reqgen: reqgen, + scanner: scanner, + results: results, + workerCount: 50, + } + for _, o := range opts { + o(s) + } + return s +} + +func (e *GenericEngine) Results() <-chan Result { + return e.results.Chan() +} + +func (e *GenericEngine) Start(ctx context.Context, r *Range) (<-chan interface{}, <-chan error) { + done := make(chan interface{}) + errc := make(chan error, 100) + requests, err := e.reqgen.GenerateRequests(ctx, r) + if err != nil { + errc <- err + close(errc) + close(done) + return done, errc + } + go func() { + defer close(done) + defer close(errc) + var wg sync.WaitGroup + for i := 1; i <= e.workerCount; i++ { + wg.Add(1) + go e.worker(ctx, &wg, requests, errc) + } + wg.Wait() + }() + return done, errc +} + +func (e *GenericEngine) worker(ctx context.Context, wg *sync.WaitGroup, + requests <-chan *Request, errc chan<- error) { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case r, ok := <-requests: + if !ok { + return + } + if r.Err != nil { + writeError(ctx, errc, r.Err) + continue + } + result, err := e.scanner.Scan(ctx, r) + if err != nil { + writeError(ctx, errc, err) + continue + } + if result != nil { + e.results.Put(result) + } + } + } +} + +func writeError(ctx context.Context, out chan<- error, err error) { + select { + case <-ctx.Done(): + return + case out <- err: + } +} diff --git a/pkg/scan/engine_test.go b/pkg/scan/engine_test.go index ccb1373..c7d5934 100644 --- a/pkg/scan/engine_test.go +++ b/pkg/scan/engine_test.go @@ -6,6 +6,7 @@ import ( "context" "errors" "net" + "sort" "testing" "time" @@ -154,11 +155,7 @@ func TestPacketSourceReturnsError(t *testing.T) { result := <-out require.Error(t, result.Err) }() - select { - case <-done: - case <-time.After(3 * time.Second): - t.Fatal("test timeout") - } + waitDone(t, done) } func TestPacketSourceReturnsData(t *testing.T) { @@ -200,9 +197,153 @@ func TestPacketSourceReturnsData(t *testing.T) { require.NoError(t, result.Err) require.Equal(t, data.Buf, result.Buf) }() - select { - case <-done: - case <-time.After(3 * time.Second): - t.Fatal("test timeout") - } + waitDone(t, done) +} + +func TestScanEngineWithRequestGeneratorError(t *testing.T) { + t.Parallel() + + done := make(chan interface{}) + go func() { + defer close(done) + + ctrl := gomock.NewController(t) + reqgen := NewMockRequestGenerator(ctrl) + scanner := NewMockScanner(ctrl) + ctx := context.Background() + + reqgen.EXPECT().GenerateRequests(gomock.Not(gomock.Nil()), &Range{}). + Return(nil, errors.New("generate error")) + engine := NewScanEngine(reqgen, scanner, NewResultChan(ctx, 10)) + + _, errc := engine.Start(ctx, &Range{}) + err := <-errc + require.Error(t, err) + }() + waitDone(t, done) +} + +func TestScanEngineWithRequestError(t *testing.T) { + t.Parallel() + + done := make(chan interface{}) + go func() { + defer close(done) + + ctrl := gomock.NewController(t) + reqgen := NewMockRequestGenerator(ctrl) + scanner := NewMockScanner(ctrl) + ctx := context.Background() + + requests := make(chan *Request, 1) + requests <- &Request{Err: errors.New("request error")} + close(requests) + reqgen.EXPECT().GenerateRequests(gomock.Not(gomock.Nil()), &Range{}). + Return(requests, nil) + engine := NewScanEngine(reqgen, scanner, NewResultChan(ctx, 10)) + + _, errc := engine.Start(ctx, &Range{}) + err := <-errc + require.Error(t, err) + }() + waitDone(t, done) +} + +func TestScanEngineWithScannerError(t *testing.T) { + t.Parallel() + + done := make(chan interface{}) + go func() { + defer close(done) + + ctrl := gomock.NewController(t) + reqgen := NewMockRequestGenerator(ctrl) + scanner := NewMockScanner(ctrl) + ctx := context.Background() + + requests := make(chan *Request, 1) + req1 := &Request{DstIP: net.IPv4(192, 168, 0, 1), DstPort: 22} + requests <- req1 + close(requests) + reqgen.EXPECT().GenerateRequests(gomock.Not(gomock.Nil()), &Range{}). + Return(requests, nil) + scanner.EXPECT().Scan(gomock.Not(gomock.Nil()), req1).Return(nil, errors.New("scan error")) + engine := NewScanEngine(reqgen, scanner, NewResultChan(ctx, 10)) + + _, errc := engine.Start(ctx, &Range{}) + err := <-errc + require.Error(t, err) + }() + waitDone(t, done) +} + +func TestScanEngineWithResults(t *testing.T) { + t.Parallel() + + done := make(chan interface{}) + go func() { + defer close(done) + + ctrl := gomock.NewController(t) + reqgen := NewMockRequestGenerator(ctrl) + scanner := NewMockScanner(ctrl) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + requests := make(chan *Request, 2) + req1 := &Request{DstIP: net.IPv4(192, 168, 0, 1), DstPort: 22} + req2 := &Request{DstIP: net.IPv4(192, 168, 0, 2), DstPort: 22} + requests <- req1 + requests <- req2 + close(requests) + reqgen.EXPECT().GenerateRequests(gomock.Not(gomock.Nil()), &Range{}). + Return(requests, nil) + + scanner.EXPECT().Scan(gomock.Not(gomock.Nil()), req1). + Return(&mockScanResult{"id1"}, nil) + scanner.EXPECT().Scan(gomock.Not(gomock.Nil()), req2). + Return(&mockScanResult{"id2"}, nil) + + resultCh := NewResultChan(ctx, 10) + engine := NewScanEngine(reqgen, scanner, resultCh, WithScanWorkerCount(10)) + + done, errc := engine.Start(ctx, &Range{}) + <-done + cancel() + require.Zero(t, len(errc), "error channel is not empty") + var results []Result + cnt := 0 + for result := range resultCh.Chan() { + cnt++ + if cnt > 2 { + require.Fail(t, "result channel contains more elements than expected: ", result) + } + results = append(results, result) + } + + sort.Slice(results, func(i, j int) bool { + return results[i].ID() < results[j].ID() + }) + require.Equal(t, results, []Result{ + &mockScanResult{"id1"}, + &mockScanResult{"id2"}, + }) + }() + waitDone(t, done) +} + +type mockScanResult struct { + id string +} + +func (r *mockScanResult) ID() string { + return r.id +} + +func (r *mockScanResult) String() string { + return r.id +} + +func (r *mockScanResult) MarshalJSON() ([]byte, error) { + return []byte(r.id), nil } diff --git a/pkg/scan/mock_engine_test.go b/pkg/scan/mock_engine_test.go index 5180ab5..92c45ea 100644 --- a/pkg/scan/mock_engine_test.go +++ b/pkg/scan/mock_engine_test.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/v-byte-cpu/sx/pkg/scan (interfaces: PacketSource) +// Source: github.com/v-byte-cpu/sx/pkg/scan (interfaces: PacketSource,Scanner) // Package scan is a generated GoMock package. package scan @@ -48,3 +48,41 @@ func (mr *MockPacketSourceMockRecorder) Packets(arg0, arg1 interface{}) *gomock. mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Packets", reflect.TypeOf((*MockPacketSource)(nil).Packets), arg0, arg1) } + +// MockScanner is a mock of Scanner interface. +type MockScanner struct { + ctrl *gomock.Controller + recorder *MockScannerMockRecorder +} + +// MockScannerMockRecorder is the mock recorder for MockScanner. +type MockScannerMockRecorder struct { + mock *MockScanner +} + +// NewMockScanner creates a new mock instance. +func NewMockScanner(ctrl *gomock.Controller) *MockScanner { + mock := &MockScanner{ctrl: ctrl} + mock.recorder = &MockScannerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockScanner) EXPECT() *MockScannerMockRecorder { + return m.recorder +} + +// Scan mocks base method. +func (m *MockScanner) Scan(arg0 context.Context, arg1 *Request) (Result, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Scan", arg0, arg1) + ret0, _ := ret[0].(Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Scan indicates an expected call of Scan. +func (mr *MockScannerMockRecorder) Scan(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockScanner)(nil).Scan), arg0, arg1) +} diff --git a/pkg/scan/mock_generator_test.go b/pkg/scan/mock_generator_test.go index c9484cf..4092d14 100644 --- a/pkg/scan/mock_generator_test.go +++ b/pkg/scan/mock_generator_test.go @@ -37,17 +37,17 @@ func (m *MockPacketFiller) EXPECT() *MockPacketFillerMockRecorder { } // Fill mocks base method. -func (m *MockPacketFiller) Fill(packet gopacket.SerializeBuffer, pair *Request) error { +func (m *MockPacketFiller) Fill(packet gopacket.SerializeBuffer, r *Request) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Fill", packet, pair) + ret := m.ctrl.Call(m, "Fill", packet, r) ret0, _ := ret[0].(error) return ret0 } // Fill indicates an expected call of Fill. -func (mr *MockPacketFillerMockRecorder) Fill(packet, pair interface{}) *gomock.Call { +func (mr *MockPacketFillerMockRecorder) Fill(packet, r interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fill", reflect.TypeOf((*MockPacketFiller)(nil).Fill), packet, pair) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Fill", reflect.TypeOf((*MockPacketFiller)(nil).Fill), packet, r) } // MockPacketGenerator is a mock of PacketGenerator interface. diff --git a/pkg/scan/mock_request_test.go b/pkg/scan/mock_request_test.go index f7b57e5..8d72a95 100644 --- a/pkg/scan/mock_request_test.go +++ b/pkg/scan/mock_request_test.go @@ -50,6 +50,44 @@ func (mr *MockPortGeneratorMockRecorder) Ports(ctx, r interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ports", reflect.TypeOf((*MockPortGenerator)(nil).Ports), ctx, r) } +// MockIPGetter is a mock of IPGetter interface. +type MockIPGetter struct { + ctrl *gomock.Controller + recorder *MockIPGetterMockRecorder +} + +// MockIPGetterMockRecorder is the mock recorder for MockIPGetter. +type MockIPGetterMockRecorder struct { + mock *MockIPGetter +} + +// NewMockIPGetter creates a new mock instance. +func NewMockIPGetter(ctrl *gomock.Controller) *MockIPGetter { + mock := &MockIPGetter{ctrl: ctrl} + mock.recorder = &MockIPGetterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockIPGetter) EXPECT() *MockIPGetterMockRecorder { + return m.recorder +} + +// GetIP mocks base method. +func (m *MockIPGetter) GetIP() (net.IP, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetIP") + ret0, _ := ret[0].(net.IP) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetIP indicates an expected call of GetIP. +func (mr *MockIPGetterMockRecorder) GetIP() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIP", reflect.TypeOf((*MockIPGetter)(nil).GetIP)) +} + // MockIPGenerator is a mock of IPGenerator interface. type MockIPGenerator struct { ctrl *gomock.Controller @@ -74,10 +112,10 @@ func (m *MockIPGenerator) EXPECT() *MockIPGeneratorMockRecorder { } // IPs mocks base method. -func (m *MockIPGenerator) IPs(ctx context.Context, r *Range) (<-chan net.IP, error) { +func (m *MockIPGenerator) IPs(ctx context.Context, r *Range) (<-chan IPGetter, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "IPs", ctx, r) - ret0, _ := ret[0].(<-chan net.IP) + ret0, _ := ret[0].(<-chan IPGetter) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/pkg/scan/request.go b/pkg/scan/request.go index b662731..543805f 100644 --- a/pkg/scan/request.go +++ b/pkg/scan/request.go @@ -1,18 +1,26 @@ //go:generate mockgen -package scan -destination=mock_request_test.go -source request.go +//go:generate easyjson -output_filename request_easyjson.go request.go package scan import ( + "bufio" "context" "errors" + "io" "net" "time" "github.com/v-byte-cpu/sx/pkg/ip" ) -var ErrPortRange = errors.New("invalid port range") -var ErrSubnet = errors.New("invalid subnet") +var ( + ErrPortRange = errors.New("invalid port range") + ErrSubnet = errors.New("invalid subnet") + ErrIP = errors.New("invalid ip") + ErrPort = errors.New("invalid port") + ErrJSON = errors.New("invalid json") +) type Request struct { Meta map[string]interface{} @@ -67,8 +75,18 @@ func validatePorts(ports []*PortRange) error { return nil } +type IPGetter interface { + GetIP() (net.IP, error) +} + +type wrapIP net.IP + +func (i wrapIP) GetIP() (net.IP, error) { + return net.IP(i), nil +} + type IPGenerator interface { - IPs(ctx context.Context, r *Range) (<-chan net.IP, error) + IPs(ctx context.Context, r *Range) (<-chan IPGetter, error) } func NewIPGenerator() IPGenerator { @@ -77,20 +95,16 @@ func NewIPGenerator() IPGenerator { type ipGenerator struct{} -func (*ipGenerator) IPs(ctx context.Context, r *Range) (<-chan net.IP, error) { +func (*ipGenerator) IPs(ctx context.Context, r *Range) (<-chan IPGetter, error) { if r.DstSubnet == nil { return nil, ErrSubnet } - out := make(chan net.IP) + out := make(chan IPGetter) go func() { defer close(out) ipnet := r.DstSubnet for ipaddr := ipnet.IP.Mask(ipnet.Mask); ipnet.Contains(ipaddr); ip.Inc(ipaddr) { - select { - case <-ctx.Done(): - return - case out <- ip.DupIP(ipaddr): - } + writeIP(ctx, out, wrapIP(ip.DupIP(ipaddr))) } }() return out, nil @@ -100,16 +114,16 @@ type RequestGenerator interface { GenerateRequests(ctx context.Context, r *Range) (<-chan *Request, error) } -func NewIPPortRequestGenerator(ipgen IPGenerator, portgen PortGenerator) RequestGenerator { - return &ipPortRequestGenerator{ipgen, portgen} +func NewIPPortGenerator(ipgen IPGenerator, portgen PortGenerator) RequestGenerator { + return &ipPortGenerator{ipgen, portgen} } -type ipPortRequestGenerator struct { +type ipPortGenerator struct { ipgen IPGenerator portgen PortGenerator } -func (rg *ipPortRequestGenerator) GenerateRequests(ctx context.Context, r *Range) (<-chan *Request, error) { +func (rg *ipPortGenerator) GenerateRequests(ctx context.Context, r *Range) (<-chan *Request, error) { ports, err := rg.portgen.Ports(ctx, r) if err != nil { return nil, err @@ -123,9 +137,10 @@ func (rg *ipPortRequestGenerator) GenerateRequests(ctx context.Context, r *Range defer close(out) for port := range ports { for ipaddr := range ips { + dstip, err := ipaddr.GetIP() writeRequest(ctx, out, &Request{ SrcIP: r.SrcIP, SrcMAC: r.SrcMAC, - DstIP: ipaddr, DstPort: port}) + DstIP: dstip, DstPort: port, Err: err}) } if ips, err = rg.ipgen.IPs(ctx, r); err != nil { writeRequest(ctx, out, &Request{Err: err}) @@ -161,31 +176,141 @@ func (rg *ipRequestGenerator) GenerateRequests(ctx context.Context, r *Range) (< go func() { defer close(out) for ipaddr := range ips { + dstip, err := ipaddr.GetIP() writeRequest(ctx, out, &Request{ - SrcIP: r.SrcIP, SrcMAC: r.SrcMAC, DstIP: ipaddr, + SrcIP: r.SrcIP, SrcMAC: r.SrcMAC, DstIP: dstip, + Err: err, }) } }() return out, nil } -type LiveRequestGenerator struct { +//easyjson:json +type IPPort struct { + IP string `json:"ip"` + Port int `json:"port"` +} + +type fileIPPortGenerator struct { + openFile OpenFileFunc +} + +type OpenFileFunc func() (io.ReadCloser, error) + +func NewFileIPPortGenerator(openFile OpenFileFunc) RequestGenerator { + return &fileIPPortGenerator{openFile} +} + +func (rg *fileIPPortGenerator) GenerateRequests(ctx context.Context, _ *Range) (<-chan *Request, error) { + input, err := rg.openFile() + if err != nil { + return nil, err + } + out := make(chan *Request) + go func() { + defer close(out) + defer input.Close() + scanner := bufio.NewScanner(input) + var entry IPPort + for scanner.Scan() { + if err := entry.UnmarshalJSON(scanner.Bytes()); err != nil { + writeRequest(ctx, out, &Request{Err: ErrJSON}) + return + } + ip := net.ParseIP(entry.IP) + if ip == nil { + writeRequest(ctx, out, &Request{Err: ErrIP}) + return + } + if !isValidPort(entry.Port) { + writeRequest(ctx, out, &Request{Err: ErrPort}) + return + } + writeRequest(ctx, out, &Request{DstIP: ip, DstPort: uint16(entry.Port)}) + } + if err = scanner.Err(); err != nil { + writeRequest(ctx, out, &Request{Err: err}) + } + }() + return out, nil +} + +func isValidPort(port int) bool { + return port > 0 && port <= 0xFFFF +} + +type ipError struct { + error +} + +func (err *ipError) GetIP() (net.IP, error) { + return nil, err +} + +type fileIPGenerator struct { + openFile OpenFileFunc +} + +func NewFileIPGenerator(openFile OpenFileFunc) IPGenerator { + return &fileIPGenerator{openFile} +} + +func (g *fileIPGenerator) IPs(ctx context.Context, _ *Range) (<-chan IPGetter, error) { + input, err := g.openFile() + if err != nil { + return nil, err + } + out := make(chan IPGetter) + go func() { + defer close(out) + defer input.Close() + scanner := bufio.NewScanner(input) + var entry IPPort + for scanner.Scan() { + if err := entry.UnmarshalJSON(scanner.Bytes()); err != nil { + writeIP(ctx, out, &ipError{error: ErrJSON}) + return + } + ip := net.ParseIP(entry.IP) + if ip == nil { + writeIP(ctx, out, &ipError{error: ErrIP}) + return + } + writeIP(ctx, out, wrapIP(ip)) + } + if err = scanner.Err(); err != nil { + writeIP(ctx, out, &ipError{error: err}) + } + }() + return out, nil +} + +func writeIP(ctx context.Context, out chan<- IPGetter, ip IPGetter) { + select { + case <-ctx.Done(): + return + case out <- ip: + } +} + +type liveRequestGenerator struct { delegate RequestGenerator rescanTimeout time.Duration } func NewLiveRequestGenerator(rg RequestGenerator, rescanTimeout time.Duration) RequestGenerator { - return &LiveRequestGenerator{rg, rescanTimeout} + return &liveRequestGenerator{rg, rescanTimeout} } -func (rg *LiveRequestGenerator) GenerateRequests(ctx context.Context, r *Range) (<-chan *Request, error) { +func (rg *liveRequestGenerator) GenerateRequests(ctx context.Context, r *Range) (<-chan *Request, error) { requests, err := rg.delegate.GenerateRequests(ctx, r) if err != nil { return nil, err } - result := make(chan *Request, cap(requests)) + out := make(chan *Request, cap(requests)) go func() { - defer close(result) + defer close(out) var request *Request var ok bool for { @@ -195,11 +320,7 @@ func (rg *LiveRequestGenerator) GenerateRequests(ctx context.Context, r *Range) case request, ok = <-requests: } if ok { - select { - case <-ctx.Done(): - return - case result <- request: - } + writeRequest(ctx, out, request) continue } @@ -211,5 +332,5 @@ func (rg *LiveRequestGenerator) GenerateRequests(ctx context.Context, r *Range) } } }() - return result, nil + return out, nil } diff --git a/pkg/scan/request_easyjson.go b/pkg/scan/request_easyjson.go new file mode 100644 index 0000000..0b2f4a2 --- /dev/null +++ b/pkg/scan/request_easyjson.go @@ -0,0 +1,92 @@ +// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. + +package scan + +import ( + json "encoding/json" + easyjson "github.com/mailru/easyjson" + jlexer "github.com/mailru/easyjson/jlexer" + jwriter "github.com/mailru/easyjson/jwriter" +) + +// suppress unused package warning +var ( + _ *json.RawMessage + _ *jlexer.Lexer + _ *jwriter.Writer + _ easyjson.Marshaler +) + +func easyjson3c9d2b01DecodeGithubComVByteCpuSxPkgScan(in *jlexer.Lexer, out *IPPort) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "ip": + out.IP = string(in.String()) + case "port": + out.Port = int(in.Int()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson3c9d2b01EncodeGithubComVByteCpuSxPkgScan(out *jwriter.Writer, in IPPort) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"ip\":" + out.RawString(prefix[1:]) + out.String(string(in.IP)) + } + { + const prefix string = ",\"port\":" + out.RawString(prefix) + out.Int(int(in.Port)) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v IPPort) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson3c9d2b01EncodeGithubComVByteCpuSxPkgScan(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v IPPort) MarshalEasyJSON(w *jwriter.Writer) { + easyjson3c9d2b01EncodeGithubComVByteCpuSxPkgScan(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *IPPort) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson3c9d2b01DecodeGithubComVByteCpuSxPkgScan(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *IPPort) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson3c9d2b01DecodeGithubComVByteCpuSxPkgScan(l, v) +} diff --git a/pkg/scan/request_test.go b/pkg/scan/request_test.go index 8a9b60c..5ed2ffd 100644 --- a/pkg/scan/request_test.go +++ b/pkg/scan/request_test.go @@ -2,11 +2,14 @@ package scan import ( "context" + "errors" + "io" + "io/ioutil" "net" + "strings" "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -70,16 +73,30 @@ func withDstPort(dstPort uint16) scanRequestOption { } } -func TestPortGeneratorWithInvalidInput(t *testing.T) { +func chanPortToGeneric(in <-chan uint16) <-chan interface{} { + out := make(chan interface{}, cap(in)) + go func() { + defer close(out) + for i := range in { + out <- i + } + }() + return out +} + +func TestPortGenerator(t *testing.T) { t.Parallel() tests := []struct { name string scanRange *Range + expected []interface{} + err bool }{ { name: "NilPorts", scanRange: newScanRange(withPorts(nil)), + err: true, }, { name: "InvalidPortRange", @@ -89,6 +106,7 @@ func TestPortGeneratorWithInvalidInput(t *testing.T) { EndPort: 2000, }, })), + err: true, }, { name: "InvalidPortRangeAfterValid", @@ -102,37 +120,8 @@ func TestPortGeneratorWithInvalidInput(t *testing.T) { EndPort: 5000, }, })), + err: true, }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - portgen := NewPortGenerator() - _, err := portgen.Ports(context.Background(), tt.scanRange) - require.Error(t, err) - }) - } -} - -func chanPortToGeneric(in <-chan uint16) <-chan interface{} { - out := make(chan interface{}, cap(in)) - go func() { - defer close(out) - for i := range in { - out <- i - } - }() - return out -} - -func TestPortGenerator(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - scanRange *Range - expected []interface{} - }{ { name: "OnePort", scanRange: newScanRange(withPorts([]*PortRange{ @@ -213,42 +202,20 @@ func TestPortGenerator(t *testing.T) { defer close(done) portgen := NewPortGenerator() ports, err := portgen.Ports(context.Background(), tt.scanRange) + if tt.err { + require.Error(t, err) + return + } require.NoError(t, err) result := chanToSlice(t, chanPortToGeneric(ports), len(tt.expected)) require.Equal(t, tt.expected, result) }() - select { - case <-done: - case <-time.After(waitTimeout): - require.Fail(t, "test timeout") - } + waitDone(t, done) }) } } -func TestIPGeneratorWithInvalidInput(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - scanRange *Range - }{ - { - name: "NilSubnet", - scanRange: newScanRange(withSubnet(nil)), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ipgen := NewIPGenerator() - _, err := ipgen.IPs(context.Background(), tt.scanRange) - require.Error(t, err) - }) - } -} - -func chanIPToGeneric(in <-chan net.IP) <-chan interface{} { +func chanIPToGeneric(in <-chan IPGetter) <-chan interface{} { out := make(chan interface{}, cap(in)) go func() { defer close(out) @@ -266,13 +233,21 @@ func TestIPGenerator(t *testing.T) { name string scanRange *Range expected []interface{} + err bool }{ + { + name: "NilSubnet", + scanRange: newScanRange(withSubnet(nil)), + err: true, + }, { name: "OneIP", scanRange: newScanRange( withSubnet(&net.IPNet{IP: net.IPv4(192, 168, 0, 1), Mask: net.CIDRMask(32, 32)}), ), - expected: []interface{}{net.IPv4(192, 168, 0, 1).To4()}, + expected: []interface{}{ + wrapIP(net.IPv4(192, 168, 0, 1).To4()), + }, }, { name: "TwoIPs", @@ -280,8 +255,8 @@ func TestIPGenerator(t *testing.T) { withSubnet(&net.IPNet{IP: net.IPv4(1, 0, 0, 1), Mask: net.CIDRMask(31, 32)}), ), expected: []interface{}{ - net.IPv4(1, 0, 0, 0).To4(), - net.IPv4(1, 0, 0, 1).To4(), + wrapIP(net.IPv4(1, 0, 0, 0).To4()), + wrapIP(net.IPv4(1, 0, 0, 1).To4()), }, }, { @@ -290,10 +265,10 @@ func TestIPGenerator(t *testing.T) { withSubnet(&net.IPNet{IP: net.IPv4(10, 0, 0, 1), Mask: net.CIDRMask(30, 32)}), ), expected: []interface{}{ - net.IPv4(10, 0, 0, 0).To4(), - net.IPv4(10, 0, 0, 1).To4(), - net.IPv4(10, 0, 0, 2).To4(), - net.IPv4(10, 0, 0, 3).To4(), + wrapIP(net.IPv4(10, 0, 0, 0).To4()), + wrapIP(net.IPv4(10, 0, 0, 1).To4()), + wrapIP(net.IPv4(10, 0, 0, 2).To4()), + wrapIP(net.IPv4(10, 0, 0, 3).To4()), }, }, } @@ -308,50 +283,15 @@ func TestIPGenerator(t *testing.T) { defer close(done) ipgen := NewIPGenerator() ips, err := ipgen.IPs(context.Background(), tt.scanRange) + if tt.err { + require.Error(t, err) + return + } require.NoError(t, err) result := chanToSlice(t, chanIPToGeneric(ips), len(tt.expected)) require.Equal(t, tt.expected, result) }() - select { - case <-done: - case <-time.After(waitTimeout): - require.Fail(t, "test timeout") - } - }) - } -} - -func TestIPPortRequestGeneratorWithInvalidInput(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - startPort uint16 - endPort uint16 - subnets []net.IPNet - scanRange *Range - }{ - { - name: "InvalidPortRange", - scanRange: newScanRange( - withPorts([]*PortRange{ - { - StartPort: 5000, - EndPort: 2000, - }, - }), - ), - }, - { - name: "NilSubnet", - scanRange: newScanRange(withSubnet(nil)), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - reqgen := NewIPPortRequestGenerator(NewIPGenerator(), NewPortGenerator()) - _, err := reqgen.GenerateRequests(context.Background(), tt.scanRange) - assert.Error(t, err) + waitDone(t, done) }) } } @@ -367,14 +307,32 @@ func chanPairToGeneric(in <-chan *Request) <-chan interface{} { return out } -func TestIPPortRequestRegenerator(t *testing.T) { +func TestIPPortGenerator(t *testing.T) { t.Parallel() tests := []struct { name string input *Range expected []interface{} + err bool }{ + { + name: "InvalidPortRange", + input: newScanRange( + withPorts([]*PortRange{ + { + StartPort: 5000, + EndPort: 2000, + }, + }), + ), + err: true, + }, + { + name: "NilSubnet", + input: newScanRange(withSubnet(nil)), + err: true, + }, { name: "OneIpOnePort", input: newScanRange( @@ -484,37 +442,35 @@ func TestIPPortRequestRegenerator(t *testing.T) { go func() { defer close(done) - reqgen := NewIPPortRequestGenerator(NewIPGenerator(), NewPortGenerator()) + reqgen := NewIPPortGenerator(NewIPGenerator(), NewPortGenerator()) pairs, err := reqgen.GenerateRequests(context.Background(), tt.input) + if tt.err { + require.Error(t, err) + return + } require.NoError(t, err) result := chanToSlice(t, chanPairToGeneric(pairs), len(tt.expected)) require.Equal(t, tt.expected, result) }() - select { - case <-done: - case <-time.After(waitTimeout): - require.Fail(t, "test timeout") - } + waitDone(t, done) }) } } -func TestIPRequestGeneratorWithInvalidInput(t *testing.T) { - t.Parallel() - - reqgen := NewIPRequestGenerator(NewIPGenerator()) - _, err := reqgen.GenerateRequests(context.Background(), newScanRange(withSubnet(nil))) - assert.Error(t, err) -} - -func TestIPRequestRegenerator(t *testing.T) { +func TestIPRequestGenerator(t *testing.T) { t.Parallel() tests := []struct { name string input *Range expected []interface{} + err bool }{ + { + name: "NilSubnet", + input: newScanRange(withSubnet(nil)), + err: true, + }, { name: "OneIP", input: newScanRange( @@ -559,15 +515,228 @@ func TestIPRequestRegenerator(t *testing.T) { reqgen := NewIPRequestGenerator(NewIPGenerator()) pairs, err := reqgen.GenerateRequests(context.Background(), tt.input) + if tt.err { + require.Error(t, err) + return + } require.NoError(t, err) result := chanToSlice(t, chanPairToGeneric(pairs), len(tt.expected)) require.Equal(t, tt.expected, result) }() - select { - case <-done: - case <-time.After(waitTimeout): - require.Fail(t, "test timeout") - } + waitDone(t, done) + }) + } +} + +func TestFileIPPortGeneratorWithInvalidFile(t *testing.T) { + t.Parallel() + + reqgen := NewFileIPPortGenerator(func() (io.ReadCloser, error) { + return nil, errors.New("open file error") + }) + _, err := reqgen.GenerateRequests(context.Background(), &Range{}) + require.Error(t, err) +} + +func TestFileIPPortGenerator(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected []interface{} + }{ + { + name: "OneIPPort", + input: `{"ip":"192.168.0.1","port":888}`, + expected: []interface{}{ + &Request{DstIP: net.IPv4(192, 168, 0, 1), DstPort: 888}, + }, + }, + { + name: "OneIPPortWithUnknownField", + input: `{"ip":"192.168.0.1","port":888,"abc":"field"}`, + expected: []interface{}{ + &Request{DstIP: net.IPv4(192, 168, 0, 1), DstPort: 888}, + }, + }, + { + name: "TwoIPPorts", + input: strings.Join([]string{ + `{"ip":"192.168.0.1","port":888}`, + `{"ip":"192.168.0.2","port":222}`, + }, "\n"), + expected: []interface{}{ + &Request{DstIP: net.IPv4(192, 168, 0, 1), DstPort: 888}, + &Request{DstIP: net.IPv4(192, 168, 0, 2), DstPort: 222}, + }, + }, + { + name: "InvalidJSON", + input: `{"ip":"192`, + expected: []interface{}{ + &Request{Err: ErrJSON}, + }, + }, + { + name: "InvalidJSONAfterValid", + input: strings.Join([]string{ + `{"ip":"192.168.0.1","port":888}`, + `{"ip":"192`, + }, "\n"), + expected: []interface{}{ + &Request{DstIP: net.IPv4(192, 168, 0, 1), DstPort: 888}, + &Request{Err: ErrJSON}, + }, + }, + { + name: "ValidJSONAfterInvalid", + input: strings.Join([]string{ + `{"ip":"192.168.0.1","port":888}`, + `{"ip":"192`, + `{"ip":"192.168.0.3","port":888}`, + }, "\n"), + expected: []interface{}{ + &Request{DstIP: net.IPv4(192, 168, 0, 1), DstPort: 888}, + &Request{Err: ErrJSON}, + }, + }, + { + name: "InvalidIP", + input: `{"ip":"192.168.0.1111","port":888}`, + expected: []interface{}{ + &Request{Err: ErrIP}, + }, + }, + { + name: "InvalidPort", + input: `{"ip":"192.168.0.1","port":88888}`, + expected: []interface{}{ + &Request{Err: ErrPort}, + }, + }, + } + for _, vtt := range tests { + tt := vtt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + done := make(chan interface{}) + go func() { + defer close(done) + + reqgen := NewFileIPPortGenerator(func() (io.ReadCloser, error) { + return ioutil.NopCloser(strings.NewReader(tt.input)), nil + }) + pairs, err := reqgen.GenerateRequests(context.Background(), &Range{}) + require.NoError(t, err) + result := chanToSlice(t, chanPairToGeneric(pairs), len(tt.expected)) + require.Equal(t, tt.expected, result) + }() + waitDone(t, done) + }) + } +} + +func TestFileIPGeneratorWithInvalidFile(t *testing.T) { + t.Parallel() + + ipgen := NewFileIPGenerator(func() (io.ReadCloser, error) { + return nil, errors.New("open file error") + }) + _, err := ipgen.IPs(context.Background(), &Range{}) + require.Error(t, err) +} + +func TestFileIPGenerator(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected []interface{} + }{ + { + name: "OneIP", + input: `{"ip":"192.168.0.1"}`, + expected: []interface{}{ + wrapIP(net.IPv4(192, 168, 0, 1)), + }, + }, + { + name: "OneIPWithUnknownField", + input: `{"ip":"192.168.0.1","abc":"field"}`, + expected: []interface{}{ + wrapIP(net.IPv4(192, 168, 0, 1)), + }, + }, + { + name: "TwoIPs", + input: strings.Join([]string{ + `{"ip":"192.168.0.1"}`, + `{"ip":"192.168.0.2"}`, + }, "\n"), + expected: []interface{}{ + wrapIP(net.IPv4(192, 168, 0, 1)), + wrapIP(net.IPv4(192, 168, 0, 2)), + }, + }, + { + name: "InvalidJSON", + input: `{"ip":"192`, + expected: []interface{}{ + &ipError{error: ErrJSON}, + }, + }, + { + name: "InvalidJSONAfterValid", + input: strings.Join([]string{ + `{"ip":"192.168.0.1","port":888}`, + `{"ip":"192`, + }, "\n"), + expected: []interface{}{ + wrapIP(net.IPv4(192, 168, 0, 1)), + &ipError{error: ErrJSON}, + }, + }, + { + name: "ValidJSONAfterInvalid", + input: strings.Join([]string{ + `{"ip":"192.168.0.1","port":888}`, + `{"ip":"192`, + `{"ip":"192.168.0.3","port":888}`, + }, "\n"), + expected: []interface{}{ + wrapIP(net.IPv4(192, 168, 0, 1)), + &ipError{error: ErrJSON}, + }, + }, + { + name: "InvalidIP", + input: `{"ip":"192.168.0.1111"}`, + expected: []interface{}{ + &ipError{error: ErrIP}, + }, + }, + } + for _, vtt := range tests { + tt := vtt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + done := make(chan interface{}) + go func() { + defer close(done) + + ipgen := NewFileIPGenerator(func() (io.ReadCloser, error) { + return ioutil.NopCloser(strings.NewReader(tt.input)), nil + }) + ips, err := ipgen.IPs(context.Background(), &Range{}) + require.NoError(t, err) + result := chanToSlice(t, chanIPToGeneric(ips), len(tt.expected)) + require.Equal(t, tt.expected, result) + }() + waitDone(t, done) }) } } @@ -575,7 +744,7 @@ func TestIPRequestRegenerator(t *testing.T) { func TestLiveRequestGeneratorContextExit(t *testing.T) { t.Parallel() - reqgen := NewIPPortRequestGenerator(NewIPGenerator(), NewPortGenerator()) + reqgen := NewIPPortGenerator(NewIPGenerator(), NewPortGenerator()) rg := NewLiveRequestGenerator(reqgen, 5*time.Second) ctx, cancel := context.WithCancel(context.Background()) cancel() diff --git a/pkg/scan/socks5/message.go b/pkg/scan/socks5/message.go new file mode 100644 index 0000000..85b44ef --- /dev/null +++ b/pkg/scan/socks5/message.go @@ -0,0 +1,63 @@ +package socks5 + +import ( + "encoding/binary" + "io" +) + +const MethodNoAuth = 0 + +// MethodRequest is a negotiation request for the authentication method to be used. +// It is the initial message that the client sends to the SOCKS5 server. +// From RFC1928: +// +----+----------+----------+ +// |VER | NMETHODS | METHODS | +// +----+----------+----------+ +// | 1 | 1 | 1 to 255 | +// +----+----------+----------+ +type MethodRequest struct { + Ver byte // version of the protocol + NMethods byte // number of method identifier octets that appear in the METHODS field. + Methods []byte +} + +func NewMethodRequest(version byte, methods ...byte) *MethodRequest { + return &MethodRequest{ + Ver: version, + NMethods: byte(len(methods)), + Methods: methods, + } +} + +func (r *MethodRequest) Len() int64 { + return 2 + int64(r.NMethods) +} + +func (r *MethodRequest) WriteTo(w io.Writer) (int64, error) { + buf := make([]byte, 0, r.Len()) + buf = append(buf, r.Ver) + buf = append(buf, r.NMethods) + buf = append(buf, r.Methods...) + n, err := w.Write(buf) + return int64(n), err +} + +// MethodReply is a negotiation reply for the authentication method to be used. +// From RFC1928: +// +----+--------+ +// |VER | METHOD | +// +----+--------+ +// | 1 | 1 | +// +----+--------+ +type MethodReply struct { + Ver byte // version of the protocol + Method byte // server selects from one of the methods given in the request METHODS field. +} + +func (*MethodReply) Len() int64 { + return 2 +} + +func (r *MethodReply) ReadFrom(in io.Reader) (int64, error) { + return r.Len(), binary.Read(in, binary.BigEndian, r) +} diff --git a/pkg/scan/socks5/message_test.go b/pkg/scan/socks5/message_test.go new file mode 100644 index 0000000..d71be94 --- /dev/null +++ b/pkg/scan/socks5/message_test.go @@ -0,0 +1,69 @@ +package socks5 + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWriteMethodRequest(t *testing.T) { + tests := []struct { + name string + request *MethodRequest + expected []byte + }{ + { + name: "oneMethod", + request: NewMethodRequest(5, 0), + expected: []byte{5, 1, 0}, + }, + { + name: "twoMethods", + request: NewMethodRequest(5, 0, 2), + expected: []byte{5, 2, 0, 2}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + _, err := tt.request.WriteTo(&buf) + require.NoError(t, err) + require.Equal(t, tt.expected, buf.Bytes()) + }) + } +} + +func TestReadMethodReply(t *testing.T) { + tests := []struct { + name string + reply []byte + expected *MethodReply + }{ + { + name: "SOCKS5version", + reply: []byte{SOCKSVersion, 0}, + expected: &MethodReply{Ver: SOCKSVersion, Method: 0}, + }, + { + name: "SOCKS4version", + reply: []byte{4, 91}, + expected: &MethodReply{Ver: 4, Method: 91}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + _, err := buf.Write(tt.reply) + require.NoError(t, err) + + reply := &MethodReply{} + _, err = reply.ReadFrom(&buf) + require.NoError(t, err) + + require.Equal(t, tt.expected, reply) + }) + } +} diff --git a/pkg/scan/socks5/result_easyjson.go b/pkg/scan/socks5/result_easyjson.go new file mode 100644 index 0000000..64954c4 --- /dev/null +++ b/pkg/scan/socks5/result_easyjson.go @@ -0,0 +1,113 @@ +// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. + +package socks5 + +import ( + json "encoding/json" + easyjson "github.com/mailru/easyjson" + jlexer "github.com/mailru/easyjson/jlexer" + jwriter "github.com/mailru/easyjson/jwriter" +) + +// suppress unused package warning +var ( + _ *json.RawMessage + _ *jlexer.Lexer + _ *jwriter.Writer + _ easyjson.Marshaler +) + +func easyjsonD3b49167DecodeGithubComVByteCpuSxPkgScanSocks5(in *jlexer.Lexer, out *ScanResult) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "scan": + out.ScanType = string(in.String()) + case "version": + out.Version = int(in.Int()) + case "ip": + out.IP = string(in.String()) + case "port": + out.Port = uint16(in.Uint16()) + case "auth": + out.Auth = bool(in.Bool()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjsonD3b49167EncodeGithubComVByteCpuSxPkgScanSocks5(out *jwriter.Writer, in ScanResult) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"scan\":" + out.RawString(prefix[1:]) + out.String(string(in.ScanType)) + } + { + const prefix string = ",\"version\":" + out.RawString(prefix) + out.Int(int(in.Version)) + } + { + const prefix string = ",\"ip\":" + out.RawString(prefix) + out.String(string(in.IP)) + } + { + const prefix string = ",\"port\":" + out.RawString(prefix) + out.Uint16(uint16(in.Port)) + } + if in.Auth { + const prefix string = ",\"auth\":" + out.RawString(prefix) + out.Bool(bool(in.Auth)) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v ScanResult) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjsonD3b49167EncodeGithubComVByteCpuSxPkgScanSocks5(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v ScanResult) MarshalEasyJSON(w *jwriter.Writer) { + easyjsonD3b49167EncodeGithubComVByteCpuSxPkgScanSocks5(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *ScanResult) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjsonD3b49167DecodeGithubComVByteCpuSxPkgScanSocks5(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *ScanResult) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjsonD3b49167DecodeGithubComVByteCpuSxPkgScanSocks5(l, v) +} diff --git a/pkg/scan/socks5/socks5.go b/pkg/scan/socks5/socks5.go new file mode 100644 index 0000000..09c372b --- /dev/null +++ b/pkg/scan/socks5/socks5.go @@ -0,0 +1,137 @@ +//go:generate easyjson -output_filename result_easyjson.go socks5.go + +package socks5 + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/v-byte-cpu/sx/pkg/scan" +) + +const ( + ScanType = "socks" + SOCKSVersion = 5 + + defaultDialTimeout = 2 * time.Second + defaultDataTimeout = 2 * time.Second +) + +//easyjson:json +type ScanResult struct { + ScanType string `json:"scan"` + Version int `json:"version"` + IP string `json:"ip"` + Port uint16 `json:"port"` + Auth bool `json:"auth,omitempty"` +} + +func (r *ScanResult) String() string { + return fmt.Sprintf("%-20s %-5d", r.IP, r.Port) +} + +func (r *ScanResult) ID() string { + return fmt.Sprintf("%s:%d", r.IP, r.Port) +} + +type Scanner struct { + dataTimeout time.Duration + dialer *net.Dialer +} + +// Assert that socks5.Scanner conforms to the scan.Scanner interface +var _ scan.Scanner = (*Scanner)(nil) + +type SocksOption func(*Scanner) + +func WithDialTimeout(timeout time.Duration) SocksOption { + return func(s *Scanner) { + s.dialer.Timeout = timeout + } +} + +func WithDataTimeout(timeout time.Duration) SocksOption { + return func(s *Scanner) { + s.dataTimeout = timeout + } +} + +func NewScanner(opts ...SocksOption) *Scanner { + s := &Scanner{ + dialer: &net.Dialer{ + Timeout: defaultDialTimeout, + }, + dataTimeout: defaultDataTimeout, + } + for _, o := range opts { + o(s) + } + return s +} + +func (s *Scanner) Scan(ctx context.Context, r *scan.Request) (result scan.Result, err error) { + var conn net.Conn + if conn, err = s.dialer.DialContext(ctx, "tcp", fmt.Sprintf("%s:%d", r.DstIP, r.DstPort)); err != nil { + return + } + defer conn.Close() + // tell the operating system to discard any unsent or unacknowledged data on Close() + // it will release all socket resources and send RST packet, fine for the scan + if err = conn.(*net.TCPConn).SetLinger(0); err != nil { + return + } + + done := make(chan interface{}) + defer close(done) + go func() { + select { + // return on ctx.Done without waiting read/write timeout + case <-ctx.Done(): + conn.Close() + case <-done: + } + }() + sconn := &socksConn{conn: conn, timeout: s.dataTimeout} + + req := NewMethodRequest(SOCKSVersion, MethodNoAuth) + if _, err = req.WriteTo(sconn); err != nil { + return + } + + reply := &MethodReply{} + if _, err = reply.ReadFrom(sconn); err != nil { + return + } + + // TODO also detect auth + if reply.Ver == SOCKSVersion && reply.Method == MethodNoAuth { + result = &ScanResult{ + ScanType: ScanType, + Version: SOCKSVersion, + IP: r.DstIP.String(), + Port: r.DstPort, + } + } + return +} + +type socksConn struct { + conn net.Conn + timeout time.Duration +} + +func (c *socksConn) Read(p []byte) (n int, err error) { + if err = c.conn.SetReadDeadline(time.Now().Add(c.timeout)); err != nil { + return + } + return c.conn.Read(p) +} + +func (c *socksConn) Write(p []byte) (n int, err error) { + if err = c.conn.SetWriteDeadline(time.Now().Add(c.timeout)); err != nil { + return + } + return c.conn.Write(p) +} diff --git a/pkg/scan/utils_test.go b/pkg/scan/utils_test.go index 8f0e14d..18890bd 100644 --- a/pkg/scan/utils_test.go +++ b/pkg/scan/utils_test.go @@ -40,3 +40,12 @@ func chanErrToGeneric(in <-chan error) <-chan interface{} { }() return out } + +func waitDone(t *testing.T, done <-chan interface{}) { + t.Helper() + select { + case <-done: + case <-time.After(waitTimeout): + require.Fail(t, "test timeout") + } +}