diff --git a/command/ssh/proxycommand.go b/command/ssh/proxycommand.go index fe35cf187..13bc00ffc 100644 --- a/command/ssh/proxycommand.go +++ b/command/ssh/proxycommand.go @@ -6,7 +6,6 @@ import ( "net" "os" "strings" - "sync" "time" "github.com/pkg/errors" @@ -228,6 +227,10 @@ func getBastion(ctx *cli.Context, user, host string) (*api.SSHBastionResponse, e } func proxyDirect(host, port string) error { + return proxyDirectWithIO(host, port, os.Stdin, os.Stdout) +} + +func proxyDirectWithIO(host, port string, in io.ReadCloser, out io.Writer) error { address := net.JoinHostPort(host, port) addr, err := net.ResolveTCPAddr("tcp", address) if err != nil { @@ -239,21 +242,27 @@ func proxyDirect(host, port string) error { return errors.Wrapf(err, "error connecting to %s", address) } - var wg sync.WaitGroup - wg.Add(1) + done := make(chan struct{}, 2) + go func() { - io.Copy(conn, os.Stdin) - conn.CloseWrite() - wg.Done() + io.Copy(conn, in) //nolint:errcheck + conn.CloseWrite() //nolint:errcheck + done <- struct{}{} }() - wg.Add(1) + go func() { - io.Copy(os.Stdout, conn) - conn.CloseRead() - wg.Done() + io.Copy(out, conn) //nolint:errcheck + conn.CloseRead() //nolint:errcheck + done <- struct{}{} }() - wg.Wait() + // Return as soon as either direction completes. When the server closes the + // connection (e.g. auth failure), the server→client goroutine finishes and + // we return immediately — the process exits cleanly and reclaims the + // client→server goroutine. Waiting for both goroutines deadlocks on macOS + // when os.Stdin is a pipe: closing the read end does not interrupt a + // blocked read syscall, so the client→server goroutine never exits. + <-done return nil } diff --git a/command/ssh/proxycommand_test.go b/command/ssh/proxycommand_test.go new file mode 100644 index 000000000..2ac7a859c --- /dev/null +++ b/command/ssh/proxycommand_test.go @@ -0,0 +1,63 @@ +package ssh + +import ( + "bytes" + "io" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestProxyDirectExitsWhenServerCloses verifies that proxyDirectWithIO returns +// promptly when the server closes the connection, even when stdin is still open. +// +// Without the fix, the two goroutines inside proxyDirectWithIO deadlock: +// +// 1. The server→stdout goroutine finishes (server closed the connection). +// 2. The stdin→server goroutine blocks on Read(os.Stdin) waiting for input +// that never arrives, because the SSH client is itself waiting for the +// ProxyCommand process to exit. +// 3. Neither side exits — the ProxyCommand hangs until an external timeout +// (typically ~60 s on the client) kills it. +func TestProxyDirectExitsWhenServerCloses(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + _, port, err := net.SplitHostPort(ln.Addr().String()) + require.NoError(t, err) + + // Server sends a message then closes the connection immediately. + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + conn.Write([]byte("hello")) //nolint:errcheck + conn.Close() + }() + + // stdinR blocks on Read until stdinW is closed. We intentionally leave + // stdinW open to simulate the client not having closed its stdin yet — + // the normal case when sshd rejects a connection mid-session. + stdinR, stdinW := io.Pipe() + defer stdinW.Close() + + var stdout bytes.Buffer + + done := make(chan error, 1) + go func() { + done <- proxyDirectWithIO("127.0.0.1", port, stdinR, &stdout) + }() + + select { + case err := <-done: + assert.NoError(t, err) + assert.Equal(t, "hello", stdout.String()) + case <-time.After(2 * time.Second): + t.Fatal("proxyDirectWithIO did not exit after server closed the connection — stdin goroutine deadlock") + } +}