Skip to content

Commit

Permalink
Add support for OK packets representing EOF
Browse files Browse the repository at this point in the history
  • Loading branch information
nemith authored and tz70s committed Sep 5, 2020
1 parent 46351a8 commit c34acf9
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 36 deletions.
25 changes: 11 additions & 14 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,16 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {

// Read Result
columnCount, err := stmt.readPrepareResultPacket()
if err == nil {
if stmt.paramCount > 0 {
if err = mc.readUntilEOF(); err != nil {
return nil, err
}
}
if err != nil {
return stmt, err
}

if columnCount > 0 {
err = mc.readUntilEOF()
}
if err := mc.readPackets(stmt.paramCount); err != nil {
return nil, err
}

if err := mc.readPackets(int(columnCount)); err != nil {
return nil, err
}

return stmt, err
Expand Down Expand Up @@ -415,11 +415,8 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
rows.mc = mc
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}

if resLen > 0 {
// Columns
if err := mc.readUntilEOF(); err != nil {
return nil, err
}
if err := mc.readPackets(resLen); err != nil {
return nil, err
}

dest := make([]driver.Value, resLen)
Expand Down
95 changes: 73 additions & 22 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
if len(data) > pos {
// character set [1 byte]
// status flags [2 bytes]
pos += 1 + 2

// capability flags (upper 2 bytes) [2 bytes]
mc.flags += clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
pos += 2

// length of auth-plugin-data [1 byte]
// reserved (all [00]) [10 bytes]
pos += 1 + 2 + 2 + 1 + 10
pos += +1 + 10

// second part of the password cipher [mininum 13 bytes],
// where len=MAX(13, length of auth-plugin-data - 8)
Expand Down Expand Up @@ -286,6 +291,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
clientLocalFiles |
clientPluginAuth |
clientMultiResults |
mc.flags&clientDeprecateEOF |
mc.flags&clientLongFlag

if mc.cfg.ClientFoundRows {
Expand Down Expand Up @@ -610,18 +616,19 @@ func readStatus(b []byte) statusFlag {
// Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
func (mc *mysqlConn) handleOkPacket(data []byte) error {
var n, m int

// 0x00 [1 byte]

// 0x00 or 0xFE [1 byte]
n := 1
var l int
// Affected rows [Length Coded Binary]
mc.affectedRows, _, n = readLengthEncodedInteger(data[1:])
mc.affectedRows, _, l = readLengthEncodedInteger(data[n:])
n += l

// Insert id [Length Coded Binary]
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
mc.insertId, _, l = readLengthEncodedInteger(data[n:])
n += l

// server_status [2 bytes]
mc.status = readStatus(data[1+n+m : 1+n+m+2])
mc.status = readStatus(data[n : n+2])
if mc.status&statusMoreResultsExists != 0 {
return nil
}
Expand All @@ -631,19 +638,24 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
return nil
}

// isEOFPacket will return true if the data is either a EOF-Packet or OK-Packet
// acting as an EOF.
func isEOFPacket(data []byte) bool {
return data[0] == iEOF && len(data) < 9
}

// Read Packets as Field Packets until EOF-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
columns := make([]mysqlField, count)

for i := 0; ; i++ {
for i := 0; i < count; i++ {
data, err := mc.readPacket()
if err != nil {
return nil, err
}

// EOF Packet
if data[0] == iEOF && (len(data) == 5 || len(data) == 1) {
if mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data) {
if i == count {
return columns, nil
}
Expand Down Expand Up @@ -729,9 +741,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
// defaultVal, _, err = bytesToLengthCodedBinary(data[pos:])
//}
}
return columns, nil
}

// Read Packets as Field Packets until EOF-Packet or an Error appears
// Read Packets as Field Packets until EOF/OK-Packet or an Error appears
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
func (rows *textRows) readRow(dest []driver.Value) error {
mc := rows.mc
Expand All @@ -746,9 +759,15 @@ func (rows *textRows) readRow(dest []driver.Value) error {
}

// EOF Packet
if data[0] == iEOF && len(data) == 5 {
// server_status [2 bytes]
rows.mc.status = readStatus(data[3:])
if isEOFPacket(data) {
if mc.flags&clientDeprecateEOF == 0 {
// server_status [2 bytes]
rows.mc.status = readStatus(data[3:])
} else {
if err := mc.handleOkPacket(data); err != nil {
return err
}
}
rows.rs.done = true
if !rows.HasNextResultSet() {
rows.mc = nil
Expand Down Expand Up @@ -808,18 +827,44 @@ func (mc *mysqlConn) readUntilEOF() error {
return err
}

switch data[0] {
case iERR:
switch {
case data[0] == iERR:
return mc.handleErrorPacket(data)
case iEOF:
if len(data) == 5 {
case isEOFPacket(data):
if mc.flags&clientDeprecateEOF == 0 {
mc.status = readStatus(data[3:])
} else {
return mc.handleOkPacket(data)
}
return nil
}
}
}

func (mc *mysqlConn) readPackets(num int) error {

// we need to read EOF as well
if mc.flags&clientDeprecateEOF == 0 {
num++
}

for i := 0; i < num; i++ {
data, err := mc.readPacket()
if err != nil {
return err
}

switch {
case data[0] == iERR:
return mc.handleErrorPacket(data)
case mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data):
mc.status = readStatus(data[3:])
return nil
}
}
return nil
}

/******************************************************************************
* Prepared Statements *
******************************************************************************/
Expand Down Expand Up @@ -1178,15 +1223,21 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {

// packet indicator [1 byte]
if data[0] != iOK {
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
rows.mc.status = readStatus(data[3:])
if isEOFPacket(data) {
if rows.mc.flags&clientDeprecateEOF == 0 {
rows.mc.status = readStatus(data[3:])
} else {
if err := rows.mc.handleOkPacket(data); err != nil {
return err
}
}
rows.rs.done = true
if !rows.HasNextResultSet() {
rows.mc = nil
}
return io.EOF
}

mc := rows.mc
rows.mc = nil

Expand Down

0 comments on commit c34acf9

Please sign in to comment.