Permalink
Browse files

New methods: rou.ScanRow row.MakeRow

  • Loading branch information...
1 parent 441a576 commit 463fc9889af1452938f137bccc32baed4e228ab6 @ziutek committed May 14, 2012
Showing with 85 additions and 68 deletions.
  1. +2 −0 mysql/interface.go
  2. +19 −0 mysql/utils.go
  3. +1 −1 native/errors.go
  4. +1 −0 native/init.go
  5. +37 −35 native/mysql.go
  6. +13 −21 native/result.go
  7. +12 −11 thrsafe/thrsafe.go
View
@@ -52,6 +52,7 @@ type Stmt interface {
type Result interface {
StatusOnly() bool
+ ScanRow(Row) error
GetRow() (Row, error)
MoreResults() bool
@@ -64,6 +65,7 @@ type Result interface {
InsertId() uint64
WarnCount() int
+ MakeRow() Row
GetRows() ([]Row, error)
End() error
}
View
@@ -1,5 +1,9 @@
package mysql
+import (
+ "io"
+)
+
// This call Start and next call GetRow as long as it reads all rows from the
// result. Next it returns all readed rows as the slice of rows.
func Query(c Conn, sql string, params ...interface{}) (rows []Row, res Result, err error) {
@@ -22,6 +26,21 @@ func Exec(s Stmt, params ...interface{}) (rows []Row, res Result, err error) {
return
}
+// Calls r.MakeRow and next r.ScanRow. Doesn't return io.EOF error (returns nil
+// row insted).
+func GetRow(r Result) (Row, error) {
+ row := r.MakeRow()
+ err := r.ScanRow(row)
+ if err == io.EOF {
+ return nil, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ return row, nil
+}
+
+
// Read all unreaded rows and discard them. This function is useful if you
// don't want to use the remaining rows. It has an impact only on current
// result. If there is multi result query, you must use NextResult method and
View
@@ -19,7 +19,7 @@ var (
UNREADED_REPLY_ERROR = errors.New("reply is not completely read")
BIND_COUNT_ERROR = errors.New("wrong number of values for bind")
BIND_UNK_TYPE = errors.New("unknown value type for bind")
- RESULT_COUNT_ERROR = errors.New("wrong number of result columns")
+ ROW_LENGTH_ERROR = errors.New("wrong length of row slice")
BAD_COMMAND_ERROR = errors.New("comand isn't text SQL nor *Stmt")
WRONG_DATE_LEN_ERROR = errors.New("wrong datetime/timestamp length")
WRONG_TIME_LEN_ERROR = errors.New("wrong time length")
View
@@ -39,6 +39,7 @@ func (my *Conn) auth() {
flags := uint32(
_CLIENT_PROTOCOL_41 |
_CLIENT_LONG_PASSWORD |
+ _CLIENT_LONG_FLAG |
_CLIENT_SECURE_CONN |
_CLIENT_MULTI_STATEMENTS |
_CLIENT_MULTI_RESULTS |
View
@@ -129,7 +129,7 @@ func (my *Conn) connect() (err error) {
// Initialisation
my.init()
my.auth()
- my.getResult(nil)
+ my.getResult(nil, nil)
// Execute all registered commands
for _, cmd := range my.init_cmds {
@@ -138,18 +138,15 @@ func (my *Conn) connect() (err error) {
// Get command response
res := my.getResponse()
- if res.field_count == 0 {
+ if res.StatusOnly() {
// No fields in result (OK result)
continue
}
// Read and discard all result rows
- var row mysql.Row
+ row := res.MakeRow()
for {
- row, err = res.getRow()
- if err != nil {
- return
- }
- if row == nil {
+ err = res.getRow(row)
+ if err == io.EOF {
res, err = res.nextResult()
if err != nil {
return
@@ -158,6 +155,10 @@ func (my *Conn) connect() (err error) {
// No more rows and results from this cmd
break
}
+ row = res.MakeRow()
+ }
+ if err != nil {
+ return
}
}
}
@@ -254,16 +255,16 @@ func (my *Conn) Use(dbname string) (err error) {
// Send command
my.sendCmd(_COM_INIT_DB, dbname)
// Get server response
- my.getResult(nil)
+ my.getResult(nil, nil)
// Save new database name if no errors
my.dbname = dbname
return
}
func (my *Conn) getResponse() (res *Result) {
- res, ok := my.getResult(nil).(*Result)
- if !ok {
+ res = my.getResult(nil, nil)
+ if res == nil {
panic(BAD_RESULT_ERROR)
}
my.unreaded_reply = !res.StatusOnly()
@@ -296,21 +297,13 @@ func (my *Conn) Start(sql string, params ...interface{}) (res mysql.Result, err
return
}
-func (res *Result) getRow() (row mysql.Row, err error) {
+func (res *Result) getRow(row mysql.Row) (err error) {
defer catchError(&err)
- switch result := res.my.getResult(res).(type) {
- case mysql.Row:
- // Row of data
- row = result
-
- case *Result:
- // EOF result
-
- default:
- err = BAD_RESULT_ERROR
+ if res.my.getResult(res, row) != nil {
+ return io.EOF
}
- return
+ return nil
}
// Returns true if more results exixts. You don't have to call it before
@@ -319,26 +312,35 @@ func (res *Result) MoreResults() bool {
return res.status&_SERVER_MORE_RESULTS_EXISTS != 0
}
-// Get the data row from server. This method reads one row of result directly
-// from network connection (without rows buffering on client side).
-func (res *Result) GetRow() (row mysql.Row, err error) {
+// Get the data row from server. This method reads one row of result set
+// directly from network connection (without rows buffering on client side).
+// Returns io.EOF if there is no more rows in current result set.
+func (res *Result) ScanRow(row mysql.Row) error {
+ if row == nil {
+ return ROW_LENGTH_ERROR
+ }
if res.eor_returned {
- err = READ_AFTER_EOR_ERROR
- return
+ return READ_AFTER_EOR_ERROR
}
- if res.field_count == 0 {
+ if res.StatusOnly() {
// There is no fields in result (OK result)
res.eor_returned = true
- return
+ return io.EOF
}
- row, err = res.getRow()
- if err == nil && row == nil {
+ err := res.getRow(row)
+ if err == io.EOF {
res.eor_returned = true
if !res.MoreResults() {
res.my.unreaded_reply = false
}
}
- return
+ return err
+}
+
+// Like ScanRow but allocates memory for every row.
+// Returns nil row insted of io.EOF error.
+func (res *Result) GetRow() (mysql.Row, error) {
+ return mysql.GetRow(res)
}
func (res *Result) nextResult() (next *Result, err error) {
@@ -374,7 +376,7 @@ func (my *Conn) Ping() (err error) {
// Send command
my.sendCmd(_COM_PING)
// Get server response
- my.getResult(nil)
+ my.getResult(nil, nil)
return
}
@@ -570,7 +572,7 @@ func (stmt *Stmt) Reset() (err error) {
// Send command
stmt.my.sendCmd(_COM_STMT_RESET, stmt.id)
// Get result
- stmt.my.getResult(nil)
+ stmt.my.getResult(nil, nil)
return
}
View
@@ -67,7 +67,11 @@ func (res *Result) WarnCount() int {
return res.warning_count
}
-func (my *Conn) getResult(res *Result) interface{} {
+func (res *Result) MakeRow() mysql.Row {
+ return make(mysql.Row, res.field_count)
+}
+
+func (my *Conn) getResult(res *Result, row mysql.Row) *Result {
loop:
pr := my.newPktReader() // New reader for next packet
pkt0 := readByte(pr)
@@ -109,11 +113,15 @@ loop:
case pkt0 < 254 && res.field_count == len(res.fields):
// Row Data Packet
+ if len(row) != res.field_count {
+ panic(ROW_LENGTH_ERROR)
+ }
if res.binary {
- return my.getBinRowPacket(pr, res)
+ my.getBinRowPacket(pr, res, row)
} else {
- return my.getTextRowPacket(pr, res)
+ my.getTextRowPacket(pr, res, row)
}
+ return nil
}
}
panic(UNK_RESULT_PKT_ERROR)
@@ -226,13 +234,12 @@ func (my *Conn) getFieldPacket(pr *pktReader) (field *mysql.Field) {
return
}
-func (my *Conn) getTextRowPacket(pr *pktReader, res *Result) mysql.Row {
+func (my *Conn) getTextRowPacket(pr *pktReader, res *Result, row mysql.Row) {
if my.Debug {
log.Printf("[%2d ->] Text row data packet", my.seq-1)
}
pr.unreadByte()
- row := make(mysql.Row, res.field_count)
for ii := 0; ii < res.field_count; ii++ {
bin, null := readNullBin(pr)
if null {
@@ -242,11 +249,9 @@ func (my *Conn) getTextRowPacket(pr *pktReader, res *Result) mysql.Row {
}
}
pr.checkEof()
-
- return row
}
-func (my *Conn) getBinRowPacket(pr *pktReader, res *Result) mysql.Row {
+func (my *Conn) getBinRowPacket(pr *pktReader, res *Result, row mysql.Row) {
if my.Debug {
log.Printf("[%2d ->] Binary row data packet", my.seq-1)
}
@@ -255,7 +260,6 @@ func (my *Conn) getBinRowPacket(pr *pktReader, res *Result) mysql.Row {
null_bitmap := make([]byte, (res.field_count+7+2)>>3)
readFull(pr, null_bitmap)
- row := make(mysql.Row, res.field_count)
for ii, field := range res.fields {
null_byte := (ii + 2) >> 3
null_mask := byte(1) << uint(2+ii-(null_byte<<3))
@@ -273,60 +277,48 @@ func (my *Conn) getBinRowPacket(pr *pktReader, res *Result) mysql.Row {
} else {
row[ii] = int8(readByte(pr))
}
-
case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR:
if unsigned {
row[ii] = readU16(pr)
} else {
row[ii] = int16(readU16(pr))
}
-
case MYSQL_TYPE_LONG, MYSQL_TYPE_INT24:
if unsigned {
row[ii] = readU32(pr)
} else {
row[ii] = int32(readU32(pr))
}
-
case MYSQL_TYPE_LONGLONG:
if unsigned {
row[ii] = readU64(pr)
} else {
row[ii] = int64(readU64(pr))
}
-
case MYSQL_TYPE_FLOAT:
row[ii] = math.Float32frombits(readU32(pr))
-
case MYSQL_TYPE_DOUBLE:
row[ii] = math.Float64frombits(readU64(pr))
-
case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL:
dec := string(readBin(pr))
var err error
row[ii], err = strconv.ParseFloat(dec, 64)
if err != nil {
panic("MySQL server returned wrong decimal value: " + dec)
}
-
case MYSQL_TYPE_STRING, MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_VARCHAR,
MYSQL_TYPE_BIT, MYSQL_TYPE_BLOB, MYSQL_TYPE_TINY_BLOB,
MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_SET,
MYSQL_TYPE_ENUM, MYSQL_TYPE_GEOMETRY:
row[ii] = readBin(pr)
-
case MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE:
row[ii] = readDate(pr)
-
case MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIMESTAMP:
row[ii] = readTime(pr)
-
case MYSQL_TYPE_TIME:
row[ii] = readDuration(pr)
-
default:
panic(UNK_MYSQL_TYPE_ERROR)
}
}
- return row
}
View
@@ -8,6 +8,7 @@ import (
//"log"
"github.com/ziutek/mymysql/mysql"
_ "github.com/ziutek/mymysql/native"
+ "io"
)
type Conn struct {
@@ -84,23 +85,23 @@ func (c *Conn) Start(sql string, params ...interface{}) (mysql.Result, error) {
c.unlock()
return nil, err
}
- if len(res.Fields()) == 0 {
+ if res.StatusOnly() {
c.unlock()
}
return &Result{Result: res, conn: c}, err
}
-func (res *Result) GetRow() (mysql.Row, error) {
- //log.Println("GetRow")
- if len(res.Result.Fields()) == 0 {
- // There is no fields in result (OK result)
- return nil, nil
- }
- row, err := res.Result.GetRow()
- if err != nil || row == nil && !res.MoreResults() {
+func (res *Result) ScanRow(row mysql.Row) error {
+ //log.Println("ScanRow")
+ err := res.Result.ScanRow(row)
+ if err != nil && (err != io.EOF || !res.StatusOnly() && !res.MoreResults()) {
res.conn.unlock()
}
- return row, err
+ return err
+}
+
+func (res *Result) GetRow() (mysql.Row, error) {
+ return mysql.GetRow(res)
}
func (res *Result) NextResult() (mysql.Result, error) {
@@ -138,7 +139,7 @@ func (stmt *Stmt) Run(params ...interface{}) (mysql.Result, error) {
stmt.conn.unlock()
return nil, err
}
- if len(res.Fields()) == 0 {
+ if res.StatusOnly() {
stmt.conn.unlock()
}
return &Result{Result: res, conn: stmt.conn}, nil

0 comments on commit 463fc98

Please sign in to comment.