Skip to content

Commit

Permalink
Add timer to reap connections not used by client
Browse files Browse the repository at this point in the history
  • Loading branch information
luker983 committed May 31, 2023
1 parent da1d233 commit 08673a9
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 23 deletions.
23 changes: 19 additions & 4 deletions src/cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"strings"
"sync"
"time"

"github.com/spf13/cobra"
"github.com/spf13/viper"
Expand Down Expand Up @@ -38,6 +39,8 @@ type serveCmdConfig struct {
simple bool
logging bool
logFile string
catchTimeout uint
connTimeout uint
}

type wiretapDefaultConfig struct {
Expand Down Expand Up @@ -66,6 +69,8 @@ var serveCmd = serveCmdConfig{
simple: false,
logging: false,
logFile: "wiretap.log",
catchTimeout: 1000,
connTimeout: 1000,
}

var wiretapDefault = wiretapDefaultConfig{
Expand Down Expand Up @@ -105,6 +110,8 @@ func init() {
cmd.Flags().BoolVarP(&serveCmd.simple, "simple", "", serveCmd.simple, "disable multihop and multiclient features for a simpler setup")
cmd.Flags().BoolVarP(&serveCmd.logging, "log", "l", serveCmd.logging, "enable logging to file")
cmd.Flags().StringVarP(&serveCmd.logFile, "log-file", "o", serveCmd.logFile, "write log to this filename")
cmd.Flags().UintVarP(&serveCmd.catchTimeout, "completion-timeout", "", serveCmd.catchTimeout, "time in ms for client to complete TCP connection to server")
cmd.Flags().UintVarP(&serveCmd.connTimeout, "conn-timeout", "", serveCmd.connTimeout, "time in ms for server to wait for outgoing TCP handshakes to complete")

cmd.Flags().StringVarP(&serveCmd.clientAddr4Relay, "ipv4-relay-client", "", serveCmd.clientAddr4Relay, "ipv4 relay address of client")
cmd.Flags().StringVarP(&serveCmd.clientAddr6Relay, "ipv6-relay-client", "", serveCmd.clientAddr6Relay, "ipv6 relay address of client")
Expand Down Expand Up @@ -298,9 +305,10 @@ func (c serveCmdConfig) Run() {
ListenPort: E2EEPort,
Peers: []peer.PeerConfigArgs{
{
PublicKey: viper.GetString("E2EE.Peer.publickey"),
Endpoint: viper.GetString("E2EE.Peer.endpoint"),
AllowedIPs: []string{c.clientAddr4E2EE + "/32", c.clientAddr6E2EE + "/128"},
PublicKey: viper.GetString("E2EE.Peer.publickey"),
Endpoint: viper.GetString("E2EE.Peer.endpoint"),
AllowedIPs: []string{c.clientAddr4E2EE + "/32", c.clientAddr6E2EE + "/128"},
PersistentKeepaliveInterval: viper.GetInt("Relay.Peer.keepalive"),
},
},
Addresses: []string{viper.GetString("E2EE.Interface.ipv4") + "/32", viper.GetString("E2EE.Interface.ipv6") + "/128", viper.GetString("E2EE.Interface.api") + "/128"},
Expand Down Expand Up @@ -419,7 +427,14 @@ func (c serveCmdConfig) Run() {
wg.Add(1)
lock.Lock()
go func() {
tcp.Handle(transportHandler, ipv4Addr, ipv6Addr, 1337, &lock)
config := tcp.TcpConfig{
CatchTimeout: time.Duration(c.catchTimeout) * time.Millisecond,
ConnTimeout: time.Duration(c.connTimeout) * time.Millisecond,
Ipv4Addr: ipv4Addr,
Ipv6Addr: ipv6Addr,
Port: 1337,
}
tcp.Handle(transportHandler, config, &lock)
wg.Done()
}()

Expand Down
45 changes: 27 additions & 18 deletions src/transport/tcp/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,14 @@ import (
"wiretap/transport"
)

// How much time the client has to complete the TCP handshake before connection is dropped.
const catchTimeout = time.Duration(1000) * time.Millisecond
// Configure TCP handler.
type TcpConfig struct {
CatchTimeout time.Duration
ConnTimeout time.Duration
Ipv4Addr netip.Addr
Ipv6Addr netip.Addr
Port uint16
}

// tcpConn tracks a connection, source and destination IP and Port.
type tcpConn struct {
Expand All @@ -54,6 +60,7 @@ var isOpenLock = sync.RWMutex{}
type preroutingMatch struct {
pktChan chan stack.PacketBufferPtr
endpoint *channel.Endpoint
config *TcpConfig
}

// Match looks for SYN packets (start of a tcp conn). Before proxying connection, we need to check
Expand All @@ -75,9 +82,9 @@ func (m preroutingMatch) Match(hook stack.Hook, packet stack.PacketBufferPtr, in
ctrack, ok := isOpen[c]
isOpenLock.RUnlock()

// If not in conn map, drop this packet for now, but clone so it can
// be reinjected if connections are successful.
if !ok {
// If not in conn map, drop this packet for now, but clone so it can
// be reinjected if connections are successful.
isOpenLock.Lock()
// In progress, but not ready to forward SYN packets yet.
isOpen[c] = connTrack{
Expand All @@ -87,17 +94,18 @@ func (m preroutingMatch) Match(hook stack.Hook, packet stack.PacketBufferPtr, in

packetClone := packet.Clone()
go func() {
checkIfOpen(c, m.pktChan, packetClone, m.endpoint)
checkIfOpen(c, m, packetClone)
packetClone.DecRef()
}()

// Hotdrop because we're taking control of the packet.
return false, true
// Already checking if port is open. Do nothing.
} else if ctrack.Connecting {
// Already checking if port is open. Do nothing.
return false, false
// Connection is verified to be open. Allow this connection and reset conn map.
} else {
// Connection is verified to be open. Allow this connection and reset conn map.

return true, false
}
}
Expand All @@ -111,9 +119,9 @@ func (m preroutingMatch) Match(hook stack.Hook, packet stack.PacketBufferPtr, in
}

// If destination is open, whitelist and reinject. Otherwise send reset.
func checkIfOpen(conn tcpConn, pktChan chan stack.PacketBufferPtr, packet stack.PacketBufferPtr, endpoint *channel.Endpoint) {
func checkIfOpen(conn tcpConn, m preroutingMatch, packet stack.PacketBufferPtr) {
log.Printf("(client %v) - Transport: TCP -> %v", conn.Source, conn.Dest)
c, err := net.Dial("tcp", conn.Dest)
c, err := net.DialTimeout("tcp", conn.Dest, m.config.ConnTimeout)
if err != nil {
//log.Printf("Error connecting to %s: %s\n", conn.Dest, err)

Expand All @@ -122,7 +130,7 @@ func checkIfOpen(conn tcpConn, pktChan chan stack.PacketBufferPtr, packet stack.
if syserr, ok := oerr.Err.(*os.SyscallError); ok {
if syserr.Err == syscall.ECONNREFUSED {
//log.Println("Connection refused, sending reset")
pktChan <- packet.Clone()
m.pktChan <- packet.Clone()
}
}
}
Expand All @@ -146,7 +154,7 @@ func checkIfOpen(conn tcpConn, pktChan chan stack.PacketBufferPtr, packet stack.
// Start "catch" timer to make sure connection is actually used.
go func() {
select {
case <-time.After(catchTimeout):
case <-time.After(m.config.CatchTimeout):
c.Close()
case <-caughtChan:
}
Expand All @@ -160,12 +168,12 @@ func checkIfOpen(conn tcpConn, pktChan chan stack.PacketBufferPtr, packet stack.
new_packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: packet.ToBuffer(),
})
endpoint.InjectInbound(netProto, new_packet)
m.endpoint.InjectInbound(netProto, new_packet)
}

// Handle creates a DNAT rule that forwards destination packets to a tcp listener.
// Once a connection is accepted, it gets handed off to handleConn().
func Handle(tnet *netstack.Net, ipv4Addr netip.Addr, ipv6Addr netip.Addr, port uint16, lock *sync.Mutex) {
func Handle(tnet *netstack.Net, config TcpConfig, lock *sync.Mutex) {
s := tnet.Stack()

// Create iptables rule.
Expand All @@ -177,14 +185,15 @@ func Handle(tnet *netstack.Net, ipv4Addr netip.Addr, ipv6Addr netip.Addr, port u
match := preroutingMatch{
pktChan: make(chan stack.PacketBufferPtr, 1),
endpoint: tnet.Endpoint(),
config: &config,
}

rule4 := stack.Rule{
Filter: headerFilter,
Matchers: []stack.Matcher{match},
Target: &stack.DNATTarget{
Addr: tcpip.Address(ipv4Addr.AsSlice()),
Port: port,
Addr: tcpip.Address(config.Ipv4Addr.AsSlice()),
Port: config.Port,
NetworkProtocol: ipv4.ProtocolNumber,
},
}
Expand All @@ -193,8 +202,8 @@ func Handle(tnet *netstack.Net, ipv4Addr netip.Addr, ipv6Addr netip.Addr, port u
Filter: headerFilter,
Matchers: []stack.Matcher{match},
Target: &stack.DNATTarget{
Addr: tcpip.Address(ipv6Addr.AsSlice()),
Port: port,
Addr: tcpip.Address(config.Ipv6Addr.AsSlice()),
Port: config.Port,
NetworkProtocol: ipv6.ProtocolNumber,
},
}
Expand All @@ -215,7 +224,7 @@ func Handle(tnet *netstack.Net, ipv4Addr netip.Addr, ipv6Addr netip.Addr, port u
}
}()

go startListener(tnet, s.IPTables(), &net.TCPAddr{Port: int(port)}, ipv4Addr, ipv6Addr, s)
go startListener(tnet, s.IPTables(), &net.TCPAddr{Port: int(config.Port)}, config.Ipv4Addr, config.Ipv6Addr, s)
}

// startListener accepts connections from WireGuard peer.
Expand Down
2 changes: 1 addition & 1 deletion wiretap.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ARG https_proxy

# Utilities for testing
RUN apt-get update
RUN apt-get install net-tools nmap dnsutils tcpdump iproute2 vim netcat iputils-ping wireguard iperf xsel -y
RUN apt-get install net-tools nmap dnsutils tcpdump iproute2 vim netcat iputils-ping wireguard iperf xsel masscan -y

WORKDIR /wiretap
COPY ./src/go.mod ./src/go.sum ./
Expand Down

0 comments on commit 08673a9

Please sign in to comment.