Skip to content

Commit

Permalink
Improve: tunnel/tcp pipe (#219)
Browse files Browse the repository at this point in the history
Co-authored-by: xjasonlyu <xjasonlyu@gmail.com>
  • Loading branch information
nange and xjasonlyu committed Apr 3, 2023
1 parent 61a9d26 commit 2d0bd1d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 28 deletions.
15 changes: 15 additions & 0 deletions tunnel/statistic/tracker.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package statistic

import (
"errors"
"net"
"time"

Expand Down Expand Up @@ -79,6 +80,20 @@ func (tt *tcpTracker) Close() error {
return tt.Conn.Close()
}

func (tt *tcpTracker) CloseRead() error {
if cr, ok := tt.Conn.(interface{ CloseRead() error }); ok {
return cr.CloseRead()
}
return errors.New("CloseRead is not implemented")
}

func (tt *tcpTracker) CloseWrite() error {
if cw, ok := tt.Conn.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite()
}
return errors.New("CloseWrite is not implemented")
}

type udpTracker struct {
net.PacketConn `json:"-"`

Expand Down
47 changes: 19 additions & 28 deletions tunnel/tcp.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package tunnel

import (
"errors"
"io"
"net"
"sync"
Expand Down Expand Up @@ -43,42 +42,34 @@ func handleTCPConn(originConn adapter.TCPConn) {
defer remoteConn.Close()

log.Infof("[TCP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress())
if err = pipe(originConn, remoteConn); err != nil {
log.Debugf("[TCP] %s <-> %s: %v", metadata.SourceAddress(), metadata.DestinationAddress(), err)
}
pipe(originConn, remoteConn)
}

// pipe copies copy data to & from provided net.Conn(s) bidirectionally.
func pipe(origin, remote net.Conn) error {
func pipe(origin, remote net.Conn) {
wg := sync.WaitGroup{}
wg.Add(2)

var leftErr, rightErr error

go func() {
defer wg.Done()
if err := copyBuffer(remote, origin); err != nil {
leftErr = errors.Join(leftErr, err)
}
remote.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
}()

go func() {
defer wg.Done()
if err := copyBuffer(origin, remote); err != nil {
rightErr = errors.Join(rightErr, err)
}
origin.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
}()
go unidirectionalStream(remote, origin, "origin->remote", &wg)
go unidirectionalStream(origin, remote, "remote->origin", &wg)

wg.Wait()
return errors.Join(leftErr, rightErr)
}

func copyBuffer(dst io.Writer, src io.Reader) error {
func unidirectionalStream(dst, src net.Conn, dir string, wg *sync.WaitGroup) {
defer wg.Done()
buf := pool.Get(pool.RelayBufferSize)
defer pool.Put(buf)

_, err := io.CopyBuffer(dst, src, buf)
return err
if _, err := io.CopyBuffer(dst, src, buf); err != nil {
log.Debugf("[TCP] copy data for %s: %v", dir, err)
}
pool.Put(buf)
// Do the upload/download side TCP half-close.
if cr, ok := src.(interface{ CloseRead() error }); ok {
cr.CloseRead()
}
if cw, ok := dst.(interface{ CloseWrite() error }); ok {
cw.CloseWrite()
}
// Set TCP half-close timeout.
dst.SetReadDeadline(time.Now().Add(tcpWaitTimeout))
}

0 comments on commit 2d0bd1d

Please sign in to comment.