Skip to content

Commit

Permalink
feat: implement a new API to convert and add net.Conn into gnet.Client
Browse files Browse the repository at this point in the history
Fixes #362
  • Loading branch information
panjf2000 committed Apr 28, 2022
1 parent 50406b3 commit c296922
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 26 deletions.
45 changes: 23 additions & 22 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"errors"
"net"
"strconv"
"strings"
"sync"
"syscall"

Expand Down Expand Up @@ -141,6 +140,11 @@ func (cli *Client) Dial(network, address string) (Conn, error) {
if err != nil {
return nil, err
}
return cli.Enroll(c)
}

// Enroll converts a net.Conn to gnet.Conn and then adds it into Client.
func (cli *Client) Enroll(c net.Conn) (Conn, error) {
defer c.Close()

sc, ok := c.(syscall.Conn)
Expand All @@ -152,9 +156,9 @@ func (cli *Client) Dial(network, address string) (Conn, error) {
return nil, errors.New("failed to get syscall.RawConn from net.Conn")
}

var DupFD int
var dupFD int
e := rc.Control(func(fd uintptr) {
DupFD, err = unix.Dup(int(fd))
dupFD, err = unix.Dup(int(fd))
})
if err != nil {
return nil, err
Expand All @@ -163,26 +167,13 @@ func (cli *Client) Dial(network, address string) (Conn, error) {
return nil, e
}

if strings.HasPrefix(network, "tcp") {
if cli.opts.TCPNoDelay == TCPDelay {
if err = socket.SetNoDelay(DupFD, 0); err != nil {
return nil, err
}
}
if cli.opts.TCPKeepAlive > 0 {
if err = socket.SetKeepAlivePeriod(DupFD, int(cli.opts.TCPKeepAlive.Seconds())); err != nil {
return nil, err
}
}
}

if cli.opts.SocketSendBuffer > 0 {
if err = socket.SetSendBuffer(DupFD, cli.opts.SocketSendBuffer); err != nil {
if err = socket.SetSendBuffer(dupFD, cli.opts.SocketSendBuffer); err != nil {
return nil, err
}
}
if cli.opts.SocketRecvBuffer > 0 {
if err = socket.SetRecvBuffer(DupFD, cli.opts.SocketRecvBuffer); err != nil {
if err = socket.SetRecvBuffer(dupFD, cli.opts.SocketRecvBuffer); err != nil {
return nil, err
}
}
Expand All @@ -197,18 +188,28 @@ func (cli *Client) Dial(network, address string) (Conn, error) {
return nil, err
}
ua := c.LocalAddr().(*net.UnixAddr)
ua.Name = c.RemoteAddr().String() + "." + strconv.Itoa(DupFD)
gc = newTCPConn(DupFD, cli.el, sockAddr, c.LocalAddr(), c.RemoteAddr())
ua.Name = c.RemoteAddr().String() + "." + strconv.Itoa(dupFD)
gc = newTCPConn(dupFD, cli.el, sockAddr, c.LocalAddr(), c.RemoteAddr())
case *net.TCPConn:
if cli.opts.TCPNoDelay == TCPDelay {
if err = socket.SetNoDelay(dupFD, 0); err != nil {
return nil, err
}
}
if cli.opts.TCPKeepAlive > 0 {
if err = socket.SetKeepAlivePeriod(dupFD, int(cli.opts.TCPKeepAlive.Seconds())); err != nil {
return nil, err
}
}
if sockAddr, _, _, _, err = socket.GetTCPSockAddr(c.RemoteAddr().Network(), c.RemoteAddr().String()); err != nil {
return nil, err
}
gc = newTCPConn(DupFD, cli.el, sockAddr, c.LocalAddr(), c.RemoteAddr())
gc = newTCPConn(dupFD, cli.el, sockAddr, c.LocalAddr(), c.RemoteAddr())
case *net.UDPConn:
if sockAddr, _, _, _, err = socket.GetUDPSockAddr(c.RemoteAddr().Network(), c.RemoteAddr().String()); err != nil {
return nil, err
}
gc = newUDPConn(DupFD, cli.el, c.LocalAddr(), sockAddr, true)
gc = newUDPConn(dupFD, cli.el, c.LocalAddr(), sockAddr, true)
default:
return nil, gerrors.ErrUnsupportedProtocol
}
Expand Down
24 changes: 20 additions & 4 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package gnet

import (
"math/rand"
"net"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -82,7 +83,7 @@ func (ev *clientEvents) OnTick() (delay time.Duration, action Action) {
}

func TestServeWithGnetClient(t *testing.T) {
// start a engine
// start an engine
// connect 10 clients
// each client will pipe random data for 1-3 seconds.
// the writes to the engine will be random sizes. 0KB - 1MB.
Expand Down Expand Up @@ -266,7 +267,11 @@ func (s *testClientServer) OnTick() (delay time.Duration, action Action) {
if atomic.CompareAndSwapInt32(&s.started, 0, 1) {
for i := 0; i < s.nclients; i++ {
atomic.AddInt32(&s.clientActive, 1)
go startGnetClient(s.tester, s.client, s.clientEV, s.network, s.addr, s.multicore, s.async)
var netConn bool
if i%2 == 0 {
netConn = true
}
go startGnetClient(s.tester, s.client, s.clientEV, s.network, s.addr, s.multicore, s.async, netConn)
}
}
if s.network == "udp" && atomic.LoadInt32(&s.clientActive) == 0 {
Expand Down Expand Up @@ -311,9 +316,20 @@ func testServeWithGnetClient(t *testing.T, network, addr string, reuseport, reus
assert.NoError(t, err)
}

func startGnetClient(t *testing.T, cli *Client, ev *clientEvents, network, addr string, multicore, async bool) {
func startGnetClient(t *testing.T, cli *Client, ev *clientEvents, network, addr string, multicore, async, netDial bool) {
rand.Seed(time.Now().UnixNano())
c, err := cli.Dial(network, addr)
var (
c Conn
err error
)
if netDial {
var netConn net.Conn
netConn, err = net.Dial(network, addr)
require.NoError(t, err)
c, err = cli.Enroll(netConn)
} else {
c, err = cli.Dial(network, addr)
}
require.NoError(t, err)
defer c.Close()
var rspCh chan []byte
Expand Down

0 comments on commit c296922

Please sign in to comment.