Skip to content

Change the Stage interface to make stdin/stdout handling more flexible #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
106 changes: 82 additions & 24 deletions pipe/command.go
Original file line number Diff line number Diff line change
@@ -18,9 +18,13 @@ import (
// commandStage is a pipeline `Stage` based on running an external
// command and piping the data through its stdin and stdout.
type commandStage struct {
name string
stdin io.Closer
cmd *exec.Cmd
name string
cmd *exec.Cmd

// lateClosers is a list of things that have to be closed once the
// command has finished.
lateClosers []io.Closer

done chan struct{}
wg errgroup.Group
stderr bytes.Buffer
@@ -30,6 +34,10 @@ type commandStage struct {
ctxErr atomic.Value
}

var (
_ Stage = (*commandStage)(nil)
)

// Command returns a pipeline `Stage` based on the specified external
// `command`, run with the given command-line `args`. Its stdin and
// stdout are handled as usual, and its stderr is collected and
@@ -59,33 +67,80 @@ func (s *commandStage) Name() string {
return s.name
}

func (s *commandStage) Preferences() StagePreferences {
prefs := StagePreferences{
StdinPreference: IOPreferenceFile,
StdoutPreference: IOPreferenceFile,
}
if s.cmd.Stdin != nil {
prefs.StdinPreference = IOPreferenceNil
}
if s.cmd.Stdout != nil {
prefs.StdoutPreference = IOPreferenceNil
}

return prefs
}

func (s *commandStage) Start(
ctx context.Context, env Env, stdin io.ReadCloser,
) (io.ReadCloser, error) {
ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser,
) error {
if s.cmd.Dir == "" {
s.cmd.Dir = env.Dir
}

s.setupEnv(ctx, env)

// Things that have to be closed as soon as the command has
// started:
var earlyClosers []io.Closer

// See the type command for `Stage` and the long comment in
// `Pipeline.WithStdin()` for the explanation of this unwrapping
// and closing behavior.

if stdin != nil {
// See the long comment in `Pipeline.Start()` for the
// explanation of this special case.
switch stdin := stdin.(type) {
case nopCloser:
case readerNopCloser:
// In this case, we shouldn't close it. But unwrap it for
// efficiency's sake:
s.cmd.Stdin = stdin.Reader
case nopCloserWriterTo:
case readerWriterToNopCloser:
// In this case, we shouldn't close it. But unwrap it for
// efficiency's sake:
s.cmd.Stdin = stdin.Reader
case *os.File:
// In this case, we can close stdin as soon as the command
// has started:
s.cmd.Stdin = stdin
earlyClosers = append(earlyClosers, stdin)
default:
// In this case, we need to close `stdin`, but we should
// only do so after the command has finished:
s.cmd.Stdin = stdin
s.lateClosers = append(s.lateClosers, stdin)
}
// Also keep a copy so that we can close it when the command exits:
s.stdin = stdin
}

stdout, err := s.cmd.StdoutPipe()
if err != nil {
return nil, err
if stdout != nil {
// See the long comment in `Pipeline.Start()` for the
// explanation of this special case.
switch stdout := stdout.(type) {
case writerNopCloser:
// In this case, we shouldn't close it. But unwrap it for
// efficiency's sake:
s.cmd.Stdout = stdout.Writer
case *os.File:
// In this case, we can close stdout as soon as the command
// has started:
s.cmd.Stdout = stdout
earlyClosers = append(earlyClosers, stdout)
default:
// In this case, we need to close `stdout`, but we should
// only do so after the command has finished:
s.cmd.Stdout = stdout
s.lateClosers = append(s.lateClosers, stdout)
}
}

// If the caller hasn't arranged otherwise, read the command's
@@ -97,7 +152,7 @@ func (s *commandStage) Start(
// can be sure.
p, err := s.cmd.StderrPipe()
if err != nil {
return nil, err
return err
}
s.wg.Go(func() error {
_, err := io.Copy(&s.stderr, p)
@@ -114,7 +169,11 @@ func (s *commandStage) Start(
s.runInOwnProcessGroup()

if err := s.cmd.Start(); err != nil {
return nil, err
return err
}

for _, closer := range earlyClosers {
_ = closer.Close()
}

// Arrange for the process to be killed (gently) if the context
@@ -128,7 +187,7 @@ func (s *commandStage) Start(
}
}()

return stdout, nil
return nil
}

// setupEnv sets or modifies the environment that will be passed to
@@ -217,19 +276,18 @@ func (s *commandStage) Wait() error {

// Make sure that any stderr is copied before `s.cmd.Wait()`
// closes the read end of the pipe:
wErr := s.wg.Wait()
wgErr := s.wg.Wait()

err := s.cmd.Wait()
err = s.filterCmdError(err)

if err == nil && wErr != nil {
err = wErr
if err == nil && wgErr != nil {
err = wgErr
}

if s.stdin != nil {
cErr := s.stdin.Close()
if cErr != nil && err == nil {
return cErr
for _, closer := range s.lateClosers {
if closeErr := closer.Close(); closeErr != nil && err == nil {
err = closeErr
}
}

3 changes: 2 additions & 1 deletion pipe/command_test.go
Original file line number Diff line number Diff line change
@@ -79,7 +79,8 @@ func TestCopyEnvWithOverride(t *testing.T) {
ex := ex
t.Run(ex.label, func(t *testing.T) {
assert.ElementsMatch(t, ex.expectedResult,
copyEnvWithOverrides(ex.env, ex.overrides))
copyEnvWithOverrides(ex.env, ex.overrides),
)
})
}
}
4 changes: 4 additions & 0 deletions pipe/export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package pipe

// This file exports a functions to be used only for testing.
var UnwrapNopCloser = unwrapNopCloser
41 changes: 34 additions & 7 deletions pipe/function.go
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@ import (
// StageFunc is a function that can be used to power a `goStage`. It
// should read its input from `stdin` and write its output to
// `stdout`. `stdin` and `stdout` will be closed automatically (if
// necessary) once the function returns.
// non-nil) once the function returns.
//
// Neither `stdin` nor `stdout` are necessarily buffered. If the
// `StageFunc` requires buffering, it needs to arrange that itself.
@@ -38,26 +38,53 @@ type goStage struct {
err error
}

var (
_ Stage = (*goStage)(nil)
)

func (s *goStage) Name() string {
return s.name
}

func (s *goStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) {
r, w := io.Pipe()
func (s *goStage) Preferences() StagePreferences {
return StagePreferences{
StdinPreference: IOPreferenceUndefined,
StdoutPreference: IOPreferenceUndefined,
}
}

func (s *goStage) Start(
ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser,
) error {
var r io.Reader = stdin
if stdin, ok := stdin.(readerNopCloser); ok {
r = stdin.Reader
}

var w io.Writer = stdout
if stdout, ok := stdout.(writerNopCloser); ok {
w = stdout.Writer
}

go func() {
s.err = s.f(ctx, env, stdin, w)
if err := w.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing output pipe for stage %q: %w", s.Name(), err)
s.err = s.f(ctx, env, r, w)

if stdout != nil {
if err := stdout.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdout for stage %q: %w", s.Name(), err)
}
}

if stdin != nil {
if err := stdin.Close(); err != nil && s.err == nil {
s.err = fmt.Errorf("error closing stdin for stage %q: %w", s.Name(), err)
}
}

close(s.done)
}()

return r, nil
return nil
}

func (s *goStage) Wait() error {
62 changes: 0 additions & 62 deletions pipe/iocopier.go

This file was deleted.

30 changes: 20 additions & 10 deletions pipe/memorylimit.go
Original file line number Diff line number Diff line change
@@ -11,12 +11,12 @@ import (

const memoryPollInterval = time.Second

// ErrMemoryLimitExceeded is the error that will be used to kill a process, if
// necessary, from MemoryLimit.
// ErrMemoryLimitExceeded is the error that will be used to kill a
// process, if necessary, from MemoryLimit.
var ErrMemoryLimitExceeded = errors.New("memory limit exceeded")

// LimitableStage is the superset of Stage that must be implemented by stages
// passed to MemoryLimit and MemoryObserver.
// LimitableStage is the superset of `Stage` that must be implemented
// by stages passed to MemoryLimit and MemoryObserver.
type LimitableStage interface {
Stage

@@ -175,12 +175,24 @@ func (m *memoryWatchStage) Name() string {
return m.stage.Name() + m.nameSuffix
}

func (m *memoryWatchStage) Start(ctx context.Context, env Env, stdin io.ReadCloser) (io.ReadCloser, error) {
io, err := m.stage.Start(ctx, env, stdin)
if err != nil {
return nil, err
func (m *memoryWatchStage) Preferences() StagePreferences {
return m.stage.Preferences()
}

func (m *memoryWatchStage) Start(
ctx context.Context, env Env, stdin io.ReadCloser, stdout io.WriteCloser,
) error {
if err := m.stage.Start(ctx, env, stdin, stdout); err != nil {
return err
}

m.monitor(ctx)

return nil
}

// monitor starts up a goroutine that monitors the memory of `m`.
func (m *memoryWatchStage) monitor(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
m.cancel = cancel
m.wg.Add(1)
@@ -189,8 +201,6 @@ func (m *memoryWatchStage) Start(ctx context.Context, env Env, stdin io.ReadClos
m.watch(ctx, m.stage)
m.wg.Done()
}()

return io, nil
}

func (m *memoryWatchStage) Wait() error {
Loading
Oops, something went wrong.