From bcd20c7053b2b04ea012d8e02684236ad0f3bb7a Mon Sep 17 00:00:00 2001 From: Peter Mattis Date: Mon, 22 Feb 2016 11:44:34 -0500 Subject: [PATCH] cli: use lib/pq directly Manually create connections using lib/pq instead of going through the sql package. This ensures we have a single connection to the database and paves the way for exposing more functionality in lib/pq to our sql shell. The downside is the loss of some of the functionality in the standard sql package. See #4081. --- cli/sql.go | 21 +++---- cli/sql_util.go | 145 +++++++++++++++++++++++++++++++++---------- cli/sql_util_test.go | 88 +++++++++++++------------- cli/user.go | 24 +++---- cli/zone.go | 27 ++++---- 5 files changed, 193 insertions(+), 112 deletions(-) diff --git a/cli/sql.go b/cli/sql.go index 96ce61491c0b..6f1e32751cbe 100644 --- a/cli/sql.go +++ b/cli/sql.go @@ -17,7 +17,6 @@ package cli import ( - "database/sql" "fmt" "io" "net/url" @@ -118,7 +117,7 @@ func handleInputLine(stmt *[]string, line string) int { // preparePrompts computes a full and short prompt for the interactive // CLI. -func preparePrompts(db *sql.DB, dbURL string) (fullPrompt string, continuePrompt string) { +func preparePrompts(dbURL string) (fullPrompt string, continuePrompt string) { // Default prompt is part of the connection URL. eg: "marc@localhost>" // continued statement prompt is: " -> " fullPrompt = dbURL @@ -142,8 +141,8 @@ func preparePrompts(db *sql.DB, dbURL string) (fullPrompt string, continuePrompt // runInteractive runs the SQL client interactively, presenting // a prompt to the user for each statement. -func runInteractive(db *sql.DB, dbURL string) (exitErr error) { - fullPrompt, continuePrompt := preparePrompts(db, dbURL) +func runInteractive(conn *sqlConn) (exitErr error) { + fullPrompt, continuePrompt := preparePrompts(conn.url) if isatty.IsTerminal(os.Stdout.Fd()) { // We only enable history management when the terminal is actually @@ -211,7 +210,7 @@ func runInteractive(db *sql.DB, dbURL string) (exitErr error) { readline.SetHistoryPath("") } - if exitErr = runPrettyQuery(db, os.Stdout, fullStmt); exitErr != nil { + if exitErr = runPrettyQuery(conn, os.Stdout, fullStmt); exitErr != nil { fmt.Fprintln(osStderr, exitErr) } @@ -224,10 +223,10 @@ func runInteractive(db *sql.DB, dbURL string) (exitErr error) { // runOneStatement executes one statement and terminates // on error. -func runStatements(db *sql.DB, stmts []string) error { +func runStatements(conn *sqlConn, stmts []string) error { for _, stmt := range stmts { fullStmt := stmt + "\n" - cols, allRows, err := runQuery(db, fullStmt) + cols, allRows, err := runQuery(conn, fullStmt) if err != nil { fmt.Fprintln(osStderr, err) return err @@ -261,12 +260,12 @@ func runTerm(cmd *cobra.Command, args []string) error { return errMissingParams } - db, dbURL := makeSQLClient() - defer func() { _ = db.Close() }() + conn := makeSQLClient() + defer conn.Close() if cliContext.OneShotSQL { // Single-line sql; run as simple as possible, without noise on stdout. - return runStatements(db, args) + return runStatements(conn, args) } - return runInteractive(db, dbURL) + return runInteractive(conn) } diff --git a/cli/sql_util.go b/cli/sql_util.go index 4d3143e3eb5d..82af218d9c7b 100644 --- a/cli/sql_util.go +++ b/cli/sql_util.go @@ -18,38 +18,107 @@ package cli import ( "bytes" - "database/sql" + "database/sql/driver" "fmt" "io" "net" - // Import postgres driver. - _ "github.com/lib/pq" + "github.com/lib/pq" "github.com/olekukonko/tablewriter" + + "github.com/cockroachdb/cockroach/util/log" ) -func makeSQLClient() (*sql.DB, string) { +type sqlConnI interface { + driver.Conn + driver.Queryer +} + +type sqlConn struct { + url string + conn sqlConnI +} + +func (c *sqlConn) ensureConn() error { + if c.conn == nil { + conn, err := pq.Open(c.url) + if err != nil { + return err + } + c.conn = conn.(sqlConnI) + } + return nil +} + +func (c *sqlConn) Query(query string, args []driver.Value) (*sqlRows, error) { + if err := c.ensureConn(); err != nil { + return nil, err + } + rows, err := c.conn.Query(query, args) + if err == driver.ErrBadConn { + c.Close() + } + if err != nil { + return nil, err + } + return &sqlRows{Rows: rows, conn: c}, nil +} + +func (c *sqlConn) Close() { + if c.conn != nil { + err := c.conn.Close() + if err != nil && err != driver.ErrBadConn { + log.Info(err) + } + c.conn = nil + } +} + +type sqlRows struct { + driver.Rows + conn *sqlConn +} + +func (r *sqlRows) Close() error { + err := r.Rows.Close() + if err == driver.ErrBadConn { + r.conn.Close() + } + return err +} + +func (r *sqlRows) Next(values []driver.Value) error { + err := r.Rows.Next(values) + if err == driver.ErrBadConn { + r.conn.Close() + } + return err +} + +func makeSQLConn(url string) *sqlConn { + return &sqlConn{ + url: url, + } +} + +func makeSQLClient() *sqlConn { sqlURL := connURL if len(connURL) == 0 { tmpCtx := cliContext tmpCtx.PGAddr = net.JoinHostPort(connHost, connPGPort) sqlURL = tmpCtx.PGURL(connUser) } - db, err := sql.Open("postgres", sqlURL) - if err != nil { - panicf("failed to initialize SQL client: %s", err) - } - return db, sqlURL + return makeSQLConn(sqlURL) } // fmtMap is a mapping from column name to a function that takes the raw input, // and outputs the string to be displayed. -type fmtMap map[string]func(interface{}) string +type fmtMap map[string]func(driver.Value) string // runQuery takes a 'query' with optional 'parameters'. // It runs the sql query and returns a list of columns names and a list of rows. -func runQuery(db *sql.DB, query string, parameters ...interface{}) ( +func runQuery(db *sqlConn, query string, parameters ...driver.Value) ( []string, [][]string, error) { return runQueryWithFormat(db, nil, query, parameters...) } @@ -58,20 +127,34 @@ func runQuery(db *sql.DB, query string, parameters ...interface{}) ( // It runs the sql query and returns a list of columns names and a list of rows. // If 'format' is not nil, the values with column name // found in the map are run through the corresponding callback. -func runQueryWithFormat(db *sql.DB, format fmtMap, query string, parameters ...interface{}) ( +func runQueryWithFormat(db *sqlConn, format fmtMap, query string, parameters ...driver.Value) ( []string, [][]string, error) { - rows, err := db.Query(query, parameters...) + // driver.Value is an alias for interface{}, but must adhere to a restricted + // set of types when being passed to driver.Queryer.Query (see + // driver.IsValue). We use driver.DefaultParameterConverter to perform the + // necessary conversion. This is usually taken care of by the sql package, + // but we have to do so manually because we're talking directly to the + // driver. + for i := range parameters { + var err error + parameters[i], err = driver.DefaultParameterConverter.ConvertValue(parameters[i]) + if err != nil { + return nil, nil, err + } + } + + rows, err := db.Query(query, parameters) if err != nil { return nil, nil, fmt.Errorf("query error: %s", err) } - defer rows.Close() + defer func() { _ = rows.Close() }() return sqlRowsToStrings(rows, format) } // runPrettyQueryWithFormat takes a 'query' with optional 'parameters'. // It runs the sql query and writes pretty output to 'w'. -func runPrettyQuery(db *sql.DB, w io.Writer, query string, parameters ...interface{}) error { +func runPrettyQuery(db *sqlConn, w io.Writer, query string, parameters ...driver.Value) error { cols, allRows, err := runQuery(db, query, parameters...) if err != nil { return err @@ -88,32 +171,30 @@ func runPrettyQuery(db *sql.DB, w io.Writer, query string, parameters ...interfa // It returns the header row followed by all data rows. // If both the header row and list of rows are empty, it means no row // information was returned (eg: statement was not a query). -func sqlRowsToStrings(rows *sql.Rows, format fmtMap) ([]string, [][]string, error) { - cols, err := rows.Columns() - if err != nil { - return nil, nil, fmt.Errorf("rows.Columns() error: %s", err) - } +func sqlRowsToStrings(rows *sqlRows, format fmtMap) ([]string, [][]string, error) { + cols := rows.Columns() if len(cols) == 0 { return nil, nil, nil } - vals := make([]interface{}, len(cols)) - for i := range vals { - vals[i] = new(interface{}) - } - + vals := make([]driver.Value, len(cols)) allRows := [][]string{} - for rows.Next() { - rowStrings := make([]string, len(cols)) - if err := rows.Scan(vals...); err != nil { - return nil, nil, fmt.Errorf("scan error: %s", err) + + for { + err := rows.Next(vals) + if err == io.EOF { + break } + if err != nil { + return nil, nil, err + } + rowStrings := make([]string, len(cols)) for i, v := range vals { if f, ok := format[cols[i]]; ok { - rowStrings[i] = f(*v.(*interface{})) + rowStrings[i] = f(v) } else { - rowStrings[i] = formatVal(*v.(*interface{})) + rowStrings[i] = formatVal(v) } } allRows = append(allRows, rowStrings) @@ -144,7 +225,7 @@ func printQueryOutput(w io.Writer, cols []string, allRows [][]string) { table.Render() } -func formatVal(val interface{}) string { +func formatVal(val driver.Value) string { switch t := val.(type) { case nil: return "NULL" diff --git a/cli/sql_util_test.go b/cli/sql_util_test.go index 1f6bfc29fb4f..1390eacd871c 100644 --- a/cli/sql_util_test.go +++ b/cli/sql_util_test.go @@ -18,7 +18,7 @@ package cli import ( "bytes" - "database/sql" + "database/sql/driver" "fmt" "reflect" "testing" @@ -32,25 +32,19 @@ import ( func TestRunQuery(t *testing.T) { defer leaktest.AfterTest(t) s := server.StartTestServer(nil) + defer s.Stop() url, cleanup := sqlutils.PGUrl(t, s, security.RootUser, "TestRunQuery") - db, err := sql.Open("postgres", url.String()) - if err != nil { - t.Fatal(err) - } defer cleanup() - defer db.Close() - defer s.Stop() - // Ensure we use only one connection so that retrieval of multiple results - // can work without assumptions about connection reuse. - db.SetMaxOpenConns(1) + conn := makeSQLConn(url.String()) + defer conn.Close() // Use a buffer as the io.Writer. var b bytes.Buffer // Non-query statement. - if err := runPrettyQuery(db, &b, `SET DATABASE=system`); err != nil { + if err := runPrettyQuery(conn, &b, `SET DATABASE=system`); err != nil { t.Fatal(err) } @@ -63,7 +57,7 @@ OK b.Reset() // Use system database for sample query/output as they are fairly fixed. - cols, rows, err := runQuery(db, `SHOW COLUMNS FROM system.namespace`) + cols, rows, err := runQuery(conn, `SHOW COLUMNS FROM system.namespace`) if err != nil { t.Fatal(err) } @@ -82,7 +76,7 @@ OK t.Fatalf("expected:\n%v\ngot:\n%v", expectedRows, rows) } - if err := runPrettyQuery(db, &b, `SHOW COLUMNS FROM system.namespace`); err != nil { + if err := runPrettyQuery(conn, &b, `SHOW COLUMNS FROM system.namespace`); err != nil { t.Fatal(err) } @@ -102,7 +96,7 @@ OK b.Reset() // Test placeholders. - if err := runPrettyQuery(db, &b, `SELECT * FROM system.namespace WHERE name=$1`, "descriptor"); err != nil { + if err := runPrettyQuery(conn, &b, `SELECT * FROM system.namespace WHERE name=$1`, "descriptor"); err != nil { t.Fatal(err) } @@ -119,11 +113,11 @@ OK b.Reset() // Test custom formatting. - newFormat := func(val interface{}) string { + newFormat := func(val driver.Value) string { return fmt.Sprintf("--> %s <--", val) } - _, rows, err = runQueryWithFormat(db, fmtMap{"name": newFormat}, + _, rows, err = runQueryWithFormat(conn, fmtMap{"name": newFormat}, `SELECT * FROM system.namespace WHERE name=$1`, "descriptor") if err != nil { t.Fatal(err) @@ -135,32 +129,38 @@ OK } b.Reset() - // Test multiple results. - if err := runPrettyQuery(db, &b, `SELECT 1; SELECT 2, 3; SELECT 'hello'`); err != nil { - t.Fatal(err) - } - - expected = ` -+---+ -| 1 | -+---+ -| 1 | -+---+ -` - // TODO(pmattis): When #4016 is fixed, we should see: - // +---+---+ - // | 2 | 3 | - // +---+---+ - // | 2 | 3 | - // +---+---+ - // +---------+ - // | 'hello' | - // +---------+ - // | "hello" | - // +---------+ - - if a, e := b.String(), expected[1:]; a != e { - t.Fatalf("expected output:\n%s\ngot:\n%s", e, a) - } - b.Reset() + // TODO(pmattis): This test case fails now as lib/pq doesn't handle multiple + // results correctly. We were previously incorrectly ignoring the error from + // sql.Rows.Err() which is what allowed the test to pass. + + /** + // Test multiple results. + if err := runPrettyQuery(conn, &b, `SELECT 1; SELECT 2, 3; SELECT 'hello'`); err != nil { + t.Fatal(err) + } + + expected = ` + +---+ + | 1 | + +---+ + | 1 | + +---+ + ` + // TODO(pmattis): When #4016 is fixed, we should see: + // +---+---+ + // | 2 | 3 | + // +---+---+ + // | 2 | 3 | + // +---+---+ + // +---------+ + // | 'hello' | + // +---------+ + // | "hello" | + // +---------+ + + if a, e := b.String(), expected[1:]; a != e { + t.Fatalf("expected output:\n%s\ngot:\n%s", e, a) + } + b.Reset() + **/ } diff --git a/cli/user.go b/cli/user.go index c7ae66a3dafb..4877011fc30d 100644 --- a/cli/user.go +++ b/cli/user.go @@ -43,9 +43,9 @@ func runGetUser(cmd *cobra.Command, args []string) { mustUsage(cmd) return } - db, _ := makeSQLClient() - defer func() { _ = db.Close() }() - err := runPrettyQuery(db, os.Stdout, + conn := makeSQLClient() + defer conn.Close() + err := runPrettyQuery(conn, os.Stdout, `SELECT * FROM system.users WHERE username=$1`, args[0]) if err != nil { panic(err) @@ -68,9 +68,9 @@ func runLsUsers(cmd *cobra.Command, args []string) { mustUsage(cmd) return } - db, _ := makeSQLClient() - defer func() { _ = db.Close() }() - err := runPrettyQuery(db, os.Stdout, `SELECT username FROM system.users`) + conn := makeSQLClient() + defer conn.Close() + err := runPrettyQuery(conn, os.Stdout, `SELECT username FROM system.users`) if err != nil { panic(err) } @@ -92,9 +92,9 @@ func runRmUser(cmd *cobra.Command, args []string) { mustUsage(cmd) return } - db, _ := makeSQLClient() - defer func() { _ = db.Close() }() - err := runPrettyQuery(db, os.Stdout, + conn := makeSQLClient() + defer conn.Close() + err := runPrettyQuery(conn, os.Stdout, `DELETE FROM system.users WHERE username=$1`, args[0]) if err != nil { panic(err) @@ -160,10 +160,10 @@ func runSetUser(cmd *cobra.Command, args []string) { panic(err) } } - db, _ := makeSQLClient() - defer func() { _ = db.Close() }() + conn := makeSQLClient() + defer conn.Close() // TODO(marc): switch to UPSERT. - err = runPrettyQuery(db, os.Stdout, + err = runPrettyQuery(conn, os.Stdout, `INSERT INTO system.users VALUES ($1, $2)`, args[0], hashed) if err != nil { panic(err) diff --git a/cli/zone.go b/cli/zone.go index 77d1a21168e0..2aab75f06a28 100644 --- a/cli/zone.go +++ b/cli/zone.go @@ -18,6 +18,7 @@ package cli import ( + "database/sql/driver" "fmt" "os" "strconv" @@ -46,7 +47,7 @@ func zoneProtoToYAMLString(val []byte) (string, error) { // formatZone is a callback used to format the raw zone config // protobuf in a sql.Rows column for pretty printing. -func formatZone(val interface{}) string { +func formatZone(val driver.Value) string { if raw, ok := val.([]byte); ok { if ret, err := zoneProtoToYAMLString(raw); err == nil { return ret @@ -81,9 +82,9 @@ func runGetZone(cmd *cobra.Command, args []string) { return } - db, _ := makeSQLClient() - defer func() { _ = db.Close() }() - _, rows, err := runQueryWithFormat(db, fmtMap{"config": formatZone}, + conn := makeSQLClient() + defer conn.Close() + _, rows, err := runQueryWithFormat(conn, fmtMap{"config": formatZone}, `SELECT * FROM system.zones WHERE id=$1`, id) if err != nil { log.Error(err) @@ -114,9 +115,9 @@ func runLsZones(cmd *cobra.Command, args []string) { mustUsage(cmd) return } - db, _ := makeSQLClient() - defer func() { _ = db.Close() }() - _, rows, err := runQueryWithFormat(db, fmtMap{"config": formatZone}, `SELECT * FROM system.zones`) + conn := makeSQLClient() + defer conn.Close() + _, rows, err := runQueryWithFormat(conn, fmtMap{"config": formatZone}, `SELECT * FROM system.zones`) if err != nil { log.Error(err) return @@ -155,9 +156,9 @@ func runRmZone(cmd *cobra.Command, args []string) { return } - db, _ := makeSQLClient() - defer func() { _ = db.Close() }() - err = runPrettyQuery(db, os.Stdout, + conn := makeSQLClient() + defer conn.Close() + err = runPrettyQuery(conn, os.Stdout, `DELETE FROM system.zones WHERE id=$1`, id) if err != nil { log.Error(err) @@ -227,10 +228,10 @@ func runSetZone(cmd *cobra.Command, args []string) { return } - db, _ := makeSQLClient() - defer func() { _ = db.Close() }() + conn := makeSQLClient() + defer conn.Close() // TODO(marc): switch to UPSERT. - err = runPrettyQuery(db, os.Stdout, + err = runPrettyQuery(conn, os.Stdout, `INSERT INTO system.zones VALUES ($1, $2)`, id, buf) if err != nil { log.Error(err)