Skip to content

Commit

Permalink
k8s/portforward: FakePortForwardClient knows its whole call history (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Maia McCormick committed Apr 30, 2021
1 parent c8027e5 commit 1d7293d
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 22 deletions.
165 changes: 155 additions & 10 deletions internal/engine/portforward/subscriber_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/tilt-dev/tilt/pkg/apis/core/v1alpha1"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -50,8 +52,8 @@ func TestPortForward(t *testing.T) {

f.onChange()
assert.Equal(t, 1, len(f.plc.activeForwards))
assert.Equal(t, "pod-id", f.kCli.LastForwardPortPodID.String())
firstPodForwardCtx := f.kCli.LastForwardContext
assert.Equal(t, "pod-id", f.kCli.LastForwardPortPodID().String())
firstPodForwardCtx := f.kCli.LastForwardContext()

state = f.st.LockMutableStateForTesting()
mt = state.ManifestTargets["fe"]
Expand All @@ -61,7 +63,7 @@ func TestPortForward(t *testing.T) {

f.onChange()
assert.Equal(t, 1, len(f.plc.activeForwards))
assert.Equal(t, "pod-id2", f.kCli.LastForwardPortPodID.String())
assert.Equal(t, "pod-id2", f.kCli.LastForwardPortPodID().String())

state = f.st.LockMutableStateForTesting()
mt = state.ManifestTargets["fe"]
Expand All @@ -76,6 +78,69 @@ func TestPortForward(t *testing.T) {
"Expected first port-forward to be canceled")
}

func TestMultiplePortForwardsForOnePod(t *testing.T) {
f := newPLCFixture(t)
defer f.TearDown()

state := f.st.LockMutableStateForTesting()
m := model.Manifest{
Name: "fe",
}
m = m.WithDeployTarget(model.K8sTarget{
PortForwards: []model.PortForward{
{
LocalPort: 8000,
ContainerPort: 8080,
},
{
LocalPort: 8001,
ContainerPort: 8081,
},
},
})
state.UpsertManifestTarget(store.NewManifestTarget(m))
f.st.UnlockMutableState()

f.onChange()
assert.Equal(t, 0, len(f.plc.activeForwards))

state = f.st.LockMutableStateForTesting()
mt := state.ManifestTargets["fe"]
mt.State.RuntimeState = store.NewK8sRuntimeStateWithPods(mt.Manifest,
v1alpha1.Pod{Name: "pod-id", Phase: string(v1.PodRunning)})
f.st.UnlockMutableState()

f.onChange()
require.Equal(t, 1, len(f.plc.activeForwards))
require.Equal(t, 2, f.kCli.CreatePortForwardCallCount())

// PortForwards are executed async so we can't guarantee the order;
// just make sure each expected call appears exactly once
expectedRemotePorts := []int{8080, 8081}
var actualRemotePorts []int
var contexts []context.Context
for _, call := range f.kCli.PortForwardCalls() {
actualRemotePorts = append(actualRemotePorts, call.RemotePort)
contexts = append(contexts, call.Context)
assert.Equal(t, "pod-id", call.PodID.String())
}
assert.ElementsMatch(t, expectedRemotePorts, actualRemotePorts, "remote ports for which PortForward was called")

state = f.st.LockMutableStateForTesting()
mt = state.ManifestTargets["fe"]
mt.State.RuntimeState = store.NewK8sRuntimeStateWithPods(mt.Manifest,
v1alpha1.Pod{Name: "pod-id", Phase: string(v1.PodPending)})
f.st.UnlockMutableState()

f.onChange()
assert.Equal(t, 0, len(f.plc.activeForwards))

for _, ctx := range contexts {
assert.Equal(t, context.Canceled, ctx.Err(),
"found uncancelled port forward context")
}
}

func TestPortForwardAutoDiscovery(t *testing.T) {
f := newPLCFixture(t)
defer f.TearDown()
Expand Down Expand Up @@ -108,7 +173,7 @@ func TestPortForwardAutoDiscovery(t *testing.T) {

f.onChange()
assert.Equal(t, 1, len(f.plc.activeForwards))
assert.Equal(t, 8000, f.kCli.LastForwardPortRemotePort)
assert.Equal(t, 8000, f.kCli.LastForwardPortRemotePort())
}

func TestPortForwardAutoDiscovery2(t *testing.T) {
Expand Down Expand Up @@ -140,7 +205,7 @@ func TestPortForwardAutoDiscovery2(t *testing.T) {

f.onChange()
assert.Equal(t, 1, len(f.plc.activeForwards))
assert.Equal(t, 8080, f.kCli.LastForwardPortRemotePort)
assert.Equal(t, 8080, f.kCli.LastForwardPortRemotePort())
}

func TestPortForwardChangePort(t *testing.T) {
Expand All @@ -164,7 +229,7 @@ func TestPortForwardChangePort(t *testing.T) {

f.onChange()
assert.Equal(t, 1, len(f.plc.activeForwards))
assert.Equal(t, 8081, f.kCli.LastForwardPortRemotePort)
assert.Equal(t, 8081, f.kCli.LastForwardPortRemotePort())

state = f.st.LockMutableStateForTesting()
kTarget := state.ManifestTargets["fe"].Manifest.K8sTarget()
Expand All @@ -173,7 +238,87 @@ func TestPortForwardChangePort(t *testing.T) {

f.onChange()
assert.Equal(t, 1, len(f.plc.activeForwards))
assert.Equal(t, 8082, f.kCli.LastForwardPortRemotePort)
assert.Equal(t, 8082, f.kCli.LastForwardPortRemotePort())
}

func TestPortForwardChangeHost(t *testing.T) {
f := newPLCFixture(t)
defer f.TearDown()

state := f.st.LockMutableStateForTesting()
m := model.Manifest{Name: "fe"}.WithDeployTarget(model.K8sTarget{
PortForwards: []model.PortForward{
{
LocalPort: 8080,
ContainerPort: 8081,
Host: "someHostA",
},
},
})
state.UpsertManifestTarget(store.NewManifestTarget(m))
mt := state.ManifestTargets["fe"]
mt.State.RuntimeState = store.NewK8sRuntimeStateWithPods(mt.Manifest,
v1alpha1.Pod{Name: "pod-id", Phase: string(v1.PodRunning)})
f.st.UnlockMutableState()

f.onChange()
assert.Equal(t, 1, len(f.plc.activeForwards))
assert.Equal(t, 8081, f.kCli.LastForwardPortRemotePort())
assert.Equal(t, "someHostA", f.kCli.LastForwardPortHost())

state = f.st.LockMutableStateForTesting()
kTarget := state.ManifestTargets["fe"].Manifest.K8sTarget()
kTarget.PortForwards[0].Host = "otherHostB"
f.st.UnlockMutableState()

f.onChange()
assert.Equal(t, 1, len(f.plc.activeForwards))
assert.Equal(t, 8081, f.kCli.LastForwardPortRemotePort())
assert.Equal(t, "otherHostB", f.kCli.LastForwardPortHost())
}

func TestPortForwardChangeManifestName(t *testing.T) {
f := newPLCFixture(t)
defer f.TearDown()

state := f.st.LockMutableStateForTesting()
m := model.Manifest{Name: "manifestA"}.WithDeployTarget(model.K8sTarget{
PortForwards: []model.PortForward{
{
LocalPort: 8080,
ContainerPort: 8081,
},
},
})
state.UpsertManifestTarget(store.NewManifestTarget(m))
mt := state.ManifestTargets["manifestA"]
mt.State.RuntimeState = store.NewK8sRuntimeStateWithPods(mt.Manifest,
v1alpha1.Pod{Name: "pod-id", Phase: string(v1.PodRunning)})
f.st.UnlockMutableState()

f.onChange()
assert.Equal(t, 1, len(f.plc.activeForwards))
assert.Equal(t, 8081, f.kCli.LastForwardPortRemotePort())

state = f.st.LockMutableStateForTesting()
delete(state.ManifestTargets, "manifestA")
m = model.Manifest{Name: "manifestB"}.WithDeployTarget(model.K8sTarget{
PortForwards: []model.PortForward{
{
LocalPort: 8080,
ContainerPort: 8081,
},
},
})
state.UpsertManifestTarget(store.NewManifestTarget(m))
mt = state.ManifestTargets["manifestB"]
mt.State.RuntimeState = store.NewK8sRuntimeStateWithPods(mt.Manifest,
v1alpha1.Pod{Name: "pod-id", Phase: string(v1.PodRunning)})
f.st.UnlockMutableState()

f.onChange()
assert.Equal(t, 1, len(f.plc.activeForwards))
assert.Equal(t, 8081, f.kCli.LastForwardPortRemotePort())
}

func TestPortForwardRestart(t *testing.T) {
Expand All @@ -200,16 +345,16 @@ func TestPortForwardRestart(t *testing.T) {

f.onChange()
assert.Equal(t, 1, len(f.plc.activeForwards))
assert.Equal(t, 1, f.kCli.CreatePortForwardCallCount)
assert.Equal(t, 1, f.kCli.CreatePortForwardCallCount())

err := fmt.Errorf("unique-error")
f.kCli.LastForwarder.Done <- err
f.kCli.LastForwarder().Done <- err

assert.Contains(t, "unique-error", f.out.String())
time.Sleep(100 * time.Millisecond)

assert.Equal(t, 1, len(f.plc.activeForwards))
assert.Equal(t, 2, f.kCli.CreatePortForwardCallCount)
assert.Equal(t, 2, f.kCli.CreatePortForwardCallCount())
}

type portForwardTestCase struct {
Expand Down
80 changes: 68 additions & 12 deletions internal/k8s/fake_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,26 +491,82 @@ func (pf FakePortForwarder) ForwardPorts() error {
}

type FakePortForwardClient struct {
CreatePortForwardCallCount int
LastForwardPortPodID PodID
LastForwardPortRemotePort int
LastForwardPortHost string
LastForwarder FakePortForwarder
LastForwardContext context.Context
mu sync.Mutex
portForwardCalls []PortForwardCall
}

func NewFakePortfowardClient() *FakePortForwardClient {
return &FakePortForwardClient{
portForwardCalls: []PortForwardCall{},
}
}

type PortForwardCall struct {
PodID PodID
RemotePort int
Host string
Forwarder FakePortForwarder
Context context.Context
}

func (c *FakePortForwardClient) CreatePortForwarder(ctx context.Context, namespace Namespace, podID PodID, optionalLocalPort, remotePort int, host string) (PortForwarder, error) {
c.CreatePortForwardCallCount++
c.LastForwardContext = ctx
c.LastForwardPortPodID = podID
c.LastForwardPortRemotePort = remotePort
c.LastForwardPortHost = host
c.mu.Lock()
defer c.mu.Unlock()

result := FakePortForwarder{
localPort: optionalLocalPort,
ctx: ctx,
Done: make(chan error),
}
c.LastForwarder = result

c.portForwardCalls = append(c.portForwardCalls, PortForwardCall{
PodID: podID,
RemotePort: remotePort,
Host: host,
Forwarder: result,
Context: ctx,
})

return result, nil
}

func (c *FakePortForwardClient) CreatePortForwardCallCount() int {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.portForwardCalls)
}
func (c *FakePortForwardClient) LastForwardPortPodID() PodID {
c.mu.Lock()
defer c.mu.Unlock()
return c.portForwardCalls[len(c.portForwardCalls)-1].PodID
}
func (c *FakePortForwardClient) LastForwardPortRemotePort() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.portForwardCalls[len(c.portForwardCalls)-1].RemotePort
}
func (c *FakePortForwardClient) LastForwardPortHost() string {
c.mu.Lock()
defer c.mu.Unlock()
return c.portForwardCalls[len(c.portForwardCalls)-1].Host
}
func (c *FakePortForwardClient) LastForwarder() FakePortForwarder {
c.mu.Lock()
defer c.mu.Unlock()
return c.portForwardCalls[len(c.portForwardCalls)-1].Forwarder
}
func (c *FakePortForwardClient) LastForwardContext() context.Context {
c.mu.Lock()
defer c.mu.Unlock()
return c.portForwardCalls[len(c.portForwardCalls)-1].Context
}
func (c *FakePortForwardClient) PortForwardCalls() []PortForwardCall {
c.mu.Lock()
defer c.mu.Unlock()

calls := make([]PortForwardCall, len(c.portForwardCalls))
for i, call := range c.portForwardCalls {
calls[i] = call
}
return calls
}

0 comments on commit 1d7293d

Please sign in to comment.