-
Notifications
You must be signed in to change notification settings - Fork 110
/
server.go
83 lines (75 loc) · 2.01 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
package mlmodel
import (
"context"
"encoding/base64"
pb "go.viam.com/api/service/mlmodel/v1"
vprotoutils "go.viam.com/utils/protoutils"
"google.golang.org/protobuf/types/known/structpb"
"go.viam.com/rdk/resource"
)
// serviceServer implements the MLModelService from mlmodel.proto.
type serviceServer struct {
pb.UnimplementedMLModelServiceServer
coll resource.APIResourceCollection[Service]
}
// NewRPCServiceServer constructs a ML Model gRPC service server.
// It is intentionally untyped to prevent use outside of tests.
func NewRPCServiceServer(coll resource.APIResourceCollection[Service]) interface{} {
return &serviceServer{coll: coll}
}
func (server *serviceServer) Infer(ctx context.Context, req *pb.InferRequest) (*pb.InferResponse, error) {
svc, err := server.coll.Resource(req.Name)
if err != nil {
return nil, err
}
id, err := asMap(req.InputData)
if err != nil {
return nil, err
}
od, err := svc.Infer(ctx, id)
if err != nil {
return nil, err
}
outputData, err := vprotoutils.StructToStructPb(od)
if err != nil {
return nil, err
}
return &pb.InferResponse{OutputData: outputData}, nil
}
// AsMap converts x to a general-purpose Go map.
// The map values are converted by calling Value.AsInterface.
func asMap(x *structpb.Struct) (map[string]interface{}, error) {
f := x.GetFields()
vs := make(map[string]interface{}, len(f))
for k, in := range f {
switch in.GetKind().(type) {
case *structpb.Value_StringValue:
out, err := base64.StdEncoding.DecodeString(in.GetStringValue())
if err != nil {
return nil, err
}
vs[k] = out
default:
vs[k] = in.AsInterface()
}
}
return vs, nil
}
func (server *serviceServer) Metadata(
ctx context.Context,
req *pb.MetadataRequest,
) (*pb.MetadataResponse, error) {
svc, err := server.coll.Resource(req.Name)
if err != nil {
return nil, err
}
md, err := svc.Metadata(ctx)
if err != nil {
return nil, err
}
metadata, err := md.toProto()
if err != nil {
return nil, err
}
return &pb.MetadataResponse{Metadata: metadata}, nil
}