/
ssh.go
313 lines (289 loc) · 6.9 KB
/
ssh.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
package util
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"net/netip"
"os"
"path/filepath"
"strconv"
"sync"
"time"
"github.com/kevinburke/ssh_config"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"k8s.io/client-go/util/homedir"
)
type SshConfig struct {
Addr string
User string
Password string
Keyfile string
ConfigAlias string
RemoteKubeconfig string
}
func Main(remoteEndpoint, localEndpoint *netip.AddrPort, conf *SshConfig, done chan struct{}) error {
var remote *ssh.Client
var err error
if conf.ConfigAlias != "" {
remote, err = jumpRecursion(conf.ConfigAlias)
} else {
var auth []ssh.AuthMethod
if conf.Keyfile != "" {
auth = append(auth, publicKeyFile(conf.Keyfile))
}
if conf.Password != "" {
auth = append(auth, ssh.Password(conf.Password))
}
// refer to https://godoc.org/golang.org/x/crypto/ssh for other authentication types
sshConfig := &ssh.ClientConfig{
// SSH connection username
User: conf.User,
Auth: auth,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
// Connect to SSH remote server using serverEndpoint
remote, err = ssh.Dial("tcp", conf.Addr, sshConfig)
}
if err != nil {
log.Errorf("Dial INTO remote server error: %s", err)
return err
}
// Listen on remote server port
listen, err := net.Listen("tcp", "localhost:0")
if err != nil {
return err
}
defer listen.Close()
*localEndpoint, err = netip.ParseAddrPort(listen.Addr().String())
if err != nil {
return err
}
done <- struct{}{}
// handle incoming connections on reverse forwarded tunnel
for {
local, err := listen.Accept()
if err != nil {
log.Error(err)
continue
}
go func() {
defer local.Close()
var conn net.Conn
var err error
for i := 0; i < 5; i++ {
conn, err = remote.Dial("tcp", remoteEndpoint.String())
if err == nil {
break
}
time.Sleep(time.Second)
}
if conn == nil {
return
}
handleClient(local, conn)
}()
}
}
func Run(conf *SshConfig, cmd string, env []string) (output []byte, errOut []byte, err error) {
var remote *ssh.Client
if conf.ConfigAlias != "" {
remote, err = jumpRecursion(conf.ConfigAlias)
} else {
var auth []ssh.AuthMethod
if conf.Keyfile != "" {
auth = append(auth, publicKeyFile(conf.Keyfile))
}
if conf.Password != "" {
auth = append(auth, ssh.Password(conf.Password))
}
// refer to https://godoc.org/golang.org/x/crypto/ssh for other authentication types
sshConfig := &ssh.ClientConfig{
// SSH connection username
User: conf.User,
Auth: auth,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
// Connect to SSH remote server using serverEndpoint
remote, err = ssh.Dial("tcp", conf.Addr, sshConfig)
}
if err != nil {
log.Errorf("Dial INTO remote server error: %s", err)
return
}
defer remote.Close()
var session *ssh.Session
session, err = remote.NewSession()
if err != nil {
return
}
if len(env) == 2 {
// /etc/ssh/sshd_config
// AcceptEnv DEBIAN_FRONTEND
if err = session.Setenv(env[0], env[1]); err != nil {
log.Warn(err)
err = nil
}
}
defer remote.Close()
var out bytes.Buffer
var er bytes.Buffer
session.Stdout = &out
session.Stderr = &er
err = session.Run(cmd)
return out.Bytes(), er.Bytes(), err
}
func publicKeyFile(file string) ssh.AuthMethod {
var err error
if len(file) != 0 && file[0] == '~' {
file = filepath.Join(homedir.HomeDir(), file[1:])
}
file, err = filepath.Abs(file)
if err != nil {
log.Fatalln(fmt.Sprintf("Cannot read SSH public key file %s", file))
return nil
}
buffer, err := os.ReadFile(file)
if err != nil {
log.Fatalln(fmt.Sprintf("Cannot read SSH public key file %s", file))
return nil
}
key, err := ssh.ParsePrivateKey(buffer)
if err != nil {
log.Fatalln(fmt.Sprintf("Cannot parse SSH public key file %s", file))
return nil
}
return ssh.PublicKeys(key)
}
func handleClient(client net.Conn, remote net.Conn) {
chDone := make(chan bool)
// start remote -> local data transfer
go func() {
_, err := io.Copy(client, remote)
if err != nil {
log.Debugf("error while copy remote->local: %s", err)
}
chDone <- true
}()
// start local -> remote data transfer
go func() {
_, err := io.Copy(remote, client)
if err != nil {
log.Debugf("error while copy local->remote: %s", err)
}
chDone <- true
}()
<-chDone
}
func jumpRecursion(name string) (client *ssh.Client, err error) {
var jumper = "ProxyJump"
var bastionList = []*SshConfig{getBastion(name)}
for {
value := confList.Get(name, jumper)
if value != "" {
bastionList = append(bastionList, getBastion(value))
name = value
continue
}
break
}
for i := len(bastionList) - 1; i >= 0; i-- {
if bastionList[i] == nil {
return nil, errors.New("config is nil")
}
if client == nil {
client, err = dial(bastionList[i])
if err != nil {
return
}
} else {
client, err = jump(client, bastionList[i])
if err != nil {
return
}
}
}
return
}
func getBastion(name string) *SshConfig {
var host, port string
config := SshConfig{
ConfigAlias: name,
}
var propertyList = []string{"ProxyJump", "Hostname", "User", "Port", "IdentityFile"}
for i, s := range propertyList {
value := confList.Get(name, s)
switch i {
case 0:
case 1:
host = value
case 2:
config.User = value
case 3:
if port = value; port == "" {
port = strconv.Itoa(22)
}
case 4:
config.Keyfile = value
}
}
config.Addr = net.JoinHostPort(host, port)
return &config
}
func dial(from *SshConfig) (*ssh.Client, error) {
// connect to the bastion host
return ssh.Dial("tcp", from.Addr, &ssh.ClientConfig{
User: from.User,
Auth: []ssh.AuthMethod{publicKeyFile(from.Keyfile)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
}
func jump(bClient *ssh.Client, to *SshConfig) (*ssh.Client, error) {
// Dial a connection to the service host, from the bastion
conn, err := bClient.Dial("tcp", to.Addr)
if err != nil {
return nil, err
}
ncc, chans, reqs, err := ssh.NewClientConn(conn, to.Addr, &ssh.ClientConfig{
User: to.User,
Auth: []ssh.AuthMethod{publicKeyFile(to.Keyfile)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
})
if err != nil {
return nil, err
}
sClient := ssh.NewClient(ncc, chans, reqs)
return sClient, nil
}
type conf []*ssh_config.Config
func (c conf) Get(alias string, key string) string {
for _, s := range c {
if v, err := s.Get(alias, key); err == nil {
return v
}
}
return ssh_config.Get(alias, key)
}
var once sync.Once
var confList conf
func init() {
once.Do(func() {
strings := []string{
filepath.Join(homedir.HomeDir(), ".ssh", "config"),
filepath.Join("/", "etc", "ssh", "ssh_config"),
}
for _, s := range strings {
file, err := os.ReadFile(s)
if err != nil {
continue
}
cfg, err := ssh_config.DecodeBytes(file)
if err != nil {
continue
}
confList = append(confList, cfg)
}
})
}