Skip to content

Commit

Permalink
Fix #17 - panic() after reusing prepared stmt (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
huebnerr authored and sitingren committed Aug 20, 2019
1 parent 81454c5 commit dc2aed2
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
2 changes: 1 addition & 1 deletion driver.go
Expand Up @@ -47,7 +47,7 @@ type Driver struct{}

const (
driverName string = "vertica-sql-go"
driverVersion string = "0.1.2"
driverVersion string = "0.1.3"
protocolVersion uint32 = 0x00030008
)

Expand Down
34 changes: 33 additions & 1 deletion driver_test.go
Expand Up @@ -122,7 +122,7 @@ func assertErr(t *testing.T, err error, errorSubstring string) {
return
}

t.Fatalf("expected an error, but it was '%s' instead of containing '%s'", errStr, errorSubstring)
t.Fatalf("expected an error containing '%s', but found '%s'", errorSubstring, errStr)
}

func assertNext(t *testing.T, rows *sql.Rows) {
Expand Down Expand Up @@ -477,6 +477,38 @@ func TestValueTypes(t *testing.T) {
assertNoErr(t, rows.Close())
}

func TestStmtReuseBug(t *testing.T) {
connDB := openConnection(t)
defer closeConnection(t, connDB)

var res bool

stmt, err := connDB.PrepareContext(ctx, "SELECT true AS res")
assertNoErr(t, err)

// first call
rows, err := stmt.QueryContext(ctx)
assertNoErr(t, err)

defer rows.Close()

assertNext(t, rows)
assertNoErr(t, rows.Scan(&res))
assertEqual(t, res, true)
assertNoNext(t, rows)

// second call
rows, err = stmt.QueryContext(ctx)
assertNoErr(t, err)

defer rows.Close()

assertNext(t, rows)
assertNoErr(t, rows.Scan(&res))
assertEqual(t, res, true)
assertNoNext(t, rows)
}

func init() {
logger.SetLogLevel(logger.INFO)

Expand Down
13 changes: 10 additions & 3 deletions stmt.go
Expand Up @@ -67,6 +67,7 @@ type stmt struct {
preparedName string
parseState parseState
paramTypes []common.ParameterType
lastRowDesc *msgs.BERowDescMsg
}

func newStmt(connection *connection, command string) (*stmt, error) {
Expand Down Expand Up @@ -176,7 +177,7 @@ func (s *stmt) QueryContextRaw(ctx context.Context, args []driver.NamedValue) (*
return s.collectResults()
}

// We aren't a prepared statement, manually interpolate and do as a simpe query.
// We aren't a prepared statement, manually interpolate and do as a simple query.
cmd, err = s.interpolate(args)

if err != nil {
Expand All @@ -196,9 +197,12 @@ func (s *stmt) QueryContextRaw(ctx context.Context, args []driver.NamedValue) (*

switch msg := bMsg.(type) {
case *msgs.BEDataRowMsg:
if rows == emptyRowSet {
rows = newRows(s.lastRowDesc, s.conn.serverTZOffset)
}
rows.addRow(msg)
case *msgs.BERowDescMsg:
rows = newRows(msg, s.conn.serverTZOffset)
s.lastRowDesc = msg
case *msgs.BECmdCompleteMsg:
break
case *msgs.BEErrorMsg:
Expand Down Expand Up @@ -341,9 +345,12 @@ func (s *stmt) collectResults() (*rows, error) {

switch msg := bMsg.(type) {
case *msgs.BEDataRowMsg:
if rows == emptyRowSet {
rows = newRows(s.lastRowDesc, s.conn.serverTZOffset)
}
rows.addRow(msg)
case *msgs.BERowDescMsg:
rows = newRows(msg, s.conn.serverTZOffset)
s.lastRowDesc = msg
case *msgs.BEErrorMsg:
return emptyRowSet, msg.ToErrorType()
case *msgs.BEEmptyQueryResponseMsg:
Expand Down

0 comments on commit dc2aed2

Please sign in to comment.