-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.go
137 lines (110 loc) · 3.34 KB
/
utils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package drill
import (
"encoding/binary"
"errors"
"io"
"net"
"github.com/zeroshade/go-drill/internal/rpc/proto/exec/rpc"
"github.com/zeroshade/go-drill/internal/rpc/proto/exec/user"
"google.golang.org/protobuf/proto"
)
type encoder interface {
WriteRaw(net.Conn, []byte) (int, error)
Write(net.Conn, rpc.RpcMode, user.RpcType, int32, proto.Message) (int, error)
ReadMsg(net.Conn, proto.Message) (*rpc.RpcHeader, error)
ReadRaw(net.Conn) (*rpc.CompleteRpcMessage, error)
}
type rpcEncoder struct{}
func (rpcEncoder) WriteRaw(conn net.Conn, b []byte) (int, error) {
return conn.Write(makePrefixedMessage(b))
}
func (rpcEncoder) Write(conn net.Conn, mode rpc.RpcMode, typ user.RpcType, coord int32, msg proto.Message) (int, error) {
encoded, err := encodeRPCMessage(mode, typ, coord, msg)
if err != nil {
return 0, err
}
return conn.Write(makePrefixedMessage(encoded))
}
func (rpcEncoder) ReadRaw(conn net.Conn) (*rpc.CompleteRpcMessage, error) {
return readPrefixedRaw(conn)
}
func (rpcEncoder) ReadMsg(conn net.Conn, msg proto.Message) (*rpc.RpcHeader, error) {
return readPrefixedMessage(conn, msg)
}
var errInvalidResponse = errors.New("invalid response")
func makePrefixedMessage(data []byte) []byte {
if data == nil {
return nil
}
buf := make([]byte, binary.MaxVarintLen32)
nbytes := binary.PutUvarint(buf, uint64(len(data)))
return append(buf[:nbytes], data...)
}
func readPrefixed(r io.Reader) ([]byte, error) {
vbytes := make([]byte, binary.MaxVarintLen32)
n, err := io.ReadAtLeast(r, vbytes, binary.MaxVarintLen32)
if err == io.EOF {
return nil, io.ErrUnexpectedEOF
} else if err != nil {
return nil, err
}
respLength, vlength := binary.Uvarint(vbytes)
// if we got an empty message and read too many bytes we're screwed
// but this shouldn't happen anyways, just in case
if vlength < 1 || vlength+int(respLength) < n {
return nil, errInvalidResponse
}
respBytes := make([]byte, respLength)
extraLen := copy(respBytes, vbytes[vlength:])
_, err = io.ReadFull(r, respBytes[extraLen:])
if err == io.EOF {
return nil, io.ErrUnexpectedEOF
} else if err != nil {
return nil, err
}
return respBytes, nil
}
func readPrefixedRaw(r io.Reader) (*rpc.CompleteRpcMessage, error) {
respBytes, err := readPrefixed(r)
if err != nil {
return nil, err
}
return getRawRPCMessage(respBytes)
}
func readPrefixedMessage(r io.Reader, msg proto.Message) (*rpc.RpcHeader, error) {
respBytes, err := readPrefixed(r)
if err != nil {
return nil, err
}
return decodeRPCMessage(respBytes, msg)
}
func encodeRPCMessage(mode rpc.RpcMode, msgType user.RpcType, coordID int32, msg proto.Message) ([]byte, error) {
data, err := proto.Marshal(msg)
if err != nil {
return nil, err
}
rpcMsg := &rpc.CompleteRpcMessage{
Header: &rpc.RpcHeader{
Mode: &mode,
CoordinationId: &coordID,
RpcType: proto.Int32(int32(msgType)),
},
ProtobufBody: data,
}
return proto.Marshal(rpcMsg)
}
func getRawRPCMessage(data []byte) (*rpc.CompleteRpcMessage, error) {
rpcMsg := &rpc.CompleteRpcMessage{}
if err := proto.Unmarshal(data, rpcMsg); err != nil {
return nil, err
}
return rpcMsg, nil
}
func decodeRPCMessage(data []byte, msg proto.Message) (*rpc.RpcHeader, error) {
rpcMsg, err := getRawRPCMessage(data)
if err != nil {
return nil, err
}
ret := rpcMsg.GetHeader()
return ret, proto.Unmarshal(rpcMsg.ProtobufBody, msg)
}