/
server_handler.go
166 lines (162 loc) · 4.53 KB
/
server_handler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
package chserver
import (
"net/http"
"strings"
"sync/atomic"
"time"
chshare "github.com/qbee-io/tcpforwarder/share"
"github.com/qbee-io/tcpforwarder/share/cnet"
"github.com/qbee-io/tcpforwarder/share/settings"
"github.com/qbee-io/tcpforwarder/share/tunnel"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
)
// handleClientHandler is the main http websocket handler for the tcpforwarder server
func (s *Server) handleClientHandler(w http.ResponseWriter, r *http.Request) {
//websockets upgrade AND has tcpforwarder prefix
upgrade := strings.ToLower(r.Header.Get("Upgrade"))
protocol := r.Header.Get("Sec-WebSocket-Protocol")
if upgrade == "websocket" && strings.HasPrefix(protocol, "tcpforwarder-") {
if protocol == chshare.ProtocolVersion {
s.handleWebsocket(w, r)
return
}
//print into server logs and silently fall-through
s.Infof("ignored client connection using protocol '%s', expected '%s'",
protocol, chshare.ProtocolVersion)
}
//proxy target was provided
if s.reverseProxy != nil {
s.reverseProxy.ServeHTTP(w, r)
return
}
//no proxy defined, provide access to health/version checks
switch r.URL.String() {
case "/health":
w.Write([]byte("OK\n"))
return
case "/version":
w.Write([]byte(chshare.BuildVersion))
return
}
//missing :O
w.WriteHeader(404)
w.Write([]byte("Not found"))
}
// handleWebsocket is responsible for handling the websocket connection
func (s *Server) handleWebsocket(w http.ResponseWriter, req *http.Request) {
id := atomic.AddInt32(&s.sessCount, 1)
l := s.Fork("session#%d", id)
wsConn, err := upgrader.Upgrade(w, req, nil)
if err != nil {
l.Debugf("Failed to upgrade (%s)", err)
return
}
conn := cnet.NewWebSocketConn(wsConn)
// perform SSH handshake on net.Conn
l.Debugf("Handshaking with %s...", req.RemoteAddr)
sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil {
s.Debugf("Failed to handshake (%s)", err)
return
}
// pull the users from the session map
var user *settings.User
if s.users.Len() > 0 {
sid := string(sshConn.SessionID())
u, ok := s.sessions.Get(sid)
if !ok {
panic("bug in ssh auth handler")
}
user = u
s.sessions.Del(sid)
}
// tcpforwarder server handshake (reverse of client handshake)
// verify configuration
l.Debugf("Verifying configuration")
// wait for request, with timeout
var r *ssh.Request
select {
case r = <-reqs:
case <-time.After(settings.EnvDuration("CONFIG_TIMEOUT", 10*time.Second)):
l.Debugf("Timeout waiting for configuration")
sshConn.Close()
return
}
failed := func(err error) {
l.Debugf("Failed: %s", err)
r.Reply(false, []byte(err.Error()))
}
if r.Type != "config" {
failed(s.Errorf("expecting config request"))
return
}
c, err := settings.DecodeConfig(r.Payload)
if err != nil {
failed(s.Errorf("invalid config"))
return
}
//print if client and server versions dont match
if c.Version != chshare.BuildVersion {
v := c.Version
if v == "" {
v = "<unknown>"
}
l.Infof("Client version (%s) differs from server version (%s)",
v, chshare.BuildVersion)
}
//validate remotes
for _, r := range c.Remotes {
//if user is provided, ensure they have
//access to the desired remotes
if user != nil {
addr := r.UserAddr()
if !user.HasAccess(addr) {
failed(s.Errorf("access to '%s' denied", addr))
return
}
}
//confirm reverse tunnels are allowed
if r.Reverse && !s.config.Reverse {
l.Debugf("Denied reverse port forwarding request, please enable --reverse")
failed(s.Errorf("Reverse port forwaring not enabled on server"))
return
}
//confirm reverse tunnel is available
if r.Reverse && !r.CanListen() {
failed(s.Errorf("Server cannot listen on %s", r.String()))
return
}
}
//successfuly validated config!
r.Reply(true, nil)
//tunnel per ssh connection
tunnel := tunnel.New(tunnel.Config{
Logger: l,
Inbound: s.config.Reverse,
Outbound: true, //server always accepts outbound
Socks: s.config.Socks5,
KeepAlive: s.config.KeepAlive,
})
//bind
eg, ctx := errgroup.WithContext(req.Context())
eg.Go(func() error {
//connected, handover ssh connection for tunnel to use, and block
return tunnel.BindSSH(ctx, sshConn, reqs, chans)
})
eg.Go(func() error {
//connected, setup reversed-remotes?
serverInbound := c.Remotes.Reversed(true)
if len(serverInbound) == 0 {
return nil
}
//block
return tunnel.BindRemotes(ctx, serverInbound)
})
err = eg.Wait()
if err != nil && !strings.HasSuffix(err.Error(), "EOF") {
l.Debugf("Closed connection (%s)", err)
} else {
l.Debugf("Closed connection")
}
}