-
Notifications
You must be signed in to change notification settings - Fork 18
/
sshtunnel.go
139 lines (132 loc) · 3.07 KB
/
sshtunnel.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package sshtunnel
import (
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)
type Endpoint struct {
Host string
Port int
User string
}
func NewEndpoint(s string) *Endpoint {
endpoint := &Endpoint{
Host: s,
}
if parts := strings.Split(endpoint.Host, "@"); len(parts) > 1 {
endpoint.User = parts[0]
endpoint.Host = parts[1]
}
if parts := strings.Split(endpoint.Host, ":"); len(parts) > 1 {
endpoint.Host = parts[0]
endpoint.Port, _ = strconv.Atoi(parts[1])
}
return endpoint
}
func (endpoint *Endpoint) String() string {
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
}
type SSHTunnel struct {
running bool
Local *Endpoint
Server *Endpoint
Remote *Endpoint
Config *ssh.ClientConfig
Log *logrus.Entry
}
func (tunnel *SSHTunnel) logf(fmt string, args ...interface{}) {
if tunnel.Log != nil {
tunnel.Log.Debugf(fmt, args...)
}
}
func (tunnel *SSHTunnel) Start() error {
if tunnel.running {
return fmt.Errorf("already running")
}
listener, err := net.Listen("tcp", tunnel.Local.String())
if err != nil {
return err
}
tunnel.running = true
tunnel.Local.Port = listener.Addr().(*net.TCPAddr).Port
go func() {
for tunnel.running {
conn, err := listener.Accept()
if err != nil {
tunnel.running = false
return
}
tunnel.logf("accepted connection")
go tunnel.forward(conn)
}
listener.Close()
}()
return nil
}
func (tunnel *SSHTunnel) Stop() {
tunnel.running = false
}
func (tunnel *SSHTunnel) forward(localConn net.Conn) {
serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
if err != nil {
tunnel.logf("server dial error: %s", err)
return
}
tunnel.logf("connected to %s (1 of 2)\n", tunnel.Server.String())
remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String())
if err != nil {
tunnel.logf("remote dial error: %s", err)
localConn.Close()
return
}
tunnel.logf("connected to %s (2 of 2)\n", tunnel.Remote.String())
copyConn := func(writer, reader net.Conn) {
_, err := io.Copy(writer, reader)
if err != nil {
tunnel.logf("io.Copy error: %s", err)
tunnel.running = false
localConn.Close()
remoteConn.Close()
}
}
go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)
}
func PrivateKeyFile(file string) (ssh.AuthMethod, error) {
buffer, err := os.ReadFile(file)
if err != nil {
return nil, err
}
key, err := ssh.ParsePrivateKey(buffer)
if err != nil {
return nil, err
}
return ssh.PublicKeys(key), nil
}
func NewSSHTunnel(tunnel string, auth ssh.AuthMethod, destination string) *SSHTunnel {
// A random port will be chosen for us.
localEndpoint := NewEndpoint("localhost:0")
server := NewEndpoint(tunnel)
if server.Port == 0 {
server.Port = 22
}
sshTunnel := &SSHTunnel{
Config: &ssh.ClientConfig{
User: server.User,
Auth: []ssh.AuthMethod{auth},
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
// Always accept key.
return nil
},
},
Local: localEndpoint,
Server: server,
Remote: NewEndpoint(destination),
}
return sshTunnel
}