Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
fix read bug
fix tcp read decode packet
  • Loading branch information
sun8911879 committed Mar 20, 2018
1 parent 5788c80 commit 7595b5f
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 59 deletions.
2 changes: 1 addition & 1 deletion obfs/base.go
Expand Up @@ -15,7 +15,7 @@ var (
type IObfs interface {
SetServerInfo(s *ssr.ServerInfoForObfs)
GetServerInfo() (s *ssr.ServerInfoForObfs)
Encode(data []byte) (encodedData []byte, err error)
Encode(data []byte) ([]byte, error)
Decode(data []byte) ([]byte, uint64, error)
SetData(data interface{})
GetData() interface{}
Expand Down
3 changes: 1 addition & 2 deletions obfs/tls12_ticket_auth.go
Expand Up @@ -206,15 +206,14 @@ func (t *tls12TicketAuth) Encode(data []byte) (encodedData []byte, err error) {
return
}

//return (outData []byte, length uint64, err error)
func (t *tls12TicketAuth) Decode(data []byte) ([]byte, uint64, error) {
if t.handshakeStatus == -1 {
return data, 0, nil
}
dataLength := len(data)

if t.handshakeStatus == 8 {
if !hmac.Equal(data[0:3], []byte{0x17, 0x3, 0x3}) {
if data[0] != 0x17 {
return nil, 0, ssr.ErrTLS12TicketAuthIncorrectMagicNumber
}
size := int(binary.BigEndian.Uint16(data[3:5]))
Expand Down
23 changes: 13 additions & 10 deletions protocol/auth_aes128_md5.go
Expand Up @@ -241,39 +241,42 @@ func (a *authAES128) PreEncrypt(plainData []byte) (outData []byte, err error) {
return
}

func (a *authAES128) PostDecrypt(plainData []byte) ([]byte, error) {
func (a *authAES128) PostDecrypt(plainData []byte) ([]byte, int, error) {
a.buffer.Reset()
plainLength := len(plainData)
datalength := plainLength
readlenth := 0
key := make([]byte, len(a.userKey)+4)
copy(key, a.userKey)
for plainLength > 4 {
binary.LittleEndian.PutUint32(key[len(key)-4:], a.recvID)

h := a.hmac(key, plainData[0:2])

if h[0] != plainData[2] || h[1] != plainData[3] {
return nil, ssr.ErrAuthAES128HMACError
return nil, 0, ssr.ErrAuthAES128HMACError
}

length := int(binary.LittleEndian.Uint16(plainData[0:2]))
if length >= 8192 || length < 8 {
return nil, ssr.ErrAuthAES128DataLengthError
return nil, 0, ssr.ErrAuthAES128DataLengthError
}

if length > plainLength {
break
}

a.recvID++
pos := int(plainData[4])
if pos != 0xFF {
if pos < 255 {
pos += 4
} else {
pos = int(binary.LittleEndian.Uint16(plainData[5:5+2])) + 4
pos = int(binary.LittleEndian.Uint16(plainData[5:7])) + 4
}

a.buffer.Write(plainData[pos : length-4])
plainData = plainData[length:]
plainLength -= length
readlenth += length
}
if datalength == readlenth {
readlenth = -1
}
return a.buffer.Bytes(), nil
return a.buffer.Bytes(), readlenth, nil
}
11 changes: 6 additions & 5 deletions protocol/auth_sha1_v4.go
Expand Up @@ -185,7 +185,8 @@ func (a *authSHA1v4) PreEncrypt(plainData []byte) (outData []byte, err error) {
return
}

func (a *authSHA1v4) PostDecrypt(plainData []byte) (outData []byte, err error) {
func (a *authSHA1v4) PostDecrypt(plainData []byte) ([]byte, int, error) {
var outData []byte
dataLength := len(plainData)
b := make([]byte, len(a.recvBuffer)+dataLength)
copy(b, a.recvBuffer)
Expand All @@ -195,13 +196,13 @@ func (a *authSHA1v4) PostDecrypt(plainData []byte) (outData []byte, err error) {
for a.recvBufferLength > 4 {
crc32 := ssr.CalcCRC32(a.recvBuffer, 2, 0xFFFFFFFF)
if binary.LittleEndian.Uint16(a.recvBuffer[2:4]) != uint16(crc32&0xFFFF) {
return nil, ssr.ErrAuthSHA1v4CRC32Error
return nil, 0, ssr.ErrAuthSHA1v4CRC32Error
}
length := int(binary.BigEndian.Uint16(a.recvBuffer[0:2]))
if length >= 8192 || length < 8 {
a.recvBufferLength = 0
a.recvBuffer = nil
return nil, ssr.ErrAuthSHA1v4DataLengthError
return nil, 0, ssr.ErrAuthSHA1v4DataLengthError
}
if length > a.recvBufferLength {
break
Expand All @@ -224,8 +225,8 @@ func (a *authSHA1v4) PostDecrypt(plainData []byte) (outData []byte, err error) {
} else {
a.recvBufferLength = 0
a.recvBuffer = nil
return nil, ssr.ErrAuthSHA1v4IncorrectChecksum
return nil, 0, ssr.ErrAuthSHA1v4IncorrectChecksum
}
}
return
return outData, 0, nil
}
4 changes: 2 additions & 2 deletions protocol/base.go
Expand Up @@ -15,8 +15,8 @@ var (
type IProtocol interface {
SetServerInfo(s *ssr.ServerInfoForObfs)
GetServerInfo() *ssr.ServerInfoForObfs
PreEncrypt(data []byte) (encryptedData []byte, err error)
PostDecrypt(data []byte) (decryptedData []byte, err error)
PreEncrypt(data []byte) ([]byte, error)
PostDecrypt(data []byte) ([]byte, int, error)
SetData(data interface{})
GetData() interface{}
}
Expand Down
4 changes: 2 additions & 2 deletions protocol/origin.go
Expand Up @@ -29,8 +29,8 @@ func (o *origin) PreEncrypt(data []byte) (encryptedData []byte, err error) {
return data, nil
}

func (o *origin) PostDecrypt(data []byte) (decryptedData []byte, err error) {
return data, nil
func (o *origin) PostDecrypt(data []byte) ([]byte, int, error) {
return data, 0, nil
}

func (o *origin) SetData(data interface{}) {
Expand Down
4 changes: 2 additions & 2 deletions protocol/verify_sha1.go
Expand Up @@ -96,6 +96,6 @@ func (v *verifySHA1) PreEncrypt(data []byte) (encryptedData []byte, err error) {
return
}

func (v *verifySHA1) PostDecrypt(data []byte) (decryptedData []byte, err error) {
return data, nil
func (v *verifySHA1) PostDecrypt(data []byte) ([]byte, int, error) {
return data, 0, nil
}
90 changes: 55 additions & 35 deletions tcp.go
Expand Up @@ -16,24 +16,28 @@ type SSTCPConn struct {
net.Conn
sync.RWMutex
*StreamCipher
IObfs obfs.IObfs
IProtocol protocol.IProtocol
readBuf []byte
readDecodeBuf *bytes.Buffer
readIndex uint64
readUserBuf *bytes.Buffer
writeBuf []byte
lastReadError error
IObfs obfs.IObfs
IProtocol protocol.IProtocol
readBuf []byte
readDecodeBuf *bytes.Buffer
readIObfsBuf *bytes.Buffer
readEncryptBuf *bytes.Buffer
readIndex uint64
readUserBuf *bytes.Buffer
writeBuf []byte
lastReadError error
}

func NewSSTCPConn(c net.Conn, cipher *StreamCipher) *SSTCPConn {
return &SSTCPConn{
Conn: c,
StreamCipher: cipher,
readBuf: leakybuf.GlobalLeakyBuf.Get(),
readDecodeBuf: bytes.NewBuffer(nil),
readUserBuf: bytes.NewBuffer(nil),
writeBuf: leakybuf.GlobalLeakyBuf.Get(),
Conn: c,
StreamCipher: cipher,
readBuf: leakybuf.GlobalLeakyBuf.Get(),
readDecodeBuf: bytes.NewBuffer(nil),
readIObfsBuf: bytes.NewBuffer(nil),
readUserBuf: bytes.NewBuffer(nil),
readEncryptBuf: bytes.NewBuffer(nil),
writeBuf: leakybuf.GlobalLeakyBuf.Get(),
}
}

Expand Down Expand Up @@ -85,7 +89,8 @@ func (c *SSTCPConn) Read(b []byte) (n int, err error) {
}
//未读取够长度继续读取并解码
decodelength := c.readDecodeBuf.Len()
if (decodelength == 0 || (c.readIndex != 0 && c.readIndex > uint64(decodelength))) && c.lastReadError == nil {
if (decodelength == 0 || c.readEncryptBuf.Len() > 0 || (c.readIndex != 0 && c.readIndex > uint64(decodelength))) && c.lastReadError == nil {
c.readIndex = 0
n, c.lastReadError = c.Conn.Read(c.readBuf)
//写入decode 缓存
c.readDecodeBuf.Write(c.readBuf[0:n])
Expand All @@ -97,9 +102,10 @@ func (c *SSTCPConn) Read(b []byte) (n int, err error) {
decodelength = c.readDecodeBuf.Len()
decodebytes := c.readDecodeBuf.Bytes()
c.readDecodeBuf.Reset()

for {
decodedData, length, err := c.IObfs.Decode(decodebytes)

decodedData, length, err := c.IObfs.Decode(decodebytes)
if length == 0 && err != nil {
return 0, err
}
Expand All @@ -110,34 +116,36 @@ func (c *SSTCPConn) Read(b []byte) (n int, err error) {
return 0, nil
}

//未读取完全数据
//数据不够长度
if err != nil && length > 5 {
if uint64(decodelength) > length {
return 0, fmt.Errorf("data length: %d,decode data length: %d unknown panic", decodelength, length)
}
c.readIndex = length
//c.readDecodeBuf.Write(decodebytes)
c.readDecodeBuf.Write(decodebytes)
if c.readIObfsBuf.Len() == 0 {
return 0, nil
}
break
}
//完全读取数据
if length == 0 {
c.readDecodeBuf.Write(decodedData)
decodebytes = decodebytes[:0]

if length > 1 {
//读出数据 但是有多余的数据 返回已经读取数值
c.readIObfsBuf.Write(decodedData)
decodebytes = decodebytes[length:]
decodelength = len(decodebytes)
break
}
c.readDecodeBuf.Write(decodedData)
decodebytes = decodebytes[length:]
decodelength = len(decodebytes)
if decodelength < 5 {
break
continue
}

//完全读取数据 -- length == 0
c.readIObfsBuf.Write(decodedData)
break
}
decodedData := c.readDecodeBuf.Bytes()
decodelength = c.readDecodeBuf.Len()
c.readDecodeBuf.Reset()
c.readDecodeBuf.Write(decodebytes)
//Protocol decrypt

decodedData := c.readIObfsBuf.Bytes()
decodelength = c.readIObfsBuf.Len()
c.readIObfsBuf.Reset()

if c.dec == nil {
iv := decodedData[0:c.info.ivLen]
if err = c.initDecrypt(iv); err != nil {
Expand All @@ -154,10 +162,22 @@ func (c *SSTCPConn) Read(b []byte) (n int, err error) {
buf := make([]byte, decodelength)
c.decrypt(buf, decodedData)

postDecryptedData, err := c.IProtocol.PostDecrypt(buf)
c.readEncryptBuf.Write(buf)
encryptbuf := c.readEncryptBuf.Bytes()
c.readEncryptBuf.Reset()
postDecryptedData, length, err := c.IProtocol.PostDecrypt(encryptbuf)
if err != nil {
return 0, err
}
if length == 0 {
c.readEncryptBuf.Write(encryptbuf)
return 0, nil
}

if length > 0 {
c.readEncryptBuf.Write(encryptbuf[length:])
}

postDecryptedlength := len(postDecryptedData)
blength := len(b)
copy(b, postDecryptedData)
Expand Down

0 comments on commit 7595b5f

Please sign in to comment.