diff --git a/pkg/container/docker/client.go b/pkg/container/docker/client.go index c3a51c1a3..7276a6a49 100644 --- a/pkg/container/docker/client.go +++ b/pkg/container/docker/client.go @@ -556,14 +556,21 @@ func (c *Client) GetWorkloadInfo(ctx context.Context, workloadName string) (runt created = time.Time{} // Use zero time if parsing fails } + // Convert start time + startedAt, err := time.Parse(time.RFC3339Nano, info.State.StartedAt) + if err != nil { + startedAt = time.Time{} // Use zero time if parsing fails + } + return runtime.ContainerInfo{ - Name: strings.TrimPrefix(info.Name, "/"), - Image: info.Config.Image, - Status: info.State.Status, - State: dockerToDomainStatus(info.State.Status), - Created: created, - Labels: info.Config.Labels, - Ports: ports, + Name: strings.TrimPrefix(info.Name, "/"), + Image: info.Config.Image, + Status: info.State.Status, + State: dockerToDomainStatus(info.State.Status), + Created: created, + StartedAt: startedAt, + Labels: info.Config.Labels, + Ports: ports, }, nil } diff --git a/pkg/container/docker/errors.go b/pkg/container/docker/errors.go index 03303abdb..455c62b6a 100644 --- a/pkg/container/docker/errors.go +++ b/pkg/container/docker/errors.go @@ -18,6 +18,9 @@ var ( // ErrContainerExited is returned when a container has exited unexpectedly ErrContainerExited = fmt.Errorf("container exited unexpectedly") + + // ErrContainerRemoved is returned when a container has been removed + ErrContainerRemoved = fmt.Errorf("container removed") ) // ContainerError represents an error related to container operations diff --git a/pkg/container/docker/monitor.go b/pkg/container/docker/monitor.go index 71f3c6af7..43ca3d9f3 100644 --- a/pkg/container/docker/monitor.go +++ b/pkg/container/docker/monitor.go @@ -12,13 +12,14 @@ import ( // ContainerMonitor watches a container's state and reports when it exits type ContainerMonitor struct { - runtime runtime.Runtime - containerName string - stopCh chan struct{} - errorCh chan error - wg sync.WaitGroup - running bool - mutex sync.Mutex + runtime runtime.Runtime + containerName string + stopCh chan struct{} + errorCh chan error + wg sync.WaitGroup + running bool + mutex sync.Mutex + initialStartTime time.Time // Track container start time to detect restarts } // NewMonitor creates a new container monitor @@ -49,6 +50,13 @@ func (m *ContainerMonitor) StartMonitoring(ctx context.Context) (<-chan error, e return nil, NewContainerError(ErrContainerNotRunning, m.containerName, "container is not running") } + // Get initial container info to track start time + info, err := m.runtime.GetWorkloadInfo(ctx, m.containerName) + if err != nil { + return nil, NewContainerError(err, m.containerName, fmt.Sprintf("failed to get container info: %v", err)) + } + m.initialStartTime = info.StartedAt + m.running = true m.wg.Add(1) @@ -97,14 +105,14 @@ func (m *ContainerMonitor) monitor(ctx context.Context) { running, err := m.runtime.IsWorkloadRunning(checkCtx, m.containerName) cancel() // Always cancel the context to avoid leaks if err != nil { - // If the container is not found, it may have been removed + // If the container is not found, it has been removed if IsContainerNotFound(err) { - exitErr := NewContainerError( - ErrContainerExited, + removeErr := NewContainerError( + ErrContainerRemoved, m.containerName, - fmt.Sprintf("Container %s (%s) not found, it may have been removed", m.containerName, m.containerName), + fmt.Sprintf("Container %s not found, it has been removed", m.containerName), ) - m.errorCh <- exitErr + m.errorCh <- removeErr return } @@ -129,6 +137,22 @@ func (m *ContainerMonitor) monitor(ctx context.Context) { m.errorCh <- exitErr return } + + // Container is running - check if it was restarted (different start time) + infoCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + info, err := m.runtime.GetWorkloadInfo(infoCtx, m.containerName) + cancel() + if err == nil && !info.StartedAt.IsZero() && !info.StartedAt.Equal(m.initialStartTime) { + // Container was restarted (has a different start time) + restartErr := NewContainerError( + ErrContainerExited, + m.containerName, + fmt.Sprintf("Container %s was restarted (start time changed from %s to %s)", + m.containerName, m.initialStartTime.Format(time.RFC3339), info.StartedAt.Format(time.RFC3339)), + ) + m.errorCh <- restartErr + return + } } } } diff --git a/pkg/container/docker/monitor_test.go b/pkg/container/docker/monitor_test.go index 359944e5e..4043df8fb 100644 --- a/pkg/container/docker/monitor_test.go +++ b/pkg/container/docker/monitor_test.go @@ -39,6 +39,10 @@ func TestContainerMonitor_StartMonitoring_WhenRunningStarts(t *testing.T) { // StartMonitoring should verify running exactly once on first call. mockRT.EXPECT().IsWorkloadRunning(ctx, "workload-1").Return(true, nil).Times(1) + // StartMonitoring now gets the container start time + mockRT.EXPECT().GetWorkloadInfo(ctx, "workload-1").Return(rt.ContainerInfo{ + StartedAt: time.Now(), + }, nil).Times(1) m := NewMonitor(mockRT, "workload-1") ch, err := m.StartMonitoring(ctx) @@ -145,6 +149,10 @@ func TestContainerMonitor_StartStop_TerminatesQuickly(t *testing.T) { defer cancel() mockRT.EXPECT().IsWorkloadRunning(ctx, "workload-5").Return(true, nil).Times(1) + // StartMonitoring now gets the container start time + mockRT.EXPECT().GetWorkloadInfo(ctx, "workload-5").Return(rt.ContainerInfo{ + StartedAt: time.Now(), + }, nil).Times(1) m := NewMonitor(mockRT, "workload-5") ch, err := m.StartMonitoring(ctx) diff --git a/pkg/container/runtime/types.go b/pkg/container/runtime/types.go index 2bbdc75e0..b234b0ae5 100644 --- a/pkg/container/runtime/types.go +++ b/pkg/container/runtime/types.go @@ -55,6 +55,8 @@ type ContainerInfo struct { State WorkloadStatus // Created is the container creation timestamp Created time.Time + // StartedAt is when the container was last started (changes on restart) + StartedAt time.Time // Labels is the container labels Labels map[string]string // Ports is the container port mappings diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 54fc3fcd4..41f331308 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -4,6 +4,7 @@ package runner import ( "bytes" "context" + "errors" "fmt" "net/http" "os" @@ -18,6 +19,7 @@ import ( "github.com/stacklok/toolhive/pkg/auth/remote" "github.com/stacklok/toolhive/pkg/client" "github.com/stacklok/toolhive/pkg/config" + ct "github.com/stacklok/toolhive/pkg/container" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/labels" "github.com/stacklok/toolhive/pkg/logger" @@ -383,7 +385,7 @@ func (r *Runner) Run(ctx context.Context) error { } if !running { // Transport is no longer running (container exited or was stopped) - logger.Info("Transport is no longer running, exiting...") + logger.Warn("Transport is no longer running, attempting automatic restart...") close(doneCh) return } @@ -404,7 +406,7 @@ func (r *Runner) Run(ctx context.Context) error { case sig := <-sigCh: stopMCPServer(fmt.Sprintf("Received signal %s", sig)) case <-doneCh: - // The transport has already been stopped (likely by the container monitor) + // The transport has already been stopped (likely by the container exit) // Clean up the PID file and state // TODO: Stop writing to PID file once we migrate over to statuses. if err := process.RemovePIDFile(r.Config.BaseName); err != nil { @@ -414,12 +416,88 @@ func (r *Runner) Run(ctx context.Context) error { logger.Warnf("Warning: Failed to reset workload %s PID: %v", r.Config.BaseName, err) } - logger.Infof("MCP server %s stopped", r.Config.ContainerName) + // Check if workload still exists (using status manager and runtime) + // If it doesn't exist, it was removed - clean up client config + // If it exists, it exited unexpectedly - signal restart needed + exists, checkErr := r.doesWorkloadExist(ctx, r.Config.BaseName) + if checkErr != nil { + logger.Warnf("Warning: Failed to check if workload exists: %v", checkErr) + // Assume restart needed if we can't check + } else if !exists { + // Workload doesn't exist in `thv ls` - it was removed + logger.Infof( + "Workload %s no longer exists. Removing from client configurations.", + r.Config.BaseName, + ) + clientManager, clientErr := client.NewManager(ctx) + if clientErr == nil { + removeErr := clientManager.RemoveServerFromClients( + ctx, + r.Config.ContainerName, + r.Config.Group, + ) + if removeErr != nil { + logger.Warnf("Warning: Failed to remove from client config: %v", removeErr) + } else { + logger.Infof( + "Successfully removed %s from client configurations", + r.Config.ContainerName, + ) + } + } + logger.Infof("MCP server %s stopped and cleaned up", r.Config.ContainerName) + return nil // Exit gracefully, no restart + } + + // Workload still exists - signal restart needed + logger.Infof("MCP server %s stopped, restart needed", r.Config.ContainerName) + return fmt.Errorf("container exited, restart needed") } return nil } +// doesWorkloadExist checks if a workload exists in the status manager and runtime. +// For remote workloads, it trusts the status manager. +// For container workloads, it verifies the container exists in the runtime. +func (r *Runner) doesWorkloadExist(ctx context.Context, workloadName string) (bool, error) { + // Check if workload exists by trying to get it from status manager + workload, err := r.statusManager.GetWorkload(ctx, workloadName) + if err != nil { + if errors.Is(err, rt.ErrWorkloadNotFound) { + return false, nil + } + return false, fmt.Errorf("failed to check if workload exists: %w", err) + } + + // If remote workload, check if it should exist + if workload.Remote { + // For remote workloads, trust the status manager + return workload.Status != rt.WorkloadStatusError, nil + } + + // For container workloads, verify the container actually exists in the runtime + // Create a runtime instance to check if container exists + backend, err := ct.NewFactory().Create(ctx) + if err != nil { + logger.Warnf("Failed to create runtime to check container existence: %v", err) + // Fall back to status manager only + return workload.Status != rt.WorkloadStatusError, nil + } + + // Check if container exists in the runtime (not just running) + // GetWorkloadInfo will return an error if the container doesn't exist + _, err = backend.GetWorkloadInfo(ctx, workloadName) + if err != nil { + // Container doesn't exist + logger.Debugf("Container %s not found in runtime: %v", workloadName, err) + return false, nil + } + + // Container exists (may be running or stopped) + return true, nil +} + // handleRemoteAuthentication handles authentication for remote MCP servers func (r *Runner) handleRemoteAuthentication(ctx context.Context) (oauth2.TokenSource, error) { if r.Config.RemoteAuthConfig == nil { diff --git a/pkg/transport/http.go b/pkg/transport/http.go index 90e0431ee..10ba86e90 100644 --- a/pkg/transport/http.go +++ b/pkg/transport/http.go @@ -2,6 +2,7 @@ package transport import ( "context" + "errors" "fmt" "net/http" "net/url" @@ -10,9 +11,10 @@ import ( "golang.org/x/oauth2" "github.com/stacklok/toolhive/pkg/container" + "github.com/stacklok/toolhive/pkg/container/docker" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/logger" - "github.com/stacklok/toolhive/pkg/transport/errors" + transporterrors "github.com/stacklok/toolhive/pkg/transport/errors" "github.com/stacklok/toolhive/pkg/transport/middleware" "github.com/stacklok/toolhive/pkg/transport/proxy/transparent" "github.com/stacklok/toolhive/pkg/transport/types" @@ -58,6 +60,10 @@ type HTTPTransport struct { // Container monitor monitor rt.Monitor errorCh <-chan error + + // Container exit error (for determining if restart is needed) + containerExitErr error + exitErrMutex sync.Mutex } // NewHTTPTransport creates a new HTTP transport. @@ -165,7 +171,7 @@ func (t *HTTPTransport) Start(ctx context.Context) error { t.proxyPort, targetURI) } else { if t.containerName == "" { - return errors.ErrContainerNameNotSet + return transporterrors.ErrContainerNameNotSet } // For local containers, use the configured target URI @@ -273,14 +279,41 @@ func (t *HTTPTransport) handleContainerExit(ctx context.Context) { case <-ctx.Done(): return case err := <-t.errorCh: - logger.Infof("Container %s exited: %v", t.containerName, err) - // Stop the transport when the container exits + // Store the exit error so runner can check if restart is needed + t.exitErrMutex.Lock() + t.containerExitErr = err + t.exitErrMutex.Unlock() + + logger.Warnf("Container %s exited: %v", t.containerName, err) + + // Check if container was removed (not just exited) using typed error + if errors.Is(err, docker.ErrContainerRemoved) { + logger.Infof("Container %s was removed. Stopping proxy and cleaning up.", t.containerName) + } else { + logger.Infof("Container %s exited. Will attempt automatic restart.", t.containerName) + } + + // Stop the transport when the container exits/removed if stopErr := t.Stop(ctx); stopErr != nil { logger.Errorf("Error stopping transport after container exit: %v", stopErr) } } } +// ShouldRestart returns true if the container exited and should be restarted. +// Returns false if the container was removed (intentionally deleted). +func (t *HTTPTransport) ShouldRestart() bool { + t.exitErrMutex.Lock() + defer t.exitErrMutex.Unlock() + + if t.containerExitErr == nil { + return false // No exit error, normal shutdown + } + + // Don't restart if container was removed (use typed error check) + return !errors.Is(t.containerExitErr, docker.ErrContainerRemoved) +} + // IsRunning checks if the transport is currently running. func (t *HTTPTransport) IsRunning(_ context.Context) (bool, error) { t.mutex.Lock() diff --git a/pkg/transport/http_test.go b/pkg/transport/http_test.go new file mode 100644 index 000000000..b8d4aba1d --- /dev/null +++ b/pkg/transport/http_test.go @@ -0,0 +1,51 @@ +package transport + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/stacklok/toolhive/pkg/container/docker" +) + +// TestHTTPTransport_ShouldRestart tests the ShouldRestart logic +func TestHTTPTransport_ShouldRestart(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + exitError error + expectedResult bool + }{ + { + name: "container exited - should restart", + exitError: fmt.Errorf("container exited unexpectedly"), + expectedResult: true, + }, + { + name: "container removed - should not restart", + exitError: docker.NewContainerError(docker.ErrContainerRemoved, "test", "Container removed"), + expectedResult: false, + }, + { + name: "no error - should not restart", + exitError: nil, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + transport := &HTTPTransport{ + containerName: "test-container", + containerExitErr: tt.exitError, + } + + result := transport.ShouldRestart() + assert.Equal(t, tt.expectedResult, result) + }) + } +} diff --git a/pkg/transport/proxy/transparent/transparent_proxy.go b/pkg/transport/proxy/transparent/transparent_proxy.go index da3faa30d..21268844c 100644 --- a/pkg/transport/proxy/transparent/transparent_proxy.go +++ b/pkg/transport/proxy/transparent/transparent_proxy.go @@ -51,6 +51,9 @@ type TransparentProxy struct { // Mutex for protecting shared state mutex sync.Mutex + // Track if Stop() has been called + stopped bool + // Shutdown channel shutdownCh chan struct{} @@ -373,8 +376,10 @@ func (p *TransparentProxy) Start(ctx context.Context) error { ReadHeaderTimeout: 10 * time.Second, // Prevent Slowloris attacks } + // Capture server in local variable to avoid race with Stop() + server := p.server go func() { - err := p.server.Serve(ln) + err := server.Serve(ln) if err != nil && err != http.ErrServerClosed { var opErr *net.OpError if errors.As(err, &opErr) && opErr.Op == "accept" { @@ -435,6 +440,15 @@ func (p *TransparentProxy) Stop(ctx context.Context) error { p.mutex.Lock() defer p.mutex.Unlock() + // Check if already stopped + if p.stopped { + logger.Debugf("Proxy for %s is already stopped, skipping", p.targetURI) + return nil + } + + // Mark as stopped before closing channel + p.stopped = true + // Signal shutdown close(p.shutdownCh) diff --git a/pkg/transport/proxy/transparent/transparent_test.go b/pkg/transport/proxy/transparent/transparent_test.go index f9eba25fe..4ec462c64 100644 --- a/pkg/transport/proxy/transparent/transparent_test.go +++ b/pkg/transport/proxy/transparent/transparent_test.go @@ -351,3 +351,47 @@ func TestWellKnownPathWithoutAuthHandler(t *testing.T) { assert.Equal(t, http.StatusNotFound, recorder.Code, "Without auth handler, well-known path should return 404") } + +// TestTransparentProxy_IdempotentStop tests that Stop() can be called multiple times safely +func TestTransparentProxy_IdempotentStop(t *testing.T) { + t.Parallel() + + // Create a proxy + proxy := NewTransparentProxy("127.0.0.1", 0, "http://localhost:8080", nil, nil, false, false, "sse") + + ctx := context.Background() + + // Start the proxy (this creates the shutdown channel) + err := proxy.Start(ctx) + if err != nil { + t.Fatalf("Failed to start proxy: %v", err) + } + + // First stop should succeed + err = proxy.Stop(ctx) + assert.NoError(t, err, "First Stop() should succeed") + + // Second stop should also succeed (idempotent) + err = proxy.Stop(ctx) + assert.NoError(t, err, "Second Stop() should succeed (idempotent)") + + // Third stop should also succeed + err = proxy.Stop(ctx) + assert.NoError(t, err, "Third Stop() should succeed (idempotent)") +} + +// TestTransparentProxy_StopWithoutStart tests that Stop() works even if never started +func TestTransparentProxy_StopWithoutStart(t *testing.T) { + t.Parallel() + + // Create a proxy but don't start it + proxy := NewTransparentProxy("127.0.0.1", 0, "http://localhost:8080", nil, nil, false, false, "sse") + + ctx := context.Background() + + // Stop should handle being called without Start + err := proxy.Stop(ctx) + // This may return an error or succeed depending on implementation + // The key is it shouldn't panic + _ = err +} diff --git a/pkg/transport/stdio.go b/pkg/transport/stdio.go index d224bc387..d5a247900 100644 --- a/pkg/transport/stdio.go +++ b/pkg/transport/stdio.go @@ -20,6 +20,7 @@ import ( "golang.org/x/exp/jsonrpc2" "github.com/stacklok/toolhive/pkg/container" + "github.com/stacklok/toolhive/pkg/container/docker" rt "github.com/stacklok/toolhive/pkg/container/runtime" "github.com/stacklok/toolhive/pkg/logger" transporterrors "github.com/stacklok/toolhive/pkg/transport/errors" @@ -76,6 +77,10 @@ type StdioTransport struct { // Container monitor monitor rt.Monitor + // Container exit error (for determining if restart is needed) + containerExitErr error + exitErrMutex sync.Mutex + // Retry configuration (for testing) retryConfig *retryConfig } @@ -627,7 +632,19 @@ func (t *StdioTransport) handleContainerExit(ctx context.Context) { return } - logger.Infof("Container %s exited: %v", t.containerName, err) + // Store the exit error so runner can check if restart is needed + t.exitErrMutex.Lock() + t.containerExitErr = err + t.exitErrMutex.Unlock() + + logger.Warnf("Container %s exited: %v", t.containerName, err) + + // Check if container was removed (not just exited) using typed error + if errors.Is(err, docker.ErrContainerRemoved) { + logger.Infof("Container %s was removed. Stopping proxy and cleaning up.", t.containerName) + } else { + logger.Infof("Container %s exited. Will attempt automatic restart.", t.containerName) + } // Check if the transport is already stopped before trying to stop it select { @@ -647,3 +664,17 @@ func (t *StdioTransport) handleContainerExit(ctx context.Context) { } } } + +// ShouldRestart returns true if the container exited and should be restarted. +// Returns false if the container was removed (intentionally deleted). +func (t *StdioTransport) ShouldRestart() bool { + t.exitErrMutex.Lock() + defer t.exitErrMutex.Unlock() + + if t.containerExitErr == nil { + return false // No exit error, normal shutdown + } + + // Don't restart if container was removed (use typed error check) + return !errors.Is(t.containerExitErr, docker.ErrContainerRemoved) +} diff --git a/pkg/transport/stdio_test.go b/pkg/transport/stdio_test.go index 5a378d641..12a66dc08 100644 --- a/pkg/transport/stdio_test.go +++ b/pkg/transport/stdio_test.go @@ -16,6 +16,7 @@ import ( "go.uber.org/mock/gomock" "golang.org/x/exp/jsonrpc2" + "github.com/stacklok/toolhive/pkg/container/docker" "github.com/stacklok/toolhive/pkg/container/runtime/mocks" "github.com/stacklok/toolhive/pkg/logger" ) @@ -1013,3 +1014,44 @@ func TestStdinRaceCondition(t *testing.T) { // Clean up cancel() } + +// TestStdioTransport_ShouldRestart tests the ShouldRestart logic +func TestStdioTransport_ShouldRestart(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + exitError error + expectedResult bool + }{ + { + name: "container exited - should restart", + exitError: fmt.Errorf("container exited unexpectedly"), + expectedResult: true, + }, + { + name: "container removed - should not restart", + exitError: docker.NewContainerError(docker.ErrContainerRemoved, "test", "Container removed"), + expectedResult: false, + }, + { + name: "no error - should not restart", + exitError: nil, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + transport := &StdioTransport{ + containerName: "test-container", + containerExitErr: tt.exitError, + } + + result := transport.ShouldRestart() + assert.Equal(t, tt.expectedResult, result) + }) + } +} diff --git a/pkg/workloads/manager.go b/pkg/workloads/manager.go index 6389a037f..675a2ca17 100644 --- a/pkg/workloads/manager.go +++ b/pkg/workloads/manager.go @@ -390,7 +390,7 @@ func (d *DefaultManager) stopContainerWorkload(ctx context.Context, name string) return d.stopSingleContainerWorkload(ctx, &container) } -// RunWorkload runs a workload in the foreground. +// RunWorkload runs a workload in the foreground with automatic restart on container exit. func (d *DefaultManager) RunWorkload(ctx context.Context, runConfig *runner.RunConfig) error { // Ensure that the workload has a status entry before starting the process. if err := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusStarting, ""); err != nil { @@ -398,15 +398,84 @@ func (d *DefaultManager) RunWorkload(ctx context.Context, runConfig *runner.RunC return fmt.Errorf("failed to create workload status: %v", err) } - mcpRunner := runner.NewRunner(runConfig, d.statuses) - err := mcpRunner.Run(ctx) - if err != nil { - // If the run failed, we should set the status to error. - if statusErr := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusError, err.Error()); statusErr != nil { - logger.Warnf("Failed to set workload %s status to error: %v", runConfig.BaseName, statusErr) + // Retry loop with exponential backoff for container restarts + maxRetries := 10 // Allow many retries for transient issues + retryDelay := 5 * time.Second + + for attempt := 1; attempt <= maxRetries; attempt++ { + if attempt > 1 { + logger.Infof("Restart attempt %d/%d for %s after %v delay", attempt, maxRetries, runConfig.BaseName, retryDelay) + time.Sleep(retryDelay) + + // Exponential backoff: 5s, 10s, 20s, 40s, 60s (capped) + retryDelay *= 2 + if retryDelay > 60*time.Second { + retryDelay = 60 * time.Second + } } + + mcpRunner := runner.NewRunner(runConfig, d.statuses) + err := mcpRunner.Run(ctx) + + if err != nil { + // Check if this is a "container exited, restart needed" error + if err.Error() == "container exited, restart needed" { + logger.Warnf("Container %s exited unexpectedly (attempt %d/%d). Restarting...", + runConfig.BaseName, attempt, maxRetries) + + // Remove from client config so clients notice the restart + clientManager, clientErr := client.NewManager(ctx) + if clientErr == nil { + logger.Infof("Removing %s from client configurations before restart", runConfig.BaseName) + if removeErr := clientManager.RemoveServerFromClients(ctx, runConfig.BaseName, runConfig.Group); removeErr != nil { + logger.Warnf("Warning: Failed to remove from client config: %v", removeErr) + } + } + + // Set status to starting (since we're restarting) + statusErr := d.statuses.SetWorkloadStatus( + ctx, + runConfig.BaseName, + rt.WorkloadStatusStarting, + "Container exited, restarting", + ) + if statusErr != nil { + logger.Warnf("Failed to set workload %s status to starting: %v", runConfig.BaseName, statusErr) + } + + // If we haven't exhausted retries, continue the loop + if attempt < maxRetries { + continue + } + + // Exhausted all retries + logger.Errorf("Failed to restart %s after %d attempts. Giving up.", runConfig.BaseName, maxRetries) + statusErr = d.statuses.SetWorkloadStatus( + ctx, + runConfig.BaseName, + rt.WorkloadStatusError, + "Failed to restart after container exit", + ) + if statusErr != nil { + logger.Warnf("Failed to set workload %s status to error: %v", runConfig.BaseName, statusErr) + } + return fmt.Errorf("container restart failed after %d attempts", maxRetries) + } + + // Some other error - don't retry + logger.Errorf("Workload %s failed with error: %v", runConfig.BaseName, err) + if statusErr := d.statuses.SetWorkloadStatus(ctx, runConfig.BaseName, rt.WorkloadStatusError, err.Error()); statusErr != nil { + logger.Warnf("Failed to set workload %s status to error: %v", runConfig.BaseName, statusErr) + } + return err + } + + // Success - workload completed normally + return nil } - return err + + // Should not reach here, but just in case + return fmt.Errorf("unexpected end of retry loop for %s", runConfig.BaseName) } // validateSecretParameters validates the secret parameters for a workload. diff --git a/pkg/workloads/manager_test.go b/pkg/workloads/manager_test.go index a49603de5..53319fca2 100644 --- a/pkg/workloads/manager_test.go +++ b/pkg/workloads/manager_test.go @@ -1658,3 +1658,44 @@ func TestDefaultManager_updateSingleWorkload(t *testing.T) { }) } } + +// TestDefaultManager_RunWorkload_ContainerExitHandling tests container exit handling +func TestDefaultManager_RunWorkload_ContainerExitHandling(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRuntime := runtimeMocks.NewMockRuntime(ctrl) + mockStatusMgr := statusMocks.NewMockStatusManager(ctrl) + mockConfigProvider := configMocks.NewMockProvider(ctrl) + + mockConfigProvider.EXPECT().GetConfig().Return(&config.Config{}).AnyTimes() + + // Expect status to be set to starting + mockStatusMgr.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusStarting, ""). + Return(nil) + + // Expect status to be set to error on failure + mockStatusMgr.EXPECT(). + SetWorkloadStatus(gomock.Any(), "test-workload", runtime.WorkloadStatusError, gomock.Any()). + Return(nil).AnyTimes() + + manager := &DefaultManager{ + runtime: mockRuntime, + statuses: mockStatusMgr, + configProvider: mockConfigProvider, + } + + runConfig := &runner.RunConfig{ + ContainerName: "test-container", + BaseName: "test-workload", + Group: "default", + } + + // RunWorkload will fail because the runner can't actually run + // This tests that the status is properly set + err := manager.RunWorkload(context.Background(), runConfig) + assert.Error(t, err) +}