New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: use types.Row to write result set. #5056

Merged
merged 7 commits into from Nov 10, 2017
@@ -82,20 +82,24 @@ func (af *avgFunction) Update(ctx *AggEvaluateContext, sc *variable.StatementCon
// GetResult implements Aggregation interface.
func (af *avgFunction) GetResult(ctx *AggEvaluateContext) (d types.Datum) {
var x *types.MyDecimal
switch ctx.Value.Kind() {
case types.KindFloat64:
t := ctx.Value.GetFloat64() / float64(ctx.Count)
d.SetValue(t)
case types.KindMysqlDecimal:
x := ctx.Value.GetMysqlDecimal()
y := types.NewDecFromInt(ctx.Count)
to := new(types.MyDecimal)
err := types.DecimalDiv(x, y, to, types.DivFracIncr)
terror.Log(errors.Trace(err))
err = to.Round(to, ctx.Value.Frac()+types.DivFracIncr, types.ModeHalfEven)
x = new(types.MyDecimal)
err := x.FromFloat64(ctx.Value.GetFloat64())
terror.Log(errors.Trace(err))
d.SetMysqlDecimal(to)
case types.KindMysqlDecimal:
x = ctx.Value.GetMysqlDecimal()
default:
return
}
y := types.NewDecFromInt(ctx.Count)
to := new(types.MyDecimal)
err := types.DecimalDiv(x, y, to, types.DivFracIncr)
terror.Log(errors.Trace(err))
err = to.Round(to, ctx.Value.Frac()+types.DivFracIncr, types.ModeHalfEven)
terror.Log(errors.Trace(err))
d.SetMysqlDecimal(to)
return
}
@@ -15,10 +15,12 @@ package aggregation
import (
log "github.com/Sirupsen/logrus"
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/types"
)
@@ -42,6 +44,13 @@ func (sf *sumFunction) Update(ctx *AggEvaluateContext, sc *variable.StatementCon
// GetResult implements Aggregation interface.
func (sf *sumFunction) GetResult(ctx *AggEvaluateContext) (d types.Datum) {
if ctx.Value.Kind() == types.KindFloat64 {
dec := new(types.MyDecimal)
err := dec.FromFloat64(ctx.Value.GetFloat64())
terror.Log(errors.Trace(err))
d.SetMysqlDecimal(dec)
return
}
return ctx.Value
}
View
@@ -813,9 +813,6 @@ func (cc *clientConn) writeResultset(rs ResultSet, binary bool, more bool) error
if err = cc.writeEOF(false); err != nil {
return errors.Trace(err)
}
numBytes4Null := ((len(columns) + 7 + 2) / 8)
rowBuffer := make([]byte, 1+numBytes4Null, 1+numBytes4Null+8*(len(columns)))
for {
if err != nil {
return errors.Trace(err)
@@ -825,24 +822,14 @@ func (cc *clientConn) writeResultset(rs ResultSet, binary bool, more bool) error
}
data = data[0:4]
if binary {
rowBuffer = rowBuffer[0 : 1+numBytes4Null : cap(rowBuffer)]
rowBuffer, err = dumpRowValuesBinary(rowBuffer, columns, row)
data, err = dumpBinaryRow(data, columns, row)
if err != nil {
return errors.Trace(err)
}
data = append(data, rowBuffer...)
} else {
for i, value := range row {
if value.IsNull() {
data = append(data, 0xfb)
continue
}
var valData []byte
valData, err = dumpTextValue(columns[i], value)
if err != nil {
return errors.Trace(err)
}
data = dumpLengthEncodedString(data, valData)
data, err = dumpTextRow(data, columns, row)
if err != nil {
return errors.Trace(err)
}
}
View
@@ -121,6 +121,6 @@ type PreparedStatement interface {
// ResultSet is the result set of an query.
type ResultSet interface {
Columns() ([]*ColumnInfo, error)
Next() ([]types.Datum, error)
Next() (types.Row, error)
Close() error
}
View
@@ -292,13 +292,13 @@ type tidbResultSet struct {
recordSet ast.RecordSet
}
func (trs *tidbResultSet) Next() ([]types.Datum, error) {
func (trs *tidbResultSet) Next() (types.Row, error) {
row, err := trs.recordSet.Next()
if err != nil {
return nil, errors.Trace(err)
}
if row != nil {
return row.Data, nil
return types.DatumRow(row.Data), nil
}
return nil, nil
}
View
@@ -723,6 +723,28 @@ func runTestTLSConnection(t *C, overrider configOverrider) error {
return err
}
func runTestSumAvg(c *C) {
runTests(c, nil, func(dbt *DBTest) {
dbt.mustExec("create table sumavg (a int, b decimal, c double)")
dbt.mustExec("insert sumavg values (1, 1, 1)")
rows := dbt.mustQuery("select sum(a), sum(b), sum(c) from sumavg")
c.Assert(rows.Next(), IsTrue)
var outA, outB, outC float64
err := rows.Scan(&outA, &outB, &outC)
c.Assert(err, IsNil)
c.Assert(outA, Equals, 1.0)
c.Assert(outB, Equals, 1.0)
c.Assert(outC, Equals, 1.0)
rows = dbt.mustQuery("select avg(a), avg(b), avg(c) from sumavg")
c.Assert(rows.Next(), IsTrue)
err = rows.Scan(&outA, &outB, &outC)
c.Assert(err, IsNil)
c.Assert(outA, Equals, 1.0)
c.Assert(outB, Equals, 1.0)
c.Assert(outC, Equals, 1.0)
})
}
func getMetrics(t *C) []byte {
resp, err := http.Get("http://127.0.0.1:10090/metrics")
t.Assert(err, IsNil)
View
@@ -419,5 +419,11 @@ func (ts *TidbTestSuite) TestShowCreateTableFlen(c *C) {
c.Assert(err, IsNil)
c.Assert(len(cols), Equals, 2)
c.Assert(int(cols[0].ColumnLength), Equals, 5*tmysql.MaxBytesOfCharacter)
c.Assert(int(cols[1].ColumnLength), Equals, len(row[1].GetString())*tmysql.MaxBytesOfCharacter)
str, _ := row.GetString(1)
c.Assert(int(cols[1].ColumnLength), Equals, len(str)*tmysql.MaxBytesOfCharacter)
}
func (ts *TidbTestSuite) TestSumAvg(c *C) {
c.Parallel()
runTestSumAvg(c)
}
Oops, something went wrong.
ProTip! Use n and p to navigate between commits in a pull request.