This repository has been archived by the owner on Jan 5, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 42
/
auth.go
206 lines (178 loc) · 5.25 KB
/
auth.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
/*
* go-mysqlstack
* xelabs.org
*
* Copyright (c) XeLabs
* GPL License
*
*/
package proto
import (
"crypto/sha1"
"fmt"
"github.com/xelabs/go-mysqlstack/sqldb"
"github.com/xelabs/go-mysqlstack/sqlparser/depends/common"
)
// Auth packet.
type Auth struct {
charset uint8
maxPacketSize uint32
authResponseLen uint8
clientFlags uint32
authResponse []byte
pluginName string
database string
user string
}
// NewAuth creates new Auth.
func NewAuth() *Auth {
return &Auth{}
}
// Database returns the database.
func (a *Auth) Database() string {
return a.database
}
// ClientFlags returns the client flags.
func (a *Auth) ClientFlags() uint32 {
return a.clientFlags
}
// Charset returns the charset.
func (a *Auth) Charset() uint8 {
return a.charset
}
// User returns the user.
func (a *Auth) User() string {
return a.user
}
// AuthResponse returns the auth response.
func (a *Auth) AuthResponse() []byte {
return a.authResponse
}
// CleanAuthResponse used to set the authResponse to nil.
// To improve the heap gc cost.
func (a *Auth) CleanAuthResponse() {
a.authResponse = nil
}
// UnPack parses the handshake sent by the client.
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse41
func (a *Auth) UnPack(payload []byte) error {
var err error
buf := common.ReadBuffer(payload)
if a.clientFlags, err = buf.ReadU32(); err != nil {
return fmt.Errorf("auth.unpack: can't read client flags")
}
if a.clientFlags&sqldb.CLIENT_PROTOCOL_41 == 0 {
return fmt.Errorf("auth.unpack: only support protocol 4.1")
}
if a.maxPacketSize, err = buf.ReadU32(); err != nil {
return fmt.Errorf("auth.unpack: can't read maxPacketSize")
}
if a.charset, err = buf.ReadU8(); err != nil {
return fmt.Errorf("auth.unpack: can't read charset")
}
if err = buf.ReadZero(23); err != nil {
return fmt.Errorf("auth.unpack: can't read 23zeros")
}
if a.user, err = buf.ReadStringNUL(); err != nil {
return fmt.Errorf("auth.unpack: can't read user")
}
if (a.clientFlags & sqldb.CLIENT_SECURE_CONNECTION) > 0 {
if a.authResponseLen, err = buf.ReadU8(); err != nil {
return fmt.Errorf("auth.unpack: can't read authResponse length")
}
if a.authResponse, err = buf.ReadBytes(int(a.authResponseLen)); err != nil {
return fmt.Errorf("auth.unpack: can't read authResponse")
}
} else {
if a.authResponse, err = buf.ReadBytes(20); err != nil {
return fmt.Errorf("auth.unpack: can't read authResponse")
}
if err = buf.ReadZero(1); err != nil {
return fmt.Errorf("auth.unpack: can't read authResponse")
}
}
if (a.clientFlags & sqldb.CLIENT_CONNECT_WITH_DB) > 0 {
if a.database, err = buf.ReadStringNUL(); err != nil {
return fmt.Errorf("auth.unpack: can't read dbname")
}
}
if (a.clientFlags & sqldb.CLIENT_PLUGIN_AUTH) > 0 {
if a.pluginName, err = buf.ReadStringNUL(); err != nil {
return fmt.Errorf("auth.unpack: can't read pluginName")
}
}
if a.pluginName != DefaultAuthPluginName {
return fmt.Errorf("invalid authPluginName, got %v but only support %v", a.pluginName, DefaultAuthPluginName)
}
return nil
}
// Pack used to pack a HandshakeResponse41 packet.
func (a *Auth) Pack(capabilityFlags uint32, charset uint8, username string, password string, salt []byte, database string) []byte {
buf := common.NewBuffer(256)
authResponse := nativePassword(password, salt)
if len(database) > 0 {
capabilityFlags |= sqldb.CLIENT_CONNECT_WITH_DB
} else {
capabilityFlags &= ^sqldb.CLIENT_CONNECT_WITH_DB
}
// 4 capability flags, CLIENT_PROTOCOL_41 always set
buf.WriteU32(capabilityFlags)
// 4 max-packet size (none)
buf.WriteU32(0)
// 1 character set
buf.WriteU8(charset)
// string[23] reserved (all [0])
buf.WriteZero(23)
// string[NUL] username
buf.WriteString(username)
buf.WriteZero(1)
if (capabilityFlags & sqldb.CLIENT_SECURE_CONNECTION) > 0 {
// 1 length of auth-response
// string[n] auth-response
buf.WriteU8(uint8(len(authResponse)))
buf.WriteBytes(authResponse)
} else {
buf.WriteBytes(authResponse)
buf.WriteZero(1)
}
capabilityFlags &= ^sqldb.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
// string[NUL] database
if capabilityFlags&sqldb.CLIENT_CONNECT_WITH_DB > 0 {
buf.WriteString(database)
buf.WriteZero(1)
}
// string[NUL] auth plugin name
buf.WriteString(DefaultAuthPluginName)
buf.WriteZero(1)
// CLIENT_CONNECT_ATTRS none
//
return buf.Datas()
}
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41
// SHA1( password ) XOR SHA1( "20-bytes random data from server" <concat> SHA1( SHA1( password ) ) )
// Encrypt password using 4.1+ method
func nativePassword(password string, salt []byte) []byte {
if len(password) == 0 {
return nil
}
// stage1Hash = SHA1(password)
crypt := sha1.New()
crypt.Write([]byte(password))
stage1 := crypt.Sum(nil)
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
// inner Hash
crypt.Reset()
crypt.Write(stage1)
stage1SHA1 := crypt.Sum(nil)
// stage2Hash = SHA1(salt <concat> SHA1(SHA1(password)))
crypt.Reset()
crypt.Write(salt)
crypt.Write(stage1SHA1)
stage2 := crypt.Sum(nil)
// srambleHash = stage1Hash ^ stage2Hash
scramble := make([]byte, len(stage2))
for i := range stage2 {
scramble[i] = stage1[i] ^ stage2[i]
}
return scramble
}