/
torch_request.go
126 lines (111 loc) · 3.89 KB
/
torch_request.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package eas
import (
"github.com/golang/protobuf/proto"
"github.com/pai-eas/eas-golang-sdk/eas/types/torch_predict_protos"
)
// TorchRequest class for PyTorch data and requests
type TorchRequest struct {
RequestData torch_predict_protos.PredictRequest
}
// AddFeedFloat32 function adds float values input data for torchrequest
func (tr *TorchRequest) AddFeedFloat32(index int, shape []int64, content []float32) {
for len(tr.RequestData.Inputs) < index+1 {
tr.RequestData.Inputs = append(tr.RequestData.Inputs, &torch_predict_protos.ArrayProto{})
}
requestProto := torch_predict_protos.ArrayProto{
Dtype: TorchType_DT_FLOAT,
ArrayShape: &torch_predict_protos.ArrayShape{
Dim: shape,
},
FloatVal: content,
}
tr.RequestData.Inputs[index] = &requestProto
}
// AddFeedFloat64 function adds double values input data for torchrequest
func (tr *TorchRequest) AddFeedFloat64(index int, shape []int64, content []float64) {
for len(tr.RequestData.Inputs) < index+1 {
tr.RequestData.Inputs = append(tr.RequestData.Inputs, &torch_predict_protos.ArrayProto{})
}
requestProto := torch_predict_protos.ArrayProto{
Dtype: TorchType_DT_DOUBLE,
ArrayShape: &torch_predict_protos.ArrayShape{
Dim: shape,
},
DoubleVal: content,
}
tr.RequestData.Inputs[index] = &requestProto
}
// AddFeedInt32 function adds int values input data for torchrequest
func (tr *TorchRequest) AddFeedInt32(index int, shape []int64, content []int32) {
for len(tr.RequestData.Inputs) < index+1 {
tr.RequestData.Inputs = append(tr.RequestData.Inputs, &torch_predict_protos.ArrayProto{})
}
requestProto := torch_predict_protos.ArrayProto{
Dtype: TorchType_DT_INT32,
ArrayShape: &torch_predict_protos.ArrayShape{
Dim: shape,
},
IntVal: content,
}
tr.RequestData.Inputs[index] = &requestProto
}
// AddFeedInt64 function adds int64 values input data for torchrequest
func (tr *TorchRequest) AddFeedInt64(index int, shape []int64, content []int64) {
for len(tr.RequestData.Inputs) < index+1 {
tr.RequestData.Inputs = append(tr.RequestData.Inputs, &torch_predict_protos.ArrayProto{})
}
requestProto := torch_predict_protos.ArrayProto{
Dtype: TorchType_DT_INT64,
ArrayShape: &torch_predict_protos.ArrayShape{
Dim: shape,
},
Int64Val: content,
}
tr.RequestData.Inputs[index] = &requestProto
}
// AddFetch add OutputFilter (outIndex) for response
func (tr *TorchRequest) AddFetch(outIndex int32) {
tr.RequestData.OutputFilter = append(tr.RequestData.OutputFilter, outIndex)
}
// ToString for interface
func (tr TorchRequest) ToString() (string, error) {
reqData, err := proto.Marshal(&tr.RequestData)
if err != nil {
return "", NewPredictError(-1, "", err.Error())
}
return string(reqData), nil
}
// TorchResponse class for PyTorch predicted results
type TorchResponse struct {
Response torch_predict_protos.PredictResponse
}
// GetTensorShape returns []int64 slice as shape of tensor outindexed
func (resp *TorchResponse) GetTensorShape(outIndex int) []int64 {
return resp.Response.Outputs[outIndex].ArrayShape.Dim
}
// GetFloatVal returns []float32 slice as output data
func (resp *TorchResponse) GetFloatVal(outIndex int) []float32 {
return resp.Response.Outputs[outIndex].GetFloatVal()
}
// GetDoubleVal returns []float64 slice as output data
func (resp *TorchResponse) GetDoubleVal(outIndex int) []float64 {
return resp.Response.Outputs[outIndex].GetDoubleVal()
}
// GetIntVal returns []int32 slice as output data
func (resp *TorchResponse) GetIntVal(outIndex int) []int32 {
return resp.Response.Outputs[outIndex].GetIntVal()
}
// GetInt64Val returns []int64 slice as output data
func (resp *TorchResponse) GetInt64Val(outIndex int) []int64 {
return resp.Response.Outputs[outIndex].GetInt64Val()
}
// Unmarshal for interface
func (resp *TorchResponse) unmarshal(body []byte) error {
bd := &torch_predict_protos.PredictResponse{}
err := proto.Unmarshal(body, bd)
if err != nil {
return err
}
resp.Response = *bd
return nil
}