Skip to content

Commit

Permalink
fix: recover from panics in gRPC server (#1149)
Browse files Browse the repository at this point in the history
Panics in the gRPC server now result in `codes.Internal` being returned, instead of killing the server.
  • Loading branch information
zepatrik committed Nov 24, 2022
1 parent 2e19042 commit 3e38d13
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 11 deletions.
35 changes: 24 additions & 11 deletions internal/driver/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@ import (
"net/http"
"os"
"os/signal"
"runtime/debug"
"strings"
"syscall"

grpcRecovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/ory/keto/internal/namespace/namespacehandler"
"github.com/ory/keto/internal/schema"
rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2"
Expand All @@ -22,7 +27,6 @@ import (

"github.com/ory/x/logrusx"

grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpcLogrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
"github.com/julienschmidt/httprouter"
"github.com/ory/herodot"
Expand Down Expand Up @@ -422,14 +426,22 @@ func (r *RegistryDefault) OPLSyntaxRouter(ctx context.Context) http.Handler {
return handler
}

func (r *RegistryDefault) grpcRecoveryHandler(_ context.Context, p interface{}) error {
r.Logger().
WithField("reason", p).
WithField("stack_trace", string(debug.Stack())).
WithField("handler", "rate_limit").
Error("panic recovered")
return status.Errorf(codes.Internal, "%v", p)
}

func (r *RegistryDefault) unaryInterceptors(ctx context.Context) []grpc.UnaryServerInterceptor {
is := make([]grpc.UnaryServerInterceptor, len(r.defaultUnaryInterceptors), len(r.defaultUnaryInterceptors)+2)
copy(is, r.defaultUnaryInterceptors)
is := make([]grpc.UnaryServerInterceptor, len(r.defaultUnaryInterceptors)+1, len(r.defaultUnaryInterceptors)+5)
is[0] = grpcRecovery.UnaryServerInterceptor(grpcRecovery.WithRecoveryHandlerContext(r.grpcRecoveryHandler))
copy(is[1:], r.defaultUnaryInterceptors)
is = append(is,
herodot.UnaryErrorUnwrapInterceptor,
grpcMiddleware.ChainUnaryServer(
grpcLogrus.UnaryServerInterceptor(r.l.Entry),
),
grpcLogrus.UnaryServerInterceptor(r.l.Entry),
)
if r.Tracer(ctx).IsLoaded() {
is = append(is, grpcOtel.UnaryServerInterceptor(grpcOtel.WithTracerProvider(otel.GetTracerProvider())))
Expand All @@ -441,13 +453,14 @@ func (r *RegistryDefault) unaryInterceptors(ctx context.Context) []grpc.UnarySer
}

func (r *RegistryDefault) streamInterceptors(ctx context.Context) []grpc.StreamServerInterceptor {
is := make([]grpc.StreamServerInterceptor, len(r.defaultStreamInterceptors), len(r.defaultStreamInterceptors)+2)
copy(is, r.defaultStreamInterceptors)
is := make([]grpc.StreamServerInterceptor, len(r.defaultStreamInterceptors)+1, len(r.defaultStreamInterceptors)+5)
// The recovery interceptor must be the first one to recover panics in other interceptors as well.
is[0] = grpcRecovery.StreamServerInterceptor(grpcRecovery.WithRecoveryHandlerContext(r.grpcRecoveryHandler))

copy(is[1:], r.defaultStreamInterceptors)
is = append(is,
herodot.StreamErrorUnwrapInterceptor,
grpcMiddleware.ChainStreamServer(
grpcLogrus.StreamServerInterceptor(r.l.Entry),
),
grpcLogrus.StreamServerInterceptor(r.l.Entry),
)
if r.Tracer(ctx).IsLoaded() {
is = append(is, grpcOtel.StreamServerInterceptor(grpcOtel.WithTracerProvider(otel.GetTracerProvider())))
Expand Down
64 changes: 64 additions & 0 deletions internal/driver/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,23 @@
package driver

import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/phayes/freeport"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
grpcHealthV1 "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/status"

"github.com/ory/keto/internal/driver/config"

"context"

prometheus "github.com/ory/x/prometheusx"
Expand Down Expand Up @@ -37,3 +49,55 @@ func TestMetricsHandler(t *testing.T) {
require.Contains(t, string(body), promLogLine)
}
}

func TestPanicRecovery(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

unaryPanicInterceptor := func(context.Context, interface{}, *grpc.UnaryServerInfo, grpc.UnaryHandler) (interface{}, error) {
panic("test panic")
}
streamPanicInterceptor := func(context.Context, interface{}, *grpc.UnaryServerInfo, grpc.UnaryHandler) (interface{}, error) {
panic("test panic")
}
port, err := freeport.GetFreePort()
require.NoError(t, err)

r := NewSqliteTestRegistry(t, false, WithGRPCUnaryInterceptors(unaryPanicInterceptor), WithGRPCUnaryInterceptors(streamPanicInterceptor))
require.NoError(t, r.Config(ctx).Set(config.KeyWriteAPIPort, port))

eg := errgroup.Group{}
doneShutdown := make(chan struct{})
eg.Go(r.serveWrite(ctx, doneShutdown))

conn, err := grpc.DialContext(ctx, fmt.Sprintf("127.0.0.1:%d", port), grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()

cl := grpcHealthV1.NewHealthClient(conn)

watcher, err := cl.Watch(ctx, &grpcHealthV1.HealthCheckRequest{})
require.NoError(t, err)
require.NoError(t, watcher.CloseSend())
for err := status.Error(codes.Unavailable, "init"); status.Code(err) != codes.Unavailable; _, err = watcher.Recv() {
}

// we want to ensure the server is still running after the panic
for i := 0; i < 10; i++ {
// Unary call
resp, err := cl.Check(ctx, &grpcHealthV1.HealthCheckRequest{})
require.Error(t, err, "%+v", resp)
assert.Equal(t, codes.Internal, status.Code(err))

// Streaming call
wResp, err := cl.Watch(ctx, &grpcHealthV1.HealthCheckRequest{})
require.NoError(t, err)
err = wResp.RecvMsg(nil)
require.Error(t, err)
assert.Equal(t, codes.Internal, status.Code(err))
}

cancel()
<-doneShutdown
require.NoError(t, eg.Wait())
}
6 changes: 6 additions & 0 deletions internal/driver/registry_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ func WithGRPCUnaryInterceptors(i ...grpc.UnaryServerInterceptor) TestRegistryOpt
}
}

func WithGRPCStreamInterceptors(i ...grpc.StreamServerInterceptor) TestRegistryOption {
return func(_ testing.TB, r *RegistryDefault) {
r.defaultStreamInterceptors = i
}
}

type selfSignedCert struct {
once sync.Once
cert *tls.Certificate
Expand Down

0 comments on commit 3e38d13

Please sign in to comment.