Skip to content

Commit

Permalink
ref: refactor with new ClientImpl
Browse files Browse the repository at this point in the history
  • Loading branch information
khorevaa committed Sep 7, 2021
1 parent 7edf199 commit a7b1d5b
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 105 deletions.
1 change: 0 additions & 1 deletion encodingapis/ras/client/client.proto
Expand Up @@ -15,7 +15,6 @@ message ClientOptions {
bool is_endpoint = 2;
bool is_request_service = 3;
bool is_ras_service = 4;
bool is_test = 5;
}

extend google.protobuf.MethodOptions {
Expand Down
115 changes: 43 additions & 72 deletions generator/generate_client.go
Expand Up @@ -34,8 +34,6 @@ func (m clientGenerator) genService(service *protogen.Service) {
m.genClientConstructor(service)
m.genClientDefinition(service)
m.genDetectSupportedVersion(service)
m.genClientOptionsDefinition(service)
m.genDialMethodFunction(service)

for _, method := range service.Methods {
m.genMethodHandler(service, method)
Expand All @@ -50,37 +48,40 @@ func (m clientGenerator) genHeader(packageName string) {
}
func (m clientGenerator) genClientImpl(service *protogen.Service) {

m.g.P("type ", m.getClientImp(service), " interface {")
m.g.P("type ", m.getClientServiceImp(service), " interface {")
for _, method := range service.Methods {
m.g.P("", method.GoName, "(*", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error)")
m.g.P("", method.GoName, "(ctx ", ctxPackage.Ident("Context"), ", req *", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error)")
}
m.g.P("DetectSupportedVersion(err error) string")
m.g.P("}")
m.AddImpl(m.getClientImp(service), m.file.GoImportPath)

m.AddImpl(m.getClientServiceImp(service), m.file.GoImportPath)

m.g.P("type ", m.getClientImp(), " interface {")
m.g.P("// Методы для блокировки соединения sync.Mutex")
m.g.P("// берем из sync.Locker ")
m.g.P(syncPackage.Ident("Locker"))
m.g.P("// Методы для записи и чтения из соединение")
m.g.P("// берем из io.ReadWriter")
m.g.P(ioPackage.Ident("ReadWriter"))
m.g.P("}")

m.AddImpl(m.getClientImp(), m.file.GoImportPath)

}

func (m clientGenerator) genClientConstructor(service *protogen.Service) {
serviceName := m.getClientName(service)
optionsName := m.getClientOptionsName(service)
optionName := m.getClientOptionName(service)

m.g.P("func New", serviceName, "(host string, opts... ", optionName, ") ", m.getClientImp(service), "{")
m.g.P("options := &", optionsName, "{ timeout: 5 * ", timePackage.Ident("Second"), "}")
m.g.P("for _, opt := range opts {")
m.g.P("opt(options)")
m.g.P("}")
m.g.P("func New", serviceName, "(client ", m.getClientImp(), ") ", m.getClientServiceImp(service), "{")
m.g.P("return &", unexport(serviceName), "{")
m.g.P("host: host,")
m.g.P("", optionsName, ": options,")
m.g.P("mu: &", syncPackage.Ident("Mutex"), "{},")
m.g.P("client: client,")
m.g.P("}")
m.g.P("}")

}

func (m clientGenerator) genClientDefinition(service *protogen.Service) {
serviceName := m.getClientName(service)
optionsName := m.getClientOptionsName(service)

m.g.P("// ", serviceName, " is the client for RAS service.")
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
Expand All @@ -89,10 +90,7 @@ func (m clientGenerator) genClientDefinition(service *protogen.Service) {
}
m.g.Annotate(serviceName, service.Location)
m.g.P("type ", unexport(serviceName), " struct {")
m.g.P("*", optionsName, "")
m.g.P("host string")
m.g.P("conn ", netPackage.Ident("Conn"), "")
m.g.P("mu *", syncPackage.Ident("Mutex"), "")
m.g.P("client ", m.getClientImp(), "")
m.g.P("}")
m.g.P()
}
Expand All @@ -105,7 +103,8 @@ func (m clientGenerator) genDetectSupportedVersion(service *protogen.Service) {
m.g.P()
m.g.P("var re = ", regexpPackage.Ident("MustCompile"), "(`(?m)supported=(.*?)]`)")
m.g.P()
m.g.P("func (x *", unexport(serviceName), ") DetectSupportedVersion(err error) string {")
m.g.P("// DetectSupportedVersion func helpers detect supported version in EndpointFailureAck")
m.g.P("func DetectSupportedVersion(err error) string {")
m.g.P()
m.g.P("fail, ok := err.(*", m.ObjectNamed("ras.protocol.v1.EndpointFailureAck").GoIdent, ")")
m.g.P("if !ok { return \"\" }")
Expand All @@ -127,71 +126,39 @@ func (m clientGenerator) genDetectSupportedVersion(service *protogen.Service) {
m.g.P()
}

func (m clientGenerator) genClientOptionsDefinition(service *protogen.Service) {
// Server registration.
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
m.g.P(deprecationComment)
}
optionName := m.getClientOptionName(service)
optionsName := m.getClientOptionsName(service)
m.g.P("type ", optionName, " func(*", optionsName, ")")
m.g.P()
m.g.P("type ", optionsName, " struct {")
m.g.P("dialer *", netPackage.Ident("Dialer"), "")
m.g.P("timeout ", timePackage.Ident("Duration"), "")
m.g.P("}")
m.g.P()
m.g.P("func WithDialer(dialer *", netPackage.Ident("Dialer"), ") ", optionName, " {")
m.g.P("return func (o *", optionsName, ") { o.dialer = dialer }")
m.g.P("}")
m.g.P()
m.g.P("func SetTimeout(timeout ", timePackage.Ident("Duration"), ") ", optionName, " {")
m.g.P("return func (o *", optionsName, ") { o.timeout = timeout }")
m.g.P("}")

}

func (m clientGenerator) genDialMethodFunction(service *protogen.Service) {
m.g.P("func (x *", unexport(m.getClientName(service)), ") dial() error {")
m.g.P("if x.conn != nil { return nil }")
m.g.P("if _, err :=", netPackage.Ident("ResolveTCPAddr"), "(\"tcp\", x.host); err != nil { return err }")
m.g.P()
m.g.P("var err error")
m.g.P("if x.dialer != nil { ")
m.g.P("x.conn, err = x.dialer.Dial(\"tcp\", x.host)")
m.g.P("return err")
m.g.P("}")
m.g.P("x.conn, err = ", netPackage.Ident("Dial"), "(\"tcp\", x.host)")
m.g.P("return err")
m.g.P("}")
m.g.P()
}

func (m clientGenerator) genMethodHandler(service *protogen.Service, method *protogen.Method) {

ext := GetClientMethodExtension(method.Desc.Options())
if ext.NewEndpointFunc {
m.genNewEndpointFunc(service, method)
return
}

m.g.P("func (x *", unexport(m.getClientName(service)), ") ", method.GoName, "(req *", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error) {")
m.g.P("if err := x.dial(); err != nil { return nil, err }")
m.g.P("x.mu.Lock()")
m.g.P("defer x.mu.Unlock()")
m.g.P("func (x *", unexport(m.getClientName(service)), ") ", method.GoName, "(ctx ", ctxPackage.Ident("Context"), ", req *", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error) {")
m.g.P()
m.g.P("x.client.Lock()")
m.g.P("defer x.client.Unlock()")
m.g.P()
m.g.P("// Check context ")
m.g.P("select {")
m.g.P("case <-ctx.Done():")
m.g.P("return nil, ctx.Err()")
m.g.P("default:")
m.g.P("}")
m.g.P("")
if ext.NoPacketPack {
m.g.P("if err := req.", m.formatFuncName, "(x.conn, 0 ); err != nil { return nil, err }")
m.g.P("if err := req.", m.formatFuncName, "(x.client, 0 ); err != nil { return nil, err }")
} else {
m.g.P("packet, err := ", method.Input.GoIdent.GoImportPath.Ident("NewPacket"), "(req)")
m.g.P("if err != nil { return nil, err }")
m.g.P("if _, err := packet.WriteTo(x.conn); err != nil { return nil, err }")
m.g.P("if _, err := packet.WriteTo(x.client); err != nil { return nil, err }")
}
if isEmptyPb(method.Output.Desc) {
m.g.P("return new(", method.Output.GoIdent, "), nil")
m.g.P("}")
return
}
m.g.P("if err := x.conn.SetReadDeadline(time.Now().Add(x.timeout)); err != nil { return nil, err }")
m.g.P("ackPacket, err := ", method.Input.GoIdent.GoImportPath.Ident("NewPacket"), "(x.conn)")
m.g.P("ackPacket, err := ", method.Input.GoIdent.GoImportPath.Ident("NewPacket"), "(x.client)")
m.g.P("if err != nil { return nil, err }")
m.g.P("resp := new(", method.Output.GoIdent, ")")
m.g.P("return resp, ackPacket.Unpack(resp)")
Expand All @@ -201,7 +168,7 @@ func (m clientGenerator) genMethodHandler(service *protogen.Service, method *pro

func (m clientGenerator) genNewEndpointFunc(service *protogen.Service, method *protogen.Method) {

m.g.P("func (x *", unexport(m.getClientName(service)), ") ", method.GoName, "(req *", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error) {")
m.g.P("func (x *", unexport(m.getClientName(service)), ") ", method.GoName, "(_ ", ctxPackage.Ident("Context"), ", req *", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error) {")
m.g.P("return &", method.Output.GoIdent, "{")
m.g.P("Service: req.GetService(),")
m.g.P("Version: ", castPackage.Ident("ToInt32"), "(", castPackage.Ident("ToFloat32"), "(req.GetVersion())),")
Expand All @@ -217,7 +184,11 @@ func (m clientGenerator) getClientName(service *protogen.Service) string {
return service.GoName
}

func (m clientGenerator) getClientImp(service *protogen.Service) string {
func (m clientGenerator) getClientImp() string {
return "ClientImpl"
}

func (m clientGenerator) getClientServiceImp(service *protogen.Service) string {
return service.GoName + "Impl"
}

Expand Down
28 changes: 10 additions & 18 deletions generator/generate_endpoint.go
Expand Up @@ -47,25 +47,25 @@ func (m endpointGenerator) genHeader(packageName string) {
}
func (m endpointGenerator) genImpl(service *protogen.Service) {

m.g.P("type ", m.getClientImp(service), " interface {")
m.g.P("type ", m.getEndpointImp(service), " interface {")
for _, method := range service.Methods {
m.g.P(method.GoName, "(*", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error)")
m.g.P(method.GoName, "(ctx ", ctxPackage.Ident("Context"), ", req *", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error)")
}

m.g.P()
m.g.P("}")

m.AddImpl(m.getClientImp(service), m.file.GoImportPath)
m.AddImpl(m.getEndpointImp(service), m.file.GoImportPath)

}

func (m endpointGenerator) genConstructor(service *protogen.Service) {
clientServiceImpl := m.GetImpl("ClientServiceImpl")
endpointImpl := m.GetImpl("EndpointImpl")

serviceName := m.getClientName(service)
serviceName := m.getEndpointName(service)

m.g.P("func New", serviceName, "(clientService ", clientServiceImpl, ", endpoint ", endpointImpl, ") ", m.getClientImp(service), "{")
m.g.P("func New", serviceName, "(clientService ", clientServiceImpl, ", endpoint ", endpointImpl, ") ", m.getEndpointImp(service), "{")
m.g.P("return &", unexport(serviceName), "{")
m.g.P("endpoint,")
m.g.P("clientService,")
Expand All @@ -76,7 +76,7 @@ func (m endpointGenerator) genConstructor(service *protogen.Service) {
}

func (m endpointGenerator) genDefinition(service *protogen.Service) {
serviceName := m.getClientName(service)
serviceName := m.getEndpointName(service)
clientServiceImpl := m.GetImpl("ClientServiceImpl")
endpointImpl := m.GetImpl("EndpointImpl")

Expand All @@ -100,15 +100,15 @@ func (m endpointGenerator) genMethodHandler(service *protogen.Service, method *p

endpointMessageParser := m.GetImpl("EndpointMessageParser")

m.g.P("func (x *", unexport(m.getClientName(service)), ") ", method.GoName, "(req *", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error) {")
m.g.P("func (x *", unexport(m.getEndpointName(service)), ") ", method.GoName, "(ctx ", ctxPackage.Ident("Context"), ", req *", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error) {")
m.g.P("message, err := ", anypbPackage.Ident("UnmarshalNew"),
"(req.GetRequest(),", protoPackage.Ident("UnmarshalOptions"), "{})")
m.g.P("if err != nil { return nil, err }")
m.g.P()
m.g.P("reqMessage, err := x.NewMessage(message)")
m.g.P("if err != nil { return nil, err }")
m.g.P()
m.g.P("respMessage, err := x.client.EndpointMessage(reqMessage) ")
m.g.P("respMessage, err := x.client.EndpointMessage(ctx, reqMessage) ")
m.g.P("if err != nil { return nil, err }")
m.g.P()
m.g.P("respProtoMessage, err := ", anypbPackage.Ident("UnmarshalNew"),
Expand All @@ -128,18 +128,10 @@ func (m endpointGenerator) genMethodHandler(service *protogen.Service, method *p

}

func (m endpointGenerator) getClientName(service *protogen.Service) string {
func (m endpointGenerator) getEndpointName(service *protogen.Service) string {
return service.GoName
}

func (m endpointGenerator) getClientImp(service *protogen.Service) string {
func (m endpointGenerator) getEndpointImp(service *protogen.Service) string {
return service.GoName + "Impl"
}

func (m endpointGenerator) getClientOptionsName(service *protogen.Service) string {
return service.GoName + "Options"
}

func (m endpointGenerator) getClientOptionName(service *protogen.Service) string {
return service.GoName + "Option"
}
7 changes: 4 additions & 3 deletions generator/generate_message_service.go
Expand Up @@ -42,7 +42,7 @@ func (m messageServiceGenerator) genImpl(service *protogen.Service) {

m.g.P("type ", m.getServiceImpl(service), " interface {")
for _, method := range service.Methods {
m.g.P(method.GoName, "(*", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error)")
m.g.P(method.GoName, "(ctx ", ctxPackage.Ident("Context"), ", req *", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error)")
}
m.g.P()
m.g.P("}")
Expand Down Expand Up @@ -84,7 +84,7 @@ func (m messageServiceGenerator) genMethodHandler(service *protogen.Service, met

endpointRequest := "EndpointRequest"

m.g.P("func (x *", m.getServiceName(service), ") ", method.GoName, "(req *", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error) {")
m.g.P("func (x *", m.getServiceName(service), ") ", method.GoName, "(ctx ", ctxPackage.Ident("Context"), ", req *", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error) {")
m.g.P()
m.g.P("var resp ", method.Output.GoIdent)
m.g.P()
Expand All @@ -98,7 +98,8 @@ func (m messageServiceGenerator) genMethodHandler(service *protogen.Service, met
m.g.P("Request: anyRequest,")
m.g.P("Respond: anyRespond,")
m.g.P("}")
m.g.P("response, err := x.e.Request(endpointRequest)")
m.g.P()
m.g.P("response, err := x.e.Request(ctx, endpointRequest)")
m.g.P("if err != nil { return nil, err }")
m.g.P()
m.g.P("if err := ", anypbPackage.Ident("UnmarshalTo"),
Expand Down
6 changes: 3 additions & 3 deletions generator/generate_ras_service.go
Expand Up @@ -96,8 +96,8 @@ func (m rasServiceGenerator) genProxyMethod(service *protogen.Service, method pr
m.g.P("//")
m.g.P(deprecationComment)
}
m.g.P("func (x *", serviceName, ") ", method.proxyName, "(req *", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error) {")
m.g.P("return x.", m.getServiceName(method.Parent), ".", method.GoName, "(req)")
m.g.P("func (x *", serviceName, ") ", method.proxyName, "(ctx ", ctxPackage.Ident("Context"), ", req *", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error) {")
m.g.P("return x.", m.getServiceName(method.Parent), ".", method.GoName, "(ctx, req)")
m.g.P("}")
m.g.P()
}
Expand All @@ -107,7 +107,7 @@ func (m rasServiceGenerator) genImpl(service *protogen.Service) {
m.g.P("type ", m.getServiceImpl(service), " interface {")
for _, method := range m.idxMethods {
m.g.P("// ", method.proxyName, " proxy request ", m.getServiceName(method.Parent), ".", method.GoName)
m.g.P(method.proxyName, "(*", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error)")
m.g.P(method.proxyName, "(ctx ", ctxPackage.Ident("Context"), ", req *", method.Input.GoIdent, ") (*", method.Output.GoIdent, ", error)")
}
m.g.P()
m.g.P("}")
Expand Down
2 changes: 1 addition & 1 deletion generator/ident.go
Expand Up @@ -10,7 +10,7 @@ const (
syncPackage = protogen.GoImportPath("sync")
timePackage = protogen.GoImportPath("time")
regexpPackage = protogen.GoImportPath("regexp")
netPackage = protogen.GoImportPath("net")
ctxPackage = protogen.GoImportPath("context")
anypbPackage = protogen.GoImportPath("google.golang.org/protobuf/types/known/anypb")
protoPackage = protogen.GoImportPath("google.golang.org/protobuf/proto")
emptypbPackage = protogen.GoImportPath("google.golang.org/protobuf/types/known/emptypb")
Expand Down
14 changes: 7 additions & 7 deletions tests/buf.gen.yaml
Expand Up @@ -9,12 +9,12 @@ managed:
# - googleapies/google/api

plugins:
# - name: go
# out: ./gen
# opt: paths=source_relative
# - name: go-ras
# out: ./gen
# opt: paths=source_relative
- name: go
out: ./gen
opt: paths=source_relative
opt: paths=source_relative
- name: go-ras
out: ./gen
opt: paths=source_relative
# - name: go
# out: ./gen
# opt: paths=source_relative

0 comments on commit a7b1d5b

Please sign in to comment.