/
shell.go
170 lines (154 loc) · 3.64 KB
/
shell.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
package tent
import (
"bytes"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"syscall"
"unsafe"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh"
)
const (
tokenLength int = 24
fallbackPort uint64 = 2222
)
func UserPort() uint64 {
uid := os.Getuid()
h := sha256.New()
h.Write([]byte("spanner tent uid to port"))
binary.Write(h, binary.BigEndian, int64(uid))
hash := h.Sum(nil)
reader := bytes.NewReader(hash)
number, err := binary.ReadUvarint(reader)
if err != nil {
return fallbackPort
}
// Keep the port in the safe range
if number < 10240 {
number += 10240
}
if number >= 65535 {
number = number % 65535
}
return number
}
func generateToken() string {
rb := make([]byte, tokenLength)
_, err := rand.Read(rb)
if err != nil {
panic(err)
}
rs := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(rb)
return rs
}
func writeTokenFile() (token, filename string, err error) {
cwd, err := os.Getwd()
if err != nil {
return "", "", fmt.Errorf("get current working directory: %w", err)
}
filename = filepath.Join(cwd, ".spanner-token")
token = generateToken()
err = ioutil.WriteFile(filename, []byte(token), 0600)
if err != nil {
return "", "", fmt.Errorf("writing token file: %w", err)
}
return
}
func kiHandler(ctx ssh.Context, challenger gossh.KeyboardInteractiveChallenge) bool {
token, tokenFilename, err := writeTokenFile()
if err != nil {
log.Printf("write token to file: %v", err)
return false
}
defer os.Remove(tokenFilename)
question := tokenFilename + "\n"
answers, err := challenger("", "", []string{question}, []bool{false})
if err != nil {
return false
}
inputToken := answers[0]
return inputToken == token
}
func sessionHandler(s ssh.Session) {
io.WriteString(s, "Connected to spanner shell.\n")
ptyReq, winCh, isPty := s.Pty()
if !isPty {
io.WriteString(s, "No PTY requested.\n")
s.Exit(1)
return
}
cmd := exec.Command("bash")
cmd.Env = append(os.Environ(), []string{
fmt.Sprintf("TERM=%s", ptyReq.Term),
}...)
f, err := pty.Start(cmd)
if err != nil {
log.Printf("start pty: %v", err)
return
}
go func() {
for win := range winCh {
setWinsize(f, win.Width, win.Height)
}
}()
go func() {
io.Copy(f, s) // stdin
}()
io.Copy(s, f) // stdout
cmd.Wait()
}
func setWinsize(f *os.File, w, h int) {
syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
}
func kiClientHandler(name, instruction string, questions []string, echos []bool) (answers []string, err error) {
tokenFilename := strings.TrimSpace(questions[0])
token, err := ioutil.ReadFile(tokenFilename)
if err != nil {
return nil, fmt.Errorf("read token file: %w", err)
}
return []string{string(token)}, nil
}
func shellIsPresent() bool {
config := &gossh.ClientConfig{
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
Auth: []gossh.AuthMethod{
gossh.KeyboardInteractive(kiClientHandler),
},
}
addr := "localhost:" + strconv.FormatUint(UserPort(), 10)
conn, err := gossh.Dial("tcp", addr, config)
if err != nil {
return false
}
defer conn.Close()
return true
}
func RunShellServer() error {
if shellIsPresent() {
return nil
}
userPort := UserPort()
server := &ssh.Server{
Addr: ":" + strconv.FormatUint(userPort, 10),
Handler: sessionHandler,
KeyboardInteractiveHandler: kiHandler,
}
err := server.ListenAndServe()
if err != nil {
return fmt.Errorf("listen and serve ssh: %w", err)
}
return nil
}