Skip to content

Commit

Permalink
Make TCP keepalive options configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
luker983 committed Jun 8, 2023
1 parent 915d031 commit 0bd8d00
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 43 deletions.
75 changes: 46 additions & 29 deletions src/cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,21 @@ import (
)

type serveCmdConfig struct {
configFile string
clientAddr4E2EE string
clientAddr6E2EE string
clientAddr4Relay string
clientAddr6Relay string
quiet bool
debug bool
simple bool
logging bool
logFile string
catchTimeout uint
connTimeout uint
configFile string
clientAddr4E2EE string
clientAddr6E2EE string
clientAddr4Relay string
clientAddr6Relay string
quiet bool
debug bool
simple bool
logging bool
logFile string
catchTimeout uint
connTimeout uint
keepaliveIdle uint
keepaliveCount uint
keepaliveInterval uint
}

type wiretapDefaultConfig struct {
Expand All @@ -59,18 +62,21 @@ type wiretapDefaultConfig struct {

// Defaults for serve command.
var serveCmd = serveCmdConfig{
configFile: "",
clientAddr4E2EE: ClientE2EESubnet4.Addr().Next().String(),
clientAddr6E2EE: ClientE2EESubnet6.Addr().Next().String(),
clientAddr4Relay: ClientRelaySubnet4.Addr().Next().Next().String(),
clientAddr6Relay: ClientRelaySubnet6.Addr().Next().Next().String(),
quiet: false,
debug: false,
simple: false,
logging: false,
logFile: "wiretap.log",
catchTimeout: 5000,
connTimeout: 5000,
configFile: "",
clientAddr4E2EE: ClientE2EESubnet4.Addr().Next().String(),
clientAddr6E2EE: ClientE2EESubnet6.Addr().Next().String(),
clientAddr4Relay: ClientRelaySubnet4.Addr().Next().Next().String(),
clientAddr6Relay: ClientRelaySubnet6.Addr().Next().Next().String(),
quiet: false,
debug: false,
simple: false,
logging: false,
logFile: "wiretap.log",
catchTimeout: 5 * 1000,
connTimeout: 5 * 1000,
keepaliveIdle: 60,
keepaliveCount: 3,
keepaliveInterval: 60,
}

var wiretapDefault = wiretapDefaultConfig{
Expand Down Expand Up @@ -112,6 +118,9 @@ func init() {
cmd.Flags().StringVarP(&serveCmd.logFile, "log-file", "o", serveCmd.logFile, "write log to this filename")
cmd.Flags().UintVarP(&serveCmd.catchTimeout, "completion-timeout", "", serveCmd.catchTimeout, "time in ms for client to complete TCP connection to server")
cmd.Flags().UintVarP(&serveCmd.connTimeout, "conn-timeout", "", serveCmd.connTimeout, "time in ms for server to wait for outgoing TCP handshakes to complete")
cmd.Flags().UintVarP(&serveCmd.connTimeout, "keepalive-idle", "", serveCmd.keepaliveIdle, "time in seconds before TCP keepalives are sent to client")
cmd.Flags().UintVarP(&serveCmd.connTimeout, "keepalive-interval", "", serveCmd.keepaliveInterval, "time in seconds between TCP keepalives")
cmd.Flags().UintVarP(&serveCmd.connTimeout, "keepalive-count", "", serveCmd.keepaliveCount, "number of unacknowledged TCP keepalives before closing connection")

cmd.Flags().StringVarP(&serveCmd.clientAddr4Relay, "ipv4-relay-client", "", serveCmd.clientAddr4Relay, "ipv4 relay address of client")
cmd.Flags().StringVarP(&serveCmd.clientAddr6Relay, "ipv6-relay-client", "", serveCmd.clientAddr6Relay, "ipv6 relay address of client")
Expand Down Expand Up @@ -219,6 +228,11 @@ func init() {
"api",
"keepalive",
"mtu",
"conn-timeout",
"completion-timeout",
"keepalive-interval",
"keepalive-count",
"keepalive-idle",
} {
err := cmd.Flags().MarkHidden(f)
if err != nil {
Expand Down Expand Up @@ -429,11 +443,14 @@ func (c serveCmdConfig) Run() {
lock.Lock()
go func() {
config := tcp.TcpConfig{
CatchTimeout: time.Duration(c.catchTimeout) * time.Millisecond,
ConnTimeout: time.Duration(c.connTimeout) * time.Millisecond,
Ipv4Addr: ipv4Addr,
Ipv6Addr: ipv6Addr,
Port: 1337,
CatchTimeout: time.Duration(c.catchTimeout) * time.Millisecond,
ConnTimeout: time.Duration(c.connTimeout) * time.Millisecond,
KeepaliveIdle: time.Duration(c.keepaliveIdle) * time.Second,
KeepaliveInterval: time.Duration(c.keepaliveInterval) * time.Second,
KeepaliveCount: int(c.keepaliveCount),
Ipv4Addr: ipv4Addr,
Ipv6Addr: ipv6Addr,
Port: 1337,
}
tcp.Handle(transportHandler, config, &lock)
wg.Done()
Expand Down
19 changes: 13 additions & 6 deletions src/transport/tcp/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)

const keepaliveIdleDefault = 1 * time.Minute
const keepaliveCountDefault = 2

// Address conversion adapted from https://git.zx2c4.com/wireguard-go/tree/tun/netstack/tun.go.
/* SPDX-License-Identifier: MIT
*
Expand Down Expand Up @@ -291,7 +288,7 @@ func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn {

// Changed from original:
// AcceptFrom is identical to Accept except that it also returns the Remote Address as seen by the endpoint.
func (l *TCPListener) AcceptFrom() (net.Conn, net.Addr, error) {
func (l *TCPListener) AcceptFrom(c *TcpConfig) (net.Conn, net.Addr, error) {
remoteAddr := tcpip.FullAddress{}
n, wq, err := l.ep.Accept(&remoteAddr)

Expand Down Expand Up @@ -327,7 +324,7 @@ func (l *TCPListener) AcceptFrom() (net.Conn, net.Addr, error) {

// Enable keepalive and set defaults so that after (idle + (count * interval)) connection will be dropped if unresponsive.
n.SocketOptions().SetKeepAlive(true)
keepaliveIdle := tcpip.KeepaliveIdleOption(keepaliveIdleDefault)
keepaliveIdle := tcpip.KeepaliveIdleOption(c.KeepaliveIdle)
err = n.SetSockOpt(&keepaliveIdle)
if err != nil {
return nil, nil, &net.OpError{
Expand All @@ -337,7 +334,17 @@ func (l *TCPListener) AcceptFrom() (net.Conn, net.Addr, error) {
Err: errors.New(err.String()),
}
}
err = n.SetSockOptInt(tcpip.KeepaliveCountOption, keepaliveCountDefault)
keepaliveInterval := tcpip.KeepaliveIntervalOption(c.KeepaliveInterval)
err = n.SetSockOpt(&keepaliveInterval)
if err != nil {
return nil, nil, &net.OpError{
Op: "accept",
Net: "tcp",
Addr: l.Addr(),
Err: errors.New(err.String()),
}
}
err = n.SetSockOptInt(tcpip.KeepaliveCountOption, c.KeepaliveCount)
if err != nil {
return nil, nil, &net.OpError{
Op: "accept",
Expand Down
19 changes: 11 additions & 8 deletions src/transport/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ import (

// Configure TCP handler.
type TcpConfig struct {
CatchTimeout time.Duration
ConnTimeout time.Duration
Ipv4Addr netip.Addr
Ipv6Addr netip.Addr
Port uint16
Ipv4Addr netip.Addr
Ipv6Addr netip.Addr
Port uint16
CatchTimeout time.Duration
ConnTimeout time.Duration
KeepaliveIdle time.Duration
KeepaliveInterval time.Duration
KeepaliveCount int
}

// tcpConn tracks a connection, source and destination IP and Port.
Expand Down Expand Up @@ -227,11 +230,11 @@ func Handle(tnet *netstack.Net, config TcpConfig, lock *sync.Mutex) {
}
}()

go startListener(tnet, s.IPTables(), &net.TCPAddr{Port: int(config.Port)}, config.Ipv4Addr, config.Ipv6Addr, s)
go startListener(tnet, s.IPTables(), &net.TCPAddr{Port: int(config.Port)}, config.Ipv4Addr, config.Ipv6Addr, s, &config)
}

// startListener accepts connections from WireGuard peer.
func startListener(tnet *netstack.Net, tables *stack.IPTables, listenAddr *net.TCPAddr, localAddr4 netip.Addr, localAddr6 netip.Addr, s *stack.Stack) {
func startListener(tnet *netstack.Net, tables *stack.IPTables, listenAddr *net.TCPAddr, localAddr4 netip.Addr, localAddr6 netip.Addr, s *stack.Stack, c *TcpConfig) {
// Workaround to get true remote address even when connection closes prematurely.
l, err := listenTCP(s, listenAddr)
if err != nil {
Expand All @@ -243,7 +246,7 @@ func startListener(tnet *netstack.Net, tables *stack.IPTables, listenAddr *net.T
log.Println("Transport: TCP listener up")
for {
// Every TCP connection gets accepted here, modified Accept function sets correct remote address.
c, remoteAddr, err := l.AcceptFrom()
c, remoteAddr, err := l.AcceptFrom(c)
if err != nil || remoteAddr == nil {
log.Println("Failed to accept connection:", err)
continue
Expand Down

0 comments on commit 0bd8d00

Please sign in to comment.