From c34acf9e4552ed8b8ebea4adebc2bd9466f01dd8 Mon Sep 17 00:00:00 2001 From: Brandon Bennett Date: Wed, 1 May 2019 10:04:44 -0600 Subject: [PATCH] Add support for OK packets representing EOF Fixes: #805 --- connection.go | 25 ++++++-------- packets.go | 95 +++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 84 insertions(+), 36 deletions(-) diff --git a/connection.go b/connection.go index 90aec643..425cc595 100644 --- a/connection.go +++ b/connection.go @@ -180,16 +180,16 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { // Read Result columnCount, err := stmt.readPrepareResultPacket() - if err == nil { - if stmt.paramCount > 0 { - if err = mc.readUntilEOF(); err != nil { - return nil, err - } - } + if err != nil { + return stmt, err + } - if columnCount > 0 { - err = mc.readUntilEOF() - } + if err := mc.readPackets(stmt.paramCount); err != nil { + return nil, err + } + + if err := mc.readPackets(int(columnCount)); err != nil { + return nil, err } return stmt, err @@ -415,11 +415,8 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { rows.mc = mc rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} - if resLen > 0 { - // Columns - if err := mc.readUntilEOF(); err != nil { - return nil, err - } + if err := mc.readPackets(resLen); err != nil { + return nil, err } dest := make([]driver.Value, resLen) diff --git a/packets.go b/packets.go index 6664e5ae..422278fe 100644 --- a/packets.go +++ b/packets.go @@ -235,10 +235,15 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro if len(data) > pos { // character set [1 byte] // status flags [2 bytes] + pos += 1 + 2 + // capability flags (upper 2 bytes) [2 bytes] + mc.flags += clientFlag(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16 + pos += 2 + // length of auth-plugin-data [1 byte] // reserved (all [00]) [10 bytes] - pos += 1 + 2 + 2 + 1 + 10 + pos += +1 + 10 // second part of the password cipher [mininum 13 bytes], // where len=MAX(13, length of auth-plugin-data - 8) @@ -286,6 +291,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string clientLocalFiles | clientPluginAuth | clientMultiResults | + mc.flags&clientDeprecateEOF | mc.flags&clientLongFlag if mc.cfg.ClientFoundRows { @@ -610,18 +616,19 @@ func readStatus(b []byte) statusFlag { // Ok Packet // http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet func (mc *mysqlConn) handleOkPacket(data []byte) error { - var n, m int - - // 0x00 [1 byte] - + // 0x00 or 0xFE [1 byte] + n := 1 + var l int // Affected rows [Length Coded Binary] - mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) + mc.affectedRows, _, l = readLengthEncodedInteger(data[n:]) + n += l // Insert id [Length Coded Binary] - mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) + mc.insertId, _, l = readLengthEncodedInteger(data[n:]) + n += l // server_status [2 bytes] - mc.status = readStatus(data[1+n+m : 1+n+m+2]) + mc.status = readStatus(data[n : n+2]) if mc.status&statusMoreResultsExists != 0 { return nil } @@ -631,19 +638,24 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { return nil } +// isEOFPacket will return true if the data is either a EOF-Packet or OK-Packet +// acting as an EOF. +func isEOFPacket(data []byte) bool { + return data[0] == iEOF && len(data) < 9 +} + // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { columns := make([]mysqlField, count) - for i := 0; ; i++ { + for i := 0; i < count; i++ { data, err := mc.readPacket() if err != nil { return nil, err } - // EOF Packet - if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { + if mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data) { if i == count { return columns, nil } @@ -729,9 +741,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) //} } + return columns, nil } -// Read Packets as Field Packets until EOF-Packet or an Error appears +// Read Packets as Field Packets until EOF/OK-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow func (rows *textRows) readRow(dest []driver.Value) error { mc := rows.mc @@ -746,9 +759,15 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // EOF Packet - if data[0] == iEOF && len(data) == 5 { - // server_status [2 bytes] - rows.mc.status = readStatus(data[3:]) + if isEOFPacket(data) { + if mc.flags&clientDeprecateEOF == 0 { + // server_status [2 bytes] + rows.mc.status = readStatus(data[3:]) + } else { + if err := mc.handleOkPacket(data); err != nil { + return err + } + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil @@ -808,18 +827,44 @@ func (mc *mysqlConn) readUntilEOF() error { return err } - switch data[0] { - case iERR: + switch { + case data[0] == iERR: return mc.handleErrorPacket(data) - case iEOF: - if len(data) == 5 { + case isEOFPacket(data): + if mc.flags&clientDeprecateEOF == 0 { mc.status = readStatus(data[3:]) + } else { + return mc.handleOkPacket(data) } return nil } } } +func (mc *mysqlConn) readPackets(num int) error { + + // we need to read EOF as well + if mc.flags&clientDeprecateEOF == 0 { + num++ + } + + for i := 0; i < num; i++ { + data, err := mc.readPacket() + if err != nil { + return err + } + + switch { + case data[0] == iERR: + return mc.handleErrorPacket(data) + case mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data): + mc.status = readStatus(data[3:]) + return nil + } + } + return nil +} + /****************************************************************************** * Prepared Statements * ******************************************************************************/ @@ -1178,15 +1223,21 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // packet indicator [1 byte] if data[0] != iOK { - // EOF Packet - if data[0] == iEOF && len(data) == 5 { - rows.mc.status = readStatus(data[3:]) + if isEOFPacket(data) { + if rows.mc.flags&clientDeprecateEOF == 0 { + rows.mc.status = readStatus(data[3:]) + } else { + if err := rows.mc.handleOkPacket(data); err != nil { + return err + } + } rows.rs.done = true if !rows.HasNextResultSet() { rows.mc = nil } return io.EOF } + mc := rows.mc rows.mc = nil