Skip to content

Commit

Permalink
Merge pull request #515 from weaveworks/testing
Browse files Browse the repository at this point in the history
Remove some duplicate functionality, add some basic tests.
  • Loading branch information
tomwilkie committed Sep 24, 2015
2 parents d8a3372 + e2dfcb1 commit dd696cd
Show file tree
Hide file tree
Showing 11 changed files with 261 additions and 103 deletions.
22 changes: 14 additions & 8 deletions probe/endpoint/conntrack.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,14 @@ type conntrack struct {
Flows []Flow `xml:"flow"`
}

// Conntracker is something that tracks connections.
type Conntracker interface {
WalkFlows(f func(Flow))
Stop()
}

// Conntracker uses the conntrack command to track network connections
type Conntracker struct {
type conntracker struct {
sync.Mutex
cmd exec.Cmd
activeFlows map[int64]Flow // active flows in state != TIME_WAIT
Expand All @@ -75,11 +81,11 @@ type Conntracker struct {
}

// NewConntracker creates and starts a new Conntracter
func NewConntracker(existingConns bool, args ...string) (*Conntracker, error) {
func NewConntracker(existingConns bool, args ...string) (Conntracker, error) {
if !ConntrackModulePresent() {
return nil, fmt.Errorf("No conntrack module")
}
result := &Conntracker{
result := &conntracker{
activeFlows: map[int64]Flow{},
existingConns: existingConns,
}
Expand Down Expand Up @@ -112,7 +118,7 @@ var ConntrackModulePresent = func() bool {
}

// NB this is not re-entrant!
func (c *Conntracker) run(args ...string) {
func (c *conntracker) run(args ...string) {
if c.existingConns {
// Fork another conntrack, just to capture existing connections
// for which we don't get events
Expand Down Expand Up @@ -178,7 +184,7 @@ func (c *Conntracker) run(args ...string) {
}
}

func (c *Conntracker) existingConnections(args ...string) ([]Flow, error) {
func (c *conntracker) existingConnections(args ...string) ([]Flow, error) {
args = append([]string{"-L", "-o", "xml", "-p", "tcp"}, args...)
cmd := exec.Command("conntrack", args...)
stdout, err := cmd.StdoutPipe()
Expand All @@ -203,7 +209,7 @@ func (c *Conntracker) existingConnections(args ...string) ([]Flow, error) {
}

// Stop stop stop
func (c *Conntracker) Stop() {
func (c *conntracker) Stop() {
c.Lock()
defer c.Unlock()
if c.cmd == nil {
Expand All @@ -215,7 +221,7 @@ func (c *Conntracker) Stop() {
}
}

func (c *Conntracker) handleFlow(f Flow, forceAdd bool) {
func (c *conntracker) handleFlow(f Flow, forceAdd bool) {
// A flow consists of 3 'metas' - the 'original' 4 tuple (as seen by this
// host) and the 'reply' 4 tuple, which is what it has been rewritten to.
// This code finds those metas, which are identified by a Direction
Expand Down Expand Up @@ -260,7 +266,7 @@ func (c *Conntracker) handleFlow(f Flow, forceAdd bool) {

// WalkFlows calls f with all active flows and flows that have come and gone
// since the last call to WalkFlows
func (c *Conntracker) WalkFlows(f func(Flow)) {
func (c *conntracker) WalkFlows(f func(Flow)) {
c.Lock()
defer c.Unlock()
for _, flow := range c.activeFlows {
Expand Down
90 changes: 50 additions & 40 deletions probe/endpoint/conntrack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,54 +13,62 @@ import (
testExec "github.com/weaveworks/scope/test/exec"
)

func makeFlow(id int64, srcIP, dstIP string, srcPort, dstPort int, ty, state string) Flow {
func makeFlow(ty string) Flow {
return Flow{
XMLName: xml.Name{
Local: "flow",
},
Type: ty,
Metas: []Meta{
{
XMLName: xml.Name{
Local: "meta",
},
Direction: "original",
Layer3: Layer3{
XMLName: xml.Name{
Local: "layer3",
},
SrcIP: srcIP,
DstIP: dstIP,
},
Layer4: Layer4{
XMLName: xml.Name{
Local: "layer4",
},
SrcPort: srcPort,
DstPort: dstPort,
Proto: TCP,
},
}
}

func addMeta(f *Flow, dir, srcIP, dstIP string, srcPort, dstPort int) *Meta {
meta := Meta{
XMLName: xml.Name{
Local: "meta",
},
Direction: dir,
Layer3: Layer3{
XMLName: xml.Name{
Local: "layer3",
},
SrcIP: srcIP,
DstIP: dstIP,
},
Layer4: Layer4{
XMLName: xml.Name{
Local: "layer4",
},
{
XMLName: xml.Name{
Local: "meta",
},
Direction: "independent",
ID: id,
State: state,
Layer3: Layer3{
XMLName: xml.Name{
Local: "layer3",
},
},
Layer4: Layer4{
XMLName: xml.Name{
Local: "layer4",
},
},
SrcPort: srcPort,
DstPort: dstPort,
Proto: TCP,
},
}
f.Metas = append(f.Metas, meta)
return &meta
}

func addIndependant(f *Flow, id int64, state string) *Meta {
meta := Meta{
XMLName: xml.Name{
Local: "meta",
},
Direction: "independent",
ID: id,
State: state,
Layer3: Layer3{
XMLName: xml.Name{
Local: "layer3",
},
},
Layer4: Layer4{
XMLName: xml.Name{
Local: "layer4",
},
},
}
f.Metas = append(f.Metas, meta)
return &meta
}

func TestConntracker(t *testing.T) {
Expand Down Expand Up @@ -121,7 +129,9 @@ func TestConntracker(t *testing.T) {
}
}

flow1 := makeFlow(1, "1.2.3.4", "2.3.4.5", 2, 3, New, "")
flow1 := makeFlow(New)
addMeta(&flow1, "original", "1.2.3.4", "2.3.4.5", 2, 3)
addIndependant(&flow1, 1, "")
writeFlow(flow1)
test.Poll(t, ts, []Flow{flow1}, have)

Expand Down
18 changes: 8 additions & 10 deletions probe/endpoint/nat.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ type endpointMapping struct {
rewrittenPort int
}

type natmapper struct {
*Conntracker
// NATMapper rewrites a report to deal with NAT's connections
type NATMapper struct {
Conntracker
}

func newNATMapper() (*natmapper, error) {
ct, err := NewConntracker(true, "--any-nat")
if err != nil {
return nil, err
}
return &natmapper{ct}, nil
// NewNATMapper is exposed for testing
func NewNATMapper(ct Conntracker) NATMapper {
return NATMapper{ct}
}

func toMapping(f Flow) *endpointMapping {
Expand All @@ -49,9 +47,9 @@ func toMapping(f Flow) *endpointMapping {
return &mapping
}

// applyNAT duplicates Nodes in the endpoint topology of a
// ApplyNAT duplicates Nodes in the endpoint topology of a
// report, based on the NAT table as returns by natTable.
func (n *natmapper) applyNAT(rpt report.Report, scope string) {
func (n NATMapper) ApplyNAT(rpt report.Report, scope string) {
n.WalkFlows(func(f Flow) {
var (
mapping = toMapping(f)
Expand Down
97 changes: 97 additions & 0 deletions probe/endpoint/nat_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package endpoint_test

import (
"reflect"
"testing"

"github.com/weaveworks/scope/probe/endpoint"
"github.com/weaveworks/scope/report"
"github.com/weaveworks/scope/test"
)

type mockConntracker struct {
flows []endpoint.Flow
}

func (m *mockConntracker) WalkFlows(f func(endpoint.Flow)) {
for _, flow := range m.flows {
f(flow)
}
}

func (m *mockConntracker) Stop() {}

func TestNat(t *testing.T) {
// test that two containers, on the docker network, get their connections mapped
// correctly.
// the setup is this:
//
// container2 (10.0.47.2:222222), host2 (2.3.4.5:22223) ->
// host1 (1.2.3.4:80), container1 (10.0.47.2:80)

// from the PoV of host1
{
flow := makeFlow("")
addIndependant(&flow, 1, "")
flow.Original = addMeta(&flow, "original", "2.3.4.5", "1.2.3.4", 222222, 80)
flow.Reply = addMeta(&flow, "reply", "10.0.47.1", "2.3.4.5", 80, 222222)
ct := &mockConntracker{
flows: []endpoint.Flow{flow},
}

have := report.MakeReport()
originalID := report.MakeEndpointNodeID("host1", "10.0.47.1", "80")
have.Endpoint.AddNode(originalID, report.MakeNodeWith(report.Metadata{
endpoint.Addr: "10.0.47.1",
endpoint.Port: "80",
"foo": "bar",
}))

want := have.Copy()
want.Endpoint.AddNode(report.MakeEndpointNodeID("host1", "1.2.3.4", "80"), report.MakeNodeWith(report.Metadata{
endpoint.Addr: "1.2.3.4",
endpoint.Port: "80",
"copy_of": originalID,
"foo": "bar",
}))

natmapper := endpoint.NewNATMapper(ct)
natmapper.ApplyNAT(have, "host1")
if !reflect.DeepEqual(want, have) {
t.Fatal(test.Diff(want, have))
}
}

// form the PoV of host2
{
flow := makeFlow("")
addIndependant(&flow, 2, "")
flow.Original = addMeta(&flow, "original", "10.0.47.2", "1.2.3.4", 22222, 80)
flow.Reply = addMeta(&flow, "reply", "1.2.3.4", "2.3.4.5", 80, 22223)
ct := &mockConntracker{
flows: []endpoint.Flow{flow},
}

have := report.MakeReport()
originalID := report.MakeEndpointNodeID("host2", "10.0.47.2", "22222")
have.Endpoint.AddNode(originalID, report.MakeNodeWith(report.Metadata{
endpoint.Addr: "10.0.47.2",
endpoint.Port: "22222",
"foo": "baz",
}))

want := have.Copy()
want.Endpoint.AddNode(report.MakeEndpointNodeID("host2", "2.3.4.5", "22223"), report.MakeNodeWith(report.Metadata{
endpoint.Addr: "2.3.4.5",
endpoint.Port: "22223",
"copy_of": originalID,
"foo": "baz",
}))

natmapper := endpoint.NewNATMapper(ct)
natmapper.ApplyNAT(have, "host1")
if !reflect.DeepEqual(want, have) {
t.Fatal(test.Diff(want, have))
}
}
}
21 changes: 11 additions & 10 deletions probe/endpoint/reporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ type Reporter struct {
hostName string
includeProcesses bool
includeNAT bool
conntracker *Conntracker
natmapper *natmapper
conntracker Conntracker
natmapper *NATMapper
revResolver *ReverseResolver
}

Expand All @@ -51,8 +51,8 @@ var SpyDuration = prometheus.NewSummaryVec(
func NewReporter(hostID, hostName string, includeProcesses bool, useConntrack bool) *Reporter {
var (
conntrackModulePresent = ConntrackModulePresent()
conntracker *Conntracker
natmapper *natmapper
conntracker Conntracker
natmapper NATMapper
err error
)
if conntrackModulePresent && useConntrack {
Expand All @@ -62,17 +62,18 @@ func NewReporter(hostID, hostName string, includeProcesses bool, useConntrack bo
}
}
if conntrackModulePresent {
natmapper, err = newNATMapper()
ct, err := NewConntracker(true, "--any-nat")
if err != nil {
log.Printf("Failed to start natMapper: %v", err)
log.Printf("Failed to start conntracker for natmapper: %v", err)
}
natmapper = NewNATMapper(ct)
}
return &Reporter{
hostID: hostID,
hostName: hostName,
includeProcesses: includeProcesses,
conntracker: conntracker,
natmapper: natmapper,
natmapper: &natmapper,
revResolver: NewReverseResolver(),
}
}
Expand Down Expand Up @@ -139,7 +140,7 @@ func (r *Reporter) Report() (report.Report, error) {
}

if r.natmapper != nil {
r.natmapper.applyNAT(rpt, r.hostID)
r.natmapper.ApplyNAT(rpt, r.hostID)
}

return rpt, nil
Expand All @@ -165,7 +166,7 @@ func (r *Reporter) addConnection(rpt *report.Report, localAddr, remoteAddr strin
// In case we have a reverse resolution for the IP, we can use it for
// the name...
if revRemoteName, err := r.revResolver.Get(remoteAddr); err == nil {
remoteNode = remoteNode.AddMetadata(map[string]string{
remoteNode = remoteNode.WithMetadata(map[string]string{
"name": revRemoteName,
})
}
Expand Down Expand Up @@ -211,7 +212,7 @@ func (r *Reporter) addConnection(rpt *report.Report, localAddr, remoteAddr strin
// In case we have a reverse resolution for the IP, we can use it for
// the name...
if revRemoteName, err := r.revResolver.Get(remoteAddr); err == nil {
remoteNode = remoteNode.AddMetadata(map[string]string{
remoteNode = remoteNode.WithMetadata(map[string]string{
"name": revRemoteName,
})
}
Expand Down
Loading

0 comments on commit dd696cd

Please sign in to comment.