forked from colinmarc/hdfs
/
sasl_transport.go
82 lines (70 loc) · 2.25 KB
/
sasl_transport.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
package rpc
import (
"bytes"
"fmt"
"io"
hadoop "github.com/stubey/hdfs/v2/internal/protocol/hadoop_common"
"github.com/golang/protobuf/proto"
"gopkg.in/jcmturner/gokrb5.v7/crypto"
"gopkg.in/jcmturner/gokrb5.v7/gssapi"
"gopkg.in/jcmturner/gokrb5.v7/iana/keyusage"
krbtypes "gopkg.in/jcmturner/gokrb5.v7/types"
)
// saslTransport implements encrypted or signed RPC.
type saslTransport struct {
basicTransport
// sessionKey is the encryption key used to decrypt and encrypt the payload.
sessionKey krbtypes.EncryptionKey
// privacy indicates full message encryption
privacy bool
}
// readResponse reads a SASL-wrapped RPC response.
func (t *saslTransport) readResponse(r io.Reader, method string, requestID int32, resp proto.Message) error {
// First, read the sasl payload as a standard rpc response.
sasl := hadoop.RpcSaslProto{}
err := t.basicTransport.readResponse(r, method, saslRpcCallId, &sasl)
if err != nil {
return err
} else if sasl.GetState() != hadoop.RpcSaslProto_WRAP {
return fmt.Errorf("unexpected SASL state: %s", sasl.GetState().String())
}
// The SaslProto contains the actual payload.
var wrapToken gssapi.WrapToken
err = wrapToken.Unmarshal(sasl.GetToken(), true)
if err != nil {
return err
}
rrh := &hadoop.RpcResponseHeaderProto{}
if t.privacy {
// Decrypt the blob, which then looks like a normal RPC response.
decrypted, err := crypto.DecryptMessage(wrapToken.Payload, t.sessionKey, keyusage.GSSAPI_ACCEPTOR_SEAL)
if err != nil {
return err
}
err = readRPCPacket(bytes.NewReader(decrypted), rrh, resp)
if err != nil {
return err
}
} else {
// Verify the checksum; the blob is just a normal RPC response.
_, err = wrapToken.Verify(t.sessionKey, keyusage.GSSAPI_ACCEPTOR_SEAL)
if err != nil {
return fmt.Errorf("unverifiable message from namenode: %s", err)
}
err = readRPCPacket(bytes.NewReader(wrapToken.Payload), rrh, resp)
if err != nil {
return err
}
}
if int32(rrh.GetCallId()) != requestID {
return errUnexpectedSequenceNumber
} else if rrh.GetStatus() != hadoop.RpcResponseHeaderProto_SUCCESS {
return &NamenodeError{
method: method,
message: rrh.GetErrorMsg(),
code: int(rrh.GetErrorDetail()),
exception: rrh.GetExceptionClassName(),
}
}
return nil
}