Skip to content

Commit

Permalink
Use reported PIDs for DNS requests and improve data gathering process
Browse files Browse the repository at this point in the history
  • Loading branch information
dhaavi committed Jul 20, 2023
1 parent 5d7caeb commit 41c5266
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 69 deletions.
164 changes: 95 additions & 69 deletions network/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -172,6 +173,9 @@ type Connection struct { //nolint:maligned // TODO: fix alignment
StopTunnel() error
}

RecvBytes uint64
SentBytes uint64

// pkgQueue is used to serialize packet handling for a single
// connection and is served by the connections packetHandler.
pktQueue chan packet.Packet
Expand Down Expand Up @@ -264,24 +268,43 @@ func NewConnectionFromDNSRequest(ctx context.Context, fqdn string, cnames []stri
ipVersion = packet.IPv4
}

// Get Process.
// FIXME: Find direct or redirected connection and grab the PID from there.
// Create packet info for dns request connection.
pi := &packet.Info{
Inbound: false, // outbound as we are looking for the process of the source address
Version: ipVersion,
Protocol: packet.UDP,
Src: localIP, // source as in the process we are looking for
SrcPort: localPort, // source as in the process we are looking for
Dst: nil, // do not record direction
DstPort: 0, // do not record direction
PID: process.UndefinedProcessID,
}

// Check if the dns request connection was reported with process info.
dnsRequestConnID := pi.CreateConnectionID()
// Cut the destination, as the dns request may have been redirected and we
// don't know the original destination.
dnsRequestConnIDPrefix, ok := strings.CutSuffix(dnsRequestConnID, "<nil>-0")
if !ok {
log.Tracer(ctx).Warningf("network: unexpected connection ID for finding dns requests connection: %s", dnsRequestConnID)
}
// Find matching dns request connection.
dnsRequestConn, ok := conns.findByPrefix(dnsRequestConnIDPrefix)
if ok && dnsRequestConn.PID != process.UndefinedProcessID {
log.Tracer(ctx).Debugf("network: found matching dns request connection %s", dnsRequestConn)
pi.PID = dnsRequestConn.PID
}

// Find process by remote IP/Port.
pid, _, _ := process.GetPidOfConnection(
ctx,
&packet.Info{
Inbound: false, // outbound as we are looking for the process of the source address
Version: ipVersion,
Protocol: packet.UDP,
Src: localIP, // source as in the process we are looking for
SrcPort: localPort, // source as in the process we are looking for
Dst: nil, // do not record direction
DstPort: 0, // do not record direction
PID: process.UndefinedProcessID,
},
)
proc, _ := process.GetProcessWithProfile(ctx, pid)
if pi.PID == process.UndefinedProcessID {
pi.PID, _, _ = process.GetPidOfConnection(
ctx,
pi,
)
}

// Get process and profile with PID.
proc, _ := process.GetProcessWithProfile(ctx, pi.PID)

timestamp := time.Now().Unix()
dnsConn := &Connection{
Expand Down Expand Up @@ -378,8 +401,7 @@ func NewIncompleteConnection(pkt packet.Packet) *Connection {
// GatherConnectionInfo gathers information on the process and remote entity.
func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) {
// Get PID if not yet available.
// FIXME: Only match for UndefinedProcessID when integrations have been updated.
if conn.PID <= 0 {
if conn.PID == process.UndefinedProcessID {
// Get process by looking at the system state tables.
// Apply direction as reported from the state tables.
conn.PID, conn.Inbound, _ = process.GetPidOfConnection(pkt.Ctx(), pkt.Info())
Expand All @@ -390,20 +412,22 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) {
if conn.process == nil {
// We got connection from the system.
conn.process, err = process.GetProcessWithProfile(pkt.Ctx(), conn.PID)
if err != nil {
if err == nil {
// Add process/profile metadata for connection.
conn.ProcessContext = getProcessContext(pkt.Ctx(), conn.process)
conn.ProfileRevisionCounter = conn.process.Profile().RevisionCnt()

// Inherit internal status of profile.
if localProfile := conn.process.Profile().LocalProfile(); localProfile != nil {
conn.Internal = localProfile.Internal
}
} else {
conn.process = nil
err = fmt.Errorf("failed to get process and profile of PID %d: %w", conn.PID, err)
log.Tracer(pkt.Ctx()).Debugf("network: %s", err)
return err
}

// Add process/profile metadata for connection.
conn.ProcessContext = getProcessContext(pkt.Ctx(), conn.process)
conn.ProfileRevisionCounter = conn.process.Profile().RevisionCnt()

// Inherit internal status of profile.
if localProfile := conn.process.Profile().LocalProfile(); localProfile != nil {
conn.Internal = localProfile.Internal
if pkt.InfoOnly() {
log.Tracer(pkt.Ctx()).Debugf("network: failed to get process and profile of PID %d: %s", conn.PID, err)
} else {
log.Tracer(pkt.Ctx()).Warningf("network: failed to get process and profile of PID %d: %s", conn.PID, err)
}
}
}

Expand Down Expand Up @@ -435,48 +459,50 @@ func (conn *Connection) GatherConnectionInfo(pkt packet.Packet) (err error) {
conn.Scope = IncomingInvalid
}
} else {
// Outbound direct (possibly P2P) connection.
switch conn.Entity.IPScope {
case netutils.HostLocal:
conn.Scope = PeerHost
case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
conn.Scope = PeerLAN
case netutils.Global, netutils.GlobalMulticast:
conn.Scope = PeerInternet

// check if we can find a domain for that IP
ipinfo, err := resolver.GetIPInfo(conn.process.Profile().LocalProfile().ID, pkt.Info().RemoteIP().String())
if err != nil {
// Try again with the global scope, in case DNS went through the system resolver.
ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String())
}
if err == nil {
lastResolvedDomain := ipinfo.MostRecentDomain()
if lastResolvedDomain != nil {
conn.Scope = lastResolvedDomain.Domain
conn.Entity.Domain = lastResolvedDomain.Domain
conn.Entity.CNAME = lastResolvedDomain.CNAMEs
conn.DNSContext = lastResolvedDomain.DNSRequestContext
conn.Resolver = lastResolvedDomain.Resolver
removeOpenDNSRequest(conn.process.Pid, lastResolvedDomain.Domain)
}
case netutils.Undefined, netutils.Invalid:
fallthrough
default:
conn.Scope = PeerInvalid
}
}
}

// check if destination IP is the captive portal's IP
portal := netenv.GetCaptivePortal()
if pkt.Info().RemoteIP().Equal(portal.IP) {
conn.Scope = portal.Domain
conn.Entity.Domain = portal.Domain
// Find domain and DNS context of entity.
if conn.Entity.Domain == "" && conn.process.Profile() != nil {
// check if we can find a domain for that IP
ipinfo, err := resolver.GetIPInfo(conn.process.Profile().LocalProfile().ID, pkt.Info().RemoteIP().String())
if err != nil {
// Try again with the global scope, in case DNS went through the system resolver.
ipinfo, err = resolver.GetIPInfo(resolver.IPInfoProfileScopeGlobal, pkt.Info().RemoteIP().String())
}
if err == nil {
lastResolvedDomain := ipinfo.MostRecentDomain()
if lastResolvedDomain != nil {
conn.Scope = lastResolvedDomain.Domain
conn.Entity.Domain = lastResolvedDomain.Domain
conn.Entity.CNAME = lastResolvedDomain.CNAMEs
conn.DNSContext = lastResolvedDomain.DNSRequestContext
conn.Resolver = lastResolvedDomain.Resolver
removeOpenDNSRequest(conn.process.Pid, lastResolvedDomain.Domain)
}
}
}

if conn.Scope == "" {
// outbound direct (possibly P2P) connection
switch conn.Entity.IPScope {
case netutils.HostLocal:
conn.Scope = PeerHost
case netutils.LinkLocal, netutils.SiteLocal, netutils.LocalMulticast:
conn.Scope = PeerLAN
case netutils.Global, netutils.GlobalMulticast:
conn.Scope = PeerInternet

case netutils.Undefined, netutils.Invalid:
fallthrough
default:
conn.Scope = PeerInvalid
}
}
// Check if destination IP is the captive portal's IP.
if conn.Entity.Domain == "" {
portal := netenv.GetCaptivePortal()
if pkt.Info().RemoteIP().Equal(portal.IP) {
conn.Scope = portal.Domain
conn.Entity.Domain = portal.Domain
}
}

Expand Down Expand Up @@ -838,7 +864,7 @@ func packetHandlerHandleConn(ctx context.Context, conn *Connection, pkt packet.P
case conn.Verdict.Firewall != VerdictUndecided:
tracer.Debugf("filter: connection %s fast-tracked", pkt)
default:
tracer.Infof("filter: gathered data on connection %s", conn)
tracer.Debugf("filter: gathered data on connection %s", conn)
}
// Submit trace logs.
tracer.Submit()
Expand Down
16 changes: 16 additions & 0 deletions network/connection_store.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package network

import (
"strings"
"sync"
)

Expand Down Expand Up @@ -37,6 +38,21 @@ func (cs *connectionStore) get(id string) (*Connection, bool) {
return conn, ok
}

// findByPrefix returns the first connection where the key matches the given prefix.
// If the prefix matches multiple entries, the result is not deterministic.
func (cs *connectionStore) findByPrefix(prefix string) (*Connection, bool) {
cs.rw.RLock()
defer cs.rw.RUnlock()

for key, conn := range cs.items {
if strings.HasPrefix(key, prefix) {
return conn, true
}
}

return nil, false
}

func (cs *connectionStore) clone() map[string]*Connection {
cs.rw.RLock()
defer cs.rw.RUnlock()
Expand Down

0 comments on commit 41c5266

Please sign in to comment.