Skip to content

Commit

Permalink
ssh/tailssh: fix double race condition with non-pty command
Browse files Browse the repository at this point in the history
There are two race conditions in output handling.

The first race condition is due to a misuse of exec.Cmd.StdoutPipe.
The documentation explicitly forbids concurrent use of StdoutPipe
with exec.Cmd.Wait (see golang/go#60908) because Wait will
close both sides of the pipe once the process ends without
any guarantees that all data has been read from the pipe.
To fix this, we allocate the os.Pipes ourselves and
manage cleanup ourselves when the process has ended.

The second race condition is because sshSession.run waits
upon exec.Cmd to finish and then immediately proceeds to call ss.Exit,
which will close all output streams going to the SSH client.
This may interrupt any asynchronous io.Copy still copying data.
To fix this, we close the write-side of the os.Pipes after
the process has finished (and before calling ss.Exit) and
synchronously wait for the io.Copy routines to finish.

Fixes #7601

Signed-off-by: Joe Tsai <joetsai@digital-static.net>
Co-authored-by: Maisem Ali <maisem@tailscale.com>
  • Loading branch information
dsnet and maisem committed Jun 22, 2023
1 parent 88097b8 commit 58a3d75
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 44 deletions.
38 changes: 12 additions & 26 deletions ssh/tailssh/incubator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"errors"
"flag"
"fmt"
"io"
"log"
"log/syslog"
"os"
Expand Down Expand Up @@ -476,10 +475,10 @@ func (ss *sshSession) launchProcess() error {
}
go resizeWindow(ptyDup /* arbitrary fd */, winCh)

ss.wrStdin = pty
ss.rdStdout = os.NewFile(uintptr(ptyDup), pty.Name())
ss.rdStderr = nil // not available for pty
ss.tty = tty
ss.stdin = pty
ss.stdout = os.NewFile(uintptr(ptyDup), pty.Name())
ss.stderr = nil // not available for pty

return nil
}
Expand Down Expand Up @@ -658,40 +657,27 @@ func (ss *sshSession) startWithPTY() (ptyFile, tty *os.File, err error) {

// startWithStdPipes starts cmd with os.Pipe for Stdin, Stdout and Stderr.
func (ss *sshSession) startWithStdPipes() (err error) {
var stdin io.WriteCloser
var stdout, stderr io.ReadCloser
defer func() {
if err != nil {
for _, c := range []io.Closer{stdin, stdout, stderr} {
if c != nil {
c.Close()
}
}
closeAll(ss.rdStdin, ss.rdStdout, ss.rdStderr, ss.wrStdin, ss.wrStdout, ss.wrStderr)
}
}()
cmd := ss.cmd
if cmd == nil {
if ss.cmd == nil {
return errors.New("nil cmd")
}
stdin, err = cmd.StdinPipe()
if err != nil {
if ss.rdStdin, ss.wrStdin, err = os.Pipe(); err != nil {
return err
}
stdout, err = cmd.StdoutPipe()
if err != nil {
ss.cmd.Stdin = ss.rdStdin
if ss.rdStdout, ss.wrStdout, err = os.Pipe(); err != nil {
return err
}
stderr, err = cmd.StderrPipe()
if err != nil {
return err
}
if err := cmd.Start(); err != nil {
ss.cmd.Stdout = ss.wrStdout
if ss.rdStderr, ss.wrStderr, err = os.Pipe(); err != nil {
return err
}
ss.stdin = stdin
ss.stdout = stdout
ss.stderr = stderr
return nil
ss.cmd.Stderr = ss.wrStderr
return ss.cmd.Start()
}

func envForUser(u *userMeta) []string {
Expand Down
53 changes: 35 additions & 18 deletions ssh/tailssh/tailssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -823,12 +823,11 @@ type sshSession struct {
agentListener net.Listener // non-nil if agent-forwarding requested+allowed

// initialized by launchProcess:
cmd *exec.Cmd
stdin io.WriteCloser
stdout io.ReadCloser
stderr io.Reader // nil for pty sessions
ptyReq *ssh.Pty // non-nil for pty sessions
tty *os.File // non-nil for pty sessions, must be closed after process exits
cmd *exec.Cmd
rdStdin, rdStdout, rdStderr io.ReadCloser // rdStderr is nil for pty sessions
wrStdin, wrStdout, wrStderr io.WriteCloser // wrStderr is nil for pty sessions
tty io.Closer // non-nil for pty sessions, must be closed after process exits
ptyReq *ssh.Pty // non-nil for pty sessions

// We use this sync.Once to ensure that we only terminate the process once,
// either it exits itself or is terminated
Expand Down Expand Up @@ -1107,21 +1106,22 @@ func (ss *sshSession) run() {

var processDone atomic.Bool
go func() {
defer ss.stdin.Close()
if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil {
defer ss.wrStdin.Close()
if _, err := io.Copy(rec.writer("i", ss.wrStdin), ss); err != nil {
logf("stdin copy: %v", err)
ss.cancelCtx(err)
}
}()
outputDone := make(chan struct{})
var openOutputStreams atomic.Int32
if ss.stderr != nil {
if ss.rdStderr != nil {
openOutputStreams.Store(2)
} else {
openOutputStreams.Store(1)
}
go func() {
defer ss.stdout.Close()
_, err := io.Copy(rec.writer("o", ss), ss.stdout)
defer ss.rdStdout.Close()
_, err := io.Copy(rec.writer("o", ss), ss.rdStdout)
if err != nil && !errors.Is(err, io.EOF) {
isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO)
if !isErrBecauseProcessExited {
Expand All @@ -1131,32 +1131,41 @@ func (ss *sshSession) run() {
}
if openOutputStreams.Add(-1) == 0 {
ss.CloseWrite()
close(outputDone)
}
}()
// stderr is nil for ptys.
if ss.stderr != nil {
// rdStderr is nil for ptys.
if ss.rdStderr != nil {
go func() {
_, err := io.Copy(ss.Stderr(), ss.stderr)
defer ss.rdStderr.Close()
_, err := io.Copy(ss.Stderr(), ss.rdStderr)
if err != nil {
logf("stderr copy: %v", err)
}
if openOutputStreams.Add(-1) == 0 {
ss.CloseWrite()
close(outputDone)
}
}()
}

if ss.tty != nil {
// If running a tty session, close the tty when the session is done.
defer ss.tty.Close()
}
err = ss.cmd.Wait()
processDone.Store(true)

// This will either make the SSH Termination goroutine be a no-op,
// or itself will be a no-op because the process was killed by the
// aforementioned goroutine.
ss.exitOnce.Do(func() {})

// Close the process-side of all pipes to signal the asynchronous
// io.Copy routines reading/writing from the pipes to terminate.
// Block for the io.Copy to finish before calling ss.Exit below.
closeAll(ss.rdStdin, ss.wrStdout, ss.wrStderr, ss.tty)
select {
case <-outputDone:
case <-ss.ctx.Done():
}

if err == nil {
ss.logf("Session complete")
ss.Exit(0)
Expand Down Expand Up @@ -1894,3 +1903,11 @@ type SSHTerminationError interface {
error
SSHTerminationMessage() string
}

func closeAll(cs ...io.Closer) {
for _, c := range cs {
if c != nil {
c.Close()
}
}
}
19 changes: 19 additions & 0 deletions ssh/tailssh/tailssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"os/user"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -947,6 +948,24 @@ func TestSSH(t *testing.T) {
// "foo\n" and "bar\n", not "\n" and "bar\n".
})

t.Run("large_file", func(t *testing.T) {
const wantSize = 1e6
var bb bytes.Buffer
cmd := exec.Command("ssh",
"-F", "none",
"-p", fmt.Sprint(port),
"-o", "StrictHostKeyChecking=no",
"user@127.0.0.1",
"head", "-c", strconv.Itoa(wantSize), "/dev/zero")
cmd.Stdout = &bb
if err := cmd.Run(); err != nil {
t.Fatal(err)
}
if gotSize := bb.Len(); gotSize != wantSize {
t.Fatalf("got %d, want %d", gotSize, int(wantSize))
}
})

t.Run("stdin", func(t *testing.T) {
if cibuild.On() {
t.Skip("Skipping for now; see https://github.com/tailscale/tailscale/issues/4051")
Expand Down

0 comments on commit 58a3d75

Please sign in to comment.