Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion pkg/proxy/net/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"crypto/tls"
"io"
"net"
"sync"
"time"

"github.com/pingcap/tiproxy/lib/config"
Expand All @@ -42,6 +43,11 @@ var (
ErrInvalidSequence = errors.New("invalid sequence")
)

var (
readerPool sync.Pool
writerPool sync.Pool
)

const (
DefaultConnBufferSize = 32 * 1024
)
Expand Down Expand Up @@ -86,6 +92,27 @@ type basicReadWriter struct {
inBytes uint64
outBytes uint64
sequence uint8
pooled bool
}

func getPooledReader(conn net.Conn, size int) *bufio.Reader {
if v := readerPool.Get(); v != nil {
if r := v.(*bufio.Reader); r.Size() == size {
r.Reset(conn)
return r
}
}
return bufio.NewReaderSize(conn, size)
}

func getPooledWriter(conn net.Conn, size int) *bufio.Writer {
if v := writerPool.Get(); v != nil {
if w := v.(*bufio.Writer); w.Size() == size {
w.Reset(conn)
return w
}
}
return bufio.NewWriterSize(conn, size)
}

func newBasicReadWriter(conn net.Conn, bufferSize int) *basicReadWriter {
Expand All @@ -94,7 +121,8 @@ func newBasicReadWriter(conn net.Conn, bufferSize int) *basicReadWriter {
}
return &basicReadWriter{
Conn: conn,
ReadWriter: bufio.NewReadWriter(bufio.NewReaderSize(conn, bufferSize), bufio.NewWriterSize(conn, bufferSize)),
ReadWriter: bufio.NewReadWriter(getPooledReader(conn, bufferSize), getPooledWriter(conn, bufferSize)),
pooled: true,
}
}

Expand Down Expand Up @@ -153,6 +181,22 @@ func (brw *basicReadWriter) ResetSequence() {
brw.sequence = 0
}

func (brw *basicReadWriter) Free() {
if brw.pooled {
brw.pooled = false
brw.ReadWriter.Reader.Reset(nil)
brw.ReadWriter.Writer.Reset(nil)
readerPool.Put(brw.ReadWriter.Reader)
writerPool.Put(brw.ReadWriter.Writer)
}
}

func (brw *basicReadWriter) Close() error {
err := brw.Conn.Close()
brw.Free()
return err
}

func (brw *basicReadWriter) TLSConnectionState() tls.ConnectionState {
return tls.ConnectionState{}
}
Expand Down Expand Up @@ -496,6 +540,7 @@ func (p *packetIO) Close() error {
errs = append(errs, err)
}
*/

if err := p.readWriter.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
errs = append(errs, errors.WithStack(err))
}
Expand Down
98 changes: 98 additions & 0 deletions pkg/proxy/net/packetio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ func TestForwardUntilLongData(t *testing.T) {

func TestForwardUntilError(t *testing.T) {
srvCh := make(chan *packetIO)
exitCh := make(chan struct{})
var wg waitgroup.WaitGroup
selfErr, peerErr := errors.New("self"), errors.New("peer")
// client1 writes to server1
Expand All @@ -582,6 +583,7 @@ func TestForwardUntilError(t *testing.T) {
return srv2.Flush()
})
require.ErrorIs(t, err, peerErr)
exitCh <- struct{}{}
},
1,
)
Expand All @@ -594,6 +596,7 @@ func TestForwardUntilError(t *testing.T) {
func(t *testing.T, srv2 *packetIO) {
srv2.ApplyOpts(WithWrapError(peerErr))
srvCh <- srv2
<-exitCh
},
1,
)
Expand Down Expand Up @@ -719,3 +722,98 @@ func runForwardBenchmark(b *testing.B, f func(packetIO1, packetIO2 *packetIO)) {
_ = packetIO2.Close()
wg.Wait()
}

func TestPacketIOPooling(t *testing.T) {
testTCPConn(t,
func(t *testing.T, cli *packetIO) {
brw, ok := cli.readWriter.(*basicReadWriter)
require.True(t, ok)
require.True(t, brw.pooled, "pooled flag should be true for default buffer size")

require.NoError(t, cli.WritePacket([]byte("pooltest"), true))
},
func(t *testing.T, srv *packetIO) {
brw, ok := srv.readWriter.(*basicReadWriter)
require.True(t, ok)
require.True(t, brw.pooled, "pooled flag should be true for default buffer size")

data, err := srv.ReadPacket()
require.NoError(t, err)
require.Equal(t, []byte("pooltest"), data)
},
1,
)

lg, _ := logger.CreateLoggerForTest(t)
cli, srv := net.Pipe()
cliIO := NewPacketIO(cli, lg, DefaultConnBufferSize*2)
srvIO := NewPacketIO(srv, lg, DefaultConnBufferSize*2)
brw, ok := cliIO.readWriter.(*basicReadWriter)
require.True(t, ok)
require.True(t, brw.pooled, "pooled flag should always be true")
_ = cliIO.Close()
_ = srvIO.Close()

testTCPConn(t,
func(t *testing.T, cli *packetIO) {
require.NoError(t, cli.Close())
require.NoError(t, cli.Close())
},
func(t *testing.T, srv *packetIO) {
require.NoError(t, srv.Close())
require.NoError(t, srv.Close())
},
1,
)

for i := 0; i < 100; i++ {
c1, c2 := net.Pipe()
p1 := NewPacketIO(c1, lg, DefaultConnBufferSize)
p2 := NewPacketIO(c2, lg, DefaultConnBufferSize)
_ = p1.Close()
_ = p2.Close()
}
}

func TestPoolSizeMismatch(t *testing.T) {
lg, _ := logger.CreateLoggerForTest(t)

for i := 0; i < 10; i++ {
c1, c2 := net.Pipe()
p1 := NewPacketIO(c1, lg, DefaultConnBufferSize)
p2 := NewPacketIO(c2, lg, DefaultConnBufferSize)
_ = p1.Close()
_ = p2.Close()
}

customSize := DefaultConnBufferSize * 2
c1, c2 := net.Pipe()
p1 := NewPacketIO(c1, lg, customSize)
p2 := NewPacketIO(c2, lg, customSize)
brw1, ok := p1.readWriter.(*basicReadWriter)
require.True(t, ok)
require.True(t, brw1.pooled, "pooled should always be true")
require.Equal(t, customSize, brw1.ReadWriter.Reader.Size(), "reader should have custom size")
require.Equal(t, customSize, brw1.ReadWriter.Writer.Size(), "writer should have custom size")
_ = p1.Close()
_ = p2.Close()

c1, c2 = net.Pipe()
p1 = NewPacketIO(c1, lg, customSize)
p2 = NewPacketIO(c2, lg, customSize)
brw1, ok = p1.readWriter.(*basicReadWriter)
require.True(t, ok)
require.True(t, brw1.pooled)
require.Equal(t, customSize, brw1.ReadWriter.Reader.Size())
_ = p1.Close()
_ = p2.Close()

c1, c2 = net.Pipe()
p1 = NewPacketIO(c1, lg, DefaultConnBufferSize)
p2 = NewPacketIO(c2, lg, DefaultConnBufferSize)
brw1, ok = p1.readWriter.(*basicReadWriter)
require.True(t, ok)
require.True(t, brw1.pooled)
_ = p1.Close()
_ = p2.Close()
}