Skip to content

Commit

Permalink
fix(terminal):support cancel with ctrl-c when reading password
Browse files Browse the repository at this point in the history
  • Loading branch information
vimiix committed Dec 18, 2023
1 parent eb5f3cb commit 4cafadf
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 11 deletions.
35 changes: 35 additions & 0 deletions internal/terminal/terminal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package terminal

import (
"context"

"github.com/containerd/console"
)

func ReadPassword(ctx context.Context) ([]byte, error) {
c := console.Current()
defer func() {
_ = c.Reset()
}()

var (
errch = make(chan error, 1)
password []byte
)

go func() {
bs, readErr := readPassword()
if readErr != nil {
errch <- readErr
}
password = bs
errch <- nil
}()

select {
case err := <-errch:
return password, err
case <-ctx.Done():
return nil, ctx.Err()
}
}
2 changes: 1 addition & 1 deletion internal/terminal/terminal_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/vimiix/ssx/internal/lg"
)

func ReadPassword() ([]byte, error) {
func readPassword() ([]byte, error) {
return term.ReadPassword(syscall.Stdin)
}

Expand Down
2 changes: 1 addition & 1 deletion internal/terminal/terminal_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/vimiix/ssx/internal/lg"
)

func ReadPassword() ([]byte, error) {
func readPassword() ([]byte, error) {
return term.ReadPassword(int(windows.Stdin))
}

Expand Down
7 changes: 4 additions & 3 deletions ssx/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"golang.org/x/crypto/ssh"

"github.com/containerd/console"

"github.com/vimiix/ssx/internal/lg"
"github.com/vimiix/ssx/internal/terminal"
"github.com/vimiix/ssx/ssx/entry"
Expand Down Expand Up @@ -153,7 +154,7 @@ func dialContext(ctx context.Context, network, addr string, config *ssh.ClientCo
func (c *Client) login(ctx context.Context) error {
network := "tcp"
addr := net.JoinHostPort(c.entry.Host, c.entry.Port)
clientConfig, err := c.entry.GenSSHConfig()
clientConfig, err := c.entry.GenSSHConfig(ctx)
if err != nil {
return err
}
Expand All @@ -166,13 +167,13 @@ func (c *Client) login(ctx context.Context) error {

if strings.Contains(err.Error(), "no supported methods remain") {
fmt.Printf("%s@%s's password:", c.entry.User, c.entry.Host)
bs, readErr := terminal.ReadPassword()
bs, readErr := terminal.ReadPassword(ctx)
fmt.Println()
if readErr == nil {
p := string(bs)
if p != "" {
clientConfig.Auth = []ssh.AuthMethod{ssh.Password(p)}
}
fmt.Println()
cli, err = ssh.Dial(network, addr, clientConfig)
if err == nil {
c.entry.Password = p
Expand Down
13 changes: 7 additions & 6 deletions ssx/entry/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package entry

import (
"bufio"
"context"
"fmt"
"net"
"os"
Expand Down Expand Up @@ -61,14 +62,14 @@ func getConnectTimeout() time.Duration {
return d
}

func (e *Entry) GenSSHConfig() (*ssh.ClientConfig, error) {
func (e *Entry) GenSSHConfig(ctx context.Context) (*ssh.ClientConfig, error) {
cb, err := e.sshHostKeyCallback()
if err != nil {
return nil, err
}
cfg := &ssh.ClientConfig{
User: e.User,
Auth: e.AuthMethods(),
Auth: e.AuthMethods(ctx),
HostKeyCallback: cb,
Timeout: getConnectTimeout(),
}
Expand Down Expand Up @@ -133,7 +134,7 @@ func (e *Entry) Tidy() error {
}

// AuthMethods all possible auth methods
func (e *Entry) AuthMethods() []ssh.AuthMethod {
func (e *Entry) AuthMethods(ctx context.Context) []ssh.AuthMethod {
var authMethods []ssh.AuthMethod
// password auth
if e.Password != "" {
Expand All @@ -146,11 +147,11 @@ func (e *Entry) AuthMethods() []ssh.AuthMethod {
authMethods = append(authMethods, keyfileAuths...)
}

authMethods = append(authMethods, e.interactAuth())
authMethods = append(authMethods, e.interactAuth(ctx))
return authMethods
}

func (e *Entry) interactAuth() ssh.AuthMethod {
func (e *Entry) interactAuth(ctx context.Context) ssh.AuthMethod {
return ssh.KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
answers = make([]string, 0, len(questions))
for i, q := range questions {
Expand All @@ -164,7 +165,7 @@ func (e *Entry) interactAuth() ssh.AuthMethod {
return nil, err
}
} else {
b, err := terminal.ReadPassword()
b, err := terminal.ReadPassword(ctx)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 4cafadf

Please sign in to comment.