Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: use types.Row to write result set. #5056

Merged
merged 7 commits into from Nov 10, 2017
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 14 additions & 10 deletions expression/aggregation/avg.go
Expand Up @@ -82,20 +82,24 @@ func (af *avgFunction) Update(ctx *AggEvaluateContext, sc *variable.StatementCon

// GetResult implements Aggregation interface.
func (af *avgFunction) GetResult(ctx *AggEvaluateContext) (d types.Datum) {
var x *types.MyDecimal
switch ctx.Value.Kind() {
case types.KindFloat64:
t := ctx.Value.GetFloat64() / float64(ctx.Count)
d.SetValue(t)
x = new(types.MyDecimal)
err := x.FromFloat64(ctx.Value.GetFloat64())
terror.Log(err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

errors.Trace(err)

case types.KindMysqlDecimal:
x := ctx.Value.GetMysqlDecimal()
y := types.NewDecFromInt(ctx.Count)
to := new(types.MyDecimal)
err := types.DecimalDiv(x, y, to, types.DivFracIncr)
terror.Log(errors.Trace(err))
err = to.Round(to, ctx.Value.Frac()+types.DivFracIncr, types.ModeHalfEven)
terror.Log(errors.Trace(err))
d.SetMysqlDecimal(to)
x = ctx.Value.GetMysqlDecimal()
default:
return
}
y := types.NewDecFromInt(ctx.Count)
to := new(types.MyDecimal)
err := types.DecimalDiv(x, y, to, types.DivFracIncr)
terror.Log(errors.Trace(err))
err = to.Round(to, ctx.Value.Frac()+types.DivFracIncr, types.ModeHalfEven)
terror.Log(errors.Trace(err))
d.SetMysqlDecimal(to)
return
}

Expand Down
8 changes: 8 additions & 0 deletions expression/aggregation/sum.go
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/types"
)

Expand All @@ -42,6 +43,13 @@ func (sf *sumFunction) Update(ctx *AggEvaluateContext, sc *variable.StatementCon

// GetResult implements Aggregation interface.
func (sf *sumFunction) GetResult(ctx *AggEvaluateContext) (d types.Datum) {
if ctx.Value.Kind() == types.KindFloat64 {
dec := new(types.MyDecimal)
err := dec.FromFloat64(ctx.Value.GetFloat64())
terror.Log(err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

d.SetMysqlDecimal(dec)
return
}
return ctx.Value
}

Expand Down
21 changes: 4 additions & 17 deletions server/conn.go
Expand Up @@ -813,9 +813,6 @@ func (cc *clientConn) writeResultset(rs ResultSet, binary bool, more bool) error
if err = cc.writeEOF(false); err != nil {
return errors.Trace(err)
}

numBytes4Null := ((len(columns) + 7 + 2) / 8)
rowBuffer := make([]byte, 1+numBytes4Null, 1+numBytes4Null+8*(len(columns)))
for {
if err != nil {
return errors.Trace(err)
Expand All @@ -825,24 +822,14 @@ func (cc *clientConn) writeResultset(rs ResultSet, binary bool, more bool) error
}
data = data[0:4]
if binary {
rowBuffer = rowBuffer[0 : 1+numBytes4Null : cap(rowBuffer)]
rowBuffer, err = dumpRowValuesBinary(rowBuffer, columns, row)
data, err = dumpBinaryRow(data, columns, row)
if err != nil {
return errors.Trace(err)
}
data = append(data, rowBuffer...)
} else {
for i, value := range row {
if value.IsNull() {
data = append(data, 0xfb)
continue
}
var valData []byte
valData, err = dumpTextValue(columns[i], value)
if err != nil {
return errors.Trace(err)
}
data = dumpLengthEncodedString(data, valData)
data, err = dumpTextRow(data, columns, row)
if err != nil {
return errors.Trace(err)
}
}

Expand Down
2 changes: 1 addition & 1 deletion server/driver.go
Expand Up @@ -121,6 +121,6 @@ type PreparedStatement interface {
// ResultSet is the result set of an query.
type ResultSet interface {
Columns() ([]*ColumnInfo, error)
Next() ([]types.Datum, error)
Next() (types.Row, error)
Close() error
}
4 changes: 2 additions & 2 deletions server/driver_tidb.go
Expand Up @@ -292,13 +292,13 @@ type tidbResultSet struct {
recordSet ast.RecordSet
}

func (trs *tidbResultSet) Next() ([]types.Datum, error) {
func (trs *tidbResultSet) Next() (types.Row, error) {
row, err := trs.recordSet.Next()
if err != nil {
return nil, errors.Trace(err)
}
if row != nil {
return row.Data, nil
return types.DatumRow(row.Data), nil
}
return nil, nil
}
Expand Down
22 changes: 22 additions & 0 deletions server/server_test.go
Expand Up @@ -723,6 +723,28 @@ func runTestTLSConnection(t *C, overrider configOverrider) error {
return err
}

func runTestSumAvg(c *C) {
runTests(c, nil, func(dbt *DBTest) {
dbt.mustExec("create table sumavg (a int, b decimal, c double)")
dbt.mustExec("insert sumavg values (1, 1, 1)")
rows := dbt.mustQuery("select sum(a), sum(b), sum(c) from sumavg")
c.Assert(rows.Next(), IsTrue)
var outA, outB, outC float64
err := rows.Scan(&outA, &outB, &outC)
c.Assert(err, IsNil)
c.Assert(outA, Equals, 1.0)
c.Assert(outB, Equals, 1.0)
c.Assert(outC, Equals, 1.0)
rows = dbt.mustQuery("select avg(a), avg(b), avg(c) from sumavg")
c.Assert(rows.Next(), IsTrue)
err = rows.Scan(&outA, &outB, &outC)
c.Assert(err, IsNil)
c.Assert(outA, Equals, 1.0)
c.Assert(outB, Equals, 1.0)
c.Assert(outC, Equals, 1.0)
})
}

func getMetrics(t *C) []byte {
resp, err := http.Get("http://127.0.0.1:10090/metrics")
t.Assert(err, IsNil)
Expand Down
8 changes: 7 additions & 1 deletion server/tidb_test.go
Expand Up @@ -419,5 +419,11 @@ func (ts *TidbTestSuite) TestShowCreateTableFlen(c *C) {
c.Assert(err, IsNil)
c.Assert(len(cols), Equals, 2)
c.Assert(int(cols[0].ColumnLength), Equals, 5*tmysql.MaxBytesOfCharacter)
c.Assert(int(cols[1].ColumnLength), Equals, len(row[1].GetString())*tmysql.MaxBytesOfCharacter)
str, _ := row.GetString(1)
c.Assert(int(cols[1].ColumnLength), Equals, len(str)*tmysql.MaxBytesOfCharacter)
}

func (ts *TidbTestSuite) TestSumAvg(c *C) {
c.Parallel()
runTestSumAvg(c)
}
209 changes: 117 additions & 92 deletions server/util.go
Expand Up @@ -189,7 +189,7 @@ func dumpBinaryTime(dur time.Duration) (data []byte) {
return
}

func dumpBinaryDateTime(t types.Time, loc *time.Location) (data []byte, err error) {
func dumpBinaryDateTime(data []byte, t types.Time, loc *time.Location) ([]byte, error) {
if t.Type == mysql.TypeTimestamp && loc != nil {
// TODO: Consider time_zone variable.
t1, err := t.Time.GoTime(time.Local)
Expand All @@ -214,112 +214,137 @@ func dumpBinaryDateTime(t types.Time, loc *time.Location) (data []byte, err erro
data = dumpUint16(data, uint16(year)) //year
data = append(data, byte(mon), byte(day))
}
return
return data, nil
}

func dumpRowValuesBinary(buffer []byte, columns []*ColumnInfo, row []types.Datum) ([]byte, error) {
if len(columns) != len(row) {
return nil, mysql.ErrMalformPacket
func dumpBinaryRow(buffer []byte, columns []*ColumnInfo, row types.Row) ([]byte, error) {
buffer = append(buffer, mysql.OKHeader)
nullBitmapOff := len(buffer)
numBytes4Null := (len(columns) + 7 + 2) / 8
for i := 0; i < numBytes4Null; i++ {
buffer = append(buffer, 0)
}
buffer[0] = mysql.OKHeader
nulls := buffer[1:]
for i, val := range row {
if val.IsNull() {
for i := range columns {
if row.IsNull(i) {
bytePos := (i + 2) / 8
bitPos := byte((i + 2) % 8)
nulls[bytePos] |= 1 << bitPos
buffer[nullBitmapOff+bytePos] |= 1 << bitPos
continue
}
}
for i, val := range row {
switch val.Kind() {
case types.KindInt64:
v := val.GetInt64()
switch columns[i].Type {
case mysql.TypeTiny:
buffer = append(buffer, byte(v))
case mysql.TypeShort, mysql.TypeYear:
buffer = dumpUint16(buffer, uint16(v))
case mysql.TypeInt24, mysql.TypeLong:
buffer = dumpUint32(buffer, uint32(v))
case mysql.TypeLonglong:
buffer = dumpUint64(buffer, uint64(v))
}
case types.KindUint64:
v := val.GetUint64()
switch columns[i].Type {
case mysql.TypeTiny:
buffer = append(buffer, byte(v))
case mysql.TypeShort, mysql.TypeYear:
buffer = dumpUint16(buffer, uint16(v))
case mysql.TypeInt24, mysql.TypeLong:
buffer = dumpUint32(buffer, uint32(v))
case mysql.TypeLonglong:
buffer = dumpUint64(buffer, v)
}
case types.KindFloat32:
floatBits := math.Float32bits(val.GetFloat32())
buffer = dumpUint32(buffer, floatBits)
case types.KindFloat64:
floatBits := math.Float64bits(val.GetFloat64())
buffer = dumpUint64(buffer, floatBits)
case types.KindString, types.KindBytes:
buffer = dumpLengthEncodedString(buffer, val.GetBytes())
case types.KindMysqlDecimal:
buffer = dumpLengthEncodedString(buffer, hack.Slice(val.GetMysqlDecimal().String()))
case types.KindMysqlTime:
tmp, err := dumpBinaryDateTime(val.GetMysqlTime(), nil)
switch columns[i].Type {
case mysql.TypeTiny:
v, _ := row.GetInt64(i)
buffer = append(buffer, byte(v))
case mysql.TypeShort, mysql.TypeYear:
v, _ := row.GetInt64(i)
buffer = dumpUint16(buffer, uint16(v))
case mysql.TypeInt24, mysql.TypeLong:
v, _ := row.GetInt64(i)
buffer = dumpUint32(buffer, uint32(v))
case mysql.TypeLonglong:
v, _ := row.GetUint64(i)
buffer = dumpUint64(buffer, v)
case mysql.TypeFloat:
v, _ := row.GetFloat32(i)
buffer = dumpUint32(buffer, math.Float32bits(v))
case mysql.TypeDouble:
v, _ := row.GetFloat64(i)
buffer = dumpUint64(buffer, math.Float64bits(v))
case mysql.TypeNewDecimal:
v, _ := row.GetMyDecimal(i)
buffer = dumpLengthEncodedString(buffer, hack.Slice(v.String()))
case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar,
mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob:
v, _ := row.GetBytes(i)
buffer = dumpLengthEncodedString(buffer, v)
case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
v, _ := row.GetTime(i)
var err error
buffer, err = dumpBinaryDateTime(buffer, v, nil)
if err != nil {
return buffer, errors.Trace(err)
}
buffer = append(buffer, tmp...)
case types.KindMysqlDuration:
buffer = append(buffer, dumpBinaryTime(val.GetMysqlDuration().Duration)...)
case types.KindMysqlSet:
buffer = dumpLengthEncodedString(buffer, hack.Slice(val.GetMysqlSet().String()))
case types.KindMysqlEnum:
buffer = dumpLengthEncodedString(buffer, hack.Slice(val.GetMysqlEnum().String()))
case types.KindBinaryLiteral, types.KindMysqlBit:
buffer = dumpLengthEncodedString(buffer, hack.Slice(val.GetBinaryLiteral().ToString()))
case mysql.TypeDuration:
v, _ := row.GetDuration(i)
buffer = append(buffer, dumpBinaryTime(v.Duration)...)
case mysql.TypeEnum:
v, _ := row.GetEnum(i)
buffer = dumpLengthEncodedString(buffer, hack.Slice(v.String()))
case mysql.TypeSet:
v, _ := row.GetSet(i)
buffer = dumpLengthEncodedString(buffer, hack.Slice(v.String()))
case mysql.TypeJSON:
v, _ := row.GetJSON(i)
buffer = dumpLengthEncodedString(buffer, hack.Slice(v.String()))
default:
return nil, errInvalidType.Gen("invalid type %v", columns[i].Type)
}
}
return buffer, nil
}

func dumpTextValue(colInfo *ColumnInfo, value types.Datum) ([]byte, error) {
switch value.Kind() {
case types.KindInt64:
return strconv.AppendInt(nil, value.GetInt64(), 10), nil
case types.KindUint64:
return strconv.AppendUint(nil, value.GetUint64(), 10), nil
case types.KindFloat32:
prec := -1
if colInfo.Decimal > 0 && int(colInfo.Decimal) != mysql.NotFixedDec {
prec = int(colInfo.Decimal)
func dumpTextRow(buffer []byte, columns []*ColumnInfo, row types.Row) ([]byte, error) {
tmp := make([]byte, 0, 20)
for i, col := range columns {
if row.IsNull(i) {
buffer = append(buffer, 0xfb)
continue
}
return strconv.AppendFloat(nil, value.GetFloat64(), 'f', prec, 32), nil
case types.KindFloat64:
prec := -1
if colInfo.Decimal > 0 && int(colInfo.Decimal) != mysql.NotFixedDec {
prec = int(colInfo.Decimal)
switch col.Type {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeYear, mysql.TypeInt24, mysql.TypeLong:
v, _ := row.GetInt64(i)
tmp = strconv.AppendInt(tmp[:0], v, 10)
buffer = dumpLengthEncodedString(buffer, tmp)
case mysql.TypeLonglong:
if mysql.HasUnsignedFlag(uint(columns[i].Flag)) {
v, _ := row.GetUint64(i)
tmp = strconv.AppendUint(tmp[:0], v, 10)
} else {
v, _ := row.GetInt64(i)
tmp = strconv.AppendInt(tmp[:0], v, 10)
}
buffer = dumpLengthEncodedString(buffer, tmp)
case mysql.TypeFloat:
prec := -1
if columns[i].Decimal > 0 && int(col.Decimal) != mysql.NotFixedDec {
prec = int(col.Decimal)
}
v, _ := row.GetFloat32(i)
tmp = strconv.AppendFloat(tmp[:0], float64(v), 'f', prec, 32)
buffer = dumpLengthEncodedString(buffer, tmp)
case mysql.TypeDouble:
prec := -1
if col.Decimal > 0 && int(col.Decimal) != mysql.NotFixedDec {
prec = int(col.Decimal)
}
v, _ := row.GetFloat64(i)
tmp = strconv.AppendFloat(tmp[:0], v, 'f', prec, 64)
buffer = dumpLengthEncodedString(buffer, tmp)
case mysql.TypeNewDecimal:
v, _ := row.GetMyDecimal(i)
buffer = dumpLengthEncodedString(buffer, hack.Slice(v.String()))
case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar,
mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob:
v, _ := row.GetBytes(i)
buffer = dumpLengthEncodedString(buffer, v)
case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp:
v, _ := row.GetTime(i)
buffer = dumpLengthEncodedString(buffer, hack.Slice(v.String()))
case mysql.TypeDuration:
v, _ := row.GetDuration(i)
buffer = dumpLengthEncodedString(buffer, hack.Slice(v.String()))
case mysql.TypeEnum:
v, _ := row.GetEnum(i)
buffer = dumpLengthEncodedString(buffer, hack.Slice(v.String()))
case mysql.TypeSet:
v, _ := row.GetSet(i)
buffer = dumpLengthEncodedString(buffer, hack.Slice(v.String()))
case mysql.TypeJSON:
v, _ := row.GetJSON(i)
buffer = dumpLengthEncodedString(buffer, hack.Slice(v.String()))
default:
return nil, errInvalidType.Gen("invalid type %v", columns[i].Type)
}
return strconv.AppendFloat(nil, value.GetFloat64(), 'f', prec, 64), nil
case types.KindString, types.KindBytes:
return value.GetBytes(), nil
case types.KindMysqlTime:
return hack.Slice(value.GetMysqlTime().String()), nil
case types.KindMysqlDuration:
return hack.Slice(value.GetMysqlDuration().String()), nil
case types.KindMysqlDecimal:
return hack.Slice(value.GetMysqlDecimal().String()), nil
case types.KindMysqlEnum:
return hack.Slice(value.GetMysqlEnum().String()), nil
case types.KindMysqlSet:
return hack.Slice(value.GetMysqlSet().String()), nil
case types.KindMysqlJSON:
return hack.Slice(value.GetMysqlJSON().String()), nil
case types.KindBinaryLiteral, types.KindMysqlBit:
return hack.Slice(value.GetBinaryLiteral().ToString()), nil
default:
return nil, errInvalidType.Gen("invalid type %v", value.Kind())
}
return buffer, nil
}