diff --git a/debug_exec.go b/debug_exec.go index 0ff69d55..55e49c02 100644 --- a/debug_exec.go +++ b/debug_exec.go @@ -57,6 +57,15 @@ type debugExecutor struct { exec Executor } +func (d debugExecutor) PrepareContext(ctx context.Context, query string) (Statement, error) { + d.printer.PrintQuery(query) + if p, ok := d.exec.(Preparer); ok { + return p.PrepareContext(ctx, query) + } + + return nil, fmt.Errorf("executor does not implement Preparer") +} + func (d debugExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) { d.printer.PrintQuery(query, args...) return d.exec.ExecContext(ctx, query, args...) diff --git a/dialect/mysql/table.go b/dialect/mysql/table.go index 05ea1c6e..267286a8 100644 --- a/dialect/mysql/table.go +++ b/dialect/mysql/table.go @@ -12,6 +12,7 @@ import ( "github.com/stephenafamo/bob/dialect/mysql/sm" "github.com/stephenafamo/bob/dialect/mysql/um" "github.com/stephenafamo/bob/internal" + "github.com/stephenafamo/bob/internal/mappings" "github.com/stephenafamo/bob/orm" ) @@ -22,7 +23,7 @@ func NewTable[T any, Tset any](tableName string, uniques ...[]string) *Table[T, func NewTablex[T any, Tslice ~[]T, Tset any](tableName string, uniques ...[]string) *Table[T, Tslice, Tset] { var zeroSet Tset - setMapping := internal.GetMappings(reflect.TypeOf(zeroSet)) + setMapping := mappings.GetMappings(reflect.TypeOf(zeroSet)) view, mappings := newView[T, Tslice](tableName) t := &Table[T, Tslice, Tset]{ diff --git a/dialect/mysql/view.go b/dialect/mysql/view.go index e9898437..6cbbb9b9 100644 --- a/dialect/mysql/view.go +++ b/dialect/mysql/view.go @@ -9,6 +9,7 @@ import ( "github.com/stephenafamo/bob/dialect/mysql/dialect" "github.com/stephenafamo/bob/dialect/mysql/sm" "github.com/stephenafamo/bob/internal" + "github.com/stephenafamo/bob/internal/mappings" "github.com/stephenafamo/bob/orm" "github.com/stephenafamo/scan" ) @@ -25,17 +26,17 @@ func NewViewx[T any, Tslice ~[]T](tableName string) *View[T, Tslice] { func newView[T any, Tslice ~[]T](tableName string) (*View[T, Tslice], internal.Mapping) { var zero T - mappings := internal.GetMappings(reflect.TypeOf(zero)) + mapping := mappings.GetMappings(reflect.TypeOf(zero)) alias := tableName - allCols := mappings.Columns(alias) + allCols := internal.MappingCols(mapping, alias) return &View[T, Tslice]{ name: tableName, alias: alias, - mapping: mappings, + mapping: mapping, allCols: allCols, scanner: scan.StructMapper[T](), - }, mappings + }, mapping } type View[T any, Tslice ~[]T] struct { diff --git a/dialect/psql/table.go b/dialect/psql/table.go index 33da57e1..6715b96e 100644 --- a/dialect/psql/table.go +++ b/dialect/psql/table.go @@ -11,6 +11,7 @@ import ( "github.com/stephenafamo/bob/dialect/psql/im" "github.com/stephenafamo/bob/dialect/psql/um" "github.com/stephenafamo/bob/internal" + "github.com/stephenafamo/bob/internal/mappings" "github.com/stephenafamo/bob/orm" ) @@ -21,7 +22,7 @@ func NewTable[T any, Tset any](schema, tableName string) *Table[T, []T, Tset] { func NewTablex[T any, Tslice ~[]T, Tset any](schema, tableName string) *Table[T, Tslice, Tset] { var zeroSet Tset - setMapping := internal.GetMappings(reflect.TypeOf(zeroSet)) + setMapping := mappings.GetMappings(reflect.TypeOf(zeroSet)) view, mappings := newView[T, Tslice](schema, tableName) return &Table[T, Tslice, Tset]{ View: view, diff --git a/dialect/psql/view.go b/dialect/psql/view.go index e69380b1..42e1e0f5 100644 --- a/dialect/psql/view.go +++ b/dialect/psql/view.go @@ -10,6 +10,7 @@ import ( "github.com/stephenafamo/bob/dialect/psql/dialect" "github.com/stephenafamo/bob/dialect/psql/sm" "github.com/stephenafamo/bob/internal" + "github.com/stephenafamo/bob/internal/mappings" "github.com/stephenafamo/bob/orm" "github.com/stephenafamo/scan" ) @@ -32,22 +33,22 @@ func NewViewx[T any, Tslice ~[]T](schema, tableName string) *View[T, Tslice] { func newView[T any, Tslice ~[]T](schema, tableName string) (*View[T, Tslice], internal.Mapping) { var zero T - mappings := internal.GetMappings(reflect.TypeOf(zero)) + mapping := mappings.GetMappings(reflect.TypeOf(zero)) alias := tableName if schema != "" { alias = fmt.Sprintf("%s.%s", schema, tableName) } - allCols := mappings.Columns(alias) + allCols := internal.MappingCols(mapping, alias) return &View[T, Tslice]{ schema: schema, name: tableName, alias: alias, - mapping: mappings, + mapping: mapping, allCols: allCols, scanner: scan.StructMapper[T](), - }, mappings + }, mapping } type View[T any, Tslice ~[]T] struct { diff --git a/dialect/sqlite/table.go b/dialect/sqlite/table.go index 2650069d..c910da99 100644 --- a/dialect/sqlite/table.go +++ b/dialect/sqlite/table.go @@ -11,6 +11,7 @@ import ( "github.com/stephenafamo/bob/dialect/sqlite/im" "github.com/stephenafamo/bob/dialect/sqlite/um" "github.com/stephenafamo/bob/internal" + "github.com/stephenafamo/bob/internal/mappings" "github.com/stephenafamo/bob/orm" ) @@ -21,7 +22,7 @@ func NewTable[T any, Tset any](schema, tableName string) *Table[T, []T, Tset] { func NewTablex[T any, Tslice ~[]T, Tset any](schema, tableName string) *Table[T, Tslice, Tset] { var zeroSet Tset - setMapping := internal.GetMappings(reflect.TypeOf(zeroSet)) + setMapping := mappings.GetMappings(reflect.TypeOf(zeroSet)) view, mappings := newView[T, Tslice](schema, tableName) return &Table[T, Tslice, Tset]{ View: view, diff --git a/dialect/sqlite/view.go b/dialect/sqlite/view.go index 130f4f5e..5bd357ce 100644 --- a/dialect/sqlite/view.go +++ b/dialect/sqlite/view.go @@ -10,6 +10,7 @@ import ( "github.com/stephenafamo/bob/dialect/sqlite/dialect" "github.com/stephenafamo/bob/dialect/sqlite/sm" "github.com/stephenafamo/bob/internal" + "github.com/stephenafamo/bob/internal/mappings" "github.com/stephenafamo/bob/orm" "github.com/stephenafamo/scan" ) @@ -32,22 +33,22 @@ func NewViewx[T any, Tslice ~[]T](schema, tableName string) *View[T, Tslice] { func newView[T any, Tslice ~[]T](schema, tableName string) (*View[T, Tslice], internal.Mapping) { var zero T - mappings := internal.GetMappings(reflect.TypeOf(zero)) + mapping := mappings.GetMappings(reflect.TypeOf(zero)) alias := tableName if schema != "" { alias = fmt.Sprintf("%s.%s", schema, tableName) } - allCols := mappings.Columns(alias) + allCols := internal.MappingCols(mapping, alias) return &View[T, Tslice]{ schema: schema, name: tableName, alias: alias, - mapping: mappings, + mapping: mapping, allCols: allCols, scanner: scan.StructMapper[T](), - }, mappings + }, mapping } type View[T any, Tslice ~[]T] struct { diff --git a/internal/mappings/mapping.go b/internal/mappings/mapping.go new file mode 100644 index 00000000..20e014d3 --- /dev/null +++ b/internal/mappings/mapping.go @@ -0,0 +1,117 @@ +package mappings + +import ( + "reflect" + "regexp" + "strings" +) + +var ( + matchFirstCapRe = regexp.MustCompile("(.)([A-Z][a-z]+)") + matchAllCapRe = regexp.MustCompile("([a-z0-9])([A-Z])") +) + +type colProperties struct { + Name string + IsPK bool + IsGenerated bool + AutoIncrement bool +} + +func getColProperties(tag string) colProperties { + var p colProperties + if tag == "" { + return p + } + + parts := strings.Split(tag, ",") + p.Name = parts[0] + + for _, part := range parts[1:] { + switch part { + case "pk": + p.IsPK = true + case "generated": + p.IsGenerated = true + case "autoincr": + p.AutoIncrement = true + } + } + + return p +} + +type Mapping struct { + All []string + PKs []string + NonPKs []string + Generated []string + NonGenerated []string + AutoIncrement []string +} + +func GetMappings(typ reflect.Type) Mapping { + c := Mapping{} + + if typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + + if typ.Kind() != reflect.Struct { + return c + } + + c.All = make([]string, typ.NumField()) + c.PKs = make([]string, typ.NumField()) + c.NonPKs = make([]string, typ.NumField()) + c.Generated = make([]string, typ.NumField()) + c.NonGenerated = make([]string, typ.NumField()) + c.AutoIncrement = make([]string, typ.NumField()) + + // Go through the struct fields and populate the map. + // Recursively go into any child structs, adding a prefix where necessary + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + + // Don't consider unexported fields + if !field.IsExported() { + continue + } + + // Skip columns that have the tag "-" + tag := field.Tag.Get("db") + if tag == "-" { + continue + } + + if tag == "" { + tag = snakeCase(field.Name) + } + + props := getColProperties(tag) + + c.All[field.Index[0]] = props.Name + if props.IsPK { + c.PKs[field.Index[0]] = props.Name + } else { + c.NonPKs[field.Index[0]] = props.Name + } + if props.IsGenerated { + c.Generated[field.Index[0]] = props.Name + } else { + c.NonGenerated[field.Index[0]] = props.Name + } + if props.AutoIncrement { + c.AutoIncrement[field.Index[0]] = props.Name + } + } + + return c +} + +// snakeCaseFieldFunc is a NameMapperFunc that maps struct field to snake case. +func snakeCase(str string) string { + snake := matchFirstCapRe.ReplaceAllString(str, "${1}_${2}") + snake = matchAllCapRe.ReplaceAllString(snake, "${1}_${2}") + return strings.ToLower(snake) +} diff --git a/internal/reflect.go b/internal/reflect.go index 2cbf47af..1d7aa5a3 100644 --- a/internal/reflect.go +++ b/internal/reflect.go @@ -4,42 +4,22 @@ import ( "errors" "fmt" "reflect" - "regexp" - "strings" "github.com/stephenafamo/bob" "github.com/stephenafamo/bob/expr" + "github.com/stephenafamo/bob/internal/mappings" "github.com/stephenafamo/bob/orm" ) -var ( - matchFirstCapRe = regexp.MustCompile("(.)([A-Z][a-z]+)") - matchAllCapRe = regexp.MustCompile("([a-z0-9])([A-Z])") -) - -// snakeCaseFieldFunc is a NameMapperFunc that maps struct field to snake case. -func snakeCase(str string) string { - snake := matchFirstCapRe.ReplaceAllString(str, "${1}_${2}") - snake = matchAllCapRe.ReplaceAllString(snake, "${1}_${2}") - return strings.ToLower(snake) -} - //nolint:gochecknoglobals var unsettableTyp = reflect.TypeOf((*interface{ IsUnset() bool })(nil)).Elem() -type Mapping struct { - All []string - PKs []string - NonPKs []string - Generated []string - NonGenerated []string - AutoIncrement []string -} +type Mapping = mappings.Mapping -func (c Mapping) Columns(table ...string) orm.Columns { +func MappingCols(m Mapping, table ...string) orm.Columns { // to make sure we don't modify the passed slice - cols := make([]string, 0, len(c.All)) - for _, col := range c.All { + cols := make([]string, 0, len(m.All)) + for _, col := range m.All { if col == "" { continue } @@ -47,100 +27,11 @@ func (c Mapping) Columns(table ...string) orm.Columns { cols = append(cols, col) } - copy(cols, c.All) + copy(cols, m.All) return orm.NewColumns(cols...).WithParent(table...) } -type colProperties struct { - Name string - IsPK bool - IsGenerated bool - AutoIncrement bool -} - -func getColProperties(tag string) colProperties { - var p colProperties - if tag == "" { - return p - } - - parts := strings.Split(tag, ",") - p.Name = parts[0] - - for _, part := range parts[1:] { - switch part { - case "pk": - p.IsPK = true - case "generated": - p.IsGenerated = true - case "autoincr": - p.AutoIncrement = true - } - } - - return p -} - -func GetMappings(typ reflect.Type) Mapping { - c := Mapping{} - - if typ.Kind() == reflect.Pointer { - typ = typ.Elem() - } - - if typ.Kind() != reflect.Struct { - return c - } - - c.All = make([]string, typ.NumField()) - c.PKs = make([]string, typ.NumField()) - c.NonPKs = make([]string, typ.NumField()) - c.Generated = make([]string, typ.NumField()) - c.NonGenerated = make([]string, typ.NumField()) - c.AutoIncrement = make([]string, typ.NumField()) - - // Go through the struct fields and populate the map. - // Recursively go into any child structs, adding a prefix where necessary - for i := 0; i < typ.NumField(); i++ { - field := typ.Field(i) - - // Don't consider unexported fields - if !field.IsExported() { - continue - } - - // Skip columns that have the tag "-" - tag := field.Tag.Get("db") - if tag == "-" { - continue - } - - if tag == "" { - tag = snakeCase(field.Name) - } - - props := getColProperties(tag) - - c.All[field.Index[0]] = props.Name - if props.IsPK { - c.PKs[field.Index[0]] = props.Name - } else { - c.NonPKs[field.Index[0]] = props.Name - } - if props.IsGenerated { - c.Generated[field.Index[0]] = props.Name - } else { - c.NonGenerated[field.Index[0]] = props.Name - } - if props.AutoIncrement { - c.AutoIncrement[field.Index[0]] = props.Name - } - } - - return c -} - // Get the values for non generated columns func GetColumnValues[T any](mapping Mapping, filter []string, objs ...T) ([]string, [][]bob.Expression, error) { if len(objs) == 0 { @@ -229,7 +120,6 @@ func getObjVals(mapping Mapping, cols []string, val reflect.Value) ([]bob.Expres if name == c { field := val.Field(index) values = append(values, expr.Arg(field.Interface())) - break } } } diff --git a/internal/reflect_test.go b/internal/reflect_test.go index 5ce9b5e1..9341faf5 100644 --- a/internal/reflect_test.go +++ b/internal/reflect_test.go @@ -14,6 +14,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/stephenafamo/bob" "github.com/stephenafamo/bob/expr" + "github.com/stephenafamo/bob/internal/mappings" ) type User struct { @@ -100,7 +101,7 @@ func testGetColumns[T any](t *testing.T, expected Mapping) { var x T xTyp := reflect.TypeOf(x) t.Run(xTyp.Name(), func(t *testing.T) { - cols := GetMappings(xTyp) + cols := mappings.GetMappings(xTyp) if diff := cmp.Diff(expected, cols); diff != "" { t.Fatal(diff) } @@ -207,7 +208,7 @@ func testGetColumnValues[T any](t *testing.T, name string, tc testGetColumnsCase t.Helper() var x T xTyp := reflect.TypeOf(x) - cols := GetMappings(xTyp) + cols := mappings.GetMappings(xTyp) if name == "" { name = xTyp.Name() } diff --git a/named.go b/named.go new file mode 100644 index 00000000..60cae76c --- /dev/null +++ b/named.go @@ -0,0 +1,60 @@ +package bob + +import ( + "database/sql/driver" + "errors" + "io" +) + +var ErrRawNamedArg = errors.New("named arg used without rebinding") + +// named args should ONLY be used to prepare statements +type namedArg string + +// Value implements the driver.Valuer interface. +// it always returns an error because named args should only be used to prepare statements +func (n namedArg) Value() (driver.Value, error) { + return nil, ErrRawNamedArg +} + +// Named args should ONLY be used to prepare statements +func Named(names ...string) Expression { + return named{names: names} +} + +// NamedGroup is like Named, but wraps in parentheses +func NamedGroup(names ...string) Expression { + return named{names: names} +} + +type named struct { + names []string + grouped bool +} + +func (a named) WriteSQL(w io.Writer, d Dialect, start int) ([]any, error) { + if len(a.names) == 0 { + return nil, nil + } + + args := make([]any, len(a.names)) + + if a.grouped { + w.Write([]byte("(")) + } + + for k, name := range a.names { + if k > 0 { + w.Write([]byte(", ")) + } + + d.WriteArg(w, start+k) + args[k] = namedArg(name) + } + + if a.grouped { + w.Write([]byte(")")) + } + + return args, nil +} diff --git a/stmt.go b/stmt.go index 9d8123cf..04f9e701 100644 --- a/stmt.go +++ b/stmt.go @@ -3,10 +3,20 @@ package bob import ( "context" "database/sql" + "fmt" "github.com/stephenafamo/scan" ) +type ErrMismatchedArgs struct { + Expected int + Got int +} + +func (e ErrMismatchedArgs) Error() string { + return fmt.Sprintf("expected %d args, got %d", e.Expected, e.Got) +} + type Preparer interface { Executor PrepareContext(ctx context.Context, query string) (Statement, error) @@ -15,26 +25,23 @@ type Preparer interface { type Statement interface { ExecContext(ctx context.Context, args ...any) (sql.Result, error) QueryContext(ctx context.Context, args ...any) (scan.Rows, error) + Close() error } -// NewStmt wraps an [*sql.Stmt] and returns a type that implements [Queryer] but still -// retains the expected methods used by *sql.Stmt -// This is useful when an existing *sql.Stmt is used in other places in the codebase -func Prepare(ctx context.Context, exec Preparer, q Query) (Stmt, error) { +func prepare(ctx context.Context, exec Preparer, q Query) (Stmt, []any, error) { query, args, err := Build(q) if err != nil { - return Stmt{}, err + return Stmt{}, nil, err } stmt, err := exec.PrepareContext(ctx, query) if err != nil { - return Stmt{}, err + return Stmt{}, nil, err } s := Stmt{ - exec: exec, - stmt: stmt, - lenArgs: len(args), + exec: exec, + stmt: stmt, } if l, ok := q.(Loadable); ok { @@ -43,17 +50,27 @@ func Prepare(ctx context.Context, exec Preparer, q Query) (Stmt, error) { copy(s.loaders, loaders) } - return s, nil + return s, args, nil +} + +// Prepare prepares a query using the [Preparer] and returns a [Stmt] +func Prepare(ctx context.Context, exec Preparer, q Query) (Stmt, error) { + s, _, err := prepare(ctx, exec, q) + return s, err } // Stmt is similar to *sql.Stmt but implements [Queryer] type Stmt struct { stmt Statement exec Executor - lenArgs int loaders []Loader } +// Close closes the statement +func (s Stmt) Close() error { + return s.stmt.Close() +} + // Exec executes a query without returning any rows. The args are for any placeholder parameters in the query. func (s Stmt) Exec(ctx context.Context, args ...any) (sql.Result, error) { result, err := s.stmt.ExecContext(ctx, args...) @@ -75,11 +92,16 @@ func PrepareQuery[T any](ctx context.Context, exec Preparer, q Query, m scan.Map } func PrepareQueryx[T any, Ts ~[]T](ctx context.Context, exec Preparer, q Query, m scan.Mapper[T], opts ...ExecOption[T]) (QueryStmt[T, Ts], error) { + s, _, err := prepareQuery[T, Ts](ctx, exec, q, m, opts...) + return s, err +} + +func prepareQuery[T any, Ts ~[]T](ctx context.Context, exec Preparer, q Query, m scan.Mapper[T], opts ...ExecOption[T]) (QueryStmt[T, Ts], []any, error) { var qs QueryStmt[T, Ts] - s, err := Prepare(ctx, exec, q) + s, args, err := prepare(ctx, exec, q) if err != nil { - return qs, err + return qs, nil, err } settings := ExecSettings[T]{} @@ -99,7 +121,7 @@ func PrepareQueryx[T any, Ts ~[]T](ctx context.Context, exec Preparer, q Query, settings: settings, } - return qs, nil + return qs, args, nil } type QueryStmt[T any, Ts ~[]T] struct { @@ -109,6 +131,11 @@ type QueryStmt[T any, Ts ~[]T] struct { settings ExecSettings[T] } +// Close closes the statement +func (s QueryStmt[T, Ts]) Close() error { + return s.stmt.Close() +} + func (s QueryStmt[T, Ts]) One(ctx context.Context, args ...any) (T, error) { var t T diff --git a/stmt_bound.go b/stmt_bound.go new file mode 100644 index 00000000..b0c8b6b0 --- /dev/null +++ b/stmt_bound.go @@ -0,0 +1,169 @@ +package bob + +import ( + "context" + "database/sql" + "errors" + "reflect" + + "github.com/stephenafamo/bob/internal/mappings" + "github.com/stephenafamo/scan" +) + +type structBinder[T any] struct { + args []string + fields []string +} + +func (b structBinder[T]) toArgs(arg T) ([]any, error) { + val := reflect.ValueOf(arg) + if val.Kind() == reflect.Pointer { + if val.IsNil() { + return nil, errors.New("object is nil") + } + val = val.Elem() + } + + values := make([]any, len(b.args)) + +ArgLoop: + for index, argName := range b.args { + for _, fieldName := range b.fields { + if fieldName == argName { + field := val.Field(index) + values[index] = field.Interface() + continue ArgLoop + } + } + return nil, ErrMissingArg{Name: argName} + } + + return values, nil +} + +func makeStructBinder[Arg any](args []any) (structBinder[Arg], error) { + argPositions := make([]string, len(args)) + for pos, arg := range args { + if name, ok := arg.(namedArg); ok { + argPositions[pos] = string(name) + continue + } + + return structBinder[Arg]{}, ErrNamedArgRequired{arg} + } + + var x Arg + fieldPositions := mappings.GetMappings(reflect.TypeOf(x)).All + + // check if all positions have matching fields +ArgLoop: + for _, name := range argPositions { + for _, field := range fieldPositions { + if field == name { + continue ArgLoop + } + } + return structBinder[Arg]{}, ErrMissingArg{Name: name} + } + + return structBinder[Arg]{ + args: argPositions, + fields: fieldPositions, + }, nil +} + +// PrepareBound prepares a query using the [Preparer] and returns a [NamedStmt] +func PrepareBound[Arg any](ctx context.Context, exec Preparer, q Query) (BoundStmt[Arg], error) { + stmt, args, err := prepare(ctx, exec, q) + if err != nil { + return BoundStmt[Arg]{}, err + } + + binder, err := makeStructBinder[Arg](args) + if err != nil { + return BoundStmt[Arg]{}, err + } + + return BoundStmt[Arg]{ + stmt: stmt, + binder: binder, + }, nil +} + +// BoundStmt is similar to *sql.Stmt but implements [Queryer] +// instead of taking a list of args, it takes a struct to bind to the query +type BoundStmt[Arg any] struct { + stmt Stmt + binder structBinder[Arg] +} + +// Close closes the statement. +func (s BoundStmt[Arg]) Close() error { + return s.stmt.Close() +} + +// Exec executes a query without returning any rows. The args are for any placeholder parameters in the query. +func (s BoundStmt[Arg]) Exec(ctx context.Context, arg Arg) (sql.Result, error) { + args, err := s.binder.toArgs(arg) + if err != nil { + return nil, err + } + + return s.stmt.Exec(ctx, args...) +} + +func PrepareBoundQuery[Arg any, T any](ctx context.Context, exec Preparer, q Query, m scan.Mapper[T], opts ...ExecOption[T]) (BoundQueryStmt[Arg, T, []T], error) { + return PrepareBoundQueryx[Arg, T, []T](ctx, exec, q, m, opts...) +} + +func PrepareBoundQueryx[Arg any, T any, Ts ~[]T](ctx context.Context, exec Preparer, q Query, m scan.Mapper[T], opts ...ExecOption[T]) (BoundQueryStmt[Arg, T, Ts], error) { + s, args, err := prepareQuery[T, Ts](ctx, exec, q, m, opts...) + + binder, err := makeStructBinder[Arg](args) + if err != nil { + return BoundQueryStmt[Arg, T, Ts]{}, err + } + + return BoundQueryStmt[Arg, T, Ts]{ + query: s, + binder: binder, + }, nil +} + +type BoundQueryStmt[Arg any, T any, Ts ~[]T] struct { + query QueryStmt[T, Ts] + binder structBinder[Arg] +} + +// Close closes the statment. +func (s BoundQueryStmt[Arg, T, Ts]) Close() error { + return s.query.Close() +} + +func (s BoundQueryStmt[Arg, T, Ts]) One(ctx context.Context, arg Arg) (T, error) { + args, err := s.binder.toArgs(arg) + if err != nil { + var t T + return t, err + } + + return s.query.One(ctx, args...) +} + +func (s BoundQueryStmt[Arg, T, Ts]) All(ctx context.Context, arg Arg) (Ts, error) { + args, err := s.binder.toArgs(arg) + if err != nil { + return nil, err + } + + return s.query.All(ctx, args...) +} + +func (s BoundQueryStmt[Arg, T, Ts]) Cursor(ctx context.Context, arg Arg) (scan.ICursor[T], error) { + args, err := s.binder.toArgs(arg) + if err != nil { + return nil, err + } + + return s.query.Cursor(ctx, args...) +} diff --git a/stmt_mapped.go b/stmt_mapped.go new file mode 100644 index 00000000..7833bbdc --- /dev/null +++ b/stmt_mapped.go @@ -0,0 +1,179 @@ +package bob + +import ( + "context" + "database/sql" + "fmt" + + "github.com/stephenafamo/scan" +) + +type ErrNamedArgRequired struct{ value any } + +func (e ErrNamedArgRequired) Error() string { + return fmt.Sprintf("expected named arg, got %#v", e.value) +} + +type ErrDuplicateArg struct{ Name string } + +func (e ErrDuplicateArg) Error() string { + return fmt.Sprintf("duplicate arg %s", e.Name) +} + +type ErrMissingArg struct{ Name string } + +func (e ErrMissingArg) Error() string { + return fmt.Sprintf("missing arg %s", e.Name) +} + +type mapBinder struct { + unique int + positions []string +} + +func (m mapBinder) toArgs(mapArgs map[string]any) ([]any, error) { + if len(mapArgs) != m.unique { + return nil, ErrMismatchedArgs{ + Expected: m.unique, + Got: len(mapArgs), + } + } + + args := make([]any, len(m.positions)) + for position, name := range m.positions { + value, ok := mapArgs[name] + if !ok { + return nil, ErrMissingArg{Name: name} + } + + args[position] = value + } + + return args, nil +} + +func makeMapBinder(args []any) (mapBinder, error) { + positions := make([]string, len(args)) + for pos, arg := range args { + if name, ok := arg.(namedArg); ok { + positions[pos] = string(name) + continue + } + + return mapBinder{}, ErrNamedArgRequired{arg} + } + + // count unique names + unique := make(map[string]struct{}) + for _, name := range positions { + if _, ok := unique[name]; !ok { + unique[name] = struct{}{} + } + } + + return mapBinder{ + unique: len(unique), + positions: positions, + }, nil +} + +// PrepareMapped prepares a query using the [Preparer] and returns a [NamedStmt] +func PrepareMapped(ctx context.Context, exec Preparer, q Query) (MappedStmt, error) { + stmt, args, err := prepare(ctx, exec, q) + if err != nil { + return MappedStmt{}, err + } + + m, err := makeMapBinder(args) + if err != nil { + return MappedStmt{}, err + } + + return MappedStmt{ + stmt: stmt, + mapper: m, + }, nil +} + +// MappedStmt is similar to *sql.Stmt but implements [Queryer] +// instead of taking a list of args, it takes a map of args or a struct to bind to the query +type MappedStmt struct { + stmt Stmt + mapper mapBinder +} + +// Inspect returns a map with all the expected keys +func (s MappedStmt) Inspect() []string { + return s.mapper.positions +} + +// Close closes the statement +func (s MappedStmt) Close() error { + return s.stmt.Close() +} + +// Exec executes a query without returning any rows. The args are for any placeholder parameters in the query. +func (s MappedStmt) Exec(ctx context.Context, mappedArgs map[string]any) (sql.Result, error) { + args, err := s.mapper.toArgs(mappedArgs) + if err != nil { + return nil, err + } + + return s.stmt.Exec(ctx, args...) +} + +func PrepareMappedQuery[T any](ctx context.Context, exec Preparer, q Query, m scan.Mapper[T], opts ...ExecOption[T]) (MappedQueryStmt[T, []T], error) { + return PrepareMappedQueryx[T, []T](ctx, exec, q, m, opts...) +} + +func PrepareMappedQueryx[T any, Ts ~[]T](ctx context.Context, exec Preparer, q Query, m scan.Mapper[T], opts ...ExecOption[T]) (MappedQueryStmt[T, Ts], error) { + s, args, err := prepareQuery[T, Ts](ctx, exec, q, m, opts...) + + binder, err := makeMapBinder(args) + if err != nil { + return MappedQueryStmt[T, Ts]{}, err + } + + return MappedQueryStmt[T, Ts]{ + query: s, + binder: binder, + }, nil +} + +type MappedQueryStmt[T any, Ts ~[]T] struct { + query QueryStmt[T, Ts] + binder mapBinder +} + +// Close closes the statement +func (s MappedQueryStmt[T, Ts]) Close() error { + return s.query.Close() +} + +func (s MappedQueryStmt[T, Ts]) One(ctx context.Context, arg map[string]any) (T, error) { + args, err := s.binder.toArgs(arg) + if err != nil { + var t T + return t, err + } + + return s.query.One(ctx, args...) +} + +func (s MappedQueryStmt[T, Ts]) All(ctx context.Context, arg map[string]any) (Ts, error) { + args, err := s.binder.toArgs(arg) + if err != nil { + return nil, err + } + + return s.query.All(ctx, args...) +} + +func (s MappedQueryStmt[T, Ts]) Cursor(ctx context.Context, arg map[string]any) (scan.ICursor[T], error) { + args, err := s.binder.toArgs(arg) + if err != nil { + return nil, err + } + + return s.query.Cursor(ctx, args...) +}