diff --git a/docker.go b/docker.go index cb799dc05a..e790c1b53e 100644 --- a/docker.go +++ b/docker.go @@ -62,16 +62,21 @@ type DockerContainer struct { isRunning bool imageWasBuilt bool // keepBuiltImage makes Terminate not remove the image if imageWasBuilt. - keepBuiltImage bool - provider *DockerProvider - sessionID string - terminationSignal chan bool - consumers []LogConsumer - raw *types.ContainerJSON - stopLogProductionCh chan bool - logProductionDone chan bool - logProductionError chan error - logProductionMutex sync.Mutex + keepBuiltImage bool + provider *DockerProvider + sessionID string + terminationSignal chan bool + consumers []LogConsumer + raw *types.ContainerJSON + logProductionError chan error + + // logProductionMutex protects logProductionStop channel so it can be started again. + // TODO: Remove locking once StartLogProducer has been removed and hence logging can + // only be started once. + logProductionMutex sync.Mutex + logProductionWaitGroup sync.WaitGroup + logProductionStop chan struct{} + logProductionTimeout *time.Duration logger Logging lifecycleHooks []ContainerLifecycleHooks @@ -652,9 +657,12 @@ func (c *DockerContainer) startLogProduction(ctx context.Context, opts ...LogPro c.logProductionMutex.Lock() defer c.logProductionMutex.Unlock() - if c.stopLogProductionCh != nil { + if c.logProductionStop != nil { return errors.New("log production already started") } + + c.logProductionStop = make(chan struct{}) + c.logProductionWaitGroup.Add(1) } for _, opt := range opts { @@ -676,21 +684,12 @@ func (c *DockerContainer) startLogProduction(ctx context.Context, opts ...LogPro c.logProductionTimeout = &maxLogProductionTimeout } - c.stopLogProductionCh = make(chan bool) - c.logProductionDone = make(chan bool) c.logProductionError = make(chan error, 1) - go func(stop <-chan bool, done chan<- bool, errorCh chan error) { - // signal the log production is done once go routine exits, this prevents race conditions around start/stop - // set c.stopLogProductionCh to nil so that it can be started again + go func() { defer func() { - defer c.logProductionMutex.Unlock() - close(done) - close(errorCh) - { - c.logProductionMutex.Lock() - c.stopLogProductionCh = nil - } + close(c.logProductionError) + c.logProductionWaitGroup.Done() }() since := "" @@ -708,15 +707,15 @@ func (c *DockerContainer) startLogProduction(ctx context.Context, opts ...LogPro r, err := c.provider.client.ContainerLogs(ctx, c.GetContainerID(), options) if err != nil { - errorCh <- err + c.logProductionError <- err return } defer c.provider.Close() for { select { - case <-stop: - errorCh <- r.Close() + case <-c.logProductionStop: + c.logProductionError <- r.Close() return default: h := make([]byte, 8) @@ -772,7 +771,7 @@ func (c *DockerContainer) startLogProduction(ctx context.Context, opts ...LogPro } } } - }(c.stopLogProductionCh, c.logProductionDone, c.logProductionError) + }() return nil } @@ -782,17 +781,18 @@ func (c *DockerContainer) StopLogProducer() error { return c.stopLogProduction() } -// StopLogProducer will stop the concurrent process that is reading logs +// stopLogProduction will stop the concurrent process that is reading logs // and sending them to each added LogConsumer func (c *DockerContainer) stopLogProduction() error { c.logProductionMutex.Lock() defer c.logProductionMutex.Unlock() - if c.stopLogProductionCh != nil { - c.stopLogProductionCh <- true - // block until the log production is actually done in order to avoid strange races - <-c.logProductionDone - c.stopLogProductionCh = nil - c.logProductionDone = nil + if c.logProductionStop != nil { + close(c.logProductionStop) + c.logProductionWaitGroup.Wait() + // Set c.logProductionStop to nil so that it can be started again. + // TODO: Remove this once StartLogProducer has been removed and hence logging can + // only be started once. + c.logProductionStop = nil return <-c.logProductionError } return nil @@ -1113,17 +1113,16 @@ func (p *DockerProvider) CreateContainer(ctx context.Context, req ContainerReque } c := &DockerContainer{ - ID: resp.ID, - WaitingFor: req.WaitingFor, - Image: imageName, - imageWasBuilt: req.ShouldBuildImage(), - keepBuiltImage: req.ShouldKeepBuiltImage(), - sessionID: core.SessionID(), - provider: p, - terminationSignal: termSignal, - stopLogProductionCh: nil, - logger: p.Logger, - lifecycleHooks: req.LifecycleHooks, + ID: resp.ID, + WaitingFor: req.WaitingFor, + Image: imageName, + imageWasBuilt: req.ShouldBuildImage(), + keepBuiltImage: req.ShouldKeepBuiltImage(), + sessionID: core.SessionID(), + provider: p, + terminationSignal: termSignal, + logger: p.Logger, + lifecycleHooks: req.LifecycleHooks, } err = c.createdHook(ctx) @@ -1216,15 +1215,14 @@ func (p *DockerProvider) ReuseOrCreateContainer(ctx context.Context, req Contain } dc := &DockerContainer{ - ID: c.ID, - WaitingFor: req.WaitingFor, - Image: c.Image, - sessionID: sessionID, - provider: p, - terminationSignal: termSignal, - stopLogProductionCh: nil, - logger: p.Logger, - lifecycleHooks: []ContainerLifecycleHooks{combineContainerHooks(defaultHooks, req.LifecycleHooks)}, + ID: c.ID, + WaitingFor: req.WaitingFor, + Image: c.Image, + sessionID: sessionID, + provider: p, + terminationSignal: termSignal, + logger: p.Logger, + lifecycleHooks: []ContainerLifecycleHooks{combineContainerHooks(defaultHooks, req.LifecycleHooks)}, } err = dc.startedHook(ctx) @@ -1526,7 +1524,6 @@ func containerFromDockerResponse(ctx context.Context, response types.Container) container.sessionID = core.SessionID() container.consumers = []LogConsumer{} - container.stopLogProductionCh = nil container.isRunning = response.State == "running" // the termination signal should be obtained from the reaper diff --git a/logconsumer_test.go b/logconsumer_test.go index 9c9b25fa09..192a00f954 100644 --- a/logconsumer_test.go +++ b/logconsumer_test.go @@ -9,6 +9,7 @@ import ( "net/http" "os" "strings" + "sync" "testing" "time" @@ -23,8 +24,9 @@ import ( const lastMessage = "DONE" type TestLogConsumer struct { - Msgs []string - Done chan bool + mtx sync.Mutex + msgs []string + Done chan struct{} // Accepted provides a blocking way of ensuring the logs messages have been consumed. // This allows for proper synchronization during Test_StartStop in particular. @@ -35,11 +37,21 @@ type TestLogConsumer struct { func (g *TestLogConsumer) Accept(l Log) { s := string(l.Content) if s == fmt.Sprintf("echo %s\n", lastMessage) { - g.Done <- true + close(g.Done) return } g.Accepted <- s - g.Msgs = append(g.Msgs, s) + + g.mtx.Lock() + defer g.mtx.Unlock() + g.msgs = append(g.msgs, s) +} + +func (g *TestLogConsumer) Msgs() []string { + g.mtx.Lock() + defer g.mtx.Unlock() + + return g.msgs } // devNullAcceptorChan returns string channel that essentially sends all strings to dev null @@ -57,8 +69,8 @@ func Test_LogConsumerGetsCalled(t *testing.T) { ctx := context.Background() g := TestLogConsumer{ - Msgs: []string{}, - Done: make(chan bool), + msgs: []string{}, + Done: make(chan struct{}), Accepted: devNullAcceptorChan(), } @@ -100,7 +112,7 @@ func Test_LogConsumerGetsCalled(t *testing.T) { t.Fatal("never received final log message") } - assert.Equal(t, []string{"ready\n", "echo hello\n", "echo there\n"}, g.Msgs) + assert.Equal(t, []string{"ready\n", "echo hello\n", "echo there\n"}, g.Msgs()) terminateContainerOnEnd(t, ctx, c) } @@ -172,13 +184,13 @@ func Test_MultipleLogConsumers(t *testing.T) { ctx := context.Background() first := TestLogConsumer{ - Msgs: []string{}, - Done: make(chan bool), + msgs: []string{}, + Done: make(chan struct{}), Accepted: devNullAcceptorChan(), } second := TestLogConsumer{ - Msgs: []string{}, - Done: make(chan bool), + msgs: []string{}, + Done: make(chan struct{}), Accepted: devNullAcceptorChan(), } @@ -214,13 +226,13 @@ func Test_MultipleLogConsumers(t *testing.T) { <-first.Done <-second.Done - assert.Equal(t, []string{"ready\n", "echo mlem\n"}, first.Msgs) - assert.Equal(t, []string{"ready\n", "echo mlem\n"}, second.Msgs) + assert.Equal(t, []string{"ready\n", "echo mlem\n"}, first.Msgs()) + assert.Equal(t, []string{"ready\n", "echo mlem\n"}, second.Msgs()) require.NoError(t, c.Terminate(ctx)) } func TestContainerLogWithErrClosed(t *testing.T) { - if os.Getenv("XDG_RUNTIME_DIR") != "" { + if os.Getenv("GITHUB_RUN_ID") != "" { t.Skip("Skipping as flaky on GitHub Actions, Please see https://github.com/testcontainers/testcontainers-go/issues/1924") } @@ -290,8 +302,8 @@ func TestContainerLogWithErrClosed(t *testing.T) { } consumer := TestLogConsumer{ - Msgs: []string{}, - Done: make(chan bool), + msgs: []string{}, + Done: make(chan struct{}), Accepted: devNullAcceptorChan(), } @@ -317,7 +329,7 @@ func TestContainerLogWithErrClosed(t *testing.T) { // Gather the initial container logs time.Sleep(time.Second * 1) - existingLogs := len(consumer.Msgs) + existingLogs := len(consumer.Msgs()) hitNginx := func() { i, _, err := dind.Exec(ctx, []string{"wget", "--spider", "localhost:" + port.Port()}) @@ -328,10 +340,11 @@ func TestContainerLogWithErrClosed(t *testing.T) { hitNginx() time.Sleep(time.Second * 1) - if len(consumer.Msgs)-existingLogs != 1 { - t.Fatalf("logConsumer should have 1 new log message, instead has: %v", consumer.Msgs[existingLogs:]) + msgs := consumer.Msgs() + if len(msgs)-existingLogs != 1 { + t.Fatalf("logConsumer should have 1 new log message, instead has: %v", msgs[existingLogs:]) } - existingLogs = len(consumer.Msgs) + existingLogs = len(consumer.Msgs()) iptableArgs := []string{ "INPUT", "-p", "tcp", "--dport", "2375", @@ -351,10 +364,11 @@ func TestContainerLogWithErrClosed(t *testing.T) { hitNginx() hitNginx() time.Sleep(time.Second * 1) - if len(consumer.Msgs)-existingLogs != 2 { + msgs = consumer.Msgs() + if len(msgs)-existingLogs != 2 { t.Fatalf( "LogConsumer should have 2 new log messages after detecting closed connection and"+ - " re-requesting logs. Instead has:\n%s", consumer.Msgs[existingLogs:], + " re-requesting logs. Instead has:\n%s", msgs[existingLogs:], ) } } @@ -389,8 +403,8 @@ func TestContainerLogsShouldBeWithoutStreamHeader(t *testing.T) { func TestContainerLogsEnableAtStart(t *testing.T) { ctx := context.Background() g := TestLogConsumer{ - Msgs: []string{}, - Done: make(chan bool), + msgs: []string{}, + Done: make(chan struct{}), Accepted: devNullAcceptorChan(), } @@ -434,7 +448,7 @@ func TestContainerLogsEnableAtStart(t *testing.T) { case <-time.After(10 * time.Second): t.Fatal("never received final log message") } - assert.Equal(t, []string{"ready\n", "echo hello\n", "echo there\n"}, g.Msgs) + assert.Equal(t, []string{"ready\n", "echo hello\n", "echo there\n"}, g.Msgs()) terminateContainerOnEnd(t, ctx, c) } @@ -443,8 +457,8 @@ func Test_StartLogProductionStillStartsWithTooLowTimeout(t *testing.T) { ctx := context.Background() g := TestLogConsumer{ - Msgs: []string{}, - Done: make(chan bool), + msgs: []string{}, + Done: make(chan struct{}), Accepted: devNullAcceptorChan(), } @@ -475,8 +489,8 @@ func Test_StartLogProductionStillStartsWithTooHighTimeout(t *testing.T) { ctx := context.Background() g := TestLogConsumer{ - Msgs: []string{}, - Done: make(chan bool), + msgs: []string{}, + Done: make(chan struct{}), Accepted: devNullAcceptorChan(), } @@ -518,10 +532,11 @@ func Test_MultiContainerLogConsumer_CancelledContext(t *testing.T) { // Context with cancellation functionality for simulating user interruption ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // Ensure it gets called. first := TestLogConsumer{ - Msgs: []string{}, - Done: make(chan bool), + msgs: []string{}, + Done: make(chan struct{}), Accepted: devNullAcceptorChan(), } @@ -555,8 +570,8 @@ func Test_MultiContainerLogConsumer_CancelledContext(t *testing.T) { require.NoError(t, err) second := TestLogConsumer{ - Msgs: []string{}, - Done: make(chan bool), + msgs: []string{}, + Done: make(chan struct{}), Accepted: devNullAcceptorChan(), } @@ -592,7 +607,7 @@ func Test_MultiContainerLogConsumer_CancelledContext(t *testing.T) { // Handling the termination of the containers defer func() { shutdownCtx, shutdownCancel := context.WithTimeout( - context.Background(), 60*time.Second, + context.Background(), 10*time.Second, ) defer shutdownCancel() _ = c.Terminate(shutdownCtx) @@ -604,8 +619,8 @@ func Test_MultiContainerLogConsumer_CancelledContext(t *testing.T) { // We check log size due to context cancellation causing // varying message counts, leading to test failure. - assert.GreaterOrEqual(t, len(first.Msgs), 2) - assert.GreaterOrEqual(t, len(second.Msgs), 2) + assert.GreaterOrEqual(t, len(first.Msgs()), 2) + assert.GreaterOrEqual(t, len(second.Msgs()), 2) // Restore stderr w.Close()