/
grpc.go
204 lines (186 loc) · 6.05 KB
/
grpc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
package grpcserver
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"os"
"time"
grpczap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap"
grpctags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/reflection"
)
// ServiceAPI allows individual grpc services to register the grpc server.
type ServiceAPI interface {
RegisterService(*grpc.Server)
RegisterHandlerService(*runtime.ServeMux) error
String() string
}
// Server is a very basic grpc server.
type Server struct {
listener string
logger *zap.Logger
// BoundAddress contains the address that the server bound to, useful if
// the server uses a dynamic port. It is set during startup and can be
// safely accessed after Start has completed (I.E. the returned channel has
// been waited on)
BoundAddress string
GrpcServer *grpc.Server
grp errgroup.Group
}
func unaryGrpcLogStart(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
ctxzap.Info(ctx, "started unary call")
return handler(ctx, req)
}
func streamingGrpcLogStart(
srv any,
stream grpc.ServerStream,
_ *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
ctxzap.Info(stream.Context(), "started streaming call")
return handler(srv, stream)
}
// NewWithServices creates a new Server listening on the provided address with the given logger and config.
// Services passed in the svc slice are registered with the server.
func NewWithServices(
listener string,
logger *zap.Logger,
config Config,
svc []ServiceAPI,
grpcOpts ...grpc.ServerOption,
) (*Server, error) {
if len(svc) == 0 {
return nil, errors.New("no services to register")
}
// check if listener IP is in private network range
host, _, err := net.SplitHostPort(listener)
if err != nil {
return nil, fmt.Errorf("split local listener: %w", err)
}
ip := net.ParseIP(host)
if host != "localhost" && !ip.IsPrivate() && !ip.IsLoopback() {
logger.Warn("unsecured grpc server is listening on a public IP address", zap.String("address", listener))
} else {
logger.Info("grpc server is listening on a private IP address", zap.String("address", listener))
}
server := New(listener, logger, config, grpcOpts...)
for _, s := range svc {
s.RegisterService(server.GrpcServer)
}
return server, nil
}
// NewTLS creates a new Server listening on the TLSListener address with the given logger and config.
// Services passed in the svc slice are registered with the server.
func NewTLS(logger *zap.Logger, config Config, svc []ServiceAPI) (*Server, error) {
if len(svc) == 0 {
return nil, errors.New("no services to register")
}
serverCert, err := tls.LoadX509KeyPair(config.TLSCert, config.TLSKey)
if err != nil {
return nil, fmt.Errorf("load server certificate: %w", err)
}
caCert, err := os.ReadFile(config.TLSCACert)
if err != nil {
return nil, fmt.Errorf("load ca certificate: %w", err)
}
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(caCert) {
return nil, errors.New("setup CA certificate")
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: certPool,
}
server := New(config.TLSListener, logger, config, grpc.Creds(credentials.NewTLS(tlsConfig)))
for _, s := range svc {
s.RegisterService(server.GrpcServer)
}
return server, nil
}
// New creates and returns a new Server listening on the given address.
// The server is configured with the given logger and config. Additional grpc options can be passed.
func New(listener string, logger *zap.Logger, config Config, grpcOpts ...grpc.ServerOption) *Server {
opts := []grpc.ServerOption{
grpc.ChainStreamInterceptor(
grpctags.StreamServerInterceptor(),
grpczap.StreamServerInterceptor(logger),
streamingGrpcLogStart,
),
grpc.ChainUnaryInterceptor(
grpctags.UnaryServerInterceptor(),
grpczap.UnaryServerInterceptor(logger),
unaryGrpcLogStart,
),
grpc.MaxSendMsgSize(config.GrpcSendMsgSize),
grpc.MaxRecvMsgSize(config.GrpcRecvMsgSize),
grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
MinTime: 1 * time.Minute, // keep alive more often than once per `MinTime` will be disconnected
}),
grpc.KeepaliveParams(keepalive.ServerParameters{
Time: 10 * time.Minute,
Timeout: 10 * time.Second,
}),
}
opts = append(opts, grpcOpts...)
return &Server{
listener: listener,
logger: logger,
GrpcServer: grpc.NewServer(opts...),
}
}
// Start starts the server.
func (s *Server) Start() error {
s.logger.Info("starting grpc server",
zap.String("address", s.listener),
zap.Array("services", zapcore.ArrayMarshalerFunc(func(encoder zapcore.ArrayEncoder) error {
for svc := range s.GrpcServer.GetServiceInfo() {
encoder.AppendString(svc)
}
return nil
})),
)
lis, err := net.Listen("tcp", s.listener)
if err != nil {
s.logger.Error("start listen server", zap.Error(err))
return err
}
s.BoundAddress = lis.Addr().String()
reflection.Register(s.GrpcServer)
s.logger.Info("bound to address", zap.String("address", s.BoundAddress))
s.grp.Go(func() error {
if err := s.GrpcServer.Serve(lis); err != nil {
s.logger.Error("serving grpc server", zap.Error(err))
return err
}
return nil
})
return nil
}
// Close stops the server.
func (s *Server) Close() error {
s.logger.Info("stopping the grpc server")
// GracefulStop waits for all connections to be closed before closing the
// server and returning. If there are long running stream connections then
// GracefulStop will never return. So we call it in a background thread,
// wait a bit and then call Stop which will forcefully close any remaining
// connections.
s.grp.Go(func() error {
s.GrpcServer.GracefulStop()
return nil
})
time.Sleep(time.Second * 1)
s.GrpcServer.Stop()
return s.grp.Wait()
}