/
sshtun.go
329 lines (272 loc) · 9.57 KB
/
sshtun.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
// Package sshtun provides a SSH tunnel with port forwarding.
package sshtun
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
)
// SSHTun represents a SSH tunnel
type SSHTun struct {
mutex *sync.Mutex
ctx context.Context
cancel context.CancelFunc
started bool
user string
authType AuthType
authKeyFile string
authKeyReader io.Reader
authPassword string
server *Endpoint
local *Endpoint
remote *Endpoint
timeout time.Duration
connState func(*SSHTun, ConnState)
tunneledConnState func(*SSHTun, *TunneledConnState)
active int
sshClient *ssh.Client
sshConfig *ssh.ClientConfig
}
// ConnState represents the state of the SSH tunnel. It's returned to an optional function provided to SetConnState.
type ConnState int
const (
// StateStopped represents a stopped tunnel. A call to Start will make the state to transition to StateStarting.
StateStopped ConnState = iota
// StateStarting represents a tunnel initializing and preparing to listen for connections.
// A successful initialization will make the state to transition to StateStarted, otherwise it will transition to StateStopped.
StateStarting
// StateStarted represents a tunnel ready to accept connections.
// A call to stop or an error will make the state to transition to StateStopped.
StateStarted
)
// New creates a new SSH tunnel to the specified server redirecting a port on local localhost to a port on remote localhost.
// By default the SSH connection is made to port 22 as root and using automatic detection of the authentication
// method (see Start for details on this).
// Calling SetPassword will change the authentication to password based.
// Calling SetKeyFile will change the authentication to keyfile based..
// The SSH user and port can be changed with SetUser and SetPort.
// The local and remote hosts can be changed to something different than localhost with SetLocalEndpoint and SetRemoteEndpoint.
// The states of the tunnel can be received throgh a callback function with SetConnState.
// The states of the tunneled connections can be received through a callback function with SetTunneledConnState.
func New(localPort int, server string, remotePort int) *SSHTun {
sshTun := defaultSSHTun(server)
sshTun.local = NewTCPEndpoint("localhost", localPort)
sshTun.remote = NewTCPEndpoint("localhost", remotePort)
return sshTun
}
// NewUnix does the same as New but using unix sockets.
func NewUnix(localUnixSocket string, server string, remoteUnixSocket string) *SSHTun {
sshTun := defaultSSHTun(server)
sshTun.local = NewUnixEndpoint(localUnixSocket)
sshTun.remote = NewUnixEndpoint(remoteUnixSocket)
return sshTun
}
func defaultSSHTun(server string) *SSHTun {
return &SSHTun{
mutex: &sync.Mutex{},
server: NewTCPEndpoint(server, 22),
user: "root",
authType: AuthTypeAuto,
timeout: time.Second * 15,
}
}
// SetPort changes the port where the SSH connection will be made.
func (tun *SSHTun) SetPort(port int) {
tun.server.port = port
}
// SetUser changes the user used to make the SSH connection.
func (tun *SSHTun) SetUser(user string) {
tun.user = user
}
// SetKeyFile changes the authentication to key-based and uses the specified file.
// Leaving the file empty defaults to the default linux private key locations: `~/.ssh/id_rsa`, `~/.ssh/id_dsa`,
// `~/.ssh/id_ecdsa`, `~/.ssh/id_ecdsa_sk`, `~/.ssh/id_ed25519` and `~/.ssh/id_ed25519_sk`.
func (tun *SSHTun) SetKeyFile(file string) {
tun.authType = AuthTypeKeyFile
tun.authKeyFile = file
}
// SetEncryptedKeyFile changes the authentication to encrypted key-based and uses the specified file and password.
// Leaving the file empty defaults to the default linux private key locations: `~/.ssh/id_rsa`, `~/.ssh/id_dsa`,
// `~/.ssh/id_ecdsa`, `~/.ssh/id_ecdsa_sk`, `~/.ssh/id_ed25519` and `~/.ssh/id_ed25519_sk`.
func (tun *SSHTun) SetEncryptedKeyFile(file string, password string) {
tun.authType = AuthTypeEncryptedKeyFile
tun.authKeyFile = file
tun.authPassword = password
}
// SetKeyReader changes the authentication to key-based and uses the specified reader.
func (tun *SSHTun) SetKeyReader(reader io.Reader) {
tun.authType = AuthTypeKeyReader
tun.authKeyReader = reader
}
// SetEncryptedKeyReader changes the authentication to encrypted key-based and uses the specified reader and password.
func (tun *SSHTun) SetEncryptedKeyReader(reader io.Reader, password string) {
tun.authType = AuthTypeEncryptedKeyReader
tun.authKeyReader = reader
tun.authPassword = password
}
// SetSSHAgent changes the authentication to ssh-agent.
func (tun *SSHTun) SetSSHAgent() {
tun.authType = AuthTypeSSHAgent
}
// SetPassword changes the authentication to password-based and uses the specified password.
func (tun *SSHTun) SetPassword(password string) {
tun.authType = AuthTypePassword
tun.authPassword = password
}
// SetLocalHost sets the local host to redirect (defaults to localhost).
func (tun *SSHTun) SetLocalHost(host string) {
tun.local.host = host
}
// SetRemoteHost sets the remote host to redirect (defaults to localhost).
func (tun *SSHTun) SetRemoteHost(host string) {
tun.remote.host = host
}
// SetLocalEndpoint sets the local endpoint to redirect.
func (tun *SSHTun) SetLocalEndpoint(endpoint *Endpoint) {
tun.local = endpoint
}
// SetRemoteEndpoint sets the remote endpoint to redirect.
func (tun *SSHTun) SetRemoteEndpoint(endpoint *Endpoint) {
tun.remote = endpoint
}
// SetTimeout sets the connection timeouts (defaults to 15 seconds).
func (tun *SSHTun) SetTimeout(timeout time.Duration) {
tun.timeout = timeout
}
// SetConnState specifies an optional callback function that is called when a SSH tunnel changes state.
// See the ConnState type and associated constants for details.
func (tun *SSHTun) SetConnState(connStateFun func(*SSHTun, ConnState)) {
tun.connState = connStateFun
}
// SetTunneledConnState specifies an optional callback function that is called when the underlying tunneled
// connections change state.
func (tun *SSHTun) SetTunneledConnState(tunneledConnStateFun func(*SSHTun, *TunneledConnState)) {
tun.tunneledConnState = tunneledConnStateFun
}
// Start starts the SSH tunnel. It can be stopped by calling `Stop` or cancelling its context.
// This call will block until the tunnel is stopped either calling those methods or by an error.
// Note on SSH authentication: in case the tunnel's authType is set to AuthTypeAuto the following will happen:
// The default key files will be used, if that doesn't succeed it will try to use the SSH agent.
// If that fails the whole authentication fails.
// That means if you want to use password or encrypted key file authentication, you have to specify that explicitly.
func (tun *SSHTun) Start(ctx context.Context) error {
tun.mutex.Lock()
if tun.started {
tun.mutex.Unlock()
return fmt.Errorf("already started")
}
tun.started = true
tun.ctx, tun.cancel = context.WithCancel(ctx)
tun.mutex.Unlock()
if tun.connState != nil {
tun.connState(tun, StateStarting)
}
config, err := tun.initSSHConfig()
if err != nil {
return tun.stop(fmt.Errorf("ssh config failed: %w", err))
}
tun.sshConfig = config
listenConfig := net.ListenConfig{}
localListener, err := listenConfig.Listen(tun.ctx, tun.local.Type(), tun.local.String())
if err != nil {
return tun.stop(fmt.Errorf("local listen %s on %s failed: %w", tun.local.Type(), tun.local.String(), err))
}
errChan := make(chan error)
go func() {
errChan <- tun.listen(localListener)
}()
if tun.connState != nil {
tun.connState(tun, StateStarted)
}
return tun.stop(<-errChan)
}
// Stop closes all connections and makes Start exit gracefuly.
func (tun *SSHTun) Stop() {
tun.mutex.Lock()
defer tun.mutex.Unlock()
if tun.started {
tun.cancel()
}
}
func (tun *SSHTun) initSSHConfig() (*ssh.ClientConfig, error) {
config := &ssh.ClientConfig{
User: tun.user,
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
Timeout: tun.timeout,
}
authMethod, err := tun.getSSHAuthMethod()
if err != nil {
return nil, err
}
config.Auth = []ssh.AuthMethod{authMethod}
return config, nil
}
func (tun *SSHTun) stop(err error) error {
tun.mutex.Lock()
tun.started = false
tun.mutex.Unlock()
if tun.connState != nil {
tun.connState(tun, StateStopped)
}
return err
}
func (tun *SSHTun) listen(localListener net.Listener) error {
errGroup, groupCtx := errgroup.WithContext(tun.ctx)
errGroup.Go(func() error {
for {
localConn, err := localListener.Accept()
if err != nil {
return fmt.Errorf("local accept %s on %s failed: %w", tun.local.Type(), tun.local.String(), err)
}
errGroup.Go(func() error {
return tun.handle(localConn)
})
}
})
<-groupCtx.Done()
localListener.Close()
err := errGroup.Wait()
select {
case <-tun.ctx.Done():
default:
return err
}
return nil
}
func (tun *SSHTun) handle(localConn net.Conn) error {
err := tun.addConn()
if err != nil {
return err
}
tun.forward(localConn)
tun.removeConn()
return nil
}
func (tun *SSHTun) addConn() error {
tun.mutex.Lock()
defer tun.mutex.Unlock()
if tun.active == 0 {
sshClient, err := ssh.Dial(tun.server.Type(), tun.server.String(), tun.sshConfig)
if err != nil {
return fmt.Errorf("ssh dial %s to %s failed: %w", tun.server.Type(), tun.server.String(), err)
}
tun.sshClient = sshClient
}
tun.active += 1
return nil
}
func (tun *SSHTun) removeConn() {
tun.mutex.Lock()
defer tun.mutex.Unlock()
tun.active -= 1
if tun.active == 0 {
tun.sshClient.Close()
tun.sshClient = nil
}
}