Skip to content

Commit

Permalink
ssh/tailssh: fix double race condition with non-pty command (#8405)
Browse files Browse the repository at this point in the history
There are two race conditions in output handling.

The first race condition is due to a misuse of exec.Cmd.StdoutPipe.
The documentation explicitly forbids concurrent use of StdoutPipe
with exec.Cmd.Wait (see golang/go#60908) because Wait will
close both sides of the pipe once the process ends without
any guarantees that all data has been read from the pipe.
To fix this, we allocate the os.Pipes ourselves and
manage cleanup ourselves when the process has ended.

The second race condition is because sshSession.run waits
upon exec.Cmd to finish and then immediately proceeds to call ss.Exit,
which will close all output streams going to the SSH client.
This may interrupt any asynchronous io.Copy still copying data.
To fix this, we close the write-side of the os.Pipes after
the process has finished (and before calling ss.Exit) and
synchronously wait for the io.Copy routines to finish.

Fixes #7601

Signed-off-by: Joe Tsai <joetsai@digital-static.net>
Co-authored-by: Maisem Ali <maisem@tailscale.com>
  • Loading branch information
dsnet and maisem committed Jun 22, 2023
1 parent d4de60c commit 61886e0
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 44 deletions.
41 changes: 15 additions & 26 deletions ssh/tailssh/incubator.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,10 +476,10 @@ func (ss *sshSession) launchProcess() error {
}
go resizeWindow(ptyDup /* arbitrary fd */, winCh)

ss.tty = tty
ss.stdin = pty
ss.stdout = os.NewFile(uintptr(ptyDup), pty.Name())
ss.stderr = nil // not available for pty
ss.wrStdin = pty
ss.rdStdout = os.NewFile(uintptr(ptyDup), pty.Name())
ss.rdStderr = nil // not available for pty
ss.childPipes = []io.Closer{tty}

return nil
}
Expand Down Expand Up @@ -658,40 +658,29 @@ func (ss *sshSession) startWithPTY() (ptyFile, tty *os.File, err error) {

// startWithStdPipes starts cmd with os.Pipe for Stdin, Stdout and Stderr.
func (ss *sshSession) startWithStdPipes() (err error) {
var stdin io.WriteCloser
var stdout, stderr io.ReadCloser
var rdStdin, wrStdout, wrStderr io.ReadWriteCloser
defer func() {
if err != nil {
for _, c := range []io.Closer{stdin, stdout, stderr} {
if c != nil {
c.Close()
}
}
closeAll(rdStdin, ss.wrStdin, ss.rdStdout, wrStdout, ss.rdStderr, wrStderr)
}
}()
cmd := ss.cmd
if cmd == nil {
if ss.cmd == nil {
return errors.New("nil cmd")
}
stdin, err = cmd.StdinPipe()
if err != nil {
if rdStdin, ss.wrStdin, err = os.Pipe(); err != nil {
return err
}
stdout, err = cmd.StdoutPipe()
if err != nil {
if ss.rdStdout, wrStdout, err = os.Pipe(); err != nil {
return err
}
stderr, err = cmd.StderrPipe()
if err != nil {
return err
}
if err := cmd.Start(); err != nil {
if ss.rdStderr, wrStderr, err = os.Pipe(); err != nil {
return err
}
ss.stdin = stdin
ss.stdout = stdout
ss.stderr = stderr
return nil
ss.cmd.Stdin = rdStdin
ss.cmd.Stdout = wrStdout
ss.cmd.Stderr = wrStderr
ss.childPipes = []io.Closer{rdStdin, wrStdout, wrStderr}
return ss.cmd.Start()
}

func envForUser(u *userMeta) []string {
Expand Down
58 changes: 40 additions & 18 deletions ssh/tailssh/tailssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -823,12 +823,16 @@ type sshSession struct {
agentListener net.Listener // non-nil if agent-forwarding requested+allowed

// initialized by launchProcess:
cmd *exec.Cmd
stdin io.WriteCloser
stdout io.ReadCloser
stderr io.Reader // nil for pty sessions
ptyReq *ssh.Pty // non-nil for pty sessions
tty *os.File // non-nil for pty sessions, must be closed after process exits
cmd *exec.Cmd
wrStdin io.WriteCloser
rdStdout io.ReadCloser
rdStderr io.ReadCloser // rdStderr is nil for pty sessions
ptyReq *ssh.Pty // non-nil for pty sessions

// childPipes is a list of pipes that need to be closed when the process exits.
// For pty sessions, this is the tty fd.
// For non-pty sessions, this is the stdin, stdout, stderr fds.
childPipes []io.Closer

// We use this sync.Once to ensure that we only terminate the process once,
// either it exits itself or is terminated
Expand Down Expand Up @@ -1107,21 +1111,22 @@ func (ss *sshSession) run() {

var processDone atomic.Bool
go func() {
defer ss.stdin.Close()
if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil {
defer ss.wrStdin.Close()
if _, err := io.Copy(rec.writer("i", ss.wrStdin), ss); err != nil {
logf("stdin copy: %v", err)
ss.cancelCtx(err)
}
}()
outputDone := make(chan struct{})
var openOutputStreams atomic.Int32
if ss.stderr != nil {
if ss.rdStderr != nil {
openOutputStreams.Store(2)
} else {
openOutputStreams.Store(1)
}
go func() {
defer ss.stdout.Close()
_, err := io.Copy(rec.writer("o", ss), ss.stdout)
defer ss.rdStdout.Close()
_, err := io.Copy(rec.writer("o", ss), ss.rdStdout)
if err != nil && !errors.Is(err, io.EOF) {
isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO)
if !isErrBecauseProcessExited {
Expand All @@ -1131,32 +1136,41 @@ func (ss *sshSession) run() {
}
if openOutputStreams.Add(-1) == 0 {
ss.CloseWrite()
close(outputDone)
}
}()
// stderr is nil for ptys.
if ss.stderr != nil {
// rdStderr is nil for ptys.
if ss.rdStderr != nil {
go func() {
_, err := io.Copy(ss.Stderr(), ss.stderr)
defer ss.rdStderr.Close()
_, err := io.Copy(ss.Stderr(), ss.rdStderr)
if err != nil {
logf("stderr copy: %v", err)
}
if openOutputStreams.Add(-1) == 0 {
ss.CloseWrite()
close(outputDone)
}
}()
}

if ss.tty != nil {
// If running a tty session, close the tty when the session is done.
defer ss.tty.Close()
}
err = ss.cmd.Wait()
processDone.Store(true)

// This will either make the SSH Termination goroutine be a no-op,
// or itself will be a no-op because the process was killed by the
// aforementioned goroutine.
ss.exitOnce.Do(func() {})

// Close the process-side of all pipes to signal the asynchronous
// io.Copy routines reading/writing from the pipes to terminate.
// Block for the io.Copy to finish before calling ss.Exit below.
closeAll(ss.childPipes...)
select {
case <-outputDone:
case <-ss.ctx.Done():
}

if err == nil {
ss.logf("Session complete")
ss.Exit(0)
Expand Down Expand Up @@ -1894,3 +1908,11 @@ type SSHTerminationError interface {
error
SSHTerminationMessage() string
}

func closeAll(cs ...io.Closer) {
for _, c := range cs {
if c != nil {
c.Close()
}
}
}
14 changes: 14 additions & 0 deletions ssh/tailssh/tailssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"os/user"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -947,6 +948,19 @@ func TestSSH(t *testing.T) {
// "foo\n" and "bar\n", not "\n" and "bar\n".
})

t.Run("large_file", func(t *testing.T) {
const wantSize = 1e6
var outBuf bytes.Buffer
cmd := execSSH("head", "-c", strconv.Itoa(wantSize), "/dev/zero")
cmd.Stdout = &outBuf
if err := cmd.Run(); err != nil {
t.Fatal(err)
}
if gotSize := outBuf.Len(); gotSize != wantSize {
t.Fatalf("got %d, want %d", gotSize, int(wantSize))
}
})

t.Run("stdin", func(t *testing.T) {
if cibuild.On() {
t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
Expand Down

0 comments on commit 61886e0

Please sign in to comment.