diff --git a/wait/health.go b/wait/health.go index 4712444fdc..b0821c8aec 100644 --- a/wait/health.go +++ b/wait/health.go @@ -2,8 +2,9 @@ package wait import ( "context" - "github.com/docker/docker/api/types" "time" + + "github.com/docker/docker/api/types" ) // Implement interface @@ -77,6 +78,9 @@ func (ws *HealthStrategy) WaitUntilReady(ctx context.Context, target StrategyTar if err != nil { return err } + if err := checkState(state); err != nil { + return err + } if state.Health == nil || state.Health.Status != types.Healthy { time.Sleep(ws.PollInterval) continue diff --git a/wait/health_test.go b/wait/health_test.go index 689db82de6..e6c15ed3b1 100644 --- a/wait/health_test.go +++ b/wait/health_test.go @@ -3,18 +3,18 @@ package wait import ( "context" "errors" - "github.com/stretchr/testify/assert" "io" "testing" "time" "github.com/docker/docker/api/types" "github.com/docker/go-connections/nat" + "github.com/stretchr/testify/assert" tcexec "github.com/testcontainers/testcontainers-go/exec" ) type healthStrategyTarget struct { - Health *types.Health + state *types.ContainerState } func (st healthStrategyTarget) Host(ctx context.Context) (string, error) { @@ -38,13 +38,18 @@ func (st healthStrategyTarget) Exec(ctx context.Context, cmd []string, options . } func (st healthStrategyTarget) State(ctx context.Context) (*types.ContainerState, error) { - return &types.ContainerState{Health: st.Health}, nil + return st.state, nil } // TestWaitForHealthTimesOutForUnhealthy confirms that an unhealthy container will eventually // time out. func TestWaitForHealthTimesOutForUnhealthy(t *testing.T) { - target := healthStrategyTarget{Health: &types.Health{Status: types.Unhealthy}} + target := healthStrategyTarget{ + state: &types.ContainerState{ + Running: true, + Health: &types.Health{Status: types.Unhealthy}, + }, + } wg := NewHealthStrategy().WithStartupTimeout(100 * time.Millisecond) err := wg.WaitUntilReady(context.Background(), target) @@ -54,7 +59,12 @@ func TestWaitForHealthTimesOutForUnhealthy(t *testing.T) { // TestWaitForHealthSucceeds ensures that a healthy container always succeeds. func TestWaitForHealthSucceeds(t *testing.T) { - target := healthStrategyTarget{Health: &types.Health{Status: types.Healthy}} + target := healthStrategyTarget{ + state: &types.ContainerState{ + Running: true, + Health: &types.Health{Status: types.Healthy}, + }, + } wg := NewHealthStrategy().WithStartupTimeout(100 * time.Millisecond) err := wg.WaitUntilReady(context.Background(), target) @@ -64,7 +74,12 @@ func TestWaitForHealthSucceeds(t *testing.T) { // TestWaitForHealthWithNil checks that an initial `nil` Health will not casue a panic, // and if the container eventually becomes healthy, the HealthStrategy will succeed. func TestWaitForHealthWithNil(t *testing.T) { - target := &healthStrategyTarget{Health: nil} + target := &healthStrategyTarget{ + state: &types.ContainerState{ + Running: true, + Health: nil, + }, + } wg := NewHealthStrategy(). WithStartupTimeout(500 * time.Millisecond). WithPollInterval(100 * time.Millisecond) @@ -73,7 +88,7 @@ func TestWaitForHealthWithNil(t *testing.T) { // wait a bit to simulate startup time and give check time to at least // try a few times with a nil Health time.Sleep(200 * time.Millisecond) - target.Health = &types.Health{Status: types.Healthy} + target.state.Health = &types.Health{Status: types.Healthy} }(target) err := wg.WaitUntilReady(context.Background(), target) @@ -82,7 +97,12 @@ func TestWaitForHealthWithNil(t *testing.T) { // TestWaitFailsForNilHealth checks that Health always nil fails (but will NOT cause a panic) func TestWaitFailsForNilHealth(t *testing.T) { - target := &healthStrategyTarget{Health: nil} + target := &healthStrategyTarget{ + state: &types.ContainerState{ + Running: true, + Health: nil, + }, + } wg := NewHealthStrategy(). WithStartupTimeout(500 * time.Millisecond). WithPollInterval(100 * time.Millisecond) @@ -91,3 +111,49 @@ func TestWaitFailsForNilHealth(t *testing.T) { assert.NotNil(t, err) assert.True(t, errors.Is(err, context.DeadlineExceeded)) } + +func TestWaitForHealthFailsDueToOOMKilledContainer(t *testing.T) { + target := &healthStrategyTarget{ + state: &types.ContainerState{ + OOMKilled: true, + }, + } + wg := NewHealthStrategy(). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + err := wg.WaitUntilReady(context.Background(), target) + assert.NotNil(t, err) + assert.EqualError(t, err, "container crashed with out-of-memory (OOMKilled)") +} + +func TestWaitForHealthFailsDueToExitedContainer(t *testing.T) { + target := &healthStrategyTarget{ + state: &types.ContainerState{ + Status: "exited", + ExitCode: 1, + }, + } + wg := NewHealthStrategy(). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + err := wg.WaitUntilReady(context.Background(), target) + assert.NotNil(t, err) + assert.EqualError(t, err, "container exited with code 1") +} + +func TestWaitForHealthFailsDueToUnexpectedContainerStatus(t *testing.T) { + target := &healthStrategyTarget{ + state: &types.ContainerState{ + Status: "dead", + }, + } + wg := NewHealthStrategy(). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + err := wg.WaitUntilReady(context.Background(), target) + assert.NotNil(t, err) + assert.EqualError(t, err, "unexpected container status \"dead\"") +} diff --git a/wait/host_port.go b/wait/host_port.go index be317813e2..95858f5d1e 100644 --- a/wait/host_port.go +++ b/wait/host_port.go @@ -113,6 +113,9 @@ func (hp *HostPortStrategy) WaitUntilReady(ctx context.Context, target StrategyT case <-ctx.Done(): return fmt.Errorf("%s:%w", ctx.Err(), err) case <-time.After(waitInterval): + if err := checkTarget(ctx, target); err != nil { + return err + } port, err = target.MappedPort(ctx, internalPort) if err != nil { fmt.Printf("(%d) [%s] %s\n", i, port, err) @@ -128,6 +131,9 @@ func (hp *HostPortStrategy) WaitUntilReady(ctx context.Context, target StrategyT dialer := net.Dialer{} address := net.JoinHostPort(ipAddress, portString) for { + if err := checkTarget(ctx, target); err != nil { + return err + } conn, err := dialer.DialContext(ctx, proto, address) if err != nil { if v, ok := err.(*net.OpError); ok { @@ -151,6 +157,9 @@ func (hp *HostPortStrategy) WaitUntilReady(ctx context.Context, target StrategyT if ctx.Err() != nil { return ctx.Err() } + if err := checkTarget(ctx, target); err != nil { + return err + } exitCode, _, err := target.Exec(ctx, []string{"/bin/sh", "-c", command}) if err != nil { return fmt.Errorf("%w, host port waiting failed", err) diff --git a/wait/host_port_test.go b/wait/host_port_test.go new file mode 100644 index 0000000000..bb53ee6f01 --- /dev/null +++ b/wait/host_port_test.go @@ -0,0 +1,484 @@ +package wait + +import ( + "context" + "errors" + "io" + "net" + "strconv" + "testing" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/go-connections/nat" + "github.com/testcontainers/testcontainers-go/exec" +) + +func TestWaitForListeningPortSucceeds(t *testing.T) { + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + rawPort := listener.Addr().(*net.TCPAddr).Port + port, err := nat.NewPort("tcp", strconv.Itoa(rawPort)) + if err != nil { + t.Fatal(err) + } + + var mappedPortCount, execCount int + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + defer func() { mappedPortCount++ }() + if mappedPortCount == 0 { + return "", errors.New("port not found") + } + return port, nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Running: true, + }, nil + }, + ExecImpl: func(_ context.Context, _ []string, _ ...exec.ProcessOption) (int, io.Reader, error) { + defer func() { execCount++ }() + if execCount == 0 { + return 1, nil, nil + } + return 0, nil, nil + }, + } + + wg := ForListeningPort("80"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + if err := wg.WaitUntilReady(context.Background(), target); err != nil { + t.Fatal(err) + } +} + +func TestWaitForExposedPortSucceeds(t *testing.T) { + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + rawPort := listener.Addr().(*net.TCPAddr).Port + port, err := nat.NewPort("tcp", strconv.Itoa(rawPort)) + if err != nil { + t.Fatal(err) + } + + var mappedPortCount, execCount int + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + PortsImpl: func(_ context.Context) (nat.PortMap, error) { + return nat.PortMap{ + "80": []nat.PortBinding{ + { + HostIP: "0.0.0.0", + HostPort: port.Port(), + }, + }, + }, nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + defer func() { mappedPortCount++ }() + if mappedPortCount == 0 { + return "", errors.New("port not found") + } + return port, nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Running: true, + }, nil + }, + ExecImpl: func(_ context.Context, _ []string, _ ...exec.ProcessOption) (int, io.Reader, error) { + defer func() { execCount++ }() + if execCount == 0 { + return 1, nil, nil + } + return 0, nil, nil + }, + } + + wg := ForExposedPort(). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + if err := wg.WaitUntilReady(context.Background(), target); err != nil { + t.Fatal(err) + } +} + +func TestHostPortStrategyFailsWhileGettingPortDueToOOMKilledContainer(t *testing.T) { + var mappedPortCount int + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + defer func() { mappedPortCount++ }() + if mappedPortCount == 0 { + return "", errors.New("port not found") + } + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + OOMKilled: true, + }, nil + }, + } + + wg := NewHostPortStrategy("80"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container crashed with out-of-memory (OOMKilled)" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestHostPortStrategyFailsWhileGettingPortDueToExitedContainer(t *testing.T) { + var mappedPortCount int + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + defer func() { mappedPortCount++ }() + if mappedPortCount == 0 { + return "", errors.New("port not found") + } + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Status: "exited", + ExitCode: 1, + }, nil + }, + } + + wg := NewHostPortStrategy("80"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container exited with code 1" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestHostPortStrategyFailsWhileGettingPortDueToUnexpectedContainerStatus(t *testing.T) { + var mappedPortCount int + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + defer func() { mappedPortCount++ }() + if mappedPortCount == 0 { + return "", errors.New("port not found") + } + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Status: "dead", + }, nil + }, + } + + wg := NewHostPortStrategy("80"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "unexpected container status \"dead\"" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestHostPortStrategyFailsWhileExternalCheckingDueToOOMKilledContainer(t *testing.T) { + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + OOMKilled: true, + }, nil + }, + } + + wg := NewHostPortStrategy("80"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container crashed with out-of-memory (OOMKilled)" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestHostPortStrategyFailsWhileExternalCheckingDueToExitedContainer(t *testing.T) { + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Status: "exited", + ExitCode: 1, + }, nil + }, + } + + wg := NewHostPortStrategy("80"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container exited with code 1" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestHostPortStrategyFailsWhileExternalCheckingDueToUnexpectedContainerStatus(t *testing.T) { + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Status: "dead", + }, nil + }, + } + + wg := NewHostPortStrategy("80"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "unexpected container status \"dead\"" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestHostPortStrategyFailsWhileInternalCheckingDueToOOMKilledContainer(t *testing.T) { + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + rawPort := listener.Addr().(*net.TCPAddr).Port + port, err := nat.NewPort("tcp", strconv.Itoa(rawPort)) + if err != nil { + t.Fatal(err) + } + + var stateCount int + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + return port, nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + defer func() { stateCount++ }() + if stateCount == 0 { + return &types.ContainerState{ + Running: true, + }, nil + } + return &types.ContainerState{ + OOMKilled: true, + }, nil + }, + } + + wg := NewHostPortStrategy("80"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container crashed with out-of-memory (OOMKilled)" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestHostPortStrategyFailsWhileInternalCheckingDueToExitedContainer(t *testing.T) { + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + rawPort := listener.Addr().(*net.TCPAddr).Port + port, err := nat.NewPort("tcp", strconv.Itoa(rawPort)) + if err != nil { + t.Fatal(err) + } + + var stateCount int + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + return port, nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + defer func() { stateCount++ }() + if stateCount == 0 { + return &types.ContainerState{ + Running: true, + }, nil + } + return &types.ContainerState{ + Status: "exited", + ExitCode: 1, + }, nil + }, + } + + wg := NewHostPortStrategy("80"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container exited with code 1" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestHostPortStrategyFailsWhileInternalCheckingDueToUnexpectedContainerStatus(t *testing.T) { + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + rawPort := listener.Addr().(*net.TCPAddr).Port + port, err := nat.NewPort("tcp", strconv.Itoa(rawPort)) + if err != nil { + t.Fatal(err) + } + + var stateCount int + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + return port, nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + defer func() { stateCount++ }() + if stateCount == 0 { + return &types.ContainerState{ + Running: true, + }, nil + } + return &types.ContainerState{ + Status: "dead", + }, nil + }, + } + + wg := NewHostPortStrategy("80"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "unexpected container status \"dead\"" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} diff --git a/wait/http.go b/wait/http.go index eb16a79bc0..83de9f28d3 100644 --- a/wait/http.go +++ b/wait/http.go @@ -150,6 +150,10 @@ func (ws *HTTPStrategy) WaitUntilReady(ctx context.Context, target StrategyTarge case <-ctx.Done(): return fmt.Errorf("%s:%w", ctx.Err(), err) case <-time.After(ws.PollInterval): + if err := checkTarget(ctx, target); err != nil { + return err + } + port, err = target.MappedPort(ctx, ws.Port) } } @@ -225,6 +229,9 @@ func (ws *HTTPStrategy) WaitUntilReady(ctx context.Context, target StrategyTarge case <-ctx.Done(): return ctx.Err() case <-time.After(ws.PollInterval): + if err := checkTarget(ctx, target); err != nil { + return err + } req, err := http.NewRequestWithContext(ctx, ws.Method, endpoint.String(), bytes.NewReader(body)) if err != nil { return err diff --git a/wait/http_test.go b/wait/http_test.go index a8477240d1..fbc0b9f0ce 100644 --- a/wait/http_test.go +++ b/wait/http_test.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" "fmt" "io" "log" @@ -15,6 +16,8 @@ import ( "testing" "time" + "github.com/docker/docker/api/types" + "github.com/docker/go-connections/nat" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/wait" ) @@ -247,3 +250,212 @@ func TestHTTPStrategyWaitUntilReadyNoBasicAuth(t *testing.T) { return } } + +func TestHttpStrategyFailsWhileGettingPortDueToOOMKilledContainer(t *testing.T) { + var mappedPortCount int + target := &wait.MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + defer func() { mappedPortCount++ }() + if mappedPortCount == 0 { + return "", errors.New("port not found") + } + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + OOMKilled: true, + }, nil + }, + } + + wg := wait.ForHTTP("/"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container crashed with out-of-memory (OOMKilled)" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestHttpStrategyFailsWhileGettingPortDueToExitedContainer(t *testing.T) { + var mappedPortCount int + target := &wait.MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + defer func() { mappedPortCount++ }() + if mappedPortCount == 0 { + return "", errors.New("port not found") + } + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Status: "exited", + ExitCode: 1, + }, nil + }, + } + + wg := wait.ForHTTP("/"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container exited with code 1" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestHttpStrategyFailsWhileGettingPortDueToUnexpectedContainerStatus(t *testing.T) { + var mappedPortCount int + target := &wait.MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + defer func() { mappedPortCount++ }() + if mappedPortCount == 0 { + return "", errors.New("port not found") + } + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Status: "dead", + }, nil + }, + } + + wg := wait.ForHTTP("/"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "unexpected container status \"dead\"" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestHTTPStrategyFailsWhileRequestSendingDueToOOMKilledContainer(t *testing.T) { + target := &wait.MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + OOMKilled: true, + }, nil + }, + } + + wg := wait.ForHTTP("/"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container crashed with out-of-memory (OOMKilled)" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestHttpStrategyFailsWhileRequestSendingDueToExitedContainer(t *testing.T) { + target := &wait.MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Status: "exited", + ExitCode: 1, + }, nil + }, + } + + wg := wait.ForHTTP("/"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container exited with code 1" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestHttpStrategyFailsWhileRequestSendingDueToUnexpectedContainerStatus(t *testing.T) { + target := &wait.MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Status: "dead", + }, nil + }, + } + + wg := wait.ForHTTP("/"). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "unexpected container status \"dead\"" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} diff --git a/wait/log.go b/wait/log.go index 3e735cb801..4e14e9d313 100644 --- a/wait/log.go +++ b/wait/log.go @@ -87,6 +87,8 @@ LOOP: case <-ctx.Done(): return ctx.Err() default: + checkErr := checkTarget(ctx, target) + reader, err := target.Logs(ctx) if err != nil { time.Sleep(ws.PollInterval) @@ -100,7 +102,9 @@ LOOP: } logs := string(b) - if strings.Count(logs, ws.Log) >= ws.Occurrence { + if logs == "" && checkErr != nil { + return checkErr + } else if strings.Count(logs, ws.Log) >= ws.Occurrence { break LOOP } else { time.Sleep(ws.PollInterval) diff --git a/wait/log_test.go b/wait/log_test.go index f078b5a95b..8c534400c5 100644 --- a/wait/log_test.go +++ b/wait/log_test.go @@ -6,6 +6,8 @@ import ( "io" "testing" "time" + + "github.com/docker/docker/api/types" ) func TestWaitForLog(t *testing.T) { @@ -57,3 +59,88 @@ func TestWaitShouldFailWithExactNumberOfOccurrences(t *testing.T) { t.Fatal("expected error") } } + +func TestWaitForLogFailsDueToOOMKilledContainer(t *testing.T) { + target := &MockStrategyTarget{ + LogsImpl: func(_ context.Context) (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader([]byte(""))), nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + OOMKilled: true, + }, nil + }, + } + + wg := ForLog("docker"). + WithStartupTimeout(100 * time.Microsecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container crashed with out-of-memory (OOMKilled)" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestWaitForLogFailsDueToExitedContainer(t *testing.T) { + target := &MockStrategyTarget{ + LogsImpl: func(_ context.Context) (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader([]byte(""))), nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Status: "exited", + ExitCode: 1, + }, nil + }, + } + + wg := ForLog("docker"). + WithStartupTimeout(100 * time.Microsecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container exited with code 1" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestWaitForLogFailsDueToUnexpectedContainerStatus(t *testing.T) { + target := &MockStrategyTarget{ + LogsImpl: func(_ context.Context) (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader([]byte(""))), nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Status: "dead", + }, nil + }, + } + + wg := ForLog("docker"). + WithStartupTimeout(100 * time.Microsecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "unexpected container status \"dead\"" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} diff --git a/wait/sql.go b/wait/sql.go index 184d804ce1..9e0c05607c 100644 --- a/wait/sql.go +++ b/wait/sql.go @@ -87,6 +87,9 @@ func (w *waitForSql) WaitUntilReady(ctx context.Context, target StrategyTarget) case <-ctx.Done(): return fmt.Errorf("%s:%w", ctx.Err(), err) case <-ticker.C: + if err := checkTarget(ctx, target); err != nil { + return err + } port, err = target.MappedPort(ctx, w.Port) } } @@ -101,7 +104,9 @@ func (w *waitForSql) WaitUntilReady(ctx context.Context, target StrategyTarget) case <-ctx.Done(): return ctx.Err() case <-ticker.C: - + if err := checkTarget(ctx, target); err != nil { + return err + } if _, err := db.ExecContext(ctx, w.query); err != nil { continue } diff --git a/wait/sql_test.go b/wait/sql_test.go index 63048e6acc..3df3cff9bd 100644 --- a/wait/sql_test.go +++ b/wait/sql_test.go @@ -1,8 +1,14 @@ package wait import ( + "context" + "database/sql" + "database/sql/driver" + "errors" "testing" + "time" + "github.com/docker/docker/api/types" "github.com/docker/go-connections/nat" ) @@ -28,3 +34,285 @@ func Test_waitForSql_WithQuery(t *testing.T) { } }) } + +func init() { + sql.Register("mock", &mockSQLDriver{}) +} + +type mockSQLDriver struct { + driver.Driver +} + +func (sd *mockSQLDriver) Open(_ string) (driver.Conn, error) { + return &mockSQLConn{}, nil +} + +type mockSQLConn struct { + driver.Conn + driver.ConnBeginTx + driver.ConnPrepareContext +} + +func (sc *mockSQLConn) Close() error { + return nil +} + +func (sc *mockSQLConn) PrepareContext(_ context.Context, _ string) (driver.Stmt, error) { + return &mockSQLStmt{}, nil +} + +type mockSQLStmt struct { + driver.Stmt + driver.StmtExecContext + driver.StmtQueryContext +} + +func (st *mockSQLStmt) Close() error { + return nil +} + +func (st *mockSQLStmt) NumInput() int { + return 0 +} + +func (st *mockSQLStmt) ExecContext(_ context.Context, _ []driver.NamedValue) (driver.Result, error) { + return nil, nil +} + +func TestWaitForSQLSucceeds(t *testing.T) { + var mappedPortCount int + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + defer func() { mappedPortCount++ }() + if mappedPortCount == 0 { + return "", errors.New("port not found") + } + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Running: true, + }, nil + }, + } + + wg := ForSQL("3306", "mock", func(_ string, _ nat.Port) string { return "" }). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + if err := wg.WaitUntilReady(context.Background(), target); err != nil { + t.Fatal(err) + } +} + +func TestWaitForSQLFailsWhileGettingPortDueToOOMKilledContainer(t *testing.T) { + var mappedPortCount int + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + defer func() { mappedPortCount++ }() + if mappedPortCount == 0 { + return "", errors.New("port not found") + } + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + OOMKilled: true, + }, nil + }, + } + + wg := ForSQL("3306", "mock", func(_ string, _ nat.Port) string { return "" }). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container crashed with out-of-memory (OOMKilled)" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestWaitForSQLFailsWhileGettingPortDueToExitedContainer(t *testing.T) { + var mappedPortCount int + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + defer func() { mappedPortCount++ }() + if mappedPortCount == 0 { + return "", errors.New("port not found") + } + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Status: "exited", + ExitCode: 1, + }, nil + }, + } + + wg := ForSQL("3306", "mock", func(_ string, _ nat.Port) string { return "" }). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container exited with code 1" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestWaitForSQLFailsWhileGettingPortDueToUnexpectedContainerStatus(t *testing.T) { + var mappedPortCount int + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + defer func() { mappedPortCount++ }() + if mappedPortCount == 0 { + return "", errors.New("port not found") + } + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Status: "dead", + }, nil + }, + } + + wg := ForSQL("3306", "mock", func(_ string, _ nat.Port) string { return "" }). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "unexpected container status \"dead\"" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestWaitForSQLFailsWhileQueryExecutingDueToOOMKilledContainer(t *testing.T) { + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + OOMKilled: true, + }, nil + }, + } + + wg := ForSQL("3306", "mock", func(_ string, _ nat.Port) string { return "" }). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container crashed with out-of-memory (OOMKilled)" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestWaitForSQLFailsWhileQueryExecutingDueToExitedContainer(t *testing.T) { + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Status: "exited", + ExitCode: 1, + }, nil + }, + } + + wg := ForSQL("3306", "mock", func(_ string, _ nat.Port) string { return "" }). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "container exited with code 1" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} + +func TestWaitForSQLFailsWhileQueryExecutingDueToUnexpectedContainerStatus(t *testing.T) { + target := &MockStrategyTarget{ + HostImpl: func(_ context.Context) (string, error) { + return "localhost", nil + }, + MappedPortImpl: func(_ context.Context, _ nat.Port) (nat.Port, error) { + return "49152", nil + }, + StateImpl: func(_ context.Context) (*types.ContainerState, error) { + return &types.ContainerState{ + Status: "dead", + }, nil + }, + } + + wg := ForSQL("3306", "mock", func(_ string, _ nat.Port) string { return "" }). + WithStartupTimeout(500 * time.Millisecond). + WithPollInterval(100 * time.Millisecond) + + { + err := wg.WaitUntilReady(context.Background(), target) + if err == nil { + t.Fatal("no error") + } + + expected := "unexpected container status \"dead\"" + if err.Error() != expected { + t.Fatalf("expected %q, got %q", expected, err.Error()) + } + } +} diff --git a/wait/wait.go b/wait/wait.go index dae00c46ed..559ce39794 100644 --- a/wait/wait.go +++ b/wait/wait.go @@ -2,6 +2,8 @@ package wait import ( "context" + "errors" + "fmt" "io" "time" @@ -29,6 +31,28 @@ type StrategyTarget interface { State(context.Context) (*types.ContainerState, error) } +func checkTarget(ctx context.Context, target StrategyTarget) error { + state, err := target.State(ctx) + if err != nil { + return err + } + + return checkState(state) +} + +func checkState(state *types.ContainerState) error { + switch { + case state.Running: + return nil + case state.OOMKilled: + return errors.New("container crashed with out-of-memory (OOMKilled)") + case state.Status == "exited": + return fmt.Errorf("container exited with code %d", state.ExitCode) + default: + return fmt.Errorf("unexpected container status %q", state.Status) + } +} + func defaultStartupTimeout() time.Duration { return 60 * time.Second } diff --git a/wait/wait_test.go b/wait/wait_test.go new file mode 100644 index 0000000000..e7178a63a9 --- /dev/null +++ b/wait/wait_test.go @@ -0,0 +1,43 @@ +package wait + +import ( + "context" + "io" + + "github.com/docker/docker/api/types" + "github.com/docker/go-connections/nat" + tcexec "github.com/testcontainers/testcontainers-go/exec" +) + +type MockStrategyTarget struct { + HostImpl func(context.Context) (string, error) + PortsImpl func(context.Context) (nat.PortMap, error) + MappedPortImpl func(context.Context, nat.Port) (nat.Port, error) + LogsImpl func(context.Context) (io.ReadCloser, error) + ExecImpl func(context.Context, []string, ...tcexec.ProcessOption) (int, io.Reader, error) + StateImpl func(context.Context) (*types.ContainerState, error) +} + +func (st MockStrategyTarget) Host(ctx context.Context) (string, error) { + return st.HostImpl(ctx) +} + +func (st MockStrategyTarget) Ports(ctx context.Context) (nat.PortMap, error) { + return st.PortsImpl(ctx) +} + +func (st MockStrategyTarget) MappedPort(ctx context.Context, port nat.Port) (nat.Port, error) { + return st.MappedPortImpl(ctx, port) +} + +func (st MockStrategyTarget) Logs(ctx context.Context) (io.ReadCloser, error) { + return st.LogsImpl(ctx) +} + +func (st MockStrategyTarget) Exec(ctx context.Context, cmd []string, options ...tcexec.ProcessOption) (int, io.Reader, error) { + return st.ExecImpl(ctx, cmd, options...) +} + +func (st MockStrategyTarget) State(ctx context.Context) (*types.ContainerState, error) { + return st.StateImpl(ctx) +}