From 948632c15ac88ed7609223e025fb463e2cd73042 Mon Sep 17 00:00:00 2001 From: Nino Kodabande Date: Wed, 10 May 2023 11:43:17 -0700 Subject: [PATCH] Add missing test cases for Override The add with override did not have any test coverage. Signed-off-by: Nino Kodabande --- cmd/rancher-desktop-guestagent/main.go | 25 +-- pkg/tracker/apitracker_test.go | 76 +++++++++ pkg/tracker/portstorage.go | 12 +- pkg/tracker/vtunneltracker.go | 32 +++- pkg/tracker/vtunneltracker_test.go | 204 ++++++++++++++++++++++--- 5 files changed, 305 insertions(+), 44 deletions(-) diff --git a/cmd/rancher-desktop-guestagent/main.go b/cmd/rancher-desktop-guestagent/main.go index cf5ac9b..5ca13f7 100644 --- a/cmd/rancher-desktop-guestagent/main.go +++ b/cmd/rancher-desktop-guestagent/main.go @@ -57,19 +57,14 @@ var ( "address to bind Kubernetes services to on the host, valid options are 0.0.0.0 or 127.0.0.1") ) -//nolint:lll -// Flags can only be enable in the following combination: -// +======================+=======================+===========================+=========================+=============================+ -// | | Default Network Admin | Default Network Non-Admin | Namespace Network Admin | Namespace Network Non-Admin | -// +======================+=======================+===========================+=========================+=============================+ -// | privilegedService | enable | disable | disable | disable | -// +----------------------+-----------------------+---------------------------+-------------------------+-----------------------------+ -// | docker Or containerd | enable | disable | enable | enable | -// +----------------------+-----------------------+---------------------------+-------------------------+-----------------------------+ -// | iptables | disable | enable | disable | disable | -// +----------------------+-----------------------+---------------------------+-------------------------+-----------------------------+ -// | Kubernetes | enable | enable | enable | enable | -// +----------------------+-----------------------+---------------------------+-------------------------+-----------------------------+ +// Flags can only be enabled in the following combination: +// +===========+=======================================+====================+ +// | | Default Network | Namespaced Network | +// +===========+=======================================+====================+ +// | Admin | privilegedService + docker/containerd | docker/containerd | +// +-----------+---------------------------------------+--------------------+ +// | Non-Admin | iptables | docker/containerd | +// +-----------+---------------------------------------+--------------------+ const ( wslInfName = "eth0" @@ -111,10 +106,6 @@ func main() { cancel() }() - // Only one of the containerd, docker or iptables should be enabled - // at any give run. The containerd or docker is enabled depending - // on the chosen backend engine when privileged service is enabled - // and port mappings are sent over the VTunnel. However, when if !*enableContainerd && !*enableDocker && !*enableIptables { diff --git a/pkg/tracker/apitracker_test.go b/pkg/tracker/apitracker_test.go index d7ef005..af6ccb2 100644 --- a/pkg/tracker/apitracker_test.go +++ b/pkg/tracker/apitracker_test.go @@ -71,6 +71,82 @@ func TestBasicAdd(t *testing.T) { assert.Equal(t, portMapping, actualPortMapping) } +func TestAddOverride(t *testing.T) { + t.Parallel() + + var expectedExposeReq []*types.ExposeRequest + + mux := http.NewServeMux() + + mux.HandleFunc("/services/forwarder/expose", func(w http.ResponseWriter, r *http.Request) { + var tmpReq *types.ExposeRequest + err := json.NewDecoder(r.Body).Decode(&tmpReq) + assert.NoError(t, err) + expectedExposeReq = append(expectedExposeReq, tmpReq) + }) + + testSrv := httptest.NewServer(mux) + defer testSrv.Close() + + apiTracker := tracker.NewAPITracker(testSrv.URL) + portMapping := nat.PortMap{ + "80/tcp": []nat.PortBinding{ + { + HostIP: hostIP, + HostPort: hostPort, + }, + }, + "443/tcp": []nat.PortBinding{ + { + HostIP: hostIP2, + HostPort: hostPort2, + }, + }, + } + err := apiTracker.Add(containerID, portMapping) + assert.NoError(t, err) + + firstEntryIndex := 0 + assert.Equal(t, expectedExposeReq[firstEntryIndex].Local, ipPortBuilder(hostIP, hostPort)) + assert.Equal(t, expectedExposeReq[firstEntryIndex].Remote, ipPortBuilder(hostSwitchIP, hostPort)) + + secondEntryIndex := 1 + assert.Equal(t, expectedExposeReq[secondEntryIndex].Local, ipPortBuilder(hostIP2, hostPort2)) + assert.Equal(t, expectedExposeReq[secondEntryIndex].Remote, ipPortBuilder(hostSwitchIP, hostPort2)) + + actualPortMapping := apiTracker.Get(containerID) + assert.Equal(t, portMapping, actualPortMapping) + + // reset the exposeReq slice + expectedExposeReq = nil + + portMapping2 := nat.PortMap{ + "80/tcp": []nat.PortBinding{ + { + HostIP: hostIP, + HostPort: hostPort, + }, + }, + "8080/tcp": []nat.PortBinding{ + { + HostIP: hostIP2, + HostPort: "8080", + }, + }, + } + err = apiTracker.Add(containerID, portMapping2) + assert.NoError(t, err) + + assert.Equal(t, expectedExposeReq[firstEntryIndex].Local, ipPortBuilder(hostIP, hostPort)) + assert.Equal(t, expectedExposeReq[firstEntryIndex].Remote, ipPortBuilder(hostSwitchIP, hostPort)) + + assert.Equal(t, expectedExposeReq[secondEntryIndex].Local, ipPortBuilder(hostIP2, "8080")) + assert.Equal(t, expectedExposeReq[secondEntryIndex].Remote, ipPortBuilder(hostSwitchIP, "8080")) + + actualPortMapping = apiTracker.Get(containerID) + assert.Equal(t, portMapping2, actualPortMapping) +} + func TestAddWithError(t *testing.T) { t.Parallel() diff --git a/pkg/tracker/portstorage.go b/pkg/tracker/portstorage.go index 1cba689..ea691b7 100644 --- a/pkg/tracker/portstorage.go +++ b/pkg/tracker/portstorage.go @@ -70,12 +70,22 @@ func (p *portStorage) getAll() map[string]nat.PortMap { portMappings := make(map[string]nat.PortMap, len(p.portmap)) for k, v := range p.portmap { - portMappings[k] = v + portMappings[k] = copyPortMap(v) } return portMappings } +func copyPortMap(m nat.PortMap) nat.PortMap { + portMap := make(nat.PortMap, len(m)) + + for k, v := range m { + portMap[k] = v + } + + return portMap +} + func (p *portStorage) remove(containerID string) { p.mutex.Lock() defer p.mutex.Unlock() diff --git a/pkg/tracker/vtunneltracker.go b/pkg/tracker/vtunneltracker.go index ac53478..19caf34 100644 --- a/pkg/tracker/vtunneltracker.go +++ b/pkg/tracker/vtunneltracker.go @@ -16,11 +16,16 @@ limitations under the License. package tracker import ( + "errors" + "fmt" + "github.com/docker/go-connections/nat" "github.com/rancher-sandbox/rancher-desktop-agent/pkg/forwarder" "github.com/rancher-sandbox/rancher-desktop-agent/pkg/types" ) +var ErrRemoveAll = errors.New("failed to remove all portMappings") + // VTunnelTracker keeps track of port mappings and forwards // them to the privileged service on the host over AF_VSOCK // tunnel (vtunnel). @@ -41,8 +46,8 @@ func NewVTunnelTracker(vtunnelForwarder forwarder.Forwarder, wslAddrs []types.Co } } -// Add adds a container ID and port mapping to the tracker and calls the -// vtunnle forwarder to send the port mappings to privileged service. +// Add a container ID and port mapping to the tracker and calls the +// vtunnel forwarder to send the port mappings to privileged service. func (p *VTunnelTracker) Add(containerID string, portMap nat.PortMap) error { if len(portMap) == 0 { return nil @@ -68,7 +73,7 @@ func (p *VTunnelTracker) Get(containerID string) nat.PortMap { } // Remove deletes a container ID and port mapping from the tracker and calls the -// vtunnle forwarder to send the port mappings to privileged service. +// vtunnel forwarder to send the port mappings to privileged service. func (p *VTunnelTracker) Remove(containerID string) error { portMap := p.portStorage.get(containerID) if len(portMap) != 0 { @@ -89,7 +94,26 @@ func (p *VTunnelTracker) Remove(containerID string) error { // RemoveAll removes all the port bindings from the tracker. func (p *VTunnelTracker) RemoveAll() error { - p.portStorage.removeAll() + defer p.portStorage.removeAll() + + allPortMappings := p.portStorage.getAll() + + var errs []error + + for _, portMap := range allPortMappings { + err := p.vtunnelForwarder.Send(types.PortMapping{ + Remove: true, + Ports: portMap, + ConnectAddrs: p.wslAddrs, + }) + if err != nil { + errs = append(errs, err) + } + } + + if len(errs) != 0 { + return fmt.Errorf("%w: %+v", ErrRemoveAll, errs) + } return nil } diff --git a/pkg/tracker/vtunneltracker_test.go b/pkg/tracker/vtunneltracker_test.go index 77d0921..b2a4b81 100644 --- a/pkg/tracker/vtunneltracker_test.go +++ b/pkg/tracker/vtunneltracker_test.go @@ -14,7 +14,9 @@ limitations under the License. package tracker_test import ( + "encoding/json" "errors" + "reflect" "testing" "github.com/docker/go-connections/nat" @@ -52,8 +54,6 @@ func TestVTunnelTrackerAdd(t *testing.T) { err = vtunnelTracker.Add(containerID2, portMapping2) assert.NoError(t, err) - assert.Len(t, forwarder.receivedPortMappings, 2) - assert.ElementsMatch(t, forwarder.receivedPortMappings, []types.PortMapping{ { @@ -67,11 +67,78 @@ func TestVTunnelTrackerAdd(t *testing.T) { }, }) - actualPortMapping1 := vtunnelTracker.Get(containerID) - assert.Equal(t, actualPortMapping1, portMapping) + actualPortMapping := vtunnelTracker.Get(containerID) + assert.Equal(t, actualPortMapping, portMapping) + + actualPortMapping = vtunnelTracker.Get(containerID2) + assert.Equal(t, actualPortMapping, portMapping2) +} + +func TestVTunnelTrackerAddOverride(t *testing.T) { + t.Parallel() + + wslConnectAddr := []types.ConnectAddrs{{Network: "tcp", Addr: "192.168.0.1"}} + forwarder := testVTunnelForwarder{} + vtunnelTracker := tracker.NewVTunnelTracker(&forwarder, wslConnectAddr) + + portMapping := nat.PortMap{ + "80/tcp": []nat.PortBinding{ + { + HostIP: hostIP, + HostPort: hostPort, + }, + }, + "443/tcp": []nat.PortBinding{ + { + HostIP: hostIP2, + HostPort: hostPort2, + }, + }, + } + err := vtunnelTracker.Add(containerID, portMapping) + assert.NoError(t, err) + + assert.ElementsMatch(t, forwarder.receivedPortMappings, + []types.PortMapping{ + { + Remove: false, + Ports: portMapping, + ConnectAddrs: wslConnectAddr, + }, + }) + + actualPortMapping := vtunnelTracker.Get(containerID) + assert.Equal(t, actualPortMapping, portMapping) + + portMapping2 := nat.PortMap{ + "80/tcp": []nat.PortBinding{ + { + HostIP: hostIP, + HostPort: hostPort, + }, + }, + "8080/tcp": []nat.PortBinding{ + { + HostIP: hostIP2, + HostPort: "8080", + }, + }, + } + + err = vtunnelTracker.Add(containerID, portMapping2) + assert.NoError(t, err) + + secondCallIndex := 1 + assert.Equal(t, forwarder.receivedPortMappings[secondCallIndex], + types.PortMapping{ + Remove: false, + Ports: portMapping2, + ConnectAddrs: wslConnectAddr, + }, + ) - actualPortMapping2 := vtunnelTracker.Get(containerID2) - assert.Equal(t, actualPortMapping2, portMapping2) + actualPortMapping = vtunnelTracker.Get(containerID) + assert.Equal(t, actualPortMapping, portMapping2) } func TestVTunnelTrackerAddEmptyPortMap(t *testing.T) { @@ -111,8 +178,6 @@ func TestVTunnelTrackerAddWithError(t *testing.T) { err := vtunnelTracker.Add(containerID, portMapping) assert.ErrorIs(t, err, errSend) - assert.Len(t, forwarder.receivedPortMappings, 1) - assert.ElementsMatch(t, forwarder.receivedPortMappings, []types.PortMapping{ { @@ -168,11 +233,11 @@ func TestVTunnelTrackerRemove(t *testing.T) { ConnectAddrs: wslConnectAddr, }) - actualPortMapping1 := vtunnelTracker.Get(containerID) - assert.Nil(t, actualPortMapping1) + actualPortMapping := vtunnelTracker.Get(containerID) + assert.Nil(t, actualPortMapping) - actualPortMapping2 := vtunnelTracker.Get(containerID2) - assert.Equal(t, actualPortMapping2, nat.PortMap{ + actualPortMapping = vtunnelTracker.Get(containerID2) + assert.Equal(t, actualPortMapping, nat.PortMap{ "443/tcp": []nat.PortBinding{ { HostIP: hostIP2, @@ -240,8 +305,8 @@ func TestVTunnelTrackerRemoveError(t *testing.T) { ConnectAddrs: wslConnectAddr, }) - actualPortMapping1 := vtunnelTracker.Get(containerID) - assert.Equal(t, actualPortMapping1, nat.PortMap{ + actualPortMapping := vtunnelTracker.Get(containerID) + assert.Equal(t, actualPortMapping, nat.PortMap{ "80/tcp": []nat.PortBinding{ { HostIP: hostIP, @@ -250,8 +315,8 @@ func TestVTunnelTrackerRemoveError(t *testing.T) { }, }) - actualPortMapping2 := vtunnelTracker.Get(containerID2) - assert.Equal(t, actualPortMapping2, nat.PortMap{ + actualPortMapping = vtunnelTracker.Get(containerID2) + assert.Equal(t, actualPortMapping, nat.PortMap{ "443/tcp": []nat.PortBinding{ { HostIP: hostIP2, @@ -290,16 +355,104 @@ func TestVTunnelTrackerRemoveAll(t *testing.T) { err = vtunnelTracker.Add(containerID2, portMapping2) assert.NoError(t, err) - assert.Len(t, forwarder.receivedPortMappings, 2) - err = vtunnelTracker.RemoveAll() assert.NoError(t, err) - actualPortMapping1 := vtunnelTracker.Get(containerID) - assert.Nil(t, actualPortMapping1) + actualPortMapping := vtunnelTracker.Get(containerID) + assert.Nil(t, actualPortMapping) - actualPortMapping2 := vtunnelTracker.Get(containerID2) - assert.Nil(t, actualPortMapping2) + actualPortMapping = vtunnelTracker.Get(containerID2) + assert.Nil(t, actualPortMapping) + + assert.ElementsMatch(t, forwarder.receivedPortMappings, []types.PortMapping{ + { + Remove: false, + Ports: portMapping, + ConnectAddrs: wslConnectAddr, + }, + { + Remove: false, + Ports: portMapping2, + ConnectAddrs: wslConnectAddr, + }, + { + Remove: true, + Ports: portMapping, + ConnectAddrs: wslConnectAddr, + }, + { + Remove: true, + Ports: portMapping2, + ConnectAddrs: wslConnectAddr, + }, + }) +} + +func TestVTunnelTrackerRemoveAllError(t *testing.T) { + t.Parallel() + + wslConnectAddr := []types.ConnectAddrs{{Network: "tcp", Addr: "192.168.0.1"}} + forwarder := testVTunnelForwarder{} + vtunnelTracker := tracker.NewVTunnelTracker(&forwarder, wslConnectAddr) + + portMapping := nat.PortMap{ + "80/tcp": []nat.PortBinding{ + { + HostIP: hostIP, + HostPort: hostPort, + }, + }, + } + err := vtunnelTracker.Add(containerID, portMapping) + assert.NoError(t, err) + + portMapping2 := nat.PortMap{ + "443/tcp": []nat.PortBinding{ + { + HostIP: hostIP2, + HostPort: hostPort2, + }, + }, + } + err = vtunnelTracker.Add(containerID2, portMapping2) + assert.NoError(t, err) + + forwarder.failCondition = func(pm types.PortMapping) error { + if _, ok := pm.Ports["443/tcp"]; ok { + return &json.UnsupportedValueError{ + Value: reflect.Value{}, + Str: "Not Supported!", + } + } + + return nil + } + err = vtunnelTracker.RemoveAll() + assert.ErrorIs(t, err, tracker.ErrRemoveAll) + + actualPortMapping := vtunnelTracker.Get(containerID) + assert.Nil(t, actualPortMapping) + + actualPortMapping = vtunnelTracker.Get(containerID2) + assert.Nil(t, actualPortMapping) + + assert.ElementsMatch(t, forwarder.receivedPortMappings, []types.PortMapping{ + { + Remove: false, + Ports: portMapping, + ConnectAddrs: wslConnectAddr, + }, + { + Remove: false, + Ports: portMapping2, + ConnectAddrs: wslConnectAddr, + }, + { + Remove: true, + Ports: portMapping, + ConnectAddrs: wslConnectAddr, + }, + }) } func TestVTunnelTrackerGet(t *testing.T) { @@ -329,9 +482,16 @@ var errSend = errors.New("error from Send") type testVTunnelForwarder struct { receivedPortMappings []types.PortMapping sendErr error + failCondition func(types.PortMapping) error } func (v *testVTunnelForwarder) Send(portMapping types.PortMapping) error { + if v.failCondition != nil { + if err := v.failCondition(portMapping); err != nil { + return err + } + } + v.receivedPortMappings = append(v.receivedPortMappings, portMapping) return v.sendErr