Skip to content

Commit

Permalink
opt: refactor the logic of handling UDP sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
panjf2000 committed Dec 4, 2021
1 parent 4f2cfa3 commit d72d3de
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 36 deletions.
2 changes: 1 addition & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func NewClient(eventHandler EventHandler, opts ...Option) (cli *Client, err erro
options.ReadBufferCap = toolkit.CeilToPowerOfTwo(rbc)
}
el.buffer = make([]byte, options.ReadBufferCap)
el.udpSockets = make(map[int]*conn)
el.clientUDPSockets = make(map[int]*conn)
el.connections = make(map[int]*conn)
el.eventHandler = eventHandler
cli.el = el
Expand Down
14 changes: 7 additions & 7 deletions connection_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ import (

type conn struct {
fd int // file descriptor
sa unix.Sockaddr // remote socket address
ctx interface{} // user-defined context
peer unix.Sockaddr // remote socket address
loop *eventloop // connected event-loop
codec ICodec // codec for TCP
opened bool // connection opened event fired
Expand All @@ -50,7 +50,7 @@ type conn struct {
func newTCPConn(fd int, el *eventloop, sa unix.Sockaddr, codec ICodec, localAddr, remoteAddr net.Addr) (c *conn) {
c = &conn{
fd: fd,
sa: sa,
peer: sa,
loop: el,
codec: codec,
localAddr: localAddr,
Expand All @@ -65,7 +65,7 @@ func newTCPConn(fd int, el *eventloop, sa unix.Sockaddr, codec ICodec, localAddr

func (c *conn) releaseTCP() {
c.opened = false
c.sa = nil
c.peer = nil
c.ctx = nil
c.localAddr = nil
c.remoteAddr = nil
Expand All @@ -79,13 +79,13 @@ func (c *conn) releaseTCP() {
func newUDPConn(fd int, el *eventloop, localAddr net.Addr, sa unix.Sockaddr, connected bool) (c *conn) {
c = &conn{
fd: fd,
sa: sa,
peer: sa,
loop: el,
localAddr: localAddr,
remoteAddr: socket.SockaddrToUDPAddr(sa),
}
if connected {
c.sa = nil
c.peer = nil
}
return
}
Expand Down Expand Up @@ -166,10 +166,10 @@ func (c *conn) asyncWrite(itf interface{}) error {
func (c *conn) sendTo(buf []byte) error {
c.loop.eventHandler.PreWrite(c)
defer c.loop.eventHandler.AfterWrite(c, buf)
if c.sa == nil {
if c.peer == nil {
return unix.Send(c.fd, buf, 0)
}
return unix.Sendto(c.fd, buf, 0, c.sa)
return unix.Sendto(c.fd, buf, 0, c.peer)
}

// ================================== Non-concurrency-safe API's ==================================
Expand Down
47 changes: 27 additions & 20 deletions eventloop_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,16 @@ import (
)

type eventloop struct {
ln *listener // listener
idx int // loop index in the server loops list
svr *server // server in loop
poller *netpoll.Poller // epoll or kqueue
buffer []byte // read packet buffer whose capacity is set by user, default value is 64KB
connCount int32 // number of active connections in event-loop
udpSockets map[int]*conn // UDP socket map: fd -> conn
connections map[int]*conn // TCP connection map: fd -> conn
eventHandler EventHandler // user eventHandler
ln *listener // listener
idx int // loop index in the server loops list
svr *server // server in loop
poller *netpoll.Poller // epoll or kqueue
buffer []byte // read packet buffer whose capacity is set by user, default value is 64KB
connCount int32 // number of active connections in event-loop
connections map[int]*conn // TCP connection map: fd -> conn
eventHandler EventHandler // user eventHandler
clientUDPSockets map[int]*conn // client-side UDP socket map: fd -> conn
serverUDPSockets map[unix.Sockaddr]*conn // server-side UDP socket map: Sockaddr -> conn
}

func (el *eventloop) getLogger() logging.Logger {
Expand All @@ -58,11 +59,17 @@ func (el *eventloop) loadConn() int32 {
return atomic.LoadInt32(&el.connCount)
}

func (el *eventloop) closeAllConns() {
func (el *eventloop) closeAllSockets() {
// Close loops and all outstanding connections
for _, c := range el.connections {
_ = el.loopCloseConn(c, nil)
}
for _, c := range el.clientUDPSockets {
c.releaseUDP()
}
for _, c := range el.serverUDPSockets {
c.releaseUDP()
}
}

func (el *eventloop) loopRegister(itf interface{}) error {
Expand All @@ -76,7 +83,7 @@ func (el *eventloop) loopRegister(itf interface{}) error {
c.releaseUDP()
return err
}
el.udpSockets[c.fd] = c
el.clientUDPSockets[c.fd] = c
return nil
}
if err := el.poller.AddRead(c.pollAttachment); err != nil {
Expand Down Expand Up @@ -277,11 +284,15 @@ func (el *eventloop) loopReadUDP(fd int) error {
return fmt.Errorf("failed to read UDP packet from fd=%d in event-loop(%d), %v",
fd, el.idx, os.NewSyscallError("recvfrom", err))
}
c := el.udpSockets[fd]
var oneOff bool
if c == nil {
c = newUDPConn(fd, el, el.ln.lnaddr, sa, false)
oneOff = true
var c *conn
if fd == el.ln.fd {
c = el.serverUDPSockets[sa]
if c == nil {
c = newUDPConn(fd, el, el.ln.lnaddr, sa, false)
el.serverUDPSockets[sa] = c
}
} else {
c = el.clientUDPSockets[fd]
}
out, action := el.eventHandler.React(el.buffer[:n], c)
if out != nil {
Expand All @@ -290,9 +301,5 @@ func (el *eventloop) loopReadUDP(fd int) error {
if action == Shutdown {
return gerrors.ErrServerShutdown
}
if oneOff {
c.releaseUDP()
}

return nil
}
4 changes: 2 additions & 2 deletions reactor_default_bsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (el *eventloop) activateSubReactor(lockOSThread bool) {
}

defer func() {
el.closeAllConns()
el.closeAllSockets()
el.svr.signalShutdown()
}()

Expand Down Expand Up @@ -81,7 +81,7 @@ func (el *eventloop) loopRun(lockOSThread bool) {
}

defer func() {
el.closeAllConns()
el.closeAllSockets()
el.ln.close()
el.svr.signalShutdown()
}()
Expand Down
4 changes: 2 additions & 2 deletions reactor_default_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (el *eventloop) activateSubReactor(lockOSThread bool) {
}

defer func() {
el.closeAllConns()
el.closeAllSockets()
el.svr.signalShutdown()
}()

Expand Down Expand Up @@ -96,7 +96,7 @@ func (el *eventloop) loopRun(lockOSThread bool) {
}

defer func() {
el.closeAllConns()
el.closeAllSockets()
el.ln.close()
el.svr.signalShutdown()
}()
Expand Down
4 changes: 2 additions & 2 deletions reactor_optimized_bsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (el *eventloop) activateSubReactor(lockOSThread bool) {
}

defer func() {
el.closeAllConns()
el.closeAllSockets()
el.svr.signalShutdown()
}()

Expand All @@ -66,7 +66,7 @@ func (el *eventloop) loopRun(lockOSThread bool) {
}

defer func() {
el.closeAllConns()
el.closeAllSockets()
el.ln.close()
el.svr.signalShutdown()
}()
Expand Down
4 changes: 2 additions & 2 deletions reactor_optimized_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (el *eventloop) activateSubReactor(lockOSThread bool) {
}

defer func() {
el.closeAllConns()
el.closeAllSockets()
el.svr.signalShutdown()
}()

Expand All @@ -65,7 +65,7 @@ func (el *eventloop) loopRun(lockOSThread bool) {
}

defer func() {
el.closeAllConns()
el.closeAllSockets()
el.ln.close()
el.svr.signalShutdown()
}()
Expand Down
4 changes: 4 additions & 0 deletions server_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"sync"
"sync/atomic"

"golang.org/x/sys/unix"

"github.com/panjf2000/gnet/internal/netpoll"
"github.com/panjf2000/gnet/pkg/errors"
)
Expand Down Expand Up @@ -109,6 +111,7 @@ func (svr *server) activateEventLoops(numEventLoop int) (err error) {
el.svr = svr
el.poller = p
el.buffer = make([]byte, svr.opts.ReadBufferCap)
el.serverUDPSockets = make(map[unix.Sockaddr]*conn)
el.connections = make(map[int]*conn)
el.eventHandler = svr.eventHandler
if err = el.poller.AddRead(el.ln.packPollAttachment(el.loopAccept)); err != nil {
Expand Down Expand Up @@ -141,6 +144,7 @@ func (svr *server) activateReactors(numEventLoop int) error {
el.svr = svr
el.poller = p
el.buffer = make([]byte, svr.opts.ReadBufferCap)
el.serverUDPSockets = make(map[unix.Sockaddr]*conn)
el.connections = make(map[int]*conn)
el.eventHandler = svr.eventHandler
svr.lb.register(el)
Expand Down

0 comments on commit d72d3de

Please sign in to comment.