diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 13fd68732c250..a17af5cee8370 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -16,6 +16,9 @@ import ( "net" "os" "os/exec" + "os/user" + "runtime" + "strings" "syscall" "time" "unsafe" @@ -100,9 +103,9 @@ func (srv *server) handleSSH(s ssh.Session) { lb := srv.lb logf := srv.logf - user := s.User() + sshUser := s.User() addr := s.RemoteAddr() - logf("Handling SSH from %v for user %v", addr, user) + logf("Handling SSH from %v for user %v", addr, sshUser) ta, ok := addr.(*net.TCPAddr) if !ok { logf("tsshd: rejecting non-TCP addr %T %v", addr, addr) @@ -140,7 +143,7 @@ func (srv *server) handleSSH(s ssh.Session) { srcIP := srcIPP.IP() sctx := &sshContext{ now: time.Now(), - sshUser: s.User(), + sshUser: sshUser, srcIP: srcIP, node: node, uprof: &uprof, @@ -165,8 +168,19 @@ func (srv *server) handleSSH(s ssh.Session) { return } var cmd *exec.Cmd - if os.Getuid() != 0 || localUser == "root" { - cmd = exec.Command("/bin/bash") + if os.Getuid() != 0 { + u, err := user.Current() + if err != nil { + logf("failed to get current user: %v", err) + s.Exit(1) + return + } + if u.Username != localUser { + fmt.Fprintf(s, "can't switch user\n") + s.Exit(1) + return + } + cmd = exec.Command(loginShell(u.Uid)) } else { cmd = exec.Command("/usr/bin/env", "su", "-", localUser) } @@ -297,3 +311,19 @@ func matchesPrincipal(ps []*tailcfg.SSHPrincipal, sctx *sshContext) bool { } return false } + +func loginShell(uid string) string { + switch runtime.GOOS { + case "linux": + out, _ := exec.Command("getent", "passwd", uid).Output() + // out is "root:x:0:0:root:/root:/bin/bash" + f := strings.SplitN(string(out), ":", 10) + if len(f) > 6 { + return f[6] // shell + } + } + if e := os.Getenv("SHELL"); e != "" { + return e + } + return "/bin/bash" +}