Skip to content

Commit

Permalink
Support only mysql 4.0--
Browse files Browse the repository at this point in the history
  • Loading branch information
yuuki committed May 15, 2016
1 parent 7ebe0a5 commit 76b4b4f
Showing 1 changed file with 30 additions and 80 deletions.
110 changes: 30 additions & 80 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"crypto/tls"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
Expand Down Expand Up @@ -164,9 +163,6 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {

// capability flags (lower 2 bytes) [2 bytes]
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
if mc.flags&clientProtocol41 == 0 {
return nil, ErrOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
return nil, ErrNoTLS
}
Expand Down Expand Up @@ -219,14 +215,9 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
// Adjust client flags based on server support
clientFlags := clientProtocol41 |
clientSecureConn |
clientLongPassword |
clientFlags := clientLongPassword |
clientTransactions |
clientLocalFiles |
clientPluginAuth |
clientMultiResults |
mc.flags&clientLongFlag
clientLongFlag

if mc.cfg.ClientFoundRows {
clientFlags |= clientFoundRows
Expand All @@ -241,10 +232,17 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
clientFlags |= clientMultiStatements
}

// User Password
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
// User password
scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd))

pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1
pktLen := 2 + 3 + len(mc.cfg.User) + 1

scrambleBuffLen := len(scrambleBuff)
if scrambleBuffLen > 0 {
pktLen += scrambleBuffLen
} else {
pktLen += 1
}

// To specify a db name
if n := len(mc.cfg.DBName); n > 0 {
Expand All @@ -260,32 +258,19 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
return driver.ErrBadConn
}

// ClientFlags [32 bit]
// ClientFlags [16 bit]
data[4] = byte(clientFlags)
data[5] = byte(clientFlags >> 8)
data[6] = byte(clientFlags >> 16)
data[7] = byte(clientFlags >> 24)

// MaxPacketSize [32 bit] (none)
// MaxPacketSize [24 bit] (none)
data[6] = 0x00
data[7] = 0x00
data[8] = 0x00
data[9] = 0x00
data[10] = 0x00
data[11] = 0x00

// Charset [1 byte]
var found bool
data[12], found = collations[mc.cfg.Collation]
if !found {
// Note possibility for false negatives:
// could be triggered although the collation is valid if the
// collations map does not contain entries the server supports.
return errors.New("unknown collation")
}

// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if mc.cfg.tls != nil {
// Send TLS / SSL request packet
//TODO Send TLS / SSL request packet
if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil {
return err
}
Expand All @@ -299,33 +284,26 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
mc.buf.nc = tlsConn
}

// Filler [23 bytes] (all 0x00)
pos := 13
for ; pos < 13+23; pos++ {
data[pos] = 0
}

// User [null terminated string]
pos := 9
if len(mc.cfg.User) > 0 {
pos += copy(data[pos:], mc.cfg.User)
}
data[pos] = 0x00
pos++

// ScrambleBuffer [length encoded integer]
data[pos] = byte(len(scrambleBuff))
pos += 1 + copy(data[pos+1:], scrambleBuff)

// Databasename [null terminated string]
if len(mc.cfg.DBName) > 0 {
pos += copy(data[pos:], mc.cfg.DBName)
if len(scrambleBuff) > 0 {
// ScrambleBuffer [length encoded integer]
pos += copy(data[pos:], scrambleBuff)
} else {
data[pos] = 0x00
pos++
}

// Assume native client during response
pos += copy(data[pos:], "mysql_native_password")
data[pos] = 0x00
// Databasename
if len(mc.cfg.DBName) > 0 {
pos += copy(data[pos:], mc.cfg.DBName)
}

// Send Auth packet
return mc.writePacket(data)
Expand Down Expand Up @@ -586,18 +564,7 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
}

// Catalog
pos, err := skipLengthEncodedString(data)
if err != nil {
return nil, err
}

// Database [len coded string]
n, err := skipLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
pos += n
pos := 0

// Table [len coded string]
if mc.cfg.ColumnsWithAlias {
Expand All @@ -608,42 +575,25 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
pos += n
columns[i].tableName = string(tableName)
} else {
n, err = skipLengthEncodedString(data[pos:])
n, err := skipLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
pos += n
}

// Original table [len coded string]
n, err = skipLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
pos += n

// Name [len coded string]
name, _, n, err := readLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}
columns[i].name = string(name)
pos += n

// Original name [len coded string]
n, err = skipLengthEncodedString(data[pos:])
if err != nil {
return nil, err
}

// Filler [uint8]
// Charset [charset, collation uint8]
// Length [uint32]
pos += n + 1 + 2 + 4
pos += n + 4

// Field type [uint8]
columns[i].fieldType = data[pos]
pos++
pos += 2 + 1

// Flags [uint16]
columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
Expand Down

0 comments on commit 76b4b4f

Please sign in to comment.