-
Notifications
You must be signed in to change notification settings - Fork 110
/
mlmodel_service.go
58 lines (50 loc) · 1.45 KB
/
mlmodel_service.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
package inject
import (
"context"
"go.viam.com/rdk/ml"
"go.viam.com/rdk/resource"
"go.viam.com/rdk/services/mlmodel"
)
// MLModelService represents a fake instance of an MLModel service.
type MLModelService struct {
mlmodel.Service
name resource.Name
InferFunc func(ctx context.Context, tensors ml.Tensors) (ml.Tensors, error)
MetadataFunc func(ctx context.Context) (mlmodel.MLMetadata, error)
CloseFunc func(ctx context.Context) error
}
// NewMLModelService returns a new injected mlmodel service.
func NewMLModelService(name string) *MLModelService {
return &MLModelService{name: mlmodel.Named(name)}
}
// Name returns the name of the resource.
func (s *MLModelService) Name() resource.Name {
return s.name
}
// Infer calls the injected Infer or the real variant.
func (s *MLModelService) Infer(
ctx context.Context,
tensors ml.Tensors,
) (ml.Tensors, error) {
if s.InferFunc == nil {
return s.Service.Infer(ctx, tensors)
}
return s.InferFunc(ctx, tensors)
}
// Metadata calls the injected Metadata or the real variant.
func (s *MLModelService) Metadata(ctx context.Context) (mlmodel.MLMetadata, error) {
if s.MetadataFunc == nil {
return s.Service.Metadata(ctx)
}
return s.MetadataFunc(ctx)
}
// Close calls the injected Close or the real version.
func (s *MLModelService) Close(ctx context.Context) error {
if s.CloseFunc == nil {
if s.Service == nil {
return nil
}
return s.Service.Close(ctx)
}
return s.CloseFunc(ctx)
}