From 8151f29fe8de4f6ebd334e3033e74e44ef627a84 Mon Sep 17 00:00:00 2001 From: v-byte-cpu <65545655+v-byte-cpu@users.noreply.github.com> Date: Wed, 10 Mar 2021 16:17:47 +0300 Subject: [PATCH] Integrate golang linters --- .drone.yml | 10 +++++ .golangci.yml | 31 ++++++++++++++ .revive.toml | 45 ++++++++++++++++++++ README.md | 1 + command/arp.go | 43 ++++++++++++------- pkg/packet/afpacket/readwriter.go | 16 +++---- pkg/packet/receiver.go | 46 ++++++++++----------- pkg/packet/receiver_test.go | 25 ++++++----- pkg/packet/sender_test.go | 39 +++++++++-------- pkg/packet/utils_test.go | 6 ++- pkg/scan/arp/arp.go | 8 ++-- pkg/scan/engine.go | 4 +- pkg/scan/engine_test.go | 68 +++++++----------------------- pkg/scan/generator_test.go | 69 +++++++++++-------------------- pkg/scan/range.go | 4 +- pkg/scan/range_test.go | 51 +++++++++-------------- pkg/scan/utils_test.go | 42 +++++++++++++++++++ 17 files changed, 289 insertions(+), 219 deletions(-) create mode 100644 .golangci.yml create mode 100644 .revive.toml create mode 100644 pkg/scan/utils_test.go diff --git a/.drone.yml b/.drone.yml index 1a3d29d..371c240 100644 --- a/.drone.yml +++ b/.drone.yml @@ -7,6 +7,16 @@ clone: depth: 1 steps: + - name: lint + image: golangci/golangci-lint:v1.38 + commands: + - apt-get update + - apt-get install -y libpcap-dev + - golangci-lint run -v + # go 1.16 install doesn't modify go.mod and go.sum + - go install github.com/mgechev/revive@v1.0.3 + - revive -config .revive.toml -formatter friendly ./... + - name: test image: golang:1.16 commands: diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..c78abe7 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,31 @@ +linters: + enable: + - dogsled + - exportloopref + - funlen + - gocognit + - goconst + - gocritic + - gocyclo + - gofmt + - goimports + - golint + - gosec + - govet + - misspell + - nestif + - prealloc + - unconvert + - unparam + +run: + timeout: 3m + +issues: + exclude-rules: + - linters: + - gosec + text: "G404" + - linters: + - funlen + path: _test\.go diff --git a/.revive.toml b/.revive.toml new file mode 100644 index 0000000..d8f3ad6 --- /dev/null +++ b/.revive.toml @@ -0,0 +1,45 @@ +ignoreGeneratedHeader = false +severity = "warning" +confidence = 0.8 +errorCode = 1 +warningCode = 1 + +#Recommended rules +[rule.blank-imports] +[rule.context-as-argument] +[rule.context-keys-type] +[rule.dot-imports] +[rule.error-return] +[rule.error-strings] +[rule.error-naming] +[rule.if-return] +[rule.increment-decrement] +[rule.var-naming] +[rule.var-declaration] +[rule.package-comments] +[rule.range] +[rule.receiver-naming] +[rule.time-naming] +[rule.unexported-return] +[rule.indent-error-flow] +[rule.errorf] +[rule.empty-block] +[rule.superfluous-else] +[rule.unused-parameter] +[rule.unreachable-code] +[rule.redefines-builtin-id] + +#Custom rules +[rule.modifies-parameter] +[rule.unnecessary-stmt] +[rule.confusing-naming] +[rule.modifies-value-receiver] +[rule.range-val-in-closure] +[rule.range-val-address] +[rule.waitgroup-by-value] +[rule.atomic] +[rule.unused-receiver] +[rule.early-return] +[rule.unconditional-recursion] +[rule.identical-branches] +[rule.defer] \ No newline at end of file diff --git a/README.md b/README.md index 022480e..178dfc8 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # sx [![Build Status](https://cloud.drone.io/api/badges/v-byte-cpu/sx/status.svg)](https://cloud.drone.io/v-byte-cpu/sx) +[![GoReportCard Status](https://goreportcard.com/badge/github.com/v-byte-cpu/sx)](https://goreportcard.com/report/github.com/v-byte-cpu/sx) The goal of this project is to create the fastest network scanner with clean and simple code. diff --git a/command/arp.go b/command/arp.go index b71a2fe..f12003a 100644 --- a/command/arp.go +++ b/command/arp.go @@ -74,14 +74,39 @@ var arpCmd = &cobra.Command{ } } - r := &scan.Range{Subnet: dstSubnet, Interface: iface, SrcIP: srcIP, SrcMAC: srcMAC} + r := &scan.Range{Subnet: dstSubnet, Interface: iface, SrcIP: srcIP.To4(), SrcMAC: srcMAC} return startEngine(r) }, } -func startEngine(r *scan.Range) error { +func logResults(logger *zap.Logger, results <-chan *arp.ScanResult) { bw := bufio.NewWriter(os.Stdout) defer bw.Flush() + for result := range results { + // TODO refactor it using logger facade interface + if jsonFlag { + data, err := result.MarshalJSON() + if err != nil { + logger.Error("arp", zap.Error(err)) + } + _, err = bw.Write(data) + if err != nil { + logger.Error("arp", zap.Error(err)) + } + } else { + _, err := bw.WriteString(result.String()) + if err != nil { + logger.Error("arp", zap.Error(err)) + } + } + err := bw.WriteByte('\n') + if err != nil { + logger.Error("arp", zap.Error(err)) + } + } +} + +func startEngine(r *scan.Range) error { logger, err := zap.NewProduction() if err != nil { return err @@ -108,19 +133,7 @@ func startEngine(r *scan.Range) error { wg.Add(1) go func() { defer wg.Done() - for result := range m.Results() { - // TODO extract it - if jsonFlag { - data, err := result.MarshalJSON() - if err != nil { - logger.Error("arp", zap.Error(err)) - } - bw.Write(data) - } else { - bw.WriteString(result.String()) - } - bw.WriteByte('\n') - } + logResults(logger, m.Results()) }() // start scan diff --git a/pkg/packet/afpacket/readwriter.go b/pkg/packet/afpacket/readwriter.go index 2a8356d..d0de20b 100644 --- a/pkg/packet/afpacket/readwriter.go +++ b/pkg/packet/afpacket/readwriter.go @@ -9,25 +9,25 @@ import ( "golang.org/x/net/bpf" ) -type AfPacketSource struct { +type Source struct { handle *afp.TPacket } // Assert that AfPacketSource conforms to the packet.ReadWriter interface -var _ packet.ReadWriter = (*AfPacketSource)(nil) +var _ packet.ReadWriter = (*Source)(nil) -func NewPacketSource(iface string) (*AfPacketSource, error) { +func NewPacketSource(iface string) (*Source, error) { handle, err := afp.NewTPacket(afp.SocketRaw, afp.OptInterface(iface)) if err != nil { return nil, err } - return &AfPacketSource{handle}, nil + return &Source{handle}, nil } // maxPacketLength is the maximum size of packets to capture in bytes. // pcap calls it "snaplen" and default value used in tcpdump is 262144 bytes, // that is redundant for most scans, see pcap(3) and tcpdump(1) for more info -func (s *AfPacketSource) SetBPFFilter(bpfFilter string, maxPacketLength int) error { +func (s *Source) SetBPFFilter(bpfFilter string, maxPacketLength int) error { pcapBPF, err := pcap.CompileBPFFilter(layers.LinkTypeEthernet, maxPacketLength, bpfFilter) if err != nil { return err @@ -45,15 +45,15 @@ func (s *AfPacketSource) SetBPFFilter(bpfFilter string, maxPacketLength int) err return s.handle.SetBPF(bpfIns) } -func (s *AfPacketSource) Close() { +func (s *Source) Close() { s.handle.Close() } -func (s *AfPacketSource) ReadPacketData() ([]byte, *gopacket.CaptureInfo, error) { +func (s *Source) ReadPacketData() ([]byte, *gopacket.CaptureInfo, error) { data, ci, err := s.handle.ZeroCopyReadPacketData() return data, &ci, err } -func (s *AfPacketSource) WritePacketData(pkt []byte) error { +func (s *Source) WritePacketData(pkt []byte) error { return s.handle.WritePacketData(pkt) } diff --git a/pkg/packet/receiver.go b/pkg/packet/receiver.go index c182992..61b018f 100644 --- a/pkg/packet/receiver.go +++ b/pkg/packet/receiver.go @@ -61,31 +61,31 @@ func (r *receiver) ReceivePackets(ctx context.Context) <-chan error { case <-ctx.Done(): return default: - data, ci, err := r.sr.ReadPacketData() - if err != nil { - // Immediately retry for temporary errors - if isTemporaryError(err) { - continue - } - if isUnrecoverableError(err) { - return - } - // Log unknown error - select { - case <-ctx.Done(): - return - case errc <- err: - } - // Sleep briefly and try again - time.Sleep(time.Millisecond * time.Duration(5)) + } + data, ci, err := r.sr.ReadPacketData() + if err != nil { + // Immediately retry for temporary errors + if isTemporaryError(err) { continue } - if err := r.p.ProcessPacketData(data, ci); err != nil { - select { - case <-ctx.Done(): - return - case errc <- err: - } + if isUnrecoverableError(err) { + return + } + // Log unknown error + select { + case <-ctx.Done(): + return + case errc <- err: + } + // Sleep briefly and try again + time.Sleep(5 * time.Millisecond) + continue + } + if err := r.p.ProcessPacketData(data, ci); err != nil { + select { + case <-ctx.Done(): + return + case errc <- err: } } } diff --git a/pkg/packet/receiver_test.go b/pkg/packet/receiver_test.go index e771da3..0c15a2d 100644 --- a/pkg/packet/receiver_test.go +++ b/pkg/packet/receiver_test.go @@ -7,7 +7,6 @@ import ( "net" "syscall" "testing" - "time" "github.com/golang/mock/gomock" "github.com/google/gopacket" @@ -64,8 +63,8 @@ func TestReceivePacketsWithUnrecoverableError(t *testing.T) { p := NewMockProcessor(ctrl) r := NewReceiver(sr, p) - out := chanErrToGeneric(r.ReceivePackets(context.Background())) - result := chanToSlice(t, out, 0, 3*time.Second) + out := r.ReceivePackets(context.Background()) + result := chanToSlice(t, chanErrToGeneric(out), 0) assert.Equal(t, 0, len(result), "error slice is not empty") }) } @@ -90,8 +89,8 @@ func TestReceivePacketsOnePacket(t *testing.T) { ProcessPacketData(expectedData, newCaptureInfo()).Return(nil) r := NewReceiver(sr, p) - out := chanErrToGeneric(r.ReceivePackets(context.Background())) - result := chanToSlice(t, out, 0, 3*time.Second) + out := r.ReceivePackets(context.Background()) + result := chanToSlice(t, chanErrToGeneric(out), 0) assert.Equal(t, 0, len(result), "error slice is not empty") } @@ -113,8 +112,8 @@ func TestReceivePacketsOnePacketWithProcessError(t *testing.T) { ProcessPacketData(notNil, notNil).Return(errors.New("process error")) r := NewReceiver(sr, p) - out := chanErrToGeneric(r.ReceivePackets(context.Background())) - result := chanToSlice(t, out, 1, 3*time.Second) + out := r.ReceivePackets(context.Background()) + result := chanToSlice(t, chanErrToGeneric(out), 1) assert.Equal(t, 1, len(result), "error slice is invalid") assert.Error(t, result[0].(error)) } @@ -157,8 +156,8 @@ func TestReceivePacketsOnePacketWithRetryError(t *testing.T) { ProcessPacketData(expectedData, newCaptureInfo()).Return(nil) r := NewReceiver(sr, p) - out := chanErrToGeneric(r.ReceivePackets(context.Background())) - result := chanToSlice(t, out, 0, 3*time.Second) + out := r.ReceivePackets(context.Background()) + result := chanToSlice(t, chanErrToGeneric(out), 0) assert.Equal(t, 0, len(result), "error slice is not empty") }) } @@ -184,8 +183,8 @@ func TestReceivePacketsOnePacketWithUnknownError(t *testing.T) { ProcessPacketData(expectedData, newCaptureInfo()).Return(nil) r := NewReceiver(sr, p) - out := chanErrToGeneric(r.ReceivePackets(context.Background())) - result := chanToSlice(t, out, 1, 3*time.Second) + out := r.ReceivePackets(context.Background()) + result := chanToSlice(t, chanErrToGeneric(out), 1) assert.Equal(t, 1, len(result), "error slice length is invalid") assert.Error(t, result[0].(error)) } @@ -211,7 +210,7 @@ func TestReceivePacketsOnePacketWithContextCancel(t *testing.T) { }) r := NewReceiver(sr, p) - out := chanErrToGeneric(r.ReceivePackets(ctx)) - result := chanToSlice(t, out, 0, 3*time.Second) + out := r.ReceivePackets(ctx) + result := chanToSlice(t, chanErrToGeneric(out), 0) assert.Equal(t, 0, len(result), "error slice is not empty") } diff --git a/pkg/packet/sender_test.go b/pkg/packet/sender_test.go index ed69f59..2a95b0b 100644 --- a/pkg/packet/sender_test.go +++ b/pkg/packet/sender_test.go @@ -23,10 +23,9 @@ func TestSenderWithEmptyChannel(t *testing.T) { done, errc := s.SendPackets(context.Background(), in) - out := chanErrToGeneric(errc) - result := chanToSlice(t, out, 0, 3*time.Second) + result := chanToSlice(t, chanErrToGeneric(errc), 0) assert.Equal(t, 0, len(result), "error slice is not empty") - result = chanToSlice(t, done, 0, 3*time.Second) + result = chanToSlice(t, done, 0) assert.Equal(t, 0, len(result), "error slice is not empty") } @@ -35,7 +34,8 @@ func TestSenderWithOnePacket(t *testing.T) { in := make(chan *BufferData, 1) data := []byte{0x1, 0x2, 0x3} buffer := gopacket.NewSerializeBuffer() - gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{}, gopacket.Payload(data)) + err := gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{}, gopacket.Payload(data)) + require.NoError(t, err) in <- &BufferData{Buf: buffer} close(in) @@ -49,10 +49,9 @@ func TestSenderWithOnePacket(t *testing.T) { done, errc := s.SendPackets(context.Background(), in) - out := chanErrToGeneric(errc) - result := chanToSlice(t, out, 0, 3*time.Second) + result := chanToSlice(t, chanErrToGeneric(errc), 0) assert.Equal(t, 0, len(result), "error slice is not empty") - result = chanToSlice(t, done, 0, 3*time.Second) + result = chanToSlice(t, done, 0) assert.Equal(t, 0, len(result), "error slice is not empty") } @@ -62,12 +61,14 @@ func TestSenderWithTwoPackets(t *testing.T) { data := []byte{0x1, 0x2, 0x3} buffer := gopacket.NewSerializeBuffer() - gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{}, gopacket.Payload(data)) + err := gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{}, gopacket.Payload(data)) + require.NoError(t, err) in <- &BufferData{Buf: buffer} data2 := []byte{0x2, 0x3, 0x4} buffer2 := gopacket.NewSerializeBuffer() - gopacket.SerializeLayers(buffer2, gopacket.SerializeOptions{}, gopacket.Payload(data2)) + err = gopacket.SerializeLayers(buffer2, gopacket.SerializeOptions{}, gopacket.Payload(data2)) + require.NoError(t, err) in <- &BufferData{Buf: buffer2} close(in) @@ -87,10 +88,9 @@ func TestSenderWithTwoPackets(t *testing.T) { done, errc := s.SendPackets(context.Background(), in) - out := chanErrToGeneric(errc) - result := chanToSlice(t, out, 0, 3*time.Second) + result := chanToSlice(t, chanErrToGeneric(errc), 0) assert.Equal(t, 0, len(result), "error slice is not empty") - result = chanToSlice(t, done, 0, 3*time.Second) + result = chanToSlice(t, done, 0) assert.Equal(t, 0, len(result), "error slice is not empty") } @@ -106,12 +106,11 @@ func TestSenderWithInvalidPacketReturnsError(t *testing.T) { done, errc := s.SendPackets(context.Background(), in) - out := chanErrToGeneric(errc) - result := chanToSlice(t, out, 1, 3*time.Second) + result := chanToSlice(t, chanErrToGeneric(errc), 1) assert.Equal(t, 1, len(result), "error slice size is invalid") assert.Error(t, result[0].(error)) - result = chanToSlice(t, done, 0, 3*time.Second) + result = chanToSlice(t, done, 0) assert.Equal(t, 0, len(result), "error slice is not empty") } @@ -121,7 +120,8 @@ func TestSenderWithWriteErrorReturnsError(t *testing.T) { data := []byte{0x1, 0x2, 0x3} buffer := gopacket.NewSerializeBuffer() - gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{}, gopacket.Payload(data)) + err := gopacket.SerializeLayers(buffer, gopacket.SerializeOptions{}, gopacket.Payload(data)) + require.NoError(t, err) in <- &BufferData{Buf: buffer} close(in) @@ -132,12 +132,11 @@ func TestSenderWithWriteErrorReturnsError(t *testing.T) { done, errc := s.SendPackets(context.Background(), in) - out := chanErrToGeneric(errc) - result := chanToSlice(t, out, 1, 3*time.Second) + result := chanToSlice(t, chanErrToGeneric(errc), 1) assert.Equal(t, 1, len(result), "error slice size is invalid") assert.Error(t, result[0].(error)) - result = chanToSlice(t, done, 0, 3*time.Second) + result = chanToSlice(t, done, 0) assert.Equal(t, 0, len(result), "error slice is not empty") } @@ -157,6 +156,6 @@ func TestSenderWithTimeout(t *testing.T) { case <-time.After(1 * time.Second): require.FailNow(t, "exit timeout") } - result := chanToSlice(t, done, 0, 3*time.Second) + result := chanToSlice(t, done, 0) assert.Equal(t, 0, len(result), "error slice is not empty") } diff --git a/pkg/packet/utils_test.go b/pkg/packet/utils_test.go index 29ba0a7..14a2f60 100644 --- a/pkg/packet/utils_test.go +++ b/pkg/packet/utils_test.go @@ -7,7 +7,9 @@ import ( "github.com/stretchr/testify/require" ) -func chanToSlice(t *testing.T, in <-chan interface{}, expectedLen int, timeout time.Duration) []interface{} { +const waitTimeout = 3 * time.Second + +func chanToSlice(t *testing.T, in <-chan interface{}, expectedLen int) []interface{} { t.Helper() result := []interface{}{} loop: @@ -21,7 +23,7 @@ loop: require.FailNow(t, "chan size is greater than expected, data:", data) } result = append(result, data) - case <-time.After(timeout): + case <-time.After(waitTimeout): t.Fatal("read timeout") } } diff --git a/pkg/scan/arp/arp.go b/pkg/scan/arp/arp.go index 4db4a36..5bdc30a 100644 --- a/pkg/scan/arp/arp.go +++ b/pkg/scan/arp/arp.go @@ -17,15 +17,15 @@ import ( type ScanMethod struct { gen *scan.PacketMultiGenerator + parser *gopacket.DecodingLayerParser results chan *ScanResult internalResults chan *ScanResult ctx context.Context + rcvDecoded []gopacket.LayerType rcvEth layers.Ethernet rcvARP layers.ARP - rcvDecoded []gopacket.LayerType rcvMacPrefix [3]byte - parser *gopacket.DecodingLayerParser } //easyjson:json @@ -88,7 +88,7 @@ func (s *ScanMethod) Packets(ctx context.Context, r *scan.Range) <-chan *packet. return s.gen.Packets(ctx, pairs) } -func (s *ScanMethod) ProcessPacketData(data []byte, ci *gopacket.CaptureInfo) error { +func (s *ScanMethod) ProcessPacketData(data []byte, _ *gopacket.CaptureInfo) error { // try to exit as early as possible select { case <-s.ctx.Done(): @@ -125,7 +125,7 @@ func newPacketFiller() *packetFiller { return &packetFiller{} } -func (f *packetFiller) Fill(packet gopacket.SerializeBuffer, pair *scan.Request) error { +func (*packetFiller) Fill(packet gopacket.SerializeBuffer, pair *scan.Request) error { eth := &layers.Ethernet{ SrcMAC: pair.SrcMAC, DstMAC: net.HardwareAddr{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, diff --git a/pkg/scan/engine.go b/pkg/scan/engine.go index bfd4f5a..f93f63f 100644 --- a/pkg/scan/engine.go +++ b/pkg/scan/engine.go @@ -11,10 +11,10 @@ import ( ) type Range struct { - SrcIP net.IP - SrcMAC net.HardwareAddr Interface *net.Interface Subnet *net.IPNet + SrcIP net.IP + SrcMAC net.HardwareAddr StartPort uint16 EndPort uint16 } diff --git a/pkg/scan/engine_test.go b/pkg/scan/engine_test.go index 74711f8..568b789 100644 --- a/pkg/scan/engine_test.go +++ b/pkg/scan/engine_test.go @@ -11,62 +11,19 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/v-byte-cpu/sx/pkg/packet" ) -func chanErrorToSlice(t *testing.T, in <-chan error, expectedLen int, timeout time.Duration) []error { - t.Helper() - result := []error{} -loop: - for { - select { - case data, ok := <-in: - if !ok { - break loop - } - if len(result) == expectedLen { - require.FailNow(t, "chan size is greater than expected, data:", data) - } - result = append(result, data) - case <-time.After(timeout): - t.Fatal("read timeout") - } - } - return result -} - -// generics would be helpful :) -func chanRequestToSlice(t *testing.T, in <-chan *Request, expectedLen int, timeout time.Duration) []*Request { - t.Helper() - result := []*Request{} -loop: - for { - select { - case data, ok := <-in: - if !ok { - break loop - } - if len(result) == expectedLen { - require.FailNow(t, "chan size is greater than expected, data:", data) - } - result = append(result, data) - case <-time.After(timeout): - t.Fatal("read timeout") - } - } - return result -} - func TestMergeErrChanEmptyChannels(t *testing.T) { t.Parallel() c1 := make(chan error) close(c1) c2 := make(chan error) close(c2) + out := mergeErrChan(context.Background(), c1, c2) + result := chanToSlice(t, chanErrToGeneric(out), 0) - result := chanErrorToSlice(t, out, 0, 3*time.Second) assert.Equal(t, 0, len(result), "error slice is not empty") } @@ -77,11 +34,12 @@ func TestMergeErrChanOneElementAndEmptyChannel(t *testing.T) { close(c1) c2 := make(chan error) close(c2) + out := mergeErrChan(context.Background(), c1, c2) + result := chanToSlice(t, chanErrToGeneric(out), 1) - result := chanErrorToSlice(t, out, 1, 3*time.Second) assert.Equal(t, 1, len(result), "error slice size is invalid") - assert.Error(t, result[0]) + assert.Error(t, result[0].(error)) } func TestMergeErrChanTwoElements(t *testing.T) { @@ -92,12 +50,13 @@ func TestMergeErrChanTwoElements(t *testing.T) { c2 := make(chan error, 1) c2 <- errors.New("test error") close(c2) + out := mergeErrChan(context.Background(), c1, c2) + result := chanToSlice(t, chanErrToGeneric(out), 2) - result := chanErrorToSlice(t, out, 2, 3*time.Second) assert.Equal(t, 2, len(result), "error slice size is invalid") - assert.Error(t, result[0]) - assert.Error(t, result[1]) + assert.Error(t, result[0].(error)) + assert.Error(t, result[1].(error)) } func TestMergeErrChanContextExit(t *testing.T) { @@ -109,9 +68,10 @@ func TestMergeErrChanContextExit(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) defer cancel() + out := mergeErrChan(ctx, c1, c2) + result := chanToSlice(t, chanErrToGeneric(out), 0) - result := chanErrorToSlice(t, out, 0, 3*time.Second) assert.Equal(t, 0, len(result), "error slice is not empty") } @@ -147,8 +107,8 @@ func TestEngineStartCollectsAllErrors(t *testing.T) { EndPort: 888, }) - result := chanErrorToSlice(t, out, 2, 3*time.Second) + result := chanToSlice(t, chanErrToGeneric(out), 2) assert.Equal(t, 2, len(result), "error slice is invalid") - assert.Error(t, result[0]) - assert.Error(t, result[1]) + assert.Error(t, result[0].(error)) + assert.Error(t, result[1].(error)) } diff --git a/pkg/scan/generator_test.go b/pkg/scan/generator_test.go index 6efe3a0..9d05c64 100644 --- a/pkg/scan/generator_test.go +++ b/pkg/scan/generator_test.go @@ -25,27 +25,6 @@ func chanBufferDataToGeneric(in <-chan *packet.BufferData) <-chan interface{} { return out } -func chanToSlice(t *testing.T, in <-chan interface{}, expectedLen int, timeout time.Duration) []interface{} { - t.Helper() - result := []interface{}{} -loop: - for { - select { - case data, ok := <-in: - if !ok { - break loop - } - if len(result) == expectedLen { - require.FailNow(t, "chan size is greater than expected, data:", data) - } - result = append(result, data) - case <-time.After(timeout): - t.Fatal("read timeout") - } - } - return result -} - func TestGeneratorPacketsWithEmptyChannel(t *testing.T) { t.Parallel() in := make(chan *Request) @@ -55,8 +34,8 @@ func TestGeneratorPacketsWithEmptyChannel(t *testing.T) { f := NewMockPacketFiller(ctrl) g := NewPacketGenerator(f) - out := chanBufferDataToGeneric(g.Packets(context.Background(), in)) - result := chanToSlice(t, out, 0, 3*time.Second) + out := g.Packets(context.Background(), in) + result := chanToSlice(t, chanBufferDataToGeneric(out), 0) assert.Equal(t, 0, len(result), "result is not empty") } @@ -69,8 +48,8 @@ func TestMultiGeneratorPacketsWithEmptyChannel(t *testing.T) { f := NewMockPacketFiller(ctrl) g := NewPacketMultiGenerator(f, runtime.NumCPU()) - out := chanBufferDataToGeneric(g.Packets(context.Background(), in)) - result := chanToSlice(t, out, 0, 3*time.Second) + out := g.Packets(context.Background(), in) + result := chanToSlice(t, chanBufferDataToGeneric(out), 0) assert.Equal(t, 0, len(result), "result is not empty") } @@ -90,8 +69,8 @@ func TestGeneratorPacketsWithOnePair(t *testing.T) { g := NewPacketGenerator(f) - out := chanBufferDataToGeneric(g.Packets(context.Background(), in)) - results := chanToSlice(t, out, 1, 3*time.Second) + out := g.Packets(context.Background(), in) + results := chanToSlice(t, chanBufferDataToGeneric(out), 1) assert.Equal(t, 1, len(results), "result size is invalid") result := results[0].(*packet.BufferData) @@ -115,8 +94,8 @@ func TestMultiGeneratorPacketsWithOnePair(t *testing.T) { g := NewPacketMultiGenerator(f, runtime.NumCPU()) - out := chanBufferDataToGeneric(g.Packets(context.Background(), in)) - results := chanToSlice(t, out, 1, 3*time.Second) + out := g.Packets(context.Background(), in) + results := chanToSlice(t, chanBufferDataToGeneric(out), 1) assert.Equal(t, 1, len(results), "result size is invalid") result := results[0].(*packet.BufferData) @@ -143,8 +122,8 @@ func TestGeneratorPacketsWithTwoPairs(t *testing.T) { &Request{DstIP: net.IPv4(192, 168, 0, 1).To4(), DstPort: port + 1}) g := NewPacketGenerator(f) - out := chanBufferDataToGeneric(g.Packets(context.Background(), in)) - results := chanToSlice(t, out, 2, 3*time.Second) + out := g.Packets(context.Background(), in) + results := chanToSlice(t, chanBufferDataToGeneric(out), 2) assert.Equal(t, 2, len(results), "result size is invalid") result1 := results[0].(*packet.BufferData) @@ -175,8 +154,8 @@ func TestMultiGeneratorPacketsWithTwoPairs(t *testing.T) { g := NewPacketMultiGenerator(f, runtime.NumCPU()) - out := chanBufferDataToGeneric(g.Packets(context.Background(), in)) - results := chanToSlice(t, out, 2, 3*time.Second) + out := g.Packets(context.Background(), in) + results := chanToSlice(t, chanBufferDataToGeneric(out), 2) assert.Equal(t, 2, len(results), "result size is invalid") result1 := results[0].(*packet.BufferData) @@ -203,8 +182,8 @@ func TestGeneratorPacketsWithOnePairReturnsError(t *testing.T) { Return(errors.New("failed request")) g := NewPacketGenerator(f) - out := chanBufferDataToGeneric(g.Packets(context.Background(), in)) - results := chanToSlice(t, out, 1, 3*time.Second) + out := g.Packets(context.Background(), in) + results := chanToSlice(t, chanBufferDataToGeneric(out), 1) assert.Equal(t, 1, len(results), "result size is invalid") result := results[0].(*packet.BufferData) @@ -229,8 +208,8 @@ func TestMultiGeneratorPacketsWithOnePairReturnsError(t *testing.T) { g := NewPacketMultiGenerator(f, runtime.NumCPU()) - out := chanBufferDataToGeneric(g.Packets(context.Background(), in)) - results := chanToSlice(t, out, 1, 3*time.Second) + out := g.Packets(context.Background(), in) + results := chanToSlice(t, chanBufferDataToGeneric(out), 1) assert.Equal(t, 1, len(results), "result size is invalid") result := results[0].(*packet.BufferData) @@ -278,9 +257,9 @@ func TestMergeBufferDataChanEmptyChannels(t *testing.T) { close(c1) c2 := make(chan *packet.BufferData) close(c2) - out := chanBufferDataToGeneric(MergeBufferDataChan(context.Background(), c1, c2)) + out := MergeBufferDataChan(context.Background(), c1, c2) - result := chanToSlice(t, out, 0, 3*time.Second) + result := chanToSlice(t, chanBufferDataToGeneric(out), 0) assert.Equal(t, 0, len(result), "result slice is not empty") } @@ -291,9 +270,9 @@ func TestMergeBufferDataChanOneElementAndEmptyChannel(t *testing.T) { close(c1) c2 := make(chan *packet.BufferData) close(c2) - out := chanBufferDataToGeneric(MergeBufferDataChan(context.Background(), c1, c2)) + out := MergeBufferDataChan(context.Background(), c1, c2) - result := chanToSlice(t, out, 1, 3*time.Second) + result := chanToSlice(t, chanBufferDataToGeneric(out), 1) assert.Equal(t, 1, len(result), "result slice size is invalid") assert.NotNil(t, result[0]) } @@ -306,9 +285,9 @@ func TestMergeBufferDataChanTwoElements(t *testing.T) { c2 := make(chan *packet.BufferData, 1) c2 <- &packet.BufferData{} close(c2) - out := chanBufferDataToGeneric(MergeBufferDataChan(context.Background(), c1, c2)) + out := MergeBufferDataChan(context.Background(), c1, c2) - result := chanToSlice(t, out, 2, 3*time.Second) + result := chanToSlice(t, chanBufferDataToGeneric(out), 2) assert.Equal(t, 2, len(result), "result slice size is invalid") assert.NotNil(t, result[0]) assert.NotNil(t, result[1]) @@ -323,8 +302,8 @@ func TestMergeBufferDataChanContextExit(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) defer cancel() - out := chanBufferDataToGeneric(MergeBufferDataChan(ctx, c1, c2)) + out := MergeBufferDataChan(ctx, c1, c2) - result := chanToSlice(t, out, 0, 3*time.Second) + result := chanToSlice(t, chanBufferDataToGeneric(out), 0) assert.Equal(t, 0, len(result), "result slice is not empty") } diff --git a/pkg/scan/range.go b/pkg/scan/range.go index f0761ba..cede7e1 100644 --- a/pkg/scan/range.go +++ b/pkg/scan/range.go @@ -12,12 +12,12 @@ var ErrPortRange = errors.New("invalid port range") var ErrSubnet = errors.New("invalid subnet") type Request struct { + Meta map[string]interface{} SrcIP net.IP - SrcMAC []byte DstIP net.IP + SrcMAC []byte DstMAC []byte DstPort uint16 - Meta map[string]interface{} } func IPPortPairs(ctx context.Context, r *Range) (<-chan *Request, error) { diff --git a/pkg/scan/range_test.go b/pkg/scan/range_test.go index adeecaf..4cb8e33 100644 --- a/pkg/scan/range_test.go +++ b/pkg/scan/range_test.go @@ -4,7 +4,6 @@ import ( "context" "net" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -98,31 +97,21 @@ func TestIPPortPairsWithInvalidInput(t *testing.T) { } } -func comparePairChanToSlice(t *testing.T, expected []*Request, in <-chan *Request, timeout time.Duration) { +func comparePairChanToSlice(t *testing.T, expected []interface{}, in <-chan *Request) { t.Helper() - result := pairChanToSlice(t, in, len(expected), timeout) + result := chanToSlice(t, chanPairToGeneric(in), len(expected)) require.Equal(t, expected, result) } -func pairChanToSlice(t *testing.T, in <-chan *Request, expectedLen int, timeout time.Duration) []*Request { - t.Helper() - result := []*Request{} -loop: - for { - select { - case data, ok := <-in: - if !ok { - break loop - } - if len(result) == expectedLen { - require.FailNow(t, "chan size is greater than expected, data:", data) - } - result = append(result, data) - case <-time.After(timeout): - t.Fatal("read timeout") +func chanPairToGeneric(in <-chan *Request) <-chan interface{} { + out := make(chan interface{}, cap(in)) + go func() { + defer close(out) + for i := range in { + out <- i } - } - return result + }() + return out } func TestIPPortPairsWithOneIpOnePort(t *testing.T) { @@ -136,10 +125,10 @@ func TestIPPortPairsWithOneIpOnePort(t *testing.T) { )) assert.NoError(t, err) - expected := []*Request{ + expected := []interface{}{ newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1).To4()), withDstPort(port)), } - comparePairChanToSlice(t, expected, pairs, 5*time.Second) + comparePairChanToSlice(t, expected, pairs) } func TestIPPortPairsWithOneIpTwoPorts(t *testing.T) { @@ -153,11 +142,11 @@ func TestIPPortPairsWithOneIpTwoPorts(t *testing.T) { )) assert.NoError(t, err) - expected := []*Request{ + expected := []interface{}{ newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1).To4()), withDstPort(port)), newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1).To4()), withDstPort(port+1)), } - comparePairChanToSlice(t, expected, pairs, 5*time.Second) + comparePairChanToSlice(t, expected, pairs) } func TestIPPortPairsWithTwoIpsOnePort(t *testing.T) { @@ -171,11 +160,11 @@ func TestIPPortPairsWithTwoIpsOnePort(t *testing.T) { )) assert.NoError(t, err) - expected := []*Request{ + expected := []interface{}{ newScanRequest(withDstIP(net.IPv4(192, 168, 0, 0).To4()), withDstPort(port)), newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1).To4()), withDstPort(port)), } - comparePairChanToSlice(t, expected, pairs, 5*time.Second) + comparePairChanToSlice(t, expected, pairs) } func TestIPPortPairsWithFourIpsOnePort(t *testing.T) { @@ -189,13 +178,13 @@ func TestIPPortPairsWithFourIpsOnePort(t *testing.T) { )) assert.NoError(t, err) - expected := []*Request{ + expected := []interface{}{ newScanRequest(withDstIP(net.IPv4(192, 168, 0, 0).To4()), withDstPort(port)), newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1).To4()), withDstPort(port)), newScanRequest(withDstIP(net.IPv4(192, 168, 0, 2).To4()), withDstPort(port)), newScanRequest(withDstIP(net.IPv4(192, 168, 0, 3).To4()), withDstPort(port)), } - comparePairChanToSlice(t, expected, pairs, 5*time.Second) + comparePairChanToSlice(t, expected, pairs) } func TestIPPortPairsWithTwoIpsTwoPorts(t *testing.T) { @@ -209,11 +198,11 @@ func TestIPPortPairsWithTwoIpsTwoPorts(t *testing.T) { )) assert.NoError(t, err) - expected := []*Request{ + expected := []interface{}{ newScanRequest(withDstIP(net.IPv4(192, 168, 0, 0).To4()), withDstPort(port)), newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1).To4()), withDstPort(port)), newScanRequest(withDstIP(net.IPv4(192, 168, 0, 0).To4()), withDstPort(port+1)), newScanRequest(withDstIP(net.IPv4(192, 168, 0, 1).To4()), withDstPort(port+1)), } - comparePairChanToSlice(t, expected, pairs, 5*time.Second) + comparePairChanToSlice(t, expected, pairs) } diff --git a/pkg/scan/utils_test.go b/pkg/scan/utils_test.go new file mode 100644 index 0000000..8f0e14d --- /dev/null +++ b/pkg/scan/utils_test.go @@ -0,0 +1,42 @@ +package scan + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +const waitTimeout = 3 * time.Second + +func chanToSlice(t *testing.T, in <-chan interface{}, expectedLen int) []interface{} { + t.Helper() + result := []interface{}{} +loop: + for { + select { + case data, ok := <-in: + if !ok { + break loop + } + if len(result) == expectedLen { + require.FailNow(t, "chan size is greater than expected, data:", data) + } + result = append(result, data) + case <-time.After(waitTimeout): + t.Fatal("read timeout") + } + } + return result +} + +func chanErrToGeneric(in <-chan error) <-chan interface{} { + out := make(chan interface{}, cap(in)) + go func() { + defer close(out) + for i := range in { + out <- i + } + }() + return out +}