-
Notifications
You must be signed in to change notification settings - Fork 0
/
localsite.go
218 lines (182 loc) · 6.21 KB
/
localsite.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
/*
Copyright 2016 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package reversetunnel
import (
"fmt"
"net"
"sync"
"time"
"golang.org/x/crypto/ssh/agent"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/forward"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
)
func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSite, error) {
accessPoint, err := srv.newAccessPoint(client, []string{"reverse", domainName})
if err != nil {
return nil, trace.Wrap(err)
}
// instantiate a cache of host certificates for the forwarding server. the
// certificate cache is created in each site (instead of creating it in
// reversetunnel.server and passing it along) so that the host certificate
// is signed by the correct certificate authority.
certificateCache, err := NewHostCertificateCache(srv.Config.KeyGen, client)
if err != nil {
return nil, trace.Wrap(err)
}
return &localSite{
srv: srv,
client: client,
accessPoint: accessPoint,
certificateCache: certificateCache,
domainName: domainName,
log: log.WithFields(log.Fields{
trace.Component: teleport.ComponentReverseTunnelServer,
trace.ComponentFields: map[string]string{
"cluster": domainName,
},
}),
}, nil
}
// localSite allows to directly access the remote servers
// not using any tunnel, and using standard SSH
//
// it implements RemoteSite interface
type localSite struct {
sync.Mutex
authServer string
log *log.Entry
domainName string
connections []*remoteConn
lastUsed int
lastActive time.Time
srv *server
// client provides access to the Auth Server API of the local cluster.
client auth.ClientI
// accessPoint provides access to a cached subset of the Auth Server API of
// the local cluster.
accessPoint auth.AccessPoint
// certificateCache caches host certificates for the forwarding server.
certificateCache *certificateCache
}
func (s *localSite) CachingAccessPoint() (auth.AccessPoint, error) {
return s.accessPoint, nil
}
func (s *localSite) GetClient() (auth.ClientI, error) {
return s.client, nil
}
func (s *localSite) String() string {
return fmt.Sprintf("local(%v)", s.domainName)
}
func (s *localSite) GetStatus() string {
return teleport.RemoteClusterStatusOnline
}
func (s *localSite) GetName() string {
return s.domainName
}
func (s *localSite) GetLastConnected() time.Time {
return time.Now()
}
func (s *localSite) DialAuthServer() (conn net.Conn, err error) {
// get list of local auth servers
authServers, err := s.client.GetAuthServers()
if err != nil {
return nil, trace.Wrap(err)
}
// try and dial to one of them, as soon as we are successful, return the net.Conn
for _, authServer := range authServers {
conn, err = net.DialTimeout("tcp", authServer.GetAddr(), defaults.DefaultDialTimeout)
if err == nil {
return conn, nil
}
}
// return the last error
return nil, trace.ConnectionProblem(err, "unable to connect to auth server")
}
func (s *localSite) Dial(from net.Addr, to net.Addr, userAgent agent.Agent) (net.Conn, error) {
clusterConfig, err := s.accessPoint.GetClusterConfig()
if err != nil {
return nil, trace.Wrap(err)
}
// if the proxy is in recording mode use the agent to dial and build a
// in-memory forwarding server
if clusterConfig.GetSessionRecording() == services.RecordAtProxy {
if userAgent == nil {
return nil, trace.BadParameter("user agent missing")
}
return s.dialWithAgent(from, to, userAgent)
}
return s.dial(from, to)
}
func (s *localSite) dial(from net.Addr, to net.Addr) (net.Conn, error) {
s.log.Debugf("Dialing from %v to %v", from, to)
return net.DialTimeout(to.Network(), to.String(), defaults.DefaultDialTimeout)
}
func (s *localSite) dialWithAgent(from net.Addr, to net.Addr, userAgent agent.Agent) (net.Conn, error) {
s.log.Debugf("Dialing with an agent from %v to %v", from, to)
// get a host certificate for the forwarding node from the cache
hostCertificate, err := s.certificateCache.GetHostCertificate(to.String())
if err != nil {
return nil, trace.Wrap(err)
}
// get a net.Conn to the target server
targetConn, err := net.DialTimeout(to.Network(), to.String(), defaults.DefaultDialTimeout)
if err != nil {
return nil, err
}
// create a forwarding server that serves a single ssh connection on it. we
// don't need to close this server it will close and release all resources
// once conn is closed.
serverConfig := forward.ServerConfig{
AuthClient: s.client,
UserAgent: userAgent,
TargetConn: targetConn,
SrcAddr: from,
DstAddr: to,
HostCertificate: hostCertificate,
Ciphers: s.srv.Config.Ciphers,
KEXAlgorithms: s.srv.Config.KEXAlgorithms,
MACAlgorithms: s.srv.Config.MACAlgorithms,
DataDir: s.srv.Config.DataDir,
}
remoteServer, err := forward.New(serverConfig)
if err != nil {
return nil, trace.Wrap(err)
}
go remoteServer.Serve()
// return a connection to the forwarding server
conn, err := remoteServer.Dial()
if err != nil {
return nil, trace.Wrap(err)
}
return conn, nil
}
func findServer(addr string, servers []services.Server) (services.Server, error) {
for i := range servers {
srv := servers[i]
_, port, err := net.SplitHostPort(srv.GetAddr())
if err != nil {
log.Warningf("server %v(%v) has incorrect address format (%v)",
srv.GetAddr(), srv.GetHostname(), err.Error())
} else {
if (len(srv.GetHostname()) != 0) && (len(port) != 0) && (addr == srv.GetHostname()+":"+port || addr == srv.GetAddr()) {
return srv, nil
}
}
}
return nil, trace.NotFound("server %v is unknown", addr)
}