From 53a9eca623f1f6daebd9014aad10aefca05f70f7 Mon Sep 17 00:00:00 2001 From: Andrew Benton Date: Thu, 2 Nov 2023 11:35:37 -0700 Subject: [PATCH 1/2] feat: Implement plugin RPC handler --- codegen/codegen.go | 38 +------------ codegen/server.go | 17 ++++++ rpc/handler.go | 138 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 157 insertions(+), 36 deletions(-) create mode 100644 codegen/server.go create mode 100644 rpc/handler.go diff --git a/codegen/codegen.go b/codegen/codegen.go index 5df712b..f67aa62 100644 --- a/codegen/codegen.go +++ b/codegen/codegen.go @@ -1,48 +1,14 @@ package codegen import ( - "bufio" "context" - "fmt" - "io" - "os" pb "github.com/sqlc-dev/sqlc-go/plugin" - "google.golang.org/protobuf/proto" + "github.com/sqlc-dev/sqlc-go/rpc" ) type Handler func(context.Context, *pb.GenerateRequest) (*pb.GenerateResponse, error) func Run(h Handler) { - if err := run(h); err != nil { - fmt.Fprintf(os.Stderr, "error generating output: %s", err) - os.Exit(2) - } -} - -func run(h Handler) error { - var req pb.GenerateRequest - reqBlob, err := io.ReadAll(os.Stdin) - if err != nil { - return err - } - if err := proto.Unmarshal(reqBlob, &req); err != nil { - return err - } - resp, err := h(context.Background(), &req) - if err != nil { - return err - } - respBlob, err := proto.Marshal(resp) - if err != nil { - return err - } - w := bufio.NewWriter(os.Stdout) - if _, err := w.Write(respBlob); err != nil { - return err - } - if err := w.Flush(); err != nil { - return err - } - return nil + rpc.Handle(&server{handler: h}) } diff --git a/codegen/server.go b/codegen/server.go new file mode 100644 index 0000000..2dd5030 --- /dev/null +++ b/codegen/server.go @@ -0,0 +1,17 @@ +package codegen + +import ( + "context" + + pb "github.com/sqlc-dev/sqlc-go/plugin" +) + +type server struct { + pb.UnimplementedCodegenServiceServer + + handler Handler +} + +func (s *server) Generate(ctx context.Context, req *pb.GenerateRequest) (*pb.GenerateResponse, error) { + return s.handler(ctx, req) +} diff --git a/rpc/handler.go b/rpc/handler.go new file mode 100644 index 0000000..c6c4d25 --- /dev/null +++ b/rpc/handler.go @@ -0,0 +1,138 @@ +package rpc + +import ( + "bufio" + "context" + "fmt" + "io" + "os" + "strings" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoreflect" + + "github.com/sqlc-dev/sqlc-go/plugin" + pb "github.com/sqlc-dev/sqlc-go/plugin" +) + +func Handle(server pb.CodegenServiceServer) { + if err := handle(server); err != nil { + fmt.Fprintf(os.Stderr, "error generating output: %s", err) + os.Exit(2) + } +} + +func handle(server pb.CodegenServiceServer) error { + handler := newStdioRPCHandler() + pb.RegisterCodegenServiceServer(handler, server) + return handler.Handle() +} + +type stdioRPCHandler struct { + services map[string]*serviceInfo +} + +func newStdioRPCHandler() *stdioRPCHandler { + return &stdioRPCHandler{services: map[string]*serviceInfo{}} +} + +type serviceInfo struct { + serviceImpl any + methods map[string]*grpc.MethodDesc +} + +func (s *stdioRPCHandler) RegisterService(sd *grpc.ServiceDesc, ss any) { + // TODO some type checking, see e.g. grpc server.RegisterService() + info := &serviceInfo{ + serviceImpl: ss, + methods: make(map[string]*grpc.MethodDesc), + } + for i := range sd.Methods { + d := &sd.Methods[i] + info.methods[d.MethodName] = d + } + s.services[sd.ServiceName] = info +} + +func (s *stdioRPCHandler) Handle() error { + var methodArg string + if len(os.Args) < 2 { + // For backwards compatibility with sqlc before v1.24.0 + methodArg = fmt.Sprintf("/%s/%s", pb.CodegenService_ServiceDesc.ServiceName, "Generate") + } else { + methodArg = os.Args[1] + } + + // Adapted from grpc server handleStream() + + sm := methodArg + if sm != "" && sm[0] == '/' { + sm = sm[1:] + } + pos := strings.LastIndex(sm, "/") + if pos == -1 { + errDesc := fmt.Sprintf("malformed method name: %q", methodArg) + return status.Error(codes.Unimplemented, errDesc) + } + service := sm[:pos] + method := sm[pos+1:] + + srv, knownService := s.services[service] + if knownService { + if md, ok := srv.methods[method]; ok { + return s.processUnaryRPC(srv, md) + } + } + + // Unknown service, or known server unknown method. + var errDesc string + if !knownService { + errDesc = fmt.Sprintf("unknown service %v", service) + } else { + errDesc = fmt.Sprintf("unknown method %v for service %v", method, service) + } + + return status.Error(codes.Unimplemented, errDesc) +} + +func (s *stdioRPCHandler) processUnaryRPC(srv *serviceInfo, md *grpc.MethodDesc) error { + reqBytes, err := io.ReadAll(os.Stdin) + if err != nil { + return err + } + + var resp protoreflect.ProtoMessage + + // TODO make this generic + switch md.MethodName { + case "Generate": + var req plugin.GenerateRequest + if err := proto.Unmarshal(reqBytes, &req); err != nil { + return err + } + service, ok := srv.serviceImpl.(pb.CodegenServiceServer) + if !ok { + return status.Errorf(codes.Internal, codes.Internal.String()) + } + resp, err = service.Generate(context.Background(), &req) + if err != nil { + return err + } + } + + respBytes, err := proto.Marshal(resp) + if err != nil { + return err + } + w := bufio.NewWriter(os.Stdout) + if _, err := w.Write(respBytes); err != nil { + return err + } + if err := w.Flush(); err != nil { + return err + } + return nil +} From 0f6dcf607045bb8eae4a48baeadc5875605c961a Mon Sep 17 00:00:00 2001 From: Andrew Benton Date: Thu, 2 Nov 2023 11:51:56 -0700 Subject: [PATCH 2/2] make rpc package internal --- codegen/codegen.go | 2 +- {rpc => internal/rpc}/handler.go | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename {rpc => internal/rpc}/handler.go (100%) diff --git a/codegen/codegen.go b/codegen/codegen.go index f67aa62..1576af7 100644 --- a/codegen/codegen.go +++ b/codegen/codegen.go @@ -3,8 +3,8 @@ package codegen import ( "context" + "github.com/sqlc-dev/sqlc-go/internal/rpc" pb "github.com/sqlc-dev/sqlc-go/plugin" - "github.com/sqlc-dev/sqlc-go/rpc" ) type Handler func(context.Context, *pb.GenerateRequest) (*pb.GenerateResponse, error) diff --git a/rpc/handler.go b/internal/rpc/handler.go similarity index 100% rename from rpc/handler.go rename to internal/rpc/handler.go