Skip to content

Commit

Permalink
Add buffer pool (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
v-byte-cpu authored Apr 25, 2021
1 parent 9b59f22 commit 063c909
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 33 deletions.
26 changes: 26 additions & 0 deletions pkg/packet/memory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package packet

import (
"sync"

"github.com/google/gopacket"
)

var bufferPool = &sync.Pool{
New: func() interface{} {
return gopacket.NewSerializeBuffer()
},
}

func NewSerializeBuffer() gopacket.SerializeBuffer {
buf := bufferPool.Get().(gopacket.SerializeBuffer)
return buf
}

func FreeSerializeBuffer(buf gopacket.SerializeBuffer) (err error) {
if err = buf.Clear(); err != nil {
return
}
bufferPool.Put(buf)
return
}
3 changes: 3 additions & 0 deletions pkg/packet/sender.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func (s *sender) SendPackets(ctx context.Context, in <-chan *BufferData) (<-chan
if err := s.w.WritePacketData(pkt.Buf.Bytes()); err != nil {
errc <- err
}
if err := FreeSerializeBuffer(pkt.Buf); err != nil {
errc <- err
}
}
}
}()
Expand Down
2 changes: 1 addition & 1 deletion pkg/scan/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func mergeErrChan(ctx context.Context, channels ...<-chan error) <-chan error {
var wg sync.WaitGroup
wg.Add(len(channels))

out := make(chan error)
out := make(chan error, 100)
multiplex := func(c <-chan error) {
defer wg.Done()
for {
Expand Down
14 changes: 6 additions & 8 deletions pkg/scan/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,16 +309,14 @@ func TestScanEngineWithResults(t *testing.T) {

done, errc := engine.Start(ctx, &Range{})
<-done
results := make([]Result, 2)
results[0] = <-resultCh.Chan()
results[1] = <-resultCh.Chan()
cancel()
require.Zero(t, len(errc), "error channel is not empty")
var results []Result
cnt := 0
for result := range resultCh.Chan() {
cnt++
if cnt > 2 {
require.Fail(t, "result channel contains more elements than expected: ", result)
}
results = append(results, result)
result, ok := <-resultCh.Chan()
if ok {
require.Fail(t, "result channel contains more elements than expected: ", result)
}

sort.Slice(results, func(i, j int) bool {
Expand Down
7 changes: 3 additions & 4 deletions pkg/scan/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type packetGenerator struct {
}

func (g *packetGenerator) Packets(ctx context.Context, in <-chan *Request) <-chan *packet.BufferData {
out := make(chan *packet.BufferData)
out := make(chan *packet.BufferData, 100)
go func() {
defer close(out)
for {
Expand All @@ -42,8 +42,7 @@ func (g *packetGenerator) Packets(ctx context.Context, in <-chan *Request) <-cha
writeBufToChan(ctx, out, &packet.BufferData{Err: r.Err})
continue
}
// TODO buffer pool
buf := gopacket.NewSerializeBuffer()
buf := packet.NewSerializeBuffer()
if err := g.filler.Fill(buf, r); err != nil {
writeBufToChan(ctx, out, &packet.BufferData{Err: err})
continue
Expand Down Expand Up @@ -86,7 +85,7 @@ func MergeBufferDataChan(ctx context.Context, channels ...<-chan *packet.BufferD
var wg sync.WaitGroup
wg.Add(len(channels))

out := make(chan *packet.BufferData)
out := make(chan *packet.BufferData, len(channels)*100)
multiplex := func(c <-chan *packet.BufferData) {
defer wg.Done()
for {
Expand Down
14 changes: 7 additions & 7 deletions pkg/scan/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (*portGenerator) Ports(ctx context.Context, r *Range) (<-chan uint16, error
if err := validatePorts(r.Ports); err != nil {
return nil, err
}
out := make(chan uint16)
out := make(chan uint16, 100)
go func() {
defer close(out)
for _, portRange := range r.Ports {
Expand Down Expand Up @@ -79,9 +79,9 @@ type IPGetter interface {
GetIP() (net.IP, error)
}

type wrapIP net.IP
type WrapIP net.IP

func (i wrapIP) GetIP() (net.IP, error) {
func (i WrapIP) GetIP() (net.IP, error) {
return net.IP(i), nil
}

Expand All @@ -99,12 +99,12 @@ func (*ipGenerator) IPs(ctx context.Context, r *Range) (<-chan IPGetter, error)
if r.DstSubnet == nil {
return nil, ErrSubnet
}
out := make(chan IPGetter)
out := make(chan IPGetter, 100)
go func() {
defer close(out)
ipnet := r.DstSubnet
for ipaddr := ipnet.IP.Mask(ipnet.Mask); ipnet.Contains(ipaddr); ip.Inc(ipaddr) {
writeIP(ctx, out, wrapIP(ip.DupIP(ipaddr)))
writeIP(ctx, out, WrapIP(ip.DupIP(ipaddr)))
}
}()
return out, nil
Expand Down Expand Up @@ -132,7 +132,7 @@ func (rg *ipPortGenerator) GenerateRequests(ctx context.Context, r *Range) (<-ch
if err != nil {
return nil, err
}
out := make(chan *Request)
out := make(chan *Request, 100)
go func() {
defer close(out)
for port := range ports {
Expand Down Expand Up @@ -281,7 +281,7 @@ func (g *fileIPGenerator) IPs(ctx context.Context, _ *Range) (<-chan IPGetter, e
writeIP(ctx, out, &ipError{error: ErrIP})
return
}
writeIP(ctx, out, wrapIP(ip))
writeIP(ctx, out, WrapIP(ip))
}
if err = scanner.Err(); err != nil {
writeIP(ctx, out, &ipError{error: err})
Expand Down
26 changes: 13 additions & 13 deletions pkg/scan/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func TestIPGenerator(t *testing.T) {
withSubnet(&net.IPNet{IP: net.IPv4(192, 168, 0, 1), Mask: net.CIDRMask(32, 32)}),
),
expected: []interface{}{
wrapIP(net.IPv4(192, 168, 0, 1).To4()),
WrapIP(net.IPv4(192, 168, 0, 1).To4()),
},
},
{
Expand All @@ -255,8 +255,8 @@ func TestIPGenerator(t *testing.T) {
withSubnet(&net.IPNet{IP: net.IPv4(1, 0, 0, 1), Mask: net.CIDRMask(31, 32)}),
),
expected: []interface{}{
wrapIP(net.IPv4(1, 0, 0, 0).To4()),
wrapIP(net.IPv4(1, 0, 0, 1).To4()),
WrapIP(net.IPv4(1, 0, 0, 0).To4()),
WrapIP(net.IPv4(1, 0, 0, 1).To4()),
},
},
{
Expand All @@ -265,10 +265,10 @@ func TestIPGenerator(t *testing.T) {
withSubnet(&net.IPNet{IP: net.IPv4(10, 0, 0, 1), Mask: net.CIDRMask(30, 32)}),
),
expected: []interface{}{
wrapIP(net.IPv4(10, 0, 0, 0).To4()),
wrapIP(net.IPv4(10, 0, 0, 1).To4()),
wrapIP(net.IPv4(10, 0, 0, 2).To4()),
wrapIP(net.IPv4(10, 0, 0, 3).To4()),
WrapIP(net.IPv4(10, 0, 0, 0).To4()),
WrapIP(net.IPv4(10, 0, 0, 1).To4()),
WrapIP(net.IPv4(10, 0, 0, 2).To4()),
WrapIP(net.IPv4(10, 0, 0, 3).To4()),
},
},
}
Expand Down Expand Up @@ -702,14 +702,14 @@ func TestFileIPGenerator(t *testing.T) {
name: "OneIP",
input: `{"ip":"192.168.0.1"}`,
expected: []interface{}{
wrapIP(net.IPv4(192, 168, 0, 1)),
WrapIP(net.IPv4(192, 168, 0, 1)),
},
},
{
name: "OneIPWithUnknownField",
input: `{"ip":"192.168.0.1","abc":"field"}`,
expected: []interface{}{
wrapIP(net.IPv4(192, 168, 0, 1)),
WrapIP(net.IPv4(192, 168, 0, 1)),
},
},
{
Expand All @@ -719,8 +719,8 @@ func TestFileIPGenerator(t *testing.T) {
`{"ip":"192.168.0.2"}`,
}, "\n"),
expected: []interface{}{
wrapIP(net.IPv4(192, 168, 0, 1)),
wrapIP(net.IPv4(192, 168, 0, 2)),
WrapIP(net.IPv4(192, 168, 0, 1)),
WrapIP(net.IPv4(192, 168, 0, 2)),
},
},
{
Expand All @@ -737,7 +737,7 @@ func TestFileIPGenerator(t *testing.T) {
`{"ip":"192`,
}, "\n"),
expected: []interface{}{
wrapIP(net.IPv4(192, 168, 0, 1)),
WrapIP(net.IPv4(192, 168, 0, 1)),
&ipError{error: ErrJSON},
},
},
Expand All @@ -749,7 +749,7 @@ func TestFileIPGenerator(t *testing.T) {
`{"ip":"192.168.0.3","port":888}`,
}, "\n"),
expected: []interface{}{
wrapIP(net.IPv4(192, 168, 0, 1)),
WrapIP(net.IPv4(192, 168, 0, 1)),
&ipError{error: ErrJSON},
},
},
Expand Down
61 changes: 61 additions & 0 deletions pkg/scan/tcp/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tcp
import (
"context"
"net"
"runtime"
"testing"
"time"

Expand All @@ -11,6 +12,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/v-byte-cpu/sx/pkg/scan"
"github.com/v-byte-cpu/sx/pkg/scan/arp"
)

func TestPacketFiller(t *testing.T) {
Expand Down Expand Up @@ -285,3 +287,62 @@ func TestAllFlags(t *testing.T) {
})
}
}

type mockIPGeneratorFunc func(ctx context.Context, r *scan.Range) (<-chan scan.IPGetter, error)

func (f mockIPGeneratorFunc) IPs(ctx context.Context, r *scan.Range) (<-chan scan.IPGetter, error) {
return f(ctx, r)
}

type nullPacketReadWriter struct{}

func (*nullPacketReadWriter) ReadPacketData() (data []byte, ci *gopacket.CaptureInfo, err error) {
return
}

func (*nullPacketReadWriter) WritePacketData(_ []byte) error {
return nil
}

func BenchmarkTCPScanEngine(b *testing.B) {
b.ReportAllocs()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

dstIP := net.IPv4(192, 168, 0, 3).To4()
ipgen := mockIPGeneratorFunc(func(ctx context.Context, r *scan.Range) (<-chan scan.IPGetter, error) {
out := make(chan scan.IPGetter, 100)
go func() {
defer close(out)
for i := 0; i < b.N; i++ {
select {
case <-ctx.Done():
return
case out <- scan.WrapIP(dstIP):
}
}
}()
return out, nil
})
reqgen := arp.NewCacheRequestGenerator(
scan.NewIPPortGenerator(ipgen, scan.NewPortGenerator()),
net.HardwareAddr{0x10, 0x11, 0x12, 0x13, 0x14, 0x15},
arp.NewCache())
pktgen := scan.NewPacketMultiGenerator(NewPacketFiller(), runtime.NumCPU())
psrc := scan.NewPacketSource(reqgen, pktgen)
results := scan.NewResultChan(ctx, 1000)
sm := NewScanMethod("tcpbench", psrc, results)
engine := scan.SetupPacketEngine(&nullPacketReadWriter{}, sm)

done, _ := engine.Start(ctx, &scan.Range{
SrcIP: net.IPv4(192, 168, 0, 2).To4(),
SrcMAC: net.HardwareAddr{0x1, 0x2, 0x3, 0x4, 0x5, 0x6},
Ports: []*scan.PortRange{
{
StartPort: 22,
EndPort: 22,
},
},
})
<-done
}

0 comments on commit 063c909

Please sign in to comment.