-
Notifications
You must be signed in to change notification settings - Fork 117
/
server.go
370 lines (323 loc) · 12.5 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
package server
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
grpc_validator "github.com/grpc-ecosystem/go-grpc-middleware/validator"
gateway "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/prometheus/client_golang/prometheus/promhttp"
runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1"
"github.com/rilldata/rill/runtime"
"github.com/rilldata/rill/runtime/pkg/activity"
"github.com/rilldata/rill/runtime/pkg/graceful"
"github.com/rilldata/rill/runtime/pkg/httputil"
"github.com/rilldata/rill/runtime/pkg/middleware"
"github.com/rilldata/rill/runtime/pkg/observability"
"github.com/rilldata/rill/runtime/pkg/ratelimit"
"github.com/rilldata/rill/runtime/pkg/securetoken"
"github.com/rilldata/rill/runtime/queries"
"github.com/rilldata/rill/runtime/server/auth"
"github.com/rs/cors"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.opentelemetry.io/otel"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
)
var tracer = otel.Tracer("github.com/rilldata/rill/runtime/server")
var ErrForbidden = status.Error(codes.Unauthenticated, "action not allowed")
type Options struct {
HTTPPort int
GRPCPort int
AllowedOrigins []string
ServePrometheus bool
SessionKeyPairs [][]byte
AuthEnable bool
AuthIssuerURL string
AuthAudienceURL string
TLSCertPath string
TLSKeyPath string
}
type Server struct {
runtimev1.UnsafeRuntimeServiceServer
runtimev1.UnsafeQueryServiceServer
runtimev1.UnsafeConnectorServiceServer
runtime *runtime.Runtime
opts *Options
logger *zap.Logger
aud *auth.Audience
codec *securetoken.Codec
limiter ratelimit.Limiter
activity *activity.Client
}
var (
_ runtimev1.RuntimeServiceServer = (*Server)(nil)
_ runtimev1.QueryServiceServer = (*Server)(nil)
_ runtimev1.ConnectorServiceServer = (*Server)(nil)
)
// NewServer creates a new runtime server.
// The provided ctx is used for the lifetime of the server for background refresh of the JWKS that is used to validate auth tokens.
func NewServer(ctx context.Context, opts *Options, rt *runtime.Runtime, logger *zap.Logger, limiter ratelimit.Limiter, activityClient *activity.Client) (*Server, error) {
// The runtime doesn't actually set cookies, but we use securecookie to encode/decode ephemeral tokens.
// If no session key pairs are provided, we generate a random one for the duration of the process.
var codec *securetoken.Codec
if len(opts.SessionKeyPairs) == 0 {
codec = securetoken.NewRandom()
} else {
codec = securetoken.NewCodec(opts.SessionKeyPairs)
}
srv := &Server{
runtime: rt,
opts: opts,
logger: logger,
codec: codec,
limiter: limiter,
activity: activityClient,
}
if opts.AuthEnable {
aud, err := auth.OpenAudience(ctx, logger, opts.AuthIssuerURL, opts.AuthAudienceURL)
if err != nil {
return nil, err
}
srv.aud = aud
}
return srv, nil
}
// Close should be called when the server is done
func (s *Server) Close() error {
// TODO: This should probably trigger a server shutdown
if s.aud != nil {
s.aud.Close()
}
return nil
}
// Ping implements RuntimeService
func (s *Server) Ping(ctx context.Context, req *runtimev1.PingRequest) (*runtimev1.PingResponse, error) {
resp := &runtimev1.PingResponse{
Version: "", // TODO: Return version
Time: timestamppb.New(time.Now()),
}
return resp, nil
}
// ServeGRPC Starts the gRPC server.
func (s *Server) ServeGRPC(ctx context.Context) error {
server := grpc.NewServer(
grpc.ChainStreamInterceptor(
middleware.TimeoutStreamServerInterceptor(timeoutSelector),
observability.LoggingStreamServerInterceptor(s.logger),
grpc_validator.StreamServerInterceptor(),
auth.StreamServerInterceptor(s.aud),
middleware.ActivityStreamServerInterceptor(s.activity),
errorMappingStreamServerInterceptor(),
grpc_auth.StreamServerInterceptor(s.checkRateLimit),
),
grpc.ChainUnaryInterceptor(
middleware.TimeoutUnaryServerInterceptor(timeoutSelector),
observability.LoggingUnaryServerInterceptor(s.logger),
grpc_validator.UnaryServerInterceptor(),
auth.UnaryServerInterceptor(s.aud),
middleware.ActivityUnaryServerInterceptor(s.activity),
errorMappingUnaryServerInterceptor(),
grpc_auth.UnaryServerInterceptor(s.checkRateLimit),
),
grpc.StatsHandler(otelgrpc.NewServerHandler()),
)
runtimev1.RegisterRuntimeServiceServer(server, s)
runtimev1.RegisterQueryServiceServer(server, s)
runtimev1.RegisterConnectorServiceServer(server, s)
s.logger.Sugar().Infof("serving runtime gRPC on port:%v", s.opts.GRPCPort)
return graceful.ServeGRPC(ctx, server, s.opts.GRPCPort)
}
// Starts the HTTP server.
func (s *Server) ServeHTTP(ctx context.Context, registerAdditionalHandlers func(mux *http.ServeMux)) error {
handler, err := s.HTTPHandler(ctx, registerAdditionalHandlers)
if err != nil {
return err
}
server := &http.Server{Handler: handler}
s.logger.Sugar().Infof("serving HTTP on port:%v", s.opts.HTTPPort)
options := graceful.ServeOptions{
Port: s.opts.HTTPPort,
CertPath: s.opts.TLSCertPath,
KeyPath: s.opts.TLSKeyPath,
}
return graceful.ServeHTTP(ctx, server, options)
}
// HTTPHandler HTTP handler serving REST gateway.
func (s *Server) HTTPHandler(ctx context.Context, registerAdditionalHandlers func(mux *http.ServeMux)) (http.Handler, error) {
// Create REST gateway
gwMux := gateway.NewServeMux(gateway.WithErrorHandler(HTTPErrorHandler))
opts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
grpcAddress := fmt.Sprintf("localhost:%d", s.opts.GRPCPort)
err := runtimev1.RegisterRuntimeServiceHandlerFromEndpoint(ctx, gwMux, grpcAddress, opts)
if err != nil {
return nil, err
}
err = runtimev1.RegisterQueryServiceHandlerFromEndpoint(ctx, gwMux, grpcAddress, opts)
if err != nil {
return nil, err
}
err = runtimev1.RegisterConnectorServiceHandlerFromEndpoint(ctx, gwMux, grpcAddress, opts)
if err != nil {
return nil, err
}
// One-off REST-only path for multipart file upload
// NOTE: It's local only and we should deprecate it in favor of a cloud-friendly alternative.
err = gwMux.HandlePath("POST", "/v1/instances/{instance_id}/files/upload/-/{path=**}", auth.GatewayMiddleware(s.aud, s.UploadMultipartFile))
if err != nil {
panic(err)
}
// Call callback to register additional paths
// NOTE: This is so ugly, but not worth refactoring it properly right now.
httpMux := http.NewServeMux()
if registerAdditionalHandlers != nil {
registerAdditionalHandlers(httpMux)
}
// Add gRPC-gateway on httpMux
httpMux.Handle("/v1/", gwMux)
// Add HTTP handler for query export downloads
observability.MuxHandle(httpMux, "/v1/download", observability.Middleware("runtime", s.logger, auth.HTTPMiddleware(s.aud, http.HandlerFunc(s.downloadHandler))))
// Add handler for dynamic APIs, i.e. APIs backed by resolvers (such as custom APIs defined in YAML).
observability.MuxHandle(httpMux, "/v1/instances/{instance_id}/api/{name...}", observability.Middleware("runtime", s.logger, auth.HTTPMiddleware(s.aud, httputil.Handler(s.apiHandler))))
// Add handler for resolving component data
observability.MuxHandle(httpMux, "/v1/instances/{instance_id}/components/{name}/data", observability.Middleware("runtime", s.logger, auth.HTTPMiddleware(s.aud, httputil.Handler(s.componentDataHandler))))
// Add Prometheus
if s.opts.ServePrometheus {
httpMux.Handle("/metrics", promhttp.Handler())
}
// Build CORS options for runtime server
// If the AllowedOrigins contains a "*" we want to return the requester's origin instead of "*" in the "Access-Control-Allow-Origin" header.
// This is useful in development. In production, we set AllowedOrigins to non-wildcard values, so this does not have security implications.
// Details: https://github.com/rs/cors#allow--with-credentials-security-protection
var allowedOriginFunc func(string) bool
allowedOrigins := s.opts.AllowedOrigins
for _, origin := range s.opts.AllowedOrigins {
if origin == "*" {
allowedOriginFunc = func(origin string) bool { return true }
allowedOrigins = nil
break
}
}
corsOpts := cors.Options{
AllowedOrigins: allowedOrigins,
AllowOriginFunc: allowedOriginFunc,
AllowedMethods: []string{
http.MethodHead,
http.MethodGet,
http.MethodPost,
http.MethodPut,
http.MethodPatch,
http.MethodDelete,
},
AllowedHeaders: []string{"*"},
AllowCredentials: false,
// Set max age to 1 hour (default if not set is 5 seconds)
MaxAge: 60 * 60,
}
// Wrap mux with CORS middleware
handler := cors.New(corsOpts).Handler(httpMux)
return handler, nil
}
// HTTPErrorHandler wraps gateway.DefaultHTTPErrorHandler to map gRPC unknown errors (i.e. errors without an explicit
// code) to HTTP status code 400 instead of 500.
func HTTPErrorHandler(ctx context.Context, mux *gateway.ServeMux, marshaler gateway.Marshaler, w http.ResponseWriter, r *http.Request, err error) {
s := status.Convert(err)
if s.Code() == codes.Unknown {
err = &gateway.HTTPStatusError{HTTPStatus: http.StatusBadRequest, Err: err}
}
gateway.DefaultHTTPErrorHandler(ctx, mux, marshaler, w, r, err)
}
func timeoutSelector(fullMethodName string) time.Duration {
if strings.HasPrefix(fullMethodName, "/rill.runtime.v1.RuntimeService") && (strings.Contains(fullMethodName, "/Trigger") || strings.HasSuffix(fullMethodName, "Reconcile")) {
return time.Minute * 59 // Not 60 to avoid forced timeout on ingress
}
if strings.HasPrefix(fullMethodName, "/rill.runtime.v1.QueryService") {
return time.Minute * 5
}
if fullMethodName == runtimev1.RuntimeService_WatchFiles_FullMethodName {
return time.Minute * 30
}
if fullMethodName == runtimev1.RuntimeService_WatchResources_FullMethodName {
return time.Minute * 30
}
if fullMethodName == runtimev1.RuntimeService_WatchLogs_FullMethodName {
return time.Minute * 30
}
return time.Second * 30
}
// errorMappingUnaryServerInterceptor is an interceptor that applies mapGRPCError.
func errorMappingUnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
resp, err := handler(ctx, req)
return resp, mapGRPCError(err)
}
}
// errorMappingUnaryServerInterceptor is an interceptor that applies mapGRPCError.
func errorMappingStreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
err := handler(srv, ss)
return mapGRPCError(err)
}
}
// mapGRPCError rewrites errors returned from gRPC handlers before they are returned to the client.
func mapGRPCError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, context.DeadlineExceeded) {
return status.Error(codes.DeadlineExceeded, err.Error())
}
if errors.Is(err, context.Canceled) {
return status.Error(codes.Canceled, err.Error())
}
if errors.Is(err, queries.ErrForbidden) {
return ErrForbidden
}
return err
}
func (s *Server) checkRateLimit(ctx context.Context) (context.Context, error) {
// Any request type might be limited separately as it is part of Metadata
// Any request type might be excluded from this limit check and limited later,
// e.g. in the corresponding request handler by calling s.limiter.Limit(ctx, "limitKey", redis_rate.PerMinute(100))
if auth.GetClaims(ctx).Subject() == "" {
method, ok := grpc.Method(ctx)
if !ok {
return ctx, fmt.Errorf("server context does not have a method")
}
limitKey := ratelimit.AnonLimitKey(method, observability.GrpcPeer(ctx))
if err := s.limiter.Limit(ctx, limitKey, ratelimit.Public); err != nil {
if errors.As(err, &ratelimit.QuotaExceededError{}) {
return ctx, status.Errorf(codes.ResourceExhausted, err.Error())
}
return ctx, err
}
}
return ctx, nil
}
func (s *Server) addInstanceRequestAttributes(ctx context.Context, instanceID string) {
attrs := s.runtime.GetInstanceAttributes(ctx, instanceID)
observability.AddRequestAttributes(ctx, attrs...)
}
func (s *Server) IssueDevJWT(ctx context.Context, req *runtimev1.IssueDevJWTRequest) (*runtimev1.IssueDevJWTResponse, error) {
attr := map[string]any{
"name": req.Name,
"email": req.Email,
"domain": req.Email[strings.LastIndex(req.Email, "@")+1:],
"groups": req.Groups,
"admin": req.Admin,
}
jwt, err := auth.NewDevToken(attr)
if err != nil {
return nil, err
}
return &runtimev1.IssueDevJWTResponse{
Jwt: jwt,
}, nil
}