-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.go
133 lines (119 loc) · 3.93 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
/*
Copyright 2018 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package proxy
import (
"crypto/tls"
"net"
"net/http"
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
)
// TLSServerConfig is a configuration for TLS server
type TLSServerConfig struct {
// ForwarderConfig is a config of a forwarder
ForwarderConfig
// TLS is a base TLS configuration
TLS *tls.Config
// LimiterConfig is limiter config
LimiterConfig limiter.LimiterConfig
// AccessPoint is caching access point
AccessPoint auth.AccessPoint
// Component is used for debugging purposes
Component string
}
// CheckAndSetDefaults checks and sets default values
func (c *TLSServerConfig) CheckAndSetDefaults() error {
if err := c.ForwarderConfig.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
if c.TLS == nil {
return trace.BadParameter("missing parameter TLS")
}
c.TLS.ClientAuth = tls.RequireAndVerifyClientCert
if c.TLS.ClientCAs == nil {
return trace.BadParameter("missing parameter TLS.ClientCAs")
}
if c.TLS.RootCAs == nil {
return trace.BadParameter("missing parameter TLS.RootCAs")
}
if len(c.TLS.Certificates) == 0 {
return trace.BadParameter("missing parameter TLS.Certificates")
}
if c.AccessPoint == nil {
return trace.BadParameter("missing parameter AccessPoint")
}
return nil
}
// TLSServer is TLS auth server
type TLSServer struct {
*http.Server
// TLSServerConfig is TLS server configuration used for auth server
TLSServerConfig
}
// NewTLSServer returns new unstarted TLS server
func NewTLSServer(cfg TLSServerConfig) (*TLSServer, error) {
if err := cfg.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
// limiter limits requests by frequency and amount of simultaneous
// connections per client
limiter, err := limiter.NewLimiter(cfg.LimiterConfig)
if err != nil {
return nil, trace.Wrap(err)
}
fwd, err := NewForwarder(cfg.ForwarderConfig)
if err != nil {
return nil, trace.Wrap(err)
}
// authMiddleware authenticates request assuming TLS client authentication
// adds authentication information to the context
// and passes it to the API server
authMiddleware := &auth.AuthMiddleware{
AccessPoint: cfg.AccessPoint,
AcceptedUsage: []string{teleport.UsageKubeOnly},
}
authMiddleware.Wrap(fwd)
// Wrap sets the next middleware in chain to the authMiddleware
limiter.WrapHandle(authMiddleware)
// force client auth if given
cfg.TLS.ClientAuth = tls.VerifyClientCertIfGiven
server := &TLSServer{
TLSServerConfig: cfg,
Server: &http.Server{
Handler: limiter,
},
}
server.TLS.GetConfigForClient = server.GetConfigForClient
return server, nil
}
// Serve takes TCP listener, upgrades to TLS using config and starts serving
func (t *TLSServer) Serve(listener net.Listener) error {
return t.Server.Serve(tls.NewListener(listener, t.TLS))
}
// GetConfigForClient is getting called on every connection
// and server's GetConfigForClient reloads the list of trusted
// local and remote certificate authorities
func (t *TLSServer) GetConfigForClient(info *tls.ClientHelloInfo) (*tls.Config, error) {
pool, err := auth.ClientCertPool(t.AccessPoint)
if err != nil {
log.Errorf("failed to retrieve client pool: %v", trace.DebugReport(err))
// this falls back to the default config
return nil, nil
}
tlsCopy := t.TLS.Clone()
tlsCopy.ClientCAs = pool
return tlsCopy, nil
}