forked from cockroachdb/cockroach
/
server.go
313 lines (273 loc) · 9.28 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
// Copyright 2015 The Cockroach Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied. See the License for the specific language governing
// permissions and limitations under the License.
//
// Author: Ben Darnell
package pgwire
import (
"crypto/tls"
"fmt"
"io"
"net"
"time"
"golang.org/x/net/context"
"github.com/cockroachdb/cockroach/pkg/base"
"github.com/cockroachdb/cockroach/pkg/sql"
"github.com/cockroachdb/cockroach/pkg/sql/mon"
"github.com/cockroachdb/cockroach/pkg/sql/parser"
"github.com/cockroachdb/cockroach/pkg/util"
"github.com/cockroachdb/cockroach/pkg/util/envutil"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/metric"
"github.com/cockroachdb/cockroach/pkg/util/syncutil"
"github.com/pkg/errors"
)
const (
// ErrSSLRequired is returned when a client attempts to connect to a
// secure server in cleartext.
ErrSSLRequired = "cleartext connections are not permitted"
// ErrDraining is returned when a client attempts to connect to a server
// which is not accepting client connections.
ErrDraining = "server is not accepting clients"
)
// Fully-qualified names for metrics.
var (
MetaConns = metric.Metadata{Name: "sql.conns"}
MetaBytesIn = metric.Metadata{Name: "sql.bytesin"}
MetaBytesOut = metric.Metadata{Name: "sql.bytesout"}
)
const (
version30 = 196608
versionSSL = 80877103
)
const drainMaxWait = 10 * time.Second
// baseSQLMemoryBudget is the amount of memory pre-allocated in each connection.
var baseSQLMemoryBudget = envutil.EnvOrDefaultInt64("COCKROACH_BASE_SQL_MEMORY_BUDGET",
int64(2.1*float64(mon.DefaultPoolAllocationSize)))
// connReservationBatchSize determines for how many connections memory
// is pre-reserved at once.
var connReservationBatchSize = 5
var (
sslSupported = []byte{'S'}
sslUnsupported = []byte{'N'}
)
// Server implements the server side of the PostgreSQL wire protocol.
type Server struct {
AmbientCtx log.AmbientContext
cfg *base.Config
executor *sql.Executor
metrics ServerMetrics
mu struct {
syncutil.Mutex
draining bool
}
sqlMemoryPool mon.MemoryMonitor
connMonitor mon.MemoryMonitor
}
// ServerMetrics is the set of metrics for the pgwire server.
type ServerMetrics struct {
BytesInCount *metric.Counter
BytesOutCount *metric.Counter
Conns *metric.Counter
ConnMemMetrics sql.MemoryMetrics
SQLMemMetrics sql.MemoryMetrics
internalMemMetrics *sql.MemoryMetrics
}
func makeServerMetrics(internalMemMetrics *sql.MemoryMetrics) ServerMetrics {
return ServerMetrics{
Conns: metric.NewCounter(MetaConns),
BytesInCount: metric.NewCounter(MetaBytesIn),
BytesOutCount: metric.NewCounter(MetaBytesOut),
ConnMemMetrics: sql.MakeMemMetrics("conns"),
SQLMemMetrics: sql.MakeMemMetrics("client"),
internalMemMetrics: internalMemMetrics,
}
}
// noteworthySQLMemoryUsageBytes is the minimum size tracked by the
// client SQL pool before the pool start explicitly logging overall
// usage growth in the log.
var noteworthySQLMemoryUsageBytes = envutil.EnvOrDefaultInt64("COCKROACH_NOTEWORTHY_SQL_MEMORY_USAGE", 100*1024*1024)
// noteworthyConnMemoryUsageBytes is the minimum size tracked by the
// connection monitor before the monitor start explicitly logging overall
// usage growth in the log.
var noteworthyConnMemoryUsageBytes = envutil.EnvOrDefaultInt64("COCKROACH_NOTEWORTHY_CONN_MEMORY_USAGE", 2*1024*1024)
// MakeServer creates a Server.
func MakeServer(
ambientCtx log.AmbientContext,
cfg *base.Config,
executor *sql.Executor,
internalMemMetrics *sql.MemoryMetrics,
maxSQLMem int64,
) *Server {
server := &Server{
AmbientCtx: ambientCtx,
cfg: cfg,
executor: executor,
metrics: makeServerMetrics(internalMemMetrics),
}
server.sqlMemoryPool = mon.MakeMonitor("sql",
server.metrics.SQLMemMetrics.CurBytesCount,
server.metrics.SQLMemMetrics.MaxBytesHist,
0, noteworthySQLMemoryUsageBytes)
server.sqlMemoryPool.Start(context.Background(), nil, mon.MakeStandaloneBudget(maxSQLMem))
server.connMonitor = mon.MakeMonitor("conn",
server.metrics.ConnMemMetrics.CurBytesCount,
server.metrics.ConnMemMetrics.MaxBytesHist,
int64(connReservationBatchSize)*baseSQLMemoryBudget, noteworthyConnMemoryUsageBytes)
server.connMonitor.Start(context.Background(), &server.sqlMemoryPool, mon.BoundAccount{})
return server
}
// Match returns true if rd appears to be a Postgres connection.
func Match(rd io.Reader) bool {
var buf readBuffer
_, err := buf.readUntypedMsg(rd)
if err != nil {
return false
}
version, err := buf.getUint32()
if err != nil {
return false
}
return version == version30 || version == versionSSL
}
// IsDraining returns true if the server is not currently accepting
// connections.
func (s *Server) IsDraining() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.mu.draining
}
// Metrics returns the metrics struct.
func (s *Server) Metrics() *ServerMetrics {
return &s.metrics
}
// SetDraining (when called with 'true') prevents new connections from being
// served and waits a reasonable amount of time for open connections to
// terminate. If an error is returned, the server remains in draining state,
// though open connections may continue to exist.
// When called with 'false', switches back to the normal mode of operation in
// which connections are accepted.
func (s *Server) SetDraining(drain bool) error {
s.mu.Lock()
s.mu.draining = drain
s.mu.Unlock()
if !drain {
return nil
}
return util.RetryForDuration(drainMaxWait, func() error {
if c := s.metrics.Conns.Count(); c != 0 {
// TODO(tschottdorf): Do more plumbing to actively disrupt
// connections; see #6283. There isn't much of a point until
// we know what load-balanced clients like to see (#6295).
return fmt.Errorf("timed out waiting for %d open connections to drain", c)
}
return nil
})
}
// ServeConn serves a single connection, driving the handshake process
// and delegating to the appropriate connection type.
func (s *Server) ServeConn(ctx context.Context, conn net.Conn) error {
var draining bool
{
s.mu.Lock()
draining = s.mu.draining
s.mu.Unlock()
}
// If the Server is draining, we will use the connection only to send an
// error, so we don't count it in the stats. This makes sense since
// DrainClient() waits for that number to drop to zero,
// so we don't want it to oscillate unnecessarily.
if !draining {
s.metrics.Conns.Inc(1)
defer s.metrics.Conns.Dec(1)
}
var buf readBuffer
n, err := buf.readUntypedMsg(conn)
if err != nil {
return err
}
s.metrics.BytesInCount.Inc(int64(n))
version, err := buf.getUint32()
if err != nil {
return err
}
errSSLRequired := false
if version == versionSSL {
if len(buf.msg) > 0 {
return errors.Errorf("unexpected data after SSLRequest: %q", buf.msg)
}
if s.cfg.Insecure {
if _, err := conn.Write(sslUnsupported); err != nil {
return err
}
} else {
if _, err := conn.Write(sslSupported); err != nil {
return err
}
tlsConfig, err := s.cfg.GetServerTLSConfig()
if err != nil {
return err
}
conn = tls.Server(conn, tlsConfig)
}
n, err := buf.readUntypedMsg(conn)
if err != nil {
return err
}
s.metrics.BytesInCount.Inc(int64(n))
version, err = buf.getUint32()
if err != nil {
return err
}
} else if !s.cfg.Insecure {
errSSLRequired = true
}
if version == version30 {
// We make a connection before anything. If there is an error
// parsing the connection arguments, the connection will only be
// used to send a report of that error.
v3conn := makeV3Conn(conn, &s.metrics, &s.sqlMemoryPool, s.executor)
defer v3conn.finish(ctx)
if v3conn.sessionArgs, err = parseOptions(buf.msg); err != nil {
return v3conn.sendInternalError(err.Error())
}
if errSSLRequired {
return v3conn.sendInternalError(ErrSSLRequired)
}
if draining {
// TODO(tschottdorf): Likely not handled gracefully by clients.
// See #6295.
return v3conn.sendInternalError(ErrDraining)
}
v3conn.sessionArgs.User = parser.Name(v3conn.sessionArgs.User).Normalize()
if err := v3conn.handleAuthentication(ctx, s.cfg.Insecure); err != nil {
return v3conn.sendInternalError(err.Error())
}
// Reserve some memory for this connection using the server's
// monitor. This reduces pressure on the shared pool because the
// server monitor allocates in chunks from the shared pool and
// these chunks should be larger than baseSQLMemoryBudget.
//
// We only reserve memory to the connection monitor after
// authentication has completed successfully, so as to prevent a DoS
// attack: many open-but-unauthenticated connections that exhaust
// the memory available to connections already open.
acc := s.connMonitor.MakeBoundAccount(ctx)
if err := acc.Grow(baseSQLMemoryBudget); err != nil {
return errors.Errorf("unable to pre-allocate %d bytes for this connection: %v",
baseSQLMemoryBudget, err)
}
return v3conn.serve(ctx, acc)
}
return errors.Errorf("unknown protocol version %d", version)
}