Skip to content

Commit

Permalink
Add InTx method to Stmts
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenafamo committed Jul 11, 2023
1 parent 3ac3def commit dc037bf
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 21 deletions.
16 changes: 16 additions & 0 deletions stdlib.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,19 @@ type stdStmt struct {
func (s stdStmt) QueryContext(ctx context.Context, args ...any) (scan.Rows, error) {
return s.Stmt.QueryContext(ctx, args...)
}

type errStmt struct {
err error
}

func (s errStmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) {
return nil, s.err
}

func (s errStmt) QueryContext(ctx context.Context, args ...any) (scan.Rows, error) {
return nil, s.err
}

func (s errStmt) Close() error {
return s.err
}
28 changes: 25 additions & 3 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package bob
import (
"context"
"database/sql"
"errors"
"fmt"

"github.com/stephenafamo/scan"
Expand Down Expand Up @@ -66,6 +67,19 @@ type Stmt struct {
loaders []Loader
}

// InTx returns a copy of the Stmt that will execute in the given transaction
func (s Stmt) InTx(ctx context.Context, tx Tx) Stmt {
var stmt Statement = errStmt{errors.New("stmt is not an stdStmt")}

if std, ok := s.stmt.(stdStmt); !ok {
stmt = stdStmt{tx.wrapped.StmtContext(ctx, std.Stmt)}
}

s.stmt = stmt
s.exec = tx
return s
}

// Close closes the statement
func (s Stmt) Close() error {
return s.stmt.Close()
Expand Down Expand Up @@ -131,9 +145,17 @@ type QueryStmt[T any, Ts ~[]T] struct {
settings ExecSettings[T]
}

// Close closes the statement
func (s QueryStmt[T, Ts]) Close() error {
return s.stmt.Close()
// InTx returns a copy of the Stmt that will execute in the given transaction
func (s QueryStmt[T, Ts]) InTx(ctx context.Context, tx Tx) QueryStmt[T, Ts] {
var stmt Statement = errStmt{errors.New("stmt is not an stdStmt")}

if std, ok := s.stmt.(stdStmt); !ok {
stmt = stdStmt{tx.wrapped.StmtContext(ctx, std.Stmt)}
}

s.stmt = stmt
s.exec = tx
return s
}

func (s QueryStmt[T, Ts]) One(ctx context.Context, args ...any) (T, error) {
Expand Down
39 changes: 30 additions & 9 deletions stmt_bound.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,19 @@ type BoundStmt[Arg any] struct {
binder structBinder[Arg]
}

// InTx returns a new MappedStmt that will be executed in the given transaction
func (s BoundStmt[Arg]) InTx(ctx context.Context, tx Tx) BoundStmt[Arg] {
var stmt Statement = errStmt{errors.New("stmt is not an stdStmt")}

if std, ok := s.stmt.stmt.(stdStmt); !ok {
stmt = stdStmt{tx.wrapped.StmtContext(ctx, std.Stmt)}
}

s.stmt.stmt = stmt
s.stmt.exec = tx
return s
}

// Close closes the statement.
func (s BoundStmt[Arg]) Close() error {
return s.stmt.Close()
Expand Down Expand Up @@ -133,19 +146,27 @@ func PrepareBoundQueryx[Arg any, T any, Ts ~[]T](ctx context.Context, exec Prepa
}

return BoundQueryStmt[Arg, T, Ts]{
query: s,
binder: binder,
QueryStmt: s,
binder: binder,
}, nil
}

type BoundQueryStmt[Arg any, T any, Ts ~[]T] struct {
query QueryStmt[T, Ts]
QueryStmt[T, Ts]
binder structBinder[Arg]
}

// Close closes the statement.
func (s BoundQueryStmt[Arg, T, Ts]) Close() error {
return s.query.Close()
// InTx returns a new MappedStmt that will be executed in the given transaction
func (s BoundQueryStmt[Arg, T, Ts]) InTx(ctx context.Context, tx Tx) BoundQueryStmt[Arg, T, Ts] {
var stmt Statement = errStmt{errors.New("stmt is not an stdStmt")}

if std, ok := s.stmt.(stdStmt); !ok {
stmt = stdStmt{tx.wrapped.StmtContext(ctx, std.Stmt)}
}

s.stmt = stmt
s.exec = tx
return s
}

func (s BoundQueryStmt[Arg, T, Ts]) One(ctx context.Context, arg Arg) (T, error) {
Expand All @@ -155,7 +176,7 @@ func (s BoundQueryStmt[Arg, T, Ts]) One(ctx context.Context, arg Arg) (T, error)
return t, err
}

return s.query.One(ctx, args...)
return s.QueryStmt.One(ctx, args...)
}

func (s BoundQueryStmt[Arg, T, Ts]) All(ctx context.Context, arg Arg) (Ts, error) {
Expand All @@ -164,7 +185,7 @@ func (s BoundQueryStmt[Arg, T, Ts]) All(ctx context.Context, arg Arg) (Ts, error
return nil, err
}

return s.query.All(ctx, args...)
return s.QueryStmt.All(ctx, args...)
}

func (s BoundQueryStmt[Arg, T, Ts]) Cursor(ctx context.Context, arg Arg) (scan.ICursor[T], error) {
Expand All @@ -173,5 +194,5 @@ func (s BoundQueryStmt[Arg, T, Ts]) Cursor(ctx context.Context, arg Arg) (scan.I
return nil, err
}

return s.query.Cursor(ctx, args...)
return s.QueryStmt.Cursor(ctx, args...)
}
40 changes: 31 additions & 9 deletions stmt_mapped.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package bob
import (
"context"
"database/sql"
"errors"
"fmt"

"github.com/stephenafamo/scan"
Expand Down Expand Up @@ -107,6 +108,19 @@ type MappedStmt struct {
mapper mapBinder
}

// InTx returns a new MappedStmt that will be executed in the given transaction
func (s MappedStmt) InTx(ctx context.Context, tx Tx) MappedStmt {
var stmt Statement = errStmt{errors.New("stmt is not an stdStmt")}

if std, ok := s.stmt.stmt.(stdStmt); !ok {
stmt = stdStmt{tx.wrapped.StmtContext(ctx, std.Stmt)}
}

s.stmt.stmt = stmt
s.stmt.exec = tx
return s
}

// Inspect returns a map with all the expected keys
func (s MappedStmt) Inspect() []string {
return s.mapper.positions
Expand Down Expand Up @@ -143,19 +157,27 @@ func PrepareMappedQueryx[T any, Ts ~[]T](ctx context.Context, exec Preparer, q Q
}

return MappedQueryStmt[T, Ts]{
query: s,
binder: binder,
QueryStmt: s,
binder: binder,
}, nil
}

type MappedQueryStmt[T any, Ts ~[]T] struct {
query QueryStmt[T, Ts]
QueryStmt[T, Ts]
binder mapBinder
}

// Close closes the statement
func (s MappedQueryStmt[T, Ts]) Close() error {
return s.query.Close()
// InTx returns a new MappedStmt that will be executed in the given transaction
func (s MappedQueryStmt[T, Ts]) InTx(ctx context.Context, tx Tx) MappedQueryStmt[T, Ts] {
var stmt Statement = errStmt{errors.New("stmt is not an stdStmt")}

if std, ok := s.stmt.(stdStmt); !ok {
stmt = stdStmt{tx.wrapped.StmtContext(ctx, std.Stmt)}
}

s.stmt = stmt
s.exec = tx
return s
}

func (s MappedQueryStmt[T, Ts]) One(ctx context.Context, arg map[string]any) (T, error) {
Expand All @@ -165,7 +187,7 @@ func (s MappedQueryStmt[T, Ts]) One(ctx context.Context, arg map[string]any) (T,
return t, err
}

return s.query.One(ctx, args...)
return s.QueryStmt.One(ctx, args...)
}

func (s MappedQueryStmt[T, Ts]) All(ctx context.Context, arg map[string]any) (Ts, error) {
Expand All @@ -174,7 +196,7 @@ func (s MappedQueryStmt[T, Ts]) All(ctx context.Context, arg map[string]any) (Ts
return nil, err
}

return s.query.All(ctx, args...)
return s.QueryStmt.All(ctx, args...)
}

func (s MappedQueryStmt[T, Ts]) Cursor(ctx context.Context, arg map[string]any) (scan.ICursor[T], error) {
Expand All @@ -183,5 +205,5 @@ func (s MappedQueryStmt[T, Ts]) Cursor(ctx context.Context, arg map[string]any)
return nil, err
}

return s.query.Cursor(ctx, args...)
return s.QueryStmt.Cursor(ctx, args...)
}

0 comments on commit dc037bf

Please sign in to comment.