From 689bd03a29e3dc9cacfe769f47ce2d25e9939464 Mon Sep 17 00:00:00 2001 From: Jackson Owens Date: Wed, 22 Apr 2015 12:56:04 -0700 Subject: [PATCH] Add support for query logging, Exec takes builders --- ast.go | 259 ++++++++++++++++++++++++++++++-------------- builder_test.go | 46 ++++++++ db.go | 136 ++++++++++++++++++----- db_internal_test.go | 21 +++- logger.go | 49 +++++++++ 5 files changed, 401 insertions(+), 110 deletions(-) create mode 100644 logger.go diff --git a/ast.go b/ast.go index e9823a7..577c166 100644 --- a/ast.go +++ b/ast.go @@ -26,20 +26,108 @@ import ( // The Serializer interface is implemented by all // expressions/statements. type Serializer interface { - // Serialize writes the statement/expression to the io.Writer. If an - // error is returned the io.Writer may contain partial output. - Serialize(w io.Writer) error + // Serialize writes the statement/expression to the Writer. If an + // error is returned the Writer may contain partial output. + Serialize(w Writer) error } // Serialize serializes a serializer to a string. func Serialize(s Serializer) (string, error) { - var buf bytes.Buffer - if err := s.Serialize(&buf); err != nil { + w := &standardWriter{} + if err := s.Serialize(w); err != nil { return "", err } - return buf.String(), nil + return w.String(), nil } +// SerializeWithPlaceholders serializes a serializer to a string but without substituting +// values. It may be useful for logging. +func SerializeWithPlaceholders(s Serializer) (string, error) { + w := &placeholderWriter{} + if err := s.Serialize(w); err != nil { + return "", err + } + return w.String(), nil +} + +// Writer defines an interface for writing a AST as SQL. +type Writer interface { + io.Writer + + // WriteBytes writes a string of unprintable value. + WriteBytes(node BytesVal) error + // WriteEncoded writes an already encoded value. + WriteEncoded(node EncodedVal) error + // WriteNum writes a number value. + WriteNum(node NumVal) error + // WriteRaw writes a raw Go value. + WriteRaw(node RawVal) error + // WriteStr writes a SQL string value. + WriteStr(node StrVal) error +} + +type standardWriter struct { + bytes.Buffer +} + +func (w *standardWriter) WriteRaw(node RawVal) error { + return encodeSQLValue(w, node.Val) +} + +func (w *standardWriter) WriteEncoded(node EncodedVal) error { + _, err := w.Write(node.Val) + return err +} + +func (w *standardWriter) WriteStr(node StrVal) error { + return encodeSQLString(w, string(node)) +} + +func (w *standardWriter) WriteBytes(node BytesVal) error { + return encodeSQLBytes(w, []byte(node)) +} + +func (w *standardWriter) WriteNum(node NumVal) error { + _, err := io.WriteString(w, string(node)) + return err +} + +// placeholderWriter will write all SQL value types as ? placeholders. +type placeholderWriter struct { + bytes.Buffer +} + +func (w *placeholderWriter) WriteRaw(node RawVal) error { + _, err := w.Write(astPlaceholder) + return err +} + +func (w *placeholderWriter) WriteEncoded(node EncodedVal) error { + _, err := w.Write(astPlaceholder) + return err +} + +func (w *placeholderWriter) WriteStr(node StrVal) error { + _, err := w.Write(astPlaceholder) + return err +} + +func (w *placeholderWriter) WriteBytes(node BytesVal) error { + _, err := w.Write(astPlaceholder) + return err +} + +func (w *placeholderWriter) WriteNum(node NumVal) error { + _, err := w.Write(astPlaceholder) + return err +} + +var ( + // Placeholder is a placeholder for a value in a SQL statement. It is replaced with + // an actual value when the query is executed. + Placeholder = PlaceholderVal{} +) + // Statement represents a statement. type Statement interface { Serializer @@ -92,7 +180,7 @@ var ( astSelectFrom = []byte(" FROM ") ) -func (node *Select) Serialize(w io.Writer) error { +func (node *Select) Serialize(w Writer) error { if _, err := w.Write(astSelect); err != nil { return err } @@ -142,7 +230,7 @@ const ( astIntersect = "INTERSECT" ) -func (node *Union) Serialize(w io.Writer) error { +func (node *Union) Serialize(w Writer) error { if err := node.Left.Serialize(w); err != nil { return err } @@ -173,7 +261,7 @@ var ( astSpace = []byte(" ") ) -func (node *Insert) Serialize(w io.Writer) error { +func (node *Insert) Serialize(w Writer) error { if _, err := io.WriteString(w, node.Kind); err != nil { return err } @@ -229,7 +317,7 @@ var ( astSet = []byte(" SET ") ) -func (node *Update) Serialize(w io.Writer) error { +func (node *Update) Serialize(w Writer) error { if _, err := w.Write(astUpdate); err != nil { return err } @@ -268,7 +356,7 @@ var ( astDeleteFrom = []byte("FROM ") ) -func (node *Delete) Serialize(w io.Writer) error { +func (node *Delete) Serialize(w Writer) error { if _, err := w.Write(astDelete); err != nil { return err } @@ -293,7 +381,7 @@ func (node *Delete) Serialize(w io.Writer) error { // Comments represents a list of comments. type Comments []string -func (node Comments) Serialize(w io.Writer) error { +func (node Comments) Serialize(w Writer) error { for _, c := range node { if _, err := io.WriteString(w, c); err != nil { return nil @@ -312,7 +400,7 @@ var ( astCommaSpace = []byte(", ") ) -func (node SelectExprs) Serialize(w io.Writer) error { +func (node SelectExprs) Serialize(w Writer) error { var prefix []byte for _, n := range node { if _, err := w.Write(prefix); err != nil { @@ -344,7 +432,7 @@ var ( astStar = []byte("*") ) -func (node *StarExpr) Serialize(w io.Writer) error { +func (node *StarExpr) Serialize(w Writer) error { if node.TableName != "" { if err := quoteName(w, node.TableName); err != nil { return err @@ -367,7 +455,7 @@ var ( astAsPrefix = []byte(" AS `") ) -func (node *NonStarExpr) Serialize(w io.Writer) error { +func (node *NonStarExpr) Serialize(w Writer) error { if err := node.Expr.Serialize(w); err != nil { return err } @@ -396,7 +484,7 @@ var ( astCloseParen = []byte(")") ) -func (node Columns) Serialize(w io.Writer) error { +func (node Columns) Serialize(w Writer) error { if _, err := w.Write(astOpenParen); err != nil { return err } @@ -410,7 +498,7 @@ func (node Columns) Serialize(w io.Writer) error { // TableExprs represents a list of table expressions. type TableExprs []TableExpr -func (node TableExprs) Serialize(w io.Writer) error { +func (node TableExprs) Serialize(w Writer) error { var prefix []byte for _, n := range node { if _, err := w.Write(prefix); err != nil { @@ -442,7 +530,7 @@ type AliasedTableExpr struct { Hints *IndexHints } -func (node *AliasedTableExpr) Serialize(w io.Writer) error { +func (node *AliasedTableExpr) Serialize(w Writer) error { if err := node.Expr.Serialize(w); err != nil { return err } @@ -480,7 +568,7 @@ type TableName struct { Name, Qualifier string } -func (node *TableName) Serialize(w io.Writer) error { +func (node *TableName) Serialize(w Writer) error { if node.Qualifier != "" { if err := quoteName(w, node.Qualifier); err != nil { return err @@ -515,7 +603,7 @@ const ( astNaturalJoin = "NATURAL JOIN" ) -func (node *JoinTableExpr) Serialize(w io.Writer) error { +func (node *JoinTableExpr) Serialize(w Writer) error { if err := node.LeftExpr.Serialize(w); err != nil { return err } @@ -551,7 +639,7 @@ var ( astOn = []byte(" ON ") ) -func (node *OnJoinCond) Serialize(w io.Writer) error { +func (node *OnJoinCond) Serialize(w Writer) error { if _, err := w.Write(astOn); err != nil { return err } @@ -567,7 +655,7 @@ var ( astUsing = []byte(" USING ") ) -func (node *UsingJoinCond) Serialize(w io.Writer) error { +func (node *UsingJoinCond) Serialize(w Writer) error { if _, err := w.Write(astUsing); err != nil { return err } @@ -586,7 +674,7 @@ const ( astForce = "FORCE" ) -func (node *IndexHints) Serialize(w io.Writer) error { +func (node *IndexHints) Serialize(w Writer) error { if _, err := fmt.Fprintf(w, " %s INDEX ", node.Type); err != nil { return err } @@ -622,7 +710,7 @@ func NewWhere(typ string, expr BoolExpr) *Where { return &Where{Type: typ, Expr: expr} } -func (node *Where) Serialize(w io.Writer) error { +func (node *Where) Serialize(w Writer) error { if node == nil { return nil } @@ -646,8 +734,9 @@ func (*ComparisonExpr) expr() {} func (*RangeCond) expr() {} func (*NullCheck) expr() {} func (*ExistsExpr) expr() {} +func (PlaceholderVal) expr() {} func (RawVal) expr() {} -func (encodedVal) expr() {} +func (EncodedVal) expr() {} func (StrVal) expr() {} func (NumVal) expr() {} func (ValArg) expr() {} @@ -685,7 +774,7 @@ type AndExpr struct { Exprs []BoolExpr } -func (node *AndExpr) Serialize(w io.Writer) error { +func (node *AndExpr) Serialize(w Writer) error { if len(node.Exprs) == 0 { _, err := w.Write(astBoolTrue) return err @@ -721,7 +810,7 @@ type OrExpr struct { Exprs []BoolExpr } -func (node *OrExpr) Serialize(w io.Writer) error { +func (node *OrExpr) Serialize(w Writer) error { if len(node.Exprs) == 0 { _, err := w.Write(astBoolFalse) return err @@ -757,7 +846,7 @@ type NotExpr struct { Expr BoolExpr } -func (node *NotExpr) Serialize(w io.Writer) error { +func (node *NotExpr) Serialize(w Writer) error { if _, err := io.WriteString(w, node.Op); err != nil { return err } @@ -769,7 +858,7 @@ type ParenBoolExpr struct { Expr BoolExpr } -func (node *ParenBoolExpr) Serialize(w io.Writer) error { +func (node *ParenBoolExpr) Serialize(w Writer) error { if _, err := w.Write(astOpenParen); err != nil { return err } @@ -802,7 +891,7 @@ const ( astNotLike = " NOT LIKE " ) -func (node *ComparisonExpr) Serialize(w io.Writer) error { +func (node *ComparisonExpr) Serialize(w Writer) error { if err := node.Left.Serialize(w); err != nil { return err } @@ -829,7 +918,7 @@ var ( astAnd = []byte(" AND ") ) -func (node *RangeCond) Serialize(w io.Writer) error { +func (node *RangeCond) Serialize(w Writer) error { if err := node.Left.Serialize(w); err != nil { return err } @@ -857,7 +946,7 @@ const ( astIsNotNull = " IS NOT NULL" ) -func (node *NullCheck) Serialize(w io.Writer) error { +func (node *NullCheck) Serialize(w Writer) error { if err := node.Expr.Serialize(w); err != nil { return err } @@ -876,19 +965,33 @@ type ValExpr interface { Expr } -func (RawVal) valExpr() {} -func (encodedVal) valExpr() {} -func (StrVal) valExpr() {} -func (NumVal) valExpr() {} -func (ValArg) valExpr() {} -func (*NullVal) valExpr() {} -func (*ColName) valExpr() {} -func (ValTuple) valExpr() {} -func (*Subquery) valExpr() {} -func (*BinaryExpr) valExpr() {} -func (*UnaryExpr) valExpr() {} -func (*FuncExpr) valExpr() {} -func (*CaseExpr) valExpr() {} +func (PlaceholderVal) valExpr() {} +func (RawVal) valExpr() {} +func (EncodedVal) valExpr() {} +func (StrVal) valExpr() {} +func (NumVal) valExpr() {} +func (ValArg) valExpr() {} +func (*NullVal) valExpr() {} +func (*ColName) valExpr() {} +func (ValTuple) valExpr() {} +func (*Subquery) valExpr() {} +func (*BinaryExpr) valExpr() {} +func (*UnaryExpr) valExpr() {} +func (*FuncExpr) valExpr() {} +func (*CaseExpr) valExpr() {} + +var ( + astPlaceholder = []byte("?") +) + +// PlaceholderVal represents a placeholder parameter that will be supplied +// when executing the query. It will be serialized as a ?. +type PlaceholderVal struct{} + +func (node PlaceholderVal) Serialize(w Writer) error { + _, err := w.Write(astPlaceholder) + return err +} // RawVal represents a raw go value type RawVal struct { @@ -900,26 +1003,25 @@ var ( astBoolFalse = []byte("0") ) -func (node RawVal) Serialize(w io.Writer) error { - return encodeSQLValue(w, node.Val) +func (node RawVal) Serialize(w Writer) error { + return w.WriteRaw(node) } -// encodedVal represents an already encoded value. Not exported but -// this struct must be used with caution. -type encodedVal struct { +// EncodedVal represents an already encoded value. This struct must be used +// with caution because misuse can provide an avenue for SQL injection attacks. +type EncodedVal struct { Val []byte } -func (node encodedVal) Serialize(w io.Writer) error { - _, err := w.Write(node.Val) - return err +func (node EncodedVal) Serialize(w Writer) error { + return w.WriteEncoded(node) } // StrVal represents a string value. type StrVal string -func (node StrVal) Serialize(w io.Writer) error { - return encodeSQLString(w, string(node)) +func (node StrVal) Serialize(w Writer) error { + return w.WriteStr(node) } // BytesVal represents a string of unprintable value. @@ -928,8 +1030,8 @@ type BytesVal []byte func (BytesVal) expr() {} func (BytesVal) valExpr() {} -func (node BytesVal) Serialize(w io.Writer) error { - return encodeSQLBytes(w, []byte(node)) +func (node BytesVal) Serialize(w Writer) error { + return w.WriteBytes(node) } // ErrVal represents an error condition that occurred while @@ -941,22 +1043,21 @@ type ErrVal struct { func (ErrVal) expr() {} func (ErrVal) valExpr() {} -func (node ErrVal) Serialize(w io.Writer) error { +func (node ErrVal) Serialize(w Writer) error { return node.Err } // NumVal represents a number. type NumVal string -func (node NumVal) Serialize(w io.Writer) error { - _, err := io.WriteString(w, string(node)) - return err +func (node NumVal) Serialize(w Writer) error { + return w.WriteNum(node) } // ValArg represents a named bind var argument. type ValArg string -func (node ValArg) Serialize(w io.Writer) error { +func (node ValArg) Serialize(w Writer) error { _, err := fmt.Fprintf(w, ":%s", string(node)[1:]) return err } @@ -968,7 +1069,7 @@ var ( astNull = []byte("NULL") ) -func (node *NullVal) Serialize(w io.Writer) error { +func (node *NullVal) Serialize(w Writer) error { _, err := w.Write(astNull) return err } @@ -983,7 +1084,7 @@ var ( astPeriod = []byte(".") ) -func (node *ColName) Serialize(w io.Writer) error { +func (node *ColName) Serialize(w Writer) error { if node.Qualifier != "" { if err := quoteName(w, node.Qualifier); err != nil { return err @@ -1020,7 +1121,7 @@ type ValTuple struct { Exprs ValExprs } -func (node ValTuple) Serialize(w io.Writer) error { +func (node ValTuple) Serialize(w Writer) error { if _, err := w.Write(astOpenParen); err != nil { return err } @@ -1035,7 +1136,7 @@ func (node ValTuple) Serialize(w io.Writer) error { // It's not a valid expression because it's not parenthesized. type ValExprs []ValExpr -func (node ValExprs) Serialize(w io.Writer) error { +func (node ValExprs) Serialize(w Writer) error { var prefix []byte for _, n := range node { if _, err := w.Write(prefix); err != nil { @@ -1072,7 +1173,7 @@ const ( astMod = '%' ) -func (node *BinaryExpr) Serialize(w io.Writer) error { +func (node *BinaryExpr) Serialize(w Writer) error { if err := node.Left.Serialize(w); err != nil { return err } @@ -1095,7 +1196,7 @@ const ( astTilda = '~' ) -func (node *UnaryExpr) Serialize(w io.Writer) error { +func (node *UnaryExpr) Serialize(w Writer) error { if _, err := fmt.Fprintf(w, "%c", node.Operator); err != nil { return err } @@ -1113,7 +1214,7 @@ var ( astFuncDistinct = []byte("DISTINCT ") ) -func (node *FuncExpr) Serialize(w io.Writer) error { +func (node *FuncExpr) Serialize(w Writer) error { if _, err := io.WriteString(w, node.Name); err != nil { return err } @@ -1139,7 +1240,7 @@ type CaseExpr struct { Else ValExpr } -func (node *CaseExpr) Serialize(w io.Writer) error { +func (node *CaseExpr) Serialize(w Writer) error { if _, err := fmt.Fprintf(w, "CASE "); err != nil { return err } @@ -1180,7 +1281,7 @@ type When struct { Val ValExpr } -func (node *When) Serialize(w io.Writer) error { +func (node *When) Serialize(w Writer) error { fmt.Sprintf("WHEN ") if err := node.Cond.Serialize(w); err != nil { return err @@ -1196,7 +1297,7 @@ var ( astValues = []byte("VALUES ") ) -func (node Values) Serialize(w io.Writer) error { +func (node Values) Serialize(w Writer) error { prefix := astValues for _, n := range node { if _, err := w.Write(prefix); err != nil { @@ -1217,7 +1318,7 @@ var ( astGroupBy = []byte(" GROUP BY ") ) -func (node GroupBy) Serialize(w io.Writer) error { +func (node GroupBy) Serialize(w Writer) error { prefix := astGroupBy for _, n := range node { if _, err := w.Write(prefix); err != nil { @@ -1238,7 +1339,7 @@ var ( astOrderBy = []byte(" ORDER BY ") ) -func (node OrderBy) Serialize(w io.Writer) error { +func (node OrderBy) Serialize(w Writer) error { prefix := astOrderBy for _, n := range node { if _, err := w.Write(prefix); err != nil { @@ -1264,7 +1365,7 @@ const ( astDesc = " DESC" ) -func (node *Order) Serialize(w io.Writer) error { +func (node *Order) Serialize(w Writer) error { if err := node.Expr.Serialize(w); err != nil { return err } @@ -1281,7 +1382,7 @@ var ( astLimit = []byte(" LIMIT ") ) -func (node *Limit) Serialize(w io.Writer) error { +func (node *Limit) Serialize(w Writer) error { if node == nil { return nil } @@ -1302,7 +1403,7 @@ func (node *Limit) Serialize(w io.Writer) error { // UpdateExprs represents a list of update expressions. type UpdateExprs []*UpdateExpr -func (node UpdateExprs) Serialize(w io.Writer) error { +func (node UpdateExprs) Serialize(w Writer) error { var prefix []byte for _, n := range node { if _, err := w.Write(prefix); err != nil { @@ -1326,7 +1427,7 @@ var ( astUpdateEq = []byte(" = ") ) -func (node *UpdateExpr) Serialize(w io.Writer) error { +func (node *UpdateExpr) Serialize(w Writer) error { if err := node.Name.Serialize(w); err != nil { return nil } @@ -1343,7 +1444,7 @@ var ( astOnDupKeyUpdate = []byte(" ON DUPLICATE KEY UPDATE ") ) -func (node OnDup) Serialize(w io.Writer) error { +func (node OnDup) Serialize(w Writer) error { if node == nil { return nil } diff --git a/builder_test.go b/builder_test.go index b867f64..f40c72a 100644 --- a/builder_test.go +++ b/builder_test.go @@ -441,6 +441,52 @@ func TestSelectBuilder(t *testing.T) { } } +func TestSerializeWithPlaceholders(t *testing.T) { + type User struct { + Foo, Bar, Qux string + } + users := NewTable("users", User{}) + foo := users.C("foo") + + type Object struct { + Foo, Baz string + } + + testCases := []struct { + builder *SelectBuilder + expected string + }{ + {users.Select(foo.As("bar")), + "SELECT `users`.`foo` AS `bar` FROM `users`"}, + {users.Select("*").Where(foo.Eq(1)), + "SELECT * FROM `users` WHERE `users`.`foo` = ?"}, + {users.Select("*").Where(foo.Neq(false)), + "SELECT * FROM `users` WHERE `users`.`foo` != ?"}, + {users.Select("*").Where(foo.NullSafeEq(false)), + "SELECT * FROM `users` WHERE `users`.`foo` <=> ?"}, + {users.Select("*").Where(foo.Gte(time.Time{})), + "SELECT * FROM `users` WHERE `users`.`foo` >= ?"}, + {users.Select("*").Where(foo.Lt(2.5)), + "SELECT * FROM `users` WHERE `users`.`foo` < ?"}, + {users.Select("*").Where(foo.Lt(true)), + "SELECT * FROM `users` WHERE `users`.`foo` < ?"}, + {users.Select("*").Where(foo.IsNull()), + "SELECT * FROM `users` WHERE `users`.`foo` IS NULL"}, + {users.Select(foo).Where(foo.In([]string{"baz", "qux"})), + "SELECT `users`.`foo` FROM `users` WHERE `users`.`foo` IN (?, ?)"}, + {users.Select(foo).Where(foo.Like("baz")), + "SELECT `users`.`foo` FROM `users` WHERE `users`.`foo` LIKE ?"}, + } + + for _, c := range testCases { + if sql, err := SerializeWithPlaceholders(c.builder); err != nil { + t.Errorf("Expected success, but found %s\n%s", err, c.expected) + } else if c.expected != sql { + t.Errorf("Expected\n%s\nbut got\n%s", c.expected, sql) + } + } +} + func TestSelectBuilderErrors(t *testing.T) { type User struct { Foo string diff --git a/db.go b/db.go index 1c7f13a..50b4c07 100644 --- a/db.go +++ b/db.go @@ -19,9 +19,11 @@ import ( "database/sql" "errors" "fmt" + "io" "reflect" "sort" "sync" + "time" ) // ErrMixedAutoIncrIDs is returned when attempting to insert multiple @@ -34,7 +36,7 @@ var ErrMixedAutoIncrIDs = errors.New("sql: auto increment column must be all set // DB or on a Tx. type Executor interface { Delete(list ...interface{}) (int64, error) - Exec(string, ...interface{}) (sql.Result, error) + Exec(query interface{}, args ...interface{}) (sql.Result, error) Get(dest interface{}, keys ...interface{}) error Insert(list ...interface{}) error Query(query interface{}, args ...interface{}) (*Rows, error) @@ -255,6 +257,14 @@ func getUpsert(m *Model) insertPlan { return m.upsert } +// stringSerializer is a wrapper around a string that implements Serializer. +type stringSerializer string + +func (ss stringSerializer) Serialize(w Writer) error { + _, err := io.WriteString(w, string(ss)) + return err +} + // DB is a wrapper around a sql.DB which also implements the // squalor.Executor interface. DB is safe for concurrent use by // multiple goroutines. @@ -268,6 +278,7 @@ type DB struct { // // The default is false. IgnoreUnmappedCols bool + Logger QueryLogger mu sync.RWMutex models map[reflect.Type]*Model mappings map[reflect.Type]fieldMap @@ -279,11 +290,21 @@ func NewDB(db *sql.DB) *DB { DB: db, AllowStringQueries: true, IgnoreUnmappedCols: false, + Logger: nil, models: map[reflect.Type]*Model{}, mappings: map[reflect.Type]fieldMap{}, } } +func (db *DB) logQuery(query Serializer, exec Executor, start time.Time, err error) { + if db.Logger == nil { + return + } + + executionTime := time.Now().Sub(start) + db.Logger.Log(query, exec, executionTime, err) +} + // GetModel retrieves the model for the specified object. Obj must be // a struct. An error is returned if obj has not been bound to a table // via a call to BindModel. @@ -321,19 +342,19 @@ func (db *DB) getMapping(t reflect.Type) fieldMap { return mapping } -func (db *DB) queryString(query interface{}) (string, error) { +func (db *DB) getSerializer(query interface{}) (Serializer, error) { if t, ok := query.(Serializer); ok { - return Serialize(t) + return t, nil } if db.AllowStringQueries { switch t := query.(type) { case string: - return t, nil + return stringSerializer(t), nil } } - return "", fmt.Errorf("unsupported query type %T", query) + return nil, fmt.Errorf("unsupported query type %T", query) } // BindModel binds the supplied interface with the named table. You @@ -411,6 +432,25 @@ func (db *DB) Delete(list ...interface{}) (int64, error) { return deleteObjects(db, db, list) } +// Exec executes a query without returning any rows. The args are for any +// placeholder parameters in the query. +func (db *DB) Exec(query interface{}, args ...interface{}) (sql.Result, error) { + serializer, err := db.getSerializer(query) + if err != nil { + return nil, err + } + querystr, err := Serialize(serializer) + if err != nil { + return nil, err + } + + start := time.Now() + result, err := db.DB.Exec(querystr, args...) + db.logQuery(serializer, db, start, err) + + return result, err +} + // Get runs a SQL SELECT to fetch a single row. Keys must be the // primary keys defined for the table. The order must match the order // of the columns in the primary key. @@ -440,11 +480,19 @@ func (db *DB) Insert(list ...interface{}) error { // small wrapper around sql.DB.Query that returns a *squalor.Rows // instead. func (db *DB) Query(query interface{}, args ...interface{}) (*Rows, error) { - querystr, err := db.queryString(query) + serializer, err := db.getSerializer(query) if err != nil { return nil, err } + querystr, err := Serialize(serializer) + if err != nil { + return nil, err + } + + start := time.Now() rows, err := db.DB.Query(querystr, args...) + db.logQuery(serializer, db, start, err) + if err != nil { return nil, err } @@ -456,11 +504,19 @@ func (db *DB) Query(query interface{}, args ...interface{}) (*Rows, error) { // until Row's Scan method is called. This is a small wrapper around // sql.DB.QueryRow that returns a *squalor.Row instead. func (db *DB) QueryRow(query interface{}, args ...interface{}) *Row { - querystr, err := db.queryString(query) + serializer, err := db.getSerializer(query) if err != nil { return &Row{rows: Rows{Rows: nil, db: nil}, err: err} } + querystr, err := Serialize(serializer) + if err != nil { + return &Row{rows: Rows{Rows: nil, db: nil}, err: err} + } + + start := time.Now() rows, err := db.DB.Query(querystr, args...) + db.logQuery(serializer, db, start, err) + return &Row{rows: Rows{Rows: rows, db: db}, err: err} } @@ -535,6 +591,25 @@ type Tx struct { DB *DB } +// Exec executes a query that doesn't return rows. For example: an +// INSERT and UPDATE. +func (tx *Tx) Exec(query interface{}, args ...interface{}) (sql.Result, error) { + serializer, err := tx.DB.getSerializer(query) + if err != nil { + return nil, err + } + querystr, err := Serialize(serializer) + if err != nil { + return nil, err + } + + start := time.Now() + result, err := tx.Tx.Exec(querystr, args...) + tx.DB.logQuery(serializer, tx, start, err) + + return result, err +} + // Delete runs a batched SQL DELETE statement, grouping the objects by // the model type of the list elements. List elements must be pointers // to structs. @@ -576,11 +651,19 @@ func (tx *Tx) Insert(list ...interface{}) error { // small wrapper around sql.Tx.Query that returns a *squalor.Rows // instead. func (tx *Tx) Query(query interface{}, args ...interface{}) (*Rows, error) { - querystr, err := tx.DB.queryString(query) + serializer, err := tx.DB.getSerializer(query) + if err != nil { + return nil, err + } + querystr, err := Serialize(serializer) if err != nil { return nil, err } + + start := time.Now() rows, err := tx.Tx.Query(querystr, args...) + tx.DB.logQuery(serializer, tx, start, err) + if err != nil { return nil, err } @@ -592,11 +675,19 @@ func (tx *Tx) Query(query interface{}, args ...interface{}) (*Rows, error) { // until Row's Scan method is called. This is a small wrapper around // sql.Tx.QueryRow that returns a *squalor.Row instead. func (tx *Tx) QueryRow(query interface{}, args ...interface{}) *Row { - querystr, err := tx.DB.queryString(query) + serializer, err := tx.DB.getSerializer(query) + if err != nil { + return &Row{rows: Rows{Rows: nil, db: nil}, err: err} + } + querystr, err := Serialize(serializer) if err != nil { return &Row{rows: Rows{Rows: nil, db: nil}, err: err} } + + start := time.Now() rows, err := tx.Tx.Query(querystr, args...) + tx.DB.logQuery(serializer, tx, start, err) + return &Row{rows: Rows{Rows: rows, db: tx.DB}, err: err} } @@ -942,7 +1033,7 @@ func deleteModel(model *Model, exec Executor, list []interface{}) (int64, error) // the AND and IN expression. The buffers are the max size to // minimize reallocations, though it is possible we'll only use a // handful of values in the same batch. - valbuf := make([]encodedVal, len(rows)+n-1) + valbuf := make([]EncodedVal, len(rows)+n-1) argbuf := make(ValExprs, 0, len(rows)) var inTuple ValTuple inTuple.Exprs = argbuf @@ -985,12 +1076,7 @@ func deleteModel(model *Model, exec Executor, list []interface{}) (int64, error) b.Where(andExpr.And(inExpr)) } - s, err := Serialize(&b) - if err != nil { - return -1, err - } - - res, err := exec.Exec(s) + res, err := exec.Exec(&b) if err != nil { return -1, err } @@ -1124,22 +1210,18 @@ func insertModel(model *Model, exec Executor, getPlan func(m *Model) insertPlan, return ErrMixedAutoIncrIDs } - var s string - var err error + var serializer Serializer if plan.replaceBuilder != nil { b := *plan.replaceBuilder b.AddRows(rows) - s, err = Serialize(&b) + serializer = &b } else { b := *plan.insertBuilder b.AddRows(rows) - s, err = Serialize(&b) - } - if err != nil { - return err + serializer = &b } - res, err := exec.Exec(s) + res, err := exec.Exec(serializer) if err != nil { return err } @@ -1262,11 +1344,7 @@ func updateModel(model *Model, exec Executor, list []interface{}) (int64, error) } b.Where(where) - s, err := Serialize(b) - if err != nil { - return -1, err - } - res, err := exec.Exec(s) + res, err := exec.Exec(b) if err != nil { return -1, err } diff --git a/db_internal_test.go b/db_internal_test.go index 6209395..cd963d6 100644 --- a/db_internal_test.go +++ b/db_internal_test.go @@ -83,11 +83,28 @@ type recordingExecutor struct { query []string } -func (r *recordingExecutor) Exec(stmt string, args ...interface{}) (sql.Result, error) { +func (r *recordingExecutor) Exec(stmt interface{}, args ...interface{}) (sql.Result, error) { + var querystr string + var err error + + switch t := stmt.(type) { + case string: + querystr = t + return r.Exec(t, args...) + case Serializer: + querystr, err = Serialize(t) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("unexpected stmt type") + } + if len(args) != 0 { panic(fmt.Errorf("expected 0 args: %+v", args)) } - r.exec = append(r.exec, stmt) + + r.exec = append(r.exec, querystr) return dummyResult{}, nil } diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..dbd0939 --- /dev/null +++ b/logger.go @@ -0,0 +1,49 @@ +// Copyright 2015 Square Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package squalor + +import ( + "log" + "time" +) + +// QueryLogger defines an interface for query loggers. +type QueryLogger interface { + // Log is called on completion of a query with a Serializer for the + // query, the Executor it was called on, the execution time of the query + // and an error if one occurred. + // + // The Executor may be used to trace queries within a transaction because + // queries in the same transaction will use the same executor. + Log(query Serializer, exec Executor, executionTime time.Duration, err error) +} + +// StandardLogger implements the QueryLogger interface and wraps a log.Logger. +type StandardLogger struct { + *log.Logger +} + +func (l *StandardLogger) Log(query Serializer, exec Executor, executionTime time.Duration, err error) { + querystr, serializeErr := Serialize(query) + if serializeErr != nil { + return + } + + if err != nil { + l.Printf("[%p] %s - `%s` - %s\n", exec, executionTime, querystr, err) + } else { + l.Printf("[%p] %s - `%s`\n", exec, executionTime, querystr) + } +}