diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..acb88a4 --- /dev/null +++ b/go.sum @@ -0,0 +1,11 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/tencentyun/TLSSigAPI.go b/tencentyun/TLSSigAPI.go index 0699f55..5dabf9d 100644 --- a/tencentyun/TLSSigAPI.go +++ b/tencentyun/TLSSigAPI.go @@ -8,8 +8,10 @@ import ( "encoding/base64" "encoding/json" "errors" + "io" "io/ioutil" "strconv" + "sync" "time" ) @@ -23,7 +25,7 @@ import ( * expire - UserSig 票据的过期时间,单位是秒,比如 86400 代表生成的 UserSig 票据在一天后就无法再使用了。 */ - /** +/** * Function: Used to issue UserSig that is required by the TRTC and IM services. * * Parameter description: @@ -94,7 +96,6 @@ func GenUserSigWithBuf(sdkappid int, key string, userid string, expire int, buf * - privilegeMap == 0010 1010 == 42: Indicates that the UserID has only the permissions to enter the room and receive audio/video data. */ - func GenPrivateMapKey(sdkappid int, key string, userid string, expire int, roomid uint32, privilegeMap uint32) (string, error) { var userbuf []byte = genUserBuf(userid, sdkappid, roomid, expire, privilegeMap, 0, "") return genSig(sdkappid, key, userid, expire, userbuf) @@ -252,52 +253,28 @@ func genUserBuf(account string, dwSdkappid int, dwAuthID uint32, return userBuf } -func hmacsha256(sdkappid int, key string, identifier string, currTime int64, expire int, base64UserBuf *string) string { - var contentToBeSigned string - contentToBeSigned = "TLS.identifier:" + identifier + "\n" - contentToBeSigned += "TLS.sdkappid:" + strconv.Itoa(sdkappid) + "\n" - contentToBeSigned += "TLS.time:" + strconv.FormatInt(currTime, 10) + "\n" - contentToBeSigned += "TLS.expire:" + strconv.Itoa(expire) + "\n" - if nil != base64UserBuf { - contentToBeSigned += "TLS.userbuf:" + *base64UserBuf + "\n" - } - - h := hmac.New(sha256.New, []byte(key)) - h.Write([]byte(contentToBeSigned)) - return base64.StdEncoding.EncodeToString(h.Sum(nil)) -} - func genSig(sdkappid int, key string, identifier string, expire int, userbuf []byte) (string, error) { currTime := time.Now().Unix() - sigDoc := make(map[string]interface{}) - sigDoc["TLS.ver"] = "2.0" - sigDoc["TLS.identifier"] = identifier - sigDoc["TLS.sdkappid"] = sdkappid - sigDoc["TLS.expire"] = expire - sigDoc["TLS.time"] = currTime - var base64UserBuf string - if nil != userbuf { - base64UserBuf = base64.StdEncoding.EncodeToString(userbuf) - sigDoc["TLS.userbuf"] = base64UserBuf - sigDoc["TLS.sig"] = hmacsha256(sdkappid, key, identifier, currTime, expire, &base64UserBuf) - } else { - sigDoc["TLS.sig"] = hmacsha256(sdkappid, key, identifier, currTime, expire, nil) - } - - data, err := json.Marshal(sigDoc) - if err != nil { - return "", err + sigDoc := userSig{ + Version: "2.0", + Identifier: identifier, + SdkAppID: uint64(sdkappid), + Expire: int64(expire), + Time: currTime, + UserBuf: userbuf, } + sigDoc.Sig = sigDoc.sign(key) var b bytes.Buffer - w := zlib.NewWriter(&b) - if _, err = w.Write(data); err != nil { + w := newZlibWriter(&b) + defer zlibWriterPool.Put(w) + if err := json.NewEncoder(w).Encode(sigDoc); err != nil { return "", err } - if err = w.Close(); err != nil { + if err := w.Close(); err != nil { return "", err } - return base64urlEncode(b.Bytes()), nil + return base64url.EncodeToString(b.Bytes()), nil } // VerifyUserSig 检验UserSig在now时间点时是否有效 @@ -327,7 +304,7 @@ type userSig struct { Expire int64 `json:"TLS.expire,omitempty"` Time int64 `json:"TLS.time,omitempty"` UserBuf []byte `json:"TLS.userbuf,omitempty"` - Sig string `json:"TLS.sig,omitempty"` + Sig []byte `json:"TLS.sig,omitempty"` } func newUserSig(usersig string) (userSig, error) { @@ -373,35 +350,41 @@ func (u userSig) verify(sdkappid uint64, key string, userid string, now time.Tim } else if u.UserBuf != nil { return ErrUserBufTypeNotMatch } - if u.sign(key) != u.Sig { + if !bytes.Equal(u.sign(key), u.Sig) { return ErrSigNotMatch } return nil } -func (u userSig) sign(key string) string { - var sb bytes.Buffer - sb.WriteString("TLS.identifier:") - sb.WriteString(u.Identifier) - sb.WriteString("\n") - sb.WriteString("TLS.sdkappid:") - sb.WriteString(strconv.FormatUint(u.SdkAppID, 10)) - sb.WriteString("\n") - sb.WriteString("TLS.time:") - sb.WriteString(strconv.FormatInt(u.Time, 10)) - sb.WriteString("\n") - sb.WriteString("TLS.expire:") - sb.WriteString(strconv.FormatInt(u.Expire, 10)) - sb.WriteString("\n") - if u.UserBuf != nil { - sb.WriteString("TLS.userbuf:") - sb.WriteString(base64.StdEncoding.EncodeToString(u.UserBuf)) - sb.WriteString("\n") - } +var ( + sigIdentifier = []byte("TLS.identifier:") + sigSdkAppID = []byte("TLS.sdkappid:") + sigTime = []byte("TLS.time:") + sigExpire = []byte("TLS.expire:") + sigUserBuf = []byte("TLS.userbuf:") + sigEnter = []byte("\n") +) +func (u userSig) sign(key string) []byte { h := hmac.New(sha256.New, []byte(key)) - h.Write(sb.Bytes()) - return base64.StdEncoding.EncodeToString(h.Sum(nil)) + h.Write(sigIdentifier) + h.Write([]byte(u.Identifier)) + h.Write(sigEnter) + h.Write(sigSdkAppID) + h.Write([]byte(strconv.FormatUint(u.SdkAppID, 10))) + h.Write(sigEnter) + h.Write(sigTime) + h.Write([]byte(strconv.FormatInt(u.Time, 10))) + h.Write(sigEnter) + h.Write(sigExpire) + h.Write([]byte(strconv.FormatInt(u.Expire, 10))) + h.Write(sigEnter) + if u.UserBuf != nil { + h.Write(sigUserBuf) + h.Write([]byte(base64.StdEncoding.EncodeToString(u.UserBuf))) + h.Write(sigEnter) + } + return h.Sum(nil) } // 错误类型 @@ -413,3 +396,26 @@ var ( ErrUserBufNotMatch = errors.New("userbuf not match") ErrSigNotMatch = errors.New("sig not match") ) + +var ( + zlibWriterPool sync.Pool +) + +func newZlibWriter(w io.Writer) *zlib.Writer { + v := zlibWriterPool.Get() + if v == nil { + zw, err := zlib.NewWriterLevel(w, DefaultCompressionLevel) + if err != nil { + return zlib.NewWriter(w) + } + return zw + } + zw := v.(*zlib.Writer) + zw.Reset(w) + return zw +} + +// DefaultCompressionLevel is the default compression level. +// Default is zlib.NoCompression. +// It can be set to any valid compression level to balance speed and size. +var DefaultCompressionLevel = zlib.NoCompression diff --git a/tencentyun/TLSSigAPI_test.go b/tencentyun/TLSSigAPI_test.go index 06f3593..a235996 100644 --- a/tencentyun/TLSSigAPI_test.go +++ b/tencentyun/TLSSigAPI_test.go @@ -32,3 +32,15 @@ func TestGenAndVerify(t *testing.T) { assert.Equal(t, ErrUserBufTypeNotMatch, VerifyUserSigWithBuf(1, "3", "3", bufSig, now, nil)) assert.Equal(t, ErrUserBufNotMatch, VerifyUserSigWithBuf(1, "3", "3", bufSig, now, []byte{6})) } + +func BenchmarkGenSig(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = GenUserSig(1, "abc", "a", 1) + } +} + +func BenchmarkGenUserSigWithBuf(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = GenUserSigWithBuf(1, "abc", "a", 1, []byte{1}) + } +} diff --git a/tencentyun/base64url.go b/tencentyun/base64url.go index c6a9c86..1bf9d1e 100644 --- a/tencentyun/base64url.go +++ b/tencentyun/base64url.go @@ -5,13 +5,7 @@ import ( "strings" ) -func base64urlEncode(data []byte) string { - str := base64.StdEncoding.EncodeToString(data) - str = strings.Replace(str, "+", "*", -1) - str = strings.Replace(str, "/", "-", -1) - str = strings.Replace(str, "=", "_", -1) - return str -} +var base64url = base64.NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789*-").WithPadding('_') func base64urlDecode(str string) ([]byte, error) { str = strings.Replace(str, "_", "=", -1)