Skip to content

Commit

Permalink
[#11] display sql bind values
Browse files Browse the repository at this point in the history
  • Loading branch information
dwkang committed Aug 19, 2022
1 parent 1eb8716 commit d9a6fed
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 36 deletions.
3 changes: 2 additions & 1 deletion annotation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ import (
)

const (
AnnotationSqlId = 20
AnnotationHttpUrl = 40
//AnnotationHttpParam = 41
//AnnotationHttpParam = 41
AnnotationHttpCookie = 45
AnnotationHttpStatusCode = 46
AnnotationHttpRequestHeader = 47
Expand Down
6 changes: 3 additions & 3 deletions noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ func (se *noopSpanEvent) SetDestination(id string) {}

func (se *noopSpanEvent) SetEndPoint(endPoint string) {}

func (se *noopSpanEvent) SetSQL(sql string) {}
func (se *noopSpanEvent) SetSQL(sql string, args string) {}

func (span *noopSpanEvent) Annotations() Annotation {
return &span.annotations
func (se *noopSpanEvent) Annotations() Annotation {
return &se.annotations
}

func (se *noopSpanEvent) FixDuration(start time.Time, end time.Time) {}
Expand Down
4 changes: 2 additions & 2 deletions span_event.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,15 @@ func (se *spanEvent) SetEndPoint(endPoint string) {
se.endPoint = endPoint
}

func (se *spanEvent) SetSQL(sql string) {
func (se *spanEvent) SetSQL(sql string, args string) {
if sql == "" {
return
}

normalizer := newSqlNormalizer(sql)
nsql, param := normalizer.run()
id := se.parentSpan.agent.CacheSql(nsql)
se.annotations.AppendIntStringString(20, id, param, "" /* bind value for prepared stmt */)
se.annotations.AppendIntStringString(AnnotationSqlId, id, param, args)
}

func (se *spanEvent) Annotations() Annotation {
Expand Down
91 changes: 62 additions & 29 deletions sql_driver.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package pinpoint

import (
"bytes"
"context"
"database/sql/driver"
"errors"
"fmt"
"time"
)

Expand All @@ -13,6 +15,7 @@ type DatabaseTrace struct {
DbName string
DbHost string
QueryString string
bindArgs string

ParseDSN func(td *DatabaseTrace, dataSourceName string)
}
Expand All @@ -34,7 +37,7 @@ func NewDatabaseTracer(ctx context.Context, funcName string, dt *DatabaseTrace)
se.SetServiceType(int32(dt.QueryType))
se.SetEndPoint(dt.DbHost)
se.SetDestination(dt.DbName)
se.SetSQL(dt.QueryString)
se.SetSQL(dt.QueryString, dt.bindArgs)

return tracer
}
Expand Down Expand Up @@ -119,7 +122,7 @@ type sqlConn struct {
trace DatabaseTrace
}

func prepare(stmt driver.Stmt, err error, td *DatabaseTrace, query string, ctx context.Context) (driver.Stmt, error) {
func prepare(stmt driver.Stmt, err error, td *DatabaseTrace, query string) (driver.Stmt, error) {
if nil != err {
return nil, err
}
Expand All @@ -128,22 +131,17 @@ func prepare(stmt driver.Stmt, err error, td *DatabaseTrace, query string, ctx c
return &sqlStmt{
Stmt: stmt,
trace: td,
ctx: ctx,
}, nil
}

func (c *sqlConn) Prepare(query string) (driver.Stmt, error) {
stmt, err := c.Conn.Prepare(query)
return prepare(stmt, err, &c.trace, query, context.Background())
}

func (c *sqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if cpc, ok := c.Conn.(driver.ConnPrepareContext); ok {
stmt, err := cpc.PrepareContext(ctx, query)
return prepare(stmt, err, &c.trace, query, ctx)
return prepare(stmt, err, &c.trace, query)
}

return c.Prepare(query)
stmt, err := c.Conn.Prepare(query)
return prepare(stmt, err, &c.trace, query)
}

func newSqlSpanEvent(ctx context.Context, operation string, dt *DatabaseTrace, start time.Time, err error) {
Expand All @@ -162,6 +160,7 @@ func (c *sqlConn) ExecContext(ctx context.Context, query string, args []driver.N

if err != driver.ErrSkip {
c.trace.QueryString = query
c.trace.bindArgs = namedValueToString(args)
newSqlSpanEvent(ctx, "ConnExecContext", &c.trace, start, err)
}

Expand All @@ -183,6 +182,7 @@ func (c *sqlConn) ExecContext(ctx context.Context, query string, args []driver.N
result, err := e.Exec(query, dargs)
if err != driver.ErrSkip {
c.trace.QueryString = query
c.trace.bindArgs = valueToString(dargs)
newSqlSpanEvent(ctx, "ConnExec", &c.trace, start, err)
}

Expand All @@ -199,6 +199,7 @@ func (c *sqlConn) QueryContext(ctx context.Context, query string, args []driver.
rows, err := qc.QueryContext(ctx, query, args)
if err != driver.ErrSkip {
c.trace.QueryString = query
c.trace.bindArgs = namedValueToString(args)
newSqlSpanEvent(ctx, "ConnQueryContext", &c.trace, start, err)
}

Expand All @@ -220,6 +221,7 @@ func (c *sqlConn) QueryContext(ctx context.Context, query string, args []driver.
rows, err := q.Query(query, dargs)
if err != driver.ErrSkip {
c.trace.QueryString = query
c.trace.bindArgs = valueToString(dargs)
newSqlSpanEvent(ctx, "ConnQuery", &c.trace, start, err)
}

Expand Down Expand Up @@ -275,13 +277,14 @@ func (t *sqlTx) Rollback() (err error) {
type sqlStmt struct {
driver.Stmt
trace *DatabaseTrace
ctx context.Context
}

func (s *sqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
start := time.Now()

if sec, ok := s.Stmt.(driver.StmtExecContext); ok {
start := time.Now()
result, err := sec.ExecContext(ctx, args)
s.trace.bindArgs = namedValueToString(args)
newSqlSpanEvent(ctx, "StmtExecContext", s.trace, start, err)
return result, err
}
Expand All @@ -297,21 +300,18 @@ func (s *sqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (dr
return nil, ctx.Err()
}

s.ctx = ctx
return s.Exec(dargs)
}

func (s *sqlStmt) Exec(args []driver.Value) (driver.Result, error) {
start := time.Now()
result, err := s.Stmt.Exec(args)
newSqlSpanEvent(s.ctx, "StmtExec", s.trace, start, err)
result, err := s.Stmt.Exec(dargs)
s.trace.bindArgs = valueToString(dargs)
newSqlSpanEvent(ctx, "StmtExec", s.trace, start, err)
return result, err
}

func (s *sqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
start := time.Now()

if sqc, ok := s.Stmt.(driver.StmtQueryContext); ok {
start := time.Now()
rows, err := sqc.QueryContext(ctx, args)
s.trace.bindArgs = namedValueToString(args)
newSqlSpanEvent(ctx, "StmtQueryContext", s.trace, start, err)
return rows, err
}
Expand All @@ -327,14 +327,9 @@ func (s *sqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (d
return nil, ctx.Err()
}

s.ctx = ctx
return s.Query(dargs)
}

func (s *sqlStmt) Query(args []driver.Value) (driver.Rows, error) {
start := time.Now()
rows, err := s.Stmt.Query(args)
newSqlSpanEvent(s.ctx, "StmtQuery", s.trace, start, err)
rows, err := s.Stmt.Query(dargs)
s.trace.bindArgs = valueToString(dargs)
newSqlSpanEvent(ctx, "StmtQuery", s.trace, start, err)
return rows, err
}

Expand All @@ -349,3 +344,41 @@ func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
}
return dargs, nil
}

const maxBindArgsLength = 1024

func namedValueToString(named []driver.NamedValue) string {
var b bytes.Buffer

c := len(named) - 1
for i, param := range named {
b.WriteString(fmt.Sprint(param.Value))
if i < c {
b.WriteString(", ")
}
if b.Len() > maxBindArgsLength {
b.WriteString("...(1024)")
break
}
}

return b.String()
}

func valueToString(vals []driver.Value) string {
var b bytes.Buffer

c := len(vals) - 1
for i, v := range vals {
b.WriteString(fmt.Sprint(v))
if i < c {
b.WriteString(", ")
}
if b.Len() > maxBindArgsLength {
b.WriteString("...(1024)")
break
}
}

return b.String()
}
2 changes: 1 addition & 1 deletion tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ type SpanEventRecorder interface {
SetDestination(id string)
SetEndPoint(endPoint string)
SetError(e error)
SetSQL(sql string)
SetSQL(sql string, args string)
Annotations() Annotation
FixDuration(start time.Time, end time.Time)
}
Expand Down

0 comments on commit d9a6fed

Please sign in to comment.