From b30aca1159166fb21d9bb5981a47dbdd94ebfead Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 12 Nov 2020 10:07:35 +0530 Subject: [PATCH 1/5] fix COM_STMT_PREPARE_OK packet Signed-off-by: Harshit Gangal --- go/mysql/encoding.go | 5 +++++ go/mysql/query.go | 32 ++++++++++++++++++++++++-------- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/go/mysql/encoding.go b/go/mysql/encoding.go index 893e58df785..38cf932634b 100644 --- a/go/mysql/encoding.go +++ b/go/mysql/encoding.go @@ -348,6 +348,11 @@ func (d *coder) writeUint16(value uint16) { d.pos = newPos } +func (d *coder) writeUint32(value uint32) { + newPos := writeUint32(d.data, d.pos, value) + d.pos = newPos +} + func (d *coder) writeLenEncString(value string) { newPos := writeLenEncString(d.data, d.pos, value) d.pos = newPos diff --git a/go/mysql/query.go b/go/mysql/query.go index d8857d75094..04c30ccc519 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -1015,6 +1015,15 @@ func (c *Conn) writeEndResult(more bool, affectedRows, lastInsertID uint64, warn return nil } +// PacketComStmtPrepareOK contains the COM_STMT_PREPARE_OK packet details +type PacketComStmtPrepareOK struct { + status uint8 + stmtID uint32 + numCols uint16 + numParams uint16 + warningCount uint16 +} + // writePrepare writes a prepare query response to the wire. func (c *Conn) writePrepare(fld []*querypb.Field, prepare *PrepareData) error { paramsCount := prepare.ParamsCount @@ -1026,14 +1035,21 @@ func (c *Conn) writePrepare(fld []*querypb.Field, prepare *PrepareData) error { prepare.ColumnNames = make([]string, columnCount) } - data, pos := c.startEphemeralPacketWithHeader(12) - - pos = writeByte(data, pos, 0x00) - pos = writeUint32(data, pos, uint32(prepare.StatementID)) - pos = writeUint16(data, pos, uint16(columnCount)) - pos = writeUint16(data, pos, uint16(paramsCount)) - pos = writeByte(data, pos, 0x00) - writeUint16(data, pos, 0x0000) + ok := PacketComStmtPrepareOK{ + status: OKPacket, + stmtID: prepare.StatementID, + numCols: (uint16)(columnCount), + numParams: paramsCount, + warningCount: 0, + } + bytes, pos := c.startEphemeralPacketWithHeader(12) + data := &coder{data: bytes, pos: pos} + data.writeByte(ok.status) + data.writeUint32(ok.stmtID) + data.writeUint16(ok.numCols) + data.writeUint16(ok.numParams) + data.writeByte(0x00) // reserved 1 byte + data.writeUint16(ok.warningCount) if err := c.writeEphemeralPacket(); err != nil { return err From 3ccaeb9c966da9547738be4873c71de4e8f64075 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 12 Nov 2020 10:49:22 +0530 Subject: [PATCH 2/5] added com_stmt_prepare unit test Signed-off-by: Harshit Gangal --- go/mysql/query_test.go | 57 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 84334297062..61416fdddaa 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -161,6 +161,63 @@ func TestComStmtPrepare(t *testing.T) { } } +func TestComStmtPrepareUpdStmt(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + sql := "UPDATE test SET __bit = ?, __tinyInt = ?, __tinyIntU = ?, __smallInt = ?, __smallIntU = ?, __mediumInt = ?, __mediumIntU = ?, __int = ?, __intU = ?, __bigInt = ?, __bigIntU = ?, __decimal = ?, __float = ?, __double = ?, __date = ?, __datetime = ?, __timestamp = ?, __time = ?, __year = ?, __char = ?, __varchar = ?, __binary = ?, __varbinary = ?, __tinyblob = ?, __tinytext = ?, __blob = ?, __text = ?, __enum = ?, __set = ? WHERE __id = 0" + mockData := MockQueryPackets(t, sql) + + if err := cConn.writePacket(mockData); err != nil { + t.Fatalf("writePacket failed: %v", err) + } + + data, err := sConn.ReadPacket() + if err != nil { + t.Fatalf("sConn.ReadPacket - ComPrepare failed: %v", err) + } + + parsedQuery := sConn.parseComPrepare(data) + if parsedQuery != sql { + t.Fatalf("Received incorrect query, want: %v, got: %v", sql, parsedQuery) + } + + paramsCount := uint16(29) + prepare := &PrepareData{ + StatementID: 1, + PrepareStmt: sql, + ParamsCount: paramsCount, + } + sConn.PrepareData = make(map[uint32]*PrepareData) + sConn.PrepareData[prepare.StatementID] = prepare + + // write the response to the client + if err := sConn.writePrepare(nil, prepare); err != nil { + t.Fatalf("sConn.writePrepare failed: %v", err) + } + + resp, err := cConn.ReadPacket() + if err != nil { + t.Fatalf("cConn.ReadPacket failed: %v", err) + } + if uint32(resp[1]) != prepare.StatementID { + t.Fatalf("Received incorrect Statement ID, want: %v, got: %v", prepare.StatementID, resp[1]) + } + for i := uint16(0); i < paramsCount; i++ { + resp, err := cConn.ReadPacket() + if err != nil { + t.Fatalf("cConn.ReadPacket failed: %v", err) + } + if resp[17] != 0xfd { + t.Fatalf("Received incorrect params type, want: %v, got: %v", 0xfd, resp[21]) + } + } +} + func TestComStmtSendLongData(t *testing.T) { listener, sConn, cConn := createSocketPair(t) defer func() { From 185a50364d435e7dc8d85a3ed23407797b3302c0 Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 12 Nov 2020 19:01:40 +0530 Subject: [PATCH 3/5] year in binary is changed to varbinary Signed-off-by: Harshit Gangal --- go/sqltypes/type.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/go/sqltypes/type.go b/go/sqltypes/type.go index a89b19d9a1f..45edc7b62b5 100644 --- a/go/sqltypes/type.go +++ b/go/sqltypes/type.go @@ -196,37 +196,30 @@ func modifyType(typ querypb.Type, flags int64) querypb.Type { if flags&mysqlUnsigned != 0 { return Uint8 } - return Int8 case Int16: if flags&mysqlUnsigned != 0 { return Uint16 } - return Int16 case Int32: if flags&mysqlUnsigned != 0 { return Uint32 } - return Int32 case Int64: if flags&mysqlUnsigned != 0 { return Uint64 } - return Int64 case Int24: if flags&mysqlUnsigned != 0 { return Uint24 } - return Int24 case Text: if flags&mysqlBinary != 0 { return Blob } - return Text case VarChar: if flags&mysqlBinary != 0 { return VarBinary } - return VarChar case Char: if flags&mysqlBinary != 0 { return Binary @@ -237,7 +230,10 @@ func modifyType(typ querypb.Type, flags int64) querypb.Type { if flags&mysqlSet != 0 { return Set } - return Char + case Year: + if flags&mysqlBinary != 0 { + return VarBinary + } } return typ } From c7f579d63884875c1b00e2fadd6f20e23bbf73be Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 12 Nov 2020 19:02:30 +0530 Subject: [PATCH 4/5] added params type check unit test in COM_STMT_EXECUTE Signed-off-by: Harshit Gangal --- go/mysql/query_test.go | 80 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 61416fdddaa..1f85ab1a9f2 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -22,6 +22,9 @@ import ( "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/golang/protobuf/proto" "vitess.io/vitess/go/sqltypes" @@ -279,6 +282,83 @@ func TestComStmtExecute(t *testing.T) { } } +func TestComStmtExecuteUpdStmt(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + prepareDataMap := map[uint32]*PrepareData{ + 1: { + StatementID: 1, + ParamsCount: 29, + ParamsType: make([]int32, 29), + BindVars: map[string]*querypb.BindVariable{}, + }} + + // This is simulated packets for update query + data := []byte{ + 0x29, 0x01, 0x00, 0x00, 0x17, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x01, 0x10, 0x00, 0x01, 0x00, 0x01, 0x80, 0x02, 0x00, 0x02, 0x80, 0x03, 0x00, 0x03, + 0x80, 0x03, 0x00, 0x03, 0x80, 0x08, 0x00, 0x08, 0x80, 0x00, 0x00, 0x04, 0x00, 0x05, 0x00, 0x0a, + 0x00, 0x0c, 0x00, 0x07, 0x00, 0x0b, 0x00, 0x0d, 0x80, 0xfe, 0x00, 0xfe, 0x00, 0xfc, 0x00, 0xfc, + 0x00, 0xfc, 0x00, 0xfe, 0x00, 0xfc, 0x00, 0xfe, 0x00, 0xfe, 0x00, 0xfe, 0x00, 0x08, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0xaa, 0xe0, 0x80, 0xff, 0x00, 0x80, 0xff, 0xff, 0x00, 0x00, 0x80, 0xff, + 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x80, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x15, 0x31, 0x32, 0x33, + 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x30, 0x2e, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + 0x38, 0x39, 0xd0, 0x0f, 0x49, 0x40, 0x44, 0x17, 0x41, 0x54, 0xfb, 0x21, 0x09, 0x40, 0x04, 0xe0, + 0x07, 0x08, 0x08, 0x0b, 0xe0, 0x07, 0x08, 0x08, 0x11, 0x19, 0x3b, 0x00, 0x00, 0x00, 0x00, 0x0b, + 0xe0, 0x07, 0x08, 0x08, 0x11, 0x19, 0x3b, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x01, 0x08, 0x00, 0x00, + 0x00, 0x07, 0x3b, 0x3b, 0x00, 0x00, 0x00, 0x00, 0x04, 0x31, 0x39, 0x39, 0x39, 0x08, 0x31, 0x32, + 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x0c, 0xe9, 0x9f, 0xa9, 0xe5, 0x86, 0xac, 0xe7, 0x9c, 0x9f, + 0xe8, 0xb5, 0x9e, 0x08, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x08, 0x31, 0x32, 0x33, + 0x34, 0x35, 0x36, 0x37, 0x38, 0x08, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x0c, 0xe9, + 0x9f, 0xa9, 0xe5, 0x86, 0xac, 0xe7, 0x9c, 0x9f, 0xe8, 0xb5, 0x9e, 0x08, 0x31, 0x32, 0x33, 0x34, + 0x35, 0x36, 0x37, 0x38, 0x0c, 0xe9, 0x9f, 0xa9, 0xe5, 0x86, 0xac, 0xe7, 0x9c, 0x9f, 0xe8, 0xb5, + 0x9e, 0x03, 0x66, 0x6f, 0x6f, 0x07, 0x66, 0x6f, 0x6f, 0x2c, 0x62, 0x61, 0x72} + + stmtID, _, err := sConn.parseComStmtExecute(prepareDataMap, data[4:]) // first 4 are header + require.NoError(t, err) + require.EqualValues(t, 1, stmtID) + + prepData := prepareDataMap[stmtID] + assert.EqualValues(t, querypb.Type_BIT, prepData.ParamsType[0], "got: %s", querypb.Type(prepData.ParamsType[0])) + assert.EqualValues(t, querypb.Type_INT8, prepData.ParamsType[1], "got: %s", querypb.Type(prepData.ParamsType[1])) + assert.EqualValues(t, querypb.Type_INT8, prepData.ParamsType[2], "got: %s", querypb.Type(prepData.ParamsType[2])) + assert.EqualValues(t, querypb.Type_INT16, prepData.ParamsType[3], "got: %s", querypb.Type(prepData.ParamsType[3])) + assert.EqualValues(t, querypb.Type_INT16, prepData.ParamsType[4], "got: %s", querypb.Type(prepData.ParamsType[4])) + assert.EqualValues(t, querypb.Type_INT32, prepData.ParamsType[5], "got: %s", querypb.Type(prepData.ParamsType[5])) + assert.EqualValues(t, querypb.Type_INT32, prepData.ParamsType[6], "got: %s", querypb.Type(prepData.ParamsType[6])) + assert.EqualValues(t, querypb.Type_INT32, prepData.ParamsType[7], "got: %s", querypb.Type(prepData.ParamsType[7])) + assert.EqualValues(t, querypb.Type_INT32, prepData.ParamsType[8], "got: %s", querypb.Type(prepData.ParamsType[8])) + assert.EqualValues(t, querypb.Type_INT64, prepData.ParamsType[9], "got: %s", querypb.Type(prepData.ParamsType[9])) + assert.EqualValues(t, querypb.Type_INT64, prepData.ParamsType[10], "got: %s", querypb.Type(prepData.ParamsType[10])) + assert.EqualValues(t, querypb.Type_DECIMAL, prepData.ParamsType[11], "got: %s", querypb.Type(prepData.ParamsType[11])) + assert.EqualValues(t, querypb.Type_FLOAT32, prepData.ParamsType[12], "got: %s", querypb.Type(prepData.ParamsType[12])) + assert.EqualValues(t, querypb.Type_FLOAT64, prepData.ParamsType[13], "got: %s", querypb.Type(prepData.ParamsType[13])) + assert.EqualValues(t, querypb.Type_DATE, prepData.ParamsType[14], "got: %s", querypb.Type(prepData.ParamsType[14])) + assert.EqualValues(t, querypb.Type_DATETIME, prepData.ParamsType[15], "got: %s", querypb.Type(prepData.ParamsType[15])) + assert.EqualValues(t, querypb.Type_TIMESTAMP, prepData.ParamsType[16], "got: %s", querypb.Type(prepData.ParamsType[16])) + assert.EqualValues(t, querypb.Type_TIME, prepData.ParamsType[17], "got: %s", querypb.Type(prepData.ParamsType[17])) + + // this is year but in binary it is changed to varbinary + assert.EqualValues(t, querypb.Type_VARBINARY, prepData.ParamsType[18], "got: %s", querypb.Type(prepData.ParamsType[18])) + + assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[19], "got: %s", querypb.Type(prepData.ParamsType[19])) + assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[20], "got: %s", querypb.Type(prepData.ParamsType[20])) + assert.EqualValues(t, querypb.Type_TEXT, prepData.ParamsType[21], "got: %s", querypb.Type(prepData.ParamsType[21])) + assert.EqualValues(t, querypb.Type_TEXT, prepData.ParamsType[22], "got: %s", querypb.Type(prepData.ParamsType[22])) + assert.EqualValues(t, querypb.Type_TEXT, prepData.ParamsType[23], "got: %s", querypb.Type(prepData.ParamsType[23])) + assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[24], "got: %s", querypb.Type(prepData.ParamsType[24])) + assert.EqualValues(t, querypb.Type_TEXT, prepData.ParamsType[25], "got: %s", querypb.Type(prepData.ParamsType[25])) + assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[26], "got: %s", querypb.Type(prepData.ParamsType[26])) + assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[27], "got: %s", querypb.Type(prepData.ParamsType[27])) + assert.EqualValues(t, querypb.Type_CHAR, prepData.ParamsType[28], "got: %s", querypb.Type(prepData.ParamsType[28])) +} + func TestComStmtClose(t *testing.T) { listener, sConn, cConn := createSocketPair(t) defer func() { From 8c29251236c7bac9c6d3262252ea2bda83e01c8d Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Thu, 12 Nov 2020 19:28:07 +0530 Subject: [PATCH 5/5] addressed review comments Signed-off-by: Harshit Gangal --- go/mysql/query_test.go | 35 +++++++++++------------------------ 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 1f85ab1a9f2..a58da8fb486 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -175,19 +175,14 @@ func TestComStmtPrepareUpdStmt(t *testing.T) { sql := "UPDATE test SET __bit = ?, __tinyInt = ?, __tinyIntU = ?, __smallInt = ?, __smallIntU = ?, __mediumInt = ?, __mediumIntU = ?, __int = ?, __intU = ?, __bigInt = ?, __bigIntU = ?, __decimal = ?, __float = ?, __double = ?, __date = ?, __datetime = ?, __timestamp = ?, __time = ?, __year = ?, __char = ?, __varchar = ?, __binary = ?, __varbinary = ?, __tinyblob = ?, __tinytext = ?, __blob = ?, __text = ?, __enum = ?, __set = ? WHERE __id = 0" mockData := MockQueryPackets(t, sql) - if err := cConn.writePacket(mockData); err != nil { - t.Fatalf("writePacket failed: %v", err) - } + err := cConn.writePacket(mockData) + require.NoError(t, err, "writePacket failed") data, err := sConn.ReadPacket() - if err != nil { - t.Fatalf("sConn.ReadPacket - ComPrepare failed: %v", err) - } + require.NoError(t, err, "sConn.ReadPacket - ComPrepare failed") parsedQuery := sConn.parseComPrepare(data) - if parsedQuery != sql { - t.Fatalf("Received incorrect query, want: %v, got: %v", sql, parsedQuery) - } + require.Equal(t, sql, parsedQuery, "Received incorrect query") paramsCount := uint16(29) prepare := &PrepareData{ @@ -199,25 +194,17 @@ func TestComStmtPrepareUpdStmt(t *testing.T) { sConn.PrepareData[prepare.StatementID] = prepare // write the response to the client - if err := sConn.writePrepare(nil, prepare); err != nil { - t.Fatalf("sConn.writePrepare failed: %v", err) - } + err = sConn.writePrepare(nil, prepare) + require.NoError(t, err, "sConn.writePrepare failed") resp, err := cConn.ReadPacket() - if err != nil { - t.Fatalf("cConn.ReadPacket failed: %v", err) - } - if uint32(resp[1]) != prepare.StatementID { - t.Fatalf("Received incorrect Statement ID, want: %v, got: %v", prepare.StatementID, resp[1]) - } + require.NoError(t, err, "cConn.ReadPacket failed") + require.EqualValues(t, prepare.StatementID, resp[1], "Received incorrect Statement ID") + for i := uint16(0); i < paramsCount; i++ { resp, err := cConn.ReadPacket() - if err != nil { - t.Fatalf("cConn.ReadPacket failed: %v", err) - } - if resp[17] != 0xfd { - t.Fatalf("Received incorrect params type, want: %v, got: %v", 0xfd, resp[21]) - } + require.NoError(t, err, "cConn.ReadPacket failed") + require.EqualValues(t, 0xfd, resp[17], "Received incorrect Statement ID") } }