Skip to content

Commit

Permalink
Merge pull request cockroachdb#4567 from petermattis/pmattis/use-pq-d…
Browse files Browse the repository at this point in the history
…irectly

cli: use lib/pq directly
  • Loading branch information
petermattis committed Feb 22, 2016
2 parents d86fee8 + bcd20c7 commit 6d55a75
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 112 deletions.
21 changes: 10 additions & 11 deletions cli/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package cli

import (
"database/sql"
"fmt"
"io"
"net/url"
Expand Down Expand Up @@ -118,7 +117,7 @@ func handleInputLine(stmt *[]string, line string) int {

// preparePrompts computes a full and short prompt for the interactive
// CLI.
func preparePrompts(db *sql.DB, dbURL string) (fullPrompt string, continuePrompt string) {
func preparePrompts(dbURL string) (fullPrompt string, continuePrompt string) {
// Default prompt is part of the connection URL. eg: "marc@localhost>"
// continued statement prompt is: " -> "
fullPrompt = dbURL
Expand All @@ -142,8 +141,8 @@ func preparePrompts(db *sql.DB, dbURL string) (fullPrompt string, continuePrompt

// runInteractive runs the SQL client interactively, presenting
// a prompt to the user for each statement.
func runInteractive(db *sql.DB, dbURL string) (exitErr error) {
fullPrompt, continuePrompt := preparePrompts(db, dbURL)
func runInteractive(conn *sqlConn) (exitErr error) {
fullPrompt, continuePrompt := preparePrompts(conn.url)

if isatty.IsTerminal(os.Stdout.Fd()) {
// We only enable history management when the terminal is actually
Expand Down Expand Up @@ -211,7 +210,7 @@ func runInteractive(db *sql.DB, dbURL string) (exitErr error) {
readline.SetHistoryPath("")
}

if exitErr = runPrettyQuery(db, os.Stdout, fullStmt); exitErr != nil {
if exitErr = runPrettyQuery(conn, os.Stdout, fullStmt); exitErr != nil {
fmt.Fprintln(osStderr, exitErr)
}

Expand All @@ -224,10 +223,10 @@ func runInteractive(db *sql.DB, dbURL string) (exitErr error) {

// runOneStatement executes one statement and terminates
// on error.
func runStatements(db *sql.DB, stmts []string) error {
func runStatements(conn *sqlConn, stmts []string) error {
for _, stmt := range stmts {
fullStmt := stmt + "\n"
cols, allRows, err := runQuery(db, fullStmt)
cols, allRows, err := runQuery(conn, fullStmt)
if err != nil {
fmt.Fprintln(osStderr, err)
return err
Expand Down Expand Up @@ -261,12 +260,12 @@ func runTerm(cmd *cobra.Command, args []string) error {
return errMissingParams
}

db, dbURL := makeSQLClient()
defer func() { _ = db.Close() }()
conn := makeSQLClient()
defer conn.Close()

if cliContext.OneShotSQL {
// Single-line sql; run as simple as possible, without noise on stdout.
return runStatements(db, args)
return runStatements(conn, args)
}
return runInteractive(db, dbURL)
return runInteractive(conn)
}
145 changes: 113 additions & 32 deletions cli/sql_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,107 @@ package cli

import (
"bytes"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"net"

// Import postgres driver.
_ "github.com/lib/pq"
"github.com/lib/pq"

"github.com/olekukonko/tablewriter"

"github.com/cockroachdb/cockroach/util/log"
)

func makeSQLClient() (*sql.DB, string) {
type sqlConnI interface {
driver.Conn
driver.Queryer
}

type sqlConn struct {
url string
conn sqlConnI
}

func (c *sqlConn) ensureConn() error {
if c.conn == nil {
conn, err := pq.Open(c.url)
if err != nil {
return err
}
c.conn = conn.(sqlConnI)
}
return nil
}

func (c *sqlConn) Query(query string, args []driver.Value) (*sqlRows, error) {
if err := c.ensureConn(); err != nil {
return nil, err
}
rows, err := c.conn.Query(query, args)
if err == driver.ErrBadConn {
c.Close()
}
if err != nil {
return nil, err
}
return &sqlRows{Rows: rows, conn: c}, nil
}

func (c *sqlConn) Close() {
if c.conn != nil {
err := c.conn.Close()
if err != nil && err != driver.ErrBadConn {
log.Info(err)
}
c.conn = nil
}
}

type sqlRows struct {
driver.Rows
conn *sqlConn
}

func (r *sqlRows) Close() error {
err := r.Rows.Close()
if err == driver.ErrBadConn {
r.conn.Close()
}
return err
}

func (r *sqlRows) Next(values []driver.Value) error {
err := r.Rows.Next(values)
if err == driver.ErrBadConn {
r.conn.Close()
}
return err
}

func makeSQLConn(url string) *sqlConn {
return &sqlConn{
url: url,
}
}

func makeSQLClient() *sqlConn {
sqlURL := connURL
if len(connURL) == 0 {
tmpCtx := cliContext
tmpCtx.PGAddr = net.JoinHostPort(connHost, connPGPort)
sqlURL = tmpCtx.PGURL(connUser)
}
db, err := sql.Open("postgres", sqlURL)
if err != nil {
panicf("failed to initialize SQL client: %s", err)
}
return db, sqlURL
return makeSQLConn(sqlURL)
}

// fmtMap is a mapping from column name to a function that takes the raw input,
// and outputs the string to be displayed.
type fmtMap map[string]func(interface{}) string
type fmtMap map[string]func(driver.Value) string

// runQuery takes a 'query' with optional 'parameters'.
// It runs the sql query and returns a list of columns names and a list of rows.
func runQuery(db *sql.DB, query string, parameters ...interface{}) (
func runQuery(db *sqlConn, query string, parameters ...driver.Value) (
[]string, [][]string, error) {
return runQueryWithFormat(db, nil, query, parameters...)
}
Expand All @@ -58,20 +127,34 @@ func runQuery(db *sql.DB, query string, parameters ...interface{}) (
// It runs the sql query and returns a list of columns names and a list of rows.
// If 'format' is not nil, the values with column name
// found in the map are run through the corresponding callback.
func runQueryWithFormat(db *sql.DB, format fmtMap, query string, parameters ...interface{}) (
func runQueryWithFormat(db *sqlConn, format fmtMap, query string, parameters ...driver.Value) (
[]string, [][]string, error) {
rows, err := db.Query(query, parameters...)
// driver.Value is an alias for interface{}, but must adhere to a restricted
// set of types when being passed to driver.Queryer.Query (see
// driver.IsValue). We use driver.DefaultParameterConverter to perform the
// necessary conversion. This is usually taken care of by the sql package,
// but we have to do so manually because we're talking directly to the
// driver.
for i := range parameters {
var err error
parameters[i], err = driver.DefaultParameterConverter.ConvertValue(parameters[i])
if err != nil {
return nil, nil, err
}
}

rows, err := db.Query(query, parameters)
if err != nil {
return nil, nil, fmt.Errorf("query error: %s", err)
}

defer rows.Close()
defer func() { _ = rows.Close() }()
return sqlRowsToStrings(rows, format)
}

// runPrettyQueryWithFormat takes a 'query' with optional 'parameters'.
// It runs the sql query and writes pretty output to 'w'.
func runPrettyQuery(db *sql.DB, w io.Writer, query string, parameters ...interface{}) error {
func runPrettyQuery(db *sqlConn, w io.Writer, query string, parameters ...driver.Value) error {
cols, allRows, err := runQuery(db, query, parameters...)
if err != nil {
return err
Expand All @@ -88,32 +171,30 @@ func runPrettyQuery(db *sql.DB, w io.Writer, query string, parameters ...interfa
// It returns the header row followed by all data rows.
// If both the header row and list of rows are empty, it means no row
// information was returned (eg: statement was not a query).
func sqlRowsToStrings(rows *sql.Rows, format fmtMap) ([]string, [][]string, error) {
cols, err := rows.Columns()
if err != nil {
return nil, nil, fmt.Errorf("rows.Columns() error: %s", err)
}
func sqlRowsToStrings(rows *sqlRows, format fmtMap) ([]string, [][]string, error) {
cols := rows.Columns()

if len(cols) == 0 {
return nil, nil, nil
}

vals := make([]interface{}, len(cols))
for i := range vals {
vals[i] = new(interface{})
}

vals := make([]driver.Value, len(cols))
allRows := [][]string{}
for rows.Next() {
rowStrings := make([]string, len(cols))
if err := rows.Scan(vals...); err != nil {
return nil, nil, fmt.Errorf("scan error: %s", err)

for {
err := rows.Next(vals)
if err == io.EOF {
break
}
if err != nil {
return nil, nil, err
}
rowStrings := make([]string, len(cols))
for i, v := range vals {
if f, ok := format[cols[i]]; ok {
rowStrings[i] = f(*v.(*interface{}))
rowStrings[i] = f(v)
} else {
rowStrings[i] = formatVal(*v.(*interface{}))
rowStrings[i] = formatVal(v)
}
}
allRows = append(allRows, rowStrings)
Expand Down Expand Up @@ -144,7 +225,7 @@ func printQueryOutput(w io.Writer, cols []string, allRows [][]string) {
table.Render()
}

func formatVal(val interface{}) string {
func formatVal(val driver.Value) string {
switch t := val.(type) {
case nil:
return "NULL"
Expand Down
Loading

0 comments on commit 6d55a75

Please sign in to comment.