Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support Ryuk for the compose module #2485

Merged
merged 19 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci-test-go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ jobs:
continue-on-error: ${{ !inputs.fail-fast }}
env:
TESTCONTAINERS_RYUK_DISABLED: "${{ inputs.ryuk-disabled }}"
RYUK_CONNECTION_TIMEOUT: "${{ inputs.project-directory == 'modules/compose' && '5m' || '60s' }}"
RYUK_RECONNECTION_TIMEOUT: "${{ inputs.project-directory == 'modules/compose' && '30s' || '10s' }}"
steps:
- name: Setup rootless Docker
if: ${{ inputs.rootless-docker }}
Expand Down
9 changes: 9 additions & 0 deletions docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ func (c *DockerContainer) SetProvider(provider *DockerProvider) {
c.provider = provider
}

// SetTerminationSignal sets the termination signal for the container
func (c *DockerContainer) SetTerminationSignal(signal chan bool) {
c.terminationSignal = signal
}

func (c *DockerContainer) GetContainerID() string {
return c.ID
}
Expand Down Expand Up @@ -846,6 +851,10 @@ func (n *DockerNetwork) Remove(ctx context.Context) error {
return n.provider.client.NetworkRemove(ctx, n.ID)
}

func (n *DockerNetwork) SetTerminationSignal(signal chan bool) {
n.terminationSignal = signal
}

// DockerProvider implements the ContainerProvider interface
type DockerProvider struct {
*DockerProviderOptions
Expand Down
23 changes: 23 additions & 0 deletions modules/compose/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package compose
import (
"context"
"errors"
"fmt"
"path/filepath"
"runtime"
"strings"
Expand Down Expand Up @@ -121,6 +122,25 @@ func NewDockerComposeWith(opts ...ComposeStackOption) (*dockerCompose, error) {
return nil, err
}

reaperProvider, err := testcontainers.NewDockerProvider()
if err != nil {
return nil, fmt.Errorf("failed to create reaper provider for compose: %w", err)
}

tcConfig := reaperProvider.Config()

var composeReaper *testcontainers.Reaper
if !tcConfig.RyukDisabled {
// NewReaper is deprecated: we need to find a way to create the reaper for compose
// bypassing the deprecation.
r, err := testcontainers.NewReaper(context.Background(), testcontainers.SessionID(), reaperProvider, "")
if err != nil {
return nil, fmt.Errorf("failed to create reaper for compose: %w", err)
}

composeReaper = r
}

composeAPI := &dockerCompose{
name: composeOptions.Identifier,
configs: composeOptions.Paths,
Expand All @@ -129,6 +149,9 @@ func NewDockerComposeWith(opts ...ComposeStackOption) (*dockerCompose, error) {
dockerClient: dockerCli.Client(),
waitStrategies: make(map[string]wait.Strategy),
containers: make(map[string]*testcontainers.DockerContainer),
networks: make(map[string]*testcontainers.DockerNetwork),
sessionID: testcontainers.SessionID(),
reaper: composeReaper,
}

return composeAPI, nil
Expand Down
127 changes: 124 additions & 3 deletions modules/compose/compose_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/compose-spec/compose-go/v2/types"
"github.com/docker/cli/cli/command"
"github.com/docker/compose/v2/pkg/api"
dockertypes "github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/client"
Expand Down Expand Up @@ -134,6 +135,10 @@ type dockerCompose struct {
// used in ServiceContainer(...) function to avoid calls to the Docker API
containers map[string]*testcontainers.DockerContainer

// cache for containers that are part of the stack
// used in ServiceContainer(...) function to avoid calls to the Docker API
mdelapenya marked this conversation as resolved.
Show resolved Hide resolved
networks map[string]*testcontainers.DockerNetwork

// docker/compose API service instance used to control the compose stack
composeService api.Service

Expand All @@ -147,6 +152,12 @@ type dockerCompose struct {
// compiled compose project
// can be nil if the stack wasn't started yet
project *types.Project

// sessionID is used to identify the reaper session
sessionID string

// reaper is used to clean up containers after the stack is stopped
reaper *testcontainers.Reaper
}

func (d *dockerCompose) ServiceContainer(ctx context.Context, svcName string) (*testcontainers.DockerContainer, error) {
Expand Down Expand Up @@ -235,26 +246,89 @@ func (d *dockerCompose) Up(ctx context.Context, opts ...StackUpOption) error {
return err
}

err = d.lookupNetworks(ctx)
if err != nil {
return err
}

if d.reaper != nil {
for _, n := range d.networks {
termSignal, err := d.reaper.Connect()
if err != nil {
return fmt.Errorf("failed to connect to reaper: %w", err)
}
n.SetTerminationSignal(termSignal)

// Cleanup on error, otherwise set termSignal to nil before successful return.
defer func() {
if termSignal != nil {
termSignal <- true
}
}()
}
}

errGrpContainers, errGrpCtx := errgroup.WithContext(ctx)

for _, srv := range d.project.Services {
// we are going to connect each container to the reaper
srv := srv
errGrpContainers.Go(func() error {
dc, err := d.lookupContainer(errGrpCtx, srv.Name)
if err != nil {
return err
}

if d.reaper != nil {
termSignal, err := d.reaper.Connect()
if err != nil {
return fmt.Errorf("failed to connect to reaper: %w", err)
}
dc.SetTerminationSignal(termSignal)

// Cleanup on error, otherwise set termSignal to nil before successful return.
defer func() {
if termSignal != nil {
termSignal <- true
}
}()
}

d.containers[srv.Name] = dc

return nil
})
}

// wait here for the containers lookup to finish
if err := errGrpContainers.Wait(); err != nil {
return err
}

if len(d.waitStrategies) == 0 {
return nil
}

errGrp, errGrpCtx := errgroup.WithContext(ctx)
errGrpWait, errGrpCtx := errgroup.WithContext(ctx)

for svc, strategy := range d.waitStrategies { // pinning the variables
svc := svc
strategy := strategy

errGrp.Go(func() error {
errGrpWait.Go(func() error {
target, err := d.lookupContainer(errGrpCtx, svc)
if err != nil {
return err
}

// cache all the containers on compose.up
d.containers[svc] = target

return strategy.WaitUntilReady(errGrpCtx, target)
})
}

return errGrp.Wait()
return errGrpWait.Wait()
}

func (d *dockerCompose) WaitForService(s string, strategy wait.Strategy) ComposeStack {
Expand Down Expand Up @@ -327,6 +401,34 @@ func (d *dockerCompose) lookupContainer(ctx context.Context, svcName string) (*t
return container, nil
}

func (d *dockerCompose) lookupNetworks(ctx context.Context) error {
d.containersLock.Lock()
defer d.containersLock.Unlock()

listOptions := dockertypes.NetworkListOptions{
Filters: filters.NewArgs(
filters.Arg("label", fmt.Sprintf("%s=%s", api.ProjectLabel, d.name)),
),
}

networks, err := d.dockerClient.NetworkList(ctx, listOptions)
if err != nil {
return err
}

for _, n := range networks {
dn := &testcontainers.DockerNetwork{
ID: n.ID,
Name: n.Name,
Driver: n.Driver,
}

d.networks[n.ID] = dn
}

return nil
}

func (d *dockerCompose) compileProject(ctx context.Context) (*types.Project, error) {
const nameAndDefaultConfigPath = 2
projectOptions := make([]cli.ProjectOptionsFn, len(d.projectOptions), len(d.projectOptions)+nameAndDefaultConfigPath)
Expand All @@ -353,6 +455,11 @@ func (d *dockerCompose) compileProject(ctx context.Context) (*types.Project, err
api.ConfigFilesLabel: strings.Join(proj.ComposeFiles, ","),
api.OneoffLabel: "False", // default, will be overridden by `run` command
}

for k, label := range testcontainers.GenericLabels() {
s.CustomLabels[k] = label
}

for i, envFile := range compiledOptions.EnvFiles {
// add a label for each env file, indexed by its position
s.CustomLabels[fmt.Sprintf("%s.%d", api.EnvironmentFileLabel, i)] = envFile
Expand All @@ -361,6 +468,20 @@ func (d *dockerCompose) compileProject(ctx context.Context) (*types.Project, err
proj.Services[i] = s
}

for key, n := range proj.Networks {
n.Labels = map[string]string{
api.ProjectLabel: proj.Name,
api.NetworkLabel: n.Name,
api.VersionLabel: api.ComposeVersion,
}

for k, label := range testcontainers.GenericLabels() {
n.Labels[k] = label
}

proj.Networks[key] = n
}

return proj, nil
}

Expand Down